EvbCFfp1XB

イーブイ進化系 C7BMkOO7Qbmcwck7(AtCoder) dtwYhl0YJqzOnZKi(CodinGame) AnlJ8vIySM8j8Nfq(Codeforces) j8LsXlKzJPRosaXt(Kaggle)




Approach 焼きなまし法を使いました。

  • 近傍1 : 隣の領域に1マス渡す(数の和が最大のとき) or 1マス受け取る(数の和が最小のとき)。
    • 数の和が最小または最大の領域のみ選ぶ。
    • 数の和が最小または最大の領域で、隣の領域に隣接しているマスと方向を、追加削除参照がO(1)のSetで保持する。
  • 近傍2 : 隣の領域に1マス渡して、同じ領域から別の1マス受け取る。
    • 隣接している2個所を選んで、2通りの受け渡しの組合せの内、数の差が小さい方を使う。
  • 多スタート : 6回
    • 焼きなました後の領域を見ると、多スタートしてもあまり効果なさそうだけど、ちょっと良くなる。


感想 1回に1マスづつ受け渡しする近傍で、スコアが「悪くなる -> 良くなる」という遷移が受理されにくいのを、 1回に2マス受け渡しする近傍にして、悪くなる遷移を飛ばすと、スコアが良くなる典型的な問題だった。

Source Code
import java.util.Arrays;
import java.util.HashMap;
import java.util.Scanner;

public class Main {
    private static final int[] dr = { -1, 0, 0, 1, };
    private static final int[] dc = { 0, -1, 1, 0, };

    private int N;
    private int K;
    private int[][] numbers;

    final static XorShift rng = new XorShift(System.nanoTime());
    final static Watch watch = new Watch();
    private SAState sa = new SAState();

    private double score;
    private double bestScore;
    private int[][] ids;
    private int[][] bestIds;
    private int[] sumNumbers;
    private IntSet[] availableSet;

    private HashMap sumNumberToIds;
    private int minSumNumber;
    private int maxSumNumber;

    private InitializedBooleanArray_v2 used = new InitializedBooleanArray_v2();
    private int minScore;

    public static void main(String[] args) {
        new Main().run();
    }

    private void run() {
        read();
        init();
        solve();
        write();
    }

