| #!/usr/bin/env python |
| # Copyright (c) PLUMgrid, Inc. |
| # Licensed under the Apache License, Version 2.0 (the "License") |
| |
| # test program to count the packets sent to a device in a .5 |
| # second period |
| |
| from ctypes import c_uint, c_ulong, Structure |
| from netaddr import IPAddress |
| from bcc import BPF |
| from subprocess import check_call |
| import sys |
| from unittest import main, TestCase |
| |
| arg1 = sys.argv.pop(1) |
| arg2 = "" |
| if len(sys.argv) > 1: |
| arg2 = sys.argv.pop(1) |
| |
| Key = None |
| Leaf = None |
| if arg1.endswith(".b"): |
| class Key(Structure): |
| _fields_ = [("dip", c_uint), |
| ("sip", c_uint)] |
| class Leaf(Structure): |
| _fields_ = [("rx_pkts", c_ulong), |
| ("tx_pkts", c_ulong)] |
| |
| class TestBPFSocket(TestCase): |
| def setUp(self): |
| b = BPF(arg1, arg2, debug=0) |
| fn = b.load_func("on_packet", BPF.SOCKET_FILTER) |
| BPF.attach_raw_socket(fn, "eth0") |
| self.stats = b.get_table("stats", Key, Leaf) |
| |
| def test_ping(self): |
| cmd = ["ping", "-f", "-c", "100", "172.16.1.1"] |
| check_call(cmd) |
| #for key, leaf in self.stats.items(): |
| # print(IPAddress(key.sip), "=>", IPAddress(key.dip), |
| # "rx", leaf.rx_pkts, "tx", leaf.tx_pkts) |
| key = self.stats.Key(IPAddress("172.16.1.2").value, IPAddress("172.16.1.1").value) |
| leaf = self.stats[key] |
| self.assertEqual(leaf.rx_pkts, 100) |
| self.assertEqual(leaf.tx_pkts, 100) |
| del self.stats[key] |
| with self.assertRaises(KeyError): |
| x = self.stats[key] |
| with self.assertRaises(KeyError): |
| del self.stats[key] |
| self.stats.clear() |
| self.assertEqual(len(self.stats), 0) |
| self.stats[key] = leaf |
| self.assertEqual(len(self.stats), 1) |
| self.stats.clear() |
| self.assertEqual(len(self.stats), 0) |
| |
| def test_empty_key(self): |
| # test with a 0 key |
| self.stats.clear() |
| self.stats[self.stats.Key()] = self.stats.Leaf(100, 200) |
| x = self.stats.popitem() |
| self.stats[self.stats.Key(10, 20)] = self.stats.Leaf(300, 400) |
| with self.assertRaises(KeyError): |
| x = self.stats[self.stats.Key()] |
| (_, x) = self.stats.popitem() |
| self.assertEqual(x.rx_pkts, 300) |
| self.assertEqual(x.tx_pkts, 400) |
| self.stats.clear() |
| self.assertEqual(len(self.stats), 0) |
| self.stats[self.stats.Key()] = x |
| self.stats[self.stats.Key(0, 1)] = x |
| self.stats[self.stats.Key(0, 2)] = x |
| self.stats[self.stats.Key(0, 3)] = x |
| self.assertEqual(len(self.stats), 4) |
| |
| if __name__ == "__main__": |
| main() |