iJoe's Blog
Published on

算法005-线段树(优化-矩阵中的局部最大值 II)

Authors

线段树

线段树用于查询区间最大最小值,对于一维数组在预处理时时间复杂度为O(nlogn),空间复杂度也是O(nlogn)。下面表格对常用算法作比较。

数据结构主要作用查询修改建立空间特点
树状数组单点修改+区间和O(logn)O(logn)O(n)O(n)简洁
ST表静态区间最值O(1)不支持O(nlogn)O(nlogn)查询最快
线段树动态区间问题O(logn)O(logn)O(n)O(4n)万能

原理

原理很简单,就是将数组区间以树方式呈现。比如[l, r]区间的节点包含它的子区间,也就是对应的左右子树[l, mid]和[mid + 1, r],mid是中间值。这样建立的树也是完全二叉树,所以查询修改是O(logn)。

代码

#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
private:
    int n;
    vector<long long> tree;

    void build(const vector<int>& a, int node, int l, int r) {
        if (l == r) {
            tree[node] = a[l];
            return;
        }

        int mid = (l + r) / 2;

        build(a, node * 2, l, mid);
        build(a, node * 2 + 1, mid + 1, r);

        tree[node] = tree[node * 2] + tree[node * 2 + 1];
    }

    void update(int node, int l, int r, int index, int value) {
        if (l == r) {
            tree[node] = value;
            return;
        }

        int mid = (l + r) / 2;

        if (index <= mid) {
            update(node * 2, l, mid, index, value);
        } else {
            update(node * 2 + 1, mid + 1, r, index, value);
        }

        tree[node] = tree[node * 2] + tree[node * 2 + 1];
    }

    long long query(int node, int l, int r, int ql, int qr) {
        // 当前区间完全被查询区间包含
        if (ql <= l && r <= qr) {
            return tree[node];
        }

        int mid = (l + r) / 2;
        long long ans = 0;

        // 查询区间和左子树有交集
        if (ql <= mid) {
            ans += query(node * 2, l, mid, ql, qr);
        }

        // 查询区间和右子树有交集
        if (qr > mid) {
            ans += query(node * 2 + 1, mid + 1, r, ql, qr);
        }

        return ans;
    }

public:
    SegmentTree(const vector<int>& a) {
        n = a.size();
        tree.assign(4 * n, 0);
        build(a, 1, 0, n - 1);
    }

    // 单点修改:把 a[index] 改成 value
    void update(int index, int value) {
        update(1, 0, n - 1, index, value);
    }

    // 查询闭区间 [l, r] 的和
    long long query(int l, int r) {
        return query(1, 0, n - 1, l, r);
    }
};

int main() {
    vector<int> a = {2, 1, 5, 3, 4};

    SegmentTree seg(a);

    cout << seg.query(1, 3) << endl;
    // 查询 a[1] + a[2] + a[3] = 1 + 5 + 3 = 9

    seg.update(2, 10);
    // 把 a[2] 从 5 改成 10

    cout << seg.query(1, 3) << endl;
    // 查询 a[1] + a[2] + a[3] = 1 + 10 + 3 = 14

    return 0;
}

矩阵中的局部最大值 II

给你一个 n x m 的整数矩阵 matrix ,所有元素均为非负整数。

一个 非零 单元格 (row, col) 会按如下方式检查其附近的单元格:

1.令 x = matrix[row][col] 。
2.考虑在 (row, col) 的 x 行和 x 列范围内的每个单元格。
3.忽略矩阵外的单元格。
4.忽略行距离和列距离都恰好等于 x 的 单元格。

如果单元格 (row, col) 是 非零 的,并且所有考虑的单元格中没有一个值 大于 x ,那么该单元格就是一个 局部最大值 。

返回一个整数,表示 matrix 中 局部最大值 的数量。

  1. 示例 1:

输入:matrix = [[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,2,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0]]

输出: 1

解释:
对于非零单元格 (3, 3) ,x = matrix[3][3] = 2
高亮的单元格是在 (3, 3) 的 x 行和 x 列范围内被考虑的单元格。
行距离和列距离都等于 x = 2 的四个单元格被忽略。
没有一个被考虑的单元格的值大于 2 ,因此 (3, 3) 是一个局部最大值。
没有其他非零单元格,所以答案是 1 。

  1. 示例 2:

输入: matrix = [[1,2],[3,4]]

输出: 1

