Make host_get() not invalidate old records.
authorEmil Mikulic <emikulic@gmail.com>
Sun, 15 May 2011 07:54:11 +0000 (17:54 +1000)
committerEmil Mikulic <emikulic@gmail.com>
Sat, 28 May 2011 10:10:53 +0000 (20:10 +1000)
acct.c
hosts_db.c

diff --git a/acct.c b/acct.c
index 49ec8e2..4b04793 100644 (file)
--- a/acct.c
+++ b/acct.c
@@ -203,45 +203,42 @@ acct_for(const struct pktsummary * const sm)
    if (hosts_max == 0) return; /* skip per-host accounting */
 
    /* Hosts. */
+   hosts_db_reduce();
    hs = host_get(&(sm->src));
    hs->out   += sm->len;
    hs->total += sm->len;
    memcpy(hs->u.host.mac_addr, sm->src_mac, sizeof(sm->src_mac));
    hs->u.host.last_seen = now;
 
-   hd = host_get(&(sm->dst)); /* this can invalidate hs! */
+   hd = host_get(&(sm->dst));
    hd->in    += sm->len;
    hd->total += sm->len;
    memcpy(hd->u.host.mac_addr, sm->dst_mac, sizeof(sm->dst_mac));
    hd->u.host.last_seen = now;
 
    /* Protocols. */
-   hs = host_find(&(sm->src));
-   if (hs != NULL) {
+   if (sm->proto != IPPROTO_INVALID) {
       ps = host_get_ip_proto(hs, sm->proto);
       ps->out   += sm->len;
       ps->total += sm->len;
-   }
 
-   pd = host_get_ip_proto(hd, sm->proto);
-   pd->in    += sm->len;
-   pd->total += sm->len;
+      pd = host_get_ip_proto(hd, sm->proto);
+      pd->in    += sm->len;
+      pd->total += sm->len;
+   }
 
    if (ports_max == 0) return; /* skip ports accounting */
 
    /* Ports. */
-   switch (sm->proto)
-   {
+   switch (sm->proto) {
    case IPPROTO_TCP:
-      if ((sm->src_port <= highest_port) && (hs != NULL))
-      {
+      if (sm->src_port <= highest_port) {
          ps = host_get_port_tcp(hs, sm->src_port);
          ps->out   += sm->len;
          ps->total += sm->len;
       }
 
-      if (sm->dst_port <= highest_port)
-      {
+      if (sm->dst_port <= highest_port) {
          pd = host_get_port_tcp(hd, sm->dst_port);
          pd->in    += sm->len;
          pd->total += sm->len;
@@ -251,15 +248,13 @@ acct_for(const struct pktsummary * const sm)
       break;
 
    case IPPROTO_UDP:
-      if ((sm->src_port <= highest_port) && (hs != NULL))
-      {
+      if (sm->src_port <= highest_port) {
          ps = host_get_port_udp(hs, sm->src_port);
          ps->out   += sm->len;
          ps->total += sm->len;
       }
 
-      if (sm->dst_port <= highest_port)
-      {
+      if (sm->dst_port <= highest_port) {
          pd = host_get_port_udp(hd, sm->dst_port);
          pd->in    += sm->len;
          pd->total += sm->len;
index 902774f..ce465fd 100644 (file)
@@ -597,16 +597,18 @@ hashtable_search(struct hashtable *h, const void *key)
    return (NULL);
 }
 
+typedef enum { NO_REDUCE = 0, ALLOW_REDUCE = 1 } reduce_bool;
 /* Search for a key.  If it's not there, make and insert a bucket for it. */
 static struct bucket *
-hashtable_find_or_insert(struct hashtable *h, const void *key)
+hashtable_find_or_insert(struct hashtable *h, const void *key,
+      const reduce_bool allow_reduce)
 {
    struct bucket *b = hashtable_search(h, key);
 
    if (b == NULL) {
       /* Not found, so insert after checking occupancy. */
-      /*assert(h->count <= h->count_max);*/
-      if (h->count >= h->count_max) hashtable_reduce(h);
+      if (allow_reduce && (h->count >= h->count_max))
+         hashtable_reduce(h);
       b = h->make_func(key);
       hashtable_insert(h, b);
    }
@@ -644,7 +646,7 @@ hashtable_free(struct hashtable *h)
 struct bucket *
 host_get(const struct addr *const a)
 {
-   return (hashtable_find_or_insert(hosts_db, a));
+   return (hashtable_find_or_insert(hosts_db, a, NO_REDUCE));
 }
 
 /* ---------------------------------------------------------------------------
@@ -744,6 +746,13 @@ hashtable_reduce(struct hashtable *ht)
    hashtable_rehash(ht, ht->bits); /* is this needed? */
 }
 
+/* Reduce hosts_db if needed. */
+void hosts_db_reduce(void)
+{
+   if (hosts_db->count >= hosts_db->count_max)
+      hashtable_reduce(hosts_db);
+}
+
 /* ---------------------------------------------------------------------------
  * Reset hosts_db to empty.
  */
@@ -801,7 +810,7 @@ host_get_port_tcp(struct bucket *host, const uint16_t port)
          hash_func_short, free_func_simple, key_func_port_tcp,
          find_func_port_tcp, make_func_port_tcp,
          format_cols_port_tcp, format_row_port_tcp);
-   return (hashtable_find_or_insert(h->ports_tcp, &port));
+   return (hashtable_find_or_insert(h->ports_tcp, &port, ALLOW_REDUCE));
 }
 
 /* ---------------------------------------------------------------------------
@@ -817,7 +826,7 @@ host_get_port_udp(struct bucket *host, const uint16_t port)
          hash_func_short, free_func_simple, key_func_port_udp,
          find_func_port_udp, make_func_port_udp,
          format_cols_port_udp, format_row_port_udp);
-   return (hashtable_find_or_insert(h->ports_udp, &port));
+   return (hashtable_find_or_insert(h->ports_udp, &port, ALLOW_REDUCE));
 }
 
 /* ---------------------------------------------------------------------------
@@ -834,7 +843,7 @@ host_get_ip_proto(struct bucket *host, const uint8_t proto)
          hash_func_byte, free_func_simple, key_func_ip_proto,
          find_func_ip_proto, make_func_ip_proto,
          format_cols_ip_proto, format_row_ip_proto);
-   return (hashtable_find_or_insert(h->ip_protos, &proto));
+   return (hashtable_find_or_insert(h->ip_protos, &proto, ALLOW_REDUCE));
 }
 
 static struct str *html_hosts_main(const char *qs);