🌈C@T :)

BalsnCTF2022 Vss

vss

task.py

#!/opt/homebrew/bin/python3
from Crypto.Util.number import *
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
import os
import random
from hashlib import sha256
FLAG = b'this_is_a_test_flag'

class ShareScheme:
    def __init__(self, key: bytes):
        assert len(key) == 128
        self.key1 = bytes_to_long(key[:64])
        self.key2 = bytes_to_long(key[64:])

    def getShare(self):
        p = getPrime(512)
        a = random.randint(2, p - 1)
        b = random.randint(2, p - 1)
        c = random.randint(2, p - 1)
        y = (a + self.key1 * b + self.key2 * c) % p
        return p, a, b, c, y
        
def commit(val: int):
    p = getPrime(512)
    g = random.randint(2, p - 1)
    print(f"Commitment: {p} {g} {pow(g, val, p)}")

key = os.urandom(128)
ss = ShareScheme(key)

real_key = sha256(key).digest()[:16]
cipher = AES.new(real_key, AES.MODE_ECB)
enc_flag = cipher.encrypt(pad(FLAG, 16))
print(f"flag = {enc_flag.hex()}")

while True:
    op = int(input("Option: "))
    if op == 1:
        p, a, b, c, y = ss.getShare()
        print(f"{p = }")
        print(f"{a = }")
        print(f"{b = }")
        print(f"{c = }")
        commit(y)
    else:
        exit(0)

Soulution

We can get $a,b,c$ from share, $y=a+b\cdot x_1+c\cdot x_2$

Commitment $(p,\ g,\ com)$ is known which satisfy $com=g^y\ (mod\ p)$.

if $p-1$ has small factors $n$, using pohlig-hellman can recover $s=y\ (mod\ n)$

Rewriting the equation, $$ s+k\cdot n=a+b\cdot x_1+c\cdot x_2\ (mod\ p) $$ Choose the data, which $n>2^{t}$ .

Bit length of $x_1,\ x_2$ is 512, $s<n,\ k<\frac{p}{n}$

So combine the relation through sufficient share equation using CRT.

It looks like this, $$ S+\sum n_i’\cdot k_i=A+B\cdot x_1+C\cdot x_2\ (\ mod\ \prod p_i) $$ $k_i,x_1,x_2$ is unknown, we can apply LLL to solve this, the lattice is as follow $$ \left[\begin{matrix} 1&&&&&&&&B\\ &1&&&&&&&C\\ &&2^{512}&&&&&&(A-S)\\ &&&2^{t}&&&&&n_1’\\ &&&&&\cdots&\\ &&&&&&&2^{t}&n_l’\\ &&&&&&&&\prod p_i \end{matrix}\right] $$

The vector $(x_1,\ x_2,\ 2^{512},\ 2^t\cdot k_1,\ \cdots,\ 2^t\cdot k_l,\ 0)$ is in the lattice,

We hope to make the vector is shorter than minkowsiki bound,

$\Rightarrow 2^{512\cdot l+t\cdot l+512}>2^{512\cdot(l+4)}$

$\Rightarrow t>\frac{512\cdot3}{l}$

I use the parameter $t=30,\ l=55$

from pwn import *
from Crypto.Util.number import *
from hashlib import sha256
from Crypto.Cipher import AES
#context(log_level="debug")
# def matrix_overview(BB):
#     for ii in range(BB.dimensions()[0]):
#         a = ('%02d '%ii)
#         for jj in range(BB.dimensions()[1]):
#             if BB[ii,jj] == 0:
#                 a += ' '
#             else:
#                 a += 'X'
#             if BB.dimensions()[0] < 60:
#                 a += ' '
#         print(a)

DB = []
p = 2
while p < 2**23:
    DB.append(p)
    p = next_prime(p)

    
def bsgs(g, y, bound, p):
    step = isqrt(bound)+1
    dir = {}
    for i in range(step):
        dir[str(y*inverse_mod(pow(g, i, p), p)%p)] = i
    g_ = pow(g, step, p)
    for i in range(step):
        idx = pow(g_, i, p)
        if str(idx) in dir.keys():
            return i*step + dir[str(idx)]
        

def test_ord(g, p, t):
    n = p-1
    while n%t == 0:
        n = n//t
    if pow(g, n, p) == 1:
        return False
    return True

        
def PH(p, g, y):
    n = p-1
    dl = []
    ml = []
    for item in DB:
        if n%item == 0:
            g_ = pow(g, n//item, p)
            if g_ == 1: continue
            y_ = pow(y, n//item, p)
            d = bsgs(g_, y_, item, p)
            assert pow(g_, d, p) == y_
            dl.append(d)
            ml.append(item)
    F = GF(p)
    g1 = F(pow(g, n//prod(ml), p))
    for item in ml:
        if not test_ord(g, p, item):
            dl.remove(dl[ml.index(item)])
            ml.remove(item)
    if prod(ml) < 2**30: return None
    return (crt(dl, ml), prod(ml))

def solver(al, bl, cl, yl, nl, kl):
    bound = 2**30
    t = 55
    B = crt(bl, nl)
    C = crt(cl, nl)
    A = crt([al[i]-yl[i] for i in range(len(al))], nl)
    D = []
    for i in range(t):
        D.append(crt([0]*i+[kl[i]]+[0]*(t-1-i), nl))

    L = matrix(ZZ, t+4, t+4)

    L[0, t+3] = B
    L[1, t+3] = C
    L[2, t+3] = A

    for _ in range(2):
        L[_, _] = 1
    L[2, 2] = 2^512

    for _ in range(3, t+3):
        L[_, _] = bound
        L[_, t+3] = D[_-3]

    L[t+3, t+3] = prod(nl)
    mb = int(L.det()^(1/(t+4)))
    print(mb.bit_length())
    assert mb.bit_length() > 512
    basis = L.LLL()[0]
    return basis[0], basis[1]


io = process("./chall.py")

para = []
io.recvuntil("flag = ")
flag_enc = bytes.fromhex(io.recvline()[:-1].decode())
al, bl, cl, yl, nl, kl = [[] for _ in range(6)]
while len(al) < 55:
    io.sendlineafter("Option: ", '1')
    io.recvuntil("p = ")
    p = int(io.recvuntil("\n"))
    io.recvuntil("a = ")
    a = int(io.recvuntil("\n"))
    io.recvuntil("b = ")
    b = int(io.recvuntil("\n"))
    io.recvuntil("c = ")
    c = int(io.recvuntil("\n"))
    io.recvuntil("Commitment: ")
    com = list(map(int, io.recvuntil("\n").decode().split(' ')))
    k = PH(*com)
    if k:
        al.append(a)
        bl.append(b)
        cl.append(c)
        yl.append(k[0])
        nl.append(p)
        kl.append(k[1])

print("Start solver")
key = solver(al, bl, cl, yl, nl, kl)
key = long_to_bytes(abs(key[0]))+long_to_bytes(abs(key[1]))
real_key = sha256(key).digest()[:16]
cipher = AES.new(real_key, AES.MODE_ECB)
flag = cipher.decrypt(flag_enc)
print(flag)