Problem Statement in English
You’re given an array of positive integers w where w[i] describes the weight of index i.
You need to implement the Solution class:
Solution(int[] w)Initializes the object with the given weight arrayw.int pickIndex()Picks an index randomly, where the probability of picking indexiis $w[i] / sum(w)$.
Approach
This is probably not something that you would come up with on your own during an interview.
Since we need to pick an index with probability proportional to its weight, we need a way to sample a random number according to weights. The sampling itself is easy, we can just use the built in random thingy to do that. The hard part is to figure out how to do the sampling according to weights.
So the idea behind this is that we first build a way to calculate the weight until each index.
For example, we can represent the input [1,3] as: [1, 4].
Now if we randomly pick a number from 0 to the highest weight (in this case 4), we can determine which index it belongs to by checking which range it falls into. And the index whose range that random number belongs to is our answer.
Think about it, the number 1 belongs to index 0, and the numbers 2, 3, 4 belong to index 1. So index 0 has a probability of $1/4$ and index 1 has a probability of $3/4$, which is exactly what we want.
This means that for any given random number, we use a linear scan to find the index whose weight range that number belongs to.
But we can do better than linear scan using binary search. So the only thing that changes is that instead of scanning linearly, we do a binary search to find the index.
Python lets us do that easily using the bisect module. You can either use bisect_left or bisect_right here.
Solution in Python
class Solution:
def __init__(self, w: List[int]):
self.prefix_sums = [w[0]]
for v in w[1:]:
self.prefix_sums.append(self.prefix_sums[-1] + v)
def pickIndex(self) -> int:
num = random.randint(1, self.prefix_sums[-1])
return bisect_left(self.prefix_sums, num)
Complexity
Time: $O(\log n)$
Since we use a binary search to find the index for each pick.Space: $O(n)$
Since we are using only a linear amount of space to store the intermediate results.
Mistakes I Made
I had to look this up :(
And we are done.