
Description: TCP protocol implementation with connection management
Language: c
Lines: 978

/*
 * SMACKTM TCP Implementation
 * Transmission Control Protocol with full state machine
 */

#include "../kernel.h"
#include "tcp.h"
#include "ip.h"
#include "network.h"

#define TCP_MSS             1460
#define TCP_WINDOW_SIZE     65535
#define TCP_MAX_RETRIES     3
#define TCP_TIMEOUT_MS      1000
#define TCP_KEEPALIVE_MS    120000

static struct tcp_connection* connections[MAX_TCP_CONNECTIONS];
static int connection_count = 0;
static uint32_t next_sequence_number = 1000;

void tcp_init(void) {
    kprintf("Initializing TCP protocol...\n");
    
    for (int i = 0; i < MAX_TCP_CONNECTIONS; i++) {
        connections[i] = NULL;
    }
    
    connection_count = 0;
    next_sequence_number = get_timer_ticks() * 1000;
    
    // Register with IP layer
    ip_register_protocol(IP_PROTOCOL_TCP, tcp_receive_packet);
    
    kprintf("TCP protocol initialized\n");
}

int tcp_socket_init(struct socket* sock) {
    struct tcp_socket_data* tcp_data = kmalloc(sizeof(struct tcp_socket_data));
    if (!tcp_data) {
        return -ENOMEM;
    }
    
    memset(tcp_data, 0, sizeof(struct tcp_socket_data));
    tcp_data->state = TCP_CLOSED;
    tcp_data->send_window = TCP_WINDOW_SIZE;
    tcp_data->recv_window = TCP_WINDOW_SIZE;
    tcp_data->mss = TCP_MSS;
    
    sock->protocol_data = tcp_data;
    return 0;
}

int tcp_socket_bind(struct socket* sock, struct sockaddr* addr, socklen_t addrlen) {
    if (addrlen < sizeof(struct sockaddr_in)) {
        return -EINVAL;
    }
    
    struct sockaddr_in* sin = (struct sockaddr_in*)addr;
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    tcp_data->local_addr = sin->sin_addr.s_addr;
    tcp_data->local_port = ntohs(sin->sin_port);
    
    sock->state = SOCKET_BOUND;
    return 0;
}

int tcp_socket_listen(struct socket* sock, int backlog) {
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    if (tcp_data->local_port == 0) {
        return -EINVAL; // Must bind first
    }
    
    tcp_data->state = TCP_LISTEN;
    tcp_data->backlog = backlog;
    sock->state = SOCKET_LISTENING;
    
    return 0;
}

struct socket* tcp_socket_accept(struct socket* sock, struct sockaddr* addr, socklen_t* addrlen) {
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    if (tcp_data->state != TCP_LISTEN) {
        return NULL;
    }
    
    // Wait for incoming connection
    while (tcp_data->pending_connections == NULL) {
        yield(); // Give up CPU while waiting
    }
    
    // Get first pending connection
    struct tcp_connection* conn = tcp_data->pending_connections;
    tcp_data->pending_connections = conn->next;
    
