Grey Cat The Flag 2024 Qualifiers

發表於
分類於 CTF

這禮拜自己用 nyahello solo 了這場,把 crypto 的題目全解了,題目有幾題也很好玩,所以來寫個 writeup。

Filter Ciphertext

from Crypto.Cipher import AES
import os

with open("flag.txt", "r") as f:
    flag = f.read()

BLOCK_SIZE = 16
iv = os.urandom(BLOCK_SIZE)

xor = lambda x, y: bytes(a^b for a,b in zip(x,y))

key = os.urandom(16)

def encrypt(pt):
    cipher = AES.new(key=key, mode=AES.MODE_ECB)
    blocks = [pt[i:i+BLOCK_SIZE] for i in range(0, len(pt), BLOCK_SIZE)]
    tmp = iv
    ret = b""
    
    for block in blocks:
        res = cipher.encrypt(xor(block, tmp))
        ret += res
        tmp = xor(block, res)
        
    return ret

    
def decrypt(ct):
    cipher = AES.new(key=key, mode=AES.MODE_ECB)
    blocks = [ct[i:i+BLOCK_SIZE] for i in range(0, len(ct), BLOCK_SIZE)]

    for block in blocks:
        if block in secret_enc:
            blocks.remove(block)
    
    tmp = iv
    ret = b""
    
    for block in blocks:
        res = xor(cipher.decrypt(block), tmp)
        ret += res
        tmp = xor(block, res)
    
    return ret
    
secret = os.urandom(80)
secret_enc = encrypt(secret)

print(f"Encrypted secret: {secret_enc.hex()}")

print("Enter messages to decrypt (in hex): ")

while True:
    res = input("> ")

    try:
        enc = bytes.fromhex(res)

        if (enc == secret_enc):
            print("Nice try.")
            continue
        
        dec = decrypt(enc)
        if (dec == secret):
            print(f"Wow! Here's the flag: {flag}")
            break

        else:
            print(dec.hex())
        
    except Exception as e:
        print(e)
        continue

簡單來說它有個解密 oracle,然後分成 block 的時候如果檢查到某個 block 是 secret_enc 的一部份就會把它移除。然而仔細一看那個部分會發現它好像不太對:

for block in blocks:
	if block in secret_enc:
		blocks.remove(block)

這邊它在 iterating blocks 的時候同時在修改 blocks,這樣在 python 中是會出問題的。具體來說當 i=0 的時候它會把 blocks[0] 移除,然後原本在 blocks[1] 的資料就會跑到 blocks[0],這樣就會造成 blocks[1] 沒有被檢查到。所以假設每個 block 都符合 block in blocks,那代表 0,2,4,6,... 的 blocks 會被移除掉,而 1,3,5,7,... 的 blocks 就會被保留下來。

因此要解這題就把 secret_enc 分成 blocks,把每個 block 重複兩次然後 join 起來即可:

secret_enc = bytes.fromhex(input())
blks = [secret_enc[i : i + 16] for i in range(0, len(secret_enc), 16)]
new_blks = sum([[blk] * 2 for blk in blks], [])
new_secret_enc = b"".join(new_blks)
print(new_secret_enc.hex())
# grey{00ps_n3v3r_m0d1fy_wh1l3_1t3r4t1ng}

Filter Plaintext

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from hashlib import md5
import os

with open("flag.txt", "r") as f:
    flag = f.read()

BLOCK_SIZE = 16
iv = os.urandom(BLOCK_SIZE)

xor = lambda x, y: bytes(a^b for a,b in zip(x,y))

key = os.urandom(16)

def encrypt(pt):
    cipher = AES.new(key=key, mode=AES.MODE_ECB)
    blocks = [pt[i:i+BLOCK_SIZE] for i in range(0, len(pt), BLOCK_SIZE)]
    tmp = iv
    ret = b""
    
    for block in blocks:
        res = cipher.encrypt(xor(block, tmp))
        ret += res
        tmp = xor(block, res)
        
    return ret

    
def decrypt(ct):
    cipher = AES.new(key=key, mode=AES.MODE_ECB)
    blocks = [ct[i:i+BLOCK_SIZE] for i in range(0, len(ct), BLOCK_SIZE)]

    
    tmp = iv
    ret = b""
    
    for block in blocks:
        res = xor(cipher.decrypt(block), tmp)
        if (res not in secret):
            ret += res
        tmp = xor(block, res)
        
    return ret
    
secret = os.urandom(80)
secret_enc = encrypt(secret)

print(f"Encrypted secret: {secret_enc.hex()}")

secret_key = md5(secret).digest()
secret_iv = os.urandom(BLOCK_SIZE)
cipher = AES.new(key = secret_key, iv = secret_iv, mode = AES.MODE_CBC)
flag_enc = cipher.encrypt(pad(flag.encode(), BLOCK_SIZE))

print(f"iv: {secret_iv.hex()}")

print(f"ct: {flag_enc.hex()}")

print("Enter messages to decrypt (in hex): ")

while True:
    res = input("> ")

    try:
        enc = bytes.fromhex(res)
        dec = decrypt(enc)
        print(dec.hex())
        
    except Exception as e:
        print(e)
        continue

這題實作了 AES-PCBC mode,然後在解密的時候如果 plaintext 在 secret 中就會被跳過,其他部分都是正常的 decryption oracle。

