Implement full netmasking for IPv4 and IPv6.
authorMats Erik Andersson <debian@gisladisker.se>
Thu, 14 Oct 2010 10:49:54 +0000 (12:49 +0200)
committerEmil Mikulic <emikulic@gmail.com>
Mon, 9 May 2011 12:48:50 +0000 (22:48 +1000)
A helper function "strtonum()" is implemented for platforms
where it is missing. This is detected in "configure.ac".

Netmasks with prefix length is implemented for IPv4, whereas
netmasks with explicit masks or with prefixes as built for IPv6.

acct.c
configure.ac
conv.c
conv.h

diff --git a/acct.c b/acct.c
index 8f947a8..a8e20d0 100644 (file)
--- a/acct.c
+++ b/acct.c
 #include <sys/socket.h>
 #include <stdlib.h> /* for free */
 #include <string.h> /* for memcpy */
+#include <ctype.h>  /* isdigit() */
 
 uint64_t total_packets = 0, total_bytes = 0;
 
 static int using_localnet = 0;
+static int using_localnet6 = 0;
 static in_addr_t localnet, localmask;
+static struct in6_addr localnet6, localmask6;
 
 /* Parse the net/mask specification into two IPs or die trying. */
 void
 acct_init_localnet(const char *spec)
 {
-   char **tokens;
-   int num_tokens;
+   char **tokens, *p;
+   int num_tokens, isnum, j;
+   int build_ipv6;  /* Zero for IPv4, one for IPv6.  */
+   int pfxlen, octets, remainder;
    struct in_addr addr;
+   struct in6_addr addr6;
 
    tokens = split('/', spec, &num_tokens);
    if (num_tokens != 2)
       errx(1, "expecting network/netmask, got \"%s\"", spec);
 
-   if (inet_aton(tokens[0], &addr) != 1)
-      errx(1, "invalid network address \"%s\"", tokens[0]);
-   localnet = addr.s_addr;
+   /* Presence of a colon distinguishes address families.  */
+   if (strchr(tokens[0], ':')) {
+      build_ipv6 = 1;
+      if (inet_pton(AF_INET6, tokens[0], &addr6) != 1)
+         errx(1, "invalid IPv6 network address \"%s\"", tokens[0]);
+      memcpy(&localnet6, &addr6, sizeof(localnet6));
+   } else {
+      build_ipv6 = 0;
+      if (inet_pton(AF_INET, tokens[0], &addr) != 1)
+         errx(1, "invalid network address \"%s\"", tokens[0]);
+      localnet = addr.s_addr;
+   }
+
+   /* Detect a purely numeric argument.  */
+   isnum = 0;
+   p = tokens[1];
+   while (*p != '\0') {
+      if (isdigit(*p)) {
+         isnum = 1;
+         ++p;
+         continue;
+      } else {
+         isnum = 0;
+         break;
+      }
+   }
+
+   if (!isnum) {
+      if (build_ipv6) {
+         if (inet_pton(AF_INET6, tokens[1], &addr6) != 1)
+            errx(1, "invalid IPv6 network mask \"%s\"", tokens[1]);
+         memcpy(&localmask6, &addr6, sizeof(localmask6));
+      } else {
+         if (inet_pton(AF_INET, tokens[1], &addr) != 1)
+            errx(1, "invalid network mask \"%s\"", tokens[1]);
+         localmask = addr.s_addr;
+      }
+   } else {
+      uint8_t frac, *p;
+
+      /* Compute the prefix length.  */
+      pfxlen = strtonum(tokens[1], 1, build_ipv6 ? 128 : 32, NULL);
+      if (pfxlen == 0)
+         errx(1, "invalid network prefix length \"%s\"", tokens[1]);
+
+      /* Construct the network mask.  */
+      octets = pfxlen / 8;
+      remainder = pfxlen % 8;
+      p = build_ipv6 ? (uint8_t *) localmask6.s6_addr : (uint8_t *) &localmask;
+
+      if (build_ipv6)
+         memset(&localmask6, 0, sizeof(localmask6));
+      else
+         memset(&localmask, 0, sizeof(localmask));
+
+      for (j = 0; j < octets; ++j)
+         p[j] = 0xff;
+
+      frac = 0xff << (8 - remainder);
+      if (frac)
+         p[j] = frac;   /* Have contribution for next position.  */
+   }
 
-   if (inet_aton(tokens[1], &addr) != 1)
-      errx(1, "invalid network mask \"%s\"", tokens[1]);
-   localmask = addr.s_addr;
-   /* FIXME: improve so we can accept masks like /24 for 255.255.255.0 */
+   /* Register the correct netmask and calculate the correct net.  */
+   if (build_ipv6) {
+      using_localnet6 = 1;
+      for (j = 0; j < 16; ++j)
+         localnet6.s6_addr[j] &= localmask6.s6_addr[j];
+   } else {
+      using_localnet = 1;
+      localnet &= localmask;
+   }
 
-   using_localnet = 1;
    free(tokens[0]);
    free(tokens[1]);
    free(tokens);
 
-   verbosef("local network address: %s", ip_to_str_af(&localnet, AF_INET));
-   verbosef("   local network mask: %s", ip_to_str_af(&localmask, AF_INET));
+   if (build_ipv6) {
+      verbosef("local network address: %s", ip_to_str_af(&localnet6, AF_INET6));
+      verbosef("   local network mask: %s", ip_to_str_af(&localmask6, AF_INET6));
+   } else {
+      verbosef("local network address: %s", ip_to_str_af(&localnet, AF_INET));
+      verbosef("   local network mask: %s", ip_to_str_af(&localmask, AF_INET));
+   }
 
-   if ((localnet & localmask) != localnet)
-      errx(1, "this is an invalid combination of address and mask!\n"
-      "it cannot match any address!");
 }
 
 /* Account for the given packet summary. */
