くれなゐの雑記

例を上げて 自分で手を動かして学習できる入門記事を多めに書いています

AGC002 D. Stamp Rally

問題

agc002.contest.atcoder.jp

やること

  • 二分木探索で, 最大のスコアを探索する
  • midまでの辺のfromとtoのノードを, union_findでuniteする
  • union_findでは同じグループの数をuniteするときに一緒に計算しておく
  • x_i, y_i を含むグループの数と, z_iを比較し, high, lowを更新する

これを Q 回処理すると, 間に合わないので並列に行う.

以降実装とコツ

  • 多分再帰で実装したほうが良い.
    • f(depth, low, high, その範囲で計算する人たち) みたいな感じで実装する
    • depthがなぜ必要なのかは後で
  • x_iy_i が同じグループに属している時, 片方のグループのみで数える
    • 重複は数えないため
    • カウントして, 同じグループに属していたら2で割るみたいな処理をすれば良い
  • 0人のグループになれば, 計算をしないようにする(1敗)
  • union_findは, あらかじめlog2(N)+1個作っておいて, 再利用する
    • コンストラクタもO(N)かかるためかかる時間がバカにならない
    • 再帰をかける順番を工夫して, midが小さい方から探索するようにする
    • それぞれの深さで使用するunion_findを分ける. 深さはlog2(N)+1くらいなので, それをあらかじめ用意しておく

ソースコード

今回のソースコードは特にやばい

#include <iostream>
#include <queue>
#include <map>
#include <list>
#include <vector>
#include <string>
#include <stack>
#include <limits>
#include <cassert>
#include <fstream>
#include <cstring>
#include <cmath>
#include <bitset>
#include <iomanip>
#include <algorithm>
#include <functional>
#include <cstdio>
#include <ciso646>

using namespace std;

#define FOR(i,a,b) for (int i=(a);i<(b);i++)
#define RFOR(i,a,b) for (int i=(b)-1;i>=(a);i--)
#define REP(i,n) for (int i=0;i<(n);i++)
#define RREP(i,n) for (int i=(n)-1;i>=0;i--)

#define inf 0x3f3f3f3f
#define INF INT_MAX/3
#define PB push_back
#define MP make_pair
#define ALL(a) (a).begin(),(a).end()
#define SET(a,c) memset(a,c,sizeof a)
#define CLR(a) memset(a,0,sizeof a)
#define pii pair<int,int>
#define pcc pair<char,char>
#define pic pair<int,char>
#define pci pair<char,int>
#define VS vector<string>
#define VI vector<int>
#define DEBUG(x) cout<<#x<<": "<<x<<endl
#define MIN(a,b) (a>b?b:a)
#define MAX(a,b) (a>b?a:b)
#define pi 2*acos(0.0)
#define INFILE() freopen("in0.txt","r",stdin)
#define OUTFILE()freopen("out0.txt","w",stdout)
#define in scanf
#define out printf
#define ll long long
#define ull unsigned long long
#define eps 1e-14
#define FST first
#define SEC second

class union_find {
private:
    vector<int> par;
    vector<int> rank;
    vector<int> count;
public:
    union_find(int N):par(N), rank(N, 0), count(N, 1) {
        for (int i = 0; i < N; ++i) {
            par[i] = i;
        }
    }

    int find(int x) {
        if (par[x] == x) {
            return x;
        }
        else {
            return par[x] = find(par[x]);
        }
    }

    void unite(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return;

        if (rank[x] < rank[y]) {
            count[y] += count[x];
            par[x] = y;
        }
        else {
            count[x] += count[y];
            par[y] = x;
            if (rank[x] == rank[y]) rank[x]++;
        }
    }

    bool same(int x, int y) {
        return find(x) == find(y);
    }

    int getCount(int x) {
        return count[find(x)];
    }

    void clean() {
        par = vector<int>(par.size());
    }
};

struct Edge {
    int from;
    int to;
    Edge(int f, int t) :from(f), to(t) {}
};

vector<int> res;
vector<Edge> e;
int memo_f[100001] = {};
int N, M, Q;
vector<tuple<int, int, int> > q;

vector<pair<int, union_find> > uf;

void f(int depth, int low, int high, vector<int> people) {
    if (depth == uf.size()) uf.push_back({ 0,union_find(N) });
    if (depth < uf.size()) assert(1);
        
    if (high - low <= 1) {
        for (auto &p : people) {
            res[p] = high;
        }
        return;
    }
    int mid = (high + low) / 2;

    FOR(i, uf[depth].first, mid) uf[depth].second.unite(e[i].from, e[i].to);

    vector<int> OKPeople;
    vector<int> NGPeople;

    for (auto &p : people) {
        int x, y, z; tie(x, y, z) = q[p];
        int div = (uf[depth].second.same(x, y) + 1);
        if ((uf[depth].second.getCount(x) + uf[depth].second.getCount(y))/div >= z) OKPeople.push_back(p);
        else NGPeople.push_back(p);
    }
    uf[depth].first = mid;
    if(not OKPeople.empty()) f(depth+1, low, mid, OKPeople);
    if(not NGPeople.empty()) f(depth+1, mid, high, NGPeople);
}


int main() {
    memset(memo_f, -1, sizeof(memo_f));
    cin >> N >> M;
    REP(i, M) {
        int a, b; cin >> a >> b;
        --a; --b;
        e.emplace_back(a, b);
    }
    int Q; cin >> Q;
    res.resize(Q);
    REP(i, Q) {
        int x, y, z; cin >> x >> y >> z;
        --x, --y;
        q.push_back(make_tuple(x, y, z));
    }

    vector<int> people;
    REP(i, Q) people.push_back(i);
    f(0, 0, M, people);

    for (auto &a : res) {
        cout << a << endl;
    }

    return 0;
}

感想

最近はあまり時間がないので解けた問題+1って感じだけどこの問題を解いた
永続配列を使ってunion_find解もあるらしく, 研究が辛くなったら実装しようと思う