AlpacaHack Round 3 WriteUps

發表於
分類於 CTF

因為最近有段時間都沒怎麼打 CTF,也因此沒在這邊發文,所以這周就花了一小段時間 Solo 參加了這個只有 4 題 Crypto 題目的 AlpacaHack Round 3

qrime

import os
from Crypto.Util.number import bytes_to_long, getRandomNBitInteger, isPrime

def nextPrime(n):
    while not isPrime(n := n + 1):
        continue
    return n

def gen():
    while True:
        q = getRandomNBitInteger(256)
        r = getRandomNBitInteger(256)
        p = q * nextPrime(r) + nextPrime(q) * r
        if isPrime(p) and isPrime(q):
            return p, q, r

flag = os.environ.get("FLAG", "fakeflag").encode()
m = bytes_to_long(flag)

p, q, r = gen()
n = p * q

phi = (p - 1) * (q - 1)
e = 0x10001
d = pow(e, -1, phi)
c = pow(m, e, n)

print(f"{n=}")
print(f"{e=}")
print(f"{c=}")
print(f"{r=}")

簡單來說這題的 RSA 的符合:

n=pqp=q(r+d1)+(q+d2)r\begin{aligned} n &= p \cdot q \\ p &= q(r+d_1)+(q+d_2)r \end{aligned}

其中 d1,d2d_1, d_2 是兩個很小,可爆搜範圍的數字。然後題目有提供 n,rn, r 要我們分解 nn

我的做法是直接考慮 modr\mod{r},得到 pqd1(modr)p \equiv qd_1 \pmod{r},也就有 nq2d1(modr)n \equiv q^2 d_1 \pmod{r}

假設 gcd(d1,r)=1\gcd(d_1,r)=1 那麼代表 nd11n d_1^{-1}Zr\mathbb{Z}_r 下是個 quadratic residue,所以可以減少不少錯誤的 candidate。如果 gcd(d1,r)1\gcd(d_1,r) \neq 1 那就取 r=r/gcd(d1,r)r'=r/\gcd(d_1,r) 去判斷 QR 就好。

