crypto: Integrate message handling into cipher

rsa-aes
Andri Yngvason 2023-08-15 22:56:42 +00:00
parent 8cb4910d76
commit ff5ca722b1
4 changed files with 137 additions and 48 deletions

View File

@ -9,6 +9,7 @@ struct crypto_cipher;
struct crypto_hash; struct crypto_hash;
struct crypto_rsa_pub_key; struct crypto_rsa_pub_key;
struct crypto_rsa_priv_key; struct crypto_rsa_priv_key;
struct vec;
enum crypto_cipher_type { enum crypto_cipher_type {
CRYPTO_CIPHER_INVALID = 0, CRYPTO_CIPHER_INVALID = 0,
@ -51,10 +52,10 @@ struct crypto_cipher* crypto_cipher_new(const uint8_t* enc_key,
const uint8_t* dec_key, enum crypto_cipher_type type); const uint8_t* dec_key, enum crypto_cipher_type type);
void crypto_cipher_del(struct crypto_cipher* self); void crypto_cipher_del(struct crypto_cipher* self);
bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst, bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst,
const uint8_t* src, size_t len);
bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
const uint8_t* src, size_t len); const uint8_t* src, size_t len);
ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
size_t dst_size, const uint8_t* src, size_t len);
void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad, void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad,
size_t len); size_t len);
@ -74,6 +75,9 @@ void crypto_hash_digest(struct crypto_hash* self, uint8_t* dst,
struct crypto_rsa_pub_key* crypto_rsa_pub_key_new(void); struct crypto_rsa_pub_key* crypto_rsa_pub_key_new(void);
void crypto_rsa_pub_key_del(struct crypto_rsa_pub_key*); void crypto_rsa_pub_key_del(struct crypto_rsa_pub_key*);
// Returns length in bytes
size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key);
struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus, struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus,
const uint8_t* exponent, size_t size); const uint8_t* exponent, size_t size);

View File

