Problem Statement in English

You’re given the root of a binary tree with n nodes. Each node is uniquely assigned a value from 1 to n. You are also given an integer start, which is the value of the node where an infection starts.

At minute 0, the node with value start becomes infected. Each minute, an infected node spreads the infection to all of its uninfected neighbors (the left child, the right child, and the parent node). Return the amount of time it takes for the entire tree to become infected.


Approach

Tackling this problem in its original tree structure can be quite challenging due to the need to access parent nodes during the infection spread.

To simplify the process, we can convert the binary tree into an undirected graph using an adjacency list representation. This allows us to easily traverse the tree in all directions (to children and to parents).

From there we can simply use a breadth first search or depth first search approach to simulate the infection spread from the starting node, keeping track of the time taken to infect all nodes.

The answer will be the maximum time taken to reach any node during this traversal.


Solution in Python


class Solution:
    def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
        adj = defaultdict(list)

        def traverse(node):
            if node.left:
                this_node_val = node.val
                left_node_val = node.left.val

                adj[this_node_val].append(left_node_val)
                adj[left_node_val].append(this_node_val)

                traverse(node.left)
            
            if node.right:
                this_node_val = node.val
                right_node_val = node.right.val

                adj[this_node_val].append(right_node_val)
                adj[right_node_val].append(this_node_val)

                traverse(node.right)
        
        traverse(root)

        if len(adj) == 0:
            return 0

        res = 0

        stack = [(start, 0)]
        seen = set([start])

        while stack:
            node, time = stack.pop()
            res = max(res, time)

            for nei in adj[node]:
                if nei in seen:
                    continue
                stack.append((nei, time + 1))
                seen.add(nei)

        return res

Complexity

  • Time: $O(n)$
    Since we traverse all nodes in the binary tree once.

  • Space: $O(n)$
    We use an adjacency list to store the tree structure and a stack for DFS traversal.


And we are done.