beVX Conference Challenge – OffensiveCon

During the event of OffensiveCon, we launched a reverse engineering and encryption challenge and gave the attendees the change to win great prizes.
The challenge was divided into two parts, a file – can be downloaded from here: https://www.beyondsecurity.com/bevxcon/bevx-challenge-1 – that you had to download and reverse engineer and server that you had to access to have a running version of this file.
The challenge could not have been resolved without access to the server as the encryption key that you were supposed to extract was only available in the running version on the server.
We had some great solutions sent to us, some of them were posted below – some arrived after the deadline, and some were not eligible as their solution was incomplete, but in the end we had three winners.
First place winner got an all paid, flight and hotel, and entry to our security conference beVX in September, second place prize winner got flight and entry to our security conference and the third place winner got a free entry to our event.

Challenge Source Code
If you don’t want to get a solution or hints to how to solve it – don’t continue reading – you have been warned 🙂

#include <thread>
#include <chrono>
#include <random>
#ifdef _WIN
#define EXPORT __declspec(dllexport)
#else
#define EXPORT
#endif
static const size_t COLS = 0x20;
static const size_t PRIVATE_KEY_ROWS = 3;
char EXPORT private_key[PRIVATE_KEY_ROWS][COLS] = {
  0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0,
  0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0,
  0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1
};
static const size_t MIN_KEY_LENGTH = 3 * COLS;
size_t EXPORT private_key_length = MIN_KEY_LENGTH; //+ 1 + ((unsigned int)std::rand()) % COLS;
size_t EXPORT number_of_rows = 0x10;
static const size_t WAIT_FOR = 800;
static const size_t XOR_KEY = 0xDF098B52;
EXPORT size_t encrypt(size_t num)
{
  std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_FOR));
  return num ^ XOR_KEY;
}
EXPORT size_t decrypt(size_t num)
{
  std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_FOR));
  return num ^ XOR_KEY;
}
#include <iostream>
#include <thread>
#include <chrono>
#ifdef _WIN
#define IMPORT extern __declspec(dllimport)
#else
#define IMPORT extern
#endif
IMPORT size_t encrypt(size_t num);
IMPORT  size_t decrypt(size_t num);
IMPORT const char private_key[];
IMPORT const size_t private_key_length;
IMPORT const size_t number_of_rows;
class Algo
{
public:
  static const size_t ROWS = 0x20;
  static const size_t COLS = 0x20;
  char table[ROWS][COLS] = { {0} };
  char * password;
  union
  {
    struct
    {
      size_t unused : 3;
      size_t init : 1;
    } s1;
    size_t number_of_rows : 5;
  } u1 = {};
  struct
  {
    size_t row : 7;
    int i : 7;
    char exponent[COLS];
    size_t x;
  } expo_data = {};
  static void StoreNumber(char * row, size_t num, bool enc = true)
  {
    if (enc)
      num = encrypt(num);
    for (int i = COLS - 1; i >= 0; i--)
    {
      row[i] = num % 2;
      num = num / 2;
    }
  }
  static void RetrieveNumber(char * row, size_t& num)
  {
    num = 0;
    for (int i = 0; i < COLS; i++)
    {
      num = num * 2 + row[i];
    }
    num = decrypt(num);
  }
  void Calculate_next()
  {
    expo_data.x = expo_data.x * expo_data.x;
    if (expo_data.exponent[expo_data.i++] == 1)
    {
      size_t num = 0;
      RetrieveNumber(table[expo_data.row], num);
      expo_data.x = expo_data.x * num;
    }
  }
  void End_calc()
  {
    StoreNumber(table[expo_data.row], std::hash<size_t>()(expo_data.x));
    u1.s1.init = 0;
    u1.s1.init = false;
  }
  void InitializeExp(size_t row1, size_t row2)
  {
    std::copy(table[row2], table[row2] + COLS, expo_data.exponent);
    expo_data.i = 0;
    while (expo_data.exponent[expo_data.i] == 0 && expo_data.i < COLS)
    {
      ++expo_data.i;
    }
    expo_data.x = 1;
    expo_data.row = row1;
  }
  void Multiply(size_t row1, size_t row2, size_t row3)
  {
    size_t num1 = 0, num2 = 0, num3 = 0;
    RetrieveNumber(table[row1], num1);
    RetrieveNumber(table[row2], num2);
    num3 = num1 * num2;
    StoreNumber(table[row3], num3);
  }
  void Add(size_t row1, size_t row2, size_t row3)
  {
    size_t num1 = 0, num2 = 0, num3 = 0;
    RetrieveNumber(table[row1], num1);
    RetrieveNumber(table[row2], num2);
    num3 = num1 + num2;
    StoreNumber(table[row3], num3);
  }
  void Sub(size_t row1, size_t row2, size_t row3)
  {
    size_t num1 = 0, num2 = 0, num3 = 0;
    RetrieveNumber(table[row1], num1);
    RetrieveNumber(table[row2], num2);
    num3 = num1 - num2;
    StoreNumber(table[row3], num3);
  }
  void Divide(size_t row1, size_t row2, size_t row3)
  {
    size_t num1 = 0, num2 = 0, num3 = 0;
    RetrieveNumber(table[row1], num1);
    RetrieveNumber(table[row2], num2);
    if (num2 == 0)
    {
      return;
    }
    num3 = num1 / num2;
    StoreNumber(table[row3], num3);
  }
  bool ValidateRowIndex(size_t row)
  {
    return (row < u1.number_of_rows);
  }
  void Encryption()
  {
    if (expo_data.i > COLS - 1)
    {
      End_calc();
      return;
    }
    char op = 0;
    std::cout << "Continue Encryption? (y/n)" << std::endl;
    std::cin >> op;
    switch (op)
    {
    case 'y':
    case 'Y':
      Calculate_next();
      break;
    case 'n':
    case 'N':
      End_calc();
      break;
    }
    return;
  }
  void CopyTable(char t1[ROWS][COLS], char t2[ROWS][COLS])
  {
    for (unsigned int i = 0; i < ROWS; i++)
    {
      size_t num = 0;
      RetrieveNumber(t1[i], num);
      StoreNumber(t2[i], num, false);
    }
  }
  void PrintTable()
  {
#ifdef DEBUG
    char t[ROWS][COLS] = { 0 };
    CopyTable(table, t);
    size_t col_size = 0x20;
    std::cout << " ";
    for (int i = 0; i < col_size; i++)
    {
      std::cout << "--";
    }
    std::cout << std::endl;
    for (int i = 0; i < u1.number_of_rows; i++)
    {
      std::cout << "<|";
      for (int j = 0; j < ROWS; j++)
      {
        std::cout << (char)(('0' + t[i][j])) << "|";
      }
      std::cout << ">" << std::endl;
    }
    std::cout << " ";
    for (int i = 0; i < col_size; i++)
    {
      std::cout << "--";
    }
    std::cout << std::endl;
#endif
  }
  void Init()
  {
#ifdef _DEBUG
    for (unsigned int i = 0; i < ROWS; i++)
    {
      StoreNumber(table[i], 0);
    }
#endif
  }
  void MainLoop()
  {
    bool done = false;
    Init();
    password = table[number_of_rows + 2];
    std::copy(&private_key[0], &private_key[0] + private_key_length, password);
    password[private_key_length - 1] |= 1;
    u1.number_of_rows = number_of_rows;
    while (!done)
    {
      if (u1.s1.init)
      {
        Encryption();
        continue;
      }
      size_t op = 0;
      std::cout << "Please choose your option:" << std::endl;
      std::cout << "0. Store Number" << std::endl;
      std::cout << "1. Get Number" << std::endl;
      std::cout << "2. Add" << std::endl;
      std::cout << "3. Subtract" << std::endl;
      std::cout << "4. Multiply" << std::endl;
      std::cout << "5. Divide" << std::endl;
      std::cout << "6. Private Key Encryption" << std::endl;
      std::cout << "7. Binary Representation" << std::endl;
      std::cout << "8. Exit" << std::endl;
      std::cin >> op;
      if (!std::cin)
      {
        done = true;
        break;
      }
      switch (op)
      {
      case 0:
      {
        size_t row = 0;
        size_t num = 0;
        std::cout << "Enter row and number" << std::endl;
        std::cin >> row >> num;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!ValidateRowIndex(row))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        StoreNumber(table[row], num);
        break;
      }
      case 1:
      {
        size_t row = 0;
        size_t num = 0;
        std::cout << "Enter row" << std::endl;
        std::cin >> row;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!ValidateRowIndex(row))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        RetrieveNumber(table[row], num);
        std::cout << "Result is " << num << std::endl;
        break;
      }
      case 2:
      {
        size_t row1 = 0, row2 = 0, row3 = 0;
        std::cout << "Enter row of arg1, row of arg2 and row of result" << std::endl;
        std::cin >> row1 >> row2 >> row3;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!(ValidateRowIndex(row1) && ValidateRowIndex(row2) && ValidateRowIndex(row3)))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        Add(row1, row2, row3);
        break;
      }
      case 3:
      {
        size_t row1 = 0, row2 = 0, row3 = 0;
        std::cout << "Enter row of arg1, row of arg2 and row of result" << std::endl;
        std::cin >> row1 >> row2 >> row3;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!(ValidateRowIndex(row1) && ValidateRowIndex(row2) && ValidateRowIndex(row3)))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        Sub(row1, row2, row3);
        break;
      }
      case 4:
      {
        size_t row1 = 0, row2 = 0, row3 = 0;
        std::cout << "Enter row of arg1, row of arg2 and row of result" << std::endl;
        std::cin >> row1 >> row2 >> row3;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!(ValidateRowIndex(row1) && ValidateRowIndex(row2) && ValidateRowIndex(row3)))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        Multiply(row1, row2, row3);
        break;
      }
      case 5:
      {
        size_t row1 = 0, row2 = 0, row3 = 0;
        std::cout << "Enter row of arg1, row of arg2 and row of result" << std::endl;
        std::cin >> row1 >> row2 >> row3;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!(ValidateRowIndex(row1) && ValidateRowIndex(row2) && ValidateRowIndex(row3)))
        {
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        Divide(row1, row2, row3);
        break;
      }
      case 6:
      {
        size_t row1 = 0, row2 = 0;
        u1.s1.init = 1;
        std::cout << "Enter row of message, row of key" << std::endl;
        std::cin >> row1 >> row2;
        if (!std::cin)
        {
          done = true;
          break;
        }
        if (!(ValidateRowIndex(row1) && ValidateRowIndex(row2)))
        {
          u1.s1.init = 0;
          std::cout << "Row number is out of range" << std::endl;
          break;
        }
        InitializeExp(row1, row2);
        break;
      }
      case 7:
      {
        PrintTable();
        break;
      }
      case 8:
      {
        done = true;
        break;
      }
      default:
      {
        std::cout << "Unknown option." << std::endl;
        break;
      }
      }
    }
  }
};
Algo a;
int main()
{
  a.MainLoop();
  return 0;
}
.PHONY: all clean
CCFLAGS=-std=c++11 -s -Os
OUT_DIR=$(PWD)/build/
all: out_dir bevx_cha1
bevx_lib.so: bevx_lib.o
        g++ -shared -o build/bevx_lib.so build/bevx_lib.o