AES-PCBC encryption

AES-PCBC decryption

首先因為沒有 iv,要想辦法拿 iv 才行。首先我們知道 ct0=E(ivpt0)ct_0=E(iv \oplus pt_0),如果用 (ct0,ct0)(ct_0,ct_0) 送進 decryption oracle 的話第一個 block 會是 pt0pt_0,所以會被 filter 掉。但是 iv1=ct0pt0iv_1=ct_0 \oplus pt_0,因此 oracle 給的第二個 block 的輸出會是 o=(pt0ct0)(ivpt0)=ct0ivo=(pt_0 \oplus ct_0) \oplus (iv \oplus pt_0)=ct_0 \oplus iv,因此 xor 一下就拿到 iv 了。

接下來是要想辦法拿第一個 block 的 plaintext pt0pt_0,這部分我是送 (ct0,ct1,ct1)(ct_0,ct_1,ct_1) 進去,第一和第二個 block 一樣會被 filter,不過這邊只須關注最後一個 block 就好。可以注意的輸出會是 o=(pt1ct1)(iv1pt1)=iv1ct1o=(pt_1 \oplus ct_1) \oplus (iv_1 \oplus pt_1)=iv_1 \oplus ct_1,而這邊的 iv1iv_1 指的是加密時用來加密第二個 block 的 iv,也就是 iv1=pt0ct0iv_1=pt_0 \oplus ct_0,所以這邊做一些 xor 就能拿到 pt0pt_0 了。

在有 pt0pt_0 之後後面的幾個 block 的解密就不用這麼複雜了,只要依序把要解密的 cti,i>0ct_i, i>0 送進 oracle 就好。注意 cti=E(ptiivi)=E(pticti1pti1)ct_i=E(pt_i \oplus iv_i)=E(pt_i \oplus ct_{i-1} \oplus pt_{i-1}),所以輸出會是 o=ivpticti1pti1o=iv \oplus pt_i \oplus ct_{i-1} \oplus pt_{i-1},然後拿前面一個 block 的已知的資訊 xor 就拿求出 ptipt_i 了。

from pwn import process, remote
from Crypto.Cipher import AES
from hashlib import md5

# io = process(["python", "filter_plaintext.py"])
io = remote("challs.nusgreyhats.org", 32223)
io.recvuntil(b"secret: ")
secret_enc = bytes.fromhex(io.recvlineS().strip())
io.recvuntil(b"iv: ")
secret_iv = bytes.fromhex(io.recvlineS().strip())
io.recvuntil(b"ct: ")
flag_enc = bytes.fromhex(io.recvlineS().strip())


def decrypt(m: bytes):
    io.sendlineafter(b"> ", m.hex().encode())
    return bytes.fromhex(io.recvlineS().strip())


xor = lambda x, y: bytes(a ^ b for a, b in zip(x, y))
ct = [secret_enc[i : i + 16] for i in range(0, len(secret_enc), 16)]
iv0 = xor(decrypt(b"".join([ct[0]] * 2))[-16:], ct[0])
iv1 = xor(decrypt(b"".join([ct[0]] + [ct[1]] * 2))[-16:], ct[1])
pt0 = xor(iv1, ct[0])
pt1 = xor(xor(xor(decrypt(ct[1]), iv0), pt0), ct[0])
pt2 = xor(xor(xor(decrypt(ct[2]), iv0), pt1), ct[1])
pt3 = xor(xor(xor(decrypt(ct[3]), iv0), pt2), ct[2])
pt4 = xor(xor(xor(decrypt(ct[4]), iv0), pt3), ct[3])
secret_rec = pt0 + pt1 + pt2 + pt3 + pt4
assert len(secret_rec) == 80

secret_key = md5(secret_rec).digest()
cipher = AES.new(key=secret_key, iv=secret_iv, mode=AES.MODE_CBC)
flag = cipher.decrypt(flag_enc)
print(flag)
# grey{pcbc_d3crypt10n_0r4cl3_3p1c_f41l}

AES

from secrets import token_bytes
from aes import AES

FLAG = 'REDACTED'
password = token_bytes(16)
key = token_bytes(16)

AES = AES(key)
m = bytes.fromhex(input("m: "))
if (len(m) > 4096): exit(0)
print("c:", AES.encrypt(m).hex())

print("c_p:", AES.encrypt(password).hex())
check = input("password: ")
if check == password.hex():
    print('flag:', FLAG)

這題的 AES 是個自己改過的 AES,它的 mix columns 被變成了 NO-OP,所以代表說 input 的每個 byte 是不互相影響的。即每個 input byte index 都有對應到一個 output byte index,而每個 input 上的 byte 都有個固定的轉換到 output 的 byte。

因此這邊只要用 encryption oracle 送一個 16×25616 \times 256 bytes 的輸入,包含所以 byte 在所有位置出現的可能性,那就能建一個 encryption 和 decryption 的 table,然後就能透過查表做加解密了。

from secrets import token_bytes
from aes import AES
from pwn import process, remote

# io = process(["python", "server.py"])
io = remote("challs.nusgreyhats.org", 35100)


key = token_bytes(16)
aes = AES(key)

idx_map = {}
for i in range(16):
    pt = bytearray(b"a" * 16)
    ct1 = aes.encrypt(pt)
    pt[i] ^= 1
    ct2 = aes.encrypt(pt)
    diff_idx = next(i for i, (a, b) in enumerate(zip(ct1, ct2)) if a != b)
    idx_map[i] = diff_idx
