24を計算する問題#

from z3 import *

「24を計算する」問題とは、与えられた4つの数値を +, -, *, / の演算子を使って組み合わせ、結果が 24 になる数式を求めるパズルです。
例えば、[1, 5, 5, 5] という入力に対して (5 - (1 / 5)) * 5 のような数式を見つけることを目的とします。

この問題を解くためには:

  1. 4つの数の順列(permutation)を考慮する。

  2. 3つの演算子の組み合わせを考える。

  3. 計算結果が 24 になるような組み合わせを探索する。

数値の順序#

この関数は、Z3を用いて「variablesvalues の順列になる制約を追加」します。

  • variables: 例えば [x0, x1, x2, x3] のようなZ3変数(配列)。

  • values: 例えば [1, 5, 5, 5] のような整数のリスト。

この関数が行うこと:

  1. variables の各要素は values のいずれかの値であることを制約 (Or(x == v for v in values_set))

  2. 各値が values に登場する回数と一致することを制約 (Sum([x == v for x in variables]) == values.count(v))

例えば [1, 5, 5, 5] の場合、variables に含まれる数値は {1, 5} のみで、1 は1回、5 は3回登場する必要があります。

def permutation(variables, values):
    values_set = set(values)
    exprs = []
    for x in variables:
        exprs.append(Or([x == v for v in values_set]))

    for v in values_set:
        exprs.append(Sum([x == v for x in variables]) == values.count(v))

    return And(exprs)

from helper.z3 import all_solutions

xs = IntVector('x', 4)
values = [1, 5, 5, 5]

solver = Solver()
solver.add(permutation(xs, values))

for m in all_solutions(solver):
    print([m[x] for x in xs])
[5, 5, 5, 1]
[1, 5, 5, 5]
[5, 5, 1, 5]
[5, 1, 5, 5]

コード#

次の関数は、与えられた4つの数値の並びと演算子の組み合わせを考慮し、24 を作る計算式を求めます。

def output(numbers, ops):
    def _output_op(n1, n2, op):
        op_str = '+-*/-/'[op]
        if op in [4, 5]:
            n1, n2 = n2, n1

        if isinstance(n1, str):
            n1 = f'({n1})'

        if isinstance(n2, str):
            n2 = f'({n2})'
            
        return f'{n1} {op_str} {n2}'

    state = numbers[0]
    for n2, op in zip(numbers[1:], ops):
        state = _output_op(state, n2, op)

    return state
    
def solve_24_by_permutation(*numbers):
    solver = Solver()
    n = RealVector('n', 4)
    op = IntVector('op', 3)
    s = RealVector('s', 3)

    solver.add(permutation(n, numbers))
    
    for x in op:
        solver.add(0 <= x, x < 6)
   
    def calc_op(n1, op, n2, res):
        # res == n1 op n2 の制約条件を計算する
        return And([
            Implies(n2 == 0, op != 3),
            Implies(n1 == 0, op != 5),
            Implies(op == 0, res == n1 + n2),
            Implies(op == 1, res == n1 - n2),
            Implies(op == 2, res == n1 * n2),
            Implies(op == 3, res == n1 / n2),
            Implies(op == 4, res == n2 - n1),
            Implies(op == 5, res == n2 / n1),
        ])
        
    solver.add(calc_op(n[0], op[0], n[1], s[0])) # s0 = n0 op0 n1
    solver.add(calc_op(s[0], op[1], n[2], s[1])) # s1 = s0 op1 n2
    solver.add(calc_op(s[1], op[2], n[3], s[2])) # s2 = s1 op3 n3
    
    solver.add(s[-1] == 24)
    
    if solver.check() == sat:
        m = solver.model()
        return output([m.eval(x).as_long() for x in n], [m.eval(x).as_long() for x in op])
    else:
        return "no solution"

解析#

  1. 変数の設定

    • n[0] ~ n[3]: 4つの数の順列を表す実数 (RealVector)

    • op[0] ~ op[2]: 3つの演算子を表す整数 (IntVector)

    • s[0] ~ s[2]: 中間計算結果を格納する変数 (RealVector)

  2. 数値の順列を制約

    • solver.add(permutation(n, numbers)) により、nnumbers の順列であることを保証。

  3. 演算子の範囲を制約

    • solver.add(0 <= x, x < 6) により、演算子は 0 から 5 の範囲に制限(+, -, *, /, - (逆順), / (逆順))。

  4. calc_op(n1, op, n2, res) 関数

    • res = n1 op n2 の関係を Z3 の制約として表現する。

    • op の値に応じて、res が正しい結果になるようにする。

    • n2 == 0 のとき op != 3(ゼロ除算禁止)

    • n1 == 0 のとき op != 5(逆除算のゼロ禁止)

  5. 計算の流れ

    • s[0] = n[0] op[0] n[1]

    • s[1] = s[0] op[1] n[2]

    • s[2] = s[1] op[2] n[3]

    • s[2] == 24 という制約を追加。

