diff --git a/net/l2tp/l2tp_eth.c b/net/l2tp/l2tp_eth.c index d2ab837a573e..6355fd449412 100644 --- a/net/l2tp/l2tp_eth.c +++ b/net/l2tp/l2tp_eth.c @@ -51,7 +51,7 @@ struct l2tp_eth { /* via l2tp_session_priv() */ struct l2tp_eth_sess { - struct net_device *dev; + struct net_device __rcu *dev; }; @@ -69,7 +69,14 @@ static int l2tp_eth_dev_init(struct net_device *dev) static void l2tp_eth_dev_uninit(struct net_device *dev) { - dev_put(dev); + struct l2tp_eth *priv = netdev_priv(dev); + struct l2tp_eth_sess *spriv; + + spriv = l2tp_session_priv(priv->session); + RCU_INIT_POINTER(spriv->dev, NULL); + /* No need for synchronize_net() here. We're called by + * unregister_netdev*(), which does the synchronisation for us. + */ } static int l2tp_eth_dev_xmit(struct sk_buff *skb, struct net_device *dev) @@ -123,8 +130,8 @@ static void l2tp_eth_dev_setup(struct net_device *dev) static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb, int data_len) { struct l2tp_eth_sess *spriv = l2tp_session_priv(session); - struct net_device *dev = spriv->dev; - struct l2tp_eth *priv = netdev_priv(dev); + struct net_device *dev; + struct l2tp_eth *priv; if (session->debug & L2TP_MSG_DATA) { unsigned int length; @@ -148,16 +155,25 @@ static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb, skb_dst_drop(skb); nf_reset(skb); + rcu_read_lock(); + dev = rcu_dereference(spriv->dev); + if (!dev) + goto error_rcu; + + priv = netdev_priv(dev); if (dev_forward_skb(dev, skb) == NET_RX_SUCCESS) { atomic_long_inc(&priv->rx_packets); atomic_long_add(data_len, &priv->rx_bytes); } else { atomic_long_inc(&priv->rx_errors); } + rcu_read_unlock(); + return; +error_rcu: + rcu_read_unlock(); error: - atomic_long_inc(&priv->rx_errors); kfree_skb(skb); } @@ -168,11 +184,15 @@ static void l2tp_eth_delete(struct l2tp_session *session) if (session) { spriv = l2tp_session_priv(session); - dev = spriv->dev; + + rtnl_lock(); + dev = rtnl_dereference(spriv->dev); if (dev) { - unregister_netdev(dev); - spriv->dev = NULL; + unregister_netdevice(dev); + rtnl_unlock(); module_put(THIS_MODULE); + } else { + rtnl_unlock(); } } } @@ -182,9 +202,20 @@ static void l2tp_eth_show(struct seq_file *m, void *arg) { struct l2tp_session *session = arg; struct l2tp_eth_sess *spriv = l2tp_session_priv(session); - struct net_device *dev = spriv->dev; + struct net_device *dev; + + rcu_read_lock(); + dev = rcu_dereference(spriv->dev); + if (!dev) { + rcu_read_unlock(); + return; + } + dev_hold(dev); + rcu_read_unlock(); seq_printf(m, " interface %s\n", dev->name); + + dev_put(dev); } #endif @@ -204,7 +235,7 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel, if (dev) { dev_put(dev); rc = -EEXIST; - goto out; + goto err; } strlcpy(name, cfg->ifname, IFNAMSIZ); } else @@ -214,20 +245,13 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel, peer_session_id, cfg); if (IS_ERR(session)) { rc = PTR_ERR(session); - goto out; - } - - l2tp_session_inc_refcount(session); - rc = l2tp_session_register(session, tunnel); - if (rc < 0) { - kfree(session); - goto out; + goto err; } dev = alloc_netdev(sizeof(*priv), name, l2tp_eth_dev_setup); if (!dev) { rc = -ENOMEM; - goto out_del_session; + goto err_sess; } dev_net_set(dev, net); @@ -248,28 +272,48 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel, #endif spriv = l2tp_session_priv(session); - spriv->dev = dev; - rc = register_netdev(dev); - if (rc < 0) - goto out_del_dev; + l2tp_session_inc_refcount(session); + + rtnl_lock(); + + /* Register both device and session while holding the rtnl lock. This + * ensures that l2tp_eth_delete() will see that there's a device to + * unregister, even if it happened to run before we assign spriv->dev. + */ + rc = l2tp_session_register(session, tunnel); + if (rc < 0) { + rtnl_unlock(); + goto err_sess_dev; + } + + rc = register_netdevice(dev); + if (rc < 0) { + rtnl_unlock(); + l2tp_session_delete(session); + l2tp_session_dec_refcount(session); + free_netdev(dev); + + return rc; + } - __module_get(THIS_MODULE); - /* Must be done after register_netdev() */ strlcpy(session->ifname, dev->name, IFNAMSIZ); + rcu_assign_pointer(spriv->dev, dev); + + rtnl_unlock(); + l2tp_session_dec_refcount(session); - dev_hold(dev); + __module_get(THIS_MODULE); return 0; -out_del_dev: - free_netdev(dev); - spriv->dev = NULL; -out_del_session: - l2tp_session_delete(session); +err_sess_dev: l2tp_session_dec_refcount(session); -out: + free_netdev(dev); +err_sess: + kfree(session); +err: return rc; }