use IP(V6)_PKTINFO in KDC for UDP sockets
authorKen Raeburn <raeburn@mit.edu>
Sat, 7 Apr 2007 05:15:31 +0000 (05:15 +0000)
committerKen Raeburn <raeburn@mit.edu>
Sat, 7 Apr 2007 05:15:31 +0000 (05:15 +0000)
As Denis Vlasenko pointed out in ticket 3306, using IP(V6)_PKTINFO to
get or set the local address in UDP communications instead of
allocating one socket for each address seen at startup will behave
better in environments where the addresses may change while the KDC is
running, or in certain unusual network configurations.

The patch from Denis was specific to Linux (didn't do IPV6_PKTINFO if
IP_PKTINFO wasn't defined).  I've reworked it a fair amount, and
tested the results briefly on Mac OS X (which has IPV6_PKTINFO but not
IP_PKTINFO) and Linux (which has both).

With this change, on systems like Linux supporting both socket
options, the KDC should be able to use just two UDP sockets, one for
IPv4 and one for IPv6.  (And if we turned off IPV6_V6ONLY, we might do
with one.)

Filed as a separate ticket, because Denis's complaint and patch in
3306 cover the RPC code as well.

ticket: new

git-svn-id: svn://anonsvn.mit.edu/krb5/trunk@19408 dc483132-0cff-0310-8789-dd5450dbe970

src/kdc/network.c

index 3cb4400dfc3b45c86419b26a4840c949e3fb8d56..0b4514896db1dcce622e1472946eff1c078a9ccf 100644 (file)
@@ -115,6 +115,33 @@ setv6only(int sock, int value)
 }
 #endif
 