print(idx_map)


pts = [bytes([i]) * 16 for i in range(256)]
pt = b"".join(pts)
io.sendline(pt.hex().encode())
io.recvuntil(b"c: ")
ct = bytes.fromhex(io.recvlineS().strip())
cts = [ct[i : i + 16] for i in range(0, len(ct), 16)]

enc_map = {}  # (idx, val) -> (idx, val)
dec_map = {}  # (idx, val) -> (idx, val)
for pt, ct in zip(pts, cts):
    for i, a in enumerate(pt):
        j = idx_map[i]
        b = ct[j]
        enc_map[(i, a)] = (j, b)
        dec_map[(j, b)] = (i, a)


def encrypt(enc_map, pt):
    ct = [0] * 16
    for i, a in enumerate(pt):
        j, b = enc_map[(i, a)]
        ct[j] = b
    return bytes(ct)


def decrypt(dec_map, ct):
    pt = [0] * 16
    for j, b in enumerate(ct):
        i, a = dec_map[(j, b)]
        pt[i] = a
    return bytes(pt)


io.recvuntil(b"c_p: ")
pwd_ct = bytes.fromhex(io.recvlineS().strip())
pwd_ct_blks = [pwd_ct[i : i + 16] for i in range(0, len(pwd_ct), 16)]
pwd_pt_blks = [decrypt(dec_map, cblk) for cblk in pwd_ct_blks]
pwd = b"".join(pwd_pt_blks)[:16]
io.sendline(pwd.hex().encode())
io.interactive()
# grey{mix_column_is_important_in_AES_ExB3Hf9q9I3m}

PRG

from secrets import token_bytes, randbits
from param import A 
import numpy as np

FLAG = 'REDACTED'

A = np.array(A)

def print_art():
    print(r"""
            />_________________________________
    [########[]_________________________________>
            \>
    """)
    
def bytes_to_bits(s):
    return list(map(int, ''.join(format(x, '08b') for x in s)))

def bits_to_bytes(b):
    return bytes(int(''.join(map(str, b[i:i+8])), 2) for i in range(0, len(b), 8))

def prg(length):
    x = token_bytes(8); r = token_bytes(8); k = token_bytes(8)
    x = np.array(bytes_to_bits(x)); r = np.array(bytes_to_bits(r)); k = np.array(bytes_to_bits(k))
    output = []
    for i in range(length * 8):
        output.append(sum(x) % 2)
        if (i % 3 == 0): x = (A @ x + r) % 2
        if (i % 3 == 1): x = (A @ x + k) % 2
        if (i % 3 == 2): x = (A @ x + r + k) % 2
    output = output
    return bits_to_bytes(output).hex()
    
def true_random(length):
    return token_bytes(length).hex()

def main():
    try:
        print_art()
        print("I try to create my own PRG")
        print("This should be secure...")
        print("If you can win my security game for 100 times, then I will give you the flag")
        for i in range(100):
            print(f"Game {i}")
            print("Output: ", end="")
            game = randbits(1)
            if (game): print(prg(16))
            else: print(true_random(16))
            guess = int(input("What's your guess? (0/1): "))
            if guess != game:
                print("You lose")
                return
        print(f"Congrats! Here is your flag: {FLAG}")
    except Exception as e:
        return

if __name__ == "__main__":
    main()

上面的 param.A 是個 64×6464 \times 64 binary matrix,然後題目目標是要去 distinguish prgtrue_random 生成出來的輸出。

因為 prg 看起來就很 linear,我的想法是直接 symbolic 模擬它的 function,然後會得到一個 M[xrk]T=ouptutM [x \, r \, k]^T = ouptut 的系統,看它有沒有解就是了。雖然輸出只有 128 bits,而 x,r,kx,r,k 的未知數總共有 192 bits,所以可能會覺得 span(M)=F2128\operatorname{span}(M) = \mathbb{F}_2^{128} 所以可能沒辦法分辨,但實際上會發現 rank(M)=63\operatorname{rank}(M)=63,所以分辨錯誤的機率是 2612^{-61}

from sage.all import *
from pwn import process, remote
from param import A

# quick fix: https://github.com/sagemath/sage/issues/37837
from sage.rings.polynomial.multi_polynomial_sequence import (
    PolynomialSequence_generic,
    PolynomialSequence_gf2,
)

PolynomialSequence_gf2.coefficients_monomials = (
    PolynomialSequence_generic.coefficients_monomials
)


def bytes_to_bits(s):
    return list(map(int, "".join(format(x, "08b") for x in s)))


F = GF(2)
Asage = matrix(F, A)
PR = PolynomialRing(
    F,
    [f"x{i}" for i in range(64)]
    + [f"r{i}" for i in range(64)]
    + [f"k{i}" for i in range(64)],
)
x = vector(PR.gens()[:64])
r = vector(PR.gens()[64:128])
k = vector(PR.gens()[128:192])
syms = []
for i in range(16 * 8):
    syms.append(sum(x))
    if i % 3 == 0:
        x = Asage * x + r
    if i % 3 == 1:
        x = Asage * x + k
    if i % 3 == 2:
        x = Asage * x + r + k
