1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock
);
20 static LIST_HEAD(fou_list
);
26 struct udp_offload udp_offloads
;
27 struct list_head list
;
33 struct udp_port_cfg udp_config
;
36 static inline struct fou
*fou_from_sock(struct sock
*sk
)
38 return sk
->sk_user_data
;
41 static int fou_udp_encap_recv_deliver(struct sk_buff
*skb
,
42 u8 protocol
, size_t len
)
44 struct iphdr
*iph
= ip_hdr(skb
);
46 /* Remove 'len' bytes from the packet (UDP header and
47 * FOU header if present), modify the protocol to the one
48 * we found, and then call rcv_encap.
50 iph
->tot_len
= htons(ntohs(iph
->tot_len
) - len
);
52 skb_postpull_rcsum(skb
, udp_hdr(skb
), len
);
53 skb_reset_transport_header(skb
);
58 static int fou_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
60 struct fou
*fou
= fou_from_sock(sk
);
65 return fou_udp_encap_recv_deliver(skb
, fou
->protocol
,
66 sizeof(struct udphdr
));
69 static int gue_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
71 struct fou
*fou
= fou_from_sock(sk
);
73 struct guehdr
*guehdr
;
79 len
= sizeof(struct udphdr
) + sizeof(struct guehdr
);
80 if (!pskb_may_pull(skb
, len
))
84 guehdr
= (struct guehdr
*)&uh
[1];
86 len
+= guehdr
->hlen
<< 2;
87 if (!pskb_may_pull(skb
, len
))
91 guehdr
= (struct guehdr
*)&uh
[1];
93 if (guehdr
->version
!= 0)
101 return fou_udp_encap_recv_deliver(skb
, guehdr
->next_hdr
, len
);
107 static struct sk_buff
**fou_gro_receive(struct sk_buff
**head
,
110 const struct net_offload
*ops
;
111 struct sk_buff
**pp
= NULL
;
112 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
113 const struct net_offload
**offloads
;
116 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
117 ops
= rcu_dereference(offloads
[proto
]);
118 if (!ops
|| !ops
->callbacks
.gro_receive
)
121 pp
= ops
->callbacks
.gro_receive(head
, skb
);
129 static int fou_gro_complete(struct sk_buff
*skb
, int nhoff
)
131 const struct net_offload
*ops
;
132 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
134 const struct net_offload
**offloads
;
136 udp_tunnel_gro_complete(skb
, nhoff
);
139 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
140 ops
= rcu_dereference(offloads
[proto
]);
141 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
144 err
= ops
->callbacks
.gro_complete(skb
, nhoff
);
152 static struct sk_buff
**gue_gro_receive(struct sk_buff
**head
,
155 const struct net_offload
**offloads
;
156 const struct net_offload
*ops
;
157 struct sk_buff
**pp
= NULL
;
160 struct guehdr
*guehdr
;
161 unsigned int hlen
, guehlen
;
165 off
= skb_gro_offset(skb
);
166 hlen
= off
+ sizeof(*guehdr
);
167 guehdr
= skb_gro_header_fast(skb
, off
);
168 if (skb_gro_header_hard(skb
, hlen
)) {
169 guehdr
= skb_gro_header_slow(skb
, hlen
, off
);
170 if (unlikely(!guehdr
))
174 proto
= guehdr
->next_hdr
;
177 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
178 ops
= rcu_dereference(offloads
[proto
]);
179 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_receive
))
182 guehlen
= sizeof(*guehdr
) + (guehdr
->hlen
<< 2);
184 hlen
= off
+ guehlen
;
185 if (skb_gro_header_hard(skb
, hlen
)) {
186 guehdr
= skb_gro_header_slow(skb
, hlen
, off
);
187 if (unlikely(!guehdr
))
193 for (p
= *head
; p
; p
= p
->next
) {
194 const struct guehdr
*guehdr2
;
196 if (!NAPI_GRO_CB(p
)->same_flow
)
199 guehdr2
= (struct guehdr
*)(p
->data
+ off
);
201 /* Compare base GUE header to be equal (covers
202 * hlen, version, next_hdr, and flags.
204 if (guehdr
->word
!= guehdr2
->word
) {
205 NAPI_GRO_CB(p
)->same_flow
= 0;
209 /* Compare optional fields are the same. */
210 if (guehdr
->hlen
&& memcmp(&guehdr
[1], &guehdr2
[1],
211 guehdr
->hlen
<< 2)) {
212 NAPI_GRO_CB(p
)->same_flow
= 0;
217 skb_gro_pull(skb
, guehlen
);
219 /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
220 skb_gro_postpull_rcsum(skb
, guehdr
, guehlen
);
222 pp
= ops
->callbacks
.gro_receive(head
, skb
);
227 NAPI_GRO_CB(skb
)->flush
|= flush
;
232 static int gue_gro_complete(struct sk_buff
*skb
, int nhoff
)
234 const struct net_offload
**offloads
;
235 struct guehdr
*guehdr
= (struct guehdr
*)(skb
->data
+ nhoff
);
236 const struct net_offload
*ops
;
237 unsigned int guehlen
;
241 proto
= guehdr
->next_hdr
;
243 guehlen
= sizeof(*guehdr
) + (guehdr
->hlen
<< 2);
246 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
247 ops
= rcu_dereference(offloads
[proto
]);
248 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
251 err
= ops
->callbacks
.gro_complete(skb
, nhoff
+ guehlen
);
258 static int fou_add_to_port_list(struct fou
*fou
)
262 spin_lock(&fou_lock
);
263 list_for_each_entry(fout
, &fou_list
, list
) {
264 if (fou
->port
== fout
->port
) {
265 spin_unlock(&fou_lock
);
270 list_add(&fou
->list
, &fou_list
);
271 spin_unlock(&fou_lock
);
276 static void fou_release(struct fou
*fou
)
278 struct socket
*sock
= fou
->sock
;
279 struct sock
*sk
= sock
->sk
;
281 udp_del_offload(&fou
->udp_offloads
);
283 list_del(&fou
->list
);
285 /* Remove hooks into tunnel socket */
286 sk
->sk_user_data
= NULL
;
293 static int fou_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
295 udp_sk(sk
)->encap_rcv
= fou_udp_recv
;
296 fou
->protocol
= cfg
->protocol
;
297 fou
->udp_offloads
.callbacks
.gro_receive
= fou_gro_receive
;
298 fou
->udp_offloads
.callbacks
.gro_complete
= fou_gro_complete
;
299 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
300 fou
->udp_offloads
.ipproto
= cfg
->protocol
;
305 static int gue_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
307 udp_sk(sk
)->encap_rcv
= gue_udp_recv
;
308 fou
->udp_offloads
.callbacks
.gro_receive
= gue_gro_receive
;
309 fou
->udp_offloads
.callbacks
.gro_complete
= gue_gro_complete
;
310 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
315 static int fou_create(struct net
*net
, struct fou_cfg
*cfg
,
316 struct socket
**sockp
)
318 struct fou
*fou
= NULL
;
320 struct socket
*sock
= NULL
;
323 /* Open UDP socket */
324 err
= udp_sock_create(net
, &cfg
->udp_config
, &sock
);
328 /* Allocate FOU port structure */
329 fou
= kzalloc(sizeof(*fou
), GFP_KERNEL
);
337 fou
->port
= cfg
->udp_config
.local_udp_port
;
339 /* Initial for fou type */
341 case FOU_ENCAP_DIRECT
:
342 err
= fou_encap_init(sk
, fou
, cfg
);
347 err
= gue_encap_init(sk
, fou
, cfg
);
356 udp_sk(sk
)->encap_type
= 1;
359 sk
->sk_user_data
= fou
;
362 udp_set_convert_csum(sk
, true);
364 sk
->sk_allocation
= GFP_ATOMIC
;
366 if (cfg
->udp_config
.family
== AF_INET
) {
367 err
= udp_add_offload(&fou
->udp_offloads
);
372 err
= fou_add_to_port_list(fou
);
389 static int fou_destroy(struct net
*net
, struct fou_cfg
*cfg
)
392 u16 port
= cfg
->udp_config
.local_udp_port
;
395 spin_lock(&fou_lock
);
396 list_for_each_entry(fou
, &fou_list
, list
) {
397 if (fou
->port
== port
) {
398 udp_del_offload(&fou
->udp_offloads
);
404 spin_unlock(&fou_lock
);
409 static struct genl_family fou_nl_family
= {
410 .id
= GENL_ID_GENERATE
,
412 .name
= FOU_GENL_NAME
,
413 .version
= FOU_GENL_VERSION
,
414 .maxattr
= FOU_ATTR_MAX
,
418 static struct nla_policy fou_nl_policy
[FOU_ATTR_MAX
+ 1] = {
419 [FOU_ATTR_PORT
] = { .type
= NLA_U16
, },
420 [FOU_ATTR_AF
] = { .type
= NLA_U8
, },
421 [FOU_ATTR_IPPROTO
] = { .type
= NLA_U8
, },
422 [FOU_ATTR_TYPE
] = { .type
= NLA_U8
, },
425 static int parse_nl_config(struct genl_info
*info
,
428 memset(cfg
, 0, sizeof(*cfg
));
430 cfg
->udp_config
.family
= AF_INET
;
432 if (info
->attrs
[FOU_ATTR_AF
]) {
433 u8 family
= nla_get_u8(info
->attrs
[FOU_ATTR_AF
]);
435 if (family
!= AF_INET
&& family
!= AF_INET6
)
438 cfg
->udp_config
.family
= family
;
441 if (info
->attrs
[FOU_ATTR_PORT
]) {
442 u16 port
= nla_get_u16(info
->attrs
[FOU_ATTR_PORT
]);
444 cfg
->udp_config
.local_udp_port
= port
;
447 if (info
->attrs
[FOU_ATTR_IPPROTO
])
448 cfg
->protocol
= nla_get_u8(info
->attrs
[FOU_ATTR_IPPROTO
]);
450 if (info
->attrs
[FOU_ATTR_TYPE
])
451 cfg
->type
= nla_get_u8(info
->attrs
[FOU_ATTR_TYPE
]);
456 static int fou_nl_cmd_add_port(struct sk_buff
*skb
, struct genl_info
*info
)
461 err
= parse_nl_config(info
, &cfg
);
465 return fou_create(&init_net
, &cfg
, NULL
);
468 static int fou_nl_cmd_rm_port(struct sk_buff
*skb
, struct genl_info
*info
)
472 parse_nl_config(info
, &cfg
);
474 return fou_destroy(&init_net
, &cfg
);
477 static const struct genl_ops fou_nl_ops
[] = {
480 .doit
= fou_nl_cmd_add_port
,
481 .policy
= fou_nl_policy
,
482 .flags
= GENL_ADMIN_PERM
,
486 .doit
= fou_nl_cmd_rm_port
,
487 .policy
= fou_nl_policy
,
488 .flags
= GENL_ADMIN_PERM
,
492 static int __init
fou_init(void)
496 ret
= genl_register_family_with_ops(&fou_nl_family
,
502 static void __exit
fou_fini(void)
504 struct fou
*fou
, *next
;
506 genl_unregister_family(&fou_nl_family
);
508 /* Close all the FOU sockets */
510 spin_lock(&fou_lock
);
511 list_for_each_entry_safe(fou
, next
, &fou_list
, list
)
513 spin_unlock(&fou_lock
);
516 module_init(fou_init
);
517 module_exit(fou_fini
);
518 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
519 MODULE_LICENSE("GPL");