diff --git a/include/crypto.h b/include/crypto.h index 958539f..909dbe5 100644 --- a/include/crypto.h +++ b/include/crypto.h @@ -9,6 +9,7 @@ struct crypto_cipher; struct crypto_hash; struct crypto_rsa_pub_key; struct crypto_rsa_priv_key; +struct vec; enum crypto_cipher_type { CRYPTO_CIPHER_INVALID = 0, @@ -51,10 +52,10 @@ struct crypto_cipher* crypto_cipher_new(const uint8_t* enc_key, const uint8_t* dec_key, enum crypto_cipher_type type); void crypto_cipher_del(struct crypto_cipher* self); -bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst, - const uint8_t* src, size_t len); -bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst, +bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst, const uint8_t* src, size_t len); +ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst, + size_t dst_size, const uint8_t* src, size_t len); void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad, size_t len); @@ -74,6 +75,9 @@ void crypto_hash_digest(struct crypto_hash* self, uint8_t* dst, struct crypto_rsa_pub_key* crypto_rsa_pub_key_new(void); void crypto_rsa_pub_key_del(struct crypto_rsa_pub_key*); +// Returns length in bytes +size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key); + struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus, const uint8_t* exponent, size_t size); diff --git a/src/crypto-nettle.c b/src/crypto-nettle.c index e697a15..2ff281f 100644 --- a/src/crypto-nettle.c +++ b/src/crypto-nettle.c @@ -1,5 +1,6 @@ #include "crypto.h" #include "neatvnc.h" +#include "vec.h" #include #include @@ -15,10 +16,15 @@ #include #include #include +#include // TODO: This is linux specific #include +#define UDIV_UP(a, b) (((a) + (b) - 1) / (b)) + +struct vec; + struct crypto_key { int g; mpz_t p; @@ -46,10 +52,13 @@ struct crypto_cipher { uint8_t mac[16]; - bool (*encrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src, - size_t len); - bool (*decrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src, - size_t len); + uint8_t read_buffer[65536]; + uint8_t read_buffer_len; + + bool (*encrypt)(struct crypto_cipher*, struct vec* dst, + const uint8_t* src, size_t len); + ssize_t (*decrypt)(struct crypto_cipher*, uint8_t* dst, + size_t dst_size, const uint8_t* src, size_t len); }; struct crypto_hash { @@ -153,6 +162,7 @@ static size_t crypto_export(uint8_t* dst, size_t dst_size, const mpz_t n) assert(bytesize <= dst_size); + memset(dst, 0, dst_size); mpz_export(dst + dst_size - bytesize, &bytesize, order, unit_size, endian, skip_bits, n); @@ -273,17 +283,20 @@ struct crypto_key* crypto_derive_shared_secret( } static bool crypto_cipher_aes128_ecb_encrypt(struct crypto_cipher* self, - uint8_t* dst, const uint8_t* src, size_t len) + struct vec* dst, const uint8_t* src, size_t len) { - aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst, src); + vec_reserve(dst, len); + aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst->data, src); + dst->len = len; return true; } -static bool crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self, - uint8_t* dst, const uint8_t* src, size_t len) +static ssize_t crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self, + uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len) { + assert(dst_size <= len); aes128_decrypt(&self->dec_ctx.aes128_ecb, len, dst, src); - return true; + return len; } static struct crypto_cipher* crypto_cipher_new_aes128_ecb( @@ -321,36 +334,96 @@ static void crypto_aes_eax_update_nonce(struct crypto_aes_eax* self) } static bool crypto_cipher_aes_eax_encrypt(struct crypto_cipher* self, - uint8_t* dst, const uint8_t* src, size_t len) + struct vec* dst, const uint8_t* src, size_t len) { - crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax); +// size_t msg_max_size = 65535; + size_t msg_max_size = 8192; + size_t n_msg = UDIV_UP(len, msg_max_size); - if (self->ad) { - nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, - self->ad_len, self->ad); - } else { - nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 0, NULL); + vec_clear(dst); + vec_reserve(dst, len + n_msg * (2 + 16)); + + for (size_t i = 0; i < n_msg; ++i) { + size_t msglen = MIN(len - i * msg_max_size, msg_max_size); + uint16_t msglen_be = htons(msglen); + nvnc_trace("msglen %zu", msglen); + + vec_append(dst, &msglen_be, sizeof(msglen_be)); + + crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax); + nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 2, + (uint8_t*)&msglen_be); + nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, msglen, + (uint8_t*)dst->data + dst->len, src + i * msg_max_size); + dst->len += msglen; + + uint8_t mac[16]; + nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, sizeof(mac), mac); + vec_append(dst, &mac, sizeof(mac)); } - nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, len, dst, src); - nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, 16, self->mac); + nvnc_trace("Encrypted buffer of size %zu", dst->len); + return true; } -static bool crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self, - uint8_t* dst, const uint8_t* src, size_t len) +// TODO: Clean up this mess +static ssize_t crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self, + uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len) { - crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax); - if (self->ad) { - nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx, - self->ad_len, self->ad); - } else { - nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx, 0, NULL); + size_t dst_index = 0; + size_t rem = len; + + while (rem) { + size_t space = sizeof(self->read_buffer) - self->read_buffer_len; + memcpy(self->read_buffer, src + len - rem, MIN(space, rem)); + self->read_buffer_len += len; + + rem -= MIN(space, rem); + + size_t index = 0; + for (;;) { + uint8_t* msg = &self->read_buffer[index]; + uint16_t msglen_be; + memcpy(&msglen_be, msg, 2); + size_t msglen = ntohs(msglen_be); + + if (self->read_buffer_len - index < msglen + 2) { + break; + } + + if (msglen > dst_size - dst_index) { + break; + } + + nvnc_trace("Got message of length: %zu", msglen); + + crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax); + nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx, + 2, (uint8_t*)&msglen_be); + + nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx, + msglen, dst + dst_index, msg + 2); + dst_index += msglen; + assert(dst_index <= len); + + uint8_t expected_mac[16]; + nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16, + expected_mac); + + uint8_t *mac = msg + 2 + msglen; + if (memcmp(expected_mac, mac, 16) != 0) + return -1; // Authentication failure + + index += msglen + 2 + 16; + } + + self->read_buffer_len -= index; + memmove(self->read_buffer, self->read_buffer + index, + self->read_buffer_len); } - nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx, len, dst, src); - nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16, self->mac); - return true; + return dst_index; } static struct crypto_cipher* crypto_cipher_new_aes_eax(const uint8_t* enc_key, @@ -390,16 +463,16 @@ void crypto_cipher_del(struct crypto_cipher* self) free(self); } -bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst, +bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst, const uint8_t* src, size_t len) { return self->encrypt(self, dst, src, len); } -bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst, - const uint8_t* src, size_t len) +ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst, + size_t dst_size, const uint8_t* src, size_t len) { - return self->decrypt(self, dst, src, len); + return self->decrypt(self, dst, dst_size, src, len); } void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad, @@ -480,9 +553,12 @@ struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus, if (!self) return NULL; + rsa_public_key_init(&self->key); + mpz_init(self->key.n); crypto_import(self->key.n, modulus, size); + mpz_init(self->key.e); crypto_import(self->key.e, exponent, size); - self->key.size = rsa_public_key_prepare(&self->key); + rsa_public_key_prepare(&self->key); return self; } @@ -519,6 +595,11 @@ void crypto_rsa_priv_key_del(struct crypto_rsa_priv_key* self) free(self); } +size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key) +{ + return key->key.size; +} + static void generate_random_for_rsa(void* random_ctx, size_t len, uint8_t* dst) { getrandom(dst, len, 0); diff --git a/src/server.c b/src/server.c index e03f648..68d669e 100644 --- a/src/server.c +++ b/src/server.c @@ -462,7 +462,7 @@ static int on_apple_dh_response(struct nvnc_client* client) char username[128] = {}; char* password = username + 64; - crypto_cipher_decrypt(cipher, (uint8_t*)username, + crypto_cipher_decrypt(cipher, (uint8_t*)username, sizeof(username), msg->encrypted_credentials, sizeof(username)); username[63] = '\0'; username[127] = '\0'; diff --git a/src/stream-tcp.c b/src/stream-tcp.c index 07af480..9b06703 100644 --- a/src/stream-tcp.c +++ b/src/stream-tcp.c @@ -39,11 +39,10 @@ static_assert(sizeof(struct stream) <= STREAM_ALLOC_SIZE, static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload) { - uint8_t* ciphertext = malloc(payload->size); - assert(ciphertext); - crypto_cipher_encrypt(self->cipher, ciphertext, payload->payload, + struct vec ciphertext = {}; + crypto_cipher_encrypt(self->cipher, &ciphertext, payload->payload, payload->size); - struct rcbuf* result = rcbuf_new(ciphertext, payload->size); + struct rcbuf* result = rcbuf_new(ciphertext.data, ciphertext.len); rcbuf_unref(payload); return result; } @@ -217,12 +216,17 @@ static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) if (rc > 0) self->bytes_received += rc; - if (rc > 0 && self->cipher && !crypto_cipher_decrypt(self->cipher, dst, - read_buffer, rc)) { - nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!"); - stream__remote_closed(self); - errno = EPROTO; - return -1; + if (rc > 0 && self->cipher) { + nvnc_trace("Got cipher text of length %zd", rc); + ssize_t len = crypto_cipher_decrypt(self->cipher, dst, size, + read_buffer, rc); + if (len < 0) { + nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!"); + stream__remote_closed(self); + errno = EPROTO; + return -1; + } + rc = len; } return rc;