DownUnderCTF 2023 Writeups

發表於
分類於 CTF

今年在 ${CyStick} 中參加了今年的 DownUnderCTF,解了幾個題目而已。

Crypto

apbq rsa i

from Crypto.Util.number import getPrime, bytes_to_long
from random import randint

p = getPrime(1024)
q = getPrime(1024)
n = p * q
e = 0x10001

hints = []
for _ in range(2):
    a, b = randint(0, 2**12), randint(0, 2**312)
    hints.append(a * p + b * q)

FLAG = open('flag.txt', 'rb').read().strip()
c = pow(bytes_to_long(FLAG), e, n)
print(f'{n = }')
print(f'{c = }')
print(f'{hints = }')

RSA,並額外提供了:

h1=a1p+b1qh2=a2p+b2q\begin{aligned} h_1&=a_1 p+b_1 q \\ h_2&=a_2 p+b_2 q \end{aligned}

其中 a1,a2a_1, a_2 很小,所以爆一下之後求

a2h1a1h2=(a2b1a1b2)qa_2h_1-a_1h_2=(a_2b_1-a_1b_2)q

qq 的倍數,和 n=pqn=pq gcd 結束。

from output import *
from itertools import product
from math import gcd
from tqdm import tqdm

h1, h2 = hints

for a, b in product(range(2**12), repeat=2):
    q = gcd(a * h1 - b * h2, n)
    if q != 1 and q < n:
        print(q, n)
        break
q = 131749193259488372734882395267267400452018470526669557625739671139987328020291297864452866159092612057179256194270162132453379915244109293125925735503860433762913331882646107732350881847176645056234707603429859084687993766987515407413263165507571230913665811470277548803428982720272589512297691853544766981321
p = n // q
e = 0x10001
d = pow(e, -1, (p - 1) * (q - 1))
m = pow(c, d, n)
flag = m.to_bytes(1024, "big").strip(b"\x00")
print(flag)
# DUCTF{gcd_1s_a_g00d_alg0r1thm_f0r_th3_t00lbox}

apbq rsa ii

from Crypto.Util.number import getPrime, bytes_to_long
from random import randint

p = getPrime(1024)
q = getPrime(1024)
n = p * q
e = 0x10001

hints = []
for _ in range(3):
    a, b = randint(0, 2**312), randint(0, 2**312)
    hints.append(a * p + b * q)

FLAG = open('flag.txt', 'rb').read().strip()
c = pow(bytes_to_long(FLAG), e, n)
print(f'{n = }')
print(f'{c = }')
print(f'{hints = }')

和前一題幾乎一樣,不過多了一個 hint,且 aia_i 的範圍變大了。

總之有三條等式,先把未知的 p,qp, q 消除得到:

a1a3b2h1a1a2b3h1a1a3b1h2+a12b3h2+a1a2b1h3a12b2h3=0a_1 a_3 b_2 h_1 - a_1 a_2 b_3 h_1 - a_1 a_3 b_1 h_2 + a_1^2 b_3 h_2 + a_1 a_2 b_1 h_3 - a_1^2 b_2 h_3 = 0

因為 hih_i 比較大,可以嘗試 LLLLLL 求出 short vector:

s=(a1a3b2a1a2b3,a1a3b1+a12b3,a1a2b1a12b2)\vec{s}=(a_1 a_3 b_2 - a_1 a_2 b_3, -a_1 a_3 b_1 + a_1^2 b_3, a_1 a_2 b_1 - a_1^2 b_2)

不過因為 a1a_1 在三個都有出現,所以實際上得到的是:

