题解 P3380 【【模板】二逼平衡树(树套树)】

这道题最经典的做法便是线段树套平衡树。

也就是说,我们用一颗线段树管理下标区间 $[L,R]$ ,然后每个线段树的节点都是一颗平衡树,平衡树中存储了所有下标在 $[L,R]$ 之间的所有数的信息。这样的话,所有区间都可以被分为 $O(\log N)$ 个线段树的节点,区间问题便转化为了几个平衡树上的问题。

然后我们考虑如何具体实现各个操作:

1、查询区间内一个数的排名:在线段树上找到区间对应的节点,然后每个节点的平衡树内查询对应数的排名并求和。时间复杂度 $O(\log^2N)$

2、查询区间内排名为k的数是几:由于这项操作在线段树上不可加。所以我们考虑转换为判定一个数是不是区间内排名为k的。这个可以在 $O(\log^2N)$ 的时间内通过操作1完成。那么我们考虑二分答案,就可以解决这个问题。时间复杂度 $O(\log^3N)$

3、单点修改:我们在线段树上找到所有覆盖这个点的区间,然后在所有区间对应的平衡树中删除原数,加入新数即可。时间复杂度 $O(\log^2N)$

4,5、 查询区间内一个数前驱、后继:这是平衡树上的经典问题。我们只需要对于所有区间分别查询,然后答案相应的取 $\max \min$ 即可。时间复杂度 $O(\log^2N)$

当然,由于序列中的一个数最多被插入 $O(\log N)$ 个平衡树内,考虑到每次修改操作最多增加并删除一个节点,所以空间复杂度为 $O(N\log N)$ 。经过我们的分析,时间复杂度为 $O(N\log^3N)$。当然这是一个较松的上界。

对于实现来说,我使用了 Treap 作为平衡树,主要原因是因为好写而且常数较小。(我这份代码不开O2竟然也能过)而且要注意操作1与操作2的一些边界问题。

代码写的很丑。。

#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <complex>
#include <string>
#include <cstdio>
#include <vector>
#include <bitset>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <map>
#include <set>
using namespace std;

const int MAXN = 1e5 + 100;
int n, m;
int a[MAXN];

namespace Treap
{
    struct balanced
    {
        int w;
        int sz;
        int num;
        int fix;
        int ch[2];
    };
    int tot;
    balanced tree[MAXN * 20];
    int newnode(int w)
    {
        ++tot;
        tree[tot].w = w;
        tree[tot].fix = rand();
        tree[tot].num = 1;
        tree[tot].ch[0] = tree[tot].ch[1] = 0;
        tree[tot].sz = 1;
        return tot;
    }
    void pushup(int p)
    {
        tree[p].sz = tree[tree[p].ch[0]].sz + tree[tree[p].ch[1]].sz + tree[p].num;
    }
    void rotate(int &p, int d)
    {
        int y = tree[p].ch[d];
        tree[p].ch[d] = tree[y].ch[d ^ 1];
        tree[y].ch[d ^ 1] = p;
        pushup(p);
        pushup(y);
        p = y;
    }
    void insert(int &p, int w)
    {
        if (!p)
            p = newnode(w);
        else if (tree[p].w == w)
            ++tree[p].num;
        else 
        {
            if (tree[p].w > w)
            {
                insert(tree[p].ch[0], w);
                if (tree[tree[p].ch[0]].fix > tree[p].fix)
                    rotate(p, 0);
            }
            else
            {
                insert(tree[p].ch[1], w);
                if (tree[tree[p].ch[1]].fix > tree[p].fix)
                    rotate(p, 1);
            }
        }
        pushup(p);
    }
    void remove(int &p, int w)
    {
        if (tree[p].w > w)
            remove(tree[p].ch[0], w);
        else if (tree[p].w < w)
            remove(tree[p].ch[1], w);
        else
        {
            if (tree[p].num > 1)
                --tree[p].num;
            else
            {
                if (!tree[p].ch[0] && !tree[p].ch[1])
                    p = 0;
                else if (!tree[p].ch[0])
                {
                    rotate(p, 1);
                    remove(tree[p].ch[0], w);
                }
                else if (!tree[p].ch[1])
                {
                    rotate(p, 0);
                    remove(tree[p].ch[1], w);
                }
                else
                {
                    if (tree[tree[p].ch[0]].fix > tree[tree[p].ch[1]].fix)
                    {
                        rotate(p, 0);
                        remove(tree[p].ch[1], w);
                    }
                    else
                    {
                        rotate(p, 1);
                        remove(tree[p].ch[0], w);
                    }
                }
            }
        }
        if (p)
            pushup(p);
    }
    int queryrank(int p, int k) // return the highest rank of value 'k'
    {
        if (!p)
            return 0;
        if (tree[p].w > k)
            return queryrank(tree[p].ch[0], k);
        else if (tree[p].w == k)
            return tree[tree[p].ch[0]].sz;
        else
            return tree[tree[p].ch[0]].sz + tree[p].num + queryrank(tree[p].ch[1], k);
    }
    int querynum(int p, int k) // return the value of kth rank node
    {
        if (tree[tree[p].ch[0]].sz + 1 == k)
            return tree[p].w;
        else if (tree[tree[p].ch[0]].sz + 1 < k)
            return querynum(tree[p].ch[1], k - 1 - tree[tree[p].ch[0]].sz);
        else    
            return querynum(tree[p].ch[0], k);
    }
    int querypre(int p, int k) // return the prefix of value k
    {
        if (!p) 
            return -2147483647;
        if (tree[p].w >= k)
            return querypre(tree[p].ch[0], k);
        else
            return max(tree[p].w, querypre(tree[p].ch[1], k));
    }
    int querysuf(int p, int k) // return the suffix of value k
    {
        if (!p)
            return 2147483647;
        if (tree[p].w <= k)
            return querysuf(tree[p].ch[1], k);
        else
            return min(tree[p].w, querysuf(tree[p].ch[0], k));
    }
    void listall(int p)
    {
        if (tree[p].ch[0])
            listall(tree[p].ch[0]);
        cerr << tree[p].w << ",sz=" << tree[p].num << "   ";
        if (tree[p].ch[1])
            listall(tree[p].ch[1]);
    }
}
using Treap::listall;

