EvbCFfp1XB

problem and my answer.

August 2017


Approach

問題を読んでSAで良さそうだなと思った。

Visualizer の int[][] component が何回も初期化されていたのを、一回だけ初期化して使いまわして高速化した。
中央の Component 1個だけ score を計算して高速化した。

近傍は Visualizer で言うと int[] perm のランダムに2つ選んで交換するだけ。

seed3 のスコアは約1.6e6

ここまでが submission1 SA

制限時間を10倍にすると、seed3 のスコアは約2e6 になった。
高速化に効果があることが分かったが、
Visuallizer を見ると思ったより小さいエリアしか使えていなかった。

スコアを計算する範囲を中央の方から広げて, greedy に解を作ることにする。
0回目のループ  perm[S/2] と perm[j] を交換して、[S/2,S/2]*[S/2,S/2] のスコアが最もよい j を選ぶ (0<=j<S)。
1回目のループ  perm[S/2-1] と perm[j] を交換して、[S/2-1,S/2]*[S/2-1,S/2] のスコアが最もよい j を選ぶ (0<=j<S)。
2回目のループ  perm[S/2+1] と perm[j] を交換して、[S/2-1,S/2+1]*[S/2-1,S/2+1] のスコアが最もよい j を選ぶ (0<=j<S)。
.
.
.
S-1回目のループ  perm[0またはS-1] と perm[j] を交換して、[0,S-1]*[0,S-1] のスコアが最もよい j を選ぶ (0<=j<S)。

seed3 のスコアは約5e6

”スコアが最もよい j ”は、既に決めたところから選ばれることがほとんどなかったので、選ばないようにした。
そうすると、スコアを差分だけ計算して高速化できた。

ここまでが submission2 greedy

greedy より beam search の方がいいはず。

制限時間を無視して、ビーム幅 10 で seed3 のスコアは7e6以上、
ビーム幅 100 で 8e6以上出ることもあった。

時間調整して、平均ビーム幅約8 seed3 のスコアは約6e6

ここまでが submission3 beam search

バグを修正したりしてコーディングフェーズ終わり、submission4

結局、seed1 は beam search より SA の方が良かったので、もう一週間あればSAの高速化を頑張ったかも。

以上。


追記 beam search を中央からじゃなくて左上から右下にしたら約7%良くなった。








source code (submission4 ) 左上から右下バージョンは次です

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

