TopCoder SRM 641 Hard

2017-02-15 | [Competitive programming]

概要

n (1 <= n <= 20) ビットのビット列に対し,「一様ランダムに 1 ビット選んで反転する」を繰り返す. ビット列にはカーソルがあり,ビット反転を行うためには反転したいビットまでカーソルを動かす必要がある. カーソルの移動には,位置の差の絶対値に等しいコストがかかる. ビット列のすべてのビットが同じになったら操作は終了である.

ビット列 (n は固定),最初のカーソルの位置が与えられたとき,操作が終了するまでにかかるコストの期待値を求めるクエリを 1000 個処理せよ.

解法

「位置 i にいる状態から位置 j に移動してビット j を反転する」が行われる回数の期待値を求められればよい(各 i, j に対し,この期待値を重み |i - j| を掛けて足し合わせる).

2^n 通りのビット,現在のカーソル位置 n 通り全部に対して遷移を考えて,連立方程式を解けばこれは求められるが,明らかに計算量が大きすぎる. ここで,この期待値を求めるためには,「位置 i のビットの値」「位置 j のビットの値」「カーソル位置が i か?」「位置 i, j 以外のビットで 1 が立っている個数」にだけ注目すれば十分である. この場合の状態数は 8(n+1) になり,遷移をすべて考慮して連立方程式を解いても (1 回だけなら) 十分間に合う.

さらに,この連立方程式を 1 回解けば,複数の i, j や複数の初期ビット列に対しても結果を使いまわすことができる. 最初の連立方程式の計算は O(n^3) (ただし,定数 512 がある) であり,各クエリの処理は O(n^2) なので,十分間に合う.

コード

int sz;
double matr[200][200];
double v[200];

void solve()
{
	for (int i = 0; i < sz; ++i) {
		int piv = i;
		for (int j = i + 1; j < sz; ++j) {
			if (fabs(matr[piv][i]) < fabs(matr[j][i])) piv = j;
		}
		if (piv != i) {
			for (int j = 0; j < sz; ++j) swap(matr[i][j], matr[piv][j]);
			swap(v[i], v[piv]);
		}
		for (int j = i + 1; j < sz; ++j) matr[i][j] /= matr[i][i];
		v[i] /= matr[i][i];
		matr[i][i] = 1.0;
		for (int j = 0; j < sz; ++j) if (j != i) {
			for (int k = i + 1; k < sz; ++k) {
				matr[j][k] -= matr[i][k] * matr[j][i];
			}
			v[j] -= v[i] * matr[j][i];
			matr[j][i] = 0;
		}
	}
}

class BitToggler {
public:
    vector <double> expectation(int n, vector <string> bits, vector <int> pos) {
		auto idx = [&](int ist, int jst, int isi, int cnt) {
			return cnt * 8 + ist * 4 + jst * 2 + isi;
		};

		sz = 8 * (n + 1);
		for (int s = 0; s < sz; ++s) {
			for (int j = 0; j < sz; ++j) matr[s][j] = 0;
			v[s] = 0.0;

			int ist = (s & 4) >> 2;
			int jst = (s & 2) >> 1;
			int isi = (s & 1);
			int cnt = s >> 3;

			matr[s][s] = 1.0;
			if (cnt == 0 && ist == 0 && jst == 0) continue;
			if (cnt == n - 2 && ist == 1 && jst == 1) continue;

			// to i

			int t = idx(ist ^ 1, jst, 1, cnt);
			matr[s][t] -= 1.0 / n;

			// to j

			t = idx(ist, jst ^ 1, 0, cnt);
			matr[s][t] -= 1.0 / n;
			if (isi == 1) v[s] += 1.0 / n;

			// incl

			if (cnt < n - 2) {
				t = idx(ist, jst, 0, cnt + 1);
				matr[s][t] -= (double)(n - 2 - cnt) / n;
			}

			// decl

			if (cnt > 0) {
				t = idx(ist, jst, 0, cnt - 1);
				matr[s][t] -= (double)cnt / n;
			}
		}
		solve();

		vector<double> ans;
		for (int i = 0; i < bits.size(); ++i) {
			int tot = 0;
			for (int j = 0; j < n; ++j) if (bits[i][j] == '1') ++tot;

			double sol = 0;
			for (int j = 0; j < n; ++j) {
				for (int k = 0; k < n; ++k) {
					int bj = (bits[i][j] == '1' ? 1 : 0);
					int bk = (bits[i][k] == '1' ? 1 : 0);

					int rem = tot - bj - bk;
					sol += abs(j - k) * v[idx(bj, bk, j == pos[i] ? 1 : 0, rem)];
				}
			}
			ans.push_back(sol);
		}
		return ans;
    }
};