    private void read() {
        try (Scanner sc = new Scanner(System.in)) {
            N = sc.nextInt();
            K = sc.nextInt();
            numbers = new int[N][N];
            for (int r = 0; r < N; r++) {
                for (int c = 0; c < N; c++) {
                    numbers[r][c] = sc.nextInt();
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        Utils.debug("N", N, "K", K);
    }

    private void init() {
        ids = new int[N][N];
        bestIds = new int[N][N];
        sumNumbers = new int[K];

        score = 1e9;
        bestScore = 1e9;

        availableSet = new IntSet[K];
        for (int id = 0; id < K; id++) {
            availableSet[id] = new IntSet(N * N * 4);
        }

        sumNumberToIds = new HashMap<>();

        used.init(N * N);
    }

    private void solve() {
        greedy();
        multiSA();
    }

    private void greedy() {
        double mean = 0;
        {
            int sum = 0;
            for (int r = 0; r < N; r++) {
                for (int c = 0; c < N; c++) {
                    sum += numbers[r][c];
                }
            }
            mean = (double) sum / K;
            Utils.debug("mean", mean);

            minScore = 1;
            if (mean == Math.floor(mean)) {
                minScore = 0;
            }
        }

        {
            int id = 0;
            int sum = 0;
            for (int r = 0; r < N; r++) {
                Arrays.fill(ids[r], -1);
            }
            int R2 = 5;
            boolean startLeft = true;
            for (int r = 0; r < N; r += R2) {
                if (startLeft) {
                    for (int c = 0; c < N; c++) {
                        if (ids[r][c] >= 0) {
                            continue;
                        }
                        for (int r2 = 0; r2 < R2; r2++) {
                            if (!isValid(r + r2, 0, N)) {
                                continue;
                            }
                            sum += numbers[r + r2][c];
                            ids[r + r2][c] = id;
                            if (sum >= (int) mean) {
                                boolean add = Math.abs(sum - numbers[r + r2][c] - (int) mean) > Math.abs(sum - (int) mean);
                                sum = 0;
                                id++;
                                if (id >= K - 1) {
                                    id = K - 1;
                                }
                                if (add) {
                                    sum += numbers[r + r2][c];
                                    ids[r + r2][c] = id;
                                }
                            }
                        }
                    }
                    startLeft = !startLeft;
                } else {
                    for (int c = N - 1; c >= 0; c--) {
                        if (ids[r][c] >= 0) {
                            continue;
                        }
                        for (int r2 = 0; r2 < R2; r2++) {
                            if (!isValid(r + r2, 0, N)) {
                                continue;
                            }
                            sum += numbers[r + r2][c];
                            ids[r + r2][c] = id;
                            if (sum >= (int) mean) {
                                boolean add = Math.abs(sum - numbers[r + r2][c] - (int) mean) > Math.abs(sum - (int) mean);
                                sum = 0;
                                id++;
                                if (id >= K - 1) {
                                    id = K - 1;
                                }
                                if (add) {
                                    sum += numbers[r + r2][c];
                                    ids[r + r2][c] = id;
                                }
                            }
                        }
                    }
                    startLeft = !startLeft;
                }
            }
        }

        initSumNumbers();
        initMinMaxSumNumber();
        initSumNumberToIds();
        initAvailableSet();

        score = calculateScore();
        saveBest();
        Utils.debug("greedy", "score", score);
    }

    private void initAvailableSet() {
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                updateAvailableSet(r, c);
            }
        }
    }

    private void initMinMaxSumNumber() {
        minSumNumber = (int) 1e9;
        maxSumNumber = (int) -1e9;
        for (int id = 0; id < K; id++) {
            if (sumNumbers[id] < minSumNumber) {
                minSumNumber = sumNumbers[id];
            }
            if (sumNumbers[id] > maxSumNumber) {
                maxSumNumber = sumNumbers[id];
            }
        }
    }

    private void initSumNumberToIds() {
        sumNumberToIds.clear();
        for (int id = 0; id < K; id++) {
            if (sumNumberToIds.get(sumNumbers[id]) == null) {
                sumNumberToIds.put(sumNumbers[id], new IntSet(K));
            }
            sumNumberToIds.get(sumNumbers[id]).add(id);
        }
    }

    private void updateSumNumberToIds(int id, int oldSumNumbers) {
        if (sumNumberToIds.get(oldSumNumbers) != null) {
            sumNumberToIds.get(oldSumNumbers).remove(id);
        }

        if (sumNumberToIds.get(sumNumbers[id]) == null) {
            sumNumberToIds.put(sumNumbers[id], new IntSet(K));
        }
        sumNumberToIds.get(sumNumbers[id]).add(id);
    }

    private void initSumNumbers() {
        Arrays.fill(sumNumbers, 0);
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                sumNumbers[ids[r][c]] += numbers[r][c];
            }
        }
    }

    private void multiSA() {
        int numRestart = 6;

        double endTime = 9.5;
        double startTime = watch.getSecond();
        double remainTime = endTime - startTime;

        double startStartTemperature = 10;
        double endStartTemperature = 0;
        for (double restart = 0; restart < numRestart; restart++) {
            sa.startTime = startTime + remainTime * restart / numRestart;
            sa.endTime = startTime + remainTime * (restart + 1) / numRestart;
            sa.startTemperature = endStartTemperature + (startStartTemperature - endStartTemperature) * ((numRestart - restart) / numRestart);
            score = calculateScore();

            SA();
        }
        loadBest();

    }

    private void SA() {
        sa.init();
        for (;; ++sa.numIterations) {
            if ((sa.numIterations & ((1 << 3) - 1)) == 0) {
                sa.update();

                if (bestScore == minScore) {
                    Utils.debug(sa.numIterations, String.format("%.2f%%", 100.0 * sa.validIterations / sa.numIterations), String.format("%.2f%%", 100.0 * sa.acceptIterations / sa.validIterations), String.format("%5.0f", score), String.format("%5.0f", bestScore), String.format("%.6f", 1.0 / sa.inverseTemperature), String.format("%.6f", 1.0 / sa.lastAcceptTemperature), "time", watch.getSecondString());
                    break;
                }

                if (sa.isTLE()) {
                    Utils.debug(sa.numIterations, String.format("%.2f%%", 100.0 * sa.validIterations / sa.numIterations), String.format("%.2f%%", 100.0 * sa.acceptIterations / sa.validIterations), String.format("%5.0f", score), String.format("%5.0f", bestScore), String.format("%.6f", 1.0 / sa.inverseTemperature), String.format("%.6f", 1.0 / sa.lastAcceptTemperature), "time", watch.getSecondString());
                    break;
                }
            }

            mutate();
        }
    }

    private void mutate() {
        double random = 2 * rng.nextDouble();
        if (random < 1) {
            change();
        } else if (random < 2) {
            change2();
        }
    }

    private void change() {
        IntSet ids2 = sumNumberToIds.get(rng.nextDouble() < 0.5 ? minSumNumber : maxSumNumber);
        int id = ids2.get((int) (ids2.size() * rng.nextDouble()));

        int v = availableSet[id].get((int) (availableSet[id].size() * rng.nextDouble()));
        int r = v / (N * 4);
        int c = (v >>> 2) % N;
        int d = v & 3;

        int r2 = r + dr[d];
        int c2 = c + dc[d];
        int id2 = ids[r2][c2];

        int currentMinSumNumber = minSumNumber;
        int currentMaxSumNumber = maxSumNumber;
        boolean increaseMin = sumNumbers[id] == minSumNumber;
        if (increaseMin) {
            ids[r2][c2] = id;
            sumNumbers[id] += numbers[r2][c2];
            sumNumbers[id2] -= numbers[r2][c2];
        } else {
            ids[r][c] = id2;
            sumNumbers[id] -= numbers[r][c];
            sumNumbers[id2] += numbers[r][c];
        }
        initMinMaxSumNumber();
        double deltaScore = (maxSumNumber - minSumNumber) - score;
        if (sa.accept(deltaScore) && ((increaseMin && isConnected(id2, r2, c2)) || (!increaseMin && isConnected(id, r, c)))) {
            score += deltaScore;
            if (increaseMin) {
                updateSumNumberToIds(id, sumNumbers[id] - numbers[r2][c2]);
                updateSumNumberToIds(id2, sumNumbers[id2] + numbers[r2][c2]);

                removeAvailableSet(r2, c2, id2);
                updateAvailableSet(r2, c2);
            } else {
                updateSumNumberToIds(id, sumNumbers[id] + numbers[r][c]);
                updateSumNumberToIds(id2, sumNumbers[id2] - numbers[r][c]);

                removeAvailableSet(r, c, id);
                updateAvailableSet(r, c);
            }

            saveBest();
        } else {
            if (increaseMin) {
                ids[r2][c2] = id2;
                sumNumbers[id] -= numbers[r2][c2];
                sumNumbers[id2] += numbers[r2][c2];
            } else {
                ids[r][c] = id;
                sumNumbers[id] += numbers[r][c];
                sumNumbers[id2] -= numbers[r][c];
            }
            minSumNumber = currentMinSumNumber;
            maxSumNumber = currentMaxSumNumber;
        }
    }

    private void change2() {
        IntSet ids2 = sumNumberToIds.get(rng.nextDouble() < 0.5 ? minSumNumber : maxSumNumber);
        int id = ids2.get((int) (ids2.size() * rng.nextDouble()));
        if (availableSet[id].size() < 2) {
            return;
        }
        int v = availableSet[id].get((int) (availableSet[id].size() * rng.nextDouble()));
        int v3 = availableSet[id].get((int) (availableSet[id].size() * rng.nextDouble()));
        while (v3 == v) {
            v3 = availableSet[id].get((int) (availableSet[id].size() * rng.nextDouble()));
        }

        int r = v / (N * 4);
        int c = (v >>> 2) % N;
        int d = v & 3;

        int r2 = r + dr[d];
        int c2 = c + dc[d];
        int id2 = ids[r2][c2];

        int r3 = v3 / (N * 4);
        int c3 = (v3 >>> 2) % N;
        int d3 = v3 & 3;
        if (r3 == r && c3 == c) {
            return;
        }

        int r4 = r3 + dr[d3];
        int c4 = c3 + dc[d3];
        if (r4 == r2 && c4 == c2) {
            return;
        }
        int id4 = ids[r4][c4];
        if (id4 != id2) {
            return;
        }

        boolean increaseMin = sumNumbers[id] == minSumNumber;
        boolean use23 = Math.abs(numbers[r2][c2] - numbers[r3][c3]) < Math.abs(numbers[r4][c4] - numbers[r][c]);
        if (increaseMin) {
            if (use23) {
                if (numbers[r2][c2] >= numbers[r3][c3]) {
                    ids[r2][c2] = id;
                    sumNumbers[id] += numbers[r2][c2];
                    sumNumbers[id2] -= numbers[r2][c2];

                    ids[r3][c3] = id2;
                    sumNumbers[id] -= numbers[r3][c3];
                    sumNumbers[id2] += numbers[r3][c3];
                } else {
                    return;
                }
            } else {
                if (numbers[r4][c4] >= numbers[r][c]) {
                    ids[r4][c4] = id;
                    sumNumbers[id] += numbers[r4][c4];
                    sumNumbers[id2] -= numbers[r4][c4];

                    ids[r][c] = id2;
                    sumNumbers[id] -= numbers[r][c];
                    sumNumbers[id2] += numbers[r][c];
                } else {
                    return;
                }
            }
        } else {
            if (use23) {
                if (numbers[r2][c2] > numbers[r3][c3]) {
                    return;
                } else {
                    ids[r2][c2] = id;
                    sumNumbers[id] += numbers[r2][c2];
                    sumNumbers[id2] -= numbers[r2][c2];

                    ids[r3][c3] = id2;
                    sumNumbers[id] -= numbers[r3][c3];
                    sumNumbers[id2] += numbers[r3][c3];
                }
            } else {
                if (numbers[r4][c4] > numbers[r][c]) {
                    return;
                } else {
                    ids[r4][c4] = id;
                    sumNumbers[id] += numbers[r4][c4];
                    sumNumbers[id2] -= numbers[r4][c4];

                    ids[r][c] = id2;
                    sumNumbers[id] -= numbers[r][c];
                    sumNumbers[id2] += numbers[r][c];
                }
            }
        }

        int currentMinSumNumber = minSumNumber;
        int currentMaxSumNumber = maxSumNumber;
        initMinMaxSumNumber();
        double deltaScore = (maxSumNumber - minSumNumber) - score;
        if (sa.accept(deltaScore) && ((use23 && isConnected(id2, r2, c2) && isConnected(id, r3, c3)) || (!use23 && isConnected(id, r, c) && isConnected(id2, r4, c4)))) {
            score += deltaScore;

            if (use23) {
                updateSumNumberToIds(id, sumNumbers[id] - numbers[r2][c2] + numbers[r3][c3]);
                updateSumNumberToIds(id2, sumNumbers[id2] + numbers[r2][c2] - numbers[r3][c3]);

                removeAvailableSet(r2, c2, id2);
                updateAvailableSet(r2, c2);

                removeAvailableSet(r3, c3, id);
                updateAvailableSet(r3, c3);
            } else {
                updateSumNumberToIds(id, sumNumbers[id] + numbers[r][c] - numbers[r4][c4]);
                updateSumNumberToIds(id2, sumNumbers[id2] - numbers[r][c] + numbers[r4][c4]);

                removeAvailableSet(r4, c4, id2);
                updateAvailableSet(r4, c4);

                removeAvailableSet(r, c, id);
                updateAvailableSet(r, c);
            }

            saveBest();
        } else {
            if (use23) {
                ids[r2][c2] = id2;
                sumNumbers[id] -= numbers[r2][c2];
                sumNumbers[id2] += numbers[r2][c2];

                ids[r3][c3] = id;
                sumNumbers[id] += numbers[r3][c3];
                sumNumbers[id2] -= numbers[r3][c3];
            } else {
                ids[r4][c4] = id2;
                sumNumbers[id] -= numbers[r4][c4];
                sumNumbers[id2] += numbers[r4][c4];

                ids[r][c] = id;
                sumNumbers[id] += numbers[r][c];
                sumNumbers[id2] -= numbers[r][c];
            }
            minSumNumber = currentMinSumNumber;
            maxSumNumber = currentMaxSumNumber;
        }
    }

    private void updateAvailableSet(int r, int c) {
        int id = ids[r][c];
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (!isValid(nr, 0, N) || !isValid(nc, 0, N)) {
                continue;
            }
            if (ids[nr][nc] == id) {
                availableSet[id].remove(compose(r, c, d));
                availableSet[ids[nr][nc]].remove(compose(nr, nc, 3 - d));
                continue;
            }
            availableSet[id].add(compose(r, c, d));
            availableSet[ids[nr][nc]].add(compose(nr, nc, 3 - d));
        }
    }

    private void removeAvailableSet(int r, int c, int oldId) {
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (!isValid(nr, 0, N) || !isValid(nc, 0, N)) {
                continue;
            }
            availableSet[oldId].remove(compose(r, c, d));
            availableSet[ids[nr][nc]].remove(compose(nr, nc, 3 - d));
        }
    }

    private int compose(int r, int c, int d) {
        return r * N * 4 + c * 4 + d;
    }

    private boolean isValid(int value, int min, int minUpper) {
        return value >= min && value < minUpper;
    }

    private void saveBest() {
        if (score < bestScore) {
            bestScore = score;
            for (int r = 0; r < N; r++) {
                for (int c = 0; c < N; c++) {
                    bestIds[r][c] = ids[r][c];
                }
            }
        }
    }

    private void loadBest() {
        score = bestScore;
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                ids[r][c] = bestIds[r][c];
            }
        }
    }

    private double calculateScore() {
        if (!isConnected()) {
            return 1e9;
        }

        int[] sumNumbers = new int[K];
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                int id = ids[r][c];
                if (id < 0 || id >= K) {
                    return 1e9;
                }
                sumNumbers[id] += numbers[r][c];
            }
        }
        Arrays.sort(sumNumbers);

        return sumNumbers[K - 1] - sumNumbers[0];
    }

    private boolean isConnected(int id, int r0, int c0) {
        used.clear();
        int count = 0;
        for (int d = 0; d < dr.length; d++) {
            int r = r0 + dr[d];
            int c = c0 + dc[d];
            if (!isValid(r, 0, N) || !isValid(c, 0, N)) {
                continue;
            }
            if (ids[r][c] != id) {
                continue;
            }
            if (used.get(r * N + c)) {
                continue;
            }
            dfs(r, c, ids[r][c], used);
            count++;
            if (count > 1) {
                return false;
            }
        }
        return true;
    }

    private boolean isConnected() {
        boolean[][] used = new boolean[N][N];
        int[] count = new int[K];
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                if (used[r][c]) {
                    continue;
                }
                dfs(r, c, ids[r][c], used);
                count[ids[r][c]]++;
                if (count[ids[r][c]] > 1) {
                    return false;
                }
            }
        }
        return true;
    }

    private void dfs(int r, int c, int id, boolean[][] used) {
        used[r][c] = true;
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (!isValid(nr, 0, N) || !isValid(nc, 0, N)) {
                continue;
            }
            if (ids[nr][nc] != id) {
                continue;
            }
            if (used[nr][nc]) {
                continue;
            }
            dfs(nr, nc, id, used);
        }
    }

    private void dfs(int r, int c, int id, InitializedBooleanArray_v2 used) {
        used.set(r * N + c, true);
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (!isValid(nr, 0, N) || !isValid(nc, 0, N)) {
                continue;
            }
            if (ids[nr][nc] != id) {
                continue;
            }
            if (used.get(nr * N + nc)) {
                continue;
            }
            dfs(nr, nc, id, used);
        }
    }

    private void write() {
        StringBuilder sb = new StringBuilder();
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                if (c > 0) {
                    sb.append(" ");
                }
                sb.append(ids[r][c]);
            }
            sb.append("\n");
        }
        System.out.print(sb.toString());
        System.out.flush();
    }

}