public class ConnectedComponent {
private static final int[] dr = { 0, 1, 0, -1, };
private static final int[] dc = { -1, 0, 1, 0, };

static final Watch watch = new Watch();
static final XorShift rng = new XorShift(System.nanoTime());

private SAState sa = new SAState();

private int S;
private int[][] original;

private double score;
private double bestScore;
private int[] solution;
private int[] bestSolution;

public int[] permute(int[] matrix) {
init(matrix);

solve();

Utils.debug("time", watch.getSecond(), "score", score, "countSets", countSets);

return solution;
}

private void init(int[] matrix) {
S = (int) Math.sqrt(matrix.length);

original = new int[S][S];
for (int r = 0; r < S; r++) {
for (int c = 0; c < S; c++) {
original[r][c] = matrix[r * S + c];
}
}

int numElements = 0;
for (int r = 0; r < S; r++) {
for (int c = 0; c < S; c++) {
if (original[r][c] != 0) {
numElements++;
}
}
}

component = new int[S][S];

moveHistory = new int[S];

hashHistory = new long[S];

Utils.debug("S", S, "numElements", numElements, String.format("%.1f%%", 100.0 * numElements / (S * S)));
}

private void solve() {
solution = new int[S];
for (int i = 0; i < solution.length; i++) {
solution[i] = i;
}

sa.startTemperature = 1.0 / S;

{
int maxBeamWidth = (int) (50.0 + (10000.0 - 50.0) * (500.0 - S) / (500.0 - 50.0));
ArrayList<State> moves = beamsearch(S, maxBeamWidth);
assert turn == 0;

if (moves == null) {
Utils.debug("moves == null");
}
for (State state : moves) {
next(state.j);
}

sa.startTemperature = 1e-3;
}

score = calculateScoreFromCenter();

bestSolution = Arrays.copyOf(solution, solution.length);
bestScore = score;

Utils.debug();
Utils.debug("score", score);
Utils.debug();

SA();
}

private int[] is;
private int[] minRanges;
private int[] maxRanges;
private int countSets = 0;

private ArrayList<State> beamsearch(int maxDepth, int maxBeamWidth) {
{
is = new int[S];
minRanges = new int[S];
maxRanges = new int[S];

int minRange = S - 1;
int maxRange = 0;
for (int isi = 0, i = S / 2, delta = 0; i >= 0 && i < S; delta = (1 + Math.abs(delta)) * (Integer.signum(delta) >= 0 ? -1 : 1), i += delta) {
is[isi] = i;
minRange = Math.min(minRange, i);
maxRange = Math.max(maxRange, i);
minRanges[isi] = minRange;
maxRanges[isi] = maxRange;
isi++;
}
}

int mask = (1 << (int) (Math.sqrt(500 / S))) - 1;

ArrayList<State> currents = new ArrayList<>();
ArrayList<State> nexts = new ArrayList<>();
State best = null;
currents.add(null);

double startTime = watch.getSecond();
double endTime = 9.5;

double sum = (1.0 + 1.0) * (maxDepth) / 2.0;

double[] timelimits = new double[maxDepth];
for (int i = 0; i < maxDepth; i++) {
if (i == 0) {
timelimits[i] = (endTime - startTime) * (1.0 + (1.0 - 1.0) * (i - 0.0) / (maxDepth - 1 - 0.0)) / sum;
} else {
timelimits[i] = timelimits[i - 1] + (endTime - startTime) * (1.0 + (1.0 - 1.0) * (i - 0.0) / (maxDepth - 1 - 0.0)) / sum;
}
}

for (int depth = 0; depth < maxDepth; depth++) {
if (currents.size() == 0) {
break;
}

int beamWidth = Math.min(maxBeamWidth, currents.size());

CollectionsUtils.select(currents, 0, currents.size(), beamWidth);
CollectionsUtils.sort(currents, 0, beamWidth - 1);

for (int beam = 0; beam < beamWidth; beam++) {
State currentState = currents.get(beam);

if (currentState == null) {
} else {
if (best == null || currentState.compareTo(best) < 0) {
best = currentState;
}
}

if (depth == maxDepth - 1) {
break;
}

set(reverse(toList(currentState)), 0);
countSets++;

int currentSum = 0;
int currentSize = 0;
if (depth > 0) {
calculateScoreFromCenterUsingField2(minRanges[depth - 1], maxRanges[depth - 1]);
currentSum = sumBase;
currentSize = sizeBase;
} else {
}

int bestNComp = nComp;

int i = is[depth];
for (int j = 0; j < S; j++) {
if (j >= minRanges[depth] && j <= maxRanges[depth]) {
if (j == i) {
} else {
continue;
}
}

next(j);

calculateScoreForGreedyUpdateField(minRanges[depth], maxRanges[depth], currentSum, currentSize, depth == 0 ? 0 : (is[depth] - is[depth - 1]), bestNComp);
double score = (sum3 * Math.sqrt(size3));

State next = new State();
next.parent = currentState;
next.j = j;
next.score = score;
next.hash = hash;
nexts.add(next);

previous();
}
if ((beam & mask) == 0) {
if (watch.getSecond() >= timelimits[depth]) {
break;
}
}
}
{
ArrayList<State> swap = currents;
currents = nexts;
nexts = swap;
}
nexts.clear();
}

for (; turn > 0;) {
previous();
}

if (best == null) {
return null;
}

return reverse(toList(best));
}

private ArrayList<State> reverse(ArrayList<State> list) {
for (int l = 0, r = list.size() - 1; l < r; l++, r--) {
list.set(r, list.set(l, list.get(r)));
}
return list;
}

private ArrayList<State> toList(State state2) {
ArrayList<State> res = new ArrayList<>();
for (State current = state2; current != null; current = current.parent) {
res.add(current);
}
return res;
}

private int[] moveHistory;
private int turn = 0;
private long[] hashHistory;
private long hash;

private void next(int j) {
moveHistory[turn] = j;
hashHistory[turn] = hash;

swap(is[turn], j);

hash = hash * 503 + j;

turn++;
}

private void previous() {
turn--;
swap(is[turn], moveHistory[turn]);
hash = hashHistory[turn];
}

private void set(ArrayList<State> list, int startIndex) {
int startIndexMinus1 = startIndex - 1;
for (int i = 0; i < list.size() && startIndex + i < moveHistory.length; i++) {
int j = moveHistory[startIndex + i];
State state = list.get(i);
if (state.j == j) {
startIndexMinus1 = startIndex + i;
continue;
}
break;
}

for (; turn > startIndexMinus1 + 1;) {
previous();
}
for (int i = (startIndexMinus1 + 1) - (startIndex); i < list.size(); i++) {
State state2 = list.get(i);
next(state2.j);
}
}

private void SA() {
double second = 1;
int mask = (1 << 5) - 1;

sa.startTime = watch.getSecond();
sa.endTime = 9.5;

for (sa.loop = 0;; sa.loop++) {

if ((sa.loop & mask) == 0) {
sa.updateTime();

if (sa.isTLE()) {
saveBest();
loadBest();
Utils.debug(sa.loop, String.format("%.2f%%", 100.0 * sa.countChange / sa.loop), String.format("%.2f%%", 100.0 * sa.countAccept / sa.countChange), String.format("%.4f", score), String.format("%.4f", bestScore), String.format("%.6f", sa.temperature));
break;
}

sa.updateTemperature();
if (sa.time > second) {
second++;
Utils.debug(sa.loop, String.format("%.2f%%", 100.0 * sa.countChange / sa.loop), String.format("%.2f%%", 100.0 * sa.countAccept / sa.countChange), String.format("%.4f", score), String.format("%.4f", bestScore), String.format("%.6f", sa.temperature));
}
}

swap();

}
}

private void swap() {
int i = (int) (S * rng.nextDouble());
int j = (int) ((S - 1) * rng.nextDouble());
if (j >= i) {
j++;
}

assert j != i;

swap(i, j);

double newScore = calculateScoreFromCenter();

sa.countChange++;
if (newScore >= score || sa.accept(newScore, score)) {
sa.countAccept++;

score = newScore;

saveBest();

} else {
swap(i, j);
}
}

private void swap(int i, int j) {
int swap = solution[i];
solution[i] = solution[j];
solution[j] = swap;
}

private int[][] component;
private static final int[] queue = new int[500 * 500];
private int nComp = 0;

private double calculateScoreFromCenter() {
int currentNumComponents = nComp;
double maxSum = -1e9;
for (int i = S / 2 - 1; i <= S / 2 + 1; i++) {
for (int j = S / 2 - 1; j <= S / 2 + 1; j++) {
if (original[solution[i]][solution[j]] == 0)
continue;
if (component[i][j] > currentNumComponents)
continue;

int size = 0;
double sum = 0;
nComp++;

component[i][j] = nComp;
queue[size] = (i << 10) | (j);
size++;
sum += original[solution[i]][solution[j]];

for (int ind = 0; ind < size; ind++) {
for (int d = 0; d < 4; d++) {
int newr = ((queue[ind] >>> 10) & 1023) + dr[d];
int newc = (queue[ind] & 1023) + dc[d];
if (newr < 0 || newc < 0 || newr >= S || newc >= S)
continue;
if (component[newr][newc] > currentNumComponents || original[solution[newr]][solution[newc]] == 0) {
continue;
}
component[newr][newc] = nComp;
queue[size] = (newr << 10) | (newc);
size++;
sum += original[solution[newr]][solution[newc]];
}
}
sum *= Math.sqrt(size);
if (sum > maxSum) {
maxSum = sum;
}
}
}
return maxSum;
}

private int sumBase;
private int sizeBase;

private void calculateScoreFromCenterUsingField2(int minRange, int maxRange) {

int currentNumComponents = nComp;
int i = S / 2;
int j = S / 2;

sizeBase = 0;
sumBase = 0;
nComp++;

if (original[solution[i]][solution[j]] == 0) {
return;
}

component[i][j] = nComp;
queue[sizeBase] = (i << 10) | (j);
sizeBase++;
sumBase += original[solution[i]][solution[j]];

for (int ind = 0; ind < sizeBase; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
if (newr < minRange || newr > maxRange) {
continue;
}
int newc = c + dc[d];
if (newc < minRange || newc > maxRange) {
continue;
}
if (component[newr][newc] > currentNumComponents || original[solution[newr]][solution[newc]] == 0) {
continue;
}

component[newr][newc] = nComp;
queue[sizeBase] = (newr << 10) | (newc);
sizeBase++;
sumBase += original[solution[newr]][solution[newc]];
}
}
}

