ARC033C
問題概要
整数の集合 \(S\) に対して,以下のクエリを処理する.
- \(S\) に整数 \(x\) を追加する.
- \(S\) に整数 \(x\) を削除する.
- \(S\) のうち \(x\) 番目に小さい数を答える.
\(S\) に追加する元 \(x\) は \(1\) 以上 \(2\cdot 10^{5} =: D\) 以下.
解法0: 平方分割
\(D\) を 2次元に分割する.\(D \leq quo \cdot mod\) となるように \(quo, mod\) をとる. \(S\) に元 \(x\) が入っているかを判定する配列 \(cnt\) を用意する. \(cnt[x] := cnt[quo][mod] = 1\) のとき入っていて,\(0\) のときに入ってないとする.
走査するとき,\([1,D]\) の 1次元でなく,\([0,quo] \times [0,mod)\) の 2次元になる.
大まかに言えば,以下のようになる.
追加クエリ
\(x = q * mod + r\) となる \(q \in [0,quo], r \in [0,mod)\) に対して \(cnt[q][r] = 1\) にすればよく,\(O(1)\).
削除クエリ
追加クエリと同様に, \(cnt[q][r] = 0\) にすればよく,\(O(1)\).
答えるクエリ
答えたいのは \(\sum_{y \leq x} cnt[y]\). 愚直に 1次元を走査すると \(O(D)\,つまり \(O(quo \cdot mod)\) 程度掛かってしまう. そこで,2次元に分けて走査するのが本解法.
各 \(q \in [0,quo]\) に対して,\(sum[q] := \sum_{r \in [0,mod)} cnt[q][r]\) を前計算しておけば,\(O(quo + mod)\) で求まる.
相加相乗平均の不等式から,\(mod\) と \(quo\) が同じくらいになるときに一番小さくなる. よって,\(mod\) と \(quo\) は \(D^{1/2}\) 位に取ると一番速い.
答えるクエリで前計算が必要になったので,追加クエリと削除クエリも修正する.
追加クエリ'
\(x = q * mod + r\) となる \(q \in [0,quo], r \in [0,mod)\) に対して \(cnt[q][r] = 1\) にする. さらに,
\(sum[q][r]++\)
をする. 合わせて \(O(1)\).
削除クエリ'
追加クエリと同様に,\(cnt[q][r] = 0\) と
\(\sum[q][r]--\)
を合わせて \(O(1)\).
答えるクエリ
各 \(q \in [0,quo]\) に対して,\(sum[q] := \sum_{r \in [0,mod)} cnt[q][r]\) を前計算しておけば,\(O(quo + mod)\) で \(\sum_{y \leq x} sum[y]\) が求まる.
解法1: seg tree と binary search
使っている記号,マクロ等 "https://ecsmtlir.hatenablog.com/entry/2022/12/23/131925"
// devide square
class Kth{
using CNT = int; // count
using VAL = int; // value
vector<vector<CNT>> cnt; // cnt[q][r] is the count of q*mod + r. cnt is the (quo)x(mod)-matrix.
vector<CNT> sum_cnt; // sum[q] is the sum of cnt[q]. (quo -> CNT)
int quo, mod;
public:
Kth(){}
Kth(int _d2){
mod = sqrt(_d2);
quo = _d2 / mod;
cnt.resize(quo + 1);
for(int i = 0; i < cnt.size(); i++){
cnt[i].resize(mod);
}
sum_cnt.resize(quo + 1);
assert(quo+1 >= _d2/mod);
}
void add(VAL x){
assert(x >= 0);
VAL q = x / mod;
VAL r = x % mod;
cnt[q][r]++;
sum_cnt[q]++;
}
void remove(VAL x){
assert(x >= 0);
VAL q = x / mod;
VAL r = x % mod;
assert(cnt[q][r] > 0);
assert(sum_cnt[q] > 0);
cnt[q][r]--;
sum_cnt[q]--;
}
// return the k-th element
VAL query(int k){
CNT t = 0;
for(int q = 0; q < quo+1; q++){
if(t + sum_cnt[q] < k) {
t += sum_cnt[q];
continue;
}
for(int r = 0; r < mod; r++){
t += cnt[q][r];
if(t >= k) return q*mod + r;
}
}
return -1; // fail
}
};
int main() {
Kth k_th((ll)2e5 + 5);
ll q; cin >> q;
rep(qi,q){
ll t,x ; cin >> t >> x;
if(t == 1){
k_th.add(x);
}else{
ll v = k_th.query(x);
cout << v << endl;
k_th.remove(v);
}
}
return 0;
}
#endif