bevx_lib.o: out_dir bevx_lib.cpp
        g++ $(CCFLAGS) -fPIC -c bevx_lib.cpp -o build/bevx_lib.o
bevx_cha1: bevx_cha1.cpp bevx_lib.so
        g++ -o build/bevx_cha1 $(CCFLAGS) bevx_cha1.cpp -L$(OUT_DIR) -l:./bevx_lib.so -lstdc++
out_dir:
        mkdir -p $(OUT_DIR)
clean:
        rm -rf build/

Solution (Tim)
The “encryption” routine operates on 1 bit of the key at a time, modifying an internal ongoing value each step. If the bit is a zero then this value is squared. If the bit is a one then the value is squared and then further multiplied by the value of the “message” that is selected to encrypt.
Because you can stop the encryption at any point, you can encrypt progressively more of the message, allowed each bit to be extracted by comparing the result with the result for the previous number of bits.
Script output:

[+] Connecting to x.x.x.x on port 22: Done
[+] Opening new channel: 'shell': Done
[*] Bit 0 result: 0x1
[*] Key: 0
[*] Bit 1 result: 0x11
[*] Key: 01
[*] Bit 2 result: 0x1331
....
[*] Bit 28 result: 0xb11e13b1
[*] Key: 01100010011001010101011001011
[*] Bit 29 result: 0x60ffc061
[*] Key: 011000100110010101010110010110
[*] Bit 30 result: 0x91cfa4c1
[*] Key: 0110001001100101010101100101100
#!/usr/bin/env python2
# beVX Challenge 1 exploit script
# - timpwn
import pwn          # pip install pwn
import logging
remote = True
# pwn.context.log_level = logging.DEBUG
def set_value(row, value):
    r.sendline("0")
    r.sendline(str(row))
    r.sendline(str(value))
    r.readuntil(prompt)
