BETA

Exact difference triangleを解いてみた

投稿日:2019-07-29
最終更新:2019-07-29

Exact difference triangleとは

Exact difference triangleは、1から15までの整数を以下のような逆三角形に整列させる数学パズルのような問題です。

並べるときのルールは1つだけ。
最小の逆三角形を構成する3つの数を考えたとき、下の数が上2つの数の差になっていることです。

と、書いてもなんのことやら、という感じなので、図を使います。

赤く括った3つの数に対して、a_6はa_1とa_2の差になっている、ということです。
これはすべての逆三角形(10通り)で成立します。

ちなみに、深さが5のときこれを満たす並べ方は(線対称の並べ方を1つとすると)1通りしかないことが証明されています。
参考: https://w3.math.sinica.edu.tw/bulletin/bulletin_old/d51/5120.pdf
(名前もここから取ってきました)

で。今回はこれを解くアルゴリズムを実装しました。
きっかけはたまたま(職場で)話題に上がったからです。

方針

以下、各位置をa[深さ][左から何番目か]で表します(例えばa_8 = a[2][3])。

問題を数式にすると、a[i][j] = |a[i-1][j] - a[i-1][j+1]| (2 =< i =< 5, 1 =< j =< 4)と表せます。
つまり、最上段5つを埋めると残りは自動的に決まる、ということで最上段を考える方針としました。
ここで最下段から固定しなかったのは、全探索するときにO(n!)になりそうだったから。
最上段から決めると愚直にやってもO(n^5)なので、安全を取りました。

それがこれ

# 前略  
        for i in range(1, 16):  
            for j in range(1, 16):  
                for k in range(1, 16):  
                    for l in range(1, 16):  
                        for m in range(1, 16):  
                            tmp = [i, j, k, l, m]  
                            if len(tmp) == len(set(tmp)):  
                                f.write(','.join([str(n) for n in tmp]) + '\n')  
                                for n, num in enumerate(tmp):  
                                    ans[0][n] = num  
                                ans = calc_all(ans)  
                                if check_correct(ans):  
                                    print(ans)  
                                    break  
# 以下略  

マシンパワーにものを言わせた感がすごい。
手元の環境では大体1.5 sec前後で答えが出力されました。

さて、ここからが本番。

15の位置を決めた

最上段以外は、1から15の間の2つの数の差が入ります。
ここから、15-1=14なので15は最上段にしかいられないということがわかります。
これを実装するとこんな感じ
ついでに後ろ半分は線対称になるので、探索範囲からカットしました(1つでも答えが見つかったらbreakするので意味ないですが)。

# 前略  
        for i in range(1, 8):  
            for j in range(1, 15):  
                for k in range(1, 15):  
                    for l in range(1, 15):  
                        use_list = [i, j, k, l]  
                        if len(use_list) == len(set(use_list)):  
                            for m in range(0, 3):  
                                ans_0_n = use_list[:]  
                                ans_0_n.insert(m, 15)  
                                f.write(','.join([str(n) for n in ans_0_n]) + '\n')  
                                ans[0] = ans_0_n  
                                ans = calc_all(ans)  
                                if check_correct(ans):  
                                    print(ans)  
                                    break  
# 以下略  

このとき、15の位置を決めるのはO(√n)となります(なるはず……)。
手元の環境では大体0.35 secくらい。計算通り減っててちょっと感動。

枝刈りしてみた

今のやり方だと、最上段5通りを決めて全マスを計算しています。
しかし、例えば[15, 1, 2, x, y]と当てはめると、|1-2| = 1よりxとyに関わらず不正解となります。
このためにはその下の段を計算する必要があります。
つまり、純粋な幅優先探索から深さ方向に先読みする処理を入れました(15pyramid_3.py)。

# 前略  
        for m in range(0, 3):  
            for i in range(1, 8):  
                for j in range(1, 15):  
                    ans_n_0 = [i, j, -1, -1] # -1は使わないことを示すflag  
                    if j == i:  
                        continue  
                    else:  
                        if not check_skip(ans_n_0, m, ans): # check_skip()が先読みするメソッド。可能性がなければFalseを返す。  
                            f.write(','.join([str(n) for n in ans_n_0]) + '\n')  
                            continue  
                    for k in range(1, 15):  
# 以下略  

探索範囲は30000 -> 8000、時間は0.10~0.15secほどになりました。
(枝刈りってO(N)表記だとどう表すんだっけ……)

そもそも5つ選んだ段階で無理なのあるじゃん?

1から15までの整数の性質を考えると、奇数が8つ、偶数が7つあります。
そして、|奇数 - 偶数| = 奇数、それ以外は偶数となるのは自明です。
これを使って、最上段の候補が出たときにパターン認識的に不正解かどうかを判定するロジックを入れました(15pyramid_4.py)。

ちなみに、偶奇性が保たれるのは全32通り中10通りしかなく、実行時間は1/3ほどになることが予測できます。

# 前略  
def parity_accepted_list():  
    accepted_list = []  
    check_parity = init()  
    for i in range(2):  
        for j in range(2):  
            for k in range(2):  
                for l in range(2):  
                    for m in range(2):  
                        check_parity[0] = [i, j, k, l, m]  
                        calc_all(check_parity)  
                        check_parity_1d = sum(check_parity, [])  
                        if check_parity_1d.count(1) == 8 and check_parity_1d.count(0) == 7:  
                            accepted_list.append(check_parity[0])  

    return accepted_list  

# 中略  

                                ans_n_0_mod2 = []  
                                for n in range(len(ans_n_0)):  
                                    ans_n_0_mod2.append(ans_n_0[n] % 2)  
                                if ans_n_0_mod2 not in accepted_list:  
                                    continue  
# 以下略  

結果、探索範囲は 8000 -> 4500ほどに、実行時間は0.07 secほどとなりました。
恐らく、偶奇性確認のための準備がイケてないのがあまり効果が見えない原因ですね……。

まとめ

なんやかんや手を施して、実行時間を1/20まで減らすことができました。
もう少し数学的に条件を絞ったり、2分探索的なことをしたりすればもっといける気がします。
というか、O(n^2)よりでかいと使い物にならん、と教わってきたのでもっと詰めないと怒られます(誰に?)。
夏休みにでも頑張ろうと思います。

余談ですが、手計算で答えを求めるならループを降順にするだけで相当早く求められると思います。

技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
駆け出しエンジニアからエキスパートまで全ての方々のアウトプットを歓迎しております!
or 外部アカウントで 登録 / ログイン する
クランチについてもっと詳しく

この記事が掲載されているブログ

@tmyoasの技術ブログ

よく一緒に読まれる記事

0件のコメント

ブログ開設 or ログイン してコメントを送ってみよう
目次をみる
技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
or 外部アカウントではじめる
10秒で技術ブログが作れます!