final class Utils {
    private Utils() {
    }

    public static final void debug(Object... o) {
        System.err.println(toString(o));
    }

    public static final String toString(Object... o) {
        return Arrays.deepToString(o);
    }

}

class Watch {
    private long start;

    public Watch() {
        init();
    }

    public double getSecond() {
        return (System.nanoTime() - start) * 1e-9;
    }

    public void init() {
        init(System.nanoTime());
    }

    private void init(long start) {
        this.start = start;
    }

    public String getSecondString() {
        return toString(getSecond());
    }

    public static final String toString(double second) {
        if (second < 60) {
            return String.format("%5.2fs", second);
        } else if (second < 60 * 60) {
            int minute = (int) (second / 60);
            return String.format("%2dm%2ds", minute, (int) (second % 60));
        } else {
            int hour = (int) (second / (60 * 60));
            int minute = (int) (second / 60);
            return String.format("%2dh%2dm%2ds", hour, minute % (60), (int) (second % 60));
        }
    }

}

class XorShift {
    private int w = 88675123;
    private int x = 123456789;
    private int y = 362436069;
    private int z = 521288629;

    public XorShift(long l) {
        x = (int) l;
    }

    public int nextInt() {
        final int t = x ^ (x << 11);
        x = y;
        y = z;
        z = w;
        w = w ^ (w >>> 19) ^ (t ^ (t >>> 8));
        return w;
    }

    public long nextLong() {
        return ((long) nextInt() << 32) ^ (long) nextInt();
    }

    public double nextDouble() {
        return (nextInt() >>> 1) * 4.6566128730773926E-10;
    }

    public int nextInt(int n) {
        return (int) (n * nextDouble());
    }

}

class SAState {

    public static final boolean useTime = true;

    public double startTime = 0;
    public double endTime = 9.5;
    public double time = startTime;
    public double startTemperature = 30;
    public double endTemperature = 1e-9;
    public double inverseTemperature = 1.0 / startTemperature;
    public double lastAcceptTemperature = startTemperature;

    public double startRange = 1000;
    public double endRange = 3;
    public double range = startRange;

    public int numIterations;
    public int validIterations;
    public int acceptIterations;

    private static final double[] log = new double[32768];
    static {
        for (int i = 0; i < log.length; i++) {
            log[i] = Math.log((i + 0.5) / log.length);
        }
    }

    public void init() {
        numIterations = 0;
        validIterations = 0;
        acceptIterations = 0;

        startTime = useTime ? Main.watch.getSecond() : numIterations;

        update();
        lastAcceptTemperature = inverseTemperature;
    }

    public void update() {
        updateTime();
        updateTemperature();
        updateRange();
    }

    public void updateTemperature() {
        double time0to1 = elapsedPercentage(startTime, endTime, time);
        double startY = startTemperature;
        double endY = endTemperature;
        double temperature = interpolate(startY, endY, time0to1);
        inverseTemperature = 1.0 / temperature;
    }

    private double elapsedPercentage(double min, double max, double v) {
        return (v - min) / (max - min);
    }

    private double interpolate(double v0, double v1, double d0to1) {
        return v0 + (v1 - v0) * d0to1;
    }

    public void updateRange() {
        range = endRange + (startRange - endRange) * Math.pow((endTime - time) / (endTime - startTime), 1.0);
    }

    public void updateTime() {
        time = useTime ? Main.watch.getSecond() : numIterations;
    }

    public boolean isTLE() {
        return time >= endTime;
    }

    public boolean accept(double deltaScore) {
        return acceptS(deltaScore);
    }

    public boolean acceptHC(double deltaScore) {
        return acceptHCS(deltaScore);
    }

