pex: move rx header check to callback function
[project/unetd.git] / pex.c
diff --git a/pex.c b/pex.c
index 1ec140dc61175c19df36e94801c48abf84bb5ec7..1f831a0e1da3bc27e71e406045f5d41588b4b762 100644 (file)
--- a/pex.c
+++ b/pex.c
@@ -5,6 +5,10 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/udp.h>
 #include <fcntl.h>
 #include <stdlib.h>
 #include <inttypes.h>
@@ -57,7 +61,7 @@ pex_get_peer_addr(struct sockaddr_in6 *sin6, struct network *net,
        *sin6 = (struct sockaddr_in6){
                .sin6_family = AF_INET6,
                .sin6_addr = peer->local_addr.in6,
-               .sin6_port = htons(net->net_config.pex_port),
+               .sin6_port = htons(peer->pex_port),
        };
 }
 
@@ -65,11 +69,12 @@ static void pex_msg_send(struct network *net, struct network_peer *peer)
 {
        struct sockaddr_in6 sin6 = {};
 
-       if (!peer || peer == &net->net_config.local_host->peer)
+       if (!peer || peer == &net->net_config.local_host->peer ||
+           !peer->pex_port)
                return;
 
        pex_get_peer_addr(&sin6, net, peer);
-       if (__pex_msg_send(net->pex.fd.fd, &sin6) < 0)
+       if (__pex_msg_send(net->pex.fd.fd, &sin6, NULL, 0) < 0)
                D_PEER(net, peer, "pex_msg_send failed: %s", strerror(errno));
 }
 
@@ -81,7 +86,7 @@ static void pex_msg_send_ext(struct network *net, struct network_peer *peer,
        if (!addr)
                return pex_msg_send(net, peer);
 
-       if (__pex_msg_send(-1, addr) < 0)
+       if (__pex_msg_send(-1, addr, NULL, 0) < 0)
                D_NET(net, "pex_msg_send_ext(%s) failed: %s",
                      inet_ntop(addr->sin6_family, (const void *)&addr->sin6_addr, addrbuf,
                                sizeof(addrbuf)),
@@ -163,8 +168,22 @@ network_pex_handle_endpoint_change(struct network *net, struct network_peer *pee
 static void
 network_pex_host_request_update(struct network *net, struct network_pex_host *host)
 {
+       union {
+               struct {
+                       struct ip ip;
+                       struct udphdr udp;
+               } ipv4;
+               struct {
+                       struct ip6_hdr ip;
+                       struct udphdr udp;
+               } ipv6;
+       } packet = {};
+       struct udphdr *udp;
        char addrstr[INET6_ADDRSTRLEN];
+       union network_endpoint dest_ep;
+       union network_addr local_addr = {};
        uint64_t version = 0;
+       int len;
 
        if (net->net_data_len)
                version = net->net_data_version;
@@ -180,7 +199,57 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                                         net->config.auth_key, &host->endpoint,
                                         version, true))
                return;
-       __pex_msg_send(-1, &host->endpoint);
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+
+       if (!net->net_config.local_host)
+               return;
+
+       pex_msg_init_ext(net, PEX_MSG_ENDPOINT_NOTIFY, true);
+
+       memcpy(&dest_ep, &host->endpoint, sizeof(dest_ep));
+
+       /* work around issue with local address lookup for local broadcast */
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               uint8_t *data = (uint8_t *)&dest_ep.in.sin_addr;
+
+               if (data[3] == 0xff)
+                       data[3] = 0xfe;
+       }
+       network_get_local_addr(&local_addr, &dest_ep);
+
+       memset(&dest_ep, 0, sizeof(dest_ep));
+       dest_ep.sa.sa_family = host->endpoint.sa.sa_family;
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               packet.ipv4.ip = (struct ip){
+                       .ip_hl = 5,
+                       .ip_v = 4,
+                       .ip_ttl = 64,
+                       .ip_p = IPPROTO_UDP,
+                       .ip_src = local_addr.in,
+                       .ip_dst = host->endpoint.in.sin_addr,
+               };
+               dest_ep.in.sin_addr = host->endpoint.in.sin_addr;
+               udp = &packet.ipv4.udp;
+               len = sizeof(packet.ipv4);
+       } else {
+               packet.ipv6.ip = (struct ip6_hdr){
+                       .ip6_flow = htonl(6 << 28),
+                       .ip6_hops = 128,
+                       .ip6_nxt = IPPROTO_UDP,
+                       .ip6_src = local_addr.in6,
+                       .ip6_dst = host->endpoint.in6.sin6_addr,
+               };
+               dest_ep.in6.sin6_addr = host->endpoint.in6.sin6_addr;
+               udp = &packet.ipv6.udp;
+               len = sizeof(packet.ipv6);
+       }
+
+       udp->uh_sport = htons(net->net_config.local_host->peer.port);
+       udp->uh_dport = host->endpoint.in6.sin6_port;
+
+       if (__pex_msg_send(-1, &dest_ep, &packet, len) < 0)
+               D_NET(net, "pex_msg_send_raw failed: %s", strerror(errno));
 }
 
 static void
@@ -192,10 +261,17 @@ network_pex_request_update_cb(struct uloop_timeout *t)
 
        uloop_timeout_set(t, 5000);
 
+retry:
        if (list_empty(&pex->hosts))
                return;
 
        host = list_first_entry(&pex->hosts, struct network_pex_host, list);
+       if (host->timeout && host->timeout < unet_gettime()) {
+               list_del(&host->list);
+               free(host);
+               goto retry;
+       }
+
        list_move_tail(&host->list, &pex->hosts);
        network_pex_host_request_update(net, host);
 }
@@ -267,8 +343,12 @@ network_pex_query_hosts(struct network *net)
 static void
 network_pex_send_ping(struct network *net, struct network_peer *peer)
 {
+       if (peer->state.pinged || !peer->state.endpoint.sa.sa_family)
+               return;
+
        pex_msg_init(net, PEX_MSG_PING);
        pex_msg_send(net, peer);
+       peer->state.pinged = true;
 }
 
 static void
@@ -300,11 +380,6 @@ void network_pex_event(struct network *net, struct network_peer *peer,
        if (!network_pex_active(&net->pex))
                return;
 
-       if (peer)
-               D_PEER(net, peer, "PEX event type=%d", ev);
-       else
-               D_NET(net, "PEX event type=%d", ev);
-
        switch (ev) {
        case PEX_EV_HANDSHAKE:
                pex_send_hello(net, peer);
@@ -502,7 +577,7 @@ network_pex_recv_update_response(struct network *net, const uint8_t *data, size_
 
        uloop_timeout_set(&net->reload_timer, no_prev_data ? 1 : UNETD_DATA_UPDATE_DELAY);
        vlist_for_each_element(&net->peers, peer, node) {
-               if (!peer->state.connected)
+               if (!peer->state.connected || !peer->pex_port)
                        continue;
                network_pex_send_update_request(net, peer, NULL);
        }
@@ -542,6 +617,8 @@ network_pex_recv(struct network *net, struct network_peer *peer, struct pex_hdr
                network_pex_recv_update_response(net, data, hdr->len,
                                              NULL, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               break;
        }
 }
 
@@ -575,11 +652,8 @@ network_pex_fd_cb(struct uloop_fd *fd, unsigned int events)
                if (!len)
                        continue;
 
-               if (len < sizeof(*hdr))
-                       continue;
-
-               hdr->len = ntohs(hdr->len);
-               if (len - sizeof(hdr) < hdr->len)
+               hdr = pex_rx_accept(buf, len, false);
+               if (!hdr)
                        continue;
 
                peer = pex_msg_peer(net, hdr->id);
@@ -596,15 +670,29 @@ network_pex_fd_cb(struct uloop_fd *fd, unsigned int events)
        }
 }
 
-static void
-network_pex_create_host(struct network *net, union network_endpoint *ep)
+void network_pex_create_host(struct network *net, union network_endpoint *ep,
+                            unsigned int timeout)
 {
        struct network_pex *pex = &net->pex;
        struct network_pex_host *host;
+       bool new_host = false;
+
+       list_for_each_entry(host, &pex->hosts, list) {
+               if (memcmp(&host->endpoint, ep, sizeof(host->endpoint)) != 0)
+                       continue;
+
+               list_move_tail(&host->list, &pex->hosts);
+               goto out;
+       }
 
        host = calloc(1, sizeof(*host));
+       new_host = true;
        memcpy(&host->endpoint, ep, sizeof(host->endpoint));
        list_add_tail(&host->list, &pex->hosts);
+
+out:
+       if (timeout && (new_host || host->timeout))
+               host->timeout = timeout + unet_gettime();
        network_pex_host_request_update(net, host);
 }
 
@@ -624,7 +712,7 @@ network_pex_open_auth_connect(struct network *net)
        vlist_for_each_element(&net->peers, peer, node) {
                union network_endpoint ep = {};
 
-               if (!peer->endpoint)
+               if (!peer->endpoint || peer->dynamic)
                        continue;
 
                if (network_get_endpoint(&ep, peer->endpoint,
@@ -632,7 +720,7 @@ network_pex_open_auth_connect(struct network *net)
                        continue;
 
                ep.in.sin_port = htons(UNETD_GLOBAL_PEX_PORT);
-               network_pex_create_host(net, &ep);
+               network_pex_create_host(net, &ep, 0);
        }
 
        if (!net->config.auth_connect)
@@ -645,7 +733,7 @@ network_pex_open_auth_connect(struct network *net)
                                         UNETD_GLOBAL_PEX_PORT, 0) < 0)
                        continue;
 
-               network_pex_create_host(net, &ep);
+               network_pex_create_host(net, &ep, 0);
        }
 }
 
@@ -661,7 +749,7 @@ int network_pex_open(struct network *net)
 
        network_pex_open_auth_connect(net);
 
-       if (!local_host || !net->net_config.pex_port)
+       if (!local_host || !local_host->peer.pex_port)
                return 0;
 
        local = &local_host->peer;
@@ -675,7 +763,7 @@ int network_pex_open(struct network *net)
        sin6.sin6_family = AF_INET6;
        memcpy(&sin6.sin6_addr, &local->local_addr.in6,
               sizeof(local->local_addr.in6));
-       sin6.sin6_port = htons(net->net_config.pex_port);
+       sin6.sin6_port = htons(local_host->peer.pex_port);
 
        if (bind(fd, (struct sockaddr *)&sin6, sizeof(sin6)) < 0) {
                perror("bind");
@@ -704,9 +792,16 @@ void network_pex_close(struct network *net)
 {
        struct network_pex *pex = &net->pex;
        struct network_pex_host *host, *tmp;
+       uint64_t now = unet_gettime();
 
        uloop_timeout_cancel(&pex->request_update_timer);
        list_for_each_entry_safe(host, tmp, &pex->hosts, list) {
+               if (host->timeout)
+                       continue;
+
+               if (host->last_active + UNETD_PEX_HOST_ACITVE_TIMEOUT >= now)
+                       continue;
+
                list_del(&host->list);
                free(host);
        }
@@ -719,6 +814,17 @@ void network_pex_close(struct network *net)
        network_pex_init(net);
 }
 
+void network_pex_free(struct network *net)
+{
+       struct network_pex *pex = &net->pex;
+       struct network_pex_host *host, *tmp;
+
+       list_for_each_entry_safe(host, tmp, &pex->hosts, list) {
+               list_del(&host->list);
+               free(host);
+       }
+}
+
 static struct network *
 global_pex_find_network(const uint8_t *id)
 {
@@ -733,12 +839,36 @@ global_pex_find_network(const uint8_t *id)
 }
 
 static void
-global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
+global_pex_set_active(struct network *net, struct sockaddr_in6 *addr)
 {
-       struct pex_ext_hdr *ehdr = (void *)(hdr + 1);
+       struct network_pex *pex = &net->pex;
+       struct network_pex_host *host;
+
+       list_for_each_entry(host, &pex->hosts, list) {
+               if (memcmp(&host->endpoint.in6, addr, sizeof(*addr)) != 0)
+                       continue;
+
+               host->last_active = unet_gettime();
+       }
+}
+
+static void
+global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
+{
+       struct pex_hdr *hdr;
+       struct pex_ext_hdr *ehdr;
        struct network_peer *peer;
        struct network *net;
-       void *data = (void *)(ehdr + 1);
+       char buf[INET6_ADDRSTRLEN];
+       void *data;
+       int addr_len;
+
+       hdr = pex_rx_accept(msg, msg_len, true);
+       if (!hdr)
+               return;
+
+       ehdr = (void *)(hdr + 1);
+       data = (void *)(ehdr + 1);
 
        if (hdr->version != 0)
                return;
@@ -749,6 +879,8 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
 
        *(uint64_t *)hdr->id ^= pex_network_hash(net->config.auth_key, ehdr->nonce);
 
+       global_pex_set_active(net, addr);
+
        D("PEX global rx op=%d", hdr->opcode);
        switch (hdr->opcode) {
        case PEX_MSG_HELLO:
@@ -767,15 +899,49 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
        case PEX_MSG_UPDATE_RESPONSE_NO_DATA:
                network_pex_recv_update_response(net, data, hdr->len, addr, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               peer = pex_msg_peer(net, hdr->id);
+               if (!peer)
+                       break;
+
+               D_PEER(net, peer, "receive endpoint notification from %s",
+                 inet_ntop(addr->sin6_family, network_endpoint_addr((void *)addr, &addr_len),
+                           buf, sizeof(buf)));
+
+               memcpy(&peer->state.next_endpoint, addr, sizeof(*addr));
+               break;
        }
 }
 
-int global_pex_open(void)
+static void
+pex_recv_control(struct pex_msg_local_control *msg, int len)
+{
+       struct network *net;
+
+       if (msg->msg_type != 0)
+               return;
+
+       net = global_pex_find_network(msg->auth_id);
+       if (!net)
+               return;
+
+       if (!msg->timeout)
+               msg->timeout = 60;
+       network_pex_create_host(net, &msg->ep, msg->timeout);
+}
+
+int global_pex_open(const char *unix_path)
 {
        struct sockaddr_in6 sin6 = {};
+       int ret;
 
        sin6.sin6_family = AF_INET6;
        sin6.sin6_port = htons(global_pex_port);
 
-       return pex_open(&sin6, sizeof(sin6), global_pex_recv, true);
+       ret = pex_open(&sin6, sizeof(sin6), global_pex_recv, true);
+
+       if (unix_path)
+               pex_unix_open(unix_path, pex_recv_control);
+
+       return ret;
 }