M, _ = Sequence(syms).coefficients_monomials(sparse=False)
print(M.dimensions())
print(M.rank())

# io = process(["python", "server.py"])
io = remote("challs.nusgreyhats.org", 35101)
for rnd in range(100):
    print(rnd)
    io.recvuntil(b"Output: ")
    out = bytes.fromhex(io.recvlineS().strip())
    vs = vector(F, bytes_to_bits(out))
    try:
        M.solve_right(vs)
        io.sendline(b"1")
    except Exception:
        io.sendline(b"0")
io.interactive()
# grey{Not_so_easy_to_construct_a_secure_PRG_LaQSqprzmTjBZs8ygMkGuw}

IPFE

server.py:

from IPFE import IPFE, _FeDDH_C
from secrets import randbits

FLAG = 'REDACTED'

# Prime from generate_prime()
# To save server resource, we use a fix prime
p = 16288504871510480794324762135579703649765856535591342922567026227471362965149586884658054200933438380903297812918052138867605188042574409051996196359653039
q = (p - 1) // 2

n = 5
key = IPFE.generate(n, (p, q))
print("p:", key.p)
print("g:", key.g)
print("mpk:", list(map(int, key.mpk)))

while True:
    '''
    0. Exit
    1. Encrypt (You can do this yourself honestly)
    2. Generate Decryption Key
    3. Challenge
    '''
    option = int(input("Option: "))
    if (option == 0):
        exit(0)
    elif (option == 1):
        x = list(map(int, input("x: ").split()))
        c = IPFE.encrypt(x, key)
        print("g_r:", c.g_r)
        print("c:", list(map(int, c.c)))
    elif (option == 2):
        y = list(map(int, input("y: ").split()))
        g_r = int(input("g_r: "))
        dummy_c = _FeDDH_C(g_r, [])
        dk = IPFE.keygen(y, key, dummy_c)
        print("s_k:", int(dk.sk))
    elif (option == 3):
        challenge = [randbits(40) for _ in range(n)]
        c = IPFE.encrypt(challenge, key)
        print("g_r:", c.g_r)
        print("c:", list(map(int, c.c)))
        check = list(map(int, input("challenge: ").split()))
        if (len(check) == n and all([x == y for x, y in zip(challenge, check)])):
            print("flag:", FLAG)
        exit(0)

IPFE.py:

from Crypto.Util.number import getPrime, isPrime, inverse
from secrets import randbelow
from gmpy2 import mpz
from typing import List, Tuple

# References:
# https://eprint.iacr.org/2015/017.pdf

def generate_prime():
    while True:
        q = getPrime(512)
        p = 2 * q + 1
        if isPrime(p):
            return mpz(p), mpz(q)
        
def discrete_log_bound(a, g, bounds, p):
    cul = pow(g, bounds[0], p)
    for i in range(bounds[1] - bounds[0] + 1):
        if cul == a:
            return i + bounds[0]
        cul = (cul * g) % p
    raise Exception(f"Discrete log for {a} under base {g} not found in bounds ({bounds[0]}, {bounds[1]})")

class _FeDDH_MK:
    def __init__(self, g, n: int, p: int, q: int, mpk: List[int], msk: List[int]=None):
        self.g = g
        self.n = n
        self.p = p
        self.q = q
        self.msk = msk
        self.mpk = mpk

    def has_private_key(self) -> bool:
        return self.msk is not None

    def get_public_key(self):
        return _FeDDH_MK(self.g, self.n, self.p, self.q, self.mpk)
    
class _FeDDH_SK:
    def __init__(self, y: List[int], sk: int):
        self.y = y
        self.sk = sk

class _FeDDH_C:
    def __init__(self, g_r: int, c: List[int]):
        self.g_r = g_r
        self.c = c

    
class IPFE:
    @staticmethod
    def generate(n: int, prime: Tuple[int, int] = None):
        if (prime == None): p, q = generate_prime()
        else: p, q = prime
        g = mpz(randbelow(p) ** 2) % p
        msk = [randbelow(q) for _ in range(n)]
        mpk = [pow(g, msk[i], p) for i in range(n)]

        return _FeDDH_MK(g, n, p, q, mpk=mpk, msk=msk)

    @staticmethod
    def encrypt(x: List[int], pub: _FeDDH_MK) -> _FeDDH_C:
        if len(x) != pub.n:
            raise Exception("Encrypt vector must be of length n")
        
        r = randbelow(pub.q)
        g_r = pow(pub.g, r, pub.p)
        c = [(pow(pub.mpk[i], r, pub.p) * pow(pub.g, x[i], pub.p)) % pub.p for i in range(pub.n)]

        return _FeDDH_C(g_r, c)
    
    @staticmethod
    def decrypt(c: _FeDDH_C, pub: _FeDDH_MK, sk: _FeDDH_SK, bound: Tuple[int, int]) -> int:
        cul = 1
        for i in range(pub.n):
            cul = (cul * pow(c.c[i], sk.y[i], pub.p)) % pub.p
        cul = (cul * inverse(sk.sk, pub.p)) % pub.p
        return discrete_log_bound(cul, pub.g, bound, pub.p)
    
    @staticmethod
    def keygen(y: List[int], key: _FeDDH_MK, c: _FeDDH_C) -> _FeDDH_SK:
        if len(y) != key.n:
            raise Exception(f"Function vector must be of length {key.n}")
        if not key.has_private_key():
            raise Exception("Private key not found in master key")
        
        t = sum([key.msk[i] * y[i] for i in range(key.n)]) % key.q
        sk = pow(c.g_r, t, key.p)
        return _FeDDH_SK(y, sk)
    
