rElGamal.py
from Crypto.Util.number import *
from PRNG import PRNG
class ElGamal(object):
def __init__(self, length, prime = None):
self.prng = PRNG(256)
self.length = length
self.g = 2
self.q = prime
while self.q == None:
p = self.next_prime(2**(length) + self.prng.get_bits(length))
if ((2 * p + 1) % 8 == 3 or (2 * p + 1) % 8 == 5) and isPrime(2 * p + 1):
self.q = 2 * p + 1
self.x = self.prng.get_bits(self.length)
self.h = pow(self.g, self.x, self.q)
def next_prime(self, x):
if x & 1 == 0:
x += 1
while not isPrime(x):
x += 2
return x
def encrypt(self, message):
y = self.prng.get_bits(self.length)
s = pow(self.h, y, self.q)
c1 = pow(self.g, y, self.q)
c2 = message * s % self.q
return y, c1, c2
PRNG.py
import os
class PRNG(object):
def __init__(self, length):
self.length = length
self.state = self.getseed()
self.key = self.getseed()
def parity(self,x):
aux = self.length
while aux > 1:
x ^= x >> ((aux + 1) / 2)
aux = (aux+1) / 2
return x & 1
def getseed(self):
return int(os.urandom(self.length / 8).encode('hex'), 16)
def next_state(self):
self.state = (self.state >> 1) | (self.parity(self.state & self.key) << (self.length - 1))
def get_bit(self):
output = self.state & 1
self.next_state()
return output
def get_bits(self, bits):
output = 0
for i in range(bits):
output = (output << 1) + self.get_bit()
return output
server.py
from Crypto.Util.number import *
import SocketServer
from ElGamal import ElGamal
from secret import flag
from text import *
import os
from hashlib import sha256
prime = 685221181007655969055643176795598500987539499099103356485206001142535588544010199273848305625469714980556215814571531400059127410834556408213573648640620782264547499664798168588595354413552517947288874272302443155003080155578117949303088899418619338227631822471359675459950492254857920916052407700056096582307
PORT = 2000
class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def PoW(self):
s = os.urandom(10)
h = sha256(s).hexdigest()
self.request.sendall("Provide a hex string X such that sha256(X)[-6:] = {}\n".format(h[-6:]))
inp = self.request.recv(2048).strip().lower()
is_hex = 1
for c in inp:
if not c in '0123456789abcdef':
is_hex = 0
if is_hex and sha256(inp.decode('hex')).hexdigest()[-6:] == h[-6:]:
self.request.sendall('Good, you can continue!\n')
return True
else:
self.request.sendall('Oops, your string didn\'t respect the criterion.\n')
return False
def handle(self):
self.request.settimeout(120)
if not self.PoW():
return
self.request.sendall(intro)
enc = ElGamal(1024, prime)
print enc.x
self.request.sendall(details.format(enc.q, enc.g))
correct = 0
while correct < 10:
m = int(os.urandom(32).encode('hex'), 16) % enc.q
y, c1, c2 = enc.encrypt(m)
self.request.sendall(challenge.format(c1, c2))
self.request.sendall(get_input)
try:
x = self.request.recv(1024)
x = int(x)
if x == m:
self.request.sendall(correct_answer)
correct += 1
else:
self.request.sendall(wrong_answer.format(m, y))`
correct = 0
except:
self.request.sendall(bad_input)
break
if correct == 10:
self.request.sendall(flag+'\n')
class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
pass
if __name__ == '__main__':
server = ThreadedTCPServer(('0.0.0.0', PORT), ThreadedTCPRequestHandler)
server.allow_reuse_address = True
server.serve_forever()
text.py
intro = 'Welcome!\nIn this challenge you will have to decrypt 10 messages in a row to get the flag!\n'
details = 'Here are some things to help you in your quest!\nq : {}\ng : {}\n'
challenge = 'Decrypt this:\nc1 : {}\nc2 : {}\n'
wrong_answer = "Nope, that's not it!\nhere's what you sohould've sent:\nm : {}\nalso here is your y : {}\n"
correct_answer = 'Hey, you did it, great job!\n'
get_input = 'Send the decrypted message as an integer: '
bad_input = 'Input does not match desired criteria. Aborting!\n'
STEP 1 - Recover h in ElGamal class
After send wrong $m$ to server, server tells both $m$ and $y$.
Then We can recover $h^y = m \cdot m^{-1}$.
Once we found a pair $y1, y2$ such that $gcd(y1, y2) = 1$, find $v1, v2$ satisfies $v1y1 + v2y2 = 1$ using extended gcd and recover $h = ((h^{y1})^{v1}) \cdot ((h^{y2})^{v2}))$
STEP 2 - Recover state, key in PRNG class
This PRNG extracts each bit by XORing specific bit in state.(If key = (1<<128) | (1<<64) | (1<<32), parity is 128-th bit of state $\oplus$ 192-th bit of state $\oplus$ 224-th bit of state)
Since key is fixed, with any y, you can easily build a linear equation system which contains 256 equations and 256 variables in GF(2). After solving this, you can get a key.
State is more easy. Extract LSB 256-bit of last Y and key(Beware of endian. Must be reversed on LSB 256-bit of last Y). Make another PRNG class adn extract dummy 1024-bit. Then state is synchronized with in server.
sage code for extracting key
Y = 89983791425551027210400346476092936344826533636016877225918481462124667356462216675054392621920865356045823186208305570456929667777504930092311160509140608128950177361952489418169806815449239730008355253892404586754110522969077661227581488968435597924500263911858546622164629597080557119791160012652564123666
bb = rev_bit(Y)
sz = 256
print str(Y)[:6], "done"
mat = [[0]*sz for i in range(sz)]
for i in range(sz):
for j in range(sz):
mat[i][j] = int(bb[1+i+j])
A = Matrix(IntegerModRing(2),mat)
Ainv = A.inverse()
vec = [0]*sz
for i in range(sz):
vec[i] = int(bb[i])
b = vector(IntegerModRing(2), vec)
ans = A.solve_right(b)
print(ans)
key = 0
for i in ans:
z = Integer(i)
key = (key<<1) | z
print "key ", key
solver.py
from hashlib import sha256
from Crypto.Util.number import *
import random, socket
############ my socket ###############
def interactive(socket):
print("[+] interactive mode")
while True:
rr = socket.recv(2**16).decode()
if not rr:
print("[!] socket closed")
return None
print(rr)
socket.send((input('> ')+'\n').encode())
def remote(ip, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print("[+] Connecting to {}:{}".format(ip,port))
sock.connect((ip,port))
print("[+] Done!")
return sock
def sendline(socket, msg):
if type(msg) == str: msg = msg.encode()
if msg[-1] != b'\n': msg += b'\n'
socket.send(msg)
def recv(socket):
return socket.recv(2**16).decode()
class PRNG(object):
def __init__(self, length, key, state):
self.length = length
self.state = state
self.key = key
def parity(self,x):
aux = self.length
while aux > 1:
x ^= x >> ((aux + 1) // 2)
aux = (aux+1) // 2
return x & 1
def next_state(self):
self.state = (self.state >> 1) | (self.parity(self.state & self.key) << (self.length - 1))
def get_bit(self):
output = self.state & 1
self.next_state()
return output
def get_bits(self, bits):
output = 0
for i in range(bits):
output = (output << 1) + self.get_bit()
return output
def rev_bit(x, sz):
s = bin(x)[2:].zfill(sz)
s = s[::-1]
return int(s, 2)
def egcd(a, b):
if a == 0:
return (b, 0, 1)
g, y, x = egcd(b%a,a)
return (g, x - (b//a) * y, y)
q = 685221181007655969055643176795598500987539499099103356485206001142535588544010199273848305625469714980556215814571531400059127410834556408213573648640620782264547499664798168588595354413552517947288874272302443155003080155578117949303088899418619338227631822471359675459950492254857920916052407700056096582307
r = remote('challs.xmas.htsp.ro', 10000)
pow_target = recv(r).strip()[-6:]
print(pow_target)
sendline(r, input())
x = recv(r) # Good, you can continue
c1arr = []
c2arr = []
marr = []
yarr = []
sarr = []
i1, i2 = -1, -1
# Step 1. to find h
for step in range(100000):
print("Collection phase ", step)
x = recv(r) # c1,c2
if step == 0: print(x)
c1 = x[x.find("c1")+5:]
c1 = int(c1.split('\n')[0].strip())
c2 = x[x.find("c2")+5:]
c2 = int(c2.split('\n')[0].strip())
c1arr.append(c1)
c2arr.append(c2)
sendline(r, '1')
x = recv(r)
m = x[x.find("m : ")+4:]
m = int(m.split('\n')[0].strip())
y = x[x.find("y : ")+4:]
y = int(y.split('\n')[0].strip())
s = c2 * inverse(m, q) % q
sarr.append(s)
marr.append(m)
yarr.append(y)
for i in range(step-1):
if GCD(yarr[i], yarr[step]) == 1:
i1, i2 = i, step
break
if i1 != -1:
break
g, v1, v2 = egcd(yarr[i1], yarr[i2])
assert(g==1)
assert(v1*yarr[i1]+v2*yarr[i2]==1)
h = 1
if v1 > 0:
h = h * pow(sarr[i1], v1, q) % q
else:
h = h * pow(inverse(sarr[i1],q), -v1, q) % q
if v2 > 0:
h = h * pow(sarr[i2], v2, q) % q
else:
h = h * pow(inverse(sarr[i2],q), -v2, q) % q
for i in range(step):
assert(pow(h,yarr[i],q) == sarr[i])
assert(c2arr[i] == marr[i]*sarr[i]%q)
key = 0
Y = yarr[-1]
print("Y : ", Y)
key = int(input())
prng = None
# Step 2
for step in range(10):
print("Challenge phase ", step)
x = recv(r) # c1,c2
c1 = x[x.find("c1")+5:]
c1 = int(c1.split('\n')[0].strip())
c2 = x[x.find("c2")+5:]
c2 = int(c2.split('\n')[0].strip())
if step == 0:
prng = PRNG(256, key, rev_bit(Y&(2**256-1), 256))
prng.get_bits(256)
nexty = prng.get_bits(1024)
s = pow(h, nexty, q)
m = c2 * inverse(s, q) % q
print("guessed m : ", m)
sendline(r, str(m))
x = recv(r)
print(x)
print(recv(r))
print(recv(r))
print(recv(r))
'CTF > Crypto' 카테고리의 다른 글
[Codegate 2022] PrimeGenerator (0) | 2022.02.28 |
---|---|
[2021 PBCTF] Steroid Stream (2) | 2021.10.11 |
[2021 PBCTF] Alkaloid Stream (2) | 2021.10.11 |
[2019 X-MAS CTF] Hashed Presents (0) | 2019.12.14 |
[2019 X-MAS CTF] DeFUNct Ransomware (0) | 2019.12.14 |
[HITCON CTF 2019 Quals] Very Simple Haskell (0) | 2019.10.14 |