ångstromCTF 2024 WriteUps
This time I solved some problems in Cystick. Since it coincided with the AIS3 Pre-Exam over the weekend, I only solved the three most difficult Crypto problems.
tss1
from hashlib import sha256
import fastecdsa.curve
import fastecdsa.keys
import fastecdsa.point
TARGET = b'flag'
curve = fastecdsa.curve.secp256k1
def input_point():
x = int(input('x: '))
y = int(input('y: '))
return fastecdsa.point.Point(x, y, curve=curve)
def input_sig():
c = int(input('c: '))
s = int(input('s: '))
return (c, s)
def hash_transcript(pk, R, msg):
h = sha256()
h.update(f'({pk.x},{pk.y})'.encode())
h.update(f'({R.x},{R.y})'.encode())
h.update(msg)
return int.from_bytes(h.digest(), 'big') % curve.q
def verify(pk, msg, sig):
c, s = sig
R = s * curve.G + c * pk
return c == hash_transcript(pk, R, msg)
if __name__ == '__main__':
import sys
if len(sys.argv) == 2 and sys.argv[1] == 'setup':
sk1, pk1 = fastecdsa.keys.gen_keypair(curve)
with open('key.txt', 'w') as f:
f.write(f'{sk1}\n{pk1.x}\n{pk1.y}\n')
exit()
with open('key.txt') as f:
sk1, x, y = map(int, f.readlines())
pk1 = fastecdsa.point.Point(x, y, curve=curve)
print(f'my public key: {(pk1.x, pk1.y)}')
print('gimme your public key')
pk2 = input_point()
apk = pk1 + pk2
print(f'aggregate public key: {(apk.x, apk.y)}')
print('what message do you want to sign?')
msg = bytes.fromhex(input('message: '))
if msg == TARGET:
print('anything but that')
exit()
k1, R1 = fastecdsa.keys.gen_keypair(curve)
print(f'my nonce: {(R1.x, R1.y)}')
print(f'gimme your nonce')
R2 = input_point()
R = R1 + R2
print(f'aggregate nonce: {(R.x, R.y)}')
c = hash_transcript(apk, R, msg)
s = (k1 - c * sk1) % curve.q
print(f'my share of the signature: {s}')
print(f'gimme me the aggregate signature for "{TARGET}"')
sig = input_sig()
if verify(apk, TARGET, sig):
with open('flag.txt') as f:
flag = f.read().strip()
print(flag)
The problem is essentially a Two-party Threshold Schnorr Signature Scheme, where the aggregate public key is simply the sum of both parties’ public keys. After exchanging keys, each party generates a nonce, exchanges it, and then calculates their respective , which together form part of the signature.
The attack method is also very simple. Since it doesn’t check whether the you send is really your public key, you can generate a pair yourself, then find . This way, you know the secret key of , so you can sign anything yourself.
from pwn import process, remote, context
from fastecdsa import keys, point
import ast
from tss1 import hash_transcript, verify, curve, TARGET
ask, apk = keys.gen_keypair(curve)
# io = process(["python", "tss1.py"])
io = remote("challs.actf.co", 31301)
io.recvuntil(b"my public key: ")
pk_tpl = ast.literal_eval(io.recvline().strip().decode())
pk1 = point.Point(*pk_tpl, curve=curve)
pk2 = apk - pk1
io.sendlineafter(b"x: ", str(pk2.x).encode())
io.sendlineafter(b"y: ", str(pk2.y).encode())
io.sendlineafter(b"sign?", b"peko".hex().encode())
k2, R2 = keys.gen_keypair(curve)
io.sendlineafter(b"x: ", str(R2.x).encode())
io.sendlineafter(b"y: ", str(R2.y).encode())
def sign(sk, pk, msg):
k, R = keys.gen_keypair(curve)
c = hash_transcript(pk, R, msg)
s = (k - c * sk) % curve.q
return c, s
c, s = sign(ask, apk, TARGET)
assert verify(apk, TARGET, (c, s))
io.sendlineafter(b"c: ", str(c).encode())
io.sendlineafter(b"s: ", str(s).encode())
io.interactive()
# actf{r0gu3_4ggr3g4t1on_632d50edb72d34d3}
BLÅHAJ
p = random_prime(2**1024)
q = random_prime(2**1024)
a = randint(0, 2**1024)
b = randint(0, 2**1024)
def read_flag(file='flag.txt'):
with open(file, 'rb') as fin:
flag = fin.read()
return flag
def pad_flag(flag, bits=1024):
pad = os.urandom(bits//8 - len(flag))
return int.from_bytes(flag + pad, "big")
def generate_keys(p, q):
n = p * q
e = 0x10001
return n, e
def encrypt_message(m, e, n):
return pow(m, e, n)
flag = read_flag()
m = pad_flag(flag)
n, e = generate_keys(p, q)
assert m < n
c = encrypt_message(m, e, n)
print(c)
print(n)
print(p + b * q)
print(a * p + q)
In short, you need to factorize , and the problem additionally provides , where are all 1024-bit numbers.
My approach is as follows. First, let
We know , so we can expand it to get:
Thus, we have:
Since are known, we can use LLL to find .
from sage.all import *
from Crypto.Util.number import long_to_bytes
from lll_cvp import solve_inequality
c = 2084015642966578282323320320430355169303796428932452813616522534642993911885394832889877216337047505539910273209203092431502448659110975363836148150333450975665054754794923282649866834797382152974537871637107602980242305752842881056832372729032719323542397731147821640234181272537527934691830138813076673120558790685225545176722374373891797444430162049094331866632430714818874728168815552807267537023134887147693780578518721495129472480845244586874832711656372700340333956532029558890132512276495501200465374040059705096729628227299657939763878581814774444365099164760074348408594002376240830368408586554430584367239
n = 5054759650831149212497593612117505449996534385400799412730981223889391367155509695417999090910848750197375100341995228567935542100150279063805945642486626676563744817810946769932250256245882026085508665771131635110277852237029849869509707637709162425864373201102790718464450861787668024958787992996715870537273213999452948215564503811414448019655761914535646338206972306327692776792357656461421160012732874929753400746379737081861110621084719920237555759371086828748564360804144357313056202764375401039317055078455329940265973052825170224804217772389893623634664438855142967283765161646402518311874260437023116525027
x = 14624164038828170251441254789590748299059493407127408167381909039718004816732842597998978394090418984661609925783012574638050052448399168193536431334288702858151820090630198056959727167341057230779720998603705567821824977324339182361973824850629918120718796161913475916523630822110582289148982548632694537423158073771024615682852893559402668337273941915650824074226258103607152649283649471317242038305999409041659019944804416273274057125302350242520142492153556081164794428695882962406952831655367215074021698796576010502956218034153621531037983312180680329891048894826768653762528273739711501728033951545847806707848
y = 1510897008373983998701686017209922960816127339466860789588606160332147878962564913406764611385229470849971288077374239278171471973749602656414838820558074700119192355231338991849403695036849047083282822255108161230346034624768585764808760248402952634113444599722515269998684427399124583405247144573527109975222081257008029790260559860849979421716752798127739488994858649526260912935259888277927955180906055358279957687314992052797112774530972530343550573798533793482820201610887225201749606682312677680359451371010180070950980849389904579670739506805226809592052548210734843148295456283103192907412693805696885095055
# f(t,u)=(x-t)(y-u)
# f(p,q)=0 (mod n)
# f(t,u)=xy-ux-ty+tu=xy-ux-ty (mod n)
# x, y ~ 2^1024 -> LLL
L = matrix([[n, 0, 0, 0], [x * y, 1, 0, 0], [-y, 0, 1, 0], [-x, 0, 0, 1]])
lb = [0, 1, 0, 0]
ub = [0, 1, 2**1024, 2**1024]
sol = solve_inequality(L, lb, ub)
_, _, p, q = map(int, sol)
assert p * q == n
phi = (p - 1) * (q - 1)
d = pow(0x10001, -1, phi)
m = pow(c, d, n)
print(long_to_bytes(m))
The author’s intended solution is slightly different, but the final result is the same as mine:
tss2
#!/usr/local/bin/python
from hashlib import sha256
import fastecdsa.curve
import fastecdsa.keys
import fastecdsa.point
TARGET = b'flag'
curve = fastecdsa.curve.secp256k1
def input_point():
x = int(input('x: '))
y = int(input('y: '))
return fastecdsa.point.Point(x, y, curve=curve)
def input_sig():
c = int(input('c: '))
s = int(input('s: '))
return (c, s)
def hash_transcript(pk, R, msg):
h = sha256()
h.update(f'({pk.x},{pk.y})'.encode())
h.update(f'({R.x},{R.y})'.encode())
h.update(msg)
return int.from_bytes(h.digest(), 'big') % curve.q
def verify(pk, msg, sig):
c, s = sig
R = s * curve.G + c * pk
return c == hash_transcript(pk, R, msg)
if __name__ == '__main__':
import sys
if len(sys.argv) == 2 and sys.argv[1] == 'setup':
sk1, pk1 = fastecdsa.keys.gen_keypair(curve)
with open('key.txt', 'w') as f:
f.write(f'{sk1}\n{pk1.x}\n{pk1.y}\n')
exit()
with open('key.txt') as f:
sk1, x, y = map(int, f.readlines())
pk1 = fastecdsa.point.Point(x, y, curve=curve)
print(f'my public key: {(pk1.x, pk1.y)}')
print('gimme your public key')
pk2 = input_point()
print('prove it!')
sig = input_sig()
if not verify(pk2, b'foo', sig):
print('boo')
exit()
apk = pk1 + pk2
print(f'aggregate public key: {(apk.x, apk.y)}')
print('what message do you want to sign?')
msg = bytes.fromhex(input('message: '))
if msg == TARGET:
print('anything but that')
exit()
k1, R1 = fastecdsa.keys.gen_keypair(curve)
print(f'my nonce: {(R1.x, R1.y)}')
print(f'gimme your nonce')
R2 = input_point()
R = R1 + R2
print(f'aggregate nonce: {(R.x, R.y)}')
c = hash_transcript(apk, R, msg)
s = (k1 - c * sk1) % curve.q
print(f'my share of the signature: {s}')
print(f'gimme me the aggregate signature for "{TARGET}"')
sig = input_sig()
if verify(apk, TARGET, sig):
with open('flag.txt') as f:
flag = f.read().strip()
print(flag)
Similar to tss1, but this time it checks if you have a signature of the string foo
after accepting , as proof that you really know the secret key of . Therefore, the previous rogue public key attack doesn’t work here.
My initial thought was that since it requires a signature of foo
, I could use its signing oracle to sign foo
, and then use the obtained as for the next connection.
Thinking this through, you’ll find that the form of the public key you get is , so after repeating this times, will disappear, meaning we know the secret key of . However, this attack is impractical because it requires connections, making it faster to brute-force the secret key directly.
In any case, I still wrote a script to demonstrate the feasibility of making the server produce :
from pwn import process, remote, context
from fastecdsa import keys, point
import ast
from tss2 import hash_transcript, verify, curve, TARGET
sk2, pk2 = keys.gen_keypair(curve)
# context.log_level = "debug"
def sign(sk, pk, msg):
k, R = keys.gen_keypair(curve)
c = hash_transcript(pk, R, msg)
s = (k - c * sk) % curve.q
return c, s
io = process(["python", "tss2.py"])
io.recvuntil(b"my public key: ")
pk_tpl = ast.literal_eval(io.recvline().strip().decode())
pk1 = point.Point(*pk_tpl, curve=curve)
io.sendlineafter(b"x: ", str(pk2.x).encode())
io.sendlineafter(b"y: ", str(pk2.y).encode())
c, s = sign(sk2, pk2, b"foo")
io.sendlineafter(b"c: ", str(c).encode())
io.sendlineafter(b"s: ", str(s).encode())
apk = pk1 + pk2
msg = b"foo"
io.sendlineafter(b"sign?", msg.hex().encode())
io.recvuntil(b"my nonce: ")
R1_tpl = ast.literal_eval(io.recvline().strip().decode())
R1 = point.Point(*R1_tpl, curve=curve)
k2, R2 = keys.gen_keypair(curve)
io.sendlineafter(b"x: ", str(R2.x).encode())
io.sendlineafter(b"y: ", str(R2.y).encode())
io.recvuntil(b"my share of the signature: ")
s1 = int(io.recvline().strip().decode())
R = R1 + R2
c = hash_transcript(apk, R, msg)
s2 = (k2 - c * sk2) % curve.q
s = (s1 + s2) % curve.q
assert verify(apk, msg, (c, s))
io.close()
def cont(i, apk, c, s):
io = process(["python", "tss2.py"])
io.recvuntil(b"my public key: ")
pk_tpl = ast.literal_eval(io.recvline().strip().decode())
pk1 = point.Point(*pk_tpl, curve=curve)
io.sendlineafter(b"x: ", str(apk.x).encode())
io.sendlineafter(b"y: ", str(apk.y).encode())
io.sendlineafter(b"c: ", str(c).encode())
io.sendlineafter(b"s: ", str(s).encode())
apk2 = pk1 + apk # sk = sk1 + (sk1 + sk2)
msg = b"foo"
io.sendlineafter(b"sign?", msg.hex().encode())
io.recvuntil(b"my nonce: ")
R1_tpl = ast.literal_eval(io.recvline().strip().decode())
R1 = point.Point(*R1_tpl, curve=curve)
R2 = (i - 1) * R1 # see below
io.sendlineafter(b"x: ", str(R2.x).encode())
io.sendlineafter(b"y: ", str(R2.y).encode())
io.recvuntil(b"my share of the signature: ")
s1 = int(io.recvline().strip().decode())
R = R1 + R2
c = hash_transcript(apk2, R, msg)
# s1 = (k1 - c * sk1) % curve.q
# s2 = (k2 - c * ((i - 1) * sk1 + sk2)) % curve.q
# if (i - 1) k1 = k2, (i - 1) * s1 - s2 = c * sk2 is known, so:
s2 = ((i - 1) * s1 - c * sk2) % curve.q
s = (s1 + s2) % curve.q
assert verify(apk2, msg, (c, s))
io.close()
return apk2, c, s
sk1 = 76955832679704021796451678563974610015930082877934961757640258812995135898232
assert (sk1 + sk2) * curve.G == apk
apks = [None, apk]
for i in range(2, 5):
print(i)
apk, c, s = cont(i, apk, c, s)
assert (i * sk1 + sk2) * curve.G == apk
apks.append(apk)
After spending quite some time, I realized that since is static in the problem, repeated connections use the same key. I also remembered reading somewhere that some MPC protocols are insecure under concurrent conditions.
So I searched and found On the (in)security of ROS. It mentions that many such protocols can become a problem called ROS under concurrent conditions, and the paper proposes an attack that can solve ROS in polynomial time, indicating that many similar protocols are insecure under concurrent conditions.
The ROS Attack generally involves receiving the opponent’s commitment , then calculating some magical coefficients under concurrent conditions. By processing the target you want to forge, you can get a linear combination to generate some special challenge (hash value in Fiat–Shamir Transform). Sending to all connections at once can get some responses , and performing linear combination on them can forge the signature.
In this threshold signature case, is the opponent’s nonce, and is hash_transcript(apk, R, msg)
, so we can use a similar method to find an appropriate to make the value we want, and then forge the signature.
However, the mathematical details of this attack are a bit complex, and I didn’t fully understand them. So I just used the sage code provided in the paper’s appendix, combined with examples from another article on the same paper, made some random modifications, and got it to work. I might need to spend more time studying it in the future.
from sage.all import GF
from pwn import process, remote
from fastecdsa import keys, point
import ast, random
from tss2 import hash_transcript, verify, curve, TARGET
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
sk2, pk2 = keys.gen_keypair(curve)
def connect(i):
# io = process(["python", "tss2.py"])
io = remote("challs.actf.co", 31302)
return io
def sign(sk, pk, msg):
k, R = keys.gen_keypair(curve)
c = hash_transcript(pk, R, msg)
s = (k - c * sk) % curve.q
return c, s
def establish(io, sk2, pk2):
io.recvuntil(b"my public key: ")
pk_tpl = ast.literal_eval(io.recvline().strip().decode())
pk1 = point.Point(*pk_tpl, curve=curve)
io.sendlineafter(b"x: ", str(pk2.x).encode())
io.sendlineafter(b"y: ", str(pk2.y).encode())
c, s = sign(sk2, pk2, b"foo")
io.sendlineafter(b"c: ", str(c).encode())
io.sendlineafter(b"s: ", str(s).encode())
apk = pk1 + pk2
return apk
def do_sign_get_nonce_R1(io, msg):
io.sendlineafter(b"sign?", msg.hex().encode())
io.recvuntil(b"my nonce: ")
R1_tpl = ast.literal_eval(io.recvline().strip().decode())
R1 = point.Point(*R1_tpl, curve=curve)
return R1
def send_nonce_R2_get_s(io, R2):
io.sendlineafter(b"x: ", str(R2.x).encode())
io.sendlineafter(b"y: ", str(R2.y).encode())
io.recvuntil(b"my share of the signature: ")
s1 = int(io.recvline().strip().decode())
return s1
Zp = GF(curve.q)
ell = 256
messages = [f"message{i}".encode() for i in range(ell)]
# ios = [connect() for _ in tqdm(range(ell), desc="Init")]
# apks = [establish(io, sk2, pk2) for io in tqdm(ios, desc="Establish")]
# apk = apks[0]
# R1 = [
# do_sign_get_nonce_R1(io, msg)
# for io, msg in zip(tqdm(ios, desc="Get nonce"), messages)
# ]
with ThreadPoolExecutor(max_workers=32) as executor:
ios = list(tqdm(executor.map(connect, range(ell)), desc="Init", total=ell))
apks = list(
tqdm(
executor.map(establish, ios, [sk2] * ell, [pk2] * ell),
desc="Establish",
total=ell,
)
)
apk = apks[0]
R1 = list(
tqdm(
executor.map(do_sign_get_nonce_R1, ios, messages),
desc="Get nonce",
total=ell,
)
)
print("done, forging signature...")
# idk what am I doing here, but blindly modifying the code from following sources works...
# https://eprint.iacr.org/2020/945.pdf (Appendix)
# https://qsang.xin/2023/05/23/Onthe-in-securityofROS/
# it appears to find some magic coeffficients that can forge signatures with linear combination of existing signatures...
k2 = [[random.randrange(curve.q), random.randrange(curve.q)] for i in range(ell)]
R2 = [[x * curve.G, y * curve.G] for x, y in k2]
c = [
[hash_transcript(apk, R1[i] + R2[i][b], messages[i]) for b in range(2)]
for i in range(ell)
]
g_func = lambda x, z=0: sum(
[int(Zp(2) ** i / (c[i][1] - c[i][0])) * x[i] for i in range(ell)], z
)
forged_R = g_func(R1, z=point.Point.IDENTITY_ELEMENT)
forged_message = b"flag"
forged_c = hash_transcript(apk, forged_R, forged_message)
bits = [
int(b)
for b in bin((forged_c - g_func([c[i][0] for i in range(ell)])) % curve.q)[
2:
].rjust(256, "0")
][::-1]
chosen_R2 = [R2[i][b] for (i, b) in enumerate(bits)]
with ThreadPoolExecutor(max_workers=32) as executor:
s1 = list(
tqdm(
executor.map(send_nonce_R2_get_s, ios, chosen_R2),
desc="Get signature share",
total=ell,
)
)
signatures = [
(int(c[i][bits[i]]), int(s1[i] + k2[i][bits[i]] - c[i][bits[i]] * sk2))
for i in range(ell)
]
forged_signature = (
forged_c,
(g_func(s1) - forged_c * sk2) % curve.q,
)
print([verify(apk, messages[i], signatures[i]) for i in range(ell)])
print(verify(apk, forged_message, forged_signature))
c, s = forged_signature
io = ios[0]
io.sendlineafter(b"c: ", str(c).encode())
io.sendlineafter(b"s: ", str(s).encode())
io.interactive()
# actf{th1nk_0uts1de_th3_c0nn3ct1on_d953f18e8c0870e8}