DownUnderCTF 2023 Writeups

今年在 ${CyStick} 中參加了今年的 DownUnderCTF,解了幾個題目而已。

Crypto

apbq rsa i

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from Crypto.Util.number import getPrime, bytes_to_long
from random import randint

p = getPrime(1024)
q = getPrime(1024)
n = p * q
e = 0x10001

hints = []
for _ in range(2):
a, b = randint(0, 2**12), randint(0, 2**312)
hints.append(a * p + b * q)

FLAG = open('flag.txt', 'rb').read().strip()
c = pow(bytes_to_long(FLAG), e, n)
print(f'{n = }')
print(f'{c = }')
print(f'{hints = }')

RSA,並額外提供了:

其中 很小,所以爆一下之後求

的倍數,和 gcd 結束。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from output import *
from itertools import product
from math import gcd
from tqdm import tqdm

h1, h2 = hints

for a, b in product(range(2**12), repeat=2):
q = gcd(a * h1 - b * h2, n)
if q != 1 and q < n:
print(q, n)
break
q = 131749193259488372734882395267267400452018470526669557625739671139987328020291297864452866159092612057179256194270162132453379915244109293125925735503860433762913331882646107732350881847176645056234707603429859084687993766987515407413263165507571230913665811470277548803428982720272589512297691853544766981321
p = n // q
e = 0x10001
d = pow(e, -1, (p - 1) * (q - 1))
m = pow(c, d, n)
flag = m.to_bytes(1024, "big").strip(b"\x00")
print(flag)
# DUCTF{gcd_1s_a_g00d_alg0r1thm_f0r_th3_t00lbox}

apbq rsa ii

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from Crypto.Util.number import getPrime, bytes_to_long
from random import randint

p = getPrime(1024)
q = getPrime(1024)
n = p * q
e = 0x10001

hints = []
for _ in range(3):
a, b = randint(0, 2**312), randint(0, 2**312)
hints.append(a * p + b * q)

FLAG = open('flag.txt', 'rb').read().strip()
c = pow(bytes_to_long(FLAG), e, n)
print(f'{n = }')
print(f'{c = }')
print(f'{hints = }')

和前一題幾乎一樣,不過多了一個 hint,且 的範圍變大了。

總之有三條等式,先把未知的 消除得到:

因為 比較大,可以嘗試 求出 short vector:

不過因為 在三個都有出現,所以實際上得到的是:

還有可能是三個的 gcd 不為 1,不過大不了爆一下就行

因為 LLL 所以可能還有正負號的問題要爆,不過這個好處理。總之我們可以拿六個未知數 求 groebner basis,發現說裡面有兩個等式:

所以拿 做 LLL 再小爆一下就是 了,之後就和前一題一樣 gcd 搞定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from sage.all import *
from Crypto.Util.number import getPrime, bytes_to_long
from random import randint
from itertools import product, combinations
from output import *


# assert a2 * h1 - a1 * h2 == (a2 * b1 - a1 * b2) * q
# assert a3 * h1 - a1 * h3 == (a3 * b1 - a1 * b3) * q
# assert (a3 * b1 - a1 * b3) * (a2 * h1 - a1 * h2) == (a2 * b1 - a1 * b2) * (a3 * h1 - a1 * h3)
# so: a1*a3*b2*h1 - a1*a2*b3*h1 - a1*a3*b1*h2 + a1^2*b3*h2 + a1*a2*b1*h3 - a1^2*b2*h3 == 0
# we try to use LLL to find them

h1, h2, h3 = hints
L = matrix(hints).T.augment(matrix.identity(3))
L[:, 0] *= n
L = L.LLL()
for v in L:
print(v)
print([x.bit_length() for x in v])

# the expected vector is (a1*a3*b1-a1*a2*b3, -a1*a3*b1+a1**2*b3, a1*a2*b1-a1**2*b2)
# but we can see that a1 divides all of them
# assuming gcd(gcd((a1*a3*b2-a1*a2*b3), (-a1*a3*b1+a1**2*b3)), (a1*a2*b1-a1**2*b2)) == 1 here
# the shortest vector should be (a3*b2-a2*b3, -a3*b1+a1*b3, a2*b1-a1*b2)

_, t, u, v = L[0]
# therefore this should hold
# assert abs(a3*b2 - a2*b3) == abs(t)
# assert abs(a3*b1 - a1*b3) == abs(u)
# assert abs(a2*b1 - a1*b2) == abs(v)
a1s, a2s, a3s, b1s, b2s, b3s = QQ["a1,a2,a3,b1,b2,b3"].gens()
for sign in product((-1, 1), repeat=3):
I = ideal(
[
a3s * b2s - a2s * b3s + sign[0] * t,
a3s * b1s - a1s * b3s + sign[1] * u,
a2s * b1s - a1s * b2s + sign[2] * v,
]
)
if I.dimension() != -1:
print(sign)
print("dim", I.dimension())