if __name__ == "__main__":
    n = 10
    key = IPFE.generate(n)
    x = [i for i in range(n)]
    y = [i + 10 for i in range(n)]
    c = IPFE.encrypt(x, key)
    sk = IPFE.keygen(y, key, c)
    m = IPFE.decrypt(c, key.get_public_key(), sk, (0, 1000))
    expected = sum([a * b for a, b in zip(x, y)])
    assert m == expected

總之這題用 IPFE,一個 functional encryption for inner product 的 scheme 弄了一個題目。目標是得到 encrypt 的結果之後可以恢復原本的 challenge。本身是在一個 Fp\mathbb{F}_p 中的一個 prime order subgroup 做的,order 和 generator 記為 (q,g)(q,g)

Generate master key:

pki=gskipk_i = g^{sk_i}

Encrypt:

enc=(gr,c)=(gr,[pkirgxi])=(gr,[grski+xi])enc = (g_r, c) = (g^r, [pk_i^r \cdot g^{x_i}]) = (g^r, [g^{r \cdot sk_i + x_i}])

Keygen:

sk=grt=grskiyisk = g_r^t = g^{r \sum{sk_i \cdot y_i}}

Decrypt:

u=sk1ciyi=sk1gyirski+xiyi=gxiyixiyi=loggu\begin{gather} u = sk^{-1} \cdot \prod{c_i^{y_i}} = sk^{-1} \cdot g^{\sum y_i r sk_i + x_i y_i} = g^{\sum x_i y_i} \\ \sum x_i y_i = \log_g u \end{gather}

然後 server 提供了 keygen 的 oracle 可以自己提供 gr,yg_r, y 去 keygen,而這個 oracle 的可以攻擊的方法就是指定一個非 quadratic residue 的 grg_r,order 為 2q=p12q=p-1,那麼配合 y=[2i0000]y=[2^i \, 0 \, 0 \, 0 \, 0] 就能用 modq\mod{q} 下的 LSB oracle 得到 sk0sk_0。用同樣的方法就能求得所有的 skisk_i

之後 challenge 時可以得到 gxig^{x_i},因為只有 40 bit 所以直接 sage discrete_log 搞定。

from pwn import process, remote
import ast
from sage.all import GF, discrete_log

p = 16288504871510480794324762135579703649765856535591342922567026227471362965149586884658054200933438380903297812918052138867605188042574409051996196359653039
q = (p - 1) // 2
n = 5

# io = process(["python", "server.py"])
io = remote("challs.nusgreyhats.org", 35102)
io.recvuntil(b"g: ")
g = int(io.recvlineS().strip())
io.recvuntil(b"mpk: ")
mpk = ast.literal_eval(io.recvlineS().strip())


def oracle(g_r, y):
    io.sendline(b"2")
    io.sendline(" ".join(map(str, y)).encode())
    io.sendline(str(g_r).encode())
    io.recvuntil(b"s_k: ")
    return int(io.recvlineS().strip())


def recover(oracle, idx):
    g_r = 7
    assert pow(g_r, q, p) != 1
    a = pow(2, -1, q)
    bits = []
    for i in range(q.bit_length()):
        y = [0, 0, 0, 0, 0]
        y[idx] = pow(2, -i, q)
        r = int(pow(oracle(g_r, y), q, p) != 1)  # lsb
        k = sum((pow(a, i + 1, q) * b) % q for i, b in enumerate(bits[::-1])) % q
        bits.append((r - k) % 2)
    return sum(b << i for i, b in enumerate(bits))


def batch_oracle(g_r_ar, y_ar):
    for g_r, y in zip(g_r_ar, y_ar):
        io.sendline(b"2")
        io.sendline(" ".join(map(str, y)).encode())
        io.sendline(str(g_r).encode())
    sk_ar = []
    for _ in range(len(g_r_ar)):
        io.recvuntil(b"s_k: ")
        sk_ar.append(int(io.recvlineS().strip()))
    return sk_ar


def recover_batch(batch_oracle, idx):
    g_r = 7
    assert pow(g_r, q, p) != 1
    a = pow(2, -1, q)
    ys = []
    for i in range(q.bit_length()):
        y = [0, 0, 0, 0, 0]
        y[idx] = pow(2, -i, q)
        ys.append(y)
    sk_ar = batch_oracle([g_r] * len(ys), ys)
    bits = []
    for i in range(q.bit_length()):
        r = int(pow(sk_ar[i], q, p) != 1)  # lsb
        k = sum((pow(a, i + 1, q) * b) % q for i, b in enumerate(bits[::-1])) % q
        bits.append((r - k) % 2)
    return sum(b << i for i, b in enumerate(bits))


msk = []
for i in range(n):
    # sk = recover(oracle, i)
    sk = recover_batch(batch_oracle, i)
    assert pow(g, sk, p) == mpk[i]
    msk.append(sk)
    print(i, sk)

io.sendline(b"3")
io.recvuntil(b"g_r: ")
g_r = int(io.recvlineS().strip())
io.recvuntil(b"c: ")
c = ast.literal_eval(io.recvlineS().strip())