namespace SEG
{
    struct segment
    {
        int l;
        int r;
        int root;
    };
    segment tree[MAXN * 8];
    void build(int p, int l, int r)
    {
        tree[p].l = l;
        tree[p].r = r;
        for (int i = l; i < r + 1; ++i)
            Treap::insert(tree[p].root, a[i]);
        if (l != r)
        {
            int mid = (l + r) / 2;
            build(p * 2, l, mid);
            build(p * 2 + 1, mid + 1, r);
        }
    }
    void modify(int p, int x, int y)
    {
        Treap::remove(tree[p].root, a[x]);
        Treap::insert(tree[p].root, y);
        if (tree[p].l == tree[p].r)
            return;
        int mid = (tree[p].l + tree[p].r) / 2;
        if (x > mid)
            modify(p * 2 + 1, x, y);
        else
            modify(p * 2, x, y);
    }
    int queryrank(int p, int l, int r, int k) // query the highest rank of value 'k'
    {
        if (tree[p].l > r || tree[p].r < l)
            return 0;
        if (tree[p].l >= l && tree[p].r <= r)
            return Treap::queryrank(tree[p].root, k);
        else
            return queryrank(p * 2, l, r, k) + queryrank(p * 2 + 1, l, r, k);
    }
    int querynum(int u, int v, int k) // query the value of kth num
    {
        int l = 0, r = 1e8;
        while (l < r)
        {
            int mid = (l + r + 1) / 2;
            if (queryrank(1, u, v, mid) < k)
                l = mid;
            else
                r = mid - 1;
        }
        return r;
    }
    int querypre(int p, int l, int r, int k)
    {
        if (tree[p].l > r || tree[p].r < l)
            return -2147483647;
        if (tree[p].l >= l && tree[p].r <= r)
            return Treap::querypre(tree[p].root, k);
        else
            return max(querypre(p * 2, l, r, k), querypre(p * 2 + 1, l, r, k));
    }
    int querysuf(int p, int l, int r, int k)
    {
        if (tree[p].l > r || tree[p].r < l)
            return 2147483647;
        if (tree[p].l >= l && tree[p].r <= r)
            return Treap::querysuf(tree[p].root, k);
        else
            return min(querysuf(p * 2, l, r, k), querysuf(p * 2 + 1, l, r, k));
    }
}

int read()
{
    char ch = getchar();
    int x = 0, flag = 1;
    while (ch != '-' && (ch < '0' || ch > '9'))
        ch = getchar();
    if (ch == '-')
    {
        ch = getchar();
        flag = -1;
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * flag;
}

int main()
{
    n = read();
    m = read();
    for (int i = 1; i < n + 1; ++i)
        a[i] = read();
    SEG::build(1, 1, n);
    for (int i = 0; i < m; ++i)
    {
        int opt = read();
        if (opt == 3)
        {
            int x = read(), y = read();
            SEG::modify(1, x, y);
            a[x] = y;
        }
        else
        {
            int l = read(), r = read(), k = read();
            if (opt == 1)
                printf("%d\n", SEG::queryrank(1, l, r, k) + 1);
            else if (opt == 2)
                printf("%d\n", SEG::querynum(l, r, k));
            else if (opt == 4)
                printf("%d\n", SEG::querypre(1, l, r, k));
            else
                printf("%d\n", SEG::querysuf(1, l, r, k));
        }
    }
    return 0;
}

发表于 2018-04-05 21:22:57 in 题解