Problem Statement in English

You are given a tree (i.e. a connected, undirected graph with no cycles) consisting of n nodes numbered from 0 to n - 1 and exactly n - 1 edges. The root of the tree is node 0. Each node has a value associated with it, given in the 0-indexed integer array values of length n, where values[i] is the value of the i-th node.

You are also given an integer k. You can delete any edges of the tree, thus splitting the tree.

A component is considered k-divisible if the sum of the values of all nodes in that component is divisible by k.

Return the maximum number of k-divisible components you can obtain by deleting edges optimally.


Approach

The first step is to realize that we can only split the tree into subtrees that are divisible by k if the tree itself was divisible by k to begin with. Thus it is guaranteed that the sum of all values in the tree is divisible by k.

Now we can realise that if as soon as any subtree is divisible by k, we can cut it off from the rest of the tree and count it as a valid component, as it is guaranteed that the rest of the tree will remain divisible by k.

So now we can arrive at the conclusion that we can work our way upwards from the bottom, thus evaluating subtrees and cutting them off as soon as they are divisible by k.

The implementation also exploits the fact that as soon as a subtree is divisible by k, we can say we have found one valid component. Then we return 0 to the parent, since we have already counted this subtree as a valid component.


Solution in Python


class Solution:
    def maxKDivisibleComponents(
        self, n: int, edges: List[List[int]], values: List[int], k: int
    ) -> int:
        adjList = defaultdict(list)

        for a, b in edges:
            adjList[a].append(b)
            adjList[b].append(a)

        res = 0

        def dfs(node, parent):
            total = values[node]

            for child in adjList[node]:
                if child != parent:
                    total += dfs(child, node)

            if total % k == 0:
                nonlocal res
                res += 1
                return 0
                
            return total

        dfs(0, -1)

        return res

Complexity

  • Time: $O(n)$
    Since we visit each node exactly once during the DFS traversal.

  • Space: $O(n)$
    Since we use an adjacency list to store the tree structure.


Mistakes I Made

I had to watch the NeetCode solution video to understand the approach.


And we are done.