    // Create new socket for connection
    struct socket* new_sock = socket_create(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    if (!new_sock) {
        return NULL;
    }
    
    struct tcp_socket_data* new_tcp_data = (struct tcp_socket_data*)new_sock->protocol_data;
    new_tcp_data->connection = conn;
    new_tcp_data->state = TCP_ESTABLISHED;
    new_tcp_data->local_addr = conn->local_addr;
    new_tcp_data->local_port = conn->local_port;
    new_tcp_data->remote_addr = conn->remote_addr;
    new_tcp_data->remote_port = conn->remote_port;
    
    conn->socket = new_sock;
    new_sock->state = SOCKET_CONNECTED;
    
    // Fill in client address if requested
    if (addr && addrlen && *addrlen >= sizeof(struct sockaddr_in)) {
        struct sockaddr_in* sin = (struct sockaddr_in*)addr;
        sin->sin_family = AF_INET;
        sin->sin_addr.s_addr = conn->remote_addr;
        sin->sin_port = htons(conn->remote_port);
        *addrlen = sizeof(struct sockaddr_in);
    }
    
    return new_sock;
}

int tcp_socket_connect(struct socket* sock, struct sockaddr* addr, socklen_t addrlen) {
    if (addrlen < sizeof(struct sockaddr_in)) {
        return -EINVAL;
    }
    
    struct sockaddr_in* sin = (struct sockaddr_in*)addr;
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    // Allocate local port if not bound
    if (tcp_data->local_port == 0) {
        tcp_data->local_port = allocate_ephemeral_port();
    }
    
    tcp_data->remote_addr = sin->sin_addr.s_addr;
    tcp_data->remote_port = ntohs(sin->sin_port);
    
    // Create connection
    struct tcp_connection* conn = tcp_create_connection(tcp_data->local_addr, tcp_data->local_port,
                                                       tcp_data->remote_addr, tcp_data->remote_port);
    if (!conn) {
        return -ENOMEM;
    }
    
    tcp_data->connection = conn;
    conn->socket = sock;
    
    // Send SYN packet
    tcp_send_syn(conn);
    tcp_data->state = TCP_SYN_SENT;
    
    // Wait for connection to establish
    uint32_t timeout = get_timer_ticks() + (TCP_TIMEOUT_MS * timer_frequency / 1000);
    while (tcp_data->state == TCP_SYN_SENT && get_timer_ticks() < timeout) {
        yield();
    }
    
    if (tcp_data->state == TCP_ESTABLISHED) {
        sock->state = SOCKET_CONNECTED;
        return 0;
    } else {
        return -ETIMEDOUT;
    }
}

ssize_t tcp_socket_send(struct socket* sock, const void* buffer, size_t length, int flags) {
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    if (tcp_data->state != TCP_ESTABLISHED) {
        return -ENOTCONN;
    }
    
    struct tcp_connection* conn = tcp_data->connection;
    if (!conn) {
        return -ENOTCONN;
    }
    
    // Fragment data into TCP segments
    size_t bytes_sent = 0;
    const char* data = (const char*)buffer;
    
    while (bytes_sent < length) {
        size_t segment_size = min(length - bytes_sent, tcp_data->mss);
        
        // Wait for send window space
        while (conn->send_unacked >= tcp_data->send_window) {
            yield();
        }
        
        // Send segment
        if (tcp_send_data(conn, data + bytes_sent, segment_size) != 0) {
            return bytes_sent > 0 ? bytes_sent : -EIO;
        }
        
        bytes_sent += segment_size;
    }
    
    return bytes_sent;
}

ssize_t tcp_socket_receive(struct socket* sock, void* buffer, size_t length, int flags) {
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    if (tcp_data->state != TCP_ESTABLISHED) {
        return -ENOTCONN;
    }
    
    struct tcp_connection* conn = tcp_data->connection;
    if (!conn) {
        return -ENOTCONN;
    }
    
    // Wait for data
    while (conn->recv_buffer_size == 0) {
        if (tcp_data->state != TCP_ESTABLISHED) {
            return 0; // Connection closed
        }
        yield();
    }
    
    // Copy data from receive buffer
    size_t copy_size = min(length, conn->recv_buffer_size);
    memcpy(buffer, conn->recv_buffer, copy_size);
    
    // Remove copied data from buffer
    if (copy_size < conn->recv_buffer_size) {
        memmove(conn->recv_buffer, conn->recv_buffer + copy_size,
                conn->recv_buffer_size - copy_size);
    }
    conn->recv_buffer_size -= copy_size;
    
    // Send ACK to update window
    tcp_send_ack(conn);
    
    return copy_size;
}

int tcp_socket_close(struct socket* sock) {
    struct tcp_socket_data* tcp_data = (struct tcp_socket_data*)sock->protocol_data;
    
    if (tcp_data->connection) {
        tcp_close_connection(tcp_data->connection);
    }
    
    kfree(tcp_data);
    return 0;
}

void tcp_receive_packet(struct ip_packet* ip_pkt, void* tcp_data_ptr, size_t length) {
    struct tcp_header* tcp_hdr = (struct tcp_header*)tcp_data_ptr;
    
    // Find connection
    struct tcp_connection* conn = tcp_find_connection(ip_pkt->dest_addr, ntohs(tcp_hdr->dest_port),
                                                     ip_pkt->src_addr, ntohs(tcp_hdr->src_port));
    
    if (!conn) {
        // Check for SYN to listening socket
        if (tcp_hdr->flags & TCP_SYN) {
            conn = tcp_handle_incoming_syn(ip_pkt, tcp_hdr);
        }
        
        if (!conn) {
            // Send RST
            tcp_send_rst(ip_pkt->src_addr, ntohs(tcp_hdr->src_port),
                        ip_pkt->dest_addr, ntohs(tcp_hdr->dest_port),
                        ntohl(tcp_hdr->seq_num) + 1);
            return;
        }
    }
    
    // Process packet based on connection state
    tcp_process_packet(conn, tcp_hdr, (char*)tcp_data_ptr + (tcp_hdr->header_length * 4),
                      length - (tcp_hdr->header_length * 4));
}

void tcp_timer_tick(void) {
    // Process all active connections
    for (int i = 0; i < MAX_TCP_CONNECTIONS; i++) {
        if (connections[i] != NULL) {
            tcp_process_connection_timers(connections[i]);
        }
    }
}

================================================================================


### PROC MODULE ###

