悪い子のためのビットベクターによる定数倍最適化

※UTPC 2009 Problem F(天使の階段)ネタバレあり。


この日記は、Competitive Programming Advent Calendarのために書かれました。

この記事は、計算量的にTLEになる想定誤解法を、無理やり最適化して通してしまおうという悪い子のみんなのための記事です。

題材は Angel Stairs (UTPC 2009 Problem F: 天使の階段) です。

問題概要

  • n段の階段と、m音からなる曲があって、階段を踏むごとに音が鳴る。
  • 階段のi段目には音T[i]が書かれている
  • 音はCからBまでの12種類。以下ここでは0〜11と表記し、+-した場合も12で割った余りを取る(11+1→0)とする。
  • 1段次、2段次、1段前に行くことができて、行った先の段がi段目だとすると、それぞれ T[i], T[i]+1, T[i]-1 の音が鳴る
  • 指定された曲をぴったり鳴らしてスタートからゴールまで行くことができるか

想定解法・想定誤解法概要

  • 問題の正しい解法の解説はこちら(pptx)UTPC公式サイトで「過去の東京大学プログラミングコンテストのページは,現在一時的にアクセスできなくなっています.」となっているので、Web Archiveから。
  • 想定解法は O(n+m)。詳しくはスライドを見てね。
  • 想定誤解法はたとえば「これは、DP_raw[t][i] = t音目まで鳴らした時点で、i段目にいることができるかどうか、というDPで解ける」というO(nm)。n = m ≦ 50,000 なので、nm ≦ 2,500,000,000 となりちょっとこのままでは大きすぎる気がします。
  • というか、実際に想定誤解法を無理やり通して怒られました(てへぺろ

さて、実際にやってみましょう。

定数倍最適化の方法としては、上のDP_rawの各要素を1つ1つintとかにするのではなく、32ビットとか64ビットとかまとめてint(32-bit整数)やlong long(64-bit整数)に放り込んでしまい、途中演算も全てビット演算で一括で(各ビットごとに処理しないで)行う、ということをします。

ソースコードはこんな感じになります。

#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<map>
using namespace std;

unsigned long long T[12][50002/64+1];
int S[50002];
unsigned long long DP[2][50002/64+10];
int main()
{
  map<string,int> dic;
  {
    int i = 0;
    dic["C"]  = i ++; dic["C#"] = i ++; dic["D"]  = i ++;
    dic["D#"] = i ++; dic["E"]  = i ++; dic["F"]  = i ++;
    dic["F#"] = i ++; dic["G"]  = i ++; dic["G#"] = i ++;
    dic["A"]  = i ++; dic["A#"] = i ++; dic["B"]  = i ++;
  }
  int TT;
  cin >> TT;
  for( int CC = 0; CC < TT; CC ++ ){
    int n, m;
    cin >> n >> m;

    // 階段の音の入力をパーズ
    // i段目に音jが書かれている = T[j][i/64+1]の下から(i%64)ビット目が1
    // i/64+1の+1はメモリ範囲外アクセスを防ぐ番兵用
    memset( T, 0x00, sizeof(T) );
    for( int i = 1; i <= n; i ++ ){
      string str; cin >> str; int j = dic[str];
      T[j][i/64+1] |= 1LL<<(i%64);
    }

    // 曲の音の入力をパーズ
    for( int i = 0; i < m; i ++ ){
      string str; cin >> str; S[i] = dic[str];
    }

    // DP本体
    int nmax = (n+1)/64+1;
    // 初期化: 0段目にいる
    memset( DP, 0x00, sizeof(DP) ); // (a)
    DP[0][0+1] = 1; // (a)
    for( int t = 0; t < m; t ++ ){ // メインループ
      int tnext = (t&1)^1;
      memset( DP[tnext], 0x00, sizeof(long long) * (nmax+2) );
      for( int i = 1; i <= nmax; i ++ ){
        unsigned long long v = DP[t&1][i];
        // 1段次へ; 行った先の段には S[t] が書かれている必要がある
        DP[tnext][i]   |= (v << 1)  & T[S[t]][i];
        DP[tnext][i+1] |= (v >> 63) & T[S[t]][i+1];
        // 2段次へ; 行った先の段には S[t]-1 が書かれている必要がある
        DP[tnext][i]   |= (v << 2)  & T[(S[t]+11)%12][i];
        DP[tnext][i+1] |= (v >> 62) & T[(S[t]+11)%12][i+1];
        // 1段前へ; 行った先の段には S[t]+1 が書かれている必要がある
        DP[tnext][i]   |= (v >> 1)  & T[(S[t]+1)%12][i];
        DP[tnext][i-1] |= (v << 63) & T[(S[t]+1)%12][i-1];
      }
    }

    // m音目を鳴らした後でn段目かn-1段目に居ることができればYes
    printf( "%s\n", ( (DP[m&1][n/64+1] & (1LL<<(n&63))) || 
      (DP[m&1][(n-1)/64+1] & (1LL<<((n-1)&63)))) ? "Yes" : "No" );
  }
}

やり方解説

まず、各ビットをつめて、この場合はunsigned long long(64-bit整数)に詰めます。

  • DP_raw[t][i] は、コード内では DP[t&1][i/64+1]の(i%64)ビット目( DP[t&1][i/64+1]&(1LL<<(i%64)) )に格納されています。
  • t&1になっているのはビットベクターとは関係ないメモリ領域削減のためです。そのままlong long DP[50002][50002/64]とかとすると300MBくらい食うのでMemory Limit Exceededになります。
  • +1とついてるのは番兵の分だけずらしただけです。
  • unsigned long long になっているのは、今回は右シフトした時に上位には0ビットが入ってほしいのでこうなっています。問題によって、右シフト時に上位に1が入ってほしい時にはsigned long longにすると良いでしょう。

次に、途中計算(少なくとも一番重い部分)を、各ビットごとにばらさずに、64ビット一括のビット演算・シフト演算等だけでできるように設計・実装します。

  • まず、(a)と書かれた部分でDP_rawを初期化します。
    • t=0の時点では、0段目にいるので DP_raw[0][0] = 1で他は0です。
    • 初期化は別に各ビットごとにアクセスしても問題ない場合が多いです。
  • 次に、t音目を鳴らす時にDP_rawをupdateしていきます。ここがメインなので、この処理の中では各ビットへの個別アクセスを極力減らします。
    • 今回の場合、
  // 1段次へ; 行った先の段には S[t] が書かれている必要がある
  DP[tnext][i]   |= (v << 1)  & T[S[t]][i];
  DP[tnext][i+1] |= (v >> 63) & T[S[t]][i+1];

という部分では、

    • v = DP[t&1][i] には、DP_raw[t][i*64-64〜i*64-1]、つまり t音目を鳴らす前に(i-1)*64〜i*64-1段目にいることができるか、を示すビットが64個入っています。
    • こんな感じ: (上位)(DP_raw[t][i*64-1])(DP_raw[t][i*64-2])...(DP_raw[t][i*64-63])(DP_raw[t][i*64-64])(下位)
    • 1段次に行く、ということで左に1ビットシフトします(v << 1)。最下位ビットには0が入るのでOKです。
    • その行った先の段には S[t] という音が書かれている必要があります。この条件判定のために、あらかじめ配列Tに、「段iに音jが書かれていたら、T[j][i/64+1]の(i%64)ビット目が1、それ以外は0」というデータ(マスク)を入れておきます。そうすると、個別ビットについてif文を書かずとも、T[S[t]][i] とビットごとAND(&)を取るだけで、64ビット分一括で判定できます。
    • その結果をビットごとOR(|)で書き込みます。
    • これだけだと、最初のvの最上位ビットがどこかに行ってしまう(繰り上がりした)ので、その分は特別に処理する必要があります。v << 1 した時に 64ビット目以上にいくはずだったビットを、DP[tnext][i+1] (64ビット目以上なので次の要素になる = i+1)に書き込む感じです。それが下の行です。
    • 他のケースについても同様に処理します。
    • このようにすると、メインループの中では各ビットへの個別アクセスを繰り返していないことに注意。これで、演算量がおおざっぱにいって最小で1/64になります。
  • 最後にビット演算でDP_raw[m][n-1]とDP_raw[m][n]をアクセスして終了。

さて、これだけで2.4秒くらいで通ります(AOJでのTime Limitは8秒)!! やったね!

更に、もうちょっと小細工すると0.76秒で通りました。

  • ループ範囲の限定; 時間が足りなくて明らかに0段目から到達できない、あるいはn+1段目まで届かない部分は処理しなくていいよね
  • vが0だったら処理しなくていいよね

ソースの差分はこんな感じ。

      int imin = max(1, (n+2-(m-t+1)*2)/64), imax = min(nmax, 2*t/64+1);
      for( int i = imin; i <= imax; i ++ ){
        unsigned long long v = DP[t&1][i];
        if( v ){
          // 1段次へ; 行った先の段には S[t] が書かれている必要がある
          DP[tnext][i]   |= (v << 1)  & T[S[t]][i];
          DP[tnext][i+1] |= (v >> 63) & T[S[t]][i+1];
          // 2段次へ; 行った先の段には S[t]-1 が書かれている必要がある
          DP[tnext][i]   |= (v << 2)  & T[(S[t]+11)%12][i];
          DP[tnext][i+1] |= (v >> 62) & T[(S[t]+11)%12][i+1];
          // 1段前へ; 行った先の段には S[t]+1 が書かれている必要がある
          DP[tnext][i]   |= (v >> 1)  & T[(S[t]+1)%12][i];
          DP[tnext][i-1] |= (v << 63) & T[(S[t]+1)%12][i-1];
        }
      }

実行時間ランキングで4位の人の4倍遅いくらいならバレないんじゃねwww

ビットベクターは、うまく使うと64倍とか高速化できるので、侮れないパワーがあります。
やりすぎるとそんなの無理に決まってるじゃないかという定数化最適化地獄に嵌はまるので、計算量落ちないか考える方が正攻法なことが多いですが、どうしても思いつかない時とか3分で書けるので試してみるとかにはいいかもしれません。

※おまけ: ビットの格納順番を変えると、シフト演算を再内ループから排除できるのですが、境界条件処理が面倒だったり、ループ範囲を狭める最適化ができなかったり、素でなんか遅かったりしました。ちなみに、実際に提出したのはこっちのバージョンだったと思います。

#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<map>
using namespace std;

unsigned long long T[12][4096];
int S[50002];
unsigned long long M[2][4096];
int main()
{
  map<string,int> dic;
  {
    int i = 0;
    dic["C"]  = i ++; dic["C#"] = i ++; dic["D"]  = i ++;
    dic["D#"] = i ++; dic["E"]  = i ++; dic["F"]  = i ++;
    dic["F#"] = i ++; dic["G"]  = i ++; dic["G#"] = i ++;
    dic["A"]  = i ++; dic["A#"] = i ++; dic["B"]  = i ++;
  }
  int TT;
  cin >> TT;
  for( int CC = 0; CC < TT; CC ++ ){
    int n, m;
    cin >> n >> m;
    memset( T, 0x00, sizeof(T) );

    int MASK = 32, SHIFT = 5;
    while( MASK * 64 < n+3 ){
      MASK *= 2;
      SHIFT ++;
    }
    -- MASK;

    for( int i = 1; i <= n; i ++ ){
      string str; cin >> str; int j = dic[str];
      T[j][i&MASK] |= (1LL<<(i>>SHIFT));
    }
    for( int i = 0; i < m; i ++ ){
      string str; cin >> str; S[i] = dic[str];
    }
    memset( M, 0x00, sizeof(M) );
    M[0][0] = 1;
    for( int t = 0; t < m; t ++ ){
      memset( M[(t&1)^1], 0x00, sizeof(M[0][0]) * (MASK+2) );
      {
        int i = 0;
        M[(t&1)^1][i+1] |= M[t&1][i]   & T[S[t]][i+1];
        M[(t&1)^1][i+2] |= M[t&1][i]   & T[(S[t]+11)%12][i+2];
        M[(t&1)^1][MASK] |= (M[t&1][i]>>1)  & T[(S[t]+1)%12][MASK];
      }
      for( int i = 1; i <= MASK; i ++ ){
        unsigned long long v = M[t&1][i];
        M[(t&1)^1][i+1] |= v & T[S[t]][i+1];
        M[(t&1)^1][i+2] |= v & T[(S[t]+11)%12][i+2];
        M[(t&1)^1][i-1] |= v & T[(S[t]+1)%12][i-1];
      }
      M[(t&1)^1][0] |= (M[t&1][MASK]<<1)     & T[S[t]][0];
      M[(t&1)^1][1] |= (M[t&1][MASK]<<1)     & T[(S[t]+11)%12][1];
      M[(t&1)^1][0] |= (M[t&1][MASK-1]<<1)   & T[(S[t]+11)%12][0];
    }
    printf( "%s\n", ( (M[m&1][(n)&MASK] & (1LL<<(n>>SHIFT))) || 
      (M[m&1][(n-1)&MASK] & (1LL<<((n-1)>>SHIFT)))) ? "Yes" : "No" );
  }
}