冷滟泽的个人博客冷滟泽的个人博客

CF840E. In a Trap

这题也是一道指令集优化的练手 ao 题(

因为点数和点权都是 5\cdot 10^4 的,可以使用 256 位指令集存储 short 类型做到常数除以 16。对每个节点预处理一个指令集 w_u 存储 u 的前 16 个父亲的权值。每次询问时可以从 v 开始往根的方向跳,每次跳 16 个节点,同时维护一个指令集 cur 存储当前节点 p 的前 16 个父亲到节点 v 的距离。将 cur 初始赋为 0,\cdots,16,每跳一次全局 +16 就可以了。再维护一个指令集 ret 表示答案,每次将它与 cur\oplus w_p\max 即可。最后剩下的不超过 16 个节点暴力处理。

这样的复杂度是 O\left(\frac{nq}{\omega}\right),w=16,不一定能过这道题。但发现询问比点数多,我们可以考虑离线,将 v 相同的询问的答案在一次暴力跳父亲的过程中求出来。这样复杂度就变成了 O\left(\frac{n^2}{\omega}+q\omega\right),就可以稳过了。实测最慢的点在 1 秒内跑过。

也许你想问,这题正解和指令集的做法差不多(还更优美),正式比赛还不能用指令集,那它还有什么用呢?

因为我们可以魔改这道题,比如改成求 a_i\oplus sum(i,v) 的最大值, sum(i,v) 表示 iv 的权值和模 m(m<65536)。那么正解就不工作了,此时指令集就可以暴力碾标算了。

最后放上代码:

#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
#include <immintrin.h>
#include <emmintrin.h>
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <vector>
using namespace std;
typedef unsigned short ust;
typedef __m256i oak;
const int MAXN=55000;
const int MAXQ=160000;
inline char nc()
{
    static char buf[100000], *p1=buf, *p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf, 1, 100000, stdin), p1==p2)?EOF:*p1++;
}
inline int read()
{
    int x=0; char ch=0;
    while (!isdigit(ch)) ch=nc();
    while (isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48), ch=nc();
    return x;
}
ust a[MAXN];
vector<int> g[MAXN];
int fa[MAXN], fa16[MAXN];
int dep[MAXN];
oak w[MAXN];
struct Query
{
    int u, id;
    Query(int a, int b): u(a), id(b) {}
    bool operator < (const Query& rhs) const
    {
        return dep[u]>dep[rhs.u];
    }
};
vector<Query> h[MAXN];
ust ans[MAXQ];
void dfs(int u, int f, int d)
{
    static int t[16];
    fa[u]=f, dep[u]=d, t[0]=u;
    for (int i=1; i<16; i++) t[i]=fa[t[i-1]];
    fa16[u]=fa[t[15]];
    w[u]=_mm256_set_epi16(a[t[15]], a[t[14]], a[t[13]], a[t[12]], a[t[11]], a[t[10]],
    a[t[9]], a[t[8]], a[t[7]], a[t[6]], a[t[5]], a[t[4]], a[t[3]], a[t[2]], a[t[1]], a[t[0]]);
    for (int v:g[u]) if (v!=f) dfs(v, u, d+1);
}
void solve(int v)
{
    oak cur=_mm256_set_epi16(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
    oak delta=_mm256_set1_epi16(16), ret=_mm256_setzero_si256();
    int k=0, t=v;
    while (v&&k<h[t].size())
    {
        if (dep[fa16[v]]<dep[h[t][k].u])
        {
            ust mx=0, *p=(ust*)&ret;
            for (int i=0; i<16; i++) mx=max(mx, p[i]);
            oak tmp=_mm256_xor_si256(cur, w[v]); p=(ust*)&tmp;
            while (k<h[t].size()&&dep[fa16[v]]<dep[h[t][k].u])
            {
                for (int w=v, i=0; dep[w]>=dep[h[t][k].u]; w=fa[w], i++)
                    mx=max(mx, p[i]);
                ans[h[t][k++].id]=mx;
            }
        }
        ret=_mm256_max_epu16(ret, _mm256_xor_si256(cur, w[v]));
        cur=_mm256_add_epi16(cur, delta);
        v=fa16[v];
    }
}
int main()
{
//  freopen("E.in", "r", stdin);
//  freopen("E.out", "w", stdout);
    int n=read(), q=read();
    for (int i=1; i<=n; i++) a[i]=read();
    for (int i=1; i<n; i++)
    {
        int u=read(), v=read();
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0, 1);
    for (int i=1; i<=q; i++)
    {
        int u=read(), v=read();
        h[v].push_back(Query(u, i));
    }
    for (int i=1; i<=n; i++)
        if (!h[i].empty())
        {
            sort(h[i].begin(), h[i].end());
            solve(i);
        }
    for (int i=1; i<=q; i++) printf("%d\n", ans[i]);
    return 0;
}

解释一下代码中用到的指令集操作:

  • _mm256_set_epi16() 初始化一个 256 位指令集,将其赋为 16 个指定的 short 类型变量。
  • _mm256_set1_epi16() 同上,但是将 16 个位置都赋为同一值。
  • _mm256_setzero_si256() 返回全 0 指令集。
  • _mm256_add_epi16() 返回两个 short 指令集对应位相加之和。
  • _mm256_xor_si256() 返回两个指令集的按位异或值。
  • _mm256_max_epu16() 将两个 unsigned short 指令集对应位取 \max。注意是 epu 而不是 epi,这体现了无符号。

参考 https://www.luogu.com.cn/blog/ouuan/avx-optimizehttps://wenku.baidu.com/view/33776d1c59eef8c75fbfb310.html

未经允许不得转载:冷滟泽的个人博客 » CF840E. In a Trap

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址