    public boolean acceptB(double deltaScore) {
        validIterations++;

        if (deltaScore > -1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore < 0 : Utils.toString(deltaScore);
        assert 1.0 / inverseTemperature >= 0;

        double d = deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

    public boolean acceptS(double deltaScore) {
        validIterations++;

        if (deltaScore < 1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore > 0;
        assert 1.0 / inverseTemperature >= 0;

        double d = -deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

    public boolean acceptHCS(double deltaScore) {
        validIterations++;

        if (deltaScore < 1e-9) {
            acceptIterations++;
            return true;
        }
        return false;
    }

}

class InitializedBooleanArray_v2 {
    private boolean[] values;
    private int[] epoch;
    private int current_epoch;

    public void init(int size) {
        current_epoch = 0;
        values = new boolean[size];
        epoch = new int[size];
    }

    public void clear() {
        current_epoch++;
    }

    public boolean get(int at) {
        assert (at < values.length);
        if (epoch[at] != current_epoch) {
            epoch[at] = current_epoch;
            values[at] = false;
        }
        return values[at];
    }

    public void set(int at, boolean value) {
        assert (at < values.length);
        epoch[at] = current_epoch;
        values[at] = value;
    }
}

class IntSet {
    private static final int EMPTY = -1;
    private int size;
    private int[] indexToValue;
    private int[] valueToIndex;

    public IntSet(int capacity) {
        this.size = 0;
        indexToValue = new int[capacity];
        valueToIndex = new int[capacity];
        Arrays.fill(valueToIndex, EMPTY);
    }

    public boolean add(int value) {
        if (valueToIndex[value] != EMPTY) {
            return false;
        }
        indexToValue[size] = value;
        valueToIndex[value] = size;
        size++;
        return true;
    }

    public boolean remove(int value) {
        int index = indexOf(value);
        if (index == EMPTY) {
            return false;
        }
        removeByIndex(index);
        return true;
    }

    private boolean removeByIndex(int index) {
        if (size == 0) {
            return false;
        }
        assert index < size;
        size--;
        int value = indexToValue[index];
        int value2 = indexToValue[size];
        indexToValue[index] = value2;
        valueToIndex[value2] = index;

        indexToValue[size] = value;
        valueToIndex[value] = EMPTY;

        return true;
    }

    public void swap(int index, int index2) {
        assert index < size;
        assert index2 < size;

        int swap = indexToValue[index];
        indexToValue[index] = indexToValue[index2];
        indexToValue[index2] = swap;

        valueToIndex[indexToValue[index]] = index;
        valueToIndex[indexToValue[index2]] = index2;

    }

    public void swapValue(int value, int value2) {
        assert value < size;
        assert value2 < size;

        int swap = valueToIndex[value];
        valueToIndex[value] = valueToIndex[value2];
        valueToIndex[value2] = swap;

        indexToValue[valueToIndex[value]] = value;
        indexToValue[valueToIndex[value2]] = value2;

    }

    public int get(int index) {
        assert index < size;
        return indexToValue[index];
    }

    public int indexOf(int value) {
        return valueToIndex[value];
    }

    public int size() {
        return size;
    }

    public boolean isEmpty() {
        return size() <= 0;
    }

    public void clear() {
        for (; size() > 0;) {
            removeByIndex(0);
        }
    }

    public boolean contains(int value) {
        return indexOf(value) != EMPTY;
    }

    @Override
    public String toString() {
        return Arrays.toString(Arrays.copyOf(indexToValue, size()));
    }
}




Approach 焼きなまし法を使いました。

  • 温度 : 30 -> 0 (線形)
  • 近傍1 : 端点を削除、端点に点を追加(経路上の点を追加する時は、閉路にならないように辺をカットする)
  • 近傍2 : 端点以外を追加、削除
    • 0.1% 未満の改善
  • 近傍3 : 端点以外を移動
    • 0.1% ほど改善
  • 多スタート(スタート時の温度を線形に下げる) : 10回
    • 0.1% ほど改善

感想
  • ビームサーチでほとんどすべての点を使った経路を作ろうとしたけど難しかった。
  • 端点付近しか変化しないので、経路の一部を削除して再構成する近傍も試したけどあまりよくなかった。
  • 端点付近しか変化しないので、複数の経路になるのを許容して、端点以外も変化させてみたら、1つの経路になりにくかった。


Source Code 

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;

public class Main {

    private int N;
    private int M;
    private int[] x;
    private int[] y;
    private int[] a;
    private int[] b;

    final static XorShift rng = new XorShift(System.nanoTime());
    final static Watch watch = new Watch();
    private SAState sa = new SAState();

    private ArrayList[] adjacentVertexes;
    private int[] adjacentVertexesIndex;

    private IntSet unusedVertexes;

    private double score;
    private Path path;

    private double bestScore;
    private Path bestPath;

    private double[][] distances;

    private int[] adjacentVertexesIndex2;

    public static void main(String[] args) {
        new Main().run();
    }

    private void run() {
        read();
        init();
        solve();
        write();
    }

    private void read() {
        try (Scanner sc = new Scanner(System.in)) {
            this.N = sc.nextInt();
            this.M = sc.nextInt();
            this.x = new int[N];
            this.y = new int[N];
            for (int i = 0; i < N; i++) {
                x[i] = sc.nextInt();
                y[i] = sc.nextInt();
            }
            this.a = new int[M];
            this.b = new int[M];
            for (int i = 0; i < M; i++) {
                a[i] = sc.nextInt();
                b[i] = sc.nextInt();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void init() {
        adjacentVertexes = new ArrayList[N];
        for (int i = 0; i < N; i++) {
            adjacentVertexes[i] = new ArrayList();
        }
        for (int i = 0; i < M; i++) {
            adjacentVertexes[a[i]].add(b[i]);
            adjacentVertexes[b[i]].add(a[i]);
        }

        unusedVertexes = new IntSet(N);
        for (int i = 0; i < N; i++) {
            unusedVertexes.add(i);
        }

        path = new Path(N);
        bestPath = new Path(N);

        distances = new double[N][N];
        for (int i = 0; i < N; i++) {
            for (int j = i + 1; j < N; j++) {
                int dx = x[i] - x[j];
                int dy = y[i] - y[j];
                distances[i][j] = Math.sqrt(dx * dx + dy * dy);
                distances[j][i] = distances[i][j];
            }
        }

        adjacentVertexesIndex = new int[N];
        adjacentVertexesIndex2 = new int[N];

        Utils.debug("N", N);
    }

    public void solve() {
        greedy();
        multiSA();
    }

    private void greedy() {
        int bestVertex = -1;
        double best = 1e9;
        for (int vertex = 0; vertex < N; vertex++) {
            int dx = x[vertex] - 500;
            int dy = y[vertex] - 500;
            double distance = Math.sqrt(dx * dx + dy * dy);
            if (distance < best) {
                best = distance;
                bestVertex = vertex;
            }
        }

        unusedVertexes.remove(bestVertex);
        path.add(bestVertex);

        score = calculateScore();
        saveBest();
    }

    private void multiSA() {
        int numRestart = 10;

        double endTime = 9.5;
        double startTime = watch.getSecond();
        double remainTime = endTime - startTime;

        double startStartTemperature = 30;
        double endStartTemperature = 0;
        for (double restart = 0; restart < numRestart; restart++) {
            sa.startTime = startTime + remainTime * restart / numRestart;
            sa.endTime = startTime + remainTime * (restart + 1) / numRestart;
            sa.startTemperature = endStartTemperature + (startStartTemperature - endStartTemperature) * ((numRestart - restart) / numRestart);
            SA();
        }
    }

    private void SA() {
        sa.init();
        for (;; ++sa.numIterations) {
            if ((sa.numIterations & ((1 << 10) - 1)) == 0) {
                sa.update();

                if (sa.isTLE()) {
                    loadBest();
                    Utils.debug(sa.numIterations, String.format("%.2f%%", 100.0 * sa.validIterations / sa.numIterations), String.format("%.2f%%", 100.0 * sa.acceptIterations / sa.validIterations), String.format("%5.0f", score), String.format("%5.0f", bestScore), String.format("%.6f", 1.0 / sa.inverseTemperature), String.format("%.6f", 1.0 / sa.lastAcceptTemperature), "time", watch.getSecondString());
                    break;
                }
            }

            mutate();
        }
    }

    private void mutate() {
        double random = 13 * rng.nextDouble();
        if (random < 1) {
            addVertex();
        } else if (random < 2) {
            removeVertex();
        } else if (random < 3) {
            moveVertex();
        } else if (random < 13) {
            changeEndPoint();
        }
    }

    private void changeEndPoint() {
        boolean changeStartPoint = rng.nextDouble() < 0.5;
        if (changeStartPoint) {
            if (path.size() == 0) {
                return;
            }
            int vertex = path.get(0);

            if (adjacentVertexes[vertex].size() <= 1) {
                if (path.size() < 2) {
                    return;
                }

                double deltaScore = -distances[path.get(0)][path.get(1)];

                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    int v = path.remove(0);
                    unusedVertexes.add(v);

                    saveBest();
                }
                return;
            }

            adjacentVertexesIndex[vertex] = (adjacentVertexesIndex[vertex] + 1) % adjacentVertexes[vertex].size();
            int vertex2 = adjacentVertexes[vertex].get(adjacentVertexesIndex[vertex]).intValue();

            if (!unusedVertexes.contains(vertex2)) {
                int index = path.indexOf(vertex2);
                double deltaScore = distances[vertex][vertex2] - distances[vertex2][path.get(index - 1)];
                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    path.reverse(0, index - 1);

                    saveBest();
                }
            } else {
                double deltaScore = distances[vertex][vertex2];
                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    unusedVertexes.remove(vertex2);
                    path.add(0, vertex2);

                    saveBest();
                }
            }
        } else {
            if (path.size() == 0) {
                return;
            }
            int vertex = path.get(path.size() - 1);

            if (adjacentVertexes[vertex].size() <= 1) {
                if (path.size() < 2) {
                    return;
                }

                double deltaScore = -distances[path.get(path.size() - 2)][path.get(path.size() - 1)];
                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    int v = path.remove(path.size() - 1);
                    unusedVertexes.add(v);

                    saveBest();
                }
                return;
            }

            adjacentVertexesIndex[vertex] = (adjacentVertexesIndex[vertex] + 1) % adjacentVertexes[vertex].size();
            int vertex2 = adjacentVertexes[vertex].get(adjacentVertexesIndex[vertex]).intValue();

            if (!unusedVertexes.contains(vertex2)) {
                int index = path.indexOf(vertex2);
                double deltaScore = distances[vertex][vertex2] - distances[vertex2][path.get(index + 1)];
                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    path.reverse(index + 1, path.size() - 1);

                    saveBest();
                }
            } else {
                double deltaScore = distances[vertex][vertex2];
                if (sa.accept(deltaScore)) {
                    score += deltaScore;

                    unusedVertexes.remove(vertex2);
                    path.add(vertex2);

                    saveBest();
                }
            }
        }
    }

    private void removeVertex() {
        if (path.size() <= 2) {
            return;
        }
        int index = 1 + (int) ((path.size() - 2) * rng.nextDouble());
        int vertexPrev = path.get(index - 1);
        int vertex = path.get(index);
        int vertexNext = path.get(index + 1);
        double deltaScore = distances[vertexPrev][vertexNext] - distances[vertexPrev][vertex] - distances[vertex][vertexNext];
        if (sa.accept(deltaScore)) {
            if (!adjacentVertexes[vertexPrev].contains(vertexNext)) {
                return;
            }

            score += deltaScore;

            int v = path.remove(index);
            unusedVertexes.add(v);

            saveBest();
        }
    }

    private void addVertex() {
        if (unusedVertexes.size() == 0) {
            return;
        }
        int vertex = unusedVertexes.get((int) (unusedVertexes.size() * rng.nextDouble()));
        ArrayList vs = adjacentVertexes[vertex];
        if (vs.size() == 0) {
            return;
        }
        adjacentVertexesIndex2[vertex] = (adjacentVertexesIndex2[vertex] + 1) % vs.size();
        int vertexPrev = vs.get(adjacentVertexesIndex2[vertex]).intValue();
        if (unusedVertexes.contains(vertexPrev)) {
            return;
        }
        int indexPrev = path.indexOf(vertexPrev);
        if (indexPrev == path.size() - 1) {
            return;
        }
        int vertexNext = path.get(indexPrev + 1);
        double deltaScore = distances[vertexPrev][vertex] + distances[vertex][vertexNext] - distances[vertexPrev][vertexNext];
        if (sa.accept(deltaScore)) {
            if (!vs.contains(vertexNext)) {
                return;
            }

            score += deltaScore;

            path.add(indexPrev + 1, vertex);
            unusedVertexes.remove(vertex);

            saveBest();
        }
    }

    private void moveVertex() {
        if (path.size() <= 2) {
            return;
        }
        int index = 1 + (int) ((path.size() - 2) * rng.nextDouble());
        int vertexPrev = path.get(index - 1);
        int vertex = path.get(index);
        int vertexNext = path.get(index + 1);

        ArrayList vs = adjacentVertexes[vertex];
        adjacentVertexesIndex2[vertex] = (adjacentVertexesIndex2[vertex] + 1) % vs.size();
        int vertexPrev2 = vs.get(adjacentVertexesIndex2[vertex]).intValue();
        if (unusedVertexes.contains(vertexPrev2) || vertexPrev2 == vertexPrev) {
            return;
        }
        int indexPrev2 = path.indexOf(vertexPrev2);
        if (indexPrev2 == path.size() - 1) {
            return;
        }
        int vertexNext2 = path.get(indexPrev2 + 1);
        double deltaScore = distances[vertexPrev][vertexNext] - distances[vertexPrev][vertex] - distances[vertex][vertexNext];
        deltaScore += distances[vertexPrev2][vertex] + distances[vertex][vertexNext2] - distances[vertexPrev2][vertexNext2];
        if (sa.accept(deltaScore)) {
            if (!adjacentVertexes[vertexPrev].contains(vertexNext)) {
                return;
            }
            if (!vs.contains(vertexNext2)) {
                return;
            }

            score += deltaScore;

            int v = path.remove(index);
            if (index < indexPrev2 + 1) {
                path.add(indexPrev2, vertex);
            } else {
                path.add(indexPrev2 + 1, vertex);
            }

            saveBest();
        }
    }

    private void write() {
        StringBuilder sb = new StringBuilder();
        sb.append(path.size()).append("\n");
        for (int i = 0; i < path.size(); i++) {
            sb.append(path.get(i)).append("\n");
        }
        System.out.print(sb.toString());
        System.out.flush();
    }

    private double calculateScore() {
        double score = 0;
        for (int i = 1; i < path.size(); i++) {
            int vertex = path.get(i - 1);
            int nextVertex = path.get(i);
            score += distances[vertex][nextVertex];
        }
        return score;
    }

    private void saveBest() {
        if (score > bestScore) {
            bestScore = score;
            bestPath.clear();
            for (int i = 0; i < path.size(); i++) {
                bestPath.add(path.get(i));
            }
        }
    }

    private void loadBest() {
        score = bestScore;
        path.clear();
        unusedVertexes.clear();
        for (int i = 0; i < N; i++) {
            unusedVertexes.add(i);
        }
        for (int i = 0; i < bestPath.size(); i++) {
            int v = bestPath.get(i);
            path.add(v);
            unusedVertexes.remove(v);
        }
    }

}

final class Utils {
    private Utils() {
    }

    public static final void debug(Object... o) {
        System.err.println(toString(o));
    }

    public static final String toString(Object... o) {
        return Arrays.deepToString(o);
    }

}

class Watch {
    private long start;

    public Watch() {
        init();
    }

    public double getSecond() {
        return (System.nanoTime() - start) * 1e-9;
    }

    public void init() {
        init(System.nanoTime());
    }

    private void init(long start) {
        this.start = start;
    }

    public String getSecondString() {
        return toString(getSecond());
    }

    public static final String toString(double second) {
        if (second < 60) {
            return String.format("%5.2fs", second);
        } else if (second < 60 * 60) {
            int minute = (int) (second / 60);
            return String.format("%2dm%2ds", minute, (int) (second % 60));
        } else {
            int hour = (int) (second / (60 * 60));
            int minute = (int) (second / 60);
            return String.format("%2dh%2dm%2ds", hour, minute % (60), (int) (second % 60));
        }
    }

}

class XorShift {
    private int w = 88675123;
    private int x = 123456789;
    private int y = 362436069;
    private int z = 521288629;

    public XorShift(long l) {
        x = (int) l;
    }

    public int nextInt() {
        final int t = x ^ (x << 11);
        x = y;
        y = z;
        z = w;
        w = w ^ (w >>> 19) ^ (t ^ (t >>> 8));
        return w;
    }

    public long nextLong() {
        return ((long) nextInt() << 32) ^ (long) nextInt();
    }

    public double nextDouble() {
        return (nextInt() >>> 1) * 4.6566128730773926E-10;
    }

    public int nextInt(int n) {
        return (int) (n * nextDouble());
    }

}

class SAState {

    public static final boolean useTime = true;

    public double startTime = 0;
    public double endTime = 9.5;
    public double time = startTime;
    public double startTemperature = 30;
    public double endTemperature = 1e-9;
    public double inverseTemperature = 1.0 / startTemperature;
    public double lastAcceptTemperature = startTemperature;

    public double startRange = 700;
    public double endRange = 1;
    public double range = startRange;

    public int numIterations;
    public int validIterations;
    public int acceptIterations;

    private static final double[] log = new double[32768];
    static {
        for (int i = 0; i < log.length; i++) {
            log[i] = Math.log((i + 0.5) / log.length);
        }
    }

    public void init() {
        numIterations = 0;
        validIterations = 0;
        acceptIterations = 0;

        startTime = useTime ? Main.watch.getSecond() : numIterations;

        update();
        lastAcceptTemperature = inverseTemperature;
    }

    public void update() {
        updateTime();
        updateTemperature();
        updateRange();
    }

    public void updateTemperature() {
        double time0to1 = elapsedPercentage(startTime, endTime, time);
        double startY = startTemperature;
        double endY = endTemperature;
        double temperature = interpolate(startY, endY, time0to1);
        inverseTemperature = 1.0 / temperature;
    }

    private double elapsedPercentage(double min, double max, double v) {
        return (v - min) / (max - min);
    }

    private double interpolate(double v0, double v1, double d0to1) {
        return v0 + (v1 - v0) * d0to1;
    }

    public void updateRange() {
        range = endRange + (startRange - endRange) * Math.pow((endTime - time) / (endTime - startTime), 1.0);
    }

    public void updateTime() {
        time = useTime ? Main.watch.getSecond() : numIterations;
    }

    public boolean isTLE() {
        return time >= endTime;
    }

    public boolean accept(double deltaScore) {
        return acceptB(deltaScore);
    }

    public boolean acceptB(double deltaScore) {
        validIterations++;

        if (deltaScore > -1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore < 0 : Utils.toString(deltaScore);
        assert 1.0 / inverseTemperature >= 0;

        double d = deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

    public boolean acceptS(double deltaScore) {
        validIterations++;

        if (deltaScore < 1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore > 0;
        assert 1.0 / inverseTemperature >= 0;

        double d = -deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

}

class Path {
    private int[] solution;
    private int[] indexOf;
    private int size;

    public Path(int capacity) {
        solution = new int[capacity];
        indexOf = new int[capacity];
        Arrays.fill(indexOf, -1);
        size = 0;
    }

    public void clear() {
        size = 0;
        Arrays.fill(indexOf, -1);
    }

    public void reverse(int i, int j) {
        for (; i < j; i++, j--) {
            swap(solution, i, j);
            indexOf[solution[i]] = i;
            indexOf[solution[j]] = j;
        }
    }

    private void swap(int[] a, int i, int j) {
        int swap = a[i];
        a[i] = a[j];
        a[j] = swap;
    }

    public int indexOf(int vertex) {
        return indexOf[vertex];
    }

    public int remove(int index) {
        int v = solution[index];
        indexOf[v] = -1;
        size--;
        for (int i = index; i < size; i++) {
            solution[i] = solution[i + 1];
            indexOf[solution[i]] = i;
        }
        return v;
    }

    public int get(int index) {
        return solution[index];
    }

    public int size() {
        return size;
    }

    public void add(int vertex) {
        solution[size] = vertex;
        indexOf[vertex] = size;
        size++;
    }

    public void add(int index, int vertex) {
        for (int i = size - 1; i >= index; i--) {
            solution[i + 1] = solution[i];
            indexOf[solution[i + 1]] = i + 1;
        }
        solution[index] = vertex;
        indexOf[solution[index]] = index;
        size++;
    }
}

class IntSet {
    private static final int EMPTY = -1;
    private int size;
    private int[] indexToValue;
    private int[] valueToIndex;

    public IntSet(int capacity) {
        this.size = 0;
        indexToValue = new int[capacity];
        valueToIndex = new int[capacity];
        Arrays.fill(valueToIndex, EMPTY);
    }

    public boolean add(int value) {
        if (valueToIndex[value] != EMPTY) {
            return false;
        }
        indexToValue[size] = value;
        valueToIndex[value] = size;
        size++;
        return true;
    }

    public boolean remove(int value) {
        int index = indexOf(value);
        if (index == EMPTY) {
            return false;
        }
        removeByIndex(index);
        return true;
    }

    private boolean removeByIndex(int index) {
        if (size == 0) {
            return false;
        }
        assert index < size;
        size--;
        int value = indexToValue[index];
        int value2 = indexToValue[size];
        indexToValue[index] = value2;
        valueToIndex[value2] = index;

        indexToValue[size] = value;
        valueToIndex[value] = EMPTY;

        return true;
    }

    public void swap(int index, int index2) {
        assert index < size;
        assert index2 < size;

        int swap = indexToValue[index];
        indexToValue[index] = indexToValue[index2];
        indexToValue[index2] = swap;

        valueToIndex[indexToValue[index]] = index;
        valueToIndex[indexToValue[index2]] = index2;

    }

    public void swapValue(int value, int value2) {
        assert value < size;
        assert value2 < size;

        int swap = valueToIndex[value];
        valueToIndex[value] = valueToIndex[value2];
        valueToIndex[value2] = swap;

        indexToValue[valueToIndex[value]] = value;
        indexToValue[valueToIndex[value2]] = value2;

    }

    public int get(int index) {
        assert index < size;
        return indexToValue[index];
    }

    public int indexOf(int value) {
        return valueToIndex[value];
    }

    public int size() {
        return size;
    }

    public boolean isEmpty() {
        return size() <= 0;
    }

    public void clear() {
        for (; size() > 0;) {
            removeByIndex(0);
        }
    }

    public boolean contains(int value) {
        return indexOf(value) != EMPTY;
    }

    @Override
    public String toString() {
        return Arrays.toString(Arrays.copyOf(indexToValue, size()));
    }
}




Approach 焼きなまし法を使いました。

  • 初期解 : 入力された順に並べて、ランダムに回転する。
  • 近傍 : ランダムに2点swap して、回転は4つ試して一番良いものを選ぶ。
    • 3点swap を試したら悪くなった。
  • 温度 : 線形に 0.25 -> 0.1
  • 4辺が一致していないセルを保持する。そこから交換するセルを選ぶ。
    • 1% ほど改善した。
  • ランダムに2点swap -> 4辺が一致していないセルの隣のパネルを色が一致するパネルと swap する。
    • 3% ほど改善した。
    • 回転は片方だけでよいので、高速化もできた。
  • 各セルのスコアを保持して高速化する。
    • 0.5% ほど改善した。
  • 4辺が一致していないセルの隣のパネルを色が一致するパネルと swap する。隣のパネルは色が一致していないものに限定する。
    • N が小さいケースで悪くなったが、平均 1% ほど改善した。
  • 4辺が一致していないセルの隣のパネルを色が一致するパネルと swap する。隣のパネルは色が一致していないものに限定する。(N <=10 のときは限定しない)
    • 0.3% ほど改善した。(N <=10 のとき)



Source Code
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    private static final int[] dr = { -1, 0, 1, 0, };
    private static final int[] dc = { 0, 1, 0, -1, };

    private int N;
    private int C;

    final static XorShift rng = new XorShift(System.nanoTime());
    final static Watch watch = new Watch();
    private SAState sa = new SAState();

    private double score;
    private int[][] rxcToIndex;
    private int[][] rxcToNumRotates;

    private double bestScore;
    private int[][] bestRxcToIndex;
    private int[][] bestRxcToNumRotates;

    private int[][] panelxDirectionToColor;

    private IntSet notPerfectSet;
    private int[] indexToRxc;

    private ArrayList[] colorToPanels;

    private int[][] rxcToScore;

    public static void main(String[] args) {
        new Main().run();
    }

    private void run() {
        read();
        solve();
        write();
    }

    private void read() {
        try (Scanner sc = new Scanner(System.in)) {

            N = sc.nextInt();
            C = sc.nextInt();
            int PN = N * N;
            panelxDirectionToColor = new int[PN][4];
            for (int i = 0; i < PN; i++) {
                int U = sc.nextInt();
                int D = sc.nextInt();
                int L = sc.nextInt();
                int R = sc.nextInt();
                panelxDirectionToColor[i][0] = U;
                panelxDirectionToColor[i][1] = R;
                panelxDirectionToColor[i][2] = D;
                panelxDirectionToColor[i][3] = L;
            }

            rxcToIndex = new int[N][N];
            rxcToNumRotates = new int[N][N];
            bestRxcToIndex = new int[N][N];
            bestRxcToNumRotates = new int[N][N];

            notPerfectSet = new IntSet(N * N);
            indexToRxc = new int[N * N];

            colorToPanels = new ArrayList[C];
            for (int color = 0; color < C; color++) {
                colorToPanels[color] = new ArrayList<>();
            }
            for (int panel = 0; panel < N * N; panel++) {
                for (int d = 0; d < 4; d++) {
                    int color = panelxDirectionToColor[panel][d];
                    colorToPanels[color].add(new NearPanel(panel, d, 0));
                }
            }

            rxcToScore = new int[N][N];

        } catch (Exception e) {
            e.printStackTrace();
        }
        Utils.debug("N", N, "C", C);
    }

    public void solve() {
        greedy();
        SA();
    }

    private void greedy() {
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                rxcToIndex[r][c] = r * N + c;
                rxcToNumRotates[r][c] = (int) (4 * rng.nextDouble());
                indexToRxc[rxcToIndex[r][c]] = r * N + c;
            }
        }

        score = calculateScore();
        saveBest();
        Utils.debug("greedy", "score", score, "time", watch.getSecondString());
    }

    private void SA() {
        double second = (int) watch.getSecond();
        sa.init();
        for (;; ++sa.numIterations) {
            if ((sa.numIterations & ((1 << 10) - 1)) == 0) {
                sa.update();

                if (sa.isTLE()) {
                    loadBest();
                    Utils.debug(sa.numIterations, String.format("%.2f%%", 100.0 * sa.validIterations / sa.numIterations), String.format("%.2f%%", 100.0 * sa.acceptIterations / sa.validIterations), String.format("%5.0f", score), String.format("%5.0f", bestScore), String.format("%.6f", 1.0 / sa.inverseTemperature), String.format("%.6f", 1.0 / sa.lastAcceptTemperature), "time", watch.getSecondString());
                    break;
                }

                if (watch.getSecond() > second) {
                    second++;
                    Utils.debug(sa.numIterations, String.format("%.2f%%", 100.0 * sa.validIterations / sa.numIterations), String.format("%.2f%%", 100.0 * sa.acceptIterations / sa.validIterations), String.format("%5.0f", score), String.format("%5.0f", bestScore), String.format("%.6f", 1.0 / sa.inverseTemperature), String.format("%.6f", 1.0 / sa.lastAcceptTemperature), "time", watch.getSecondString());
                }
            }

            mutate();
        }
        Utils.debug("SA", "score", score, "time", watch.getSecondString());
    }

    private void mutate() {
        if (N <= 10) {
            swap2();
        } else {
            swap2selectMismatch();
        }
    }

    private void swap2() {
        int v = notPerfectSet.get((int) (notPerfectSet.size() * rng.nextDouble()));
        int r0 = v / N;
        int c0 = v % N;

        int direction0 = (int) (4 * rng.nextDouble());
        int r = r0 + dr[direction0];
        int c = c0 + dc[direction0];
        if (r < 0 || r >= N || c < 0 || c >= N) {
            return;
        }

        int panel0 = rxcToIndex[r0][c0];
        int rotate0 = rxcToNumRotates[r0][c0];
        int color0 = panelxDirectionToColor[panel0][(direction0 - rotate0 + 4) % 4];

        ArrayList panels = colorToPanels[color0];

        NearPanel p = panels.get((int) (panels.size() * rng.nextDouble()));

        int v2 = indexToRxc[p.panel];
        int r2 = v2 / N;
        int c2 = v2 % N;
        assert rxcToIndex[r2][c2] == p.panel;
        if (r == r2 && c == c2) {
            return;
        }
        if (r0 == r2 && c0 == c2) {
            return;
        }

        int currentRotate = rxcToNumRotates[r][c];
        int currentRotate2 = rxcToNumRotates[r2][c2];

        int direction12 = getDirection(r, c, r2, c2);

        int before = rxcToScore[r][c] + rxcToScore[r2][c2];
        if (direction12 != -1) {
            int panel = rxcToIndex[r][c];
            int rotate = rxcToNumRotates[r][c];
            int npanel = rxcToIndex[r2][c2];
            int nrotate = rxcToNumRotates[r2][c2];
            if (panelxDirectionToColor[panel][(direction12 - rotate + 4) % 4] == panelxDirectionToColor[npanel][((direction12 + 2) - nrotate + 4) % 4]) {
                before--;
            }
        }

        swap(r, c, r2, c2);
        rxcToNumRotates[r][c] = ((direction0 + 2) - p.rotate + 4) % 4;
        calculateBestRotate(r2, c2);

        int scoreRC = calculateScore(r, c);
        int scoreRC2 = calculateScore(r2, c2);
        int after = scoreRC + scoreRC2;
        if (direction12 != -1) {
            int panel = rxcToIndex[r][c];
            int rotate = rxcToNumRotates[r][c];
            int npanel = rxcToIndex[r2][c2];
            int nrotate = rxcToNumRotates[r2][c2];
            if (panelxDirectionToColor[panel][(direction12 - rotate + 4) % 4] == panelxDirectionToColor[npanel][((direction12 + 2) - nrotate + 4) % 4]) {
                after--;
            }
        }

        int deltaScore = after - before;

        if (sa.accept(deltaScore)) {
            score += deltaScore;

            update(r, c);
            update(r2, c2);

            updateScore(r, c, scoreRC);
            updateScore(r2, c2, scoreRC2);

            saveBest();
        } else {
            rxcToNumRotates[r][c] = currentRotate;
            rxcToNumRotates[r2][c2] = currentRotate2;
            swap(r, c, r2, c2);
        }
    }

    private void swap2selectMismatch() {
        int v = notPerfectSet.get((int) (notPerfectSet.size() * rng.nextDouble()));
        int r0 = v / N;
        int c0 = v % N;

        int direction0 = (int) (4 * rng.nextDouble());
        int r = r0 + dr[direction0];
        int c = c0 + dc[direction0];
        int direction02 = direction0;
        if (r < 0 || r >= N || c < 0 || c >= N || isSameColor(r0, c0, direction0, r, c)) {
            int direction1 = (int) (3 * rng.nextDouble());
            if (direction1 >= direction0) {
                direction1++;
            }
            r = r0 + dr[direction1];
            c = c0 + dc[direction1];
            direction02 = direction1;
            if (r < 0 || r >= N || c < 0 || c >= N || isSameColor(r0, c0, direction1, r, c)) {
                int direction2 = (int) (2 * rng.nextDouble());
                if (direction2 >= direction0) {
                    direction2++;
                }
                if (direction2 >= direction1) {
                    direction2++;
                }
                r = r0 + dr[direction2];
                c = c0 + dc[direction2];
                direction02 = direction2;
                if (r < 0 || r >= N || c < 0 || c >= N || isSameColor(r0, c0, direction2, r, c)) {
                    int direction3 = (int) (1 * rng.nextDouble());
                    if (direction3 >= direction0) {
                        direction3++;
                    }
                    if (direction3 >= direction1) {
                        direction3++;
                    }
                    if (direction3 >= direction2) {
                        direction3++;
                    }
                    r = r0 + dr[direction3];
                    c = c0 + dc[direction3];
                    direction02 = direction3;
                    if (r < 0 || r >= N || c < 0 || c >= N || isSameColor(r0, c0, direction3, r, c)) {
                        return;
                    }
                }
            }
        }
        direction0 = direction02;

        int panel0 = rxcToIndex[r0][c0];
        int rotate0 = rxcToNumRotates[r0][c0];
        int color0 = panelxDirectionToColor[panel0][(direction0 - rotate0 + 4) % 4];

        ArrayList panels = colorToPanels[color0];

        NearPanel p = panels.get((int) (panels.size() * rng.nextDouble()));

        int v2 = indexToRxc[p.panel];
        int r2 = v2 / N;
        int c2 = v2 % N;
        assert rxcToIndex[r2][c2] == p.panel;
        if (r == r2 && c == c2) {
            return;
        }
        if (r0 == r2 && c0 == c2) {
            return;
        }

        int currentRotate = rxcToNumRotates[r][c];
        int currentRotate2 = rxcToNumRotates[r2][c2];

        int direction12 = getDirection(r, c, r2, c2);

        int before = rxcToScore[r][c] + rxcToScore[r2][c2];
        if (direction12 != -1) {
            int panel = rxcToIndex[r][c];
            int rotate = rxcToNumRotates[r][c];
            int npanel = rxcToIndex[r2][c2];
            int nrotate = rxcToNumRotates[r2][c2];
            if (panelxDirectionToColor[panel][(direction12 - rotate + 4) % 4] == panelxDirectionToColor[npanel][((direction12 + 2) - nrotate + 4) % 4]) {
                before--;
            }
        }

        swap(r, c, r2, c2);
        rxcToNumRotates[r][c] = ((direction0 + 2) - p.rotate + 4) % 4;
        calculateBestRotate(r2, c2);

        int scoreRC = calculateScore(r, c);
        int scoreRC2 = calculateScore(r2, c2);
        int after = scoreRC + scoreRC2;
        if (direction12 != -1) {
            int panel = rxcToIndex[r][c];
            int rotate = rxcToNumRotates[r][c];
            int npanel = rxcToIndex[r2][c2];
            int nrotate = rxcToNumRotates[r2][c2];
            if (panelxDirectionToColor[panel][(direction12 - rotate + 4) % 4] == panelxDirectionToColor[npanel][((direction12 + 2) - nrotate + 4) % 4]) {
                after--;
            }
        }

        int deltaScore = after - before;

        if (sa.accept(deltaScore)) {
            score += deltaScore;

            update(r, c);
            update(r2, c2);

            updateScore(r, c, scoreRC);
            updateScore(r2, c2, scoreRC2);

            saveBest();
        } else {
            rxcToNumRotates[r][c] = currentRotate;
            rxcToNumRotates[r2][c2] = currentRotate2;
            swap(r, c, r2, c2);
        }
    }

    private boolean isSameColor(int r0, int c0, int direction0, int r, int c) {
        return panelxDirectionToColor[rxcToIndex[r0][c0]][(direction0 - rxcToNumRotates[r0][c0] + 4) % 4] == panelxDirectionToColor[rxcToIndex[r][c]][((direction0 + 2) - rxcToNumRotates[r][c] + 4) % 4];
    }

    private void updateScore(int r, int c, int scoreRC) {
        rxcToScore[r][c] = scoreRC;
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (nr < 0 || nr >= N || nc < 0 || nc >= N) {
                continue;
            }
            rxcToScore[nr][nc] = calculateScore(nr, nc);
        }
    }

    private void calculateBestRotate(int r, int c) {
        int bestScore = (int) -1e9;
        int bestRotate = -1;
        for (int rotate = 0; rotate < 4; rotate++) {
            rxcToNumRotates[r][c] = rotate;
            int score = calculateScore(r, c);
            if (score > bestScore) {
                bestScore = score;
                bestRotate = rotate;
            }
        }
        rxcToNumRotates[r][c] = bestRotate;
    }

    private int getDirection(int r, int c, int r2, int c2) {
        if (r == r2) {
            if (c2 - c == 1) {
                return 1;
            } else if (c2 - c == -1) {
                return 3;
            }
        } else if (c == c2) {
            if (r2 - r == 1) {
                return 2;
            } else if (r2 - r == -1) {
                return 0;
            }
        }
        return -1;
    }

    private void update(int r, int c) {
        if (calculateScore(r, c) < 4) {
            notPerfectSet.add(r * N + c);
        } else {
            notPerfectSet.remove(r * N + c);
        }
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (nr < 0 || nr >= N || nc < 0 || nc >= N) {
                continue;
            }
            if (calculateScore(nr, nc) < 4) {
                notPerfectSet.add(nr * N + nc);
            } else {
                notPerfectSet.remove(nr * N + nc);
            }
        }
    }

    private void swap(int r, int c, int r2, int c2) {
        int swap = rxcToIndex[r][c];
        rxcToIndex[r][c] = rxcToIndex[r2][c2];
        rxcToIndex[r2][c2] = swap;

        indexToRxc[rxcToIndex[r][c]] = r * N + c;
        indexToRxc[rxcToIndex[r2][c2]] = r2 * N + c2;
    }

    private void write() {
        final int x = 0;
        final int y = 1;
        final int rotate = 2;

        int[][] ret = new int[N * N][3];
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                int index = rxcToIndex[r][c];
                ret[index][x] = c;
                ret[index][y] = r;
                ret[index][rotate] = rxcToNumRotates[r][c];
            }
        }

