diff --git a/include/net/sock.h b/include/net/sock.h index 4a0204bfb12e..232103e3de66 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -229,7 +229,7 @@ struct cg_proto; * @sk_lock: synchronizer * @sk_rcvbuf: size of receive buffer in bytes * @sk_wq: sock wait queue and async head - * @sk_rx_dst: receive input route used by early tcp demux + * @sk_rx_dst: receive input route used by early demux * @sk_dst_cache: destination cache * @sk_dst_lock: destination cache lock * @sk_policy: flow policy diff --git a/include/net/udp.h b/include/net/udp.h index 7b5837f22a38..2c4e8ecda56a 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -175,6 +175,7 @@ extern int udp_lib_get_port(struct sock *sk, unsigned short snum, unsigned int hash2_nulladdr); /* net/ipv4/udp.c */ +extern void udp_v4_early_demux(struct sk_buff *skb); extern int udp_get_port(struct sock *sk, unsigned short snum, int (*saddr_cmp)(const struct sock *, const struct sock *)); diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index f1269915090a..cda3e9ef77d0 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -1593,6 +1593,7 @@ static const struct net_offload tcp_offload = { }; static const struct net_protocol udp_protocol = { + .early_demux = udp_v4_early_demux, .handler = udp_rcv, .err_handler = udp_err, .no_policy = 1, diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 8986f481c732..552fad0fa683 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -103,6 +103,7 @@ #include #include #include +#include #include #include #include @@ -570,6 +571,26 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, } EXPORT_SYMBOL_GPL(udp4_lib_lookup); +static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif, unsigned short hnum) +{ + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net) || + udp_sk(sk)->udp_port_hash != hnum || + (inet->inet_daddr && inet->inet_daddr != rmt_addr) || + (inet->inet_dport != rmt_port && inet->inet_dport) || + (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) || + ipv6_only_sock(sk) || + (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)) + return false; + if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif)) + return false; + return true; +} + static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk, __be16 loc_port, __be32 loc_addr, __be16 rmt_port, __be32 rmt_addr, @@ -580,20 +601,11 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk, unsigned short hnum = ntohs(loc_port); sk_nulls_for_each_from(s, node) { - struct inet_sock *inet = inet_sk(s); - - if (!net_eq(sock_net(s), net) || - udp_sk(s)->udp_port_hash != hnum || - (inet->inet_daddr && inet->inet_daddr != rmt_addr) || - (inet->inet_dport != rmt_port && inet->inet_dport) || - (inet->inet_rcv_saddr && - inet->inet_rcv_saddr != loc_addr) || - ipv6_only_sock(s) || - (s->sk_bound_dev_if && s->sk_bound_dev_if != dif)) - continue; - if (!ip_mc_sf_allow(s, loc_addr, rmt_addr, dif)) - continue; - goto found; + if (__udp_is_mcast_sock(net, s, + loc_port, loc_addr, + rmt_port, rmt_addr, + dif, hnum)) + goto found; } s = NULL; found: @@ -1583,6 +1595,14 @@ static void flush_stack(struct sock **stack, unsigned int count, kfree_skb(skb1); } +static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + + dst_hold(dst); + sk->sk_rx_dst = dst; +} + /* * Multicasts and broadcasts go to each listener. * @@ -1711,11 +1731,28 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, if (udp4_csum_init(skb, uh, proto)) goto csum_error; - if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST)) - return __udp4_lib_mcast_deliver(net, skb, uh, - saddr, daddr, udptable); + if (skb->sk) { + int ret; + sk = skb->sk; - sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); + if (unlikely(sk->sk_rx_dst == NULL)) + udp_sk_rx_dst_set(sk, skb); + + ret = udp_queue_rcv_skb(sk, skb); + + /* a return value > 0 means to resubmit the input, but + * it wants the return to be -protocol, or 0 + */ + if (ret > 0) + return -ret; + return 0; + } else { + if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST)) + return __udp4_lib_mcast_deliver(net, skb, uh, + saddr, daddr, udptable); + + sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); + } if (sk != NULL) { int ret = udp_queue_rcv_skb(sk, skb); @@ -1771,6 +1808,135 @@ drop: return 0; } +/* We can only early demux multicast if there is a single matching socket. + * If more than one socket found returns NULL + */ +static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif) +{ + struct sock *sk, *result; + struct hlist_nulls_node *node; + unsigned short hnum = ntohs(loc_port); + unsigned int count, slot = udp_hashfn(net, hnum, udp_table.mask); + struct udp_hslot *hslot = &udp_table.hash[slot]; + + rcu_read_lock(); +begin: + count = 0; + result = NULL; + sk_nulls_for_each_rcu(sk, node, &hslot->head) { + if (__udp_is_mcast_sock(net, sk, + loc_port, loc_addr, + rmt_port, rmt_addr, + dif, hnum)) { + result = sk; + ++count; + } + } + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (get_nulls_value(node) != slot) + goto begin; + + if (result) { + if (count != 1 || + unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2))) + result = NULL; + else if (unlikely(!__udp_is_mcast_sock(net, sk, + loc_port, loc_addr, + rmt_port, rmt_addr, + dif, hnum))) { + sock_put(result); + result = NULL; + } + } + rcu_read_unlock(); + return result; +} + +/* For unicast we should only early demux connected sockets or we can + * break forwarding setups. The chains here can be long so only check + * if the first socket is an exact match and if not move on. + */ +static struct sock *__udp4_lib_demux_lookup(struct net *net, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif) +{ + struct sock *sk, *result; + struct hlist_nulls_node *node; + unsigned short hnum = ntohs(loc_port); + unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum); + unsigned int slot2 = hash2 & udp_table.mask; + struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; + INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr) + const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); + + rcu_read_lock(); + result = NULL; + udp_portaddr_for_each_entry_rcu(sk, node, &hslot2->head) { + if (INET_MATCH(sk, net, acookie, + rmt_addr, loc_addr, ports, dif)) + result = sk; + /* Only check first socket in chain */ + break; + } + + if (result) { + if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2))) + result = NULL; + else if (unlikely(!INET_MATCH(sk, net, acookie, + rmt_addr, loc_addr, + ports, dif))) { + sock_put(result); + result = NULL; + } + } + rcu_read_unlock(); + return result; +} + +void udp_v4_early_demux(struct sk_buff *skb) +{ + const struct iphdr *iph = ip_hdr(skb); + const struct udphdr *uh = udp_hdr(skb); + struct sock *sk; + struct dst_entry *dst; + struct net *net = dev_net(skb->dev); + int dif = skb->dev->ifindex; + + /* validate the packet */ + if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) + return; + + if (skb->pkt_type == PACKET_BROADCAST || + skb->pkt_type == PACKET_MULTICAST) + sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, + uh->source, iph->saddr, dif); + else if (skb->pkt_type == PACKET_HOST) + sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr, + uh->source, iph->saddr, dif); + else + return; + + if (!sk) + return; + + skb->sk = sk; + skb->destructor = sock_edemux; + dst = sk->sk_rx_dst; + + if (dst) + dst = dst_check(dst, 0); + if (dst) + skb_dst_set_noref(skb, dst); +} + int udp_rcv(struct sk_buff *skb) { return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP);