Add support for ssh-rsa keys.
authorAndre Noll <maan@systemlinux.org>
Thu, 25 Nov 2010 07:33:20 +0000 (08:33 +0100)
committerAndre Noll <maan@systemlinux.org>
Tue, 26 Apr 2011 20:34:22 +0000 (22:34 +0200)
This allows to use standard ssh keys (that is, keys generated with
ssh-keygen) for the challenge/response authentication method of
paraslash. Only RSA keys without password protection are supported
at the moment.

Since we want that both openssl and ssh keys just work, we introduce
the helper function is_ssh_rsa_key(). It looks at the first few bytes
of the key to decide which type of public key we have. For openssl
keys, we just call openssl's EVP_PKEY_get1_RSA() and be done. Private
keys generated by ssh-keygen do not differ from keys generated by
"openssl rsa" and need no special treatment either.

However, public ssh rsa keys are stored differently, as an uuencoded
byte stream. So this patch adds functions that decode a given buffer
via base64 or uudecode. The two rsa public parameters (modulus and
exponent) are then read from the decoded buffer using BN_bin2bn().

crypt.c
error.h

diff --git a/crypt.c b/crypt.c
index 8e1814d..6f7e611 100644 (file)
--- a/crypt.c
+++ b/crypt.c
@@ -14,6 +14,7 @@
 #include <openssl/rc4.h>
 #include <openssl/pem.h>
 #include <openssl/sha.h>
 #include <openssl/rc4.h>
 #include <openssl/pem.h>
 #include <openssl/sha.h>
+#include <openssl/bn.h>
 
 #include "para.h"
 #include "error.h"
 
 #include "para.h"
 #include "error.h"
@@ -107,6 +108,284 @@ static EVP_PKEY *load_key(const char *file, int private)
        return pkey;
 }
 
        return pkey;
 }
 