@@ -78,7 +149,8 @@ acct_for(const pktsummary *sm)
    struct bucket *hs = NULL, *hd = NULL;
    struct bucket *ps, *pd;
    struct addr46 ipaddr;
-   int dir_in, dir_out;
+   struct in6_addr scribble;
+   int dir_in, dir_out, j;
 
 #if 0 /* WANT_CHATTY? */
    printf("%15s > ", ip_to_str_af(&sm->src_ip, AF_INET));
@@ -122,11 +194,23 @@ acct_for(const pktsummary *sm)
             dir_in = 1;
       }
    } else if (sm->af == AF_INET6) {
-      /* Only exact address has been implemented. */
-      if (memcmp(&sm->src_ip6, &localip6, sizeof(localip6)) == 0)
-         dir_out = 1;
-      if (memcmp(&sm->dest_ip6, &localip6, sizeof(localip6)) == 0)
-         dir_in = 1;
+      if (using_localnet6) {
+         for (j = 0; j < 16; ++j)
+            scribble.s6_addr[j] = sm->src_ip6.s6_addr[j] & localmask6.s6_addr[j];
+         if (memcmp(&scribble, &localnet6, sizeof(scribble)) == 0)
+            dir_out = 1;
+         else {
+            for (j = 0; j < 16; ++j)
+               scribble.s6_addr[j] = sm->dest_ip6.s6_addr[j] & localmask6.s6_addr[j];
+            if (memcmp(&scribble, &localnet6, sizeof(scribble)) == 0)
+               dir_in = 1;
+         }
+      } else {
+         if (memcmp(&sm->src_ip6, &localip6, sizeof(localip6)) == 0)
+            dir_out = 1;
+         if (memcmp(&sm->dest_ip6, &localip6, sizeof(localip6)) == 0)
+            dir_in = 1;
+      }
    }
 
    if (dir_out) {
index 0a3759c..a457ef9 100644 (file)
@@ -173,6 +173,10 @@ AC_CHECK_LIB(c, strlcat,
  AC_DEFINE(HAVE_STRLCAT, 1,
   [Define to 1 if you have strlcat().]))
 
+AC_CHECK_LIB(c, strtonum,
+ AC_DEFINE(HAVE_STRTONUM, 1,
+  [Define to 1 if you have strtonum(3).]))
+
 # Some OSes (Solaris) need sys/sockio.h for SIOCGIFADDR
 AC_CHECK_HEADERS(sys/sockio.h)
 
diff --git a/conv.c b/conv.c
index a3c2476..52c249a 100644 (file)
--- a/conv.c
+++ b/conv.c
@@ -31,6 +31,7 @@
 #include <string.h>
 #include <time.h>
 #include <unistd.h>
+#include <limits.h>
 
 #define PATH_DEVNULL "/dev/null"
 
@@ -455,4 +456,55 @@ strlcat(dst, src, siz)
 }
 #endif
 
+#ifndef HAVE_STRTONUM
+/*
+ * Convert an ASCII string to a decimal numerical value. An acceptable
+ * range is specified, and an optional error message string.
+ *
+ * Implementation built from the manual page description of OpenBSD 4.6.
+ */
+long long
+strtonum(const char *nptr, long long minval, long long maxval,
+         const char **errstr)
+{
+   long long val;
+   char *p;
+
+   if ((nptr == NULL) || (*nptr == '\0') || (minval > maxval)) {
+      if (errstr)
+         *errstr = "invalid";
+      errno = EINVAL;
+      return 0;
+   }
+
+   errno = 0;
+   val = strtoll(nptr, &p, 10);
+
+   if (*p != '\0') {
+      if (errstr)
+         *errstr = "invalid";
+      errno = EINVAL;
+      return 0;
+   }
+
+   if ((val == LLONG_MIN) || (val < minval)) {
+      if (errstr)
+         *errstr = "too small";
+      errno = ERANGE;
+      return 0;
+   }
+   if ((val == LLONG_MAX) || (val > maxval)) {
+      if (errstr)
+         *errstr = "too large";
+      errno = ERANGE;
+      return 0;
+   }
+
+   /* Correct conversion.  */
+   if (errstr)
+      *errstr = NULL;
+   return val;
+}
+#endif /* !HAVE_STRTONUM */
+
 /* vim:set ts=3 sw=3 tw=78 expandtab: */
diff --git a/conv.h b/conv.h
index 867975e..4377c0d 100644 (file)
--- a/conv.h
+++ b/conv.h
@@ -30,4 +30,9 @@ size_t strlcpy(char *dst, const char *src, size_t siz);
 size_t strlcat(char *dst, const char *src, size_t siz);
 #endif
 
+#ifndef HAVE_STRTONUM
+long long strtonum(const char *nptr, long long min,
+                  long long max, const char **estr);
+#endif
+
 /* vim:set ts=3 sw=3 tw=78 expandtab: */