[4/8] net_ns: Split set_netns() and introduce new set_netns_by_id()

Submitted by Kirill Tkhai on June 28, 2017, 11:49 a.m.

Details

Message ID 149865056894.12218.16260258482550016819.stgit@localhost.localdomain
State New
Series "One-level leaked net_ns support"
Headers show

Commit Message

Kirill Tkhai June 28, 2017, 11:49 a.m.
Rename set_netns() in set_netns_by_id() and implement new
set_netns() which works with struct ns_id.

Signed-off-by: Kirill Tkhai <ktkhai@virtuozzo.com>
---
 criu/include/sockets.h |    3 ++-
 criu/sk-inet.c         |    2 +-
 criu/sk-netlink.c      |    2 +-
 criu/sk-packet.c       |    2 +-
 criu/sk-unix.c         |    4 ++--
 criu/sockets.c         |   26 ++++++++++++++++++--------
 6 files changed, 25 insertions(+), 14 deletions(-)

Patch hide | download patch | download mbox

diff --git a/criu/include/sockets.h b/criu/include/sockets.h
index adea91115..05d1a38f5 100644
--- a/criu/include/sockets.h
+++ b/criu/include/sockets.h
@@ -84,7 +84,8 @@  static inline int sk_decode_shutdown(int val)
 #define NETLINK_SOCK_DIAG NETLINK_INET_DIAG
 #endif
 
-extern int set_netns(uint32_t ns_id);
+extern int set_netns(struct ns_id *ns);
+extern int set_netns_by_id(uint32_t ns_id);
 
 #ifndef SIOCGSKNS
 #define SIOCGSKNS      0x894C          /* get socket network namespace */
diff --git a/criu/sk-inet.c b/criu/sk-inet.c
index 5077d35b0..a56129a12 100644
--- a/criu/sk-inet.c
+++ b/criu/sk-inet.c
@@ -650,7 +650,7 @@  static int open_inet_sk(struct file_desc *d, int *new_fd)
 	if (inet_validate_address(ie))
 		return -1;
 
-	if (set_netns(ie->ns_id))
+	if (set_netns_by_id(ie->ns_id))
 		return -1;
 
 	sk = socket(ie->family, ie->type, ie->proto);
diff --git a/criu/sk-netlink.c b/criu/sk-netlink.c
index 44982a1da..a303e6b4e 100644
--- a/criu/sk-netlink.c
+++ b/criu/sk-netlink.c
@@ -179,7 +179,7 @@  static int open_netlink_sk(struct file_desc *d, int *new_fd)
 
 	pr_info("Opening netlink socket id %#x\n", nse->id);
 
-	if (set_netns(nse->ns_id))
+	if (set_netns_by_id(nse->ns_id))
 		return -1;
 
 	sk = socket(PF_NETLINK, SOCK_RAW, nse->protocol);
diff --git a/criu/sk-packet.c b/criu/sk-packet.c
index 372e9be7b..079640831 100644
--- a/criu/sk-packet.c
+++ b/criu/sk-packet.c
@@ -470,7 +470,7 @@  static int open_packet_sk(struct file_desc *d, int *new_fd)
 
 	pr_info("Opening packet socket id %#x\n", pse->id);
 
-	if (set_netns(pse->ns_id))
+	if (set_netns_by_id(pse->ns_id))
 		return -1;
 
 	if (pse->type == SOCK_PACKET)
diff --git a/criu/sk-unix.c b/criu/sk-unix.c
index 7c71448de..f0969fe83 100644
--- a/criu/sk-unix.c
+++ b/criu/sk-unix.c
@@ -1104,7 +1104,7 @@  static int open_unixsk_pair_master(struct unix_sk_info *ui, int *new_fd)
 	pr_info("Opening pair master (id %#x ino %#x peer %#x)\n",
 			ui->ue->id, ui->ue->ino, ui->ue->peer);
 
-	if (set_netns(ui->ue->ns_id))
+	if (set_netns_by_id(ui->ue->ns_id))
 		return -1;
 
 	if (socketpair(PF_UNIX, ui->ue->type, 0, sk) < 0) {
@@ -1162,7 +1162,7 @@  static int open_unixsk_standalone(struct unix_sk_info *ui, int *new_fd)
 	pr_info("Opening standalone socket (id %#x ino %#x peer %#x)\n",
 			ui->ue->id, ui->ue->ino, ui->ue->peer);
 
-	if (set_netns(ui->ue->ns_id))
+	if (set_netns_by_id(ui->ue->ns_id))
 		return -1;
 
 	/*
diff --git a/criu/sockets.c b/criu/sockets.c
index 0768888f7..d010ee3a9 100644
--- a/criu/sockets.c
+++ b/criu/sockets.c
@@ -745,19 +745,14 @@  int collect_sockets(struct ns_id *ns)
 	return err;
 }
 
-int set_netns(uint32_t ns_id)
+int set_netns(struct ns_id *ns)
 {
-	struct ns_id *ns;
 	int nsfd;
 
-	if (ns_id == current->net_ns->id)
+	BUG_ON(ns->nd != &net_ns_desc);
+	if (ns == current->net_ns)
 		return 0;
 
-	ns = lookup_ns_by_id(ns_id, &net_ns_desc);
-	if (ns == NULL) {
-		pr_err("Unable to find a network namespace\n");
-		return -1;
-	}
 	nsfd = fdstore_get(ns->net.nsfd_id);
 	if (nsfd < 0)
 		return -1;
@@ -774,6 +769,21 @@  int set_netns(uint32_t ns_id)
 	return 0;
 }
 
+int set_netns_by_id(uint32_t ns_id)
+{
+	struct ns_id *ns;
+
+	if (ns_id == current->net_ns->id)
+		return 0;
+
+	ns = lookup_ns_by_id(ns_id, &net_ns_desc);
+	if (ns == NULL) {
+		pr_err("Unable to find a network namespace\n");
+		return -1;
+	}
+	return set_netns(ns);
+}
+
 void fixup_sock_net_ns_id(uint32_t *ns_id, protobuf_c_boolean *has_ns_id)
 {
 	if (*has_ns_id)