picoCTF 2022 Sequences writeup

picoCTF 2022の本番中に解けなかった問題 "Sequences" の復習writeupです。

picoCTF 2022全体のwriteupはこちら

問題の概要

Pythonスクリプトが与えられます。

スクリプトの中では、以下の関数が定義されています。

@functools.cache
def m_func(i):
    if i == 0: return 1
    if i == 1: return 2
    if i == 2: return 3
    if i == 3: return 4

    return 55692*m_func(i-4) - 9549*m_func(i-3) + 301*m_func(i-2) + 21*m_func(i-1)

スクリプト全体としては、この関数に int(2e7), つまり 20000000 を渡した結果を使ってフラグを復号するスクリプトです。

この関数は functools.cache によってメモ化されていますが、それでも 20000000 を渡すと計算は十分に遅いです。 これを高速化する必要があります。

本番中の考察&試したこと

(全体writeupからのコピペです)

問題のヒントに行列の対角化とあったので、この関数で行っている計算を行列を使って言い換えることを考えます。

この関数を f と書くことにして、 f(i)f(i+3) を使って f(i+1)f(i+4) を表現すると、

f(i+1) = f(i+1)
f(i+2) = f(i+2)
f(i+3) = f(i+3)
f(i+4) = 55692*f(i) - 9549*f(i+1) + 301*f(i+2) + 21*f(i+3)

となるので、f(i) - f(i+3)f(i+1) - f(i+4) に変換する捜査は

 \begin{pmatrix}
f(i+1)\\
f(i+2)\\
f(i+3)\\
f(i+4)
\end{pmatrix} =
\begin{pmatrix}
0&1&0&0\\
0&0&1&0\\
0&0&0&1\\
55692&-9549&301&21
\end{pmatrix}
\begin{pmatrix}
f(i)\\
f(i+1)\\
f(i+2)\\
f(i+3)
\end{pmatrix}

と書けます。 これを繰り返し使って、

 \begin{pmatrix}
f(n)\\
f(n+1)\\
f(n+2)\\
f(n+3)
\end{pmatrix} =
\begin{pmatrix}
0&1&0&0\\
0&0&1&0\\
0&0&0&1\\
55692&-9549&301&21
\end{pmatrix}^n
\begin{pmatrix}
f(0)\\
f(1)\\
f(2)\\
f(3)
\end{pmatrix} =
\begin{pmatrix}
0&1&0&0\\
0&0&1&0\\
0&0&0&1\\
55692&-9549&301&21
\end{pmatrix}^n
\begin{pmatrix}
1\\
2\\
3\\
4
\end{pmatrix}

が言えます。

つまり f(n) の値は、f(4), f(5), f(6), ...n まで順番に計算していくのではなく、行列

 \begin{pmatrix}
0&1&0&0\\
0&0&1&0\\
0&0&0&1\\
55692&-9549&301&21
\end{pmatrix}

の累乗( 20000000 乗)を計算することによって求められます。

行列の累乗は対角化によって高速に計算できるので、ここまではヒントに沿った考察ができているような気がします。

ここからは少し試行錯誤しましたが、現実的なレベルまでの高速化はできませんでした。 行列を対角化したとしても、そもそも行列でなく整数の 20000000 乗ですら現実的な時間では計算できないので、何か根本的にもう1ステップの考察が必要なんだと思います。

スクリプトを見直して、答えそのものではなく、答えを 10**10000 で割った余りのみを求めるのでも良いことはわかりましたが、 それをどう使って計算量を減らせるかはすぐに思いつかず、とりあえずパス。

解法

Pythonではデフォルトでいくらでも大きい整数を扱えますが、 桁数がとても大きい整数の計算をしたい場合は gmpy2.mpz を使うほうが高速のようです。 gmpy2.mpzの説明はこちら

原理はわかりませんが、確認してみると、たしかにはるかに高速です。

>>> import gmpy2
>>> import time
>>>
>>> def a(n):
...   3**n
...
>>> def b(n):
...   gmpy2.mpz(3)**n
...
>>> def stopwatch(f, n):
...   start = time.perf_counter()
...   f(n)
...   return time.perf_counter() - start
...
>>> stopwatch(a, int(2e7))
8.492703900003107
>>> stopwatch(b, int(2e7))
0.12415110000074492

上述した考察をもとに、整数の累乗の計算にはこの gmpy2.mpz を使うようにして、 m_func 関数を書き換えたものが以下です。

import sympy
from gmpy2 import mpz

def m_func(i):
    m = sympy.Matrix([
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [55692, -9549, 301, 21]
    ])

    p, d = m.diagonalize()
    eigs = [d[j, j] for j in range(4)]
    eigs_powered = [int(pow(mpz(e), i)) for e in eigs]
    d_powered = sympy.Matrix([
        [eigs_powered[0], 0, 0, 0],
        [0, eigs_powered[1], 0, 0],
        [0, 0, eigs_powered[2], 0],
        [0, 0, 0, eigs_powered[3]],
    ])

    m_powered = p * d_powered * p**(-1)
    f = m_powered * sympy.Matrix([[1], [2], [3], [4]])
    return f[0, 0]

この関数は、手元の環境では8秒程度で結果を返しました。

m_func 関数を上記のものに置き換えてからスクリプト全体を実行した結果は以下の通りです。

wn@wsl:~/workspace/ctf/picoCTF/picoCTF2022/sequences$ pipenv run python sequences.py
picoCTF{b1g_numb3rs_1e4c686b}

得られた教訓

桁数がとても大きい整数の計算には gmpy2.mpz を使う。