競技プログラミング日記

主に AtCoder の記事です

AtCoder Beginner Contest 159F

ABC159F

O(NS)は間に合うので,次の二つが考えられる. 今見ている index を \(i \in N\),部分和をとる区間を \([l,r)\) とする.

  • (i): \(i \in N,\ s \in S\) を全探索,
  • (ii): \(r \in N,\ s \in S\) を全探索,
  • (iii): 多項式を使った数え上げ.



解法(i)
まず,簡単のために通常の部分和問題を考えると, 今調べている \(i \in N\) に対して,使うか使わないかの2択を全探索. ここで,今までの和だけ決まっていれば,今までの選び方は区別する必要がなかったので, 状態をまとめることで高速化していた.
次に,今回の問題設定で考える. 遷移するために何が必要だろうか. \(i\)番目を使うかの2択は変わらない. 異なる点は, \(i\) 番目を使う場合は,区間 \([l,r)\) の 左側か,内側(一番右ではない),内側の一番右側か, の3択.これらを状態0,1,2とする.
実装では,区間 \([l,r)\) が\(a_{i}\) だけの場合があることにも注意. つまり,状態 \(0 \rightarrow 2\) の遷移のこと. 遷移は,全部で \(x \rightarrow y\) , \(y \geq x\) だけある.

解法(ii)
\(r \in N\) を全探索して固定し, \(l \in N\) with \(l < r\) を 高速に集計する. 各 \(l\) に対して dp 配列の init 分を追加してから遷移. 別々に遷移していた部分をまとめることで,遷移の回数を減らせる. 別の例としては, \(n\) 次多項式の値を求めるために, \(O(n^2)\)だったのを \(O(N)\) にするアルゴリズムと同じ.

解法(iii)
\(a_{i}\) に対する遷移は,\(p_{i} := (1+x^{a_{i}})\) を掛けることに対応する. 答えは, 多項式の\(s\) に対する係数. 求めるべき多項式は, \begin{align} p_{0}\\ + p_{1}(1 + p_{0}) \\ + p_{2}(1 + p_{1}(1 + p_{0})) \\ + \cdots \\ + p_{n-1}(1 + p_{n-2}(1 + \cdots)). \end{align} つまり, \(q_{i+1} := 1 + q_{i}, \ \ \ q_{0} = 0\) とおいたときの \(\sum_{i \in N} p_{i}q_{i+1}\).

使っている記号,マクロ等 "https://ecsmtlir.hatenablog.com/entry/2022/12/23/131925"

解法(i)

int main() {
  ll n, ms ;
  cin >> n >> ms;
  vll a(n); rep(i,n) { cin >> a[i]; }

  vector<vector<mint>> dp(ms+1, vector<mint>(3));
  dp[0][0] = 1;
  rep(i,n)  {
    vector<vector<mint>> old(ms+1, vector<mint>(3));
    swap(old, dp);

    rep(s, ms+1) rep(x,3){
      srep(y,x,3){
        ll t = s + a[i];
        dp[s][y] += old[s][x];
        if(t <= ms && y > 0 && x < 2) dp[t][y] += old[s][x];
      }
    }
  }
  cout << dp[ms][2].val() << endl;

  return 0;
}

解法(ii)

int main() {
  ll n, s ;
  cin >> n >> s;
  vll a(n); rep(i,n) { cin >> a[i]; }

 
  vector<mint> dp(s+1);
  mint ans;
  // ans += dp[s];
  rep(i,n) { // i : r-1 にあたる
    dp[0] += 1;

    drep(x,s+1){
      ll t = x + a[i];
      if(t <= s) dp[t] += dp[x];
    }

    ans += dp[s];
  }
  cout << ans.val() << endl;
 
  return 0;
}

解法(iii)

int main() {
  ll n, s ;
  cin >> n >> s;
  vll a(n); rep(i,n) { cin >> a[i]; }

  vector<mint> q(s+1);
  mint ans;
  rep(i,n){
    q[0] ++;
    {
      vector<mint> t = q;
      rep(j,s+1){
        if(j+a[i] <= s) t[j+a[i]] += q[j];
      }
      swap(t,q);
    }
    ans += q[s];
  }
  cout << ans.val() << endl;

  return 0;
}