之後對每個 d1d_1 的 candidate 我們都可以嘗試求 qq,但這個的數量和 rr 的 prime factors 數量相比是指數成長的。因此我這邊只考慮 rr 的最大的 prime factor rr',求 q±q(modr)q' \equiv \pm q \pmod{r'}

因為這題 qq 只有 256,rr' 有 186 bits 超過一半,所以可以直接套 coppersmith 求解。

proof.all(False)

n = 200697881793620389197751143658858424075492240536004468937396825699483210280999214674828938407830171522000573896259413231953182108686782019862906633259090814783111593304404356927145683840948437835426703183742322171552269964159917779
e = 65537
c = 77163248552037496974551155836778067107086838375316358094323022740486805320709021643760063197513767812819672431278113945221579920669369599456818771428377647053211504958874209832487794913919451387978942636428157218259130156026601708
r = 30736331670163278077316573297195977299089049174626053101058657011068283335270

rp = factor(r)[-1][0]
F = GF(rp)

for guess_d1 in range(1, 1000):
    t = Zmod(r // gcd(r, guess_d1))(n) / guess_d1
    if not t.is_square():
        continue
    print(f"{guess_d1 = }")
    qp = ZZ(F(t).sqrt())

    x = polygen(Zmod(n))
    f = qp + x * rp
    rs = f.monic().small_roots(X=2**256 // rp, beta=0.33, epsilon=0.02)
    if rs:
        g = gcd(ZZ(f(rs[0])), n)
        if g != 1 and g != n:
            print(g)
            q = g
            # q = 57138703210086603216917938147752779170509477993762976004506899310197198907231
            p = n // q
            assert p * q == n
            phi = (p - 1) * (q - 1)
            d = inverse_mod(e, phi)
            m = pow(c, d, n)
            flag = int(m).to_bytes(100, "big").strip(b"\x00")
            print(flag)
            break
# Alpaca{q_and_r_have_nothing_to_do_with_QR_code}

不過後來發現說我這個解法其實很 overkill,因為其實可以注意到:

n=pq=q2(r+d1)+q2r+qd2rnrq2+q2+qd2\begin{aligned} n &= pq=q^2(r+d_1)+q^2r+qd_2r \\ \frac{n}{r} &\approx q^2+q^2+qd_2 \end{aligned}

所以有

n2rq\lfloor \sqrt{\frac{n}{2r}} \rfloor \approx q

而且這個接近是非常的接近,它們的差值和 d1,d2d_1, d_2 是同個量級的,所以直接爆即可:

n = 200697881793620389197751143658858424075492240536004468937396825699483210280999214674828938407830171522000573896259413231953182108686782019862906633259090814783111593304404356927145683840948437835426703183742322171552269964159917779
e = 65537
c = 77163248552037496974551155836778067107086838375316358094323022740486805320709021643760063197513767812819672431278113945221579920669369599456818771428377647053211504958874209832487794913919451387978942636428157218259130156026601708
r = 30736331670163278077316573297195977299089049174626053101058657011068283335270

q = (n // r // 2).isqrt()
while n % q:
    q -= 1
p = n // q
assert p * q == n
phi = (p - 1) * (q - 1)
d = inverse_mod(e, phi)
m = pow(c, d, n)
flag = int(m).to_bytes(100, "big").strip(b"\x00")
print(flag)

Rainbow Sweet Alchemist

import os
import random
from math import prod
from Crypto.Util.number import isPrime, bytes_to_long

r = random.Random(0)
def deterministicGetPrime():
  while True:
    if isPrime(p := r.getrandbits(64)):
      return p

# This is not likely to fail
assert deterministicGetPrime() == 2710959347947821323, "Your Python's random module is not compatible with this challenge."

def getPrime(bit):
  factors = [deterministicGetPrime() for _ in range(bit // 64)]
  while True:
    p = 2 * prod(factors) + 1
    if isPrime(p):
      return p
    factors.remove(random.choice(factors))
    factors.append(deterministicGetPrime())

flag = os.environ.get("FLAG", "fakeflag").encode()
m = bytes_to_long(flag)

p, q = getPrime(1024), getPrime(1024)
n = p * q
e = 0x10001
c = pow(m, e, n)

print(f"{n=}")
print(f"{e=}")
print(f"{c=}")

也是 RSA,不過這題 p,qp, q 是用很多 64 bits 的 deterministic prime rir_i 加上額外的 random 來生成的。

因為它的 p,qp,q 都符合 2ri+12 \cdot \prod r_i + 1,理論上可以用 pollard p-1 去 enumerate 所有 2642^{64} 以內的質數來分解。不過這實際上需要的計算量非常大所以不可行。

不過這題它的 rir_i 都是 deterministic 生成的,其實只要拿它生成的數字去做 pollard p-1 就可以了。

import os
import random
from math import prod
from Crypto.Util.number import isPrime, bytes_to_long
import gmpy2

r = random.Random(0)


def deterministicGetPrime():
    while True:
        if isPrime(p := r.getrandbits(64)):
            return p


n = 2350478429681099659482802009772446082018100644248516135321613920512468478639125995627622723613436514363575959981129347545346377683616601997652559989194209421585293503204692287227768734043407645110784759572198774750930099526115866644410725881688186477790001107094553659510391748347376557636648685171853839010603373478663706118665850493342775539671166315233110564897483927720435690486237018231160348429442602322737086330061842505643074752650924036094256703773247700173034557490511259257339056944624783261440335003074769966389878838392473674878449536592166047002406250295311924149998650337286245273761909
e = 65537
c = 945455686374900611982512983855180418093086799652768743864445887891673833536194784436479986018226808021869459762652060495495939514186099959619150594580806928854502608487090614914226527710432592362185466014910082946747720345943963459584430804168801787831721882743415735573097846726969566369857274720210999142004037914646773788750511310948953348263288281876918925575402242949315439533982980005949680451780931608479641161670505447003036276496409290185385863265908516453044673078999800497412772426465138742141279302235558029258772175141248590241406152365769987248447302410223052788101550323890531305166459


x = 48763**2
while True:
    x = gmpy2.powmod(x, deterministicGetPrime(), n)
    p = gmpy2.gcd(x - 1, n)
    if p != 1 and p != n:
        print(p)
        break
q = n // p
assert p * q == n
phi = (p - 1) * (q - 1)
d = gmpy2.invert(e, phi)
m = gmpy2.powmod(c, d, n)
flag = int(m).to_bytes(100, "big").strip(b"\x00")
print(flag)
# Alpaca{n0t_s0_sm00th_y3t_n0t_s0_s4f3}

A Dance of Add and Mul

import os
import random
from Crypto.Util.number import bytes_to_long

flag = os.environ.get("FLAG", "fakeflag").encode()
bit_length = len(flag) * 8

# BLS12-381 curve
p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab
K = GF(p)
E = EllipticCurve(K, (0, 4))

G1, G2 = E.gens()
o1, o2 = G1.order(), G2.order()

xs = [random.randrange(0, o1) for _ in range(bit_length + 1)]
m = bytes_to_long(flag)

cs = []
for c, (x1, x2) in zip(bin(m)[2:].zfill(bit_length), zip(xs[:-1], xs[1:])):
  if c == "1":
    x1, x2 = x2, x1
  cs.append(x1 * G1 + x2 * G2)

print([P.xy() for P in cs])

題目本身是生成一堆 xix_i,然後對於第 ii 個 bit 拿

(x1,x2)={(xi,xi+1)if bi=0(xi+1,xi)if bi=1(x'_1, x'_2)=\begin{cases} (x_i,x_{i+1}) & \text{if } b_i=0 \\ (x_{i+1},x_i) & \text{if } b_i=1 \end{cases}

去計算 Pi=x1G1+x2G2P_i=x'_1 G_1 + x'_2 G_2。目標是透過 PiP_i 求回原本的 bits。

因為這題用的是 BLS12-381,一個 pairing-friendly curve,所以可以很輕易的計算 pairing。

這邊會需要用到三個等式,這都可以用 pairing 的性質推導出來:

e(G1,G2)=e(G2,G1)1e(x1G1+x2G2,G1)=e(G2,G1)x2e(x1G1+x2G2,G2)=e(G1,G2)x1\begin{aligned} e(G_1,G_2) &= e(G_2,G_1)^{-1} \\ e(x_1 G_1 + x_2 G_2, G_1) &= e(G_2, G_1)^{x_2} \\ e(x_1 G_1 + x_2 G_2, G_2) &= e(G_1, G_2)^{x_1} \end{aligned}

因此取 g=e(G2,G1)g=e(G_2,G_1) 的話對於每個 PiP_i 我們可以計算:

w1=e(Pi,G1)=gx2w2=e(Pi,G2)1=gx1\begin{aligned} w_1 &= e(P_i, G_1) = g^{x'_2} \\ w_2 &= e(P_i, G_2)^{-1} = g^{x'_1} \end{aligned}

然後可以透過判斷每個 PiP_i 對應到的 w1,w2w_1, w_2 有沒有被 swap 過就能知道原本的 bits。

import ast
from binteger import Bin

p = 0x1A0111EA397FE69A4B1BA7B6434BACD764774B84F38512BF6730D2A0F6B0F6241EABFFFEB153FFFFB9FEFFFFFFFFAAAB
K = GF(p)
E = EllipticCurve(K, (0, 4))

G1, G2 = E.gens()
o1, o2 = G1.order(), G2.order()

with open("chall.txt") as f:
    cs = [E(*xy) for xy in ast.literal_eval(f.read())]

# e(G1,G2)=e(G2,G1)^-1

pairs = []
for P in cs:
    w1 = P.weil_pairing(G1, o1)  # e(G2,G1)^x2
    w2i = P.weil_pairing(G2, o2)  # e(G1,G2)^x1
    # print(w1, w2i ^ -1)
    pairs.append((w1, w2i ^ -1))

bs = [0]
w1_prev, w2_prev = pairs[0]
for w1, w2 in pairs[1:]:
    if w2 == w1_prev:
        bs.append(0)
        w1_prev, w2_prev = w1, w2
    else:
        bs.append(1)
        w1_prev, w2_prev = w2, w1
flag = Bin(bs).bytes
print(flag)
# Alpaca{this_title_is_inpired_by_a_rhythm_game}

Hat Trick

import json
import os
import random
import signal
import string
from Crypto.Util.number import getPrime, getRandomInteger

class RollingHash:
  def __init__(self, p=None, base=None) -> None:
    self.p = getPrime(64) if p is None else p
    self.base = (getRandomInteger(64) if base is None else base) % self.p
  def hash(self, s: str):
    res = 0
    for i, c in enumerate(s):
      res += ord(c) * (self.base ** i)
      res %= self.p
    return res

def check_str(s: str, max_len: int):
  assert len(s) <= max_len, "too long!"
  for i, c in enumerate(s):
    assert c in string.ascii_lowercase, f"found invalid char {c} at {i}"

signal.alarm(3 * 60)

flag = os.environ.get("FLAG", "fakeflag")
MAX_LEN = 54

rhs = [RollingHash() for _ in range(3)]
print("params:",json.dumps([{ "base": rh.base, "p": rh.p } for rh in rhs]))

for _ in range(3):
  target_hash = [random.randrange(0, rh.p) for rh in rhs]
  print('target:', target_hash)
  s = input("> ")
  check_str(s, MAX_LEN)

  actual_hash = [rh.hash(s) for rh in rhs]
  if target_hash != actual_hash:
    print("Oops! You missed the target hash. Better luck next time!")
    exit(1)

print("Congratulations! Here is your flag:", flag)

簡單來說這題是要同時對三個 rolling hash function 找的三個 target 找個共同的 preimage。而 rolling hash 本身對於一個訊息 mm 是定義為:

H(m)=i=0n1mibimodpH(m) = \sum_{i=0}^{n-1} m_i \cdot b^i \mod p

其中 mim_i 是 ascii code。而這題限制 mm 不能超過 54 字元,且只能是小寫字母。

因為顯然是個很 linear 的東西,直接列出三條等式可以發現它可以寫成 lattice,然後用 CVP 求解。不過這題它題目壓的比較緊一點,我用預設的 LLL/flatter reduced basis 求出來的東西要連續成功 3 個 round 有點困難,所以就換成 BKZ 求解就成功了。

from sage.all import *
from Crypto.Util.number import getPrime, getRandomInteger
from lll_cvp import solve_inequality, kannan_cvp, BKZ
from functools import partial
from pwn import process, remote, context
import json, ast


class RollingHash:
    def __init__(self, p=None, base=None) -> None:
        self.p = getPrime(64) if p is None else p
        self.base = (getRandomInteger(64) if base is None else base) % self.p

    def hash(self, s: str):
        res = 0
        for i, c in enumerate(s):
            res += ord(c) * (self.base**i)
            res %= self.p
        return res


def solve(params, target_hash, n=54):
    m = len(target_hash)
    assert m == len(params), "???"
    L = matrix(ZZ, n + m, n + m)
    for i in range(m):
        L[i, i] = params[i]["p"]
    for i in range(m, n + m):
        for j in range(m):
            L[i, j] = pow(params[j]["base"], i - m, params[j]["p"])
        L[i, i] = 1
    lb = target_hash + [97] * n
    ub = target_hash + [122] * n
    res = solve_inequality(L, lb, ub, cvp=partial(kannan_cvp, reduction=BKZ))
    s = bytes(res[m:]).decode()
    rhs = [RollingHash(params[i]["p"], params[i]["base"]) for i in range(m)]
    for i in range(m):
        assert rhs[i].hash(s) == target_hash[i]
    return s


# context.log_level = "debug"

io = process(["python", "server.py"])
# io = remote("34.170.146.252", 53764)
io.recvuntil(b"params: ")
params = json.loads(io.recvline())

for _ in range(3):
    io.recvuntil(b"target: ")
    target_hash = ast.literal_eval(io.recvlineS())
    print(target_hash)
    s = solve(params, target_hash)
    print(f"{s = !r}")
    io.sendline(s.encode())
print(io.recvallS())
# Alpaca{i_st1ll_h4v3_n0_id3a_why_r0ll1ng_h4sh_is_c4ll3d_th4t}