2018. 12. 19. 08:22, CTF/Crypto
#!/usr/bin/python3
from Crypto.PublicKey import RSA
from Crypto.Util.number import *
FLAG = open('flag.txt', 'r').read().strip()
def menu():
print()
print('[1] Encrypt')
print('[2] Decrypt')
print('[3] Exit')
return input()
def encrypt(m):
return pow(m, rsa.e, rsa.n)
def decrypt(c):
return pow(c, rsa.d, rsa.n)
rsa = RSA.generate(1024)
flag_encrypted = pow(bytes_to_long(FLAG.encode()), rsa.e, rsa.n)
used = [bytes_to_long(FLAG.encode())]
print('Ho, ho, ho and welcome back!')
print('Your list for this year:n')
print('Sarah - Nice')
print('Bob - Nice')
print('Eve - Naughty')
print('Galf - ' + hex(flag_encrypted)[2:])
print('Alice - Nice')
print('Johnny - Naughty')
while True:
choice = menu()
if choice == '1':
m = bytes_to_long(input('nPlaintext > ').strip().encode())
used.append(m)
print('nEncrypted: ' + str(encrypt(m)))
elif choice == '2':
c = int(input('nCiphertext > ').strip())
if c == flag_encrypted:
print('Ho, ho, no...')
else:
m = decrypt(c)
for no in used:
if m % no == 0:
print('Ho, ho, no...')
break
else:
print('nDecrypted: ' + str(m))
elif choice == '3':
print('Till next time.nMerry Christmas!')
break
문제의 코드는 위와 같습니다. 참고로 e는 65537입니다. 편의상 Flag을 M이라고 했을 때 우선 $M^e$가 주어지고, Encryption/Decryption Oracle도 주어집니다. 단 Decryption Oracle에는 Decrypt 결과 $M$이나 Encryption Oracle에서 넣었던 수의 배수가 나오게 될 경우 결과를 알려주지 않습니다.
우선 $N$을 알아내야 뭔가 진행이 될 것 같습니다. $N$은 아래의 식으로부터 이끌어낼 수 있습니다.
- $(P^2)^d \equiv (P)^d(P)^d p \,\, mod \, n$
당연한 식을 왜 써놓은건가 싶겠지만, Decryption Oracle에 $P^2$와 $P$를 넣고나면.. 뭔가 느낌이 오나요? $P^2$를 넣은 결과를 $a$라고 하고, $P$를 넣은 결과를 b라고 할 때 $b^2-a$는 N의 배수이기 때문에 이러한 값들을 여러 개 구해 gcd를 취하면 N을 복원할 수 있습니다.
이제 N을 구하고나면 문제는 간단하게 해결할 수 있습니다. Textbook RSA vulnerability라고 검색하면 쉽게 찾을 수 있는 공격 방법일텐데, 임의의 꽤 큰 정수 a를 잡아 $a^e$를 계산하고, 주어진 $M^e$와 곱해 Encryption Oracle에 $(a^eM^e)$를 넣어 $(a^eM^e)^d=aM$를 계산합니다. 그리고 그 값에 $a^{-1}$을 곱하면 M을 복원해낼 수 있습니다.
$aM$은 M의 배수이나, a가 꽤 클 경우 $aM \equiv N$은 M의 배수가 아니기 때문에 해당 방식이 통합니다.
from Crypto.Util.number import *
import binascii,socket,sys
import hashlib
sys.setrecursionlimit(1000000)
def gcd(a, b):
if a == 0: return b
return gcd(b%a, a)
def egcd(a, b):
if a == 0:
return (b, 0, 1)
g, y, x = egcd(b%a,a)
return (g, x - (b//a) * y, y)
def inv(a, m):
g, x, y = egcd(a, m)
if g != 1:
raise Exception('No modular inverse')
return x%m
# x**2 = a (mod m), m is prime
def quad_congruence_equation(a, m, sign):
assert((m+1)%4 == 0)
if sign == 0: return pow(a, (m+1)//4, m)
return m-pow(a, (m+1)//4, m)
# m must satisfies pairwise relatively prime
def crt(a, m):
n = len(m)
ret = a[0]
mod = m[0]
for i in range(1,n):
m1 = mod
mod *= m[i]
m2inv = inv(m[i],m1)
m1inv = inv(m1,m[i])
ret = (ret*m[i]*m2inv+a[i]*m1*m1inv)%mod
return ret
# 0x6161 -> 'AA'
def i2s(x):
return long_to_bytes(x).decode("utf-8")
############ my socket ###############
def interactive(socket):
print("[+] interactive mode")
while True:
rr = socket.recv(2**16).decode()
if not rr:
print("[!] socket closed")
return None
print(rr)
socket.send((input('> ')+'n').encode())
def remote(ip, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print("[+] Connecting to {}:{}".format(ip,port))
sock.connect((ip,port))
print("[+] Done!")
return sock
def sendline(socket, msg):
if type(msg) not in [str,bytes]: msg = str(msg)
if type(msg) == str: msg = msg.encode()
if msg[-1] != b'n': msg += b'n'
socket.send(msg)
def recv(socket):
return socket.recv(2**16).decode()
def md5_pow(prefix):
for i in range(100000000):
if hashlib.md5(str(i).encode()).hexdigest().startswith(prefix):
return str(i)
###################################
def extract(msg):
msg = msg[msg.find(': ')+2:]
return int(msg[:msg.find('n')])
r = remote('199.247.6.180',16001)
intro1 = recv(r)
#print("intro1", intro1)
flag_encrypted = intro1[intro1.find('Galf - ')+7:]
flag_encrypted = int(flag_encrypted[:flag_encrypted.find('n')],16)
print("flag_encrypted : ", hex(flag_encrypted))
n = 0
e = 65537
for p in [19,23,29,31,37,43,47]:
sendline(r,2)
recv(r) # Ciphertext >
sendline(r,p**2)
val1 = extract(recv(r))
sendline(r,2)
recv(r)
sendline(r,p)
val2 = extract(recv(r))
n = gcd(n,val2*val2-val1)
print('n : ', n)
sendline(r,2)
recv(r) # Ciphertext >
pad = n//245+232
sendline(r,pow(pad,e,n)*flag_encrypted)
plain = extract(recv(r)) * inv(pad,n) % n
print()
assert(pow(plain,e,n)==flag_encrypted)
print(hex(plain))
print(i2s(plain))
'CTF > Crypto' 카테고리의 다른 글
[0CTF/TCTF 2019] zer0lfsr (0) | 2019.03.28 |
---|---|
[0CTF/TCTF 2019] zero0des (0) | 2019.03.28 |
[0CTF/TCTF 2019] babysponge (0) | 2019.03.28 |
[2018 X-MAS CTF] Santa's list 2.0 (0) | 2018.12.19 |
[2018 X-MAS CTF] Special Christmas Wishlist (0) | 2018.12.19 |
[2018 X-MAS CTF] Hanukkah (0) | 2018.12.18 |
Comments