def get_value(row):
    r.sendline("1")
    r.readline()
    r.sendline(str(row))
    r.readuntil("Result is ")
    response = r.readline()
    value = int(response)
    r.readuntil(prompt)
    return value
def encrypt(message_row, key_row, bit_count):
    r.sendline("6")
    r.readline()
    r.sendline(str(message_row))
    r.sendline(str(key_row))
    continue_prompt = "Continue Encryption? (y/n)"
    r.readuntil(continue_prompt)
    r.readline()
    for i in range(bit_count):
        r.sendline("y")
        response = r.readline().strip()
        # The remote system echoes our input back
        if response == "y":
            response = r.readline().strip()
        if response != continue_prompt:
            pwn.log.debug("No more encryption at bit {}".format(i))
            r.readuntil(prompt)
            return
    r.sendline("n")
    r.readuntil(prompt)
def decode_key_row(row):
    result = ""
    previous_value = 1
    # Get data out for incremental encryption key bits
    p = pwn.log.progress("Reading row {}".format(row))
    for bit in range(0, 32):
        # Encrypt 0x11 using the key in the row specified
        p.status("Getting bit {}".format(bit))
        multiplier = 0x11
        set_value(0, multiplier)
        encrypt(0, row, bit)
        v = get_value(0)
        pwn.log.debug("Bit {} result: 0x{:x}".format(bit, v))
        # See if our multiplier has been used, which indicates that
        # the key has a "1" in this position
        previous_squared = previous_value * previous_value
        if v == previous_squared & 0xffffffff:
            result += "0"
        elif v == (previous_squared * multiplier) & 0xffffffff:
            result += "1"
        elif v == previous_value:
            # This means that we've gone past the end of the key,
            # so now we know that the first N bits were zeroes
            break
        else:
            pwn.log.warn("Unexpected value!")
            result += "?"
        pwn.log.debug("Progress: " + result)
        previous_value = v
    # Add the zeroes that we didn't get to see at the start of the key
    result = ("0" * (32-len(result))) + result
    p.success("Got row, value: " + result)
    return result
