Merge branch 't/osx_overhaul'
[paraslash.git] / crypt.c
1 /*
2  * Copyright (C) 2005-2011 Andre Noll <maan@systemlinux.org>
3  *
4  * Licensed under the GPL v2. For licencing details see COPYING.
5  */
6
7 /** \file crypt.c Openssl-based encryption/decryption routines. */
8
9 #include <regex.h>
10 #include <sys/types.h>
11 #include <sys/socket.h>
12 #include <openssl/rand.h>
13 #include <openssl/err.h>
14 #include <openssl/rc4.h>
15 #include <openssl/pem.h>
16 #include <openssl/sha.h>
17 #include <openssl/bn.h>
18
19 #include "para.h"
20 #include "error.h"
21 #include "string.h"
22 #include "crypt.h"
23 #include "fd.h"
24
25 struct asymmetric_key {
26         RSA *rsa;
27 };
28
29 /**
30  * Fill a buffer with random content.
31  *
32  * \param buf The buffer to fill.
33  * \param num The size of \a buf in bytes.
34  *
35  * This function puts \a num cryptographically strong pseudo-random bytes into
36  * buf. If libssl can not guarantee an unpredictable byte sequence (for example
37  * because the PRNG has not been seeded with enough randomness) the function
38  * logs an error message and calls exit().
39  */
40 void get_random_bytes_or_die(unsigned char *buf, int num)
41 {
42         unsigned long err;
43
44         /* RAND_bytes() returns 1 on success, 0 otherwise. */
45         if (RAND_bytes(buf, num) == 1)
46                 return;
47         err = ERR_get_error();
48         PARA_EMERG_LOG("%s\n", ERR_reason_error_string(err));
49         exit(EXIT_FAILURE);
50 }
51
52 /**
53  * Seed pseudo random number generators.
54  *
55  * This function reads 64 bytes from /dev/urandom and adds them to the SSL
56  * PRNG. It also seeds the PRNG used by random() with a random seed obtained
57  * from SSL. If /dev/random could not be read, an error message is logged and
58  * the function calls exit().
59  *
60  * \sa RAND_load_file(3), \ref get_random_bytes_or_die(), srandom(3),
61  * random(3), \ref para_random().
62  */
63 void init_random_seed_or_die(void)
64 {
65         int seed, ret = RAND_load_file("/dev/urandom", 64);
66
67         if (ret != 64) {
68                 PARA_EMERG_LOG("could not seed PRNG (ret = %d)\n", ret);
69                 exit(EXIT_FAILURE);
70         }
71         get_random_bytes_or_die((unsigned char *)&seed, sizeof(seed));
72         srandom(seed);
73 }
74
75 static int check_key_file(const char *file, int private)
76 {
77         struct stat st;
78
79         if (stat(file, &st) != 0)
80                 return -ERRNO_TO_PARA_ERROR(errno);
81         if (private != LOAD_PRIVATE_KEY)
82                 return 0;
83         if ((st.st_uid == getuid()) && (st.st_mode & 077) != 0)
84                 return -E_KEY_PERM;
85         return 1;
86 }
87
88 static EVP_PKEY *load_key(const char *file, int private)
89 {
90         BIO *key;
91         EVP_PKEY *pkey = NULL;
92         int ret = check_key_file(file, private);
93
94         if (ret < 0) {
95                 PARA_ERROR_LOG("%s\n", para_strerror(-ret));
96                 return NULL;
97         }
98         key = BIO_new(BIO_s_file());
99         if (!key)
100                 return NULL;
101         if (BIO_read_filename(key, file) > 0) {
102                 if (private == LOAD_PRIVATE_KEY)
103                         pkey = PEM_read_bio_PrivateKey(key, NULL, NULL, NULL);
104                 else
105                         pkey = PEM_read_bio_PUBKEY(key, NULL, NULL, NULL);
106         }
107         BIO_free(key);
108         return pkey;
109 }
110
111 static int get_openssl_key(const char *key_file, RSA **rsa, int private)
112 {
113         EVP_PKEY *key = load_key(key_file, private);
114
115         if (!key)
116                 return (private == LOAD_PRIVATE_KEY)? -E_PRIVATE_KEY
117                         : -E_PUBLIC_KEY;
118         *rsa = EVP_PKEY_get1_RSA(key);
119         EVP_PKEY_free(key);
120         if (!*rsa)
121                 return -E_RSA;
122         return RSA_size(*rsa);
123 }
124
125 #define KEY_TYPE_TXT "ssh-rsa"
126
127 /* check if it is an ssh rsa key */
128 static size_t is_ssh_rsa_key(char *data, size_t size)
129 {
130         char *cp;
131
132         if (size < strlen(KEY_TYPE_TXT) + 2)
133                 return 0;
134         cp = memchr(data, ' ', size);
135         if (cp == NULL)
136                 return 0;
137         if (strncmp(KEY_TYPE_TXT, data, strlen(KEY_TYPE_TXT)))
138                 return 0;
139         cp++;
140         if (cp >= data + size)
141                 return 0;
142         if (*cp == '\0')
143                 return 0;
144         return cp - data;
145 }
146
147 /*
148  * This base64/uudecode stuff below is taken from openssh-5.2p1, Copyright (c)
149  * 1996 by Internet Software Consortium.  Portions Copyright (c) 1995 by
150  * International Business Machines, Inc.
151  */
152
153 static const char Base64[] =
154         "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
155 static const char Pad64 = '=';
156 /*
157  * Skips all whitespace anywhere. Converts characters, four at a time, starting
158  * at (or after) src from base - 64 numbers into three 8 bit bytes in the
159  * target area. it returns the number of data bytes stored at the target, or -1
160  * on error.
161  */
162 static int base64_decode(char const *src, unsigned char *target, size_t targsize)
163 {
164         unsigned int tarindex, state;
165         int ch;
166         char *pos;
167
168         state = 0;
169         tarindex = 0;
170
171         while ((ch = *src++) != '\0') {
172                 if (para_isspace(ch)) /* Skip whitespace anywhere. */
173                         continue;
174
175                 if (ch == Pad64)
176                         break;
177
178                 pos = strchr(Base64, ch);
179                 if (pos == 0) /* A non-base64 character. */
180                         return -1;
181
182                 switch (state) {
183                 case 0:
184                         if (target) {
185                                 if (tarindex >= targsize)
186                                         return (-1);
187                                 target[tarindex] = (pos - Base64) << 2;
188                         }
189                         state = 1;
190                         break;
191                 case 1:
192                         if (target) {
193                                 if (tarindex + 1 >= targsize)
194                                         return (-1);
195                                 target[tarindex]   |=  (pos - Base64) >> 4;
196                                 target[tarindex+1]  = ((pos - Base64) & 0x0f)
197                                                         << 4 ;
198                         }
199                         tarindex++;
200                         state = 2;
201                         break;
202                 case 2:
203                         if (target) {
204                                 if (tarindex + 1 >= targsize)
205                                         return (-1);
206                                 target[tarindex]   |=  (pos - Base64) >> 2;
207                                 target[tarindex+1]  = ((pos - Base64) & 0x03)
208                                                         << 6;
209                         }
210                         tarindex++;
211                         state = 3;
212                         break;
213                 case 3:
214                         if (target) {
215                                 if (tarindex >= targsize)
216                                         return (-1);
217                                 target[tarindex] |= (pos - Base64);
218                         }
219                         tarindex++;
220                         state = 0;
221                         break;
222                 }
223         }
224
225         /*
226          * We are done decoding Base-64 chars.  Let's see if we ended
227          * on a byte boundary, and/or with erroneous trailing characters.
228          */
229
230         if (ch == Pad64) {              /* We got a pad char. */
231                 ch = *src++;            /* Skip it, get next. */
232                 switch (state) {
233                 case 0:         /* Invalid = in first position */
234                 case 1:         /* Invalid = in second position */
235                         return (-1);
236
237                 case 2:         /* Valid, means one byte of info */
238                         /* Skip any number of spaces. */
239                         for (; ch != '\0'; ch = *src++)
240                                 if (!isspace(ch))
241                                         break;
242                         /* Make sure there is another trailing = sign. */
243                         if (ch != Pad64)
244                                 return (-1);
245                         ch = *src++;            /* Skip the = */
246                         /* Fall through to "single trailing =" case. */
247                         /* FALLTHROUGH */
248
249                 case 3:         /* Valid, means two bytes of info */
250                         /*
251                          * We know this char is an =.  Is there anything but
252                          * whitespace after it?
253                          */
254                         for (; ch != '\0'; ch = *src++)
255                                 if (!isspace(ch))
256                                         return (-1);
257
258                         /*
259                          * Now make sure for cases 2 and 3 that the "extra"
260                          * bits that slopped past the last full byte were
261                          * zeros.  If we don't check them, they become a
262                          * subliminal channel.
263                          */
264                         if (target && target[tarindex] != 0)
265                                 return (-1);
266                 }
267         } else {
268                 /*
269                  * We ended by seeing the end of the string.  Make sure we
270                  * have no partial bytes lying around.
271                  */
272                 if (state != 0)
273                         return (-1);
274         }
275
276         return (tarindex);
277 }
278
279 static int uudecode(const char *src, unsigned char *target, size_t targsize)
280 {
281         int len;
282         char *encoded, *p;
283
284         /* copy the 'readonly' source */
285         encoded = para_strdup(src);
286         /* skip whitespace and data */
287         for (p = encoded; *p == ' ' || *p == '\t'; p++)
288                 ;
289         for (; *p != '\0' && *p != ' ' && *p != '\t'; p++)
290                 ;
291         /* and remove trailing whitespace because base64_decode needs this */
292         *p = '\0';
293         len = base64_decode(encoded, target, targsize);
294         free(encoded);
295         return len >= 0? len : -E_BASE64;
296 }
297
298 /*
299  * The public key loading functions below were inspired by corresponding code
300  * of openssh-5.2p1, Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo,
301  * Finland. However, not much of the original code remains.
302  */
303
304
305 /*
306  * Can not use the inline functions of portable_io.h here because the byte
307  * order is different.
308  */
309 static uint32_t read_ssh_u32(const void *vp)
310 {
311         const unsigned char *p = (const unsigned char *)vp;
312         uint32_t v;
313
314         v  = (uint32_t)p[0] << 24;
315         v |= (uint32_t)p[1] << 16;
316         v |= (uint32_t)p[2] << 8;
317         v |= (uint32_t)p[3];
318
319         return v;
320 }
321
322 static int read_bignum(const unsigned char *buf, size_t len, BIGNUM **result)
323 {
324         const unsigned char *p = buf, *end = buf + len;
325         uint32_t bnsize;
326         BIGNUM *bn;
327
328         if (p + 4 < p)
329                 return -E_BIGNUM;
330         if (p + 4 > end)
331                 return -E_BIGNUM;
332         bnsize = read_ssh_u32(p);
333         PARA_DEBUG_LOG("bnsize: %u\n", bnsize);
334         p += 4;
335         if (p + bnsize < p)
336                 return -E_BIGNUM;
337         if (p + bnsize > end)
338                 return -E_BIGNUM;
339         if (bnsize > 8192)
340                 return -E_BIGNUM;
341         bn = BN_bin2bn(p, bnsize, NULL);
342         if (!bn)
343                 return -E_BIGNUM;
344         *result = bn;
345         return bnsize + 4;
346 }
347
348 static int read_rsa_bignums(const unsigned char *blob, int blen, RSA **result)
349 {
350         int ret;
351         RSA *rsa;
352         const unsigned char *p = blob, *end = blob + blen;
353         uint32_t rlen;
354
355         *result = NULL;
356         if (p + 4 > end)
357                 return -E_BIGNUM;
358         rlen = read_ssh_u32(p);
359         p += 4;
360         if (p + rlen < p)
361                 return -E_BIGNUM;
362         if (p + rlen > end)
363                 return -E_BIGNUM;
364         if (rlen < strlen(KEY_TYPE_TXT))
365                 return -E_BIGNUM;
366         PARA_DEBUG_LOG("type: %s, rlen: %d\n", p, rlen);
367         if (strncmp((char *)p, KEY_TYPE_TXT, strlen(KEY_TYPE_TXT)))
368                 return -E_BIGNUM;
369         p += rlen;
370
371         rsa = RSA_new();
372         if (!rsa)
373                 return -E_BIGNUM;
374         ret = read_bignum(p, end - p, &rsa->e);
375         if (ret < 0)
376                 goto fail;
377         p += ret;
378         ret = read_bignum(p, end - p, &rsa->n);
379         if (ret < 0)
380                 goto fail;
381         *result = rsa;
382         return 1;
383 fail:
384         if (rsa)
385                 RSA_free(rsa);
386         return ret;
387 }
388
389 /**
390  * Read an asymmetric key from a file.
391  *
392  * \param key_file The file containing the key.
393  * \param private if non-zero, read the private key, otherwise the public key.
394  * \param result The key structure is returned here.
395  *
396  * \return The size of the key on success, negative on errors.
397  *
398  * \sa openssl(1), rsa(1).
399  */
400 int get_asymmetric_key(const char *key_file, int private,
401                 struct asymmetric_key **result)
402 {
403         struct asymmetric_key *key = NULL;
404         void *map = NULL;
405         unsigned char *blob = NULL;
406         size_t map_size, blob_size;
407         int ret, ret2;
408         char *cp;
409
410         key = para_malloc(sizeof(*key));
411         if (private) {
412                 ret = get_openssl_key(key_file, &key->rsa, LOAD_PRIVATE_KEY);
413                 goto out;
414         }
415         ret = mmap_full_file(key_file, O_RDONLY, &map, &map_size, NULL);
416         if (ret < 0)
417                 goto out;
418         ret = is_ssh_rsa_key(map, map_size);
419         if (!ret) {
420                 ret = para_munmap(map, map_size);
421                 map = NULL;
422                 if (ret < 0)
423                         goto out;
424                 ret = get_openssl_key(key_file, &key->rsa, LOAD_PUBLIC_KEY);
425                 goto out;
426         }
427         cp = map + ret;
428         PARA_INFO_LOG("decoding public rsa-ssh key %s\n", key_file);
429         ret = -ERRNO_TO_PARA_ERROR(EOVERFLOW);
430         if (map_size > INT_MAX / 4)
431                 goto out;
432         blob_size = 2 * map_size;
433         blob = para_malloc(blob_size);
434         ret = uudecode(cp, blob, blob_size);
435         if (ret < 0)
436                 goto out;
437         ret = read_rsa_bignums(blob, ret, &key->rsa);
438         if (ret < 0)
439                 goto out;
440         ret = RSA_size(key->rsa);
441 out:
442         ret2 = para_munmap(map, map_size);
443         if (ret >= 0 && ret2 < 0)
444                 ret = ret2;
445         if (ret < 0) {
446                 free(key);
447                 result = NULL;
448                 PARA_ERROR_LOG("key %s: %s\n", key_file, para_strerror(-ret));
449         } else
450                 *result = key;
451         free(blob);
452         return ret;
453 }
454
455 /**
456  * Deallocate an asymmetric key structure.
457  *
458  * \param key Pointer to the key structure to free.
459  *
460  * This must be called for any key obtained by get_asymmetric_key().
461  */
462 void free_asymmetric_key(struct asymmetric_key *key)
463 {
464         if (!key)
465                 return;
466         RSA_free(key->rsa);
467         free(key);
468 }
469
470 /**
471  * Decrypt a buffer using a private key.
472  *
473  * \param key_file Full path of the key.
474  * \param outbuf The output buffer.
475  * \param inbuf The encrypted input buffer.
476  * \param inlen The length of \a inbuf in bytes.
477  *
478  * The \a outbuf must be large enough to hold at least \a rsa_inlen bytes.
479  *
480  * \return The size of the recovered plaintext on success, negative on errors.
481  *
482  * \sa RSA_private_decrypt(3)
483  **/
484 int priv_decrypt(const char *key_file, unsigned char *outbuf,
485                 unsigned char *inbuf, int inlen)
486 {
487         struct asymmetric_key *priv;
488         int ret;
489
490         if (inlen < 0)
491                 return -E_RSA;
492         ret = get_asymmetric_key(key_file, LOAD_PRIVATE_KEY, &priv);
493         if (ret < 0)
494                 return ret;
495         /*
496          * RSA is vulnerable to timing attacks. Generate a random blinding
497          * factor to protect against this kind of attack.
498          */
499         ret = -E_BLINDING;
500         if (RSA_blinding_on(priv->rsa, NULL) == 0)
501                 goto out;
502         ret = RSA_private_decrypt(inlen, inbuf, outbuf, priv->rsa,
503                 RSA_PKCS1_OAEP_PADDING);
504         RSA_blinding_off(priv->rsa);
505         if (ret <= 0)
506                 ret = -E_DECRYPT;
507 out:
508         free_asymmetric_key(priv);
509         return ret;
510 }
511
512 /**
513  * Encrypt a buffer using an RSA key
514  *
515  * \param pub: The public key.
516  * \param inbuf The input buffer.
517  * \param len The length of \a inbuf.
518  * \param outbuf The output buffer.
519  *
520  * \return The size of the encrypted data on success, negative on errors.
521  *
522  * \sa RSA_public_encrypt(3)
523  */
524 int pub_encrypt(struct asymmetric_key *pub, unsigned char *inbuf,
525                 unsigned len, unsigned char *outbuf)
526 {
527         int ret, flen = len; /* RSA_public_encrypt expects a signed int */
528
529         if (flen < 0)
530                 return -E_ENCRYPT;
531         ret = RSA_public_encrypt(flen, inbuf, outbuf, pub->rsa,
532                 RSA_PKCS1_OAEP_PADDING);
533         return ret < 0? -E_ENCRYPT : ret;
534 }
535
536 struct stream_cipher {
537         RC4_KEY key;
538 };
539
540 /**
541  * Allocate and initialize a stream cipher structure.
542  *
543  * \param data The key.
544  * \param len The size of the key.
545  *
546  * \return A new stream cipher structure.
547  */
548 struct stream_cipher *sc_new(const unsigned char *data, int len)
549 {
550         struct stream_cipher *sc = para_malloc(sizeof(*sc));
551         RC4_set_key(&sc->key, len, data);
552         return sc;
553 }
554
555 /**
556  * Deallocate a stream cipher structure.
557  *
558  * \param sc A stream cipher previously obtained by sc_new().
559  */
560 void sc_free(struct stream_cipher *sc)
561 {
562         free(sc);
563 }
564
565 /**
566  * The RC4() implementation of openssl apparently reads and writes data in
567  * blocks of 8 bytes. So we have to make sure our buffer sizes are a multiple
568  * of this.
569  */
570 #define RC4_ALIGN 8
571
572 /**
573  * Encrypt and send a buffer.
574  *
575  * \param scc The context.
576  * \param buf The buffer to send.
577  * \param len The size of \a buf in bytes.
578  *
579  * \return The return value of the underyling call to write_all().
580  *
581  * \sa \ref write_all(), RC4(3).
582  */
583 int sc_send_bin_buffer(struct stream_cipher_context *scc, const char *buf,
584                 size_t len)
585 {
586         int ret;
587         unsigned char *tmp;
588         static unsigned char remainder[RC4_ALIGN];
589         size_t l1 = ROUND_DOWN(len, RC4_ALIGN), l2 = ROUND_UP(len, RC4_ALIGN);
590
591         assert(len);
592         tmp = para_malloc(l2);
593         RC4(&scc->send->key, l1, (const unsigned char *)buf, tmp);
594         if (len > l1) {
595                 memcpy(remainder, buf + l1, len - l1);
596                 RC4(&scc->send->key, len - l1, remainder, tmp + l1);
597         }
598         ret = write_all(scc->fd, (char *)tmp, &len);
599         free(tmp);
600         return ret;
601 }
602
603 /**
604  * Encrypt and send a \p NULL-terminated buffer.
605  *
606  * \param scc The context.
607  * \param buf The buffer to send.
608  *
609  * \return The return value of the underyling call to sc_send_bin_buffer().
610  */
611 int sc_send_buffer(struct stream_cipher_context *scc, const char *buf)
612 {
613         return sc_send_bin_buffer(scc, buf, strlen(buf));
614 }
615
616 /**
617  * Format, encrypt and send a buffer.
618  *
619  * \param scc The context.
620  * \param fmt A format string.
621  *
622  * \return The return value of the underyling call to sc_send_buffer().
623  */
624 __printf_2_3 int sc_send_va_buffer(struct stream_cipher_context *scc,
625                 const char *fmt, ...)
626 {
627         char *msg;
628         int ret;
629
630         PARA_VSPRINTF(fmt, msg);
631         ret = sc_send_buffer(scc, msg);
632         free(msg);
633         return ret;
634 }
635
636 /**
637  * Receive a buffer and decrypt it.
638  *
639  * \param scc The context.
640  * \param buf The buffer to write the decrypted data to.
641  * \param size The size of \a buf.
642  *
643  * \return The number of bytes received on success, negative on errors, zero if
644  * the peer has performed an orderly shutdown.
645  *
646  * \sa recv(2), RC4(3).
647  */
648 int sc_recv_bin_buffer(struct stream_cipher_context *scc, char *buf,
649                 size_t size)
650 {
651         unsigned char *tmp = para_malloc(size);
652         ssize_t ret = recv(scc->fd, tmp, size, 0);
653
654         if (ret > 0)
655                 RC4(&scc->recv->key, ret, tmp, (unsigned char *)buf);
656         else if (ret < 0)
657                 ret = -ERRNO_TO_PARA_ERROR(errno);
658         free(tmp);
659         return ret;
660 }
661
662 /**
663  * Receive a buffer, decrypt it and write terminating NULL byte.
664  *
665  * \param scc The context.
666  * \param buf The buffer to write the decrypted data to.
667  * \param size The size of \a buf.
668  *
669  * Read at most \a size - 1 bytes from file descriptor given by \a scc, decrypt
670  * the received data and write a NULL byte at the end of the decrypted data.
671  *
672  * \return The return value of the underlying call to \ref
673  * sc_recv_bin_buffer().
674  */
675 int sc_recv_buffer(struct stream_cipher_context *scc, char *buf, size_t size)
676 {
677         int n;
678
679         assert(size);
680         n = sc_recv_bin_buffer(scc, buf, size - 1);
681         if (n >= 0)
682                 buf[n] = '\0';
683         else
684                 *buf = '\0';
685         return n;
686 }
687
688 /**
689  * Compute the hash of the given input data.
690  *
691  * \param data Pointer to the data to compute the hash value from.
692  * \param len The length of \a data in bytes.
693  * \param hash Result pointer.
694  *
695  * \a hash must point to an area at least \p HASH_SIZE bytes large.
696  *
697  * \sa sha(3), openssl(1).
698  * */
699 void hash_function(const char *data, unsigned long len, unsigned char *hash)
700 {
701         SHA_CTX c;
702         SHA1_Init(&c);
703         SHA1_Update(&c, data, len);
704         SHA1_Final(hash, &c);
705 }