def step2(f):
# this f is in the form of k1*a1+k2*a2+k3*a3==0
# for some reason, k1*b1+k2*b2+k3*b3==0 also holds
# use LLL to find it
print("=" * 40)
print(f)
L = matrix(f.coefficients()).T.augment(matrix.identity(3))
L[:, 0] *= n
L = L.LLL()
print(L[0])
print(L[1])
v1 = L[0]
v2 = L[1]
xs = []
for c1, c2 in product((-2, -1, 0, 1, 2), repeat=2):
v = c1 * v1 + c2 * v2
_, x1, x2, x3 = v
if all([0 <= x <= 2**312 for x in (x1, x2, x3)]):
xs.append((x1, x2, x3))
# we don't know which one is correct pair of (a1, a2, a3) and (b1, b2, b3)
# just try all combinations
for g1, g2 in combinations(xs, 2):
a1r, a2r, a3r = g1
b1r, b2r, b3r = g2
q = gcd(a2r * h1 - a1r * h2, n)
if 1 < q < n:
p = n // q
e = 0x10001
d = inverse_mod(e, (p - 1) * (q - 1))
m = pow(c, d, n)
flag = int(m).to_bytes(1024, "big").strip(b"\x00")
print(flag)
exit()

step2(I.groebner_basis()[1])
# DUCTF{0rtho_l4tt1c3_1s_a_fun_and_gr34t_t3chn1que_f0r_the_t00lbox!}

fnv

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#!/usr/bin/env python3

import os

def fnv1(s):
h = 0xcbf29ce484222325
for b in s:
h *= 0x00000100000001b3
h &= 0xffffffffffffffff
h ^= b
return h

TARGET = 0x1337133713371337

print("Welcome to FNV!")
print(f"Please enter a string in hex that hashes to 0x{TARGET:016x}:")
s = bytearray.fromhex(input())
if fnv1(s) == TARGET:
print('Well done!')
print(os.getenv('FLAG'))
else:
print('Try again... :(')

簡單來說是找這個 fnv1 hash function 的 preimage。

我是猜測說這個 hash function 高機率是 bijective 的,所以理論上有個 8 bytes 的輸入可以 hash 到 TARGET。而這個 hash function 在處理時是 byte-by-byte 的,所以我想說可以 MITM,前後各搜 4 bytes 解決。

這部分當然要用 C++ 實作才行,而中間的 table 因為要 才行,所以要在 ram 夠大的系統上跑才行。

或是可以用 mmap 之類的,不過應該會慢些

這邊是我的第一個腳本,先搜前半把 table 建好,之後 sort 之後後半每個搜的時候用 binary search 找,找到就把那個中間值和後半的輸入印出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
#include <algorithm>

__attribute__((always_inline)) inline void fnv_forward(uint64_t *s, uint8_t c) {
*s *= 0x00000100000001b3;
*s ^= c;
}
__attribute__((always_inline)) inline void fnv_backward(uint64_t *s,
uint8_t c) {
*s ^= c;
*s *= 0xce965057aff6957b;
}
int main() {
const size_t tbl_size = (1L << 32) * sizeof(uint64_t);
// int fd = open("./table", O_RDWR | O_CREAT, 0644);
// ftruncate(fd, tbl_size);
// uint64_t *tbl = (uint64_t *)mmap(NULL, tbl_size, PROT_READ | PROT_WRITE,
// MAP_SHARED, fd, 0);
// uint64_t *tblend = tbl + (1L << 32);

uint64_t *tbl = (uint64_t *)malloc(tbl_size);
uint64_t *tblend = tbl + (1L << 32);

uint64_t start = 0xcbf29ce484222325;
uint64_t end = 0x1337133713371337;

// meet in the middle attack 8 bytes
// 4 bytes forward, 4 bytes backward
uint64_t *tblptr = tbl;
for (uint8_t a = 0; a < 256; a++) {
printf("a: %d\n", a);
for (uint8_t b = 0; b < 256; b++) {
for (uint8_t c = 0; c < 256; c++) {
for (uint8_t d = 0; d < 256; d++) {
uint64_t s = start;
fnv_forward(&s, a);
fnv_forward(&s, b);
fnv_forward(&s, c);
fnv_forward(&s, d);
*tblptr = s;
tblptr++;
if (d == 255)
break;
}
if (c == 255)
break;
}
if (b == 255)
break;
}
if (a == 255)
break;
}
printf("done part 1\n");

std::sort(tbl, tblend);
printf("done sort\n");

for (uint8_t a = 0; a < 256; a++) {
printf("a: %d\n", a);
for (uint8_t b = 0; b < 256; b++) {
for (uint8_t c = 0; c < 256; c++) {
for (uint8_t d = 0; d < 256; d++) {
uint64_t s = end;
fnv_backward(&s, a);
fnv_backward(&s, b);
fnv_backward(&s, c);
fnv_backward(&s, d);
uint64_t *f = std::lower_bound(tbl, tblend, s);
if (f != tblend && *f == s) {
printf("found %lx\n", s);
printf("a: %d, b: %d, c: %d, d: %d\n", a, b, c, d);
}
if (d == 255)
break;
}
if (c == 255)
break;
}
if (b == 255)
break;
}
if (a == 255)
break;
}
return 0;
}
// g++ brute.cpp -Wall -o brute -Ofast -march=native -mtune=native && ./brute

