競技プログラミング日記

主に AtCoder の記事です

AtCoder Regular Contest 033C問題

ARC033C

問題概要

整数の集合 \(S\) に対して,以下のクエリを処理する.

  • \(S\) に整数 \(x\) を追加する.
  • \(S\) に整数 \(x\) を削除する.
  • \(S\) のうち \(x\) 番目に小さい数を答える.

\(S\) に追加する元 \(x\) は \(1\) 以上 \(2\cdot 10^{5} =: D\) 以下.

解法0: 平方分割

\(D\) を 2次元に分割する.\(D \leq quo \cdot mod\) となるように \(quo, mod\) をとる. \(S\) に元 \(x\) が入っているかを判定する配列 \(cnt\) を用意する. \(cnt[x] := cnt[quo][mod] = 1\) のとき入っていて,\(0\) のときに入ってないとする.

走査するとき,\([1,D]\) の 1次元でなく,\([0,quo] \times [0,mod)\) の 2次元になる.


大まかに言えば,以下のようになる.

追加クエリ

\(x = q * mod + r\) となる \(q \in [0,quo], r \in [0,mod)\) に対して \(cnt[q][r] = 1\) にすればよく,\(O(1)\).

削除クエリ

追加クエリと同様に, \(cnt[q][r] = 0\) にすればよく,\(O(1)\).

答えるクエリ

答えたいのは \(\sum_{y \leq x} cnt[y]\). 愚直に 1次元を走査すると \(O(D)\,つまり \(O(quo \cdot mod)\) 程度掛かってしまう. そこで,2次元に分けて走査するのが本解法.

各 \(q \in [0,quo]\) に対して,\(sum[q] := \sum_{r \in [0,mod)} cnt[q][r]\) を前計算しておけば,\(O(quo + mod)\) で求まる.


相加相乗平均の不等式から,\(mod\) と \(quo\) が同じくらいになるときに一番小さくなる. よって,\(mod\) と \(quo\) は \(D^{1/2}\) 位に取ると一番速い.
答えるクエリで前計算が必要になったので,追加クエリと削除クエリも修正する.

追加クエリ'

\(x = q * mod + r\) となる \(q \in [0,quo], r \in [0,mod)\) に対して \(cnt[q][r] = 1\) にする. さらに,

\(sum[q][r]++\)

をする. 合わせて \(O(1)\).

削除クエリ'

追加クエリと同様に,\(cnt[q][r] = 0\) と

\(\sum[q][r]--\)

を合わせて \(O(1)\).

答えるクエリ

各 \(q \in [0,quo]\) に対して,\(sum[q] := \sum_{r \in [0,mod)} cnt[q][r]\) を前計算しておけば,\(O(quo + mod)\) で \(\sum_{y \leq x} sum[y]\) が求まる.

使っている記号,マクロ等 "https://ecsmtlir.hatenablog.com/entry/2022/12/23/131925"

// devide square
class Kth{
  using CNT = int; // count
  using VAL = int; // value
  vector<vector<CNT>> cnt; // cnt[q][r] is the count of q*mod + r. cnt is the (quo)x(mod)-matrix.
  vector<CNT> sum_cnt; // sum[q] is the sum of cnt[q]. (quo -> CNT)
  int quo, mod;
public:
 
  Kth(){}
  Kth(int _d2){
    mod = sqrt(_d2);
    quo = _d2 / mod;
    cnt.resize(quo + 1);
    for(int i = 0; i < cnt.size(); i++){
      cnt[i].resize(mod);
    }
    sum_cnt.resize(quo + 1);

    assert(quo+1 >= _d2/mod);
  }

  void add(VAL x){
    assert(x >= 0);
    VAL q = x / mod;
    VAL r = x % mod;
    cnt[q][r]++;
    sum_cnt[q]++;
  }

  void remove(VAL x){
    assert(x >= 0);
    VAL q = x / mod;
    VAL r = x % mod;
    assert(cnt[q][r] > 0);
    assert(sum_cnt[q] > 0);

    cnt[q][r]--;
    sum_cnt[q]--;
  }

  // return the k-th element
  VAL query(int k){
    CNT t = 0;
    for(int q = 0; q < quo+1; q++){
      if(t + sum_cnt[q] < k) {
        t += sum_cnt[q];
        continue;
      }
     
      for(int r = 0; r < mod; r++){
        t += cnt[q][r];
        if(t >= k) return q*mod + r;
      }
    }

    return -1; // fail
  }

};

int main() {
  Kth k_th((ll)2e5 + 5);
  ll q; cin >> q;
  rep(qi,q){
    ll t,x ; cin >> t >> x;
    if(t == 1){
      k_th.add(x);
    }else{
      ll v = k_th.query(x);
      cout << v << endl;
      k_th.remove(v);
    }
  }

  return 0;
}
#endif