BZOJ 4763 / 洛谷新春2017联欢赛 雪辉

查询树链的并中不同权值的个数与 mex

题目大意

给出一个 nnn 个点的树,树上每个点有点权。每个查询给出 qqq 个树链,求这些树链的并中出现的不同的权值的个数和 mex

题目链接

BZOJ 4763

Luogu P3603

题解

先膜一下出题人nzhtl1477和他的官方题解,感觉能出这样的题也算是功力深厚...了吧?然而没有写过树分块...打这场比赛的时候听Claris说好像可以点分治做?众所周知点分治是一个好写好调的结构...于是先去乱搞一下啦!

首先,不同权值的个数与 mex 是非常难以维护的信息,并且注意到这个题中权值大小不超过 C=30000C = 30000C=30000,因此我们可以选择一个颇为暴力的维护方式:用 bitset 维护每种权值是否出现,则这样的信息可以以 A=C/64A = C/64A=C/64 的复杂度(为什么是 646464 而不是 323232?因为你得手写 bitset)完成合并。注意到点分治可以将任何路径表示为预处理过的 nlognn\log nnlogn 中两条的并,因此使用点分治实现这道题看起来是非常合适的,每次取出树链的时间复杂度可以达到 O(A)O(A)O(A)

...好像有哪里不对?如果我们要预处理 nlognn\log nnlogn条链的信息,那么复杂度已经达到了 O(nAlogn)O(nA\log n)O(nAlogn),已经几乎无法承受了。注意到点分治的过程中,在计算某个联通块的信息的时候,每次最多会有一个位置的信息被修改,因此我们可以把 bitset 分块并可持久化。如果将其分为大小为 O(A1/2)O(A^{1/2})O(A1/2) 的块,则每次更新信息的复杂度也变为 O(A1/2)O(A^{1/2})O(A1/2),预处理的时间复杂度减小到 O(nA1/2logn)O(nA^{1/2} \log n)O(nA1/2logn),可以接受了。

...(码完之后才发现)好像又有哪里不对?这个做法中预处理的时空复杂度是同阶的,即使题目给出了 512MB 的宽松的空间要求,可是即便这样空间仍然太大了,在本地运行时使用了 900MB 的空间。有没有什么办法可以减少空间呢?注意到,因为现在每次取出树链的复杂度至少是 O(A)O(A)O(A) 的,如果两条链都很短的话,那么预处理他们的信息实际上是没有什么意义的。因此在这里我只预处理了前 λlogn\lambda\log nλlogn 层点分治的信息,而容易发现未处理的信息所在的联通块大小最多为 n1−λn^{1-\lambda}n1λ,可以考虑暴力计算。这里我们使 n1−λ=An^{1-\lambda}=An1λ=A,就可以在保证查询复杂度的情况下,减少预处理的时空复杂度了。于是这样做就可以成功A掉这道题了。

不过好像BZOJ排名第一的小哥另有高论的样子...很好奇到底是怎么做的呀...

代码

这里就放 C++11 的版本了。

#include <bits/stdc++.h>
#define debug(x) std::cerr << #x << " = " << x << std::endl

typedef unsigned long long ULL;

const int MAXN = 100031, MAXL = 10, MAXC = 30001, SIZE = 32, COUNT = MAXC / (SIZE * 64) + 1;

struct B {
    ULL a[SIZE + 1];

    void init() {
        std::fill(a, a + SIZE, 0LL);
    }

    void set1(int pos) {
        a[pos >> 6] |= 1ULL << (pos & 0x3F);
    }
} pool[12 * MAXN], *top = pool;

struct S {
    B *ind[COUNT];

    S() {}

    void init() {
        for (int i = 0; i < COUNT; i++) (ind[i] = top++)->init();
    }

    void clear() {
        for (int i = 0; i < COUNT; i++) ind[i]->init();
    }

    void set1(int pos) {
        ind[pos >> 11]->set1(pos & 0x7FF);
    }

    void mf(const S &p, int pos) {
        *this = p;
        (*(ind[pos >> 11] = top++) = *p.ind[pos >> 11]).set1(pos & 0x7FF);
    }
} z[MAXN][MAXL], t;

std::vector<int> g[MAXN];
int sz[MAXN], l[MAXN], f[MAXN], x[MAXN];
bool vis[MAXN];

int gs(int u, int p) {
    sz[u] = 1;
    for (auto &v : g[u]) if (!vis[v] && v != p) sz[u] += gs(v, u);
    return sz[u];
}

int ctrd(int u, int p, int s) {
    for (auto &v : g[u]) if (!vis[v] && v != p && sz[v] > s / 2) return ctrd(v, u, s);
    return u;
}

void gbs(int u, int p, int ly) {
    for (auto &v : g[u]) if (!vis[v] && v != p) z[v][ly].mf(z[u][ly], x[v]), gbs(v, u, ly);
}

int dcomp(int u, int ly) {
    int ctr = ctrd(u, 0, gs(u, 0));
    vis[ctr] = 1;
    l[ctr] = ly;
    if (ly < MAXL) z[ctr][ly].init(), z[ctr][ly].set1(x[ctr]), gbs(ctr, 0, ly);
    for (auto &v : g[ctr]) if (!vis[v]) f[dcomp(v, ly + 1)] = ctr;
    return ctr;
}

int lca(int u, int v) {
    for (; u != v; u = f[u]) if (l[u] < l[v]) std::swap(u, v);
    return l[u];
}

bool gc(int u, int p, int tg, int ly) {
    if (u == tg) return t.set1(x[u]), 1;
    for (auto &v : g[u]) if (l[v] >= ly && v != p && gc(v, u, tg, ly)) return t.set1(x[u]), 1;
    return 0;
}

void query(int u, int v) {
    int ly = lca(u, v);
    if (ly < MAXL) for (int i = 0; i < COUNT; i++) for (int j = 0; j < SIZE; j++) t.ind[i]->a[j] |= (z[u][ly].ind[i]->a[j] | z[v][ly].ind[i]->a[j]);
    else gc(u, 0, v, ly);
}

int mex(const S &t) {
    for (int i = 0; i < COUNT; i++) for (int j = 0; j < SIZE; j++) if (t.ind[i]->a[j] != 0xFFFFFFFFFFFFFFFFULL) return (i << 11) | (j << 6) | __builtin_ctzll(~t.ind[i]->a[j]);
}

int cnt(const S &t) {
    int ans = 0;
    for (int i = 0; i < COUNT; i++) for (int j = 0; j < SIZE; j++) ans += __builtin_popcountll(t.ind[i]->a[j]);
    return ans;
}

int n = 0, q = 0, k = 0;

int main() {
    scanf("%d%d%d", &n, &q, &k);
    for (int i = 1; i <= n; i++) scanf("%d", x + i);
    for (int i = 1, u = 0, v = 0; i < n; i++) scanf("%d%d", &u, &v), g[u].push_back(v), g[v].push_back(u);
    dcomp(1, 0);
    t.init();
    for (int u = 0, v = 0, m = 0, lans = 0; q--; ) {
        scanf("%d", &m);
        t.clear();
        while (m--) scanf("%d%d", &u, &v), query(u ^ lans, v ^ lans);
        int ct = cnt(t), mx = mex(t);
        printf("%d %d\n", ct, mx);
        lans = (k ? ct + mx : 0);
    }
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注