之後再搜一次前半直接和那個中間值比對就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
#include <algorithm>

__attribute__((always_inline)) inline void fnv_forward(uint64_t *s, uint8_t c) {
*s *= 0x00000100000001b3;
*s ^= c;
}
__attribute__((always_inline)) inline void fnv_backward(uint64_t *s,
uint8_t c) {
*s ^= c;
*s *= 0xce965057aff6957b;
}
int main() {
/*
from brute:
found f254947ed944e831
a: 11, b: 201, c: 84, d: 154
*/
uint64_t start = 0xcbf29ce484222325;
uint64_t end = 0x1337133713371337;
uint64_t target = 0xf254947ed944e831;
for (uint8_t a = 0; a < 256; a++) {
printf("a: %d\n", a);
for (uint8_t b = 0; b < 256; b++) {
for (uint8_t c = 0; c < 256; c++) {
for (uint8_t d = 0; d < 256; d++) {
uint64_t s = start;
fnv_forward(&s, a);
fnv_forward(&s, b);
fnv_forward(&s, c);
fnv_forward(&s, d);
if (s == target) {
printf("found %lx\n", s);
printf("a: %d, b: %d, c: %d, d: %d\n", a, b, c, d);
}
if (d == 255)
break;
}
if (c == 255)
break;
}
if (b == 255)
break;
}
if (a == 255)
break;
}
printf("done\n");
/*
and we get this:
found f254947ed944e831
a: 104, b: 41, c: 244, d: 81

so the result is:
[104, 41, 244, 81, 154, 84, 201, 11]
*/
return 0;
}

而我最後得到的輸入是 6829f4519a54c90b,輸進 remote 拿到 flag: DUCTF{sorry_but_your_cryptographic_hash_function_is_in_another_castle}

另外據作者所說這是 intended solutions 之一,另一個方法是 LLL。

hhhhh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python3
from os import getenv as hhhhhhh
from hashlib import md5 as hhhhh

def hhhhhh(hhh):
h = hhhhh()
hh = bytes([0] * 16)
for hhhh in hhh:
h.update(bytes([hhhh]))
hh = bytes([hhhhhhh ^ hhhhhhhh for hhhhhhh, hhhhhhhh in zip(hh, h.digest())])
return hh

print('hhh hhh hhhh hhh hhhhh hhhh hhhh hhhhh hhhh hh hhhhhh hhhh?')

h = bytes.fromhex(input('h: '))

if hhhhhh(h) == b'hhhhhhhhhhhhhhhh':
print('hhhh hhhh, hhhh hh hhhh hhhh:', hhhhhhh('FLAG'))
else:
print('hhhhh, hhh hhhhh!')

這題 identifier 都被混淆了,稍微翻譯一下那個 function 可得:

1
2
3
4
5
6
7
def fn(inp):
h = md5()
ret = bytes([0] * 16)
for b in inp:
h.update(bytes([b]))
ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
return ret

這題簡單來說就是要找一組 inp 使得:

1
(md5(inp[:1]) xor md5(inp[:2]) xor md5(inp[:3]) xor md5(inp[:4]) xor ...)[:16] == b'hhhhhhhhhhhhhhhh'

一個最簡單的想法自然是直接 BFS/DFS 搜尋,但顯然不可能找到。後來想了一下突然想到 md5 可以用 fastcoll 之類的工具快速撞出 collision,那這能在這邊產生什麼幫助呢?

fastcoll 這個工具是可以指定 prefix 做 chosen prefix collision 的,所以假設我們今天有 ,然後用 當作 prefix 去找另外兩個 block 產生另一個 collision

這邊因為 merkle damgard 的性質, 在 md5 內部的狀態都是相同的,所以 也成立。而這個概念其實可以很簡單的畫成一張圖:

stateDiagram-v2
    [*] --> State1: m_1a
    [*] --> State1: m_1b
    State1 --> State2: m_2a
    State1 --> State2: m_2b
    State2 --> [*]: m_3a
    State2 --> [*]: m_3b

所以這邊我們不管選擇 還是 得到的 hash 都是一樣的。那這對題目原本的 fn 有什麼關係呢? 因為 肯定有些 byte 不同所以 也會不同。這邊把這個不同記為