次はいくつかの問題を解いてみます。

questions = [
    [1, 5, 5, 5],
    [3, 3, 8, 8],
    [3, 3, 7, 7],
    [1, 4, 5, 6],
    [2, 2, 2, 9],
    [2, 7, 8, 9],
    [6, 9, 9, 10],
    [1, 2, 7, 7],
    [4, 4, 10, 10],
    [2, 5, 5, 10],
]

for q in questions:
    expr = solve_24_by_permutation(*q)
    if expr != "no solution":
        print(f'{expr} = {eval(expr):g}')
(5 - (1 / 5)) * 5 = 24
8 / (3 - (8 / 3)) = 24
((3 / 7) + 3) * 7 = 24
6 / ((5 / 4) - 1) = 24
((9 + 2) * 2) + 2 = 24
((9 + 7) * 2) - 8 = 24
((9 * 10) / 6) + 9 = 24
((7 * 7) - 1) / 2 = 24
((10 * 10) - 4) / 4 = 24
(5 - (2 / 10)) * 5 = 24

関数で順序から数値へのマッピング#

次の関数は、solve_24_by_permutation() とは異なり、Z3の関数 (Function) を使用して順序から数値へ変換します。この方法では、インデックスに Distinct() 制約を適用することで、すべての数値の順列を求めることができます。

def solve_24_by_function(*numbers):
    number_map = Function('number_map', IntSort(), RealSort())
    solver = Solver()
    
    for i, n in enumerate(numbers):
        solver.add(number_map(i) == n)
        
    indices = IntVector('i', 4)
    op = IntVector('op', 3)
    s = RealVector('s', 3)
    
    for x in indices:
        solver.add(0 <= x, x < 4)
        
    solver.add(Distinct(indices))
    
    for x in op:
        solver.add(0 <= x, x < 6)
   
    def calc_op(n1, op, n2, res):
        "res == n1 op n2"
        return And([
            Implies(n2 == 0, op != 3),
            Implies(n1 == 0, op != 5),
            Implies(op == 0, res == n1 + n2),
            Implies(op == 1, res == n1 - n2),
            Implies(op == 2, res == n1 * n2),
            Implies(op == 3, res == n1 / n2),
            Implies(op == 4, res == n2 - n1),
            Implies(op == 5, res == n2 / n1),
        ])
        
    n = [number_map(i) for i in indices]
    solver.add(calc_op(n[0], op[0], n[1], s[0])) # s0 = n0 op0 n1
    solver.add(calc_op(s[0], op[1], n[2], s[1])) # s1 = s0 op1 n2
    solver.add(calc_op(s[1], op[2], n[3], s[2])) # s2 = s1 op3 n3
    
    solver.add(s[-1] == 24)
    
    if solver.check() == sat:
        m = solver.model()
        return output([m.eval(x).as_long() for x in n], [m.eval(x).as_long() for x in op])
    else:
        return "no solution"
  1. number_map による順序管理

number_map は、整数インデックス (IntSort()) を実数 (RealSort()) にマッピングする Z3 の関数です。
solvernumber_map(i) == n を追加することで、number_map(0) = numbers[0] から number_map(3) = numbers[3] まで、各インデックスに対応する数値を設定します。

number_map = Function('number_map', IntSort(), RealSort())
for i, n in enumerate(numbers):
    solver.add(number_map(i) == n)
  1. indices による順列の管理

indices は 4 つの整数インデックスを持ち、数値の順序を決定します。

indices = IntVector('i', 4)
for x in indices:
    solver.add(0 <= x, x < 4)
solver.add(Distinct(indices))

インデックスの制約 Distinct(indices) を追加することで、各インデックスが一意であることを保証し、異なる順列を考慮できるようにしています。

for q in questions:
    expr = solve_24_by_function(*q)
    print(f'{expr} = {eval(expr):g}')
(5 - (1 / 5)) * 5 = 24
8 / (3 - (8 / 3)) = 24
((3 / 7) + 3) * 7 = 24
4 / (1 - (5 / 6)) = 24
((9 + 2) * 2) + 2 = 24
((7 + 9) * 2) - 8 = 24
((9 / 6) * 10) + 9 = 24
((7 * 7) - 1) / 2 = 24
((10 * 10) - 4) / 4 = 24
(5 - (2 / 10)) * 5 = 24