gxs = []
for i in range(n):
    gsr = pow(g_r, msk[i], p)
    gx = pow(c[i] * pow(gsr, -1, p), 1, p)
    gxs.append(gx)
    print(i, gx)

F = GF(p)

challenge = []
for i in range(n):
    x = discrete_log(F(gxs[i]), F(g), bounds=(0, 2**40))
    challenge.append(x)
    print(i, x)

io.sendline(" ".join(map(str, challenge)).encode())
io.interactive()
# grey{catostrophic_failure_7eE37WLLdYgg}

Coding

#!/usr/local/bin/python

from secrets import randbelow
from numpy.linalg import matrix_rank
import numpy as np

FLAG = 'REDACTED'

n = 100
k = int(n * 2)
threshold = 0.05

M = []

def matrix_to_bits(G):
    return "".join([str(x) for x in G.flatten()])

def bits_to_matrix(s):
    assert len(s) == n * k
    return np.array([[int(s[i * k + j]) for j in range(k)] for i in range(n)]) % 2

def setupMatrix(G):
    assert G.shape == (n, k)
    global M

    perm = np.array([i for i in range(k)])
    np.random.shuffle(perm)
    PermMatrix = []
    for i in range(k):
        row = [0 for _ in range(k)]
        row[perm[i]] = 1
        PermMatrix.append(row)
    PermMatrix = np.array(PermMatrix)

    while True:
        S = np.array([[randbelow(2) for _ in range(n)] for i in range(n)])        
        if matrix_rank(S) == n:
            break

    M = (S @ G @ PermMatrix) % 2

def initialize():
    G = np.array([[randbelow(2) for _ in range(k)] for i in range(n)])
    setupMatrix(G)

def encrypt(m):
    original = (m @ M) % 2

    noise = [0 for _ in range(k)]
    for i in range(k):
        if randbelow(1000) < threshold * 1000:
            noise[i] = 1
    noise = np.array(noise)

    ciphertext = (original + noise) % 2
    return ciphertext

initialize()
print("M:", matrix_to_bits(M))

while True:
    '''
    0. Exit
    1. Set Matrix
    2. Encrypt (You can do this yourself honestly)
    3. Challenge
    '''
    option = int(input("Option: "))
    if (option == 0):
        exit(0)
    elif (option == 1):
        G = bits_to_matrix(input("G: ").strip()) % 2
        setupMatrix(G)
        print("M:", matrix_to_bits(M))
    elif (option == 2):
        m = np.array([randbelow(2) for _ in range(n)])
        print("m:", matrix_to_bits(m))
        print("c:", matrix_to_bits(encrypt(m)))
    elif (option == 3):
        count = 0
        for _ in range(200):
            print("Attempt:", _)
            challenge = np.array([randbelow(2) for _ in range(n)])
            check_arr = []
            print("c:", matrix_to_bits(encrypt(challenge)))
            for i in range(20):
                check = input("challenge: ").strip()
                check_arr.append(check)
            if matrix_to_bits(challenge) in check_arr:
                count += 1
                print("Correct!")
            else:
                print("Incorrect!")
        print(f"You got {count} out of 200")
        if (count >= 120):
            print("flag:", FLAG)
        else:
            print("Failed")
        exit(0)
    else:
        print("Invalid option")

這題就 McEliece 的題目,generator GG 可以自己指定,但是會搭配 server 隨機產生的 S,PS,P 生成 M=SGPM=SGP。目標是要能夠有效率的對 c=mM+ec=mM+e 做解碼,noise 數量大概是 200×0.05=10200 \times 0.05 = 10 左右,用二項分布抓 ±2σ\pm 2\sigma 的話大概是 551515 左右。

我這邊是直接把 MM 視為一個 n,k=200,100n,k=200,100 的 linear code,發現說 sage 的 LeeBrickellISDAlgorithm 其實已經能很有效率的在這個 error 數量下解碼了,最多加個 timeout 就好。

from sage.all import *
from sage.coding.linear_code import LinearCode
from sage.coding.information_set_decoder import LeeBrickellISDAlgorithm
from pwn import process, remote, context
import signal

n = 100
k = int(n * 2)
threshold = 0.05
F = GF(2)


def bits_to_matrix(s):
    assert len(s) == n * k
    return matrix(F, [[int(s[i * k + j]) for j in range(k)] for i in range(n)])


def timed_call(fn, args, timeout=1):
    def handler(signum, frame):
        raise TimeoutError()

    signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout)
    try:
        return fn(*args)
    finally:
        signal.alarm(0)


# io = process(["python", "server.py"])
io = remote("challs.nusgreyhats.org", 35103)
io.recvuntil(b"M: ")
M_str = io.recvlineS().strip()
M = bits_to_matrix(M_str)
assert M.rank() == n, ":("
C = LinearCode(M)
D = LeeBrickellISDAlgorithm(C, decoding_interval=(5, 15))
io.sendline(b"3")

