Posted in: 算法讲稿

平衡树 & LCT 模版

内容纲要

Splay

// P3369.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

int q, ch[MAX_N][2], val[MAX_N], cnt[MAX_N], siz[MAX_N], fa[MAX_N], root, ptot;

#define lson (ch[p][0])
#define rson (ch[p][1])

// Splay;

int check(int p) { return ch[fa[p]][1] == p; }

void pushup(int p) { siz[p] = siz[lson] + siz[rson] + cnt[p]; }

void rotate(int x)
{
    int y = fa[x], z = fa[y], dir = check(x), w = ch[x][dir ^ 1];
    fa[x] = z;
    if (z)
        ch[z][check(y)] = x;
    fa[y] = x, ch[x][dir ^ 1] = y;
    fa[w] = y, ch[y][dir] = w;
    pushup(y), pushup(x);
}

void splay(int p, int goal = 0)
{
    if (p == 0)
        return;
    for (int fat = fa[p]; fat = fa[p], fat != goal; rotate(p))
        if (fa[fat] != goal)
            rotate(check(p) == check(fat) ? fat : p);
    if (goal == 0)
        root = p;
}

int find(int v)
{
    int p = root;
    while (p && val[p] != v && ch[p][v > val[p]])
        p = ch[p][v > val[p]];
    splay(p);
    return p;
}

void insert(int v)
{
    int p = root, pre = 0;
    while (p && val[p] != v)
        pre = p, p = ch[p][v > val[p]];
    if (p)
        cnt[p]++;
    else
    {
        p = ++ptot, val[p] = v, cnt[p] = siz[p] = 1;
        ch[p][0] = ch[p][1] = 0, fa[p] = pre;
        if (pre)
            ch[pre][v > val[pre]] = p;
    }
    splay(p);
}

int previous(int v)
{
    int p = find(v);
    if (val[p] < v)
        return p;
    p = lson;
    while (rson)
        p = rson;
    return p;
}

int succ(int v)
{
    int p = find(v);
    if (val[p] > v)
        return p;
    p = rson;
    while (lson)
        p = lson;
    return p;
}

void remove(int val)
{
    int pre = previous(val), suc = succ(val);
    splay(pre), splay(suc, pre);
    int p = ch[suc][0];
    cnt[p]--;
    if (cnt[p] == 0)
        ch[suc][0] = fa[p] = 0;
    else
        splay(p);
}

int getRank(int v)
{
    int p = find(v);
    return siz[lson] + 1;
}

int kth(int k)
{
    int p = root;
    if (siz[p] < k)
        return 0;
    while (true)
    {
        int lsiz = siz[lson];
        if (k <= lsiz)
            p = lson;
        else if (k <= lsiz + cnt[p])
            return val[p];
        else
            k -= lsiz + cnt[p], p = rson;
    }
}

int main()
{
    insert(-0x7fffffff), insert(0x7fffffff);
    scanf("%d", &q);
    while (q--)
    {
        int opt, x;
        scanf("%d%d", &opt, &x);
        if (opt == 1)
            insert(x);
        else if (opt == 2)
            remove(x);
        else if (opt == 3)
            printf("%d\n", getRank(x) - 1);
        else if (opt == 4)
            printf("%d\n", kth(x + 1));
        else if (opt == 5)
            printf("%d\n", val[previous(x)]);
        else
            printf("%d\n", val[succ(x)]);
    }
    return 0;
}

LCT

// P3690.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 3e5 + 200;

int n, q, ch[MAX_N][2], fa[MAX_N], sum[MAX_N], val[MAX_N], siz[MAX_N];
bool tag[MAX_N];

#define lson (ch[p][0])
#define rson (ch[p][1])

// LCT;

int check(int p) { return ch[fa[p]][1] == p; }

bool isRoot(int p) { return ch[fa[p]][0] != p && ch[fa[p]][1] != p; }

void pushup(int p) { sum[p] = sum[lson] ^ sum[rson] ^ val[p], siz[p] = siz[lson] + siz[rson] + 1; }

void rotate(int x)
{
    int y = fa[x], z = fa[y], dir = check(x), w = ch[x][dir ^ 1];
    fa[x] = z;
    if (!isRoot(y))
        ch[z][check(y)] = x;
    fa[y] = x, ch[x][dir ^ 1] = y;
    fa[w] = y, ch[y][dir] = w;
    pushup(y), pushup(x), pushup(z);
}

void pushdown(int p)
{
    if (tag[p])
    {
        tag[lson] ^= 1, tag[rson] ^= 1;
        swap(lson, rson), tag[p] = 0;
    }
}

void update(int p)
{
    if (!isRoot(p))
        update(fa[p]);
    pushdown(p);
}

void splay(int p)
{
    update(p);
    for (int fat = fa[p]; fat = fa[p], !isRoot(p); rotate(p))
        if (!isRoot(fat))
            rotate(check(p) == check(fat) ? fat : p);
    pushup(p);
}

void access(int p)
{
    for (int pre = 0; p; pre = p, p = fa[p])
        splay(p), rson = pre, pushup(p);
}

void makeRoot(int p) { access(p), splay(p), tag[p] ^= 1, pushdown(p); }

int find(int p)
{
    access(p), splay(p), pushdown(p);
    while (lson)
        pushdown(p = lson);
    splay(p);
    return p;
}

void link(int x, int y)
{
    makeRoot(x);
    if (find(x) != find(y))
        fa[x] = y;
}

void cut(int x, int y)
{
    makeRoot(x);
    if (find(y) != x || siz[x] > 2)
        return;
    fa[y] = ch[x][1] = 0, pushup(x);
}

void split(int x, int y) { makeRoot(x), access(y), splay(y); }

int main()
{
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++)
        scanf("%d", &val[i]), sum[i] = val[i];
    while (q--)
    {
        int opt, x, y;
        scanf("%d%d%d", &opt, &x, &y);
        if (opt == 0)
        {
            split(x, y);
            printf("%d\n", sum[y] ^ sum[ch[y][1]]);
        }
        else if (opt == 1)
            link(x, y);
        else if (opt == 2)
            cut(x, y);
        else
            splay(x), val[x] = y, pushup(x);
    }
    return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

Back to Top