+static int
+set_pktinfo(int sock, int family)
+{
+    int sockopt = 1;
+    int option = 0, proto = 0;
+
+    switch (family) {
+#if defined(IP_PKTINFO) && defined(HAVE_STRUCT_IN_PKTINFO)
+    case AF_INET:
+       proto = IPPROTO_IP;
+       option = IP_PKTINFO;
+       break;
+#endif
+#if defined(IPV6_PKTINFO) && defined(HAVE_STRUCT_IN6_PKTINFO)
+    case AF_INET6:
+       proto = IPPROTO_IPV6;
+       option = IPV6_PKTINFO;
+       break;
+#endif
+    default:
+       return EINVAL;
+    }
+    if (setsockopt(sock, proto, option, &sockopt, sizeof(sockopt)))
+       return errno;
+    return 0;
+}
+
 \f
 static const char *paddr (struct sockaddr *sa)
 {
@@ -138,7 +165,7 @@ static const char *paddr (struct sockaddr *sa)
 
 /* KDC data.  */
 
-enum kdc_conn_type { CONN_UDP, CONN_TCP_LISTENER, CONN_TCP };
+enum kdc_conn_type { CONN_UDP, CONN_UDP_PKTINFO, CONN_TCP_LISTENER, CONN_TCP };
 
 /* Per-connection info.  */
 struct connection {
@@ -147,12 +174,17 @@ struct connection {
     void (*service)(struct connection *, const char *, int);
     union {
        /* Type-specific information.  */
+#if 0
        struct {
            int x;
        } udp;
+       struct {
+           int x;
+       } udp_pktinfo;
        struct {
            int x;
        } tcp_listener;
+#endif
        struct {
            /* connection */
            struct sockaddr_storage addr_s;
@@ -206,8 +238,8 @@ struct connection {
 #define DEL(set, idx) \
   (set.data[idx] = set.data[--set.n], 0)
 
-#define FREE_SET_DATA(set) if(set.data) free(set.data);                 \
-   (set.data = 0, set.max = 0)
+#define FREE_SET_DATA(set) \
+  (free(set.data), set.data = 0, set.max = 0)
 
 
 /* Set<struct connection *> connections; */
@@ -268,6 +300,9 @@ static krb5_error_code add_tcp_port(int port)
 struct socksetup {
     const char *prog;
     krb5_error_code retval;
+    int udp_flags;
+#define UDP_DO_IPV4 1
+#define UDP_DO_IPV6 2
 };
 
 static struct connection *
@@ -303,9 +338,10 @@ static void accept_tcp_connection(struct connection *, const char *, int);
 static void process_tcp_connection(struct connection *, const char *, int);
 
 static struct connection *
-add_udp_fd (struct socksetup *data, int sock)
+add_udp_fd (struct socksetup *data, int sock, int pktinfo)
 {
-    return add_fd(data, sock, CONN_UDP, process_packet);
+    return add_fd(data, sock, pktinfo ? CONN_UDP_PKTINFO : CONN_UDP,
+                 process_packet);
 }
 
 static struct connection *
@@ -484,15 +520,139 @@ setup_tcp_listener_ports(struct socksetup *data)
     return 0;
 }
 
+#if defined(CMSG_SPACE) && defined(HAVE_STRUCT_CMSGHDR) && (defined(IP_PKTINFO) || defined(IPV6_PKTINFO))
+union pktinfo {
+#ifdef HAVE_STRUCT_IN6_PKTINFO
+    struct in6_pktinfo pi6;
+#endif
+#ifdef HAVE_STRUCT_IN_PKTINFO
+    struct in_pktinfo pi4;
+#endif
+    char c;
+};
+
+static int
+setup_udp_port_1(struct socksetup *data, struct sockaddr *addr,
+                char *haddrbuf, int pktinfo);
+
+static void
+setup_udp_pktinfo_ports(struct socksetup *data)
+{
+#ifdef IP_PKTINFO
+    {
+       struct sockaddr_in sa;
+       int r;
+
+       memset(&sa, 0, sizeof(sa));
+       sa.sin_family = AF_INET;
+#ifdef HAVE_SA_LEN
+       sa.sin_len = sizeof(sa);
+#endif
+       r = setup_udp_port_1(data, (struct sockaddr *) &sa, "0.0.0.0", 4);
+       if (r == 0)
+           data->udp_flags &= ~UDP_DO_IPV4;
+    }
+#endif
+#ifdef IPV6_PKTINFO
+    {
+       struct sockaddr_in6 sa;
+       int r;
+
+       memset(&sa, 0, sizeof(sa));
+       sa.sin6_family = AF_INET6;
+#ifdef HAVE_SA_LEN
+       sa.sin6_len = sizeof(sa);
+#endif
+       r = setup_udp_port_1(data, (struct sockaddr *) &sa, "::", 6);
+       if (r == 0)
+           data->udp_flags &= ~UDP_DO_IPV6;
+    }
+#endif
+}
+#else /* no pktinfo compile-time support */
+static void
+setup_udp_pktinfo_ports(struct socksetup *data)
+{
+}
+#endif
+
+static int
+setup_udp_port_1(struct socksetup *data, struct sockaddr *addr,
+                char *haddrbuf, int pktinfo)
+{
+    int sock = -1, i, r;
+    u_short port;
+
+    FOREACH_ELT (udp_port_data, i, port) {
+       sock = socket (addr->sa_family, SOCK_DGRAM, 0);
+       if (sock == -1) {
+           data->retval = errno;
+           com_err(data->prog, data->retval,
+                   "Cannot create server socket for port %d address %s",
+                   port, haddrbuf);
+           return 1;
+       }
+#ifdef KRB5_USE_INET6
+       if (addr->sa_family == AF_INET6) {
+#ifdef IPV6_V6ONLY
+           if (setv6only(sock, 1))
+               com_err(data->prog, errno, "setsockopt(IPV6_V6ONLY,1) failed");
+           else
+               com_err(data->prog, 0, "setsockopt(IPV6_V6ONLY,1) worked");
+#else
+           krb5_klog_syslog(LOG_INFO, "no IPV6_V6ONLY socket option support");
+#endif /* IPV6_V6ONLY */
+       }
+#endif
+       set_sa_port(addr, htons(port));
+       if (bind (sock, (struct sockaddr *)addr, socklen (addr)) == -1) {
+           data->retval = errno;
+           com_err(data->prog, data->retval,
+                   "Cannot bind server socket to port %d address %s",
+                   port, haddrbuf);
+           close(sock);
+           return 1;
+       }
+#if !(defined(CMSG_SPACE) && defined(HAVE_STRUCT_CMSGHDR) && (defined(IP_PKTINFO) || defined(IPV6_PKTINFO)))
+       assert(pktinfo == 0);
+#endif
+       if (pktinfo) {
+           r = set_pktinfo(sock, addr->sa_family);
+           if (r) {
+               com_err(data->prog, r,
+                       "Cannot request packet info for udp socket address %s port %d",
+                       haddrbuf, port);
+               close(sock);
+               return 1;
+           }
+       }
+       FD_SET (sock, &sstate.rfds);
+       if (sock >= sstate.max)
+           sstate.max = sock + 1;
+       krb5_klog_syslog (LOG_INFO, "listening on fd %d: udp %s%s", sock,
+                         paddr((struct sockaddr *)addr),
+                         pktinfo ? " (pktinfo)" : "");
+       if (add_udp_fd (data, sock, pktinfo) == 0) {
+           close(sock);
+           return 1;
+       }
+    }
+    return 0;
+}
+
 static int
 setup_udp_port(void *P_data, struct sockaddr *addr)
 {
     struct socksetup *data = P_data;
-    int sock = -1, i;
     char haddrbuf[NI_MAXHOST];
     int err;
-    u_short port;
 
+    if (addr->sa_family == AF_INET && !(data->udp_flags & UDP_DO_IPV4))
+       return 0;
+#ifdef AF_INET6
+    if (addr->sa_family == AF_INET6 && !(data->udp_flags & UDP_DO_IPV6))
+       return 0;
+#endif
     err = getnameinfo(addr, socklen(addr), haddrbuf, sizeof(haddrbuf),
                      0, 0, NI_NUMERICHOST);
     if (err)
@@ -530,33 +690,7 @@ setup_udp_port(void *P_data, struct sockaddr *addr)
                          addr->sa_family);
        return 0;
     }
-
-    FOREACH_ELT (udp_port_data, i, port) {
-       sock = socket (addr->sa_family, SOCK_DGRAM, 0);
-       if (sock == -1) {
-           data->retval = errno;
-           com_err(data->prog, data->retval,
-                   "Cannot create server socket for port %d address %s",
-                   port, haddrbuf);
-           return 1;
-       }
-       set_sa_port(addr, htons(port));
-       if (bind (sock, (struct sockaddr *)addr, socklen (addr)) == -1) {
-           data->retval = errno;
-           com_err(data->prog, data->retval,
-                   "Cannot bind server socket to port %d address %s",
-                   port, haddrbuf);
-           return 1;
-       }
-       FD_SET (sock, &sstate.rfds);
-       if (sock >= sstate.max)
-           sstate.max = sock + 1;
-       krb5_klog_syslog (LOG_INFO, "listening on fd %d: udp %s", sock,
-                         paddr((struct sockaddr *)addr));
-       if (add_udp_fd (data, sock) == 0)
-           return 1;
-    }
-    return 0;
+    return setup_udp_port_1(data, addr, haddrbuf, 0);
 }
 
 #if 1
@@ -662,8 +796,12 @@ setup_network(const char *prog)
        so we might need only one UDP socket; fall back to binding
        sockets on each address only if IPV6_PKTINFO isn't
        supported.  */
-    if (foreach_localaddr (&setup_data, setup_udp_port, 0, 0)) {
-       return setup_data.retval;
+    setup_data.udp_flags = UDP_DO_IPV4 | UDP_DO_IPV6;
+    setup_udp_pktinfo_ports(&setup_data);
+    if (setup_data.udp_flags) {
+       if (foreach_localaddr (&setup_data, setup_udp_port, 0, 0)) {
+           return setup_data.retval;
+       }
     }
     setup_tcp_listener_ports(&setup_data);
     krb5_klog_syslog (LOG_INFO, "set up %d sockets", n_sockets);
@@ -707,14 +845,164 @@ static void init_addr(krb5_fulladdr *faddr, struct sockaddr *sa)
     }
 }
 
+static int
+recv_from_to(int s, void *buf, size_t len, int flags,
+            struct sockaddr *from, socklen_t *fromlen,
+            struct sockaddr *to, socklen_t *tolen)
+{
+#if !defined(IP_PKTINFO) && !defined(IPV6_PKTINFO)
+    if (to && tolen)
+       *tolen = 0;
+    return recvfrom(s, buf, len, flags, from, fromlen);
+#else
+    int r;
+    struct iovec iov;
+    char cmsg[CMSG_SPACE(sizeof(union pktinfo))];
+    struct cmsghdr *cmsgptr;
+    struct msghdr msg;
+
+    if (!to || !tolen)
+       return recvfrom(s, buf, len, flags, from, fromlen);
+
+    iov.iov_base = buf;
+    iov.iov_len = len;
+    memset(&msg, 0, sizeof(msg));
+    msg.msg_name = from;
+    msg.msg_namelen = *fromlen;
+    msg.msg_iov = &iov;
+    msg.msg_iovlen = 1;
+    msg.msg_control = cmsg;
+    msg.msg_controllen = sizeof(cmsg);
+
+    r = recvmsg(s, &msg, flags);
+    if (r < 0)
+       return r;
+    *fromlen = msg.msg_namelen;
+
+    /* On Darwin (and presumably all *BSD with KAME stacks),
+       CMSG_FIRSTHDR doesn't check for a non-zero controllen.  RFC
+       3542 recommends making this check, even though the (new) spec
+       for CMSG_FIRSTHDR says it's supposed to do the check.  */
+    if (msg.msg_controllen) {
+       cmsgptr = CMSG_FIRSTHDR(&msg);
+       while (cmsgptr) {
+#ifdef IP_PKTINFO
+           if (cmsgptr->cmsg_level == IPPROTO_IP
+               && cmsgptr->cmsg_type == IP_PKTINFO
+               && *tolen >= sizeof(struct sockaddr_in)) {
+               struct in_pktinfo *pktinfo;
+               memset(to, 0, sizeof(struct sockaddr_in));
+               pktinfo = (struct in_pktinfo *)CMSG_DATA(cmsgptr);
+               ((struct sockaddr_in *)to)->sin_addr = pktinfo->ipi_addr;
+               ((struct sockaddr_in *)to)->sin_family = AF_INET;
+               *tolen = sizeof(struct sockaddr_in);
+               return r;
+           }
+#endif
+#if defined(KRB5_USE_INET6) && defined(IPV6_PKTINFO)
+           if (cmsgptr->cmsg_level == IPPROTO_IPV6
+               && cmsgptr->cmsg_type == IPV6_PKTINFO
+               && *tolen >= sizeof(struct sockaddr_in6)) {
+               struct in6_pktinfo *pktinfo;
+               memset(to, 0, sizeof(struct sockaddr_in6));
+               pktinfo = (struct in6_pktinfo *)CMSG_DATA(cmsgptr);
+               ((struct sockaddr_in6 *)to)->sin6_addr = pktinfo->ipi6_addr;
+               ((struct sockaddr_in6 *)to)->sin6_family = AF_INET6;
+               *tolen = sizeof(struct sockaddr_in6);
+               return r;
+           }
+#endif
+           cmsgptr = CMSG_NXTHDR(&msg, cmsgptr);
+       }
+    }
+    /* No info about destination addr was available.  */
+    *tolen = 0;
+    return r;
+#endif
+}
+
+static int
+send_to_from(int s, void *buf, size_t len, int flags,
+            const struct sockaddr *to, socklen_t tolen,
+            const struct sockaddr *from, socklen_t fromlen)
+{
+#if !defined(IP_PKTINFO) && !defined(IPV6_PKTINFO)
+    return sendto(s, buf, len, flags, to, tolen);
+#else
+    struct iovec iov;
+    struct msghdr msg;
+    struct cmsghdr *cmsgptr;
+    char cbuf[CMSG_SPACE(sizeof(union pktinfo))];
+
+    if (from == 0 || fromlen == 0 || from->sa_family != to->sa_family) {
+    use_sendto:
+       return sendto(s, buf, len, flags, to, tolen);
+    }
+
+    iov.iov_base = buf;
+    iov.iov_len = len;
+    /* Truncation?  */
+    if (iov.iov_len != len)
+       return EINVAL;
+    memset(cbuf, 0, sizeof(cbuf));
+    memset(&msg, 0, sizeof(msg));
+    msg.msg_name = (void *) to;
+    msg.msg_namelen = tolen;
+    msg.msg_iov = &iov;
+    msg.msg_iovlen = 1;
+    msg.msg_control = cbuf;
+    /* CMSG_FIRSTHDR needs a non-zero controllen, or it'll return NULL
+       on Linux.  */
+    msg.msg_controllen = sizeof(cbuf);
+    cmsgptr = CMSG_FIRSTHDR(&msg);
+    msg.msg_controllen = 0;
+
+    switch (from->sa_family) {
+#if defined(IP_PKTINFO)
+    case AF_INET:
+       if (fromlen != sizeof(struct sockaddr_in))
+           goto use_sendto;
+       cmsgptr->cmsg_level = IPPROTO_IP;
+       cmsgptr->cmsg_type = IP_PKTINFO;
+       cmsgptr->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
+       {
+           struct in_pktinfo *p = (struct in_pktinfo *)CMSG_DATA(cmsgptr);
+           const struct sockaddr_in *from4 = (const struct sockaddr_in *)from;
+           p->ipi_spec_dst = from4->sin_addr;
+       }
+       msg.msg_controllen = CMSG_SPACE(sizeof(struct in_pktinfo));
+       break;
+#endif
+#if defined(KRB5_USE_INET6) && defined(IPV6_PKTINFO)
+    case AF_INET6:
+       if (fromlen != sizeof(struct sockaddr_in6))
+           goto use_sendto;
+       cmsgptr->cmsg_level = IPPROTO_IPV6;
+       cmsgptr->cmsg_type = IPV6_PKTINFO;
+       cmsgptr->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
+       {
+           struct in6_pktinfo *p = (struct in6_pktinfo *)CMSG_DATA(cmsgptr);
+           const struct sockaddr_in6 *from6 = (const struct sockaddr_in6 *)from;
+           p->ipi6_addr = from6->sin6_addr;
+       }
+       msg.msg_controllen = CMSG_SPACE(sizeof(struct in6_pktinfo));
+       break;
+#endif
+    default:
+       goto use_sendto;
+    }
+    return sendmsg(s, &msg, flags);
+#endif
+}
+
 static void process_packet(struct connection *conn, const char *prog,
                           int selflags)
 {
     int cc;
-    socklen_t saddr_len;
+    socklen_t saddr_len, daddr_len;
     krb5_fulladdr faddr;
     krb5_error_code retval;
-    struct sockaddr_storage saddr;
+    struct sockaddr_storage saddr, daddr;
     krb5_address addr;
     krb5_data request;
     krb5_data *response;
@@ -723,8 +1011,10 @@ static void process_packet(struct connection *conn, const char *prog,
 
     response = NULL;
     saddr_len = sizeof(saddr);
-    cc = recvfrom(port_fd, pktbuf, sizeof(pktbuf), 0,
-                 (struct sockaddr *)&saddr, &saddr_len);
+    daddr_len = sizeof(daddr);
+    cc = recv_from_to(port_fd, pktbuf, sizeof(pktbuf), 0,
+                     (struct sockaddr *)&saddr, &saddr_len,
+                     (struct sockaddr *)&daddr, &daddr_len);
     if (cc == -1) {
        if (errno != EINTR
            /* This is how Linux indicates that a previous
@@ -738,6 +1028,16 @@ static void process_packet(struct connection *conn, const char *prog,
     if (!cc)
        return;         /* zero-length packet? */
 
+#if 0
+    if (daddr_len > 0) {
+       char addrbuf[100];
+       if (getnameinfo(ss2sa(&daddr), daddr_len, addrbuf, sizeof(addrbuf),
+                       0, 0, NI_NUMERICHOST))
+           strcpy(addrbuf, "?");
+       com_err(prog, 0, "pktinfo says local addr is %s", addrbuf);
+    }
+#endif
+
     request.length = cc;
     request.data = pktbuf;
     faddr.address = &addr;
@@ -747,8 +1047,9 @@ static void process_packet(struct connection *conn, const char *prog,
        com_err(prog, retval, "while dispatching (udp)");
        return;
     }
-    cc = sendto(port_fd, response->data, (socklen_t) response->length, 0,
-               (struct sockaddr *)&saddr, saddr_len);
+    cc = send_to_from(port_fd, response->data, (socklen_t) response->length, 0,
+                     (struct sockaddr *)&saddr, saddr_len,
+                     (struct sockaddr *)&daddr, daddr_len);
     if (cc == -1) {
        char addrbuf[46];
         krb5_free_data(kdc_context, response);