correct_cnt = 0
for rnd in range(200):
    print(rnd)
    io.recvuntil(b"c: ")
    c_str = io.recvlineS().strip()
    ct = vector(F, [int(x) for x in c_str])
    try:
        if correct_cnt >= 120:
            raise TimeoutError()
        dec = timed_call(D.decode, (ct,), timeout=1)
        print("noise", ct - dec)
        sol = M.solve_left(dec)
        sol_str = "".join([str(x) for x in sol])
        sols = [sol_str] * 20
        print(sol_str)
    except TimeoutError:
        print("timeout")
        dummy = "1" * n
        sols = [dummy] * 20
    io.recvuntil(b"challenge: ")
    io.sendline("\n".join(sols).encode())
    io.recvuntil(b"challenge: " * 19)
    res = io.recvline()
    print(res)
    correct_cnt += int(res == b"Correct!\n")
    print(f"{correct_cnt}/{rnd + 1}")
io.interactive()
# grey{what_linear_code_did_you_use?_zcUQenJEv4wXNRYxB5pPkH}

Curve

from param import p, b, n
from secrets import randbelow
from hashlib import shake_128

def encrypt(msg, key):
    y = shake_128("".join(map(str, key)).encode()).digest(len(msg))
    return bytes([msg[i] ^^ y[i] for i in range(len(msg))])
    
FLAG = b'REDACTED'
m = 150

F1 = GF(p)
F2.<u> = GF(p^2)

hidden = [randbelow(2) for _ in range(m)]
factors = []
output = []

for _ in range(900):
    E1 = EllipticCurve(GF(p), [0, b])
    E2 = EllipticCurve(F2, [0, F2.random_element()])

    g = E1.random_point()
    h = E2.random_point()

    factor = [randbelow(2) for _ in range(m)]
    k = sum([hidden[i] * factor[i] for i in range(m)]) % 2
    factors.append(factor)
    
    if (k):
        x, y, z = [randbelow(n) for _ in range(3)]
    else:
        x, y = [randbelow(n) for _ in range(2)]
        z = x * y

    output.append([g, x * g, y * g, z * g, h, x * h])

output = [[(point[0], point[1]) for point in row] for row in output]

f = open("output.txt", "w")
f.write(f"c='{encrypt(FLAG, hidden).hex()}'\n")
f.write(f"{factors=}\n")
f.write(f"{output=}\n")

這題是這場 CTF 最有趣也最難的 crypto 題目。首先有

E1(Fp):y2=x3+bE2(Fp2):y2=x3+r\begin{gather} E_1(\mathbb{F}_p): y^2 = x^3 + b \\ E_2(\mathbb{F}_{p^2}): y^2 = x^3 + r \end{gather}

還有隨機的 gE1,hE2g \in E_1, h \in E_2

其中 b=6b=6 固定,rFp2r \in \mathbb{F}_{p^2} 是隨機的。之後 factor, hidden 那邊是個 F2\mathbb{F}_2 的 linear system 可以先不管它,只需要關注 k=0k=0k=1k=1 的情況,會有 z=xyz=xy or zxyz\neq xy 的情況。

然後它會提供給你 g,xg,yg,zg,h,xhg,xg,yg,zg,h,xh,要想辦法去 distinguish zxyz_xy or zxyz\neq xy 的 case 去得到 kk,然後就可以解 linear system 得到 hidden

要判斷這個的話我的想法是透過 pairing 來做,因為:

e(xg,yg)=e(g,xyg)=?e(g,zg)e(xg,yg)=e(g,xyg) \overset{?}{=} e(g,zg)

然而 pairing 對於線性獨立的 P,QP,Q (x,P=xQ\exists x ,\, P=xQ) 來說 e(P,Q)=e(P,P)x=1e(P,Q)=e(P,P)^x=1,所以只靠 E1E_1 上的 gg 是做不到的。所以我是想說希望 E2E_2hh 某種程度上和 gg 是 linear independent 的,但是要做 pairing 的話兩個點都需要在同個曲線上才行。

總之先對題目的參數做一些檢查:

p = 2956673455706017732726395787961404603421884201335599271480572766387023983052387933523497453145098834977111426152922969191412599940148797425165972377716091055492258826885933065623708772008210452334640417645070465335307327936488470721870990368244784905723647386940034701846717718881287895717648756082302396599141
n = 2956673455706017732726395787961404603421884201335599271480572766387023983052387933523497453145098834977111426152922969191412599940148797425165972377716091001116956936174395309176601513634561168168290944226676456373026396891854015965003822945171727063054140667291090754603076704996926282408059289968479821894541
b = 6

首先可知 j(E1)=j(E2)=0j(E_1)=j(E_2)=0,所以它們只要是在同個 field KK 上的話 E1,E2E_1,E_2 就是同構的,再來是 n=#E1n=\# E_1 符合 p1210(modn)p^{12}-1 \equiv 0 \pmod{n},也就是說 E1E_1 的 embedding degree k=12k=12,因此取 K=Fp12K=\mathbb{F}_{p^{12}} 的話就能有效率的在 E1(K)E_1(K) 上做 pairing。

看到這邊如果有經驗的話可能會發現它和 BLS12-381 很類似

E1(Fp)E_1(\mathbb{F}_p) 上的點轉換到 E1(K)E_1(K) 是很容易的,因為 (x,y)E1(Fp)(x,y) \in E_1(\mathbb{F}_p) 自然也在 KK 上符合 y2=x3+by^2=x^3+b。困難的是在於 E2E_2 上的點要怎麼轉換到 KK 上。

