programming/back-from-brazil

465 points

Writeup by Aryan

This is the source code we are provided with.

import random, time

def solve(eggs):
    redactedscript = """
    █ █ █████████
    ██ █ ██

    ███ █ ██ █████████
        ██████████████ █ ██

    ████████ █ ██████████

    ███ █ ██ █████████
        ███ █ ██ █████████
            ██ █ ██ █ ███ █ ██ ██
                ████████

            ██ █ ██ ██
                ████████ █ ████ █ █████
            ██ █ ██ ██
                ████████ █ █████████████ ███████ █ ███

            ████████ ██ ██████████

    ██████ ██████████
    """

    return sum([ord(c) for c in redactedscript])

n = 1000

start = time.time()

for _ in range(10):
    eggs = []
    for i in range(n):
        row = []
        for j in range(n):
            row.append(random.randint(0, 696969))
            print(row[j], end=' ')
        eggs.append(row)
        print()

    solution = solve(eggs)
    print("optimal: " + str(solution) + " 🥚")
    inputPath = input()
    inputAns = eggs[0][0]
    x = 0
    y = 0

    for direction in inputPath:
        match direction:
            case 'd':
                x += 1
            case 'r':
                y += 1
            case _:
                print("🤔")
                exit()

        if x == n or y == n:
            print("out of bounds")
            exit()

        inputAns += eggs[x][y]



    if inputAns < solution:
        print(inputAns)
        print("you didn't find enough 🥚")
        exit()
    elif len(inputPath) < 2 * n - 2:
        print("noooooooooooooooo, I'm still in Brazil")
        exit()

    if int(time.time()) - start > 60:
        print("you ran out of time")
        exit()

print("tnxs for finding all my 🥚")
f = open("/flag.txt", "r")
print(f.read())

We have a 1000 x 1000 grid of random numbers, and we have to find the path that has max sum (based on the numbers in the cells of our path). We can only move right or down.

It is a standard dynamic programming problem, and many implementations exist online. Here's one in case you want to understand the approach in more detail. Attached below is my solve script.

import socket

def connect_and_retrieve(sock):
    def recv_until_optimal(sock):
        data = b""
        while True:
            chunk = sock.recv(4096)
            if not chunk:
                break
            data += chunk
            if b"optimal:" in chunk:
                break
        return data.decode('utf-8')

    data = recv_until_optimal(sock)
    lines = data.strip().split('\n')
    optimal = ""
    eggs = []
    for line in lines:
        if "optimal:" in line:
            optimal = int(line.split(": ")[1].split(" ")[0])
            break
        row = [int(x) for x in line.split()]
        eggs.append(row)

    return eggs, optimal

def solve(eggs):
    n = len(eggs)
    dp = [[0 for _ in range(n)] for _ in range(n)]
    path = [['' for _ in range(n)] for _ in range(n)]

    dp[0][0] = eggs[0][0]

    for i in range(1, n):
        dp[i][0] = dp[i-1][0] + eggs[i][0]
        path[i][0] = path[i-1][0] + 'r'

    for j in range(1, n):
        dp[0][j] = dp[0][j-1] + eggs[0][j]
        path[0][j] = path[0][j-1] + 'd'

    for i in range(1, n):
        for j in range(1, n):
            if dp[i-1][j] > dp[i][j-1]:
                dp[i][j] = dp[i-1][j] + eggs[i][j]
                path[i][j] = path[i-1][j] + 'r'
            else:
                dp[i][j] = dp[i][j-1] + eggs[i][j]
                path[i][j] = path[i][j-1] + 'd'

    return dp[n-1][n-1], path[n-1][n-1]

def main():
    server_ip = "24.199.110.35"
    server_port = 43298

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.connect((server_ip, server_port))

        for _ in range(10):
            print(f"Iteration #{_}")
            eggs, optimal = connect_and_retrieve(s)
            print("got the input")

            optimal_calculated, path = solve(eggs)
            print("solve script executed")

            assert optimal == optimal_calculated, "optimal dp doesnt match"
            assert len(path) == 1998, "wrong path length"

            s.sendall((path + "\n").encode('utf-8'))

        flag = s.recv(4096).decode('utf-8')
        print("Flag:", flag)

if __name__ == "__main__":
    main()

Flag: n00bz{1_g0t_b4ck_h0m3!!!}

Last updated