Home Rolled

Brownie in Motion

2020/05/28

Categories: crypto

TJCTF 2020 - Home Rolled

Problem

It’s that time of year again… time to home roll your own crypto! Since pesky CTF players keep breaking my schemes, this time I obfuscated the source code, so you’ll never be able to figure out what it’s doing. I also used cutting-edge Python 3.8 syntax! Security by obscurity!

nc p1.tjctf.org 8012

Solution

The source is given in this challenge; unfortunately, it is a bit obfuscated:

(The original’s two-space indentation has been fixed)

import os,itertools
def c(l):
    while l():
        yield l
r,e,h,p,v,u=open,any,bool,filter,min,len
b=lambda x:(lambda:x)
w=lambda q:(lambda*x:q(2))
m=lambda*l:[p(e(h,l),key=w(os.urandom)).pop(0)for j in c(lambda:v(l))]
f=lambda l:[b(lambda:m(f(l[:k//2]),f(l[k//2:]))),b(b(l))][(k:=u(l))==1]()()
s=r(__file__).read()
t=lambda p:",".join(p)
o=list(itertools.permutations("rehpvu"))
exec(t(o[sum(map(ord,s))%720])+"="+t(b(o[0])()))
a=r("flag.txt").read()
print("".join(hex((g^x)+(1<<8))[7>>1:]for g,x in zip(f(list(range(256))),map(ord,a))))

Essentially, we are given an encrypted flag. Our goal is obviously to obtain the plaintext.

From this point, there are two paths that can be taken; one of them does not solve it:

  1. Deobfuscating the entire source and wasting 8 hours on an irrelevant part of the challenge

  2. Taking a cursory glance to see what is going on and paying attention only to the very last line

We can proceed with the second.

print("".join(hex((g^x)+(1<<8))[7>>1:]for g,x in zip(f(list(range(256))),map(ord,a))))

We can see from the previous line that a is just the flag. If we print out f(list(range(256))) a few times, we see f is a function that shuffles list(range(256)). It happens that the function does perform a uniform shuffle; however, feel free to spend a couple hours or more trying to find biases in it.

Anyway, we then see that hex((g^x)+(1<<8))[7>>1:] is done to each byte of the flag and its corresponding element in the shuffled list.

What this does is relatively easy to understand; it essentially just takes the hex representation of g ^ x and slices off the 0x prefix.

Thus, the final line can be interpreted as xoring the flag byte-by-byte with a key; this key is a shuffled list containing each integer from 0 to 255 once.

Now, it is important to observe that no byte of the key is ever repeated. Given that the flag format is known (tjctf{...}), this can be exploited to find the flag.

For example, we can call the first byte of the ciphertext c. Because we know the flag begins with t or 0x74, the first element of the key must be c ^ 0x74. This also means that c ^ 0x74 is never found anywhere else in the key: each element is unique. Thus, for each byte of flag, we can maintain a list of what it cannot be—for a byte b of the encrypted flag, b ^ c ^ 0x74 cannot be the corresponding byte of the plaintext because the byte of the key that would make it so has already been used up. Since we know seven bytes of the flag, narrowing down each byte of the flag to one possible character does not not take very long.

Here is a solve script; to preserve aesthetic value, only the first 6 bytes, tjctf{, were used in the attack.

from pwn import *

def get_enc_flag():
    r = remote('p1.tjctf.org', 8012)
    res = r.readline()[:-1]
    r.close()

    # Convert hex back to integer list
    return [int(chr(x) + chr(y), 16) for x, y in zip(res[::2], res[1::2])]

# Generate the set of key bytes that have been 'used up'
def generate_used(enc_flag):
    known = 'tjctf{'
    used = set()
    for plain, enc in zip(known, enc_flag):
        used.add(ord(plain) ^ enc)
    return used

# Maintain a list of possible characters, one for each character excluding tjctf{
possible_chars = [set(range(256)) for _ in range(32)]

done = False
while not done:
    done = True

    # Get the list of 'used up' key bytes
    enc_flag = get_enc_flag()
    used = generate_used(enc_flag)

    # Remove b xored with each 'used up' byte from possible characters
    for char_set, enc in zip(possible_chars, enc_flag[6:]):
        for used_num in used:
            char_set.discard(used_num ^ enc)
            if len(char_set) > 1:
                done = False

# I don't really know how python works but I'll pretend to
print('tjctf{' + ''.join(map(lambda x:chr(x.pop()), possible_chars)))