這部分我是這麼做的,因為 Fp2=Fp[x]/f(x)\mathbb{F}_{p^2}=\mathbb{F}_p[x]/f(x),其中 f(x)f(x) 是個 irreducible polynomial。而 field 的 generator uu 會符合 f(u)=0f(u)=0。例如按照題目的 sage code 來說它這邊生出來的 f(x)=x2+x3f(x)=x^2+x-3,所以 u2+u3=0u^2+u-3=0

然後 K=Fp12=Fp[x]/g(x)K=\mathbb{F}_{p^{12}}=\mathbb{F}_p[x]/g(x),其中 g(x)g(x) 是個 degree 12 的另一個 irreducible polynomial。所以我們如果需要把 uu 轉換到 KK 上的話就是直接在 K[t]K[t] 上對 f(t)=t2+t3f(t)=t^2+t-3 求根 uKu' \in K,然後就把原本 E2E_2 上的 h,xhh,xh 做個替換 uuu \rightarrow u'。同時也會存在一個 rr' 是把原本的 rruuu \rightarrow u' 替換的結果,那條 curve 就是 E2(Fp12)E_2(\mathbb{F}_{p^{12}})

接下來兩個 y2=x3+b1y^2=x^3+b_1y2=x3+b2y^2=x^3+b_2 因為 j-invariant 相同所以是同構,這部分可以用 sage 的 isomorphism_to 做或是直接用 (x,y)(u2x,u3y),b1u6=b2(x,y) \rightarrow (u^2 x, u^3 y), b_1 u^6=b_2 做轉換搞定。

所以這樣我們可以把所有的 g,xg,yg,zg,h,xhg,xg,yg,zg,h,xh 都轉換到 E1(K)E_1(K) 上,然後用 tate pairing 就能判斷 z=?xyz \overset{?}{=} xy 得到 kk 了。

還有個要注意的一點是轉換過去的 g,hg,h 不一定是線性獨立的,所以會需要檢查 e(ϕ(g),π(h))1e(\phi(g), \pi(h)) \neq 1,才行。

from hashlib import shake_128

def encrypt(msg, key):
    y = shake_128("".join(map(str, key)).encode()).digest(len(msg))
    return bytes([msg[i] ^^ y[i] for i in range(len(msg))])

proof.arithmetic(False)

p = 2956673455706017732726395787961404603421884201335599271480572766387023983052387933523497453145098834977111426152922969191412599940148797425165972377716091055492258826885933065623708772008210452334640417645070465335307327936488470721870990368244784905723647386940034701846717718881287895717648756082302396599141
n = 2956673455706017732726395787961404603421884201335599271480572766387023983052387933523497453145098834977111426152922969191412599940148797425165972377716091001116956936174395309176601513634561168168290944226676456373026396891854015965003822945171727063054140667291090754603076704996926282408059289968479821894541
b = 6

F1 = GF(p)
F2.<u> = GF(p^2)
E1 = EllipticCurve(F1, [0, b])

k = 12
K.<a> = GF(p^k)
PR.<t> = PolynomialRing(K)
uk = K(str(F2.modulus().change_ring(ZZ)(t).roots(multiplicities=False)[0]))
EK = EllipticCurve(K, [0, b])

with open('output.txt') as f:
    exec(f.read())
flag_ct = bytes.fromhex(c)

lhs = []
rhs = []
iso_cache = {}
for i, (g, xg, yg, zg, h, xh) in enumerate(output):
    g = E1(g)
    xg = E1(xg)
    yg = E1(yg)
    zg = E1(zg)
    b1 = h[1] ** 2 - h[0] ** 3
    b2 = xh[1] ** 2 - xh[0] ** 3
    assert b1 == b2
    E2 = EllipticCurve(F2, [0, b1])
    h = E2(h)
    xh = E2(xh)


    def E2_to_EK(P):
        x, y = P.xy()
        x = x.polynomial().change_ring(ZZ)(uk)
        y = y.polynomial().change_ring(ZZ)(uk)
        return x, y

    x1, y1 = E2_to_EK(h)
    x2, y2 = E2_to_EK(xh)
    # aa, bb = matrix([[x1, 1], [x2, 1]]).solve_right(
    #     vector([y1**2 - x1**3, y2**2 - x2**3])
    # )
    # EK2 = EllipticCurve(K, [aa, bb])
    # phi = EK2.isomorphism_to(EK) if (aa, bb) not in iso_cache else iso_cache[(aa, bb)]
    # iso_cache[(aa, bb)] = phi
    # hk = phi(EK2(x1, y1))
    # xhk = phi(EK2(x2, y2))
    aa = 0
    bb = y1 ** 2 - x1 ** 3
    uuu = (b / bb).nth_root(6)
    hk = EK(x1 * uuu**2, y1 * uuu**3)
    xhk = EK(x2 * uuu**2, y2 * uuu**3)
    gk = EK(g)
    xgk = EK(xg)
    ygk = EK(yg)
    zgk = EK(zg)
    if gk.tate_pairing(hk, n, 12) == 1:
        print(":(", i)
        continue
    k = int(ygk.tate_pairing(xhk, n, 12) != zgk.tate_pairing(hk, n, 12))
    lhs.append(factors[i])
    rhs.append(k)
    print("collected", i)

sol = list(matrix(GF(2), lhs).solve_right(vector(rhs)))
print(encrypt(flag_ct, sol).decode())
# grey{tate_ate_weil_VfWZTKzMmgYhpEL7xvRwFu}