BalsnCTF2022 Vss
vss
task.py
#!/opt/homebrew/bin/python3
from Crypto.Util.number import *
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
import os
import random
from hashlib import sha256
FLAG = b'this_is_a_test_flag'
class ShareScheme:
def __init__(self, key: bytes):
assert len(key) == 128
self.key1 = bytes_to_long(key[:64])
self.key2 = bytes_to_long(key[64:])
def getShare(self):
p = getPrime(512)
a = random.randint(2, p - 1)
b = random.randint(2, p - 1)
c = random.randint(2, p - 1)
y = (a + self.key1 * b + self.key2 * c) % p
return p, a, b, c, y
def commit(val: int):
p = getPrime(512)
g = random.randint(2, p - 1)
print(f"Commitment: {p} {g} {pow(g, val, p)}")
key = os.urandom(128)
ss = ShareScheme(key)
real_key = sha256(key).digest()[:16]
cipher = AES.new(real_key, AES.MODE_ECB)
enc_flag = cipher.encrypt(pad(FLAG, 16))
print(f"flag = {enc_flag.hex()}")
while True:
op = int(input("Option: "))
if op == 1:
p, a, b, c, y = ss.getShare()
print(f"{p = }")
print(f"{a = }")
print(f"{b = }")
print(f"{c = }")
commit(y)
else:
exit(0)
Soulution
We can get $a,b,c$ from share, $y=a+b\cdot x_1+c\cdot x_2$
Commitment $(p,\ g,\ com)$ is known which satisfy $com=g^y\ (mod\ p)$.
if $p-1$ has small factors $n$, using pohlig-hellman can recover $s=y\ (mod\ n)$
Rewriting the equation, $$ s+k\cdot n=a+b\cdot x_1+c\cdot x_2\ (mod\ p) $$ Choose the data, which $n>2^{t}$ .
Bit length of $x_1,\ x_2$ is 512, $s<n,\ k<\frac{p}{n}$
So combine the relation through sufficient share equation using CRT.
It looks like this, $$ S+\sum n_i’\cdot k_i=A+B\cdot x_1+C\cdot x_2\ (\ mod\ \prod p_i) $$ $k_i,x_1,x_2$ is unknown, we can apply LLL to solve this, the lattice is as follow $$ \left[\begin{matrix} 1&&&&&&&&B\\ &1&&&&&&&C\\ &&2^{512}&&&&&&(A-S)\\ &&&2^{t}&&&&&n_1’\\ &&&&&\cdots&\\ &&&&&&&2^{t}&n_l’\\ &&&&&&&&\prod p_i \end{matrix}\right] $$
The vector $(x_1,\ x_2,\ 2^{512},\ 2^t\cdot k_1,\ \cdots,\ 2^t\cdot k_l,\ 0)$ is in the lattice,
We hope to make the vector is shorter than minkowsiki bound,
$\Rightarrow 2^{512\cdot l+t\cdot l+512}>2^{512\cdot(l+4)}$
$\Rightarrow t>\frac{512\cdot3}{l}$
I use the parameter $t=30,\ l=55$
from pwn import *
from Crypto.Util.number import *
from hashlib import sha256
from Crypto.Cipher import AES
#context(log_level="debug")
# def matrix_overview(BB):
# for ii in range(BB.dimensions()[0]):
# a = ('%02d '%ii)
# for jj in range(BB.dimensions()[1]):
# if BB[ii,jj] == 0:
# a += ' '
# else:
# a += 'X'
# if BB.dimensions()[0] < 60:
# a += ' '
# print(a)
DB = []
p = 2
while p < 2**23:
DB.append(p)
p = next_prime(p)
def bsgs(g, y, bound, p):
step = isqrt(bound)+1
dir = {}
for i in range(step):
dir[str(y*inverse_mod(pow(g, i, p), p)%p)] = i
g_ = pow(g, step, p)
for i in range(step):
idx = pow(g_, i, p)
if str(idx) in dir.keys():
return i*step + dir[str(idx)]
def test_ord(g, p, t):
n = p-1
while n%t == 0:
n = n//t
if pow(g, n, p) == 1:
return False
return True
def PH(p, g, y):
n = p-1
dl = []
ml = []
for item in DB:
if n%item == 0:
g_ = pow(g, n//item, p)
if g_ == 1: continue
y_ = pow(y, n//item, p)
d = bsgs(g_, y_, item, p)
assert pow(g_, d, p) == y_
dl.append(d)
ml.append(item)
F = GF(p)
g1 = F(pow(g, n//prod(ml), p))
for item in ml:
if not test_ord(g, p, item):
dl.remove(dl[ml.index(item)])
ml.remove(item)
if prod(ml) < 2**30: return None
return (crt(dl, ml), prod(ml))
def solver(al, bl, cl, yl, nl, kl):
bound = 2**30
t = 55
B = crt(bl, nl)
C = crt(cl, nl)
A = crt([al[i]-yl[i] for i in range(len(al))], nl)
D = []
for i in range(t):
D.append(crt([0]*i+[kl[i]]+[0]*(t-1-i), nl))
L = matrix(ZZ, t+4, t+4)
L[0, t+3] = B
L[1, t+3] = C
L[2, t+3] = A
for _ in range(2):
L[_, _] = 1
L[2, 2] = 2^512
for _ in range(3, t+3):
L[_, _] = bound
L[_, t+3] = D[_-3]
L[t+3, t+3] = prod(nl)
mb = int(L.det()^(1/(t+4)))
print(mb.bit_length())
assert mb.bit_length() > 512
basis = L.LLL()[0]
return basis[0], basis[1]
io = process("./chall.py")
para = []
io.recvuntil("flag = ")
flag_enc = bytes.fromhex(io.recvline()[:-1].decode())
al, bl, cl, yl, nl, kl = [[] for _ in range(6)]
while len(al) < 55:
io.sendlineafter("Option: ", '1')
io.recvuntil("p = ")
p = int(io.recvuntil("\n"))
io.recvuntil("a = ")
a = int(io.recvuntil("\n"))
io.recvuntil("b = ")
b = int(io.recvuntil("\n"))
io.recvuntil("c = ")
c = int(io.recvuntil("\n"))
io.recvuntil("Commitment: ")
com = list(map(int, io.recvuntil("\n").decode().split(' ')))
k = PH(*com)
if k:
al.append(a)
bl.append(b)
cl.append(c)
yl.append(k[0])
nl.append(p)
kl.append(k[1])
print("Start solver")
key = solver(al, bl, cl, yl, nl, kl)
key = long_to_bytes(abs(key[0]))+long_to_bytes(abs(key[1]))
real_key = sha256(key).digest()[:16]
cipher = AES.new(real_key, AES.MODE_ECB)
flag = cipher.decrypt(flag_enc)
print(flag)