s=(a3b2a2b3,a3b1+a1b3,a2b1a1b2)=(t,u,v)\begin{aligned} \vec{s'}&=(a_3 b_2 - a_2 b_3, -a_3 b_1 + a_1 b_3, a_2 b_1 - a_1 b_2) \\ &=(t,u,v) \end{aligned}

還有可能是三個的 gcd 不為 1,不過大不了爆一下就行

因為 LLL 所以可能還有正負號的問題要爆,不過這個好處理。總之我們可以拿六個未知數 ai,bia_i, b_it,u,vt,u,v 求 groebner basis,發現說裡面有兩個等式:

k1a1+k2a2+k3a3=0k1b1+k2b2+k3b3=0\begin{aligned} k_1 a_1+k_2 a_2+k_3 a_3 &= 0 \\ k_1 b_1+k_2 b_2+k_3 b_3 &= 0 \end{aligned}

所以拿 kik_i 做 LLL 再小爆一下就是 ai,bia_i, b_i 了,之後就和前一題一樣 gcd 搞定。

from sage.all import *
from Crypto.Util.number import getPrime, bytes_to_long
from random import randint
from itertools import product, combinations
from output import *


# assert a2 * h1 - a1 * h2 == (a2 * b1 - a1 * b2) * q
# assert a3 * h1 - a1 * h3 == (a3 * b1 - a1 * b3) * q
# assert (a3 * b1 - a1 * b3) * (a2 * h1 - a1 * h2) == (a2 * b1 - a1 * b2) * (a3 * h1 - a1 * h3)
# so: a1*a3*b2*h1 - a1*a2*b3*h1 - a1*a3*b1*h2 + a1^2*b3*h2 + a1*a2*b1*h3 - a1^2*b2*h3 == 0
# we try to use LLL to find them

h1, h2, h3 = hints
L = matrix(hints).T.augment(matrix.identity(3))
L[:, 0] *= n
L = L.LLL()
for v in L:
    print(v)
    print([x.bit_length() for x in v])

# the expected vector is (a1*a3*b1-a1*a2*b3, -a1*a3*b1+a1**2*b3, a1*a2*b1-a1**2*b2)
# but we can see that a1 divides all of them
# assuming gcd(gcd((a1*a3*b2-a1*a2*b3), (-a1*a3*b1+a1**2*b3)), (a1*a2*b1-a1**2*b2)) == 1 here
# the shortest vector should be (a3*b2-a2*b3, -a3*b1+a1*b3, a2*b1-a1*b2)

_, t, u, v = L[0]
# therefore this should hold
# assert abs(a3*b2 - a2*b3) == abs(t)
# assert abs(a3*b1 - a1*b3) == abs(u)
# assert abs(a2*b1 - a1*b2) == abs(v)
a1s, a2s, a3s, b1s, b2s, b3s = QQ["a1,a2,a3,b1,b2,b3"].gens()
for sign in product((-1, 1), repeat=3):
    I = ideal(
        [
            a3s * b2s - a2s * b3s + sign[0] * t,
            a3s * b1s - a1s * b3s + sign[1] * u,
            a2s * b1s - a1s * b2s + sign[2] * v,
        ]
    )
    if I.dimension() != -1:
        print(sign)
        print("dim", I.dimension())

        def step2(f):
            # this f is in the form of k1*a1+k2*a2+k3*a3==0
            # for some reason, k1*b1+k2*b2+k3*b3==0 also holds
            # use LLL to find it
            print("=" * 40)
            print(f)
            L = matrix(f.coefficients()).T.augment(matrix.identity(3))
            L[:, 0] *= n
            L = L.LLL()
            print(L[0])
            print(L[1])
            v1 = L[0]
            v2 = L[1]
            xs = []
            for c1, c2 in product((-2, -1, 0, 1, 2), repeat=2):
                v = c1 * v1 + c2 * v2
                _, x1, x2, x3 = v
                if all([0 <= x <= 2**312 for x in (x1, x2, x3)]):
                    xs.append((x1, x2, x3))
            # we don't know which one is correct pair of (a1, a2, a3) and (b1, b2, b3)
            # just try all combinations
            for g1, g2 in combinations(xs, 2):
                a1r, a2r, a3r = g1
                b1r, b2r, b3r = g2
                q = gcd(a2r * h1 - a1r * h2, n)
                if 1 < q < n:
                    p = n // q
                    e = 0x10001
                    d = inverse_mod(e, (p - 1) * (q - 1))
                    m = pow(c, d, n)
                    flag = int(m).to_bytes(1024, "big").strip(b"\x00")
                    print(flag)
                    exit()

        step2(I.groebner_basis()[1])
# DUCTF{0rtho_l4tt1c3_1s_a_fun_and_gr34t_t3chn1que_f0r_the_t00lbox!}

fnv

#!/usr/bin/env python3

import os

def fnv1(s):
	h = 0xcbf29ce484222325
	for b in s:
		h *= 0x00000100000001b3
		h &= 0xffffffffffffffff
		h ^= b
	return h

TARGET = 0x1337133713371337

print("Welcome to FNV!")
print(f"Please enter a string in hex that hashes to 0x{TARGET:016x}:")
s = bytearray.fromhex(input())
if fnv1(s) == TARGET:
	print('Well done!')
	print(os.getenv('FLAG'))
else:
	print('Try again... :(')

簡單來說是找這個 fnv1 hash function 的 preimage。

我是猜測說這個 hash function 高機率是 bijective 的,所以理論上有個 8 bytes 的輸入可以 hash 到 TARGET。而這個 hash function 在處理時是 byte-by-byte 的,所以我想說可以 MITM,前後各搜 4 bytes 解決。

這部分當然要用 C++ 實作才行,而中間的 table 因為要 232×8bytes=32GB2^{32} \times 8 \,\text{bytes} = 32 \,\text{GB} 才行,所以要在 ram 夠大的系統上跑才行。

或是可以用 mmap 之類的,不過應該會慢些

這邊是我的第一個腳本,先搜前半把 table 建好,之後 sort 之後後半每個搜的時候用 binary search 找,找到就把那個中間值和後半的輸入印出:

#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
#include <algorithm>

__attribute__((always_inline)) inline void fnv_forward(uint64_t *s, uint8_t c) {
	*s *= 0x00000100000001b3;
	*s ^= c;
}
__attribute__((always_inline)) inline void fnv_backward(uint64_t *s,
                                                        uint8_t c) {
	*s ^= c;
	*s *= 0xce965057aff6957b;
}
int main() {
	const size_t tbl_size = (1L << 32) * sizeof(uint64_t);
	// int fd = open("./table", O_RDWR | O_CREAT, 0644);
	// ftruncate(fd, tbl_size);
	// uint64_t *tbl = (uint64_t *)mmap(NULL, tbl_size, PROT_READ | PROT_WRITE,
	//                                  MAP_SHARED, fd, 0);
	// uint64_t *tblend = tbl + (1L << 32);

	uint64_t *tbl = (uint64_t *)malloc(tbl_size);
	uint64_t *tblend = tbl + (1L << 32);

	uint64_t start = 0xcbf29ce484222325;
	uint64_t end = 0x1337133713371337;

	// meet in the middle attack 8 bytes
	// 4 bytes forward, 4 bytes backward
	uint64_t *tblptr = tbl;
	for (uint8_t a = 0; a < 256; a++) {
		printf("a: %d\n", a);
		for (uint8_t b = 0; b < 256; b++) {
			for (uint8_t c = 0; c < 256; c++) {
				for (uint8_t d = 0; d < 256; d++) {
					uint64_t s = start;
					fnv_forward(&s, a);
					fnv_forward(&s, b);
					fnv_forward(&s, c);
					fnv_forward(&s, d);
					*tblptr = s;
					tblptr++;
					if (d == 255)
						break;
				}
				if (c == 255)
					break;
			}
			if (b == 255)
				break;
		}
		if (a == 255)
			break;
	}
	printf("done part 1\n");

	std::sort(tbl, tblend);
	printf("done sort\n");

	for (uint8_t a = 0; a < 256; a++) {
		printf("a: %d\n", a);
		for (uint8_t b = 0; b < 256; b++) {
			for (uint8_t c = 0; c < 256; c++) {
				for (uint8_t d = 0; d < 256; d++) {
					uint64_t s = end;
					fnv_backward(&s, a);
					fnv_backward(&s, b);
					fnv_backward(&s, c);
					fnv_backward(&s, d);
					uint64_t *f = std::lower_bound(tbl, tblend, s);
					if (f != tblend && *f == s) {
						printf("found %lx\n", s);
						printf("a: %d, b: %d, c: %d, d: %d\n", a, b, c, d);
					}
					if (d == 255)
						break;
				}
				if (c == 255)
					break;
			}
			if (b == 255)
				break;
		}
		if (a == 255)
			break;
	}
	return 0;
}
// g++ brute.cpp -Wall -o brute -Ofast -march=native -mtune=native && ./brute

之後再搜一次前半直接和那個中間值比對就行:

#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
#include <algorithm>

__attribute__((always_inline)) inline void fnv_forward(uint64_t *s, uint8_t c) {
	*s *= 0x00000100000001b3;
	*s ^= c;
}
__attribute__((always_inline)) inline void fnv_backward(uint64_t *s,
                                                        uint8_t c) {
	*s ^= c;
	*s *= 0xce965057aff6957b;
}
int main() {
	/* 
    from brute:
    found f254947ed944e831
    a: 11, b: 201, c: 84, d: 154
    */
	uint64_t start = 0xcbf29ce484222325;
	uint64_t end = 0x1337133713371337;
	uint64_t target = 0xf254947ed944e831;
	for (uint8_t a = 0; a < 256; a++) {
		printf("a: %d\n", a);
		for (uint8_t b = 0; b < 256; b++) {
			for (uint8_t c = 0; c < 256; c++) {
				for (uint8_t d = 0; d < 256; d++) {
					uint64_t s = start;
					fnv_forward(&s, a);
					fnv_forward(&s, b);
					fnv_forward(&s, c);
					fnv_forward(&s, d);
					if (s == target) {
						printf("found %lx\n", s);
						printf("a: %d, b: %d, c: %d, d: %d\n", a, b, c, d);
					}
					if (d == 255)
						break;
				}
				if (c == 255)
					break;
			}
			if (b == 255)
				break;
		}
		if (a == 255)
			break;
	}
	printf("done\n");
    /*
    and we get this:
    found f254947ed944e831
    a: 104, b: 41, c: 244, d: 81

    so the result is:
    [104, 41, 244, 81, 154, 84, 201, 11]
    */
	return 0;
}

而我最後得到的輸入是 6829f4519a54c90b,輸進 remote 拿到 flag: DUCTF{sorry_but_your_cryptographic_hash_function_is_in_another_castle}

另外據作者所說這是 intended solutions 之一,另一個方法是 LLL。

hhhhh

#!/usr/bin/env python3
from os import getenv as hhhhhhh
from hashlib import md5 as hhhhh

def hhhhhh(hhh):
    h = hhhhh()
    hh = bytes([0] * 16)
    for hhhh in hhh:
        h.update(bytes([hhhh]))
        hh = bytes([hhhhhhh ^ hhhhhhhh for hhhhhhh, hhhhhhhh in zip(hh, h.digest())])
    return hh

print('hhh hhh hhhh hhh hhhhh hhhh hhhh hhhhh hhhh hh hhhhhh hhhh?')

h = bytes.fromhex(input('h: '))

if hhhhhh(h) == b'hhhhhhhhhhhhhhhh':
    print('hhhh hhhh, hhhh hh hhhh hhhh:', hhhhhhh('FLAG'))
else:
    print('hhhhh, hhh hhhhh!')

這題 identifier 都被混淆了,稍微翻譯一下那個 function 可得:

def fn(inp):
    h = md5()
    ret = bytes([0] * 16)
    for b in inp:
        h.update(bytes([b]))
        ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
    return ret

這題簡單來說就是要找一組 inp 使得:

(md5(inp[:1]) xor md5(inp[:2]) xor md5(inp[:3]) xor md5(inp[:4]) xor ...)[:16] == b'hhhhhhhhhhhhhhhh'

一個最簡單的想法自然是直接 BFS/DFS 搜尋,但顯然不可能找到。後來想了一下突然想到 md5 可以用 fastcoll 之類的工具快速撞出 collision,那這能在這邊產生什麼幫助呢?

fastcoll 這個工具是可以指定 prefix 做 chosen prefix collision 的,所以假設我們今天有 H(m1,a)=H(m1,b)H(m_{1,a})=H(m_{1,b}),然後用 m1,am_{1,a} 當作 prefix 去找另外兩個 block m2,a,m2,bm_{2,a}, m_{2,b} 產生另一個 collision H(m1,am2,a)=H(m1,am2,b)H(m_{1,a} || m_{2,a}) = H(m_{1,a} || m_{2,b})

這邊因為 merkle damgard 的性質,m1,am_{1,a}m1,bm_{1,b} 在 md5 內部的狀態都是相同的,所以 H(m1,bm2,a)=H(m1,bm2,b)H(m_{1,b} || m_{2,a}) = H(m_{1,b} || m_{2,b}) 也成立。而這個概念其實可以很簡單的畫成一張圖:

Loading graph...

所以這邊我們不管選擇 (m1,a,m2,b)(m_{1,a},m_{2,b}) 還是 (m1,b,m2,a)(m_{1,b},m_{2,a}) 得到的 hash 都是一樣的。那這對題目原本的 fn f(x)f(x) 有什麼關係呢? 因為 mi,am_{i,a}mi,bm_{i,b} 肯定有些 byte 不同所以 f(mi,a)f(m_{i,a})f(mi,b)f(m_{i,b}) 也會不同。這邊把這個不同記為

Δi=f(mi,a)f(mi,b)\Delta_i = f(m_{i,a}) \oplus f(m_{i,b})

假設我們今天選擇 f(m1,a,m2,a,m3,a)f(m_{1,a},m_{2,a},m_{3,a}),想要在第二條路的地方改走 m2,bm_{2,b} 那條路,那麼有

f(m1,a,m2,b,m3,a)=f(m1,a,m2,a,m3,a)Δ2f(m_{1,a},m_{2,b},m_{3,a})=f(m_{1,a},m_{2,a},m_{3,a}) \oplus \Delta_2

也就是說在這個圖上如果選擇換一條路走就等於對輸出 xor 一個 Δi\Delta_i。而這題的 f(x)f(x) 輸出是 128 bits,所以我們可以先 i[1,128]\forall i\in [1,128] 尋找 (mi,a,mi,b,Δi)(m_{i,a},m_{i,b},\Delta_i)。然後假設 c=f(m1,a,m2,a,)c=f(m_{1,a},m_{2,a},\cdots) 的話可以得到一個線性系統:

c+i=1128xiΔi=tc + \sum_{i=1}^{128} x_i \Delta_i = t

這邊是在 F2128F_{2^{128}} 下面解的,所以 xi[0,1]x_i \in [0,1],直接用 gaussian elimination 就能求出解了。而解答的意義就是代表我們在第 ii 條路應該選 mi,am_{i,a} 還是 mi,bm_{i,b} 的意思。

具體實作上我是先寫出一個腳本生出那些 collision 的資料:

from hashlib import md5 as md5
import subprocess
from tempfile import TemporaryDirectory
import os, pickle
from tqdm import trange

FASTCOLL_BIN = os.path.expanduser("~/workspace/fastcoll/fastcoll")


def fastcoll(prefix=b""):
    with TemporaryDirectory() as dir:
        with open(dir + "/prefix", "wb") as f:
            f.write(prefix)
        subprocess.run(
            [
                FASTCOLL_BIN,
                "-p",
                "prefix",
                "-o",
                "out1",
                "-o",
                "out2",
            ],
            cwd=dir,
            stdout=subprocess.DEVNULL,
            check=True,
        )
        with open(dir + "/out1", "rb") as f:
            m1 = f.read()
        with open(dir + "/out2", "rb") as f:
            m2 = f.read()
    return m1[len(prefix) :], m2[len(prefix) :]


def fn(inp):
    h = md5()
    ret = bytes([0] * 16)
    for b in inp:
        h.update(bytes([b]))
        ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
    return ret


def xor(x, y):
    return bytes([a ^ b for a, b in zip(x, y)])


# check 1
# m1, m2 = fastcoll()
# m11, m12 = fastcoll(m1)
# print(md5(m1 + m11).hexdigest())
# print(md5(m1 + m12).hexdigest())
# print(md5(m2 + m11).hexdigest())
# print(md5(m2 + m12).hexdigest())

# check 2
# m1a, m1b = fastcoll()
# m2a, m2b = fastcoll(m1a)
# x1 = xor(fn(m1a), fn(m1b))
# x2 = xor(fn(m1a + m2a), fn(m1a + m2b))
# assert x2 == xor(fn(m1b + m2a), fn(m1b + m2b))
# cur = fn(m1a + m2a)
# assert xor(cur, x1) == fn(m1b + m2a)
# assert xor(cur, x2) == fn(m1a + m2b)

ms = []
xs = []
prev = b""
for _ in trange(3):
    print(len(prev))
    ma, mb = fastcoll(prev)
    ms.append((ma, mb))
    x = xor(fn(prev + ma), fn(prev + mb))
    xs.append(x)
    prev += ma

with open("data.pkl", "wb") as f:
    pickle.dump((ms, xs), f)

# known properties:
m1a, m1b = ms[0]
m2a, m2b = ms[1]
m3a, m3b = ms[2]
x1 = xs[0]
x2 = xs[1]
x3 = xs[2]
cur = fn(m1a + m2a)
assert xor(cur, x1) == fn(m1b + m2a)
assert xor(cur, x2) == fn(m1a + m2b)
cur = fn(m1a + m2a + m3a)
assert xor(cur, x1) == fn(m1b + m2a + m3a)
assert xor(cur, x2) == fn(m1a + m2b + m3a)
assert xor(cur, x3) == fn(m1a + m2a + m3b)

然後最後再用那些資料求解搞定:

from hashlib import md5 as md5
import subprocess
from tempfile import TemporaryDirectory
import os, pickle
from tqdm import trange
from sage.all import *
from binteger import Bin
from pwn import remote


def fn(inp):
    h = md5()
    ret = bytes([0] * 16)
    for b in inp:
        h.update(bytes([b]))
        ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
    return ret


def xor(x, y):
    return bytes([a ^ b for a, b in zip(x, y)])


def byt2bv(b, n):
    return vector(GF(2), Bin(b, n=n).list)


def bv2byt(b):
    return Bin(b).bytes


with open("data.pkl", "rb") as f:
    ms, xs = pickle.load(f)

cur = byt2bv(fn(b"".join([ma for ma, mb in ms])), 128)
mat = matrix(GF(2), [byt2bv(x, 128) for x in xs])
assert mat.rank() == 128
target = byt2bv(b"h" * 16, 128)
# cur+?*mat=target
sol = mat.solve_left(target - cur)
print(sol)
msg = b""
for v, ma, mb in zip(sol, *zip(*ms)):
    if v == 0:
        msg += ma
    else:
        msg += mb
print(fn(msg))

io = remote("2023.ductf.dev", 30003)
io.sendline(msg.hex().encode())
print(io.recvallS().strip())
# DUCTF{hhh.hhh_hh_hhhhhh_hhh,hh_hh?h_hh_hhhh_hhh-hh,hhh-hhhhh_hhhhh_hh.}