+static int get_openssl_key(const char *key_file, RSA **rsa, int private)
+{
+       EVP_PKEY *key = load_key(key_file, private);
+
+       if (!key)
+               return (private == LOAD_PRIVATE_KEY)? -E_PRIVATE_KEY
+                       : -E_PUBLIC_KEY;
+       *rsa = EVP_PKEY_get1_RSA(key);
+       EVP_PKEY_free(key);
+       if (!*rsa)
+               return -E_RSA;
+       return RSA_size(*rsa);
+}
+
+#define KEY_TYPE_TXT "ssh-rsa"
+
+/* check if it is an ssh rsa key */
+static size_t is_ssh_rsa_key(char *data, size_t size)
+{
+       char *cp;
+
+       if (size < strlen(KEY_TYPE_TXT) + 2)
+               return 0;
+       cp = memchr(data, ' ', size);
+       if (cp == NULL)
+               return 0;
+       if (strncmp(KEY_TYPE_TXT, data, strlen(KEY_TYPE_TXT)))
+               return 0;
+       cp++;
+       if (cp >= data + size)
+               return 0;
+       if (*cp == '\0')
+               return 0;
+       return cp - data;
+}
+
+/*
+ * This base64/uudecode stuff below is taken from openssh-5.2p1, Copyright (c)
+ * 1996 by Internet Software Consortium.  Portions Copyright (c) 1995 by
+ * International Business Machines, Inc.
+ */
+
+static const char Base64[] =
+       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static const char Pad64 = '=';
+/*
+ * Skips all whitespace anywhere. Converts characters, four at a time, starting
+ * at (or after) src from base - 64 numbers into three 8 bit bytes in the
+ * target area. it returns the number of data bytes stored at the target, or -1
+ * on error.
+ */
+static int base64_decode(char const *src, unsigned char *target, size_t targsize)
+{
+       unsigned int tarindex, state;
+       int ch;
+       char *pos;
+
+       state = 0;
+       tarindex = 0;
+
+       while ((ch = *src++) != '\0') {
+               if (para_isspace(ch)) /* Skip whitespace anywhere. */
+                       continue;
+
+               if (ch == Pad64)
+                       break;
+
+               pos = strchr(Base64, ch);
+               if (pos == 0) /* A non-base64 character. */
+                       return -1;
+
+               switch (state) {
+               case 0:
+                       if (target) {
+                               if (tarindex >= targsize)
+                                       return (-1);
+                               target[tarindex] = (pos - Base64) << 2;
+                       }
+                       state = 1;
+                       break;
+               case 1:
+                       if (target) {
+                               if (tarindex + 1 >= targsize)
+                                       return (-1);
+                               target[tarindex]   |=  (pos - Base64) >> 4;
+                               target[tarindex+1]  = ((pos - Base64) & 0x0f)
+                                                       << 4 ;
+                       }
+                       tarindex++;
+                       state = 2;
+                       break;
+               case 2:
+                       if (target) {
+                               if (tarindex + 1 >= targsize)
+                                       return (-1);
+                               target[tarindex]   |=  (pos - Base64) >> 2;
+                               target[tarindex+1]  = ((pos - Base64) & 0x03)
+                                                       << 6;
+                       }
+                       tarindex++;
+                       state = 3;
+                       break;
+               case 3:
+                       if (target) {
+                               if (tarindex >= targsize)
+                                       return (-1);
+                               target[tarindex] |= (pos - Base64);
+                       }
+                       tarindex++;
+                       state = 0;
+                       break;
+               }
+       }
+
+       /*
+        * We are done decoding Base-64 chars.  Let's see if we ended
+        * on a byte boundary, and/or with erroneous trailing characters.
+        */
+
+       if (ch == Pad64) {              /* We got a pad char. */
+               ch = *src++;            /* Skip it, get next. */
+               switch (state) {
+               case 0:         /* Invalid = in first position */
+               case 1:         /* Invalid = in second position */
+                       return (-1);
+
+               case 2:         /* Valid, means one byte of info */
+                       /* Skip any number of spaces. */
+                       for (; ch != '\0'; ch = *src++)
+                               if (!isspace(ch))
+                                       break;
+                       /* Make sure there is another trailing = sign. */
+                       if (ch != Pad64)
+                               return (-1);
+                       ch = *src++;            /* Skip the = */
+                       /* Fall through to "single trailing =" case. */
+                       /* FALLTHROUGH */
+
+               case 3:         /* Valid, means two bytes of info */
+                       /*
+                        * We know this char is an =.  Is there anything but
+                        * whitespace after it?
+                        */
+                       for (; ch != '\0'; ch = *src++)
+                               if (!isspace(ch))
+                                       return (-1);
+
+                       /*
+                        * Now make sure for cases 2 and 3 that the "extra"
+                        * bits that slopped past the last full byte were
+                        * zeros.  If we don't check them, they become a
+                        * subliminal channel.
+                        */
+                       if (target && target[tarindex] != 0)
+                               return (-1);
+               }
+       } else {
+               /*
+                * We ended by seeing the end of the string.  Make sure we
+                * have no partial bytes lying around.
+                */
+               if (state != 0)
+                       return (-1);
+       }
+
+       return (tarindex);
+}
+
+static int uudecode(const char *src, unsigned char *target, size_t targsize)
+{
+       int len;
+       char *encoded, *p;
+
+       /* copy the 'readonly' source */
+       encoded = para_strdup(src);
+       /* skip whitespace and data */
+       for (p = encoded; *p == ' ' || *p == '\t'; p++)
+               ;
+       for (; *p != '\0' && *p != ' ' && *p != '\t'; p++)
+               ;
+       /* and remove trailing whitespace because base64_decode needs this */
+       *p = '\0';
+       len = base64_decode(encoded, target, targsize);
+       free(encoded);
+       return len >= 0? len : -E_BASE64;
+}
+
+/*
+ * The public key loading functions below were inspired by corresponding code
+ * of openssh-5.2p1, Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo,
+ * Finland. However, not much of the original code remains.
+ */
+
+
+/*
+ * Can not use the inline functions of portable_io.h here because the byte
+ * order is different.
+ */
+static uint32_t read_ssh_u32(const void *vp)
+{
+       const unsigned char *p = (const unsigned char *)vp;
+       uint32_t v;
+
+       v  = (uint32_t)p[0] << 24;
+       v |= (uint32_t)p[1] << 16;
+       v |= (uint32_t)p[2] << 8;
+       v |= (uint32_t)p[3];
+
+       return v;
+}
+
+static int read_bignum(const unsigned char *buf, size_t len, BIGNUM **result)
+{
+       const unsigned char *p = buf, *end = buf + len;
+       uint32_t bnsize;
+       BIGNUM *bn;
+
+       if (p + 4 < p)
+               return -E_BIGNUM;
+       if (p + 4 > end)
+               return -E_BIGNUM;
+       bnsize = read_ssh_u32(p);
+       PARA_DEBUG_LOG("bnsize: %u\n", bnsize);
+       p += 4;
+       if (p + bnsize < p)
+               return -E_BIGNUM;
+       if (p + bnsize > end)
+               return -E_BIGNUM;
+       if (bnsize > 8192)
+               return -E_BIGNUM;
+       bn = BN_bin2bn(p, bnsize, NULL);
+       if (!bn)
+               return -E_BIGNUM;
+       *result = bn;
+       return bnsize + 4;
+}
+
+static int read_rsa_bignums(const unsigned char *blob, int blen, RSA **result)
+{
+       int ret;
+       RSA *rsa;
+       const unsigned char *p = blob, *end = blob + blen;
+       uint32_t rlen;
+
+       *result = NULL;
+       if (p + 4 > end)
+               return -E_BIGNUM;
+       rlen = read_ssh_u32(p);
+       p += 4;
+       if (p + rlen < p)
+               return -E_BIGNUM;
+       if (p + rlen > end)
+               return -E_BIGNUM;
+       if (rlen < strlen(KEY_TYPE_TXT))
+               return -E_BIGNUM;
+       PARA_DEBUG_LOG("type: %s, rlen: %d\n", p, rlen);
+       if (strncmp((char *)p, KEY_TYPE_TXT, strlen(KEY_TYPE_TXT)))
+               return -E_BIGNUM;
+       p += rlen;
+
+       rsa = RSA_new();
+       if (!rsa)
+               return -E_BIGNUM;
+       ret = read_bignum(p, end - p, &rsa->e);
+       if (ret < 0)
+               goto fail;
+       p += ret;
+       ret = read_bignum(p, end - p, &rsa->n);
+       if (ret < 0)
+               goto fail;
+       *result = rsa;
+       return 1;
+fail:
+       if (rsa)
+               RSA_free(rsa);
+       return ret;
+}
+
 /**
  * Read an asymmetric key from a file.
  *
 /**
  * Read an asymmetric key from a file.
  *
@@ -121,21 +400,56 @@ static EVP_PKEY *load_key(const char *file, int private)
 int get_asymmetric_key(const char *key_file, int private,
                struct asymmetric_key **result)
 {
 int get_asymmetric_key(const char *key_file, int private,
                struct asymmetric_key **result)
 {
-       struct asymmetric_key *key;
-       RSA *rsa;
-       EVP_PKEY *pkey = load_key(key_file, private);
+       struct asymmetric_key *key = NULL;
+       void *map = NULL;
+       unsigned char *blob = NULL;
+       size_t map_size, blob_size;
+       int ret, ret2;
+       char *cp;
 
 
-       if (!pkey)
-               return (private == LOAD_PRIVATE_KEY)? -E_PRIVATE_KEY
-                       : -E_PUBLIC_KEY;
-       rsa = EVP_PKEY_get1_RSA(pkey);
-       EVP_PKEY_free(pkey);
-       if (!rsa)
-               return -E_RSA;
        key = para_malloc(sizeof(*key));
        key = para_malloc(sizeof(*key));
-       key->rsa = rsa;
-       *result = key;
-       return RSA_size(rsa);
+       if (private) {
+               ret = get_openssl_key(key_file, &key->rsa, LOAD_PRIVATE_KEY);
+               goto out;
+       }
+       ret = mmap_full_file(key_file, O_RDONLY, &map, &map_size, NULL);
+       if (ret < 0)
+               goto out;
+       ret = is_ssh_rsa_key(map, map_size);
+       if (!ret) {
+               ret = para_munmap(map, map_size);
+               map = NULL;
+               if (ret < 0)
+                       goto out;
+               ret = get_openssl_key(key_file, &key->rsa, LOAD_PUBLIC_KEY);
+               goto out;
+       }
+       cp = map + ret;
+       PARA_INFO_LOG("decoding public rsa-ssh key %s\n", key_file);
+       ret = -ERRNO_TO_PARA_ERROR(EOVERFLOW);
+       if (map_size > INT_MAX / 4)
+               goto out;
+       blob_size = 2 * map_size;
+       blob = para_malloc(blob_size);
+       ret = uudecode(cp, blob, blob_size);
+       if (ret < 0)
+               goto out;
+       ret = read_rsa_bignums(blob, ret, &key->rsa);
+       if (ret < 0)
+               goto out;
+       ret = RSA_size(key->rsa);
+out:
+       ret2 = para_munmap(map, map_size);
+       if (ret >= 0 && ret2 < 0)
+               ret = ret2;
+       if (ret < 0) {
+               free(key);
+               result = NULL;
+               PARA_ERROR_LOG("key %s: %s\n", key_file, para_strerror(-ret));
+       } else
+               *result = key;
+       free(blob);
+       return ret;
 }
 
 /**
 }
 
 /**
diff --git a/error.h b/error.h
index 2e32c24..c3cde04 100644 (file)
--- a/error.h
+++ b/error.h
@@ -364,6 +364,8 @@ extern const char **para_errlist[];
        PARA_ERROR(DECRYPT, "decrypt error"), \
        PARA_ERROR(BLINDING, "failed to activate key blinding"), \
        PARA_ERROR(KEY_PERM, "unprotected private key"), \
        PARA_ERROR(DECRYPT, "decrypt error"), \
        PARA_ERROR(BLINDING, "failed to activate key blinding"), \
        PARA_ERROR(KEY_PERM, "unprotected private key"), \
+       PARA_ERROR(BASE64, "failed to base64-decode ssh private key"), \
+       PARA_ERROR(BIGNUM, "bignum error"), \
 
 
 #define COMMAND_ERRORS \
 
 
 #define COMMAND_ERRORS \