TopCoder SRM 641 Hard
2017-02-15 | [Competitive programming]概要
n (1 <= n <= 20) ビットのビット列に対し,「一様ランダムに 1 ビット選んで反転する」を繰り返す. ビット列にはカーソルがあり,ビット反転を行うためには反転したいビットまでカーソルを動かす必要がある. カーソルの移動には,位置の差の絶対値に等しいコストがかかる. ビット列のすべてのビットが同じになったら操作は終了である.
ビット列 (n は固定),最初のカーソルの位置が与えられたとき,操作が終了するまでにかかるコストの期待値を求めるクエリを 1000 個処理せよ.
解法
「位置 i にいる状態から位置 j に移動してビット j を反転する」が行われる回数の期待値を求められればよい(各 i, j に対し,この期待値を重み |i - j| を掛けて足し合わせる).
2^n 通りのビット,現在のカーソル位置 n 通り全部に対して遷移を考えて,連立方程式を解けばこれは求められるが,明らかに計算量が大きすぎる. ここで,この期待値を求めるためには,「位置 i のビットの値」「位置 j のビットの値」「カーソル位置が i か?」「位置 i, j 以外のビットで 1 が立っている個数」にだけ注目すれば十分である. この場合の状態数は 8(n+1) になり,遷移をすべて考慮して連立方程式を解いても (1 回だけなら) 十分間に合う.
さらに,この連立方程式を 1 回解けば,複数の i, j や複数の初期ビット列に対しても結果を使いまわすことができる. 最初の連立方程式の計算は O(n^3) (ただし,定数 512 がある) であり,各クエリの処理は O(n^2) なので,十分間に合う.
コード
int sz;
double matr[200][200];
double v[200];
void solve()
{
for (int i = 0; i < sz; ++i) {
int piv = i;
for (int j = i + 1; j < sz; ++j) {
if (fabs(matr[piv][i]) < fabs(matr[j][i])) piv = j;
}
if (piv != i) {
for (int j = 0; j < sz; ++j) swap(matr[i][j], matr[piv][j]);
swap(v[i], v[piv]);
}
for (int j = i + 1; j < sz; ++j) matr[i][j] /= matr[i][i];
v[i] /= matr[i][i];
matr[i][i] = 1.0;
for (int j = 0; j < sz; ++j) if (j != i) {
for (int k = i + 1; k < sz; ++k) {
matr[j][k] -= matr[i][k] * matr[j][i];
}
v[j] -= v[i] * matr[j][i];
matr[j][i] = 0;
}
}
}
class BitToggler {
public:
vector <double> expectation(int n, vector <string> bits, vector <int> pos) {
auto idx = [&](int ist, int jst, int isi, int cnt) {
return cnt * 8 + ist * 4 + jst * 2 + isi;
};
sz = 8 * (n + 1);
for (int s = 0; s < sz; ++s) {
for (int j = 0; j < sz; ++j) matr[s][j] = 0;
v[s] = 0.0;
int ist = (s & 4) >> 2;
int jst = (s & 2) >> 1;
int isi = (s & 1);
int cnt = s >> 3;
matr[s][s] = 1.0;
if (cnt == 0 && ist == 0 && jst == 0) continue;
if (cnt == n - 2 && ist == 1 && jst == 1) continue;
// to i
int t = idx(ist ^ 1, jst, 1, cnt);
matr[s][t] -= 1.0 / n;
// to j
t = idx(ist, jst ^ 1, 0, cnt);
matr[s][t] -= 1.0 / n;
if (isi == 1) v[s] += 1.0 / n;
// incl
if (cnt < n - 2) {
t = idx(ist, jst, 0, cnt + 1);
matr[s][t] -= (double)(n - 2 - cnt) / n;
}
// decl
if (cnt > 0) {
t = idx(ist, jst, 0, cnt - 1);
matr[s][t] -= (double)cnt / n;
}
}
solve();
vector<double> ans;
for (int i = 0; i < bits.size(); ++i) {
int tot = 0;
for (int j = 0; j < n; ++j) if (bits[i][j] == '1') ++tot;
double sol = 0;
for (int j = 0; j < n; ++j) {
for (int k = 0; k < n; ++k) {
int bj = (bits[i][j] == '1' ? 1 : 0);
int bk = (bits[i][k] == '1' ? 1 : 0);
int rem = tot - bj - bk;
sol += abs(j - k) * v[idx(bj, bk, j == pos[i] ? 1 : 0, rem)];
}
}
ans.push_back(sol);
}
return ans;
}
};