🌈C@T :)

CorCTF2022 corrupted curves+

corrupted curves+

task.py

#!/opt/homebrew/bin/python3
from secrets import randbits
from Crypto.Util.number import getPrime
from random import randrange

def square_root(a, p):
    if legendre_symbol(a, p) != 1:
        return 0
    elif a == 0:
        return 0
    elif p == 2:
        return 0
    elif p % 4 == 3:
        return pow(a, (p + 1) // 4, p)
    s = p - 1
    e = 0
    while s % 2 == 0:
        s //= 2
        e += 1
    n = 2
    while legendre_symbol(n, p) != -1:
        n += 1
    x = pow(a, (s + 1) // 2, p)
    b = pow(a, s, p)
    g = pow(n, s, p)
    r = e
    while True:
        t = b
        m = 0
        for m in range(r):
            if t == 1:
                break
            t = pow(t, 2, p)
        if m == 0:
            return x
        gs = pow(g, 2 ** (r - m - 1), p)
        g = (gs * gs) % p
        x = (x * gs) % p
        b = (b * g) % p
        r = m

def legendre_symbol(a, p):
    ls = pow(a, (p - 1) // 2, p)
    return -1 if ls == p - 1 else ls

class EllipticCurve:
    
    def __init__(self, p, a, b):
        self.a = a
        self.b = b
        self.p = p
        if not self.check_curve():
            raise Exception("Not an elliptic curve!")
        
    def check_curve(self):
        discrim = -16 * (4*pow(self.a, 3) + 27*pow(self.b, 2))
        if discrim % self.p:
            return 1
        return 0
    
    def lift_x(self, px):
        y2 = (pow(px, 3) + self.a*px + self.b) % self.p
        py = square_root(y2, self.p)
        if py == 0:
            raise Exception("No point on elliptic curve.")
        return py

flag = b'this_is_a_sample_flag'
flag = int.from_bytes(flag, 'big')

print("Generating parameters...")
while True:
    p = getPrime(512)
    a, b = randbits(384), randbits(384)
    try:
        E = EllipticCurve(p, a, b)
        fy = E.lift_x(flag)
        print(f"p = {p}")
        print(f"flag y = {fy}")
        break
    except:
        continue
checked = set()
count = 0
while count < 2022:
    x = randrange(2, p)
    if int(x) in checked or x < 2**384 or abs(x - p) < 2**384:
        print(">:(")
        continue
    try:
        e = randbits(48)
        print(f"e = {e}")
        E = EllipticCurve(p, a^e, b^e)
        py = E.lift_x(x)
        checked.add(x)
        print(f"x = {x}")
        print(f"y = {py}")
        count += 1
    except:
        print(":(")
    more = input("more> ")
    if more.strip() == "no":
        break
print("bye!")

Solution

I find the challenge e2D1p which I made in n1ctf has the same idea with this.

First, we can get $y_0,p$ which satisfy the equation $y_0^2=m^3+a*m+b\ (mod\ p)$

Then we should recover $a,b$ with 2022 inquire oppotunities.

In every inquiry, we have $y^2=x^3+(a\oplus e)*x+(b\oplus e)\ (mod\ q)$ with $x,y,e$ known.

Rewriting the equation as follow, $$ y^2=x^3+a_l*x+\sum\limits_{i=0}\limits^{47}2^i\cdot(a_i\oplus e_i)*x+b_l+\sum\limits_{i=0}\limits^{47}2^i\cdot(b_i\oplus e_i)\ (mod\ p) $$ By the way, for one bit message $a,b$ $$ a\oplus b = \begin{cases} a& \text{b=0}\\ 1-a& \text{b=1} \end{cases} $$ The equation can be converted the linear expression about $a_l,\ b_l,\ a_i,\ b_i$

We can solve these variables by collect 98 equations.

So, if we get $a\ b$, just need to find the root of $x^3+a*x+b-y_0^2\ (mod\ p)$. That’s easy!

from pwn import *
io = process("./corruptedcurvesplus.py")

io.recvuntil("p = ")
p = int(io.recvuntil("\n"))
io.recvuntil("flag y = ")
flag_y = int(io.recvuntil("\n"))

bound = 0
maxb = 98
para = []

while bound < maxb:
    io.recvuntil("e = ")
    e = int(io.recvuntil("\n"))
    sign = io.recvline()
    if b':(' not in sign:
        x = int(sign[4:])
        io.recvuntil("y = ")
        y = int(io.recvuntil("\n"))
        bound += 1
        para.append((e, x, y))
    io.sendlineafter("more> ", 'y')


mat = matrix(Zmod(p), maxb, 98)
u = vector(Zmod(p), maxb)
for i in range(maxb):
    e, x, y = para[i]
    r = y**2-x**3
    mat[i, 0] = x; mat[i, 1] = 1
    for j in range(48):
        ej = (e>>j)&1
        if ej:
            r -= 2**j*(x+1)
            mat[i, j+2] = -x*2**j%p
            mat[i, j+50] = -2**j%p
        else:
            mat[i, j+2] = x*2**j%p
            mat[i, j+50] = 2**j%p
    u[i] = r%p

io.close()
assert mat.det() != 0
sol = mat.solve_right(u)
a = sum([sol[0]]+[sol[i+2]*2**i for i in range(48)])
b = sum([sol[1]]+[sol[i+50]*2**i for i in range(48)])

PR.<x> = PolynomialRing(GF(p))
F = x^3+a*x+b-flag_y^2
root = F.roots()

for item in root:
    print(bytes.fromhex(hex(item[0])[2:]))

After that I find another way to solve the prolem using LLL, just 2 query needed. joseph’s wu