        for (int i = 0; i < N * N; ++i) {
            System.out.println(ret[i][x] + " " + ret[i][y] + " " + ret[i][rotate]);
        }
        System.out.flush();
    }

    private int calculateScore() {
        int score = 0;
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                int calculateScore = calculateScore(r, c);
                rxcToScore[r][c] = calculateScore;
                score += calculateScore;
                if (calculateScore < 4) {
                    notPerfectSet.add(r * N + c);
                }
            }
        }

        return score / 2;
    }

    private int calculateScore(int r, int c) {
        int score = 0;
        for (int d = 0; d < dr.length; d++) {
            int nr = r + dr[d];
            int nc = c + dc[d];
            if (nr < 0 || nr >= N || nc < 0 || nc >= N) {
                continue;
            }

            int panel = rxcToIndex[r][c];
            int rotate = rxcToNumRotates[r][c];
            int npanel = rxcToIndex[nr][nc];
            int nrotate = rxcToNumRotates[nr][nc];
            if (panelxDirectionToColor[panel][(d - rotate + 4) % 4] == panelxDirectionToColor[npanel][((d + 2) - nrotate + 4) % 4]) {
                score++;
            }
        }
        return score;
    }

    private void saveBest() {
        if (score > bestScore) {
            bestScore = score;
            for (int r = 0; r < N; r++) {
                for (int c = 0; c < N; c++) {
                    bestRxcToIndex[r][c] = rxcToIndex[r][c];
                    bestRxcToNumRotates[r][c] = rxcToNumRotates[r][c];
                }
            }
        }
    }

    private void loadBest() {
        score = bestScore;
        for (int r = 0; r < N; r++) {
            for (int c = 0; c < N; c++) {
                rxcToIndex[r][c] = bestRxcToIndex[r][c];
                rxcToNumRotates[r][c] = bestRxcToNumRotates[r][c];
            }
        }
    }
}