def connect():
    if remote:
        ssh = pwn.ssh(user="challenge",
                      host="x.x.x.x",
                      password="challenge")
        r = ssh.shell()
        prompt = "8. Exit\r\n"
    else:
        r = pwn.process("./cha1")
        prompt = "8. Exit\n"
    return (r, prompt)
r, prompt = connect()
r.readuntil(prompt)
# We can get rows 18-20 ok, after which no encryption cycles are possible -
# this is most likely because the rows are all zeroes.
rows_bits = list()
for row in range(18, 21):
    partial_key = decode_key_row(row)
    rows_bits.append(partial_key)
# Pad out and combine the key parts
combined = ""
for r in rows_bits:
    padded = ("0" * (32 - len(r))) + r
    print "Row value:", hex(int(padded, 2))
    combined += "{:04x}".format(int(padded, 2))
pwn.log.info("Key: " + repr(combined.decode("hex")))

Solution (mongo)
Key: “beVX Sep 20!”

import operator
import ctypes, sys, re, os
from pwn import *
from pwnlib.tubes.ssh import *
remote = 1
if not remote:
	conn = ssh(host='192.168.1.5', user='cha1', password='cha1')
else:
	conn = ssh(host='x.x.x.x', user='challenge', password='challenge')
s = conn.shell()
s.sendall = s.send
s.recvuntil("\r\n")
s.recvuntil("8. Exit")
def store(idx, val):
	s.sendall("0\n%d\n%d\n" % (idx, val))
	s.recvuntil("8. Exit")
def get(idx):
	s.sendall("1\n%d\n" % (idx))
	s.recvuntil("Result is ")
	v = int(s.recvuntil("\r\n"))
	s.recvuntil("8. Exit")
	return v
def privateenc(msg_idx, key_idx):
	s.sendall("6\n")
	s.recvuntil("Enter row of message, row of key\r\n")
	s.sendall("%d\n%d\n" % (msg_idx, key_idx))
	s.recvuntil("\n")
	s.recvuntil("\n")
def reset():
	for i in range(16):
		store(i, 0)