解释:
只有值为 4 的单元格是局部最大值。其他每个非零单元格都考虑到了一个具有更大值的单元格。

  1. 示例 3:

输入:matrix = [[1,0,1],[0,1,0],[1,0,1]]

输出:5

解释:
对于值为 1 的单元格,考虑的单元格是其自身及其在矩阵内的 4 个方向上相邻的单元格。
这五个值为 1 的单元格中,每一个都只考虑到值为 0 或 1 的单元格,所以这五个单元格都是局部最大值。

  1. 示例 4:

输入:matrix = [[1,1],[1,1]]

输出:4

解释:
所有单元格都具有相同的值。因此,没有任何一个单元格会考虑到具有更大值的其他单元格,所以所有 4 个单元格都是局部最大值。

提示:
1 <= n == matrix.length <= 200
1 <= m == matrix[i].length <= 200
0 <= matrix[i][j] <= 200

解题思路

线段树+st表优化的核心思路是,将列用st表构建,行用线段树构建,即时间复杂度为O(n*mlogm)。随后再通过查询,其中线段树时间复杂度是O(mnlogn),因此整体时间复杂度为O(mn(logn+logm))。
在上述基础上,有两种不同构建方式。这里我把最基本的单列表构建的st表称为基础st表。

  1. 基于构建的整个基础st表来构建线段树,即:任意st[i][j]都可以构建出一个列范围为[i, i + 2^j -1]的表。
  2. 构建线段树时,只依托于子节点的st[i][0]生成一个新的列,然后再通过st构建出全新的表。

上面两个其实时间复杂度都一样,但第一个代码实现上较复杂,所以选用第二个。

代码

struct STTable {
public:
    vector<vector<int>> st;

    STTable() {}

    STTable(const vector<int>& table) {
        int n = table.size();
        int kn = bit_width((unsigned)n);
        st.resize(n, vector<int>(kn));

        for (int i = 0; i < n; i++) {
            st[i][0] = table[i];
        }

        for (int ki = 1; ki < kn; ki++) {
            for (int i = 0; i <= n - (1 << ki); i++) {
                st[i][ki] = max(st[i][ki - 1], st[i + (1 << (ki - 1))][ki - 1]);
            }
        }
    }

    int query(int l, int r) {
        int k = bit_width(1u * (r - l)) - 1;
        return max(st[l][k], st[r - (1 << k)][k]);
    }
};

struct SegmentTree {
public:
    vector<STTable> t;

    void build(const vector<vector<int>>& a, int node, int l, int r) {
        if (l == r) {
            t[node] = STTable(a[l]); // 这里时对叶节点进行构建
            return;
        }

        int mid = (l + r) / 2;
        build(a, node * 2, l, mid);
        build(a, node * 2 + 1, mid + 1, r);

        vector<int> merge(a[0].size());

        // 这里是先基于st[i][0]生成一个全新的列
        for (int i = 0; i < a[0].size(); i++) {
            merge[i] = max(t[node * 2].st[i][0], t[node * 2 + 1].st[i][0]);
        }

        t[node] = STTable(merge); // 然后再将其列构建一个全新的st表
    }

    SegmentTree(const vector<vector<int>>& a) : t(2 << bit_width(a.size())) {
        build(a, 1, 0, a.size() - 1);
    }

    int query(int node, int l, int r, int r1, int r2, int c1, int c2) {
        if (r1 <= l && r <= r2) {
            return t[node].query(c1, c2);
        }
        int m = (l + r) / 2;
        if (r2 <= m) {
            return query(node * 2, l, m, r1, r2, c1, c2);
        }

        if (r1 > m) {
            return query(node * 2 + 1, m + 1, r, r1, r2, c1, c2);
        }

        return max(query(node * 2, l, m, r1, r2, c1, c2), query(node * 2 + 1, m + 1, r, r1, r2, c1, c2));
    }
};

class Solution {
public:
    int countLocalMaximums(vector<vector<int>>& matrix) {
        int n = matrix.size();
        int m = matrix[0].size();
        SegmentTree t(matrix);

        int ans = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                int x = matrix[i][j];
                if (x > 0 && max(t.query(1, 0, n - 1, max(i - x, 0), min(i + x, n - 1), max(j - x + 1, 0), min(j + x, m)),
                                 t.query(1, 0, n - 1, max(i - x + 1, 0), min(i + x - 1, n - 1), max(j - x, 0), min(j + x + 1, m))) <= x)
                    ans++;
            }
        }

        return ans;
    }
};