final class Utils {
    private Utils() {
    }

    public static final void debug(Object... o) {
        System.err.println(toString(o));
    }

    public static final String toString(Object... o) {
        return Arrays.deepToString(o);
    }

}

class Watch {
    private long start;

    public Watch() {
        init();
    }

    public double getSecond() {
        return (System.nanoTime() - start) * 1e-9;
    }

    public void init() {
        init(System.nanoTime());
    }

    private void init(long start) {
        this.start = start;
    }

    public String getSecondString() {
        return toString(getSecond());
    }

    public static final String toString(double second) {
        if (second < 60) {
            return String.format("%5.2fs", second);
        } else if (second < 60 * 60) {
            int minute = (int) (second / 60);
            return String.format("%2dm%2ds", minute, (int) (second % 60));
        } else {
            int hour = (int) (second / (60 * 60));
            int minute = (int) (second / 60);
            return String.format("%2dh%2dm%2ds", hour, minute % (60), (int) (second % 60));
        }
    }

}

class XorShift {
    private int w = 88675123;
    private int x = 123456789;
    private int y = 362436069;
    private int z = 521288629;

    public XorShift(long l) {
        x = (int) l;
    }

    public int nextInt() {
        final int t = x ^ (x << 11);
        x = y;
        y = z;
        z = w;
        w = w ^ (w >>> 19) ^ (t ^ (t >>> 8));
        return w;
    }

