Problem Statement in English

You’re given a matrix matrix. Your task is to ensure $O(1)$ queries on the sum of elements in any submatrix of matrix.


Approach

Full disclosure: this is not my code, I got this online from NeetCode (brilliant channel by the way, definitely check him out).

I think it’ll be much easier to explain this through an example.

Initialisation

Consider the following matrix:

$1$$2$$3$
$4$$5$$6$
$7$$8$$9$

If we had to calculate the sum of this submatrix:

$1$$2$$3$
$4$$5$$6$
$7$$8$$9$

That’s easy, because during initialisation of the class, we need to add the cell just above the current cell to it. So a backing matrix would look something like this:

$1$$2$$3$
$4+1=5$$2+5=7$$3+6=9$
$5+7=12$$15$$18$

Now we can directly check the value of index backingMatrix[2][2] to get the sum of the required submatrix: $12$.

But notice that this doesn’t help us with horizontal submatrices.

For that we need to add the element in the cell just to the left of the current cell it. So for this matrix:

$1$$2$$3$
$4$$5$$6$
$7$$8$$9$

A potential backing matrix would look like this:

$1$$3$$6$
$4$$9$$15$
$7$$15$$24$

And now, we can check the index backingMatrix[0][2] to get the sum of the required submatrix: $6$.

Finally, what about a combination of the previous cases, like:

$1$$2$$3$
$4$$5$$6$
7$8$$9$

It’s quite clear that the previous approach of separate kinds of indexing won’t work. So then, what can be done?

Combine the indexes!

We do a row-wise prefix sum calculation and add the sum just above the current cell to this cell.

But then notice that there’s an overlap between the two kinds of indexing. So we need to subtract the sum of the cell just above and to the left (diagonally left) of the current cell, which gets added twice to this cell.

I know this sounds real nasty, but just bear with me for a bit.

Let’s approach this through an example.

The easy part:

  • for the first row, we do a row-wise prefix sum calculation
  • for the first column, we do a column-wise prefix sum calculation
$1$$3$$6$
$5$
$12$

Now let’s see the slightly more involved calculation:

$1$$3$$6$
$5$$?$
$12$

Here we go:

$1$$3$$6$
$5$$5 (\text{this}) + 5 (\text{left}) + 3 (\text{above}) - 1 (\text{diag left}) = 12$
$12$

We’re removing the diagonally left cell because the cell above is $2+1=3$, the cell to the left is $5+1=6$, and the cell diagonally left is $1$. So both, the cell above and to the left already contain the diagonally left cell.

Notice how we’re adding the diagonally left cell twice. Which we don’t want.

Let’s work through another example:

$1$$3$$6$
$5$$12$$?$
$12$

Here’s the calculation:

$1$$3$$6$
$5$$12$$6 (\text{this}) + 12 (\text{left}) + 6 (\text{above}) - 3 (\text{diag left}) = 21$
$12$

And, finishing the rest in a similar manner, we have:

$1$$3$$6$
$5$$12$$21$
$12$$27$$45$

Querying

This is the easy part. We just need to calculate the sum of the submatrix using the backing matrix.

Notice that the sum of any matrix that has its top left at index [0][0] is just the value at the bottom right of the submatrix. Here’s what I mean by that:

$1$$3$$6$
$5$$12$$21$
$12$$27$$45$

The sum of this submatrix is at the cell [1][1]: $12$.

But what about submatrices that don’t start at [0][0] like…

$1$$2$$3$
$4$$5$$6$
$7$$8$$9$

The question will indicate this submatrix as:

  • $\text{row1} = 1, \text{col1} = 1$
  • $\text{row2} = 2, \text{col2} = 2$

To solve this, we need to remove the row and column just before the start of the submatrix.

But we also need to add the cell diagonally left of the start of the submatrix, because otherwise, that’d be removed twice. Here’s what it looks like visually:

$1$$2$$3$
$4$$5$$6$
$7$$8$$9$

Green indicates parts we add, and red indicates parts we remove.

We can generalise this as:

$$ \text{sum} = \text{(row2, col2)} - \text{(row1-1, col2)} - \text{(row2, col1-1)} + \text{(row1-1, col1-1)} $$

  • (row2, col2) is the sum of the submatrix from [0][0] to the [row2][col2] cell
  • (row1-1, col2) is the sum of the submatrix that starts from [0][0] and ends before the row of the submatrix we’re interested in
  • (row2, col1-1) is the sum of the submatrix that starts from [0][0] and ends to the left of the submatrix we’re interested in
  • (row1-1, col1-1) is the diagonally left cell (that gets removed twice)

Just keep in mind that sometimes the row or column before the start of the submatrix might not exist, so we need to factor that in.


Whew, that was something.


Solution in Python


class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        self.index = []

        for i in range(len(matrix)):
            tempRow = []

            for j in range(len(matrix[i])):
                # if both prev col and prev row exist
                if i - 1 >= 0 and j - 1 >= 0:
                    # cell on left - cell on diagonal left + cell above + this
                    tempRow.append(
                        tempRow[-1]
                        - self.index[i - 1][j - 1]
                        + self.index[i - 1][j]
                        + matrix[i][j]
                    )
                # if only prev row exists
                elif i - 1 >= 0:
                    # cell above + this
                    tempRow.append(self.index[i - 1][j] + matrix[i][j])
                # if only prev col exists
                elif j - 1 >= 0:
                    # cell on left + this
                    tempRow.append(tempRow[-1] + matrix[i][j])
                # if neither
                else:
                    # put this cell as is into index
                    # it's the (0,0) cell
                    tempRow.append(matrix[i][j])

            # add temp row to index
            self.index.append(tempRow)
            # clear temp row
            tempRow = []

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        # if prev row and prev col exists
        # (row2, col2) - prev col - prev row + intersection
        # => (row2, col2) - (row2, col1 - 1) - (row1, col2) + (row1 - 1, col1 - 1)
        if row1 - 1 >= 0 and col1 - 1 >= 0:
            return (
                self.index[row2][col2]
                - self.index[row2][col1 - 1]
                - self.index[row1 - 1][col2]
                + self.index[row1 - 1][col1 - 1]
            )

        # if only prev col exists
        # (row2, col2) - prev col
        # (row2, col2) - (row2, col1-1)
        elif col1 - 1 >= 0:
            return self.index[row2][col2] - self.index[row2][col1 - 1]

        # if only prev row exists
        # (row2, col2) - prev row
        # (row2, col2) - (row1, col2)
        elif row1 - 1 >= 0:
            return self.index[row2][col2] - self.index[row1 - 1][col2]

        # neither exist
        else:
            # (row2, col2)
            return self.index[row2][col2]

Complexity

  • Time: $O(1)$
    For the query, and $O(m \times n)$ for the initialisation.

  • Space: $O(m \times n)$
    Since we’re storing the backing matrix.


Mistakes I Made

I had to look this one up :(


And we are done.