enter: Save and restore the terminal settings.
[micoforia.git] / util.c
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 #include "m7a.h"
3
4 #include <sys/ipc.h>
5 #include <sys/sem.h>
6 #include <fcntl.h>
7 #include <ctype.h>
8 #include <sys/mount.h>
9 #include <dirent.h>
10 #include <net/if.h>
11 #include <linux/sockios.h>
12 #include <libmnl/libmnl.h>
13 #include <linux/if_link.h>
14 #include <linux/rtnetlink.h>
15 #include <sys/un.h>
16
17 void die(const char *fmt, ...)
18 {
19         char *str;
20         va_list argp;
21         int ret;
22
23         va_start(argp, fmt);
24         ret = vasprintf(&str, fmt, argp);
25         va_end(argp);
26         if (ret < 0) { /* give up */
27                 EMERG_LOG("OOM\n");
28                 exit(EXIT_FAILURE);
29         }
30         m7a_log(LL_EMERG, "%s\n", str);
31         exit(EXIT_FAILURE);
32 }
33
34 void die_errno(const char *fmt, ...)
35 {
36         char *str;
37         va_list argp;
38         int ret, save_errno = errno;
39
40         va_start(argp, fmt);
41         ret = vasprintf(&str, fmt, argp);
42         va_end(argp);
43         if (ret < 0) {
44                 EMERG_LOG("OOM\n");
45                 exit(EXIT_FAILURE);
46         }
47         m7a_log(LL_EMERG, "%s: %s\n", str, strerror(save_errno));
48         exit(EXIT_FAILURE);
49 }
50
51 void *xrealloc(void *p, size_t size)
52 {
53         assert(size > 0);
54         assert((p = realloc(p, size)));
55         return p;
56 }
57
58 void *xmalloc(size_t size)
59 {
60         return xrealloc(NULL, size);
61 }
62
63 void *xzmalloc(size_t size)
64 {
65         void *p = xrealloc(NULL, size);
66         memset(p, 0, size);
67         return p;
68 }
69
70 void *xstrdup(const char *s)
71 {
72         char *ret = strdup(s? s: "");
73
74         assert(ret);
75         return ret;
76 }
77
78 char *msg(const char *fmt, ...)
79 {
80         char *m;
81         size_t size = 100;
82
83         m = xmalloc(size);
84         while (1) {
85                 int n;
86                 va_list ap;
87
88                 /* Try to print in the allocated space. */
89                 va_start(ap, fmt);
90                 n = vsnprintf(m, size, fmt, ap);
91                 va_end(ap);
92                 /* If that worked, return the string. */
93                 if (n < size)
94                         return m;
95                 /* Else try again with more space. */
96                 size = n + 1; /* precisely what is needed */
97                 m = xrealloc(m, size);
98         }
99 }
100
101 char *xstrcat(char *a, const char *b)
102 {
103         char *tmp;
104
105         if (!a)
106                 return xstrdup(b);
107         if (!b)
108                 return a;
109         tmp = msg("%s%s", a, b);
110         free(a);
111         return tmp;
112 }
113
114 void die_empty_arg(const char *opt)
115 {
116         die("argument to --%s must not be empty", opt);
117 }
118
119 __attribute__ ((noreturn))
120 static void die_range(const char *opt)
121 {
122         die("argument to --%s is out of range", opt);
123 }
124
125 void check_range(uint32_t val, uint32_t min, uint32_t max, const char *opt)
126 {
127         if (val < min || val > max)
128                 die_range(opt);
129 }
130
131 bool fd2buf(int fd, const struct iovec *iov)
132 {
133         ssize_t ret, nread = 0, max;
134         char *buf = iov->iov_base;
135
136         assert(iov->iov_len > 1);
137         max = iov->iov_len - 1;
138         for (;;) {
139                 ret = read(fd, buf + nread, max - nread);
140                 if (ret < 0) {
141                         if (errno == EAGAIN || errno == EINTR)
142                                 continue;
143                         ERROR_LOG("read error: %s\n", strerror(errno));
144                         return false;
145                 }
146                 if (ret == 0) {
147                         buf[nread] = '\0';
148                         DEBUG_LOG("read %zd bytes\n", nread);
149                         return true;
150                 }
151                 nread += ret;
152                 if (nread >= max) {
153                         ERROR_LOG("cmd output truncated\n");
154                         return false;
155                 }
156         }
157 }
158
159 bool xexec(char * const argv[], const struct iovec *iov)
160 {
161         pid_t pid;
162         int pipefd[2] = {-1, -1};
163         unsigned n;
164
165         for (n = 0; argv[n]; n++)
166                 DEBUG_LOG("argv[%u]=%s\n", n, argv[n]);
167         if (iov) {
168                 if (pipe(pipefd) < 0)
169                         die_errno("pipe");
170         }
171         if ((pid = fork()) < 0)
172                 die_errno("fork");
173         if (pid > 0) { /* parent */
174                 int wstatus;
175                 bool success = true;
176                 if (iov) {
177                         close(pipefd[1]);
178                         success = fd2buf(pipefd[0], iov);
179                         close(pipefd[0]);
180                 }
181                 if (waitpid(pid, &wstatus, 0) < 0)
182                         die_errno("waitp");
183                 if (!success)
184                         return false;
185                 if (!WIFEXITED(wstatus))
186                         return false;
187                 if (WEXITSTATUS(wstatus) != EXIT_SUCCESS)
188                         return false;
189                 return true;
190         }
191         if (pipefd[0] >= 0)
192                 close(pipefd[0]);
193         if (pipefd[1] >= 0 && pipefd[1] != STDOUT_FILENO) {
194                 if (dup2(pipefd[1], STDOUT_FILENO) < 0)
195                         die_errno("dup2()");
196                 close(pipefd[1]);
197         }
198         execvp(argv[0], argv);
199         EMERG_LOG("execvp error: %s\n", strerror(errno));
200         _exit(EXIT_FAILURE);
201 }
202
203 void valid_fd012(void)
204 {
205         /* Ensure that file descriptors 0, 1, and 2 are valid. */
206         while (1) {
207                 int fd = open("/dev/null", O_RDWR);
208                 if (fd < 0)
209                         die_errno("open");
210                 if (fd > 2) {
211                         close(fd);
212                         break;
213                 }
214         }
215 }
216
217 void check_name(const char *arg)
218 {
219         size_t m, len;
220         char c;
221
222         len = strlen(arg);
223         if (len == 0)
224                 die("empty name");
225         if (len > 32)
226                 die("name too long: %s", arg);
227         for (m = 0; m < len; m++) {
228                 c = arg[m];
229                 if (!isascii(c))
230                         goto invalid;
231                 if (!isalnum(c) && c != '-')
232                         goto invalid;
233         }
234         return;
235 invalid:
236         die("invalid character '%c' in name %s", c, arg);
237 }
238
239 /* allocates two new strings that should be freed by the caller */
240 void parse_compound_arg(const char *arg, const char *opt, char **name, char **val)
241 {
242         char *copy, *p;
243
244         if (arg[0] == '\0')
245                 die_empty_arg(opt);
246         copy = xstrdup(arg);
247         p = strchr(copy, ':');
248         if (!p)
249                 die("could not parse argument to --%s", opt);
250         *p = '\0';
251         check_name(copy);
252         *name = copy;
253         p++;
254         *val = xstrdup(p);
255 }
256
257 char *parse_cgroup_acl(const char *arg)
258 {
259         if (!strncmp(arg, "allow ", 6))
260                 return msg("a%s", arg + 6);
261         if (!strncmp(arg, "deny ", 5))
262                 return msg("d%s", arg + 5);
263         die("invalid cgroup access specifier: %s", arg);
264 }
265
266 void parse_ifspec(const char *arg, char **bridge, uint8_t *hwaddr)
267 {
268         const char *colon = strchr(arg, ':');
269         size_t len;
270         unsigned n, x[6];
271
272         if (colon) {
273                 len = colon - arg;
274                 *bridge = xmalloc(len + 1);
275                 memcpy(*bridge, arg, len);
276                 (*bridge)[len] = '\0';
277         } else
278                 *bridge = xstrdup(arg);
279         check_name(*bridge);
280         if (!colon) {
281                 memset(hwaddr, 0, 6);
282                 return;
283         }
284         if (sscanf(colon + 1, "%02x:%02x:%02x:%02x:%02x:%02x",
285                 x, x + 1, x + 2, x + 3, x + 4, x + 5) != 6)
286                 die("invalid hwaddress for ifspec %s", arg);
287         if (colon[1 + 6 * 2 + 5] != '\0')
288                 die("trailing garbage at the end of ifspec %s", arg);
289         for (n = 0; n < 6; n++)
290                 hwaddr[n] = x[n];
291 }
292
293 uint32_t atou32(const char *str, const char *opt)
294 {
295         char *endptr;
296         long long tmp;
297
298         errno = 0; /* To distinguish success/failure after call */
299         tmp = strtoll(str, &endptr, 10);
300         if (errno == ERANGE && (tmp == LLONG_MAX || tmp == LLONG_MIN))
301                 die_range(opt);
302         if (tmp < 0 || tmp > (uint32_t)-1)
303                 die_range(opt);
304         /*
305          * If there were no digits at all, strtoll() stores the original value
306          * of str in *endptr.
307          */
308         if (endptr == str)
309                 die_empty_arg(opt);
310         /*
311          * The implementation may also set errno and return 0 in case no
312          * conversion was performed.
313          */
314         if (errno != 0 && tmp == 0)
315                 die_empty_arg(opt);
316         if (*endptr != '\0') /* Further characters after number */
317                 die("--%s: trailing characters after number", opt);
318         return tmp;
319 }
320
321 bool remove_subdirs_recursively(const char *path)
322 {
323         DIR *d = opendir(path);
324         struct dirent *entry;
325         int dfd;
326         struct stat stat;
327
328         if (!d) {
329                 ERROR_LOG("opendir %s: %m\n", path);
330                 return false;
331         }
332         dfd = dirfd(d);
333         assert(dfd >= 0);
334         while ((entry = readdir(d))) {
335                 char *subpath;
336                 if (!strcmp(entry->d_name, "."))
337                         continue;
338                 if (!strcmp(entry->d_name, ".."))
339                         continue;
340                 if (fstatat(dfd, entry->d_name, &stat, 0) == -1) {
341                         WARNING_LOG("%s/%s: %m", path, entry->d_name);
342                         continue;
343                 }
344                 if (!S_ISDIR(stat.st_mode))
345                         continue;
346                 subpath = msg("%s/%s", path, entry->d_name);
347                 remove_subdirs_recursively(subpath);
348                 DEBUG_LOG("removing %s\n", subpath);
349                 if (rmdir(subpath) < 0) {
350                         ERROR_LOG("rmdir %s: %m\n", subpath);
351                         return false;
352                 }
353                 free(subpath);
354         }
355         closedir(d);
356         return true;
357 }
358
359 void daemonize(const char *logfile)
360 {
361         pid_t pid;
362         int nullfd, logfd;
363
364         if ((pid = fork()) < 0)
365                 die_errno("fork");
366         if (pid) /* parent exits */
367                 exit(EXIT_SUCCESS);
368         valid_fd012();
369         /* become session leader */
370         if (setsid() < 0)
371                 die_errno("setsid");
372         if ((nullfd = open("/dev/null", O_RDWR)) < 0)
373                 die_errno("open /dev/null");
374         logfile = logfile? logfile : "/dev/null";
375         if ((logfd = open(logfile, O_WRONLY | O_APPEND | O_CREAT, 0666)) < 0)
376                 die_errno("open %s", logfile);
377         NOTICE_LOG("subsequent log messages go to %s\n", logfile);
378         if (dup2(nullfd, STDIN_FILENO) < 0)
379                 die_errno("dup2");
380         close(nullfd);
381         if (dup2(logfd, STDOUT_FILENO) < 0)
382                 die_errno("dup2");
383         if (dup2(logfd, STDERR_FILENO) < 0)
384                 die_errno("dup2");
385         close(logfd);
386         if (chdir("/") < 0)
387                 die_errno("chdir");
388 }
389
390 static int super_dull_hash(const char *input)
391 {
392         const uint8_t *x = (typeof(x))input;
393         const unsigned p1 = 16777619, p2 = 2971215073;
394         unsigned n, m, h, result = 0;
395
396         for (n = 0; n < 4; n++) {
397                 h = p1 * (x[0] + n);
398                 for (m = 1; x[m] != 0; m++)
399                         h = p2 * (h ^ x[m]);
400                 result = (result << 8) | (h % 256);
401         }
402         return result >> 1;
403 }
404
405 /**
406  * We use a semaphore set with two semaphores. The first semaphore is modified
407  * in all locking related functions while the second semaphore is modified only
408  * in try_lock() and aquire_lock(). This allows us to obtain the PID of the
409  * lock holder by querying the PID that last performed an operation on the
410  * second semaphore. This is achieved by passing GETPID as the control
411  * operation to semctl().
412  */
413
414 static bool get_lock(const char *string, pid_t *pid, bool wait)
415 {
416         int semid, ret;
417         struct sembuf sops[4];
418         key_t key = super_dull_hash(string);
419         bool success;
420         short sem_flg = SEM_UNDO;
421
422         if (!wait)
423                 sem_flg |= IPC_NOWAIT;
424         ret = semget(key, 2, IPC_CREAT | 0600);
425         if (ret < 0) {
426                 ERROR_LOG("semget: %m\n");
427                 return false;
428         }
429         semid = ret;
430         DEBUG_LOG("key: 0x%0x, semid: %d\n", (unsigned)key, semid);
431         ret = semctl(semid, 1, GETPID);
432         if (ret < 0)
433                 return false;
434         if (pid)
435                 *pid = ret;
436         sops[0].sem_num = 0;
437         sops[0].sem_op = 0;
438         sops[0].sem_flg = sem_flg;
439
440         sops[1].sem_num = 0;
441         sops[1].sem_op = 1;
442         sops[1].sem_flg = sem_flg;
443
444         sops[2].sem_num = 1;
445         sops[2].sem_op = 0;
446         sops[2].sem_flg = sem_flg;
447
448         sops[3].sem_num = 1;
449         sops[3].sem_op = 1;
450         sops[3].sem_flg = sem_flg;
451
452         success = semop(semid, sops, 4) >= 0;
453         if (!success)
454                 INFO_LOG("semop: %m\n");
455         return success;
456 }
457
458 bool try_lock(const char *string, pid_t *pid)
459 {
460         return get_lock(string, pid, false /* don't wait */);
461 }
462
463 bool acquire_lock(const char *string)
464 {
465         return get_lock(string, NULL /* don't need pid */, true /* do wait */);
466 }
467
468 bool release_lock(const char *string)
469 {
470         int semid, ret;
471         struct sembuf sops[2];
472         key_t key = super_dull_hash(string);
473         bool success;
474
475         ret = semget(key, 2, IPC_CREAT | 0600);
476         if (ret < 0) {
477                 ERROR_LOG("semget: %m\n");
478                 return false;
479         }
480         semid = ret;
481         DEBUG_LOG("key: 0x%0x, semid: %d\n", (unsigned)key, semid);
482         sops[0].sem_num = 0;
483         sops[0].sem_op = -1;
484         sops[0].sem_flg = SEM_UNDO;
485         sops[1].sem_num = 1;
486         sops[1].sem_op = -1;
487         sops[1].sem_flg = SEM_UNDO;
488         success = semop(semid, sops, 2) >= 0;
489         if (!success)
490                 INFO_LOG("semop: %m\n");
491         return success;
492 }
493
494 bool is_locked(const char *string, pid_t *pid)
495 {
496         int ret, semid;
497         struct sembuf sops = {
498                 .sem_num = 0,
499                 .sem_op = 0,
500                 .sem_flg = SEM_UNDO | IPC_NOWAIT
501         };
502         key_t key = super_dull_hash(string);
503
504         if (pid)
505                 *pid = 0;
506         ret = semget(key, 2, 0);
507         if (ret < 0)
508                 return false;
509         semid = ret;
510         DEBUG_LOG("key: 0x%0x, semid: %d\n", (unsigned)key, semid);
511         if (semop(semid, &sops, 1) >= 0)
512                 return false;
513         ret = semctl(semid, 1, GETPID);
514         if (ret < 0)
515                 return false;
516         if (pid)
517                 *pid = ret;
518         return true;
519 }
520
521 bool attach_to_bridge(const char *iface, const char *bridge)
522 {
523         int fd, idx;
524         struct ifreq ifr;
525         bool success;
526
527         INFO_LOG("adding interface %s to bridge %s\n", iface, bridge);
528         if (!(idx = if_nametoindex(iface))) {
529                 ERROR_LOG("no index for %s: %m\n", iface);
530                 return false;
531         }
532         if ((fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
533                 ERROR_LOG("socket: %m\n");
534                 return false;
535         }
536         strncpy(ifr.ifr_name, bridge, IFNAMSIZ - 1);
537         ifr.ifr_name[IFNAMSIZ - 1] = '\0';
538         ifr.ifr_ifindex = idx;
539         success = ioctl(fd, SIOCBRADDIF, &ifr) == 0;
540         if (!success)
541                 ERROR_LOG("interface %s, bridge %s: ioctl SIOCBRADDIF: %m\n",
542                         iface, bridge);
543         close(fd);
544         return success;
545 }
546
547
548 #define NLMSG_TAIL(nmsg) \
549         ((struct rtattr *) (((void *) (nmsg)) + NLMSG_ALIGN((nmsg)->nlmsg_len)))
550
551 static void addattr_l(struct nlmsghdr *nlh, int type, const void *data,
552                 int alen)
553 {
554         int len = RTA_LENGTH(alen);
555         struct rtattr *rta;
556
557         rta = NLMSG_TAIL(nlh);
558         rta->rta_type = type;
559         rta->rta_len = len;
560         if (alen > 0)
561                 memcpy(RTA_DATA(rta), data, alen);
562         nlh->nlmsg_len = NLMSG_ALIGN(nlh->nlmsg_len) + RTA_ALIGN(len);
563 }
564
565 static struct rtattr *addattr_nest(struct nlmsghdr *n, int type)
566 {
567         struct rtattr *nest = NLMSG_TAIL(n);
568         addattr_l(n, type, NULL, 0);
569         return nest;
570 }
571
572 static void end_nest(struct nlmsghdr *nlh, struct rtattr *attr)
573 {
574         attr->rta_len = (void *)NLMSG_TAIL(nlh) - (void *)attr;
575 }
576
577 static struct mnl_socket *get_and_bind_netlink_socket(void)
578 {
579         struct mnl_socket *nl = mnl_socket_open(NETLINK_ROUTE);
580
581         if (!nl) {
582                 ERROR_LOG("mnl_socket_open error\n");
583                 return NULL;
584         }
585         if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
586                 ERROR_LOG("mnl_socket_bind\n");
587                 mnl_socket_close(nl);
588                 return NULL;
589         }
590         return nl;
591 }
592
593 static struct nlmsghdr *prepare_netlink_msg_header(char *buf)
594 {
595         struct nlmsghdr *nlh = mnl_nlmsg_put_header(buf);
596         nlh->nlmsg_flags = NLM_F_REQUEST;
597         nlh->nlmsg_seq = time(NULL);
598         return nlh;
599 }
600
601 bool rename_interface(const char *before, const char *after)
602 {
603         int idx;
604         struct mnl_socket *nl;
605         char buf[MNL_SOCKET_BUFFER_SIZE];
606         struct nlmsghdr *nlh;
607         struct ifinfomsg *ifm;
608         bool success;
609
610         INFO_LOG("%s -> %s\n", before, after);
611         if (!(idx = if_nametoindex(before))) {
612                 ERROR_LOG("no index for %s\n", before);
613                 return false;
614         }
615         if (!(nl = get_and_bind_netlink_socket()))
616                 return false;
617
618         nlh = prepare_netlink_msg_header(buf);
619         nlh->nlmsg_type = RTM_NEWLINK;
620
621         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
622         ifm->ifi_family = AF_UNSPEC;
623         ifm->ifi_index = idx;
624         addattr_l(nlh, IFLA_IFNAME, after, strlen(after) + 1);
625         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
626                 ERROR_LOG("mnl_socket_sendto failed\n");
627                 success = false;
628                 goto close;
629         }
630         success = true;
631 close:
632         mnl_socket_close(nl);
633         return success;
634 }
635
636 void pretty_print_hwaddr(const uint8_t *hwaddr, char *result)
637 {
638         sprintf(result, "%02x:%02x:%02x:%02x:%02x:%02x", hwaddr[0], hwaddr[1],
639                 hwaddr[2], hwaddr[3], hwaddr[4], hwaddr[5]);
640 }
641
642 bool set_hwaddr(const char *iface, const uint8_t *hwaddr)
643 {
644         struct mnl_socket *nl;
645         char buf[MNL_SOCKET_BUFFER_SIZE];
646         struct nlmsghdr *nlh;
647         struct ifinfomsg *ifm;
648         bool success;
649         const uint8_t zero[6] = {0};
650         char pretty_hwaddr[18];
651
652         if (!memcmp(hwaddr, zero, 6))
653                 return true; /* no hwaddr specified, nothing to do */
654         pretty_print_hwaddr(hwaddr, pretty_hwaddr);
655         INFO_LOG("hardware address of %s: %s\n", iface, pretty_hwaddr);
656         if (!(nl = get_and_bind_netlink_socket()))
657                 return false;
658
659         nlh = prepare_netlink_msg_header(buf);
660         nlh->nlmsg_type = RTM_NEWLINK;
661
662         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
663         ifm->ifi_family = AF_UNSPEC;
664         addattr_l(nlh, IFLA_ADDRESS, hwaddr, 6);
665         addattr_l(nlh, IFLA_IFNAME, iface, strlen(iface) + 1);
666         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
667                 ERROR_LOG("%s: mnl_socket_sendto failed\n", iface);
668                 success = false;
669                 goto close;
670         }
671         success = true;
672 close:
673         mnl_socket_close(nl);
674         return success;
675 }
676
677 bool link_del(const char *iface)
678 {
679         struct mnl_socket *nl;
680         char buf[MNL_SOCKET_BUFFER_SIZE];
681         struct nlmsghdr *nlh;
682         struct ifinfomsg *ifm;
683         bool success;
684
685         INFO_LOG("removing interface %s\n", iface);
686         if (!(nl = get_and_bind_netlink_socket()))
687                 return false;
688
689         nlh = prepare_netlink_msg_header(buf);
690         nlh->nlmsg_type = RTM_DELLINK;
691
692         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
693         ifm->ifi_family = AF_UNSPEC;
694         ifm->ifi_change = IFF_UP;
695         ifm->ifi_flags = IFF_UP;
696         addattr_l(nlh, IFLA_IFNAME, iface, strlen(iface) + 1);
697         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
698                 ERROR_LOG("%s: mnl_socket_sendto failed\n", iface);
699                 success = false;
700                 goto close;
701         }
702         success = true;
703 close:
704         mnl_socket_close(nl);
705         return success;
706 }
707
708 bool link_up(const char *iface)
709 {
710         struct mnl_socket *nl;
711         char buf[MNL_SOCKET_BUFFER_SIZE];
712         struct nlmsghdr *nlh;
713         struct ifinfomsg *ifm;
714         bool success;
715
716         INFO_LOG("activating interface %s\n", iface);
717         if (!(nl = get_and_bind_netlink_socket()))
718                 return false;
719         nlh = prepare_netlink_msg_header(buf);
720         nlh->nlmsg_type = RTM_NEWLINK;
721
722         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
723         ifm->ifi_family = AF_UNSPEC;
724         ifm->ifi_change = IFF_UP;
725         ifm->ifi_flags = IFF_UP;
726         addattr_l(nlh, IFLA_IFNAME, iface, strlen(iface) + 1);
727         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
728                 ERROR_LOG("%s: mnl_socket_sendto failed\n", iface);
729                 success = false;
730                 goto close;
731         }
732         success = true;
733 close:
734         mnl_socket_close(nl);
735         return success;
736 }
737
738 #ifndef VETH_INFO_PEER
739 #define VETH_INFO_PEER 1
740 #endif
741
742 bool create_veth_device_pair(const char *name, char *peer)
743 {
744         struct mnl_socket *nl;
745         char buf[MNL_SOCKET_BUFFER_SIZE];
746         struct rtattr *n1, *n2, *n3;
747         struct nlmsghdr *nlh;
748         struct ifinfomsg *ifm;
749         bool success;
750
751         INFO_LOG("new pair: %s <-> %s\n", name, peer);
752         if (!(nl = get_and_bind_netlink_socket()))
753                 return false;
754
755         nlh = prepare_netlink_msg_header(buf);
756         nlh->nlmsg_type = RTM_NEWLINK;
757         nlh->nlmsg_flags |= NLM_F_CREATE | NLM_F_EXCL;
758
759         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
760         ifm->ifi_family = AF_UNSPEC;
761         n1 = addattr_nest(nlh, IFLA_LINKINFO);
762         addattr_l(nlh, IFLA_INFO_KIND, "veth", 5);
763         n2 = addattr_nest(nlh, IFLA_INFO_DATA);
764         n3 = addattr_nest(nlh, VETH_INFO_PEER);
765         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
766         ifm->ifi_family = AF_UNSPEC;
767         addattr_l(nlh, IFLA_IFNAME, peer, strlen(peer) + 1);
768         end_nest(nlh, n3);
769         end_nest(nlh, n2);
770         end_nest(nlh, n1);
771         addattr_l(nlh, IFLA_IFNAME, name, strlen(name) + 1);
772         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
773                 ERROR_LOG("%s: mnl_socket_sendto\n", name);
774                 success = false;
775                 goto close;
776         }
777         success = true;
778 close:
779         mnl_socket_close(nl);
780         return success;
781 }
782
783 bool set_netns(const char *iface, pid_t pid)
784 {
785         struct mnl_socket *nl;
786         char buf[MNL_SOCKET_BUFFER_SIZE];
787         struct nlmsghdr *nlh;
788         struct ifinfomsg *ifm;
789
790         INFO_LOG("changing net namespace of interface %s to pid %d\n",
791                 iface, (int)pid);
792         if (!(nl = get_and_bind_netlink_socket()))
793                 return false;
794
795         nlh = prepare_netlink_msg_header(buf);
796         nlh->nlmsg_type = RTM_NEWLINK;
797
798         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
799         ifm->ifi_family = AF_UNSPEC;
800         ifm->ifi_change = 0;
801         ifm->ifi_flags = 0;
802         addattr_l(nlh, IFLA_NET_NS_PID, &pid, sizeof(pid));
803         mnl_attr_put_str(nlh, IFLA_IFNAME, iface);
804
805         if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
806                 ERROR_LOG("%s: mnl_socket_sendto failed\n", iface);
807                 return false;
808         }
809         mnl_socket_close(nl);
810         return true;
811 }
812
813 #ifndef UNIX_PATH_MAX
814 #define UNIX_PATH_MAX (sizeof(((struct sockaddr_un *)0)->sun_path))
815 #endif
816
817 static bool init_unix_socket(const char *socket_path, int *socketfd,
818                 struct sockaddr_un *sau)
819 {
820         int fd;
821
822         *socketfd = -1;
823         if (strlen(socket_path) + 1 >= UNIX_PATH_MAX) {
824                 ERROR_LOG("socket path to long: %s\n", socket_path);
825                 return false;
826         }
827         memset(sau, 0, sizeof(struct sockaddr_un));
828         sau->sun_family = PF_UNIX;
829         sau->sun_path[0] = '\0'; /* use the abstract socket namespace */
830         strcpy(sau->sun_path + 1, socket_path);
831         fd = socket(PF_UNIX, SOCK_STREAM, 0);
832         if (fd < 0) {
833                 ERROR_LOG("socket: %m\n");
834                 return false;
835         }
836         *socketfd = fd;
837         return true;
838 }
839
840 bool listen_on_unix_socket(const char *socket_path, int *result)
841 {
842         struct sockaddr_un sau;
843         int fd, flags;
844         bool success = false;
845
846         if (!init_unix_socket(socket_path, &fd, &sau))
847                 return false;
848         flags = fcntl(fd, F_GETFL);
849         if (flags < 0) {
850                 ERROR_LOG("fcntl (F_GETFL): %m\n");
851                 goto fail;
852         }
853         flags = fcntl(fd, F_SETFL, ((long)flags) | O_NONBLOCK);
854         if (flags < 0) {
855                 ERROR_LOG("fcntl (F_SETFL): %m\n");
856                 goto fail;
857         }
858         if (bind(fd, (struct sockaddr *)&sau, sizeof(sau)) < 0) {
859                 ERROR_LOG("bind: %m\n");
860                 goto fail;
861         }
862         if (listen(fd , 5) < 0) {
863                 ERROR_LOG("listen: %m\n");
864                 goto fail;
865         }
866         *result = fd;
867         NOTICE_LOG("listening on fd %d\n", fd);
868         return true;
869 fail:
870         close(fd);
871         return success;
872 }
873 /*
874  * Send a buffer and the credentials of the current process to a socket.
875  *
876  * buf must be zero-terminated.
877  * return the return value of the underlying call to sendmsg().
878  */
879 static bool send_cred_buffer(int sock, char *buf)
880 {
881         char control[255] __attribute__((__aligned__(8)));
882         struct msghdr msg;
883         struct cmsghdr *cmsg;
884         static struct iovec iov;
885         struct ucred c;
886
887         /* Response data */
888         iov.iov_base = buf;
889         iov.iov_len = strlen(buf) + 1;
890         c.pid = getpid();
891         c.uid = getuid();
892         c.gid = getgid();
893         /* compose the message */
894         memset(&msg, 0, sizeof(msg));
895         msg.msg_iov = &iov;
896         msg.msg_iovlen = 1;
897         msg.msg_control = control;
898         msg.msg_controllen = sizeof(control);
899         /* attach the ucred struct */
900         cmsg = CMSG_FIRSTHDR(&msg);
901         cmsg->cmsg_level = SOL_SOCKET;
902         cmsg->cmsg_type = SCM_CREDENTIALS;
903         cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
904         *(struct ucred *)CMSG_DATA(cmsg) = c;
905         msg.msg_controllen = cmsg->cmsg_len;
906         if (sendmsg(sock, &msg, 0) < 0) {
907                 ERROR_LOG("sendmsg: %m\n");
908                 return false;
909         }
910         return true;
911 }
912
913 static void dispose_fds(int *fds, unsigned num)
914 {
915         int i;
916
917         for (i = 0; i < num; i++)
918                 close(fds[i]);
919 }
920
921 /* Receive a buffer and the Unix credentials of the sending process. */
922 bool recv_cred_buffer(int socketfd, char *buf, size_t size,
923                 int *clientfd, uid_t *uid)
924 {
925         char control[255] __attribute__((__aligned__(8)));
926         struct msghdr msg;
927         struct cmsghdr *cmsg;
928         struct iovec iov;
929         int yes = 1, cfd, ret;
930         struct ucred cred;
931         struct sockaddr_un sau;
932         socklen_t sizeof_sau = sizeof(sau);
933
934         ret = accept(socketfd, (struct sockaddr *)&sau, &sizeof_sau);
935         if (ret < 0) {
936                 ERROR_LOG("accept: %m\n");
937                 return false;
938         }
939         cfd = ret;
940         setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &yes, sizeof(int));
941         memset(&msg, 0, sizeof(msg));
942         iov.iov_base = buf;
943         iov.iov_len = size;
944         msg.msg_iov = &iov;
945         msg.msg_iovlen = 1;
946         msg.msg_control = control;
947         msg.msg_controllen = sizeof(control);
948         if (recvmsg(cfd, &msg, 0) < 0) {
949                 ERROR_LOG("recvmsg: %m\n");
950                 goto fail;
951         }
952         cmsg = CMSG_FIRSTHDR(&msg);
953         while (cmsg) {
954                 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type
955                                 == SCM_CREDENTIALS) {
956                         memcpy(&cred, CMSG_DATA(cmsg), sizeof(struct ucred));
957                         *uid = cred.uid;
958                         *clientfd = cfd;
959                         return true;
960                 } else
961                         if (cmsg->cmsg_level == SOL_SOCKET
962                                         && cmsg->cmsg_type == SCM_RIGHTS) {
963                                 dispose_fds((int *)CMSG_DATA(cmsg),
964                                         (cmsg->cmsg_len - CMSG_LEN(0))
965                                         / sizeof(int));
966                         }
967                 cmsg = CMSG_NXTHDR(&msg, cmsg);
968         }
969 fail:
970         close(*clientfd);
971         *clientfd = -1;
972         return false;
973 }
974
975 bool pass_fd(int passfd, int socketfd)
976 {
977         struct msghdr msg = {.msg_iov = NULL};
978         struct cmsghdr *cmsg;
979         char control[255] __attribute__((__aligned__(8)));
980         struct iovec iov;
981         char buf[] = "\0OK";
982
983         iov.iov_base = buf;
984         iov.iov_len  = sizeof(buf);
985
986         msg.msg_iov = &iov;
987         msg.msg_iovlen = 1;
988
989         msg.msg_control = control;
990         msg.msg_controllen = sizeof(control);
991
992         cmsg = CMSG_FIRSTHDR(&msg);
993         cmsg->cmsg_level = SOL_SOCKET;
994         cmsg->cmsg_type = SCM_RIGHTS;
995         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
996         *(int *)CMSG_DATA(cmsg) = passfd;
997
998         /* Sum of the length of all control messages in the buffer */
999         msg.msg_controllen = cmsg->cmsg_len;
1000         DEBUG_LOG("passing %s and fd %d\n", buf, passfd);
1001         if (sendmsg(socketfd, &msg, 0) < 0) {
1002                 ERROR_LOG("sendmsg: %m\n");
1003                 return false;
1004         }
1005         return true;
1006 }
1007
1008 static bool recv_fd(int socketfd, int *recvfd)
1009 {
1010         char control[255] __attribute__((__aligned__(8)));
1011         struct msghdr msg = {.msg_iov = NULL};
1012         struct cmsghdr *cmsg;
1013         struct iovec iov;
1014         char buf[100];
1015         ssize_t sz = sizeof(buf), ssz;
1016
1017         *recvfd = -1;
1018         iov.iov_base = buf;
1019         iov.iov_len = sz - 1;
1020         msg.msg_iov = &iov;
1021         msg.msg_iovlen = 1;
1022         msg.msg_control = control;
1023         msg.msg_controllen = sizeof(control);
1024         memset(buf, 0, sz);
1025         ssz = recvmsg(socketfd, &msg, 0);
1026         if (ssz < 0) {
1027                 ERROR_LOG("recvmsg: %m\n");
1028                 return false;
1029         }
1030         buf[ssz] = '\0';
1031         INFO_LOG("server response: %u (%s)\n", (unsigned)buf[0], buf + 1);
1032         for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
1033                 if (cmsg->cmsg_level != SOL_SOCKET
1034                         || cmsg->cmsg_type != SCM_RIGHTS)
1035                         continue;
1036                 if ((cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int) != 1)
1037                         continue;
1038                 *recvfd = *(int *)CMSG_DATA(cmsg);
1039                 return true;
1040         }
1041         return false;
1042 }
1043
1044 int request_fd(const char *socket_path, char *msg, int *result)
1045 {
1046         struct sockaddr_un sau;
1047         int socketfd, receivefd;
1048
1049         if (!init_unix_socket(socket_path, &socketfd, &sau))
1050                 die("could not init socket");
1051         if (connect(socketfd, (struct sockaddr *)&sau, sizeof(sau)) < 0)
1052                 die_errno("connect");
1053         if (!send_cred_buffer(socketfd, msg))
1054                 die("could not send cred buffer");
1055         if (!recv_fd(socketfd, &receivefd))
1056                 die("did not receive tty fd");
1057         NOTICE_LOG("received fd %d\n", receivefd);
1058         *result = receivefd;
1059         return socketfd;
1060 }
1061
1062 bool request_int(const char *socket_path, char *msg, int *result)
1063 {
1064         struct sockaddr_un sau;
1065         int socketfd;
1066         bool success = false;
1067         char buf[100];
1068         ssize_t ssz;
1069
1070         *result = -1;
1071         if (!init_unix_socket(socket_path, &socketfd, &sau))
1072                 return false;
1073         if (connect(socketfd, (struct sockaddr *)&sau, sizeof(sau)) < 0) {
1074                 ERROR_LOG("connect: %m\n");
1075                 goto close;
1076         }
1077         if (!send_cred_buffer(socketfd, msg)) {
1078                 ERROR_LOG("could not send cred msg \"%s\"\n", msg);
1079                 goto close;
1080         }
1081         ssz = read(socketfd, buf, sizeof(buf) - 1);
1082         if (ssz < 0) {
1083                 ERROR_LOG("did not receive integer: %m\n");
1084                 goto close;
1085         }
1086         if (buf[0] != 0) {
1087                 ERROR_LOG("did not receive integer: %s\n", buf + 1);
1088                 goto close;
1089         }
1090         if (ssz != sizeof(int) + 1) {
1091                 ERROR_LOG("protocol mismatch, server msg: %s\n", buf + 1);
1092                 goto close;
1093         }
1094         memcpy(result, buf + 1, sizeof(int));
1095         DEBUG_LOG("received integer: %d\n", *result);
1096         success = true;
1097 close:
1098         close(socketfd);
1099         return success;
1100 }
1101
1102 int signal_pipe[2];
1103
1104 static void signal_handler(int signum)
1105 {
1106         uint8_t u = signum;
1107         int save_errno = errno;
1108         assert(signum > 0 && signum < 256);
1109         if (write(signal_pipe[1], &u, 1) < 0)
1110                 ERROR_LOG("write to signal pipe: %m\n");
1111         errno = save_errno;
1112 }
1113
1114 void init_signal_handling(void)
1115 {
1116         struct sigaction act;
1117
1118         if (pipe(signal_pipe) < 0)
1119                 die_errno("signal pipe");
1120         act.sa_handler = signal_handler;
1121         sigemptyset(&act.sa_mask);
1122         act.sa_flags = SA_RESTART;
1123         if (sigaction(SIGINT, &act, NULL) < 0)
1124                 die_errno("sigaction");
1125         if (sigaction(SIGTERM, &act, NULL) < 0)
1126                 die_errno("sigaction");
1127         if (sigaction(SIGCHLD, &act, NULL) < 0)
1128                 die_errno("sigaction");
1129 }
1130
1131 int next_signal(void)
1132 {
1133         uint8_t u = 0;
1134 again:
1135         if (read(signal_pipe[0], &u, 1) < 0) {
1136                 if (errno != EINTR)
1137                         die_errno("read");
1138                 goto again;
1139         }
1140         DEBUG_LOG("process %d received signal %u\n", getpid(), u);
1141         return u;
1142 }