    public long nextLong() {
        return ((long) nextInt() << 32) ^ (long) nextInt();
    }

    public double nextDouble() {
        return (nextInt() >>> 1) * 4.6566128730773926E-10;
    }

    public int nextInt(int n) {
        return (int) (n * nextDouble());
    }

}

class SAState {

    public static final boolean useTime = true;

    public double startTime = 0;
    public double endTime = 9.5;
    public double time = startTime;
    public double startTemperature = 0.25;
    public double endTemperature = 0.1;
    public double inverseTemperature = 1.0 / startTemperature;
    public double lastAcceptTemperature = startTemperature;

    public double startRange = 700;
    public double endRange = 1;
    public double range = startRange;

    public int numIterations;
    public int validIterations;
    public int acceptIterations;

    private static final double[] log = new double[32768];
    static {
        for (int i = 0; i < log.length; i++) {
            log[i] = Math.log((i + 0.5) / log.length);
        }
    }

    public void init() {
        numIterations = 0;
        validIterations = 0;
        acceptIterations = 0;

        startTime = useTime ? Main.watch.getSecond() : numIterations;

        update();
        lastAcceptTemperature = inverseTemperature;
    }

    public void update() {
        updateTime();
        updateTemperature();
        updateRange();
    }

    public void updateTemperature() {
        double time0to1 = elapsedPercentage(startTime, endTime, time);
        double startY = startTemperature;
        double endY = endTemperature;
        double temperature = interpolate(startY, endY, time0to1);
        inverseTemperature = 1.0 / temperature;
    }