假設我們今天選擇 ,想要在第二條路的地方改走 那條路,那麼有

也就是說在這個圖上如果選擇換一條路走就等於對輸出 xor 一個 。而這題的 輸出是 128 bits,所以我們可以先 尋找 。然後假設 的話可以得到一個線性系統:

這邊是在 下面解的,所以 ,直接用 gaussian elimination 就能求出解了。而解答的意義就是代表我們在第 條路應該選 還是 的意思。

具體實作上我是先寫出一個腳本生出那些 collision 的資料:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from hashlib import md5 as md5
import subprocess
from tempfile import TemporaryDirectory
import os, pickle
from tqdm import trange

FASTCOLL_BIN = os.path.expanduser("~/workspace/fastcoll/fastcoll")


def fastcoll(prefix=b""):
with TemporaryDirectory() as dir:
with open(dir + "/prefix", "wb") as f:
f.write(prefix)
subprocess.run(
[
FASTCOLL_BIN,
"-p",
"prefix",
"-o",
"out1",
"-o",
"out2",
],
cwd=dir,
stdout=subprocess.DEVNULL,
check=True,
)
with open(dir + "/out1", "rb") as f:
m1 = f.read()
with open(dir + "/out2", "rb") as f:
m2 = f.read()
return m1[len(prefix) :], m2[len(prefix) :]


def fn(inp):
h = md5()
ret = bytes([0] * 16)
for b in inp:
h.update(bytes([b]))
ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
return ret


def xor(x, y):
return bytes([a ^ b for a, b in zip(x, y)])


# check 1
# m1, m2 = fastcoll()
# m11, m12 = fastcoll(m1)
# print(md5(m1 + m11).hexdigest())
# print(md5(m1 + m12).hexdigest())
# print(md5(m2 + m11).hexdigest())
# print(md5(m2 + m12).hexdigest())

# check 2
# m1a, m1b = fastcoll()
# m2a, m2b = fastcoll(m1a)
# x1 = xor(fn(m1a), fn(m1b))
# x2 = xor(fn(m1a + m2a), fn(m1a + m2b))
# assert x2 == xor(fn(m1b + m2a), fn(m1b + m2b))
# cur = fn(m1a + m2a)
# assert xor(cur, x1) == fn(m1b + m2a)
# assert xor(cur, x2) == fn(m1a + m2b)

ms = []
xs = []
prev = b""
for _ in trange(3):
print(len(prev))
ma, mb = fastcoll(prev)
ms.append((ma, mb))
x = xor(fn(prev + ma), fn(prev + mb))
xs.append(x)
prev += ma

with open("data.pkl", "wb") as f:
pickle.dump((ms, xs), f)

# known properties:
m1a, m1b = ms[0]
m2a, m2b = ms[1]
m3a, m3b = ms[2]
x1 = xs[0]
x2 = xs[1]
x3 = xs[2]
cur = fn(m1a + m2a)
assert xor(cur, x1) == fn(m1b + m2a)
assert xor(cur, x2) == fn(m1a + m2b)
cur = fn(m1a + m2a + m3a)
assert xor(cur, x1) == fn(m1b + m2a + m3a)
assert xor(cur, x2) == fn(m1a + m2b + m3a)
assert xor(cur, x3) == fn(m1a + m2a + m3b)

然後最後再用那些資料求解搞定:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from hashlib import md5 as md5
import subprocess
from tempfile import TemporaryDirectory
import os, pickle
from tqdm import trange
from sage.all import *
from binteger import Bin
from pwn import remote


def fn(inp):
h = md5()
ret = bytes([0] * 16)
for b in inp:
h.update(bytes([b]))
ret = bytes([a ^ b for a, b in zip(ret, h.digest())])
return ret


def xor(x, y):
return bytes([a ^ b for a, b in zip(x, y)])


def byt2bv(b, n):
return vector(GF(2), Bin(b, n=n).list)


def bv2byt(b):
return Bin(b).bytes


with open("data.pkl", "rb") as f:
ms, xs = pickle.load(f)

cur = byt2bv(fn(b"".join([ma for ma, mb in ms])), 128)
mat = matrix(GF(2), [byt2bv(x, 128) for x in xs])
assert mat.rank() == 128
target = byt2bv(b"h" * 16, 128)
# cur+?*mat=target
sol = mat.solve_left(target - cur)
print(sol)
msg = b""
for v, ma, mb in zip(sol, *zip(*ms)):
if v == 0:
msg += ma
else:
msg += mb
print(fn(msg))

io = remote("2023.ductf.dev", 30003)
io.sendline(msg.hex().encode())
print(io.recvallS().strip())
# DUCTF{hhh.hhh_hh_hhhhhh_hhh,hh_hh?h_hh_hhhh_hhh-hh,hhh-hhhhh_hhhhh_hh.}