# proxy.py
from flask import Flask, session
import json
import socket
import subprocess
import threading
import secrets
import os
import hashlib

from pathlib import Path
from time import time

from routes import register_routes
from federation import Federation

STORAGE_DIR = Path("/var/lib/resilience")
USERS_FILE = Path("users.json")
RESILIENCE_BIN = Path("/usr/local/bin/resilience")
SECRET_FILE = Path("secret.key")

if SECRET_FILE.exists():
    with open(SECRET_FILE, 'rb') as f:
        GLOBAL_SECRET = f.read()
else:
    GLOBAL_SECRET = secrets.token_bytes(32)
    with open(SECRET_FILE, 'wb') as f:
        f.write(GLOBAL_SECRET)

DOMAIN = os.environ.get("DOMAIN")
if not DOMAIN:
    print("ERROR: DOMAIN environment variable must be set")
    print("Example: DOMAIN=payment.example.com python proxy.py")
    exit(1)

CERT_PATH = Path(os.environ.get("CERT_PATH", f"/etc/letsencrypt/live/{DOMAIN}"))
CERT_NAME = os.environ.get("CERT_NAME", "fullchain.pem")
KEY_NAME = os.environ.get("KEY_NAME", "privkey.pem")

users = {}
processes = {}
ports = {}
next_port = 9000

app = Flask(__name__)
app.secret_key = secrets.token_bytes(32)

PAYMENT_REQUEST_TIMEOUT = 900
payment_requests = {}
payment_lock = threading.Lock()

def add_payment_request(recipient, sender, sender_domain, key):
    with payment_lock:
        if recipient not in payment_requests:
            payment_requests[recipient] = []
        
        current_time = time()
        payment_requests[recipient] = [
            req for req in payment_requests[recipient]
            if current_time - req['timestamp'] < PAYMENT_REQUEST_TIMEOUT
        ]
        
        if len(payment_requests[recipient]) >= 16:
            return False
        
        payment_id = hashlib.sha256(key).hexdigest()
        
        payment_requests[recipient].append({
            'id': payment_id,
            'sender': sender,
            'sender_domain': sender_domain, 
            'key': key,
            'timestamp': current_time
        })
        return True

def load_users():
    global users
    if USERS_FILE.exists():
        with open(USERS_FILE) as f:
            users = json.load(f)
    else:
        users = {}

def save_users():
    tmp_file = USERS_FILE.with_suffix('.tmp')
    with open(tmp_file, 'w') as f:
        json.dump(users, f)
        f.flush()
        os.fsync(f.fileno())
    tmp_file.replace(USERS_FILE)

def spawn_resilience(user_id):
    global next_port
    storage_file = STORAGE_DIR / f"{user_id}.bin"
    port = next_port
    next_port += 1
    
    if not storage_file.exists():
        subprocess.run([
            str(RESILIENCE_BIN), user_id, DOMAIN,
            "0", "0" * 64, str(storage_file)
        ], check=True)
    
    proc = subprocess.Popen([str(RESILIENCE_BIN), "1", str(storage_file), str(port)])
    processes[user_id] = proc
    ports[user_id] = port

def remove_resilience(user_id):
    if user_id in processes:
        processes[user_id].terminate()
        del processes[user_id]
        del ports[user_id]
        
        storage_file = STORAGE_DIR / f"{user_id}.bin"
        if storage_file.exists():
            storage_file.unlink()

def send_udp_command(user_id, command, args=b''):
    port = ports.get(user_id)
    if not port:
        return 1, b'User not found'
    
    message = bytes([8, command]) + args
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(5.0)
    
    try:
        sock.sendto(message, ('127.0.0.1', port))
        response, _ = sock.recvfrom(4096)
        return response[0], response[1:]
    finally:
        sock.close()

def require_auth(f):
    def wrapper(*args, **kwargs):
        if 'user_id' not in session:
            from flask import jsonify
            return jsonify({'error': 'Unauthorized'}), 401
        return f(*args, **kwargs)
    wrapper.__name__ = f.__name__
    return wrapper

def require_admin(f):
    def wrapper(*args, **kwargs):
        if 'user_id' not in session or not session.get('is_admin'):
            from flask import jsonify
            return jsonify({'error': 'Forbidden'}), 403
        return f(*args, **kwargs)
    wrapper.__name__ = f.__name__
    return wrapper

def create_ssl_connection(server, port):
    import ssl
    context = ssl.create_default_context()
    
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.settimeout(5.0)
    ssl_sock = context.wrap_socket(sock, server_hostname=server)
    ssl_sock.connect((server, port))
    return ssl_sock

def p2p_router():
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind(('0.0.0.0', 2012))
    
    while True:
        try:
            data, _ = sock.recvfrom(508)
            if len(data) < 33:
                continue
            if (data[0] & 8):   # USER_SESSION
                continue

            recipient = data[1:33].decode('utf-8', errors='ignore').split('\0')[0]
            
            if recipient in ports:
                fwd_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                fwd_sock.sendto(data, ('127.0.0.1', ports[recipient]))
                fwd_sock.close()
        except:
            continue

if __name__ == '__main__':
    STORAGE_DIR.mkdir(parents=True, exist_ok=True)
    load_users()
    
    if not users:
        users['admin'] = {'password': 'admin', 'is_admin': True}
        save_users()

    register_routes(
        app,
        send_udp_command,
        require_auth,
        require_admin,
        create_ssl_connection,
        GLOBAL_SECRET,
        DOMAIN,
        users,
        save_users,
        spawn_resilience,
        remove_resilience,
        payment_requests,
        payment_lock,
        add_payment_request
    )

    for user_id, user_data in users.items():
        if not user_data.get('is_admin', False):
            spawn_resilience(user_id)
    
    threading.Thread(target=p2p_router, daemon=True).start()
    
    fed_server = Federation(
        users, ports, GLOBAL_SECRET, DOMAIN, send_udp_command,
        CERT_PATH, CERT_NAME, KEY_NAME, add_payment_request
    )
    threading.Thread(target=fed_server.run, daemon=True).start()
    
    print(f"Resilience Proxy starting on port 3000")
    app.run(host='0.0.0.0', port=3000, threaded=True)