    private double elapsedPercentage(double min, double max, double v) {
        return (v - min) / (max - min);
    }

    private double interpolate(double v0, double v1, double d0to1) {
        return v0 + (v1 - v0) * d0to1;
    }

    public void updateRange() {
        range = endRange + (startRange - endRange) * Math.pow((endTime - time) / (endTime - startTime), 1.0);
    }

    public void updateTime() {
        time = useTime ? Main.watch.getSecond() : numIterations;
    }

    public boolean isTLE() {
        return time >= endTime;
    }

    public boolean accept(double deltaScore) {
        return acceptB(deltaScore);
    }

    public boolean acceptB(double deltaScore) {
        validIterations++;

        if (deltaScore > -1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore < 0 : Utils.toString(deltaScore);
        assert 1.0 / inverseTemperature >= 0;

        double d = deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

    public boolean acceptS(double deltaScore) {
        validIterations++;

        if (deltaScore < 1e-9) {
            acceptIterations++;
            return true;
        }

        assert deltaScore > 0;
        assert 1.0 / inverseTemperature >= 0;

        double d = -deltaScore * inverseTemperature;
        if (d < -10) {
            return false;
        }
        if (log[Main.rng.nextInt() & 32767] < d) {
            acceptIterations++;
            lastAcceptTemperature = inverseTemperature;
            return true;
        }
        return false;
    }

}

class IntSet {
    private static final int EMPTY = -1;
    private int size;
    private int[] indexToValue;
    private int[] valueToIndex;

    public IntSet(int capacity) {
        this.size = 0;
        indexToValue = new int[capacity];
        valueToIndex = new int[capacity];
        Arrays.fill(valueToIndex, EMPTY);
    }

    public boolean add(int value) {
        if (valueToIndex[value] != EMPTY) {
            return false;
        }
        indexToValue[size] = value;
        valueToIndex[value] = size;
        size++;
        return true;
    }

    public boolean remove(int value) {
        int index = indexOf(value);
        if (index == EMPTY) {
            return false;
        }
        removeByIndex(index);
        return true;
    }

    private boolean removeByIndex(int index) {
        if (size == 0) {
            return false;
        }
        assert index < size;
        size--;
        int value = indexToValue[index];
        int value2 = indexToValue[size];
        indexToValue[index] = value2;
        valueToIndex[value2] = index;

        indexToValue[size] = value;
        valueToIndex[value] = EMPTY;

        return true;
    }

    public void swap(int index, int index2) {
        assert index < size;
        assert index2 < size;

        int swap = indexToValue[index];
        indexToValue[index] = indexToValue[index2];
        indexToValue[index2] = swap;

        valueToIndex[indexToValue[index]] = index;
        valueToIndex[indexToValue[index2]] = index2;

    }

    public void swapValue(int value, int value2) {
        assert value < size;
        assert value2 < size;

        int swap = valueToIndex[value];
        valueToIndex[value] = valueToIndex[value2];
        valueToIndex[value2] = swap;

        indexToValue[valueToIndex[value]] = value;
        indexToValue[valueToIndex[value2]] = value2;

    }

    public int get(int index) {
        assert index < size;
        return indexToValue[index];
    }

    public int indexOf(int value) {
        return valueToIndex[value];
    }

    public int size() {
        return size;
    }

    public boolean isEmpty() {
        return size() <= 0;
    }

    public void clear() {
        for (; size() > 0;) {
            removeByIndex(0);
        }
    }

    public boolean contains(int value) {
        return indexOf(value) != EMPTY;
    }

    @Override
    public String toString() {
        return Arrays.toString(Arrays.copyOf(indexToValue, size()));
    }
}

class NearPanel {
    int panel;
    int rotate;
    int sameColors;

    public NearPanel(int panel, int rotate, int sameColors) {
        super();
        this.panel = panel;
        this.rotate = rotate;
        this.sameColors = sameColors;
    }

}

↑このページのトップヘ