[2018 X-MAS CTF] Santa's list
#!/usr/bin/python3
from Crypto.PublicKey import RSA
from Crypto.Util.number import *

FLAG = open('flag.txt', 'r').read().strip()

def menu():
    print()
    print('[1] Encrypt')
    print('[2] Decrypt')
    print('[3] Exit')
    return input()


def encrypt(m):
    return pow(m, rsa.e, rsa.n)


def decrypt(c):
    return pow(c, rsa.d, rsa.n)


rsa = RSA.generate(1024)
flag_encrypted = pow(bytes_to_long(FLAG.encode()), rsa.e, rsa.n)
used = [bytes_to_long(FLAG.encode())]

print('Ho, ho, ho and welcome back!')
print('Your list for this year:n')
print('Sarah - Nice')
print('Bob - Nice')
print('Eve - Naughty')
print('Galf - ' + hex(flag_encrypted)[2:])
print('Alice - Nice')
print('Johnny - Naughty')

while True:
    choice = menu()

    if choice == '1':
        m = bytes_to_long(input('nPlaintext > ').strip().encode())
        used.append(m)

        print('nEncrypted: ' + str(encrypt(m)))

    elif choice == '2':
        c = int(input('nCiphertext > ').strip())

        if c == flag_encrypted:
            print('Ho, ho, no...')

        else:
            m = decrypt(c)

            for no in used:
                if m % no == 0:
                    print('Ho, ho, no...')
                    break

            else:
                print('nDecrypted: ' + str(m))

    elif choice == '3':
        print('Till next time.nMerry Christmas!')
        break

문제의 코드는 위와 같습니다. 참고로 e는 65537입니다. 편의상 Flag을 M이라고 했을 때 우선 $M^e$가 주어지고, Encryption/Decryption Oracle도 주어집니다. 단 Decryption Oracle에는 Decrypt 결과 $M$이나 Encryption Oracle에서 넣었던 수의 배수가 나오게 될 경우 결과를 알려주지 않습니다.

 

우선 $N$을 알아내야 뭔가 진행이 될 것 같습니다. $N$은 아래의 식으로부터 이끌어낼 수 있습니다.

 

  • $(P^2)^d \equiv (P)^d(P)^d p \,\, mod \, n$

당연한 식을 왜 써놓은건가 싶겠지만, Decryption Oracle에 $P^2$와 $P$를 넣고나면.. 뭔가 느낌이 오나요? $P^2$를 넣은 결과를 $a$라고 하고, $P$를 넣은 결과를 b라고 할 때 $b^2-a$는 N의 배수이기 때문에 이러한 값들을 여러 개 구해 gcd를 취하면 N을 복원할 수 있습니다.

 

이제 N을 구하고나면 문제는 간단하게 해결할 수 있습니다. Textbook RSA vulnerability라고 검색하면 쉽게 찾을 수 있는 공격 방법일텐데, 임의의 꽤 큰 정수 a를 잡아 $a^e$를 계산하고, 주어진 $M^e$와 곱해 Encryption Oracle에 $(a^eM^e)$를 넣어 $(a^eM^e)^d=aM$를 계산합니다. 그리고 그 값에 $a^{-1}$을 곱하면 M을 복원해낼 수 있습니다.

 

$aM$은 M의 배수이나, a가 꽤 클 경우 $aM \equiv N$은 M의 배수가 아니기 때문에 해당 방식이 통합니다.

from Crypto.Util.number import *
import binascii,socket,sys
import hashlib
sys.setrecursionlimit(1000000)
def gcd(a, b):
  if a == 0: return b
  return gcd(b%a, a)

def egcd(a, b):
  if a == 0:
    return (b, 0, 1)
  g, y, x = egcd(b%a,a)
  return (g, x - (b//a) * y, y)

def inv(a, m):
  g, x, y = egcd(a, m)
  if g != 1:
    raise Exception('No modular inverse')
  return x%m

# x**2 = a (mod m), m is prime
def quad_congruence_equation(a, m, sign):
  assert((m+1)%4 == 0)
  if sign == 0: return pow(a, (m+1)//4, m)
  return m-pow(a, (m+1)//4, m)
# m must satisfies pairwise relatively prime
def crt(a, m):
  n = len(m)
  ret = a[0]
  mod = m[0]
  for i in range(1,n):
    m1 = mod
    mod *= m[i]
    m2inv = inv(m[i],m1)
    m1inv = inv(m1,m[i])
    ret = (ret*m[i]*m2inv+a[i]*m1*m1inv)%mod
  return ret

# 0x6161 -> 'AA'
def i2s(x):
  return long_to_bytes(x).decode("utf-8")

############ my socket ###############
def interactive(socket):
  print("[+] interactive mode")
  while True:
    rr = socket.recv(2**16).decode()
    if not rr:
      print("[!] socket closed")
      return None
    print(rr)
    socket.send((input('> ')+'n').encode())

def remote(ip, port):
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  print("[+] Connecting to {}:{}".format(ip,port))
  sock.connect((ip,port))
  print("[+] Done!")
  return sock

def sendline(socket, msg):
  if type(msg) not in [str,bytes]: msg = str(msg)
  if type(msg) == str: msg = msg.encode()
  if msg[-1] != b'n': msg += b'n'
  socket.send(msg)

def recv(socket):
  return socket.recv(2**16).decode()

def md5_pow(prefix):
  for i in range(100000000):
    if hashlib.md5(str(i).encode()).hexdigest().startswith(prefix):
      return str(i)
###################################

def extract(msg):
  msg = msg[msg.find(': ')+2:]
  return int(msg[:msg.find('n')])
r = remote('199.247.6.180',16001)
intro1 = recv(r)
#print("intro1", intro1)
flag_encrypted = intro1[intro1.find('Galf - ')+7:]
flag_encrypted = int(flag_encrypted[:flag_encrypted.find('n')],16)
print("flag_encrypted : ", hex(flag_encrypted))
n = 0
e = 65537
for p in [19,23,29,31,37,43,47]:
  sendline(r,2)
  recv(r) # Ciphertext > 
  sendline(r,p**2)
  val1 = extract(recv(r))
  sendline(r,2)
  recv(r)
  sendline(r,p)
  val2 = extract(recv(r))
  n = gcd(n,val2*val2-val1)

print('n : ', n)

sendline(r,2)
recv(r) # Ciphertext >
pad = n//245+232
sendline(r,pow(pad,e,n)*flag_encrypted)
plain = extract(recv(r)) * inv(pad,n) % n
print() 
assert(pow(plain,e,n)==flag_encrypted)
print(hex(plain))
print(i2s(plain))

'CTF > Crypto' 카테고리의 다른 글

[0CTF/TCTF 2019] zer0lfsr  (0) 2019.03.28
[0CTF/TCTF 2019] zero0des  (0) 2019.03.28
[0CTF/TCTF 2019] babysponge  (0) 2019.03.28
[2018 X-MAS CTF] Santa's list 2.0  (0) 2018.12.19
[2018 X-MAS CTF] Special Christmas Wishlist  (0) 2018.12.19
[2018 X-MAS CTF] Hanukkah  (0) 2018.12.18
  Comments