競技プログラミング日記

主に AtCoder の記事です

AtCoder Beginner Contest 306F

ABC306F

\(\def\cnt #1{\lvert {#1} \rvert}\)

解法

\(A,B\) を集合,\(A \cap B = \emptyset\) とする. 集合 \(C\) に対して,\(x\) 以下の元全体を \(C^{\leq x}\) とおく. \begin{align} f_{A,B} &= \sum_{x \in A} \cnt{A \cup B}^{\leq x} \\ &= \sum_{x \in A} \cnt{A^{\leq x}} + \cnt{B^{\leq x}} \end{align} であるから, \begin{align} ans &= \sum_{l,r \in N \\ l < r} \sum_{x \in S_{l}} \cnt{S_{l}^{\leq x}} + \cnt{S_{r}^{\leq x}}. \end{align} Intersection が空より \(\cnt{{\cup_{r > l} S_{r}}^{\leq x}} = \sum_{r>l} \cnt{{S_{r}}^{\leq x}}\) であるから, \begin{align} \sum_{l,r \in N \\ l < r} \sum_{x \in S_{l}} \cnt{S_{r}^{\leq x}} &= \sum_{l \in N} \sum_{x \in S_{l}} \cnt{\cup_{r \in N \\ r > r} S_{r}^{\leq x}}. \end{align} また, \begin{align} \sum_{l,r \in N \\ l < r} \sum_{x \in S_{l}} \cnt{S_{l}^{\leq x}} &= {}_{N}C_{2} \cdot {}_{M+1}C_{2}, \end{align} であることも用いると,\(ans\) が求まる.

実装

\(x\) 以下の元の個数は,segtree で数えることができる. ただし,\(\cup_{r > l} {S_{r}}^{\leq x}\) において \(S_{l}\) 内の \(x\) は数えない. よって,\(S_{l}\) の元を大きい順に調べる. また,\(r > l\) を調べるため,\(l\) も大きい順に調べる.
最後に,\(S_{i}\) の元は大小だけが重要なので座標圧縮しておく.

使っている記号,マクロ等 "https://ecsmtlir.hatenablog.com/entry/2022/12/23/131925"

template<typename P>
struct Points {
  vector<P> ps;
  bool end_init;

  Points() {}
  void add(P x) {
    ps.emplace_back(x);
  }
  void init() {
    sort(ps.begin(), ps.end());
    ps.erase(unique(ps.begin(), ps.end()), ps.end());
    end_init = true;
  }
  P operator[](int i) {
    if(!end_init) init();
    return ps[i];
  }
  int operator()(const P& x) {
    if(!end_init) init();
    return lower_bound(ps.begin(),ps.end(), x)-ps.begin();
  }
  int size() const { return ps.size(); }
};

ll op(ll a, ll b) { return a+b; }
ll e() { return 0LL; }

int main() {
  ll n,m;
  cin >> n >> m;
  vvll a(n, vll(m));
  cinvv(a);

  Points<ll> po;
  rep(i,n) rep(j,m) po.add(a[i][j]);
  po.init();

  ll ans = 0;
  segtree<ll,op,e> st(po.size());
  drep(l,n) {
    sort(all(a[l]), greater<ll>()); // 無いと WA
    for(auto x: a[l]) {
      x = po(x);
      ans += st.prod(0, x+1);
      st.set(x, st.get(x)+1);
    }
  }
  ans += nC2(n)*nC2(m+1);
  cout << ans << endl;

  return 0;
}