blob: 9a33b038cb31985923abe1934ee78a075f0a749d [file] [log] [blame]
Alexei Starovoitov8d48f5e2017-03-30 21:45:42 -07001/* Copyright (c) 2016,2017 Facebook
2 *
3 * This program is free software; you can redistribute it and/or
4 * modify it under the terms of version 2 of the GNU General Public
5 * License as published by the Free Software Foundation.
6 */
7#include <stddef.h>
8#include <string.h>
9#include <linux/bpf.h>
10#include <linux/if_ether.h>
11#include <linux/if_packet.h>
12#include <linux/ip.h>
13#include <linux/ipv6.h>
14#include <linux/in.h>
15#include <linux/udp.h>
16#include <linux/tcp.h>
17#include <linux/pkt_cls.h>
18#include <sys/socket.h>
19#include "bpf_helpers.h"
20#include "test_iptunnel_common.h"
21
22#define htons __builtin_bswap16
23#define ntohs __builtin_bswap16
24int _version SEC("version") = 1;
25
26struct bpf_map_def SEC("maps") rxcnt = {
27 .type = BPF_MAP_TYPE_PERCPU_ARRAY,
28 .key_size = sizeof(__u32),
29 .value_size = sizeof(__u64),
30 .max_entries = 256,
31};
32
33struct bpf_map_def SEC("maps") vip2tnl = {
34 .type = BPF_MAP_TYPE_HASH,
35 .key_size = sizeof(struct vip),
36 .value_size = sizeof(struct iptnl_info),
37 .max_entries = MAX_IPTNL_ENTRIES,
38};
39
40static __always_inline void count_tx(__u32 protocol)
41{
42 __u64 *rxcnt_count;
43
44 rxcnt_count = bpf_map_lookup_elem(&rxcnt, &protocol);
45 if (rxcnt_count)
46 *rxcnt_count += 1;
47}
48
49static __always_inline int get_dport(void *trans_data, void *data_end,
50 __u8 protocol)
51{
52 struct tcphdr *th;
53 struct udphdr *uh;
54
55 switch (protocol) {
56 case IPPROTO_TCP:
57 th = (struct tcphdr *)trans_data;
58 if (th + 1 > data_end)
59 return -1;
60 return th->dest;
61 case IPPROTO_UDP:
62 uh = (struct udphdr *)trans_data;
63 if (uh + 1 > data_end)
64 return -1;
65 return uh->dest;
66 default:
67 return 0;
68 }
69}
70
71static __always_inline void set_ethhdr(struct ethhdr *new_eth,
72 const struct ethhdr *old_eth,
73 const struct iptnl_info *tnl,
74 __be16 h_proto)
75{
76 memcpy(new_eth->h_source, old_eth->h_dest, sizeof(new_eth->h_source));
77 memcpy(new_eth->h_dest, tnl->dmac, sizeof(new_eth->h_dest));
78 new_eth->h_proto = h_proto;
79}
80
81static __always_inline int handle_ipv4(struct xdp_md *xdp)
82{
83 void *data_end = (void *)(long)xdp->data_end;
84 void *data = (void *)(long)xdp->data;
85 struct iptnl_info *tnl;
86 struct ethhdr *new_eth;
87 struct ethhdr *old_eth;
88 struct iphdr *iph = data + sizeof(struct ethhdr);
89 __u16 *next_iph;
90 __u16 payload_len;
91 struct vip vip = {};
92 int dport;
93 __u32 csum = 0;
94 int i;
95
96 if (iph + 1 > data_end)
97 return XDP_DROP;
98
99 dport = get_dport(iph + 1, data_end, iph->protocol);
100 if (dport == -1)
101 return XDP_DROP;
102
103 vip.protocol = iph->protocol;
104 vip.family = AF_INET;
105 vip.daddr.v4 = iph->daddr;
106 vip.dport = dport;
107 payload_len = ntohs(iph->tot_len);
108
109 tnl = bpf_map_lookup_elem(&vip2tnl, &vip);
110 /* It only does v4-in-v4 */
111 if (!tnl || tnl->family != AF_INET)
112 return XDP_PASS;
113
114 if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct iphdr)))
115 return XDP_DROP;
116
117 data = (void *)(long)xdp->data;
118 data_end = (void *)(long)xdp->data_end;
119
120 new_eth = data;
121 iph = data + sizeof(*new_eth);
122 old_eth = data + sizeof(*iph);
123
124 if (new_eth + 1 > data_end ||
125 old_eth + 1 > data_end ||
126 iph + 1 > data_end)
127 return XDP_DROP;
128
129 set_ethhdr(new_eth, old_eth, tnl, htons(ETH_P_IP));
130
131 iph->version = 4;
132 iph->ihl = sizeof(*iph) >> 2;
133 iph->frag_off = 0;
134 iph->protocol = IPPROTO_IPIP;
135 iph->check = 0;
136 iph->tos = 0;
137 iph->tot_len = htons(payload_len + sizeof(*iph));
138 iph->daddr = tnl->daddr.v4;
139 iph->saddr = tnl->saddr.v4;
140 iph->ttl = 8;
141
142 next_iph = (__u16 *)iph;
143#pragma clang loop unroll(full)
144 for (i = 0; i < sizeof(*iph) >> 1; i++)
145 csum += *next_iph++;
146
147 iph->check = ~((csum & 0xffff) + (csum >> 16));
148
149 count_tx(vip.protocol);
150
151 return XDP_TX;
152}
153
154static __always_inline int handle_ipv6(struct xdp_md *xdp)
155{
156 void *data_end = (void *)(long)xdp->data_end;
157 void *data = (void *)(long)xdp->data;
158 struct iptnl_info *tnl;
159 struct ethhdr *new_eth;
160 struct ethhdr *old_eth;
161 struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
162 __u16 payload_len;
163 struct vip vip = {};
164 int dport;
165
166 if (ip6h + 1 > data_end)
167 return XDP_DROP;
168
169 dport = get_dport(ip6h + 1, data_end, ip6h->nexthdr);
170 if (dport == -1)
171 return XDP_DROP;
172
173 vip.protocol = ip6h->nexthdr;
174 vip.family = AF_INET6;
175 memcpy(vip.daddr.v6, ip6h->daddr.s6_addr32, sizeof(vip.daddr));
176 vip.dport = dport;
177 payload_len = ip6h->payload_len;
178
179 tnl = bpf_map_lookup_elem(&vip2tnl, &vip);
180 /* It only does v6-in-v6 */
181 if (!tnl || tnl->family != AF_INET6)
182 return XDP_PASS;
183
184 if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct ipv6hdr)))
185 return XDP_DROP;
186
187 data = (void *)(long)xdp->data;
188 data_end = (void *)(long)xdp->data_end;
189
190 new_eth = data;
191 ip6h = data + sizeof(*new_eth);
192 old_eth = data + sizeof(*ip6h);
193
194 if (new_eth + 1 > data_end || old_eth + 1 > data_end ||
195 ip6h + 1 > data_end)
196 return XDP_DROP;
197
198 set_ethhdr(new_eth, old_eth, tnl, htons(ETH_P_IPV6));
199
200 ip6h->version = 6;
201 ip6h->priority = 0;
202 memset(ip6h->flow_lbl, 0, sizeof(ip6h->flow_lbl));
203 ip6h->payload_len = htons(ntohs(payload_len) + sizeof(*ip6h));
204 ip6h->nexthdr = IPPROTO_IPV6;
205 ip6h->hop_limit = 8;
206 memcpy(ip6h->saddr.s6_addr32, tnl->saddr.v6, sizeof(tnl->saddr.v6));
207 memcpy(ip6h->daddr.s6_addr32, tnl->daddr.v6, sizeof(tnl->daddr.v6));
208
209 count_tx(vip.protocol);
210
211 return XDP_TX;
212}
213
214SEC("xdp_tx_iptunnel")
215int _xdp_tx_iptunnel(struct xdp_md *xdp)
216{
217 void *data_end = (void *)(long)xdp->data_end;
218 void *data = (void *)(long)xdp->data;
219 struct ethhdr *eth = data;
220 __u16 h_proto;
221
222 if (eth + 1 > data_end)
223 return XDP_DROP;
224
225 h_proto = eth->h_proto;
226
227 if (h_proto == htons(ETH_P_IP))
228 return handle_ipv4(xdp);
229 else if (h_proto == htons(ETH_P_IPV6))
230
231 return handle_ipv6(xdp);
232 else
233 return XDP_DROP;
234}
235
236char _license[] SEC("license") = "GPL";