Tsayaki

Flag

HTB{th1s_4tt4ck_m4k3s_T34_1n4ppr0pr14t3_f0r_h4sh1ng!}

Intro

File: server.py

from tea import Cipher as TEA
from secret import IV, FLAG
import os

ROUNDS = 10

def show_menu():
    print("""
============================================================================================
|| I made this decryption oracle in which I let users choose their own decryption keys.   ||
|| I think that it's secure as the tea cipher doesn't produce collisions (?) ... Right?   ||
|| If you manage to prove me wrong 10 times, you get a special gift.                      ||
============================================================================================
""")

def run():
    show_menu()

    server_message = os.urandom(20)
    print(f'Here is my special message: {server_message.hex()}')

    used_keys = []
    ciphertexts = []
    for i in range(ROUNDS):
        print(f'Round {i+1}/10')
        try:
            ct = bytes.fromhex(input('Enter your target ciphertext (in hex) : '))
            assert ct not in ciphertexts

            for j in range(4):
                key = bytes.fromhex(input(f'[{i+1}/{j+1}] Enter your encryption key (in hex) : '))
                assert len(key) == 16 and key not in used_keys
                used_keys.append(key)
                cipher = TEA(key, IV)
                enc = cipher.encrypt(server_message)
                if enc != ct:
                    print(f'Hmm ... close enough, but {enc.hex()} does not look like {ct.hex()} at all! Bye...')
                    exit()
        except:
            print('Nope.')
            exit()

        ciphertexts.append(ct)

    print(f'Wait, really? {FLAG}')


if __name__ == '__main__':
    run()

File: tea.py

from Crypto.Util.Padding import pad
from Crypto.Util.number import bytes_to_long as b2l, long_to_bytes as l2b
from enum import Enum

class Mode(Enum):
    ECB = 0x01
    CBC = 0x02

class Cipher:
    def __init__(self, key, iv=None):
        self.BLOCK_SIZE = 64
        self.KEY = [b2l(key[i:i+self.BLOCK_SIZE//16]) for i in range(0, len(key), self.BLOCK_SIZE//16)]
        self.DELTA = 0x9e3779b9
        self.IV = iv
        if self.IV:
            self.mode = Mode.CBC
        else:
            self.mode = Mode.ECB

    def _xor(self, a, b):
        return b''.join(bytes([_a ^ _b]) for _a, _b in zip(a, b))

    def encrypt(self, msg):
        msg = pad(msg, self.BLOCK_SIZE//8)
        blocks = [msg[i:i+self.BLOCK_SIZE//8] for i in range(0, len(msg), self.BLOCK_SIZE//8)]

        ct = b''
        if self.mode == Mode.ECB:
            for pt in blocks:
                ct += self.encrypt_block(pt)
        elif self.mode == Mode.CBC:
            X = self.IV
            for pt in blocks:
                enc_block = self.encrypt_block(self._xor(X, pt))
                ct += enc_block
                X = enc_block
        return ct

    def encrypt_block(self, msg):
        m0 = b2l(msg[:4])
        m1 = b2l(msg[4:])
        K = self.KEY
        msk = (1 << (self.BLOCK_SIZE//2)) - 1

        s = 0
        for i in range(32):
            s += self.DELTA
            m0 += ((m1 << 4) + K[0]) ^ (m1 + s) ^ ((m1 >> 5) + K[1])
            m0 &= msk
            m1 += ((m0 << 4) + K[2]) ^ (m0 + s) ^ ((m0 >> 5) + K[3])
            m1 &= msk

        m = ((m0 << (self.BLOCK_SIZE//2)) + m1) & ((1 << self.BLOCK_SIZE) - 1) # m = m0 || m1

        return l2b(m)

Solution

thanks-greysneakthief.png

from tea import Cipher

IV = bytes.fromhex('0dddd2773cf4b908')

from Crypto.Util.number import bytes_to_long as b2l, long_to_bytes as l2b

def add_bytes(a, b):
    return l2b(b2l(a) + b2l(b))

flip_k0 = bytes.fromhex('90000000' + '10000000' * 3)
flip_k1 = bytes.fromhex('10000000' + '90000000' + '10000000' * 2)
flip_k2 = bytes.fromhex('00000000' * 2 + '80000000' + '00000000')
flip_k3 = bytes.fromhex('00000000' * 3 + '80000000')

flips = [add_bytes(flip_k0, flip_k2),
         add_bytes(flip_k0, flip_k3),
         add_bytes(flip_k1, flip_k2),
         add_bytes(flip_k1, flip_k3)]

pt = input('server_message: ')
pt = bytes.fromhex(pt)


for i in range(10):
    base = add_bytes(bytes.fromhex('00' * 16), l2b(i))
    for j in range(4):
        # print(j, base.hex(), flips[j].hex())
        key = add_bytes(base, flips[j])
        # print(key.hex())
        cipher = Cipher(key, IV)
        ct = cipher.encrypt(pt)
        if j == 0:
            ct_old = ct
            print(ct.hex())
        else:
            assert ct_old == ct
        print(key.hex())

References

Author: calx

Created: 2024-05-30 Thu 05:24

Emacs 29.3 (Org mode 9.6.15)