[Devel,2/2] test: check ucred for netlink messages

Submitted by Andrei Vagin on Feb. 14, 2017, 5:59 a.m.

Details

Message ID 1487051961-9422-2-git-send-email-avagin@openvz.org
State New
Series "Series without cover letter"
Headers show

Commit Message

Andrei Vagin Feb. 14, 2017, 5:59 a.m.
From: Andrei Vagin <avagin@virtuozzo.com>

Send a netlink message with ucred and check that it received
with the same ucred.

Signed-off-by: Andrei Vagin <avagin@virtuozzo.com>
---
 test/zdtm/static/sk-netlink.c | 55 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 54 insertions(+), 1 deletion(-)

Patch hide | download patch | download mbox

diff --git a/test/zdtm/static/sk-netlink.c b/test/zdtm/static/sk-netlink.c
index 976a140..bc61dea 100644
--- a/test/zdtm/static/sk-netlink.c
+++ b/test/zdtm/static/sk-netlink.c
@@ -1,8 +1,12 @@ 
+#define _GNU_SOURCE
 #include <unistd.h>
 #include <linux/netlink.h>
 #include <sys/socket.h>
 #include <linux/socket.h>
 #include <string.h>
+#include <sys/un.h>
+#include <signal.h>
+#include <sys/wait.h>
 
 #include "zdtmtst.h"
 
@@ -17,7 +21,7 @@  const char *test_author	= "Andrew Vagin <avagin@parallels.com>";
 
 int main(int argc, char ** argv)
 {
-	int ssk, bsk, csk, dsk;
+	int ssk, bsk, csk, dsk, on = 1;
 	struct sockaddr_nl addr;
 	struct msghdr msg;
 	struct {
@@ -25,14 +29,30 @@  int main(int argc, char ** argv)
 	} req;
 	struct iovec iov;
 	char buf[4096];
+	char cmsg[1024];
+	struct cmsghdr *ch;
+	struct ucred *ucred;
+	pid_t pid;
 
 	test_init(argc, argv);
 
+	pid = fork();
+	if (pid < 0) {
+		pr_err("fork");
+		return 1;
+	}
+
+	if (pid == 0) {
+		test_waitsig();
+		return 0;
+	}
+
 	ssk = socket(PF_NETLINK, SOCK_RAW, NETLINK_KOBJECT_UEVENT);
 	if (ssk < 0) {
 		pr_perror("Can't create sock diag socket");
 		return -1;
 	}
+	setsockopt(ssk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
 	bsk = socket(PF_NETLINK, SOCK_RAW, NETLINK_KOBJECT_UEVENT);
 	if (bsk < 0) {
 		pr_perror("Can't create sock diag socket");
@@ -108,6 +128,7 @@  int main(int argc, char ** argv)
 		return 1;
 	}
 
+
 	memset(&msg, 0, sizeof(msg));
 	msg.msg_namelen = 0;
 	msg.msg_iov     = &iov;
@@ -140,6 +161,18 @@  int main(int argc, char ** argv)
 	msg.msg_name	= &addr;
 	msg.msg_iov     = &iov;
 	msg.msg_iovlen  = 1;
+	msg.msg_control = cmsg;
+	msg.msg_controllen = sizeof(cmsg);
+
+	ch = CMSG_FIRSTHDR(&msg);
+	ch->cmsg_len = CMSG_LEN(sizeof(struct ucred));
+	ch->cmsg_level = SOL_SOCKET;
+	ch->cmsg_type = SCM_CREDENTIALS;
+	ucred = (struct ucred *) CMSG_DATA(ch);
+	ucred->pid = pid;
+	ucred->uid = 58;
+	ucred->gid = 39;
+	msg.msg_controllen = CMSG_SPACE(sizeof(struct ucred));
 
 	iov.iov_base    = (void *) &req;
 	iov.iov_len     = sizeof(req);
@@ -168,12 +201,17 @@  int main(int argc, char ** argv)
 	test_waitsig();
 #endif
 
+	kill(pid, SIGTERM);
+	wait(NULL);
+
 	memset(&msg, 0, sizeof(msg));
 	memset(&addr, 0, sizeof(addr));
 	msg.msg_namelen = sizeof(addr);
 	msg.msg_name	= &addr;
 	msg.msg_iov     = &iov;
 	msg.msg_iovlen  = 1;
+	msg.msg_control = cmsg;
+	msg.msg_controllen = sizeof(cmsg);
 
 	iov.iov_base    = buf;
 	iov.iov_len     = sizeof(buf);
@@ -183,6 +221,21 @@  int main(int argc, char ** argv)
 		return 1;
 	}
 
+	ch = CMSG_FIRSTHDR(&msg);
+	if (!ch || ch->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
+	    ch->cmsg_level != SOL_SOCKET ||
+	    ch->cmsg_type != SCM_CREDENTIALS) {
+		pr_err("Unable to get ucred\n");
+		return 1;
+	}
+
+	ucred = (struct ucred *) CMSG_DATA(ch);
+	if (ucred->pid != pid || ucred->uid != 58 || ucred->gid != 39) {
+		pr_err("pid %d uid %d gid %d\n",
+			ucred->pid, ucred->uid, ucred->gid);
+		return -1;
+	}
+
 	if (addr.nl_pid != getpid() * 10) {
 		fail("address mismatch: %x != %x size %d", addr.nl_pid, getpid(), msg.msg_namelen);
 		return 1;