@ -1,5 +1,6 @@
#include "crypto.h" #include "crypto.h"
#include "neatvnc.h" #include "neatvnc.h"
#include "vec.h"
#include <gmp.h> #include <gmp.h>
#include <nettle/base64.h> #include <nettle/base64.h>
@ -15,10 +16,15 @@
#include <stdbool.h> #include <stdbool.h>
#include <string.h> #include <string.h>
#include <sys/param.h> #include <sys/param.h>
#include <arpa/inet.h>
// TODO: This is linux specific // TODO: This is linux specific
#include <sys/random.h> #include <sys/random.h>
#define UDIV_UP(a, b) (((a) + (b) - 1) / (b))
struct vec;
struct crypto_key { struct crypto_key {
int g; int g;
mpz_t p; mpz_t p;
@ -46,10 +52,13 @@ struct crypto_cipher {
uint8_t mac[16]; uint8_t mac[16];
bool (*encrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src, uint8_t read_buffer[65536];
size_t len); uint8_t read_buffer_len;
bool (*decrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src,
size_t len); bool (*encrypt)(struct crypto_cipher*, struct vec* dst,
const uint8_t* src, size_t len);
ssize_t (*decrypt)(struct crypto_cipher*, uint8_t* dst,
size_t dst_size, const uint8_t* src, size_t len);
}; };
struct crypto_hash { struct crypto_hash {
@ -153,6 +162,7 @@ static size_t crypto_export(uint8_t* dst, size_t dst_size, const mpz_t n)
assert(bytesize <= dst_size); assert(bytesize <= dst_size);
memset(dst, 0, dst_size);
mpz_export(dst + dst_size - bytesize, &bytesize, order, unit_size, mpz_export(dst + dst_size - bytesize, &bytesize, order, unit_size,
endian, skip_bits, n); endian, skip_bits, n);
@ -273,17 +283,20 @@ struct crypto_key* crypto_derive_shared_secret(
} }
static bool crypto_cipher_aes128_ecb_encrypt(struct crypto_cipher* self, static bool crypto_cipher_aes128_ecb_encrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len) struct vec* dst, const uint8_t* src, size_t len)
{ {
aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst, src); vec_reserve(dst, len);
aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst->data, src);
dst->len = len;
return true; return true;
} }
static bool crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self, static ssize_t crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len) uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len)
{ {
assert(dst_size <= len);
aes128_decrypt(&self->dec_ctx.aes128_ecb, len, dst, src); aes128_decrypt(&self->dec_ctx.aes128_ecb, len, dst, src);
return true; return len;
} }
static struct crypto_cipher* crypto_cipher_new_aes128_ecb( static struct crypto_cipher* crypto_cipher_new_aes128_ecb(
@ -321,36 +334,96 @@ static void crypto_aes_eax_update_nonce(struct crypto_aes_eax* self)
} }
static bool crypto_cipher_aes_eax_encrypt(struct crypto_cipher* self, static bool crypto_cipher_aes_eax_encrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len) struct vec* dst, const uint8_t* src, size_t len)
{ {
// size_t msg_max_size = 65535;
size_t msg_max_size = 8192;
size_t n_msg = UDIV_UP(len, msg_max_size);
vec_clear(dst);
vec_reserve(dst, len + n_msg * (2 + 16));
for (size_t i = 0; i < n_msg; ++i) {
size_t msglen = MIN(len - i * msg_max_size, msg_max_size);
uint16_t msglen_be = htons(msglen);
nvnc_trace("msglen %zu", msglen);
vec_append(dst, &msglen_be, sizeof(msglen_be));
crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax); crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax);
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 2,
(uint8_t*)&msglen_be);
nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, msglen,
(uint8_t*)dst->data + dst->len, src + i * msg_max_size);
dst->len += msglen;
if (self->ad) { uint8_t mac[16];
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, sizeof(mac), mac);
self->ad_len, self->ad); vec_append(dst, &mac, sizeof(mac));
} else {
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 0, NULL);
} }
nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, len, dst, src); nvnc_trace("Encrypted buffer of size %zu", dst->len);
nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, 16, self->mac);
return true; return true;
} }
static bool crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self, // TODO: Clean up this mess
uint8_t* dst, const uint8_t* src, size_t len) static ssize_t crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self,
uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len)
{ {
crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax); size_t dst_index = 0;
if (self->ad) { size_t rem = len;
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx,
self->ad_len, self->ad); while (rem) {
} else { size_t space = sizeof(self->read_buffer) - self->read_buffer_len;
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx, 0, NULL); memcpy(self->read_buffer, src + len - rem, MIN(space, rem));
self->read_buffer_len += len;
rem -= MIN(space, rem);
size_t index = 0;
for (;;) {
uint8_t* msg = &self->read_buffer[index];
uint16_t msglen_be;
memcpy(&msglen_be, msg, 2);
size_t msglen = ntohs(msglen_be);
if (self->read_buffer_len - index < msglen + 2) {
break;
} }
nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx, len, dst, src); if (msglen > dst_size - dst_index) {
nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16, self->mac); break;
return true; }
nvnc_trace("Got message of length: %zu", msglen);
crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax);
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx,
2, (uint8_t*)&msglen_be);
nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx,
msglen, dst + dst_index, msg + 2);
dst_index += msglen;
assert(dst_index <= len);
uint8_t expected_mac[16];
nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16,
expected_mac);
uint8_t *mac = msg + 2 + msglen;
if (memcmp(expected_mac, mac, 16) != 0)
return -1; // Authentication failure
index += msglen + 2 + 16;
}
self->read_buffer_len -= index;
memmove(self->read_buffer, self->read_buffer + index,
self->read_buffer_len);
}
return dst_index;
} }
static struct crypto_cipher* crypto_cipher_new_aes_eax(const uint8_t* enc_key, static struct crypto_cipher* crypto_cipher_new_aes_eax(const uint8_t* enc_key,
@ -390,16 +463,16 @@ void crypto_cipher_del(struct crypto_cipher* self)
free(self); free(self);
} }
bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst, bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst,
const uint8_t* src, size_t len) const uint8_t* src, size_t len)
{ {
return self->encrypt(self, dst, src, len); return self->encrypt(self, dst, src, len);
} }
bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst, ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
const uint8_t* src, size_t len) size_t dst_size, const uint8_t* src, size_t len)
{ {
return self->decrypt(self, dst, src, len); return self->decrypt(self, dst, dst_size, src, len);
} }
void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad, void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad,
@ -480,9 +553,12 @@ struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus,
if (!self) if (!self)
return NULL; return NULL;
rsa_public_key_init(&self->key);
mpz_init(self->key.n);
crypto_import(self->key.n, modulus, size); crypto_import(self->key.n, modulus, size);
mpz_init(self->key.e);
crypto_import(self->key.e, exponent, size); crypto_import(self->key.e, exponent, size);
self->key.size = rsa_public_key_prepare(&self->key); rsa_public_key_prepare(&self->key);
return self; return self;
} }
@ -519,6 +595,11 @@ void crypto_rsa_priv_key_del(struct crypto_rsa_priv_key* self)
free(self); free(self);
} }
size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key)
{
return key->key.size;
}
static void generate_random_for_rsa(void* random_ctx, size_t len, uint8_t* dst) static void generate_random_for_rsa(void* random_ctx, size_t len, uint8_t* dst)
{ {
getrandom(dst, len, 0); getrandom(dst, len, 0);

View File

@ -462,7 +462,7 @@ static int on_apple_dh_response(struct nvnc_client* client)
char username[128] = {}; char username[128] = {};
char* password = username + 64; char* password = username + 64;
crypto_cipher_decrypt(cipher, (uint8_t*)username, crypto_cipher_decrypt(cipher, (uint8_t*)username, sizeof(username),
msg->encrypted_credentials, sizeof(username)); msg->encrypted_credentials, sizeof(username));
username[63] = '\0'; username[63] = '\0';
username[127] = '\0'; username[127] = '\0';

View File

@ -39,11 +39,10 @@ static_assert(sizeof(struct stream) <= STREAM_ALLOC_SIZE,
static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload) static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload)
{ {
uint8_t* ciphertext = malloc(payload->size); struct vec ciphertext = {};
assert(ciphertext); crypto_cipher_encrypt(self->cipher, &ciphertext, payload->payload,
crypto_cipher_encrypt(self->cipher, ciphertext, payload->payload,
payload->size); payload->size);
struct rcbuf* result = rcbuf_new(ciphertext, payload->size); struct rcbuf* result = rcbuf_new(ciphertext.data, ciphertext.len);
rcbuf_unref(payload); rcbuf_unref(payload);
return result; return result;
} }
@ -217,13 +216,18 @@ static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size)
if (rc > 0) if (rc > 0)
self->bytes_received += rc; self->bytes_received += rc;
if (rc > 0 && self->cipher && !crypto_cipher_decrypt(self->cipher, dst, if (rc > 0 && self->cipher) {
read_buffer, rc)) { nvnc_trace("Got cipher text of length %zd", rc);
ssize_t len = crypto_cipher_decrypt(self->cipher, dst, size,
read_buffer, rc);
if (len < 0) {
nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!"); nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!");
stream__remote_closed(self); stream__remote_closed(self);
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
rc = len;
}
return rc; return rc;
} }