diff --git a/Adam/Net/Tcp.HC b/Adam/Net/Tcp.HC index 4027a49..fe98345 100644 --- a/Adam/Net/Tcp.HC +++ b/Adam/Net/Tcp.HC @@ -102,6 +102,58 @@ class CTcpPseudoHeader { U16 tcp_length; }; +class CTcpSocketListItem { + CTcpSocketListItem *prev; + CTcpSocketListItem *next; + CTcpSocket *sock; +}; + +static CTcpSocketListItem** tcp_socket_list; + +static CTcpSocket* GetTcpSocketFromList(CIPv4Packet* packet, CTcpHeader* hdr) { + CTcpSocketListItem* item = tcp_socket_list[ntohs(hdr->dest_port)]->next; + while (item) { + if (item->sock->remote_addr == packet->source_ip && + item->sock->remote_port == ntohs(hdr->source_port)) { + return item->sock; + } + item = item->next; + } + return NULL; +} + +U0 AddTcpSocketToList(CTcpSocket* s) { + CTcpSocketListItem* prev = tcp_socket_list[s->local_port]; + CTcpSocketListItem* new = CAlloc(sizeof(CTcpSocketListItem)); + while (prev->next) { + prev = prev->next; + } + new->prev = prev; + new->sock = s; + prev->next = new; +} + +CTcpSocket* RemoveTcpSocketFromList(CTcpSocket* s) { + CTcpSocketListItem* prev = NULL; + CTcpSocketListItem* next = NULL; + CTcpSocketListItem* item = tcp_socket_list[s->local_port]->next; + while (item) { + if (item->sock==s) { + prev = item->prev; + next = item->next; + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + return s; + } + item = item->next; + } + return NULL; +} + // TODO: this takes up half a meg, change it to a binary tree or something static CTcpSocket** tcp_bound_sockets; @@ -454,7 +506,8 @@ I64 TcpSocketClose(CTcpSocket* s) { } if (s->local_port) - tcp_bound_sockets[s->local_port] = NULL; + if (!RemoveTcpSocketFromList(s)) + tcp_bound_sockets[s->local_port] = NULL; Free(s->recv_buf); Free(s); @@ -710,8 +763,7 @@ U0 TcpSocketHandle(CTcpSocket* s, CIPv4Packet* packet, CTcpHeader* hdr, U8* data TcpSend2(new_socket, TCP_FLAG_SYN | TCP_FLAG_ACK); new_socket->state = TCP_STATE_SYN_RECEIVED; - // FIXME FIXME FIXME FIXME - tcp_bound_sockets[new_socket->local_port] = new_socket; + AddTcpSocketToList(new_socket); if (s->backlog_last) s->backlog_last->backlog_next = new_socket; @@ -896,7 +948,9 @@ I64 TcpHandler(CIPv4Packet* packet) { U16 dest_port = ntohs(hdr->dest_port); //"%u => %p\n", dest_port, tcp_bound_sockets[dest_port]; - CTcpSocket* s = tcp_bound_sockets[dest_port]; + CTcpSocket* s = GetTcpSocketFromList(packet, hdr); + if (!s) + s = tcp_bound_sockets[dest_port]; // FIXME: should also check that bound address is INADDR_ANY, // OR packet dest IP matches bound address @@ -911,8 +965,14 @@ I64 TcpHandler(CIPv4Packet* packet) { } U0 TcpInit() { + I64 i; tcp_bound_sockets = MAlloc(65536 * sizeof(CTcpSocket*)); MemSet(tcp_bound_sockets, 0, 65536 * sizeof(CTcpSocket*)); + tcp_socket_list = MAlloc(65536 * sizeof(CTcpSocketListItem*)); + for (i=0; i<65536; i++) + { + tcp_socket_list[i] = CAlloc(sizeof(CTcpSocketListItem)); + } } TcpInit; diff --git a/Demo/Network/TcpEchoServer.HC b/Demo/Network/TcpEchoServer.HC index 081cc82..2087e88 100644 --- a/Demo/Network/TcpEchoServer.HC +++ b/Demo/Network/TcpEchoServer.HC @@ -4,6 +4,23 @@ #define PORT 8000 +U0 TcpEchoSession(CTcpSocket* client) { + U8 buffer[2048 + 1]; + I64 count = recv(client, buffer, sizeof(buffer) - 1, 0); + + if (count <= 0) { + "$FG,6$recv: error %d\n$FG$", count; + } + else { + buffer[count] = 0; + "$FG,8$Received %d bytes:\n$FG$%s\n", count, buffer; + } + + send(client, buffer, count, 0); + + close(client); +} + I64 TcpEchoServer() { SocketInit(); @@ -32,23 +49,14 @@ I64 TcpEchoServer() { "$FG,2$Listening on port %d\n$FG$", PORT; - I64 client = accept(sock, 0, 0); - - U8 buffer[2048 + 1]; - I64 count = recv(client, buffer, sizeof(buffer) - 1, 0); - - if (count <= 0) { - "$FG,6$recv: error %d\n$FG$", count; + while (1) { + I64 client = accept(sock, 0, 0); + if (client) + Spawn(&TcpEchoSession, client); + else + break; + Yield; // loop unconditionally } - else { - buffer[count] = 0; - "$FG,8$Received %d bytes:\n$FG$%s\n", count, buffer; - } - - send(client, buffer, count, 0); - - close(client); - close(sock); return 0; }