Problem Statement in English

You’re given a binary string s and an integer k.

In one operation, you can choose k indices and flip the any 0s to 1s and any 1s to 0s.

The goal is to make all characters in the string equal to 1. You need to return the minimum number of operations required to achieve this. If it’s not possible, return -1.


Approach

Whew. Buckle up kids, this one is going to be messy.

So there are 2 phases to this problem.

The first one is to realise how to solve the problem. The second one is how to optimise the solution.

The crazy thing about this question is that it’s a graph problem. Yea. Who would have even…we’ll work our way up to that though.

So, let’s first handle the modelling part of it. How do we model the problem. Well, one option is to represent the state as the number of 0s in the string. So, for example, if we have s = "11000", then the state would be 3 since there are 3 0s in the string. The goal is to reach the state 0 since we want all characters to be 1.

Now comes the math. Let’s define some variables.

Let n be the length of the string, c be the current state (number of 0s), x be the number of 0s we decide to flip in a certain operation, and k be the number of indices we can flip in one operation.

So if we have c 0s, and we decide to flip x of them, then we will also be flipping k - x 1s. This means that the new state will be $$c_\text{new} = c - x + (k - x) = c + k - 2x$$

We can’t flip more 0s than we have, so $$x \leq c$$ We also can’t flip more 1s than we have, and the number of 1s we’re flipping is k - x, so $$k - x \leq n - c$$ which simplifies to $$x \geq k - n + c$$

We have finally arrived at the range of x that we can flip in one operation, which is $$\max(0, k - n + c) \leq x \leq \min(c, k)$$

And this is where we finally arrive at the graph part of the problem.

We can represent the states as nodes in a graph, and there is an edge from node c to all nodes in the aforementioned range.

Now we can just try to go from node c to node 0 using breadth first search, and the number of edges we need to traverse will be the answer. Breadth first search is the right choice here since we want to find the minimum number of operations, which is equivalent to finding the shortest path in an unweighted graph.

Remember that we also need to keep track of the visited nodes to avoid cycles in our graph. Once we visit a certain node, we can mark it as visited and not visit it again. So if we ever encounter a certain number of 0s that we’ve already visited, we can skip it.

Now for the optimisation part. What can we optimise for to begin with?

Let me draw your attention to that inequality again. $$c_\text{new} = c + k - 2x$$

Notice that c_new is always of the same parity as c + k. This is because 2x is always even, so it doesn’t change the parity of c + k. Since we know that odd + even = odd and even + even = even, we can conclude that c_new will always be of the same parity as c + k.

This brings us to the first optimisation. We can maintain two separate sets of nodes, one for even states and one for odd states. This way, when we want to find the next states from a given state c, we can directly look into the set that corresponds to the parity of c + k, which will be the parity of c_new.

This will allow us to skip checking all the nodes that are not of the same parity, which will significantly reduce our search space.

The second optimisation is to use a sorted set to maintain the nodes in each parity set. This way, when we want to find the next states from a given state c, we can use binary search to quickly find the range of nodes that we can transition to, which will further reduce our search space.

You can now read the code and try to put the pieces together.

I would also like to point out the maximum range of numbers that we will have in our sets is from 0 to n, since the number of 0s in the string can never be negative and can never exceed the length of the string.


Solution in Python


class Solution:
    def minOperations(self, s: str, k: int) -> int:
        n = len(s)
        ts = [SortedSet() for _ in range(2)]

        for i in range(n + 1):
            ts[i % 2].add(i)

        cnt0 = s.count('0')
        ts[cnt0 % 2].remove(cnt0)
        q = deque([cnt0])
        ans = 0

        while q:
            for _ in range(len(q)):
                cur = q.popleft()

                if cur == 0:
                    return ans

                l = cur + k - 2 * min(cur, k)
                r = cur + k - 2 * max(k - n + cur, 0)
                t = ts[l % 2]

                j = t.bisect_left(l)

                while j < len(t) and t[j] <= r:
                    q.append(t[j])
                    t.remove(t[j])

            ans += 1
            
        return -1

Complexity

  • Time: $O(n \log n)$
    Since we are using a sorted set to maintain the nodes, and we are performing binary search on it, which takes $O(\log n)$ time. In the worst case, we might have to visit all the nodes, which will take $O(n)$ time. Therefore, the overall time complexity is $O(n \log n)$.

  • Space: $O(n)$
    Since we are maintaining two sorted sets to keep track of the nodes, and in the worst case, we might have to store all the nodes in one of the sets, which will take $O(n)$ space. Therefore, the overall space complexity is $O(n)$.


Mistakes I Made

I had to look this one up.


And we are done.