private int sum3;
private int size3;

private void calculateScoreForGreedyUpdateField(int minRange, int maxRange, int currentSum, int currentSize, int sign, int parentNComp) {
size3 = 0;
sum3 = 0;
nComp++;

if (sign < 0) {
{
int c = minRange;
for (int r = minRange + 1; r <= maxRange; r++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((c + 1 < S && (component[r][c + 1] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
{
int r = minRange;
for (int c = minRange; c <= maxRange; c++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((r + 1 < S && (component[r + 1][c] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
for (int ind = 0; ind < size3; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
int newc = c + dc[d];

if (newr < minRange || newr > maxRange || newc < minRange || newc > maxRange) {
continue;
}
if (original[solution[newr]][solution[newc]] == 0 || component[newr][newc] == nComp || component[newr][newc] == parentNComp) {
continue;
}

component[newr][newc] = nComp;
queue[size3] = (newr << 10) | (newc);
sum3 += original[solution[newr]][solution[newc]];
size3++;
}
}
} else if (sign > 0) {
{
int c = maxRange;
for (int r = minRange; r <= maxRange - 1; r++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((c - 1 >= 0 && (component[r][c - 1] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
{
int r = maxRange;
for (int c = minRange; c <= maxRange; c++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((r - 1 >= 0 && (component[r - 1][c] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
for (int ind = 0; ind < size3; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
int newc = c + dc[d];

if (newr < minRange || newr > maxRange || newc < minRange || newc > maxRange) {
continue;
}
if (original[solution[newr]][solution[newc]] == 0 || component[newr][newc] == nComp || component[newr][newc] == parentNComp) {
continue;
}

component[newr][newc] = nComp;
queue[size3] = (newr << 10) | (newc);
sum3 += original[solution[newr]][solution[newc]];
size3++;
}
}
} else {
if (original[solution[S / 2]][solution[S / 2]] != 0) {
component[S / 2][S / 2] = nComp;
sum3 += original[solution[S / 2]][solution[S / 2]];
size3++;
}
}
sum3 += currentSum;
size3 += currentSize;
}

private void loadBest() {
score = bestScore;
for (int i = 0; i < solution.length; i++) {
solution[i] = bestSolution[i];
}
}

private void saveBest() {
if (score > bestScore) {
bestScore = score;
for (int i = 0; i < solution.length; i++) {
bestSolution[i] = solution[i];
}
}
}

public static void main(String[] args) {
try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) {

int M = Integer.parseInt(br.readLine());
int[] matrix = new int[M];
for (int i = 0; i < M; ++i) {
matrix[i] = Integer.parseInt(br.readLine());
}

ConnectedComponent cc = new ConnectedComponent();
int[] ret = cc.permute(matrix);

System.out.println(ret.length);
for (int i = 0; i < ret.length; ++i) {
System.out.println(ret[i]);
}
System.out.flush();
} catch (Exception e) {
e.printStackTrace();
}
}
}

class SAState {

public static final boolean useTime = true;

public double startTime = 0;
public double endTime = 9.5;
public double time = startTime;

public double startTemperature = 1e-1;
public double endTemperature = 0;
public double temperature = startTemperature;

public double startRange = 128;
public double endRange = 1;
public double range = startRange;

public int loop;
public int countChange;
public int countAccept;

public void updateTemperature() {
temperature = endTemperature + (startTemperature - endTemperature) * Math.pow((endTime - time) / (endTime - startTime), 1.0);
}

public void updateRange() {
range = (endRange + (startRange - endRange) * Math.pow((endTime - time) / (endTime - startTime), 1.0));
}

public void updateTime() {
time = useTime ? ConnectedComponent.watch.getSecond() : loop;
}

public boolean isTLE() {
return time >= endTime;
}

public boolean accept(double newScore, double currentScore) {
assert newScore - currentScore < 0;
assert temperature >= 0;
return ConnectedComponent.rng.nextDouble() < StrictMath.exp((newScore - currentScore) / (currentScore * temperature));
}
}

final class Utils {
private Utils() {
}

public static final void debug(Object... o) {
System.err.println(toString(o));
}

public static final String toString(Object... o) {
return Arrays.deepToString(o);
}

}

class Watch {
private long start;

public Watch() {
init();
}

public double getSecond() {
return (System.nanoTime() - start) * 1e-9;
}

public void init() {
init(System.nanoTime());
}

private void init(long start) {
this.start = start;
}

public String getSecondString() {
return toString(getSecond());
}

public static final String toString(double second) {
if (second < 60) {
return String.format("%5.2fs", second);
} else if (second < 60 * 60) {
int minute = (int) (second / 60);
return String.format("%2dm%2ds", minute, (int) (second % 60));
} else {
int hour = (int) (second / (60 * 60));
int minute = (int) (second / 60);
return String.format("%2dh%2dm%2ds", hour, minute % (60), (int) (second % 60));
}
}

}

class XorShift {
private int w = 88675123;
private int x = 123456789;
private int y = 362436069;
private int z = 521288629;

public XorShift(long l) {
x = (int) l;
}

public int nextInt() {
final int t = x ^ (x << 11);
x = y;
y = z;
z = w;
w = w ^ (w >>> 19) ^ (t ^ (t >>> 8));
return w;
}

public long nextLong() {
return ((long) nextInt() << 32) ^ (long) nextInt();
}

public double nextDouble() {
return (nextInt() >>> 1) * 4.6566128730773926E-10;
}

public int nextInt(int n) {
return (int) (n * nextDouble());
}

}

class State implements Comparable<State> {
State parent;
int j;
double score;
long hash;

@Override
public int compareTo(State o) {
if (score > o.score) {
return -1;
}
if (score < o.score) {
return 1;
}
if (hash < o.hash) {
return -1;
}
if (hash > o.hash) {
return 1;
}
return 0;
}

@Override
public String toString() {
return Utils.toString("j", j, "score", score);
}
}

class CollectionsUtils {
private CollectionsUtils() {
}

public static final <T> void shuffle(ArrayList<T> list, Random rng) {
for (int i = list.size() - 1; i >= 0; i--) {
int j = (int) ((i + 1) * rng.nextDouble());
CollectionsUtils.swap(list, i, j);
}
}

public static final <T> void shuffle(ArrayList<T> list, XorShift rng) {
for (int i = list.size() - 1; i >= 0; i--) {
int j = (int) ((i + 1) * rng.nextDouble());
CollectionsUtils.swap(list, i, j);
}
}

public static final <T> void select0(ArrayList<T> a, int l, int r, int k, Comparator<T> comparator) {
while (l < r) {
int i = CollectionsUtils.partition3(a, l, r, comparator);
if (i >= k)
r = i - 1;
if (i <= k)
l = i + 1;
}
}

public static final <T> void select(ArrayList<T> a, int startInclusive, int endExclusive, int k, Comparator<T> comparator) {
select0(a, startInclusive, endExclusive - 1, k, comparator);
}

public static final <T extends Comparable<T>> void select0(ArrayList<T> a, int l, int r, int k) {
while (l < r) {
int i = CollectionsUtils.partition3(a, l, r);
if (i >= k)
r = i - 1;
if (i <= k)
l = i + 1;
}
}

public static final <T extends Comparable<T>> void select(ArrayList<T> a, int startInclusive, int endExclusive, int k) {
select0(a, startInclusive, endExclusive - 1, k);
}

public static final <T> void swap(ArrayList<T> a, int i, int j) {
T swap = a.set(i, a.get(j));
a.set(j, swap);
}

public static final <T> void sort3(ArrayList<T> a, int i, int j, int k, Comparator<T> comparator) {
if (comparator.compare(a.get(i), a.get(j)) > 0) {
swap(a, i, j);
}
if (comparator.compare(a.get(i), a.get(k)) > 0) {
swap(a, i, k);
}
if (comparator.compare(a.get(j), a.get(k)) > 0) {
swap(a, j, k);
}
}

public static final <T extends Comparable<T>> void sort3(ArrayList<T> a, int i, int j, int k) {
if (a.get(i).compareTo(a.get(j)) > 0) {
swap(a, i, j);
}
if (a.get(i).compareTo(a.get(k)) > 0) {
swap(a, i, k);
}
if (a.get(j).compareTo(a.get(k)) > 0) {
swap(a, j, k);
}
}

public static final <T> int partition3(ArrayList<T> a, int l, int r, Comparator<T> comparator) {
int center = (l + r) >>> 1;
sort3(a, l, center, r, comparator);
swap(a, center, r - 1);
if (r - l < 3) {
return l;
}
int r1 = r - 1;
T v = a.get(r1);
int i = l - 1;
int j = r1;
for (;;) {
for (; comparator.compare(a.get(++i), v) < 0;) {
}
for (; comparator.compare(a.get(--j), v) > 0;) {
}
if (i >= j) {
break;
}
swap(a, i, j);
}
swap(a, i, r1);
return i;
}

public static final <T extends Comparable<T>> int partition3(ArrayList<T> a, int l, int r) {
int center = (l + r) >>> 1;
sort3(a, l, center, r);
swap(a, center, r - 1);
if (r - l < 3) {
return l;
}
int r1 = r - 1;
T v = a.get(r1);
int i = l - 1;
int j = r1;
for (;;) {
for (; a.get(++i).compareTo(v) < 0;) {
}
for (; a.get(--j).compareTo(v) > 0;) {
}
if (i >= j) {
break;
}
swap(a, i, j);
}
swap(a, i, r1);
return i;
}

public static final <T extends Comparable<T>> int partition(ArrayList<T> a, int l, int r) {
int i = l - 1, j = r;
T v = a.get(r);
for (;;) {
while (a.get(++i).compareTo(v) < 0)
;
while (v.compareTo(a.get(--j)) < 0)
if (j == l)
break;
if (i >= j)
break;
swap(a, i, j);
}
swap(a, i, r);
return i;
}

public static final <T> void sort(ArrayList<T> a, int lInclusive, int rInclusive, Comparator<T> comparator) {
if (lInclusive >= rInclusive) {
return;
}
int k = partition3(a, lInclusive, rInclusive, comparator);
sort(a, lInclusive, k - 1, comparator);
sort(a, k + 1, rInclusive, comparator);
}

public static final <T extends Comparable<T>> void sort(ArrayList<T> a, int lInclusive, int rInclusive) {
if (lInclusive >= rInclusive) {
return;
}
int k = partition3(a, lInclusive, rInclusive);
sort(a, lInclusive, k - 1);
sort(a, k + 1, rInclusive);
}

public static final <T> ArrayList<T> reverse(ArrayList<T> list) {
for (int l = 0, r = list.size() - 1; l < r; l++, r--) {
list.set(r, list.set(l, list.get(r)));
}
return list;
}

public static final <T> ArrayList<T> newReverse(ArrayList<T> list) {
ArrayList<T> res = new ArrayList<>();
for (int i = list.size() - 1; i >= 0; i--) {
res.add(list.get(i));
}
return res;
}

}


source code (左上から右下)


import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

public class ConnectedComponent {
private static final int[] dr = { 0, 1, 0, -1, };
private static final int[] dc = { -1, 0, 1, 0, };

static final Watch watch = new Watch();
static final XorShift rng = new XorShift(System.nanoTime());

private SAState sa = new SAState();

private int S;
private int[][] original;

private double score;
private double bestScore;
private int[] solution;
private int[] bestSolution;

public int[] permute(int[] matrix) {
init(matrix);

solve();

Utils.debug("time", watch.getSecond(), "score", score, "countSets", countSets);

return solution;
}

private void init(int[] matrix) {
S = (int) Math.sqrt(matrix.length);

original = new int[S][S];
for (int r = 0; r < S; r++) {
for (int c = 0; c < S; c++) {
original[r][c] = matrix[r * S + c];
}
}

int numElements = 0;
for (int r = 0; r < S; r++) {
for (int c = 0; c < S; c++) {
if (original[r][c] != 0) {
numElements++;
}
}
}

component = new int[S][S];

moveHistory = new int[S];

hashHistory = new long[S];

Utils.debug("S", S, "numElements", numElements, String.format("%.1f%%", 100.0 * numElements / (S * S)));
}

private void solve() {
solution = new int[S];
for (int i = 0; i < solution.length; i++) {
solution[i] = i;
}

sa.startTemperature = 1.0 / S;

{
int maxBeamWidth = (int) (50.0 + (10000.0 - 50.0) * (500.0 - S) / (500.0 - 50.0));
ArrayList<State> moves = beamsearch(S, maxBeamWidth);
assert turn == 0;

if (moves == null) {
Utils.debug("moves == null");
}
for (State state : moves) {
next(state.j);
}

sa.startTemperature = 1e-3;
}

score = calculateScoreFromCenter();

bestSolution = Arrays.copyOf(solution, solution.length);
bestScore = score;

Utils.debug();
Utils.debug("score", score);
Utils.debug();

SA();
}

private int[] is;
private int[] minRanges;
private int[] maxRanges;
private int countSets = 0;

private ArrayList<State> beamsearch(int maxDepth, int maxBeamWidth) {
{
is = new int[S];
minRanges = new int[S];
maxRanges = new int[S];

int minRange = S - 1;
int maxRange = 0;
// for (int isi = 0, i = S / 2, delta = 0; i >= 0 && i < S; delta = (1 + Math.abs(delta)) * (Integer.signum(delta) >= 0 ? -1 : 1), i += delta) {
for (int isi = 0, i = 0; i < S; i += 1) {
is[isi] = i;
minRange = Math.min(minRange, i);
maxRange = Math.max(maxRange, i);
minRanges[isi] = minRange;
maxRanges[isi] = maxRange;
isi++;
}
}

int mask = (1 << (int) (Math.sqrt(500 / S))) - 1;

ArrayList<State> currents = new ArrayList<>();
ArrayList<State> nexts = new ArrayList<>();
State best = null;
currents.add(null);

double startTime = watch.getSecond();
double endTime = 9.5;

double sum = (1.0 + 1.0) * (maxDepth) / 2.0;

double[] timelimits = new double[maxDepth];
for (int i = 0; i < maxDepth; i++) {
if (i == 0) {
timelimits[i] = (endTime - startTime) * (1.0 + (1.0 - 1.0) * (i - 0.0) / (maxDepth - 1 - 0.0)) / sum;
} else {
timelimits[i] = timelimits[i - 1] + (endTime - startTime) * (1.0 + (1.0 - 1.0) * (i - 0.0) / (maxDepth - 1 - 0.0)) / sum;
}
}

for (int depth = 0; depth < maxDepth; depth++) {
if (currents.size() == 0) {
break;
}

int beamWidth = Math.min(maxBeamWidth, currents.size());

CollectionsUtils.select(currents, 0, currents.size(), beamWidth);
CollectionsUtils.sort(currents, 0, beamWidth - 1);

for (int beam = 0; beam < beamWidth; beam++) {
State currentState = currents.get(beam);

if (currentState == null) {
} else {
if (best == null || currentState.compareTo(best) < 0) {
best = currentState;
}
}

if (depth == maxDepth - 1) {
break;
}

set(reverse(toList(currentState)), 0);
countSets++;

int currentSum = 0;
int currentSize = 0;
if (depth > 0) {
calculateScoreFromCenterUsingField2(minRanges[depth - 1], maxRanges[depth - 1]);
currentSum = sumBase;
currentSize = sizeBase;
} else {
}

int bestNComp = nComp;

int i = is[depth];
for (int j = 0; j < S; j++) {
if (j >= minRanges[depth] && j <= maxRanges[depth]) {
if (j == i) {
} else {
continue;
}
}

next(j);

calculateScoreForGreedyUpdateField(minRanges[depth], maxRanges[depth], currentSum, currentSize, depth == 0 ? 0 : (is[depth] - is[depth - 1]), bestNComp);
double score = (sum3 * Math.sqrt(size3));

State next = new State();
next.parent = currentState;
next.j = j;
next.score = score;
next.hash = hash;
nexts.add(next);

previous();
}
if ((beam & mask) == 0) {
if (watch.getSecond() >= timelimits[depth]) {
break;
}
}
}
{
ArrayList<State> swap = currents;
currents = nexts;
nexts = swap;
}
nexts.clear();
}

for (; turn > 0;) {
previous();
}

if (best == null) {
return null;
}

return reverse(toList(best));
}

private ArrayList<State> reverse(ArrayList<State> list) {
for (int l = 0, r = list.size() - 1; l < r; l++, r--) {
list.set(r, list.set(l, list.get(r)));
}
return list;
}

private ArrayList<State> toList(State state2) {
ArrayList<State> res = new ArrayList<>();
for (State current = state2; current != null; current = current.parent) {
res.add(current);
}
return res;
}

private int[] moveHistory;
private int turn = 0;
private long[] hashHistory;
private long hash;

private void next(int j) {
moveHistory[turn] = j;
hashHistory[turn] = hash;

swap(is[turn], j);

hash = hash * 503 + j;

turn++;
}

private void previous() {
turn--;
swap(is[turn], moveHistory[turn]);
hash = hashHistory[turn];
}

private void set(ArrayList<State> list, int startIndex) {
int startIndexMinus1 = startIndex - 1;
for (int i = 0; i < list.size() && startIndex + i < moveHistory.length; i++) {
int j = moveHistory[startIndex + i];
State state = list.get(i);
if (state.j == j) {
startIndexMinus1 = startIndex + i;
continue;
}
break;
}

for (; turn > startIndexMinus1 + 1;) {
previous();
}
for (int i = (startIndexMinus1 + 1) - (startIndex); i < list.size(); i++) {
State state2 = list.get(i);
next(state2.j);
}
}

private void SA() {
double second = 1;
int mask = (1 << 5) - 1;

sa.startTime = watch.getSecond();
sa.endTime = 9.5;

for (sa.loop = 0;; sa.loop++) {

if ((sa.loop & mask) == 0) {
sa.updateTime();

if (sa.isTLE()) {
saveBest();
loadBest();
Utils.debug(sa.loop, String.format("%.2f%%", 100.0 * sa.countChange / sa.loop), String.format("%.2f%%", 100.0 * sa.countAccept / sa.countChange), String.format("%.4f", score), String.format("%.4f", bestScore), String.format("%.6f", sa.temperature));
break;
}

sa.updateTemperature();
if (sa.time > second) {
second++;
Utils.debug(sa.loop, String.format("%.2f%%", 100.0 * sa.countChange / sa.loop), String.format("%.2f%%", 100.0 * sa.countAccept / sa.countChange), String.format("%.4f", score), String.format("%.4f", bestScore), String.format("%.6f", sa.temperature));
}
}

swap();

}
}

private void swap() {
int i = (int) (S * rng.nextDouble());
int j = (int) ((S - 1) * rng.nextDouble());
if (j >= i) {
j++;
}

assert j != i;

swap(i, j);

double newScore = calculateScoreFromCenter();

sa.countChange++;
if (newScore >= score || sa.accept(newScore, score)) {
sa.countAccept++;

score = newScore;

saveBest();

} else {
swap(i, j);
}
}

private void swap(int i, int j) {
int swap = solution[i];
solution[i] = solution[j];
solution[j] = swap;
}

private int[][] component;
private static final int[] queue = new int[500 * 500];
private int nComp = 0;

private double calculateScoreFromCenter() {
int currentNumComponents = nComp;
double maxSum = -1e9;
for (int i = 1 - 1; i <= 1 + 1; i++) {
for (int j = 1 - 1; j <= 1 + 1; j++) {
if (original[solution[i]][solution[j]] == 0)
continue;
if (component[i][j] > currentNumComponents)
continue;

int size = 0;
double sum = 0;
nComp++;

component[i][j] = nComp;
queue[size] = (i << 10) | (j);
size++;
sum += original[solution[i]][solution[j]];

for (int ind = 0; ind < size; ind++) {
for (int d = 0; d < 4; d++) {
int newr = ((queue[ind] >>> 10) & 1023) + dr[d];
int newc = (queue[ind] & 1023) + dc[d];
if (newr < 0 || newc < 0 || newr >= S || newc >= S)
continue;
if (component[newr][newc] > currentNumComponents || original[solution[newr]][solution[newc]] == 0) {
continue;
}
component[newr][newc] = nComp;
queue[size] = (newr << 10) | (newc);
size++;
sum += original[solution[newr]][solution[newc]];
}
}
sum *= Math.sqrt(size);
if (sum > maxSum) {
maxSum = sum;
}
}
}
return maxSum;
}

private int sumBase;
private int sizeBase;

private void calculateScoreFromCenterUsingField2(int minRange, int maxRange) {

int currentNumComponents = nComp;
int i = 0;
int j = 0;

sizeBase = 0;
sumBase = 0;
nComp++;

if (original[solution[i]][solution[j]] == 0) {
return;
}

component[i][j] = nComp;
queue[sizeBase] = (i << 10) | (j);
sizeBase++;
sumBase += original[solution[i]][solution[j]];

for (int ind = 0; ind < sizeBase; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
if (newr < minRange || newr > maxRange) {
continue;
}
int newc = c + dc[d];
if (newc < minRange || newc > maxRange) {
continue;
}
if (component[newr][newc] > currentNumComponents || original[solution[newr]][solution[newc]] == 0) {
continue;
}

component[newr][newc] = nComp;
queue[sizeBase] = (newr << 10) | (newc);
sizeBase++;
sumBase += original[solution[newr]][solution[newc]];
}
}
}

private int sum3;
private int size3;

private void calculateScoreForGreedyUpdateField(int minRange, int maxRange, int currentSum, int currentSize, int sign, int parentNComp) {
size3 = 0;
sum3 = 0;
nComp++;

if (sign < 0) {
{
int c = minRange;
for (int r = minRange + 1; r <= maxRange; r++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((c + 1 < S && (component[r][c + 1] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
{
int r = minRange;
for (int c = minRange; c <= maxRange; c++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((r + 1 < S && (component[r + 1][c] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
for (int ind = 0; ind < size3; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
int newc = c + dc[d];

if (newr < minRange || newr > maxRange || newc < minRange || newc > maxRange) {
continue;
}
if (original[solution[newr]][solution[newc]] == 0 || component[newr][newc] == nComp || component[newr][newc] == parentNComp) {
continue;
}

component[newr][newc] = nComp;
queue[size3] = (newr << 10) | (newc);
sum3 += original[solution[newr]][solution[newc]];
size3++;
}
}
} else if (sign > 0) {
{
int c = maxRange;
for (int r = minRange; r <= maxRange - 1; r++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((c - 1 >= 0 && (component[r][c - 1] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
{
int r = maxRange;
for (int c = minRange; c <= maxRange; c++) {
if (original[solution[r]][solution[c]] == 0) {
continue;
}
if ((r - 1 >= 0 && (component[r - 1][c] == parentNComp))) {
component[r][c] = nComp;

queue[size3] = (r << 10) | (c);
sum3 += original[solution[r]][solution[c]];
size3++;
}
}
}
for (int ind = 0; ind < size3; ind++) {
int r = (queue[ind] >>> 10) & 1023;
int c = queue[ind] & 1023;
for (int d = 0; d < 4; d++) {
int newr = r + dr[d];
int newc = c + dc[d];

if (newr < minRange || newr > maxRange || newc < minRange || newc > maxRange) {
continue;
}
if (original[solution[newr]][solution[newc]] == 0 || component[newr][newc] == nComp || component[newr][newc] == parentNComp) {
continue;
}

component[newr][newc] = nComp;
queue[size3] = (newr << 10) | (newc);
sum3 += original[solution[newr]][solution[newc]];
size3++;
}
}
} else {
if (original[solution[0]][solution[0]] != 0) {
component[0][0] = nComp;
sum3 += original[solution[0]][solution[0]];
size3++;
}
}
sum3 += currentSum;
size3 += currentSize;
}

private void loadBest() {
score = bestScore;
for (int i = 0; i < solution.length; i++) {
solution[i] = bestSolution[i];
}
}

private void saveBest() {
if (score > bestScore) {
bestScore = score;
for (int i = 0; i < solution.length; i++) {
bestSolution[i] = solution[i];
}
}
}

public static void main(String[] args) {
try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) {

int M = Integer.parseInt(br.readLine());
int[] matrix = new int[M];
for (int i = 0; i < M; ++i) {
matrix[i] = Integer.parseInt(br.readLine());
}

ConnectedComponent cc = new ConnectedComponent();
int[] ret = cc.permute(matrix);

System.out.println(ret.length);
for (int i = 0; i < ret.length; ++i) {
System.out.println(ret[i]);
}
System.out.flush();
} catch (Exception e) {
e.printStackTrace();
}
}
}

class SAState {

public static final boolean useTime = true;

public double startTime = 0;
public double endTime = 9.5;
public double time = startTime;

public double startTemperature = 1e-1;
public double endTemperature = 0;
public double temperature = startTemperature;

public double startRange = 128;
public double endRange = 1;
public double range = startRange;

public int loop;
public int countChange;
public int countAccept;

public void updateTemperature() {
temperature = endTemperature + (startTemperature - endTemperature) * Math.pow((endTime - time) / (endTime - startTime), 1.0);
}

public void updateRange() {
range = (endRange + (startRange - endRange) * Math.pow((endTime - time) / (endTime - startTime), 1.0));
}

public void updateTime() {
time = useTime ? ConnectedComponent.watch.getSecond() : loop;
}

public boolean isTLE() {
return time >= endTime;
}

public boolean accept(double newScore, double currentScore) {
assert newScore - currentScore < 0;
assert temperature >= 0;
return ConnectedComponent.rng.nextDouble() < StrictMath.exp((newScore - currentScore) / (currentScore * temperature));
}
}

final class Utils {
private Utils() {
}

public static final void debug(Object... o) {
System.err.println(toString(o));
}

public static final String toString(Object... o) {
return Arrays.deepToString(o);
}

}

class Watch {
private long start;

public Watch() {
init();
}

public double getSecond() {
return (System.nanoTime() - start) * 1e-9;
}

public void init() {
init(System.nanoTime());
}

private void init(long start) {
this.start = start;
}

public String getSecondString() {
return toString(getSecond());
}

public static final String toString(double second) {
if (second < 60) {
return String.format("%5.2fs", second);
} else if (second < 60 * 60) {
int minute = (int) (second / 60);
return String.format("%2dm%2ds", minute, (int) (second % 60));
} else {
int hour = (int) (second / (60 * 60));
int minute = (int) (second / 60);
return String.format("%2dh%2dm%2ds", hour, minute % (60), (int) (second % 60));
}
}

}

class XorShift {
private int w = 88675123;
private int x = 123456789;
private int y = 362436069;
private int z = 521288629;

public XorShift(long l) {
x = (int) l;
}

public int nextInt() {
final int t = x ^ (x << 11);
x = y;
y = z;
z = w;
w = w ^ (w >>> 19) ^ (t ^ (t >>> 8));
return w;
}

public long nextLong() {
return ((long) nextInt() << 32) ^ (long) nextInt();
}

public double nextDouble() {
return (nextInt() >>> 1) * 4.6566128730773926E-10;
}

public int nextInt(int n) {
return (int) (n * nextDouble());
}

}

class State implements Comparable<State> {
State parent;
int j;
double score;
long hash;

@Override
public int compareTo(State o) {
if (score > o.score) {
return -1;
}
if (score < o.score) {
return 1;
}
if (hash < o.hash) {
return -1;
}
if (hash > o.hash) {
return 1;
}
return 0;
}

@Override
public String toString() {
return Utils.toString("j", j, "score", score);
}
}

class CollectionsUtils {
private CollectionsUtils() {
}

public static final <T> void shuffle(ArrayList<T> list, Random rng) {
for (int i = list.size() - 1; i >= 0; i--) {
int j = (int) ((i + 1) * rng.nextDouble());
CollectionsUtils.swap(list, i, j);
}
}

public static final <T> void shuffle(ArrayList<T> list, XorShift rng) {
for (int i = list.size() - 1; i >= 0; i--) {
int j = (int) ((i + 1) * rng.nextDouble());
CollectionsUtils.swap(list, i, j);
}
}

public static final <T> void select0(ArrayList<T> a, int l, int r, int k, Comparator<T> comparator) {
while (l < r) {
int i = CollectionsUtils.partition3(a, l, r, comparator);
if (i >= k)
r = i - 1;
if (i <= k)
l = i + 1;
}
}

public static final <T> void select(ArrayList<T> a, int startInclusive, int endExclusive, int k, Comparator<T> comparator) {
select0(a, startInclusive, endExclusive - 1, k, comparator);
}

public static final <T extends Comparable<T>> void select0(ArrayList<T> a, int l, int r, int k) {
while (l < r) {
int i = CollectionsUtils.partition3(a, l, r);
if (i >= k)
r = i - 1;
if (i <= k)
l = i + 1;
}
}

public static final <T extends Comparable<T>> void select(ArrayList<T> a, int startInclusive, int endExclusive, int k) {
select0(a, startInclusive, endExclusive - 1, k);
}

public static final <T> void swap(ArrayList<T> a, int i, int j) {
T swap = a.set(i, a.get(j));
a.set(j, swap);
}

public static final <T> void sort3(ArrayList<T> a, int i, int j, int k, Comparator<T> comparator) {
if (comparator.compare(a.get(i), a.get(j)) > 0) {
swap(a, i, j);
}
if (comparator.compare(a.get(i), a.get(k)) > 0) {
swap(a, i, k);
}
if (comparator.compare(a.get(j), a.get(k)) > 0) {
swap(a, j, k);
}
}

public static final <T extends Comparable<T>> void sort3(ArrayList<T> a, int i, int j, int k) {
if (a.get(i).compareTo(a.get(j)) > 0) {
swap(a, i, j);
}
if (a.get(i).compareTo(a.get(k)) > 0) {
swap(a, i, k);
}
if (a.get(j).compareTo(a.get(k)) > 0) {
swap(a, j, k);
}
}

public static final <T> int partition3(ArrayList<T> a, int l, int r, Comparator<T> comparator) {
int center = (l + r) >>> 1;
sort3(a, l, center, r, comparator);
swap(a, center, r - 1);
if (r - l < 3) {
return l;
}
int r1 = r - 1;
T v = a.get(r1);
int i = l - 1;
int j = r1;
for (;;) {
for (; comparator.compare(a.get(++i), v) < 0;) {
}
for (; comparator.compare(a.get(--j), v) > 0;) {
}
if (i >= j) {
break;
}
swap(a, i, j);
}
swap(a, i, r1);
return i;
}

public static final <T extends Comparable<T>> int partition3(ArrayList<T> a, int l, int r) {
int center = (l + r) >>> 1;
sort3(a, l, center, r);
swap(a, center, r - 1);
if (r - l < 3) {
return l;
}
int r1 = r - 1;
T v = a.get(r1);
int i = l - 1;
int j = r1;
for (;;) {
for (; a.get(++i).compareTo(v) < 0;) {
}
for (; a.get(--j).compareTo(v) > 0;) {
}
if (i >= j) {
break;
}
swap(a, i, j);
}
swap(a, i, r1);
return i;
}

public static final <T extends Comparable<T>> int partition(ArrayList<T> a, int l, int r) {
int i = l - 1, j = r;
T v = a.get(r);
for (;;) {
while (a.get(++i).compareTo(v) < 0)
;
while (v.compareTo(a.get(--j)) < 0)
if (j == l)
break;
if (i >= j)
break;
swap(a, i, j);
}
swap(a, i, r);
return i;
}

public static final <T> void sort(ArrayList<T> a, int lInclusive, int rInclusive, Comparator<T> comparator) {
if (lInclusive >= rInclusive) {
return;
}
int k = partition3(a, lInclusive, rInclusive, comparator);
sort(a, lInclusive, k - 1, comparator);
sort(a, k + 1, rInclusive, comparator);
}

public static final <T extends Comparable<T>> void sort(ArrayList<T> a, int lInclusive, int rInclusive) {
if (lInclusive >= rInclusive) {
return;
}
int k = partition3(a, lInclusive, rInclusive);
sort(a, lInclusive, k - 1);
sort(a, k + 1, rInclusive);
}

public static final <T> ArrayList<T> reverse(ArrayList<T> list) {
for (int l = 0, r = list.size() - 1; l < r; l++, r--) {
list.set(r, list.set(l, list.get(r)));
}
return list;
}

public static final <T> ArrayList<T> newReverse(ArrayList<T> list) {
ArrayList<T> res = new ArrayList<>();
for (int i = list.size() - 1; i >= 0; i--) {
res.add(list.get(i));
}
return res;
}

}





Approach

684468.79 RandomForest
689847.94 RandomForest + medianfilter +-1
693728.97 RandomForest + medianfilter +-5

RandomForest は 1 image あたり 15625 samples,
feature は +-delta(=4, 8) の正方形の色の{平均、歪度,尖度} * 正方形の中心{ (r,c), (r-delta,c-delta), (r-delta,c+delta), (r+delta,c-delta), (r+delta,c+delta)} * 4色{gray, red, green, blue}。


chainerは、 VOC2012 で練習中で間に合わなかった。

このページのトップヘ