def get_key_part(key_idx):
	privateenc(16, key_idx)
	times = 0
	"""
	check the number of bits left in this key part
	"""
	for i in range(32):
		v = s.recvuntil("\r\n").strip()
		#print "<<", v
		if "Continue Encryption" in v:
			s.sendall("y\n")
			s.recvuntil("\n")
			times += 1
		else:
			break
	s.recvuntil("8. Exit")
	print "key %d bits = %d" % (key_idx, times)
	num_bits = times
	"""
	now, get key_part
	keep in mind result is multiplied by itself at every step
	if key bit is 1, we also multiply by 3 (baseval)
	"""
	baseval = 3
	pos = 32 - num_bits - 1
	skip = 0
	cur_val = 1
	key_part = 0
	for i in range(32 - num_bits, 32):
		store(0, baseval)
		privateenc(0, key_idx)
		val_if_1 = ((cur_val * cur_val) * baseval) & 0xFFFFFFFF
		val_if_0 = ((cur_val * cur_val)) & 0xFFFFFFFF
		times = 0
		for i in range(skip + 1):
			v = s.recvuntil("\n").strip()
			s.sendall("y\n")
			s.recvuntil("\n")
		v = s.recvuntil("\n").strip()
		if "Continue" in v:
			s.sendall("n\n")
			s.recvuntil("\n")
		s.recvuntil("8. Exit")
		res = get(0)
		#print "res=", res
		if res not in [val_if_0, val_if_1]:
			print res
			print [val_if_0, val_if_1]
			raise "Fail"
		key_part = (key_part << 1) | (1 if res == val_if_1 else 0)
		cur_val = res
		skip += 1
	print bin(key_part), "%08X" % key_part, ("%08X" % key_part).decode('hex')
for i in range(18, 22):
	get_key_part(i)
s.close()

Solution (Dmitry)
Secret key is “beVX Sep 20!”
Solution:
It is easy to detect that valid row numbers are 0..15 (thanks to error messages).
“Private Key Encryption” handling routine sets bit 3 (& 8) of number of rows thus allowing access to rows 16..23. Key is stored in rows 18..20. Each row represents 32-bit value.
Encryption is just calculation of pow(msgRow, keyRow, 1<<32)
Fastest method (using timing attack) allows recovering of row value in single pass. Each non-zero bit in exponent requires additional call to decrypt(), that causes sensitive delay.
But due to difficulties in automation of SSH interactive communication I derives each row in 3 steps:
1. Find number of bits if exponent (by counting “Continue Encryption? (y/n)” prompts)
2. Find highest 16 bits of exponent (by stopping encryption 16 bits before its end and brute-forcing 16 bit exponent value)
3. Find complete exponent (by brute-forcing lowest 16 bits)

import sys, subprocess, time
class SSH_beVX(object):
  EMSG = "8. Exit"
  def send_command(self, cmd):
    self.proc.stdin.write(cmd + "\n")
    self.proc.stdin.flush()
    ln = self.proc.stdout.readline()
    assert ln.startswith(cmd)
    self.started = time.clock()
  def read_line(self):
    return self.proc.stdout.readline()
  def read_until(self, msg=EMSG):
    lines = []
    while True:
      lines.append(self.read_line())
      if lines[-1].startswith(msg):
        return lines
  def write_row(self, row, val):
    self.send_command("0")
    self.read_line() # Enter row and number
    self.send_command("%d %d" % (row, val))
    self.read_until() # Please choose your option:
  def read_row(self, row):
    self.send_command("1")
    self.read_line() # Enter row
    self.send_command("%d" % row)
    ln = self.read_line() # Result is
    assert ln.startswith("Result is")
    self.read_until() # Please choose your option:
    return int(ln.split()[-1])
  def measure_crypt(self, keyRow, msgRow=0, val=3):
    self.write_row(msgRow, val)
    self.send_command("6")
    self.read_line() # Enter row of message, row of key
    self.send_command("%d %d" % (msgRow, keyRow))
    exp = 0
    while True:
      ln = self.read_line()
      delta = time.clock() - self.started
      bit = 1 if delta > 0.7 else 0
      exp = (exp*2) + bit
      sys.stderr.write("\r%8X" % exp)
      if not ln.startswith("Continue Encryption? (y/n)"): break
      self.send_command("Y")
    self.read_until() # Please choose your option:
    res = self.read_row(msgRow)
    if res != pow(val, exp, 1<<32):
      exp ^= 1
      if res != pow(val, exp, 1<<32): raise Exception("Can't find key[%d]" % keyRow)
    s = ("%08X" % exp).decode("hex")
    sys.stderr.write("\r%08X [%s]\n" % (exp, s))
    return s
  def __init__(self, host, username, password, port=22):
    args = ["plink", "-l", username, "-pw", password, "-P", "%d" % port, host]
    self.proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
def main():
  ssh = SSH_beVX("x.x.x.x", "challenge", "challenge")
  ssh.read_until() # Please choose your option:
  r = [ssh.measure_crypt(keyRow, 0, 7) for keyRow in xrange(18, 21)]
  print "Key is [%s]" % "".join(r) # "beVX Sep 20!"
  ssh.send_command("8")
if __name__=="__main__": main()