crypto: Integrate message handling into cipher

pull/100/head
Andri Yngvason 2023-08-15 22:56:42 +00:00
parent c12c1c800a
commit 71aa5acfde
4 changed files with 137 additions and 48 deletions

View File

@ -9,6 +9,7 @@ struct crypto_cipher;
struct crypto_hash;
struct crypto_rsa_pub_key;
struct crypto_rsa_priv_key;
struct vec;
enum crypto_cipher_type {
CRYPTO_CIPHER_INVALID = 0,
@ -51,10 +52,10 @@ struct crypto_cipher* crypto_cipher_new(const uint8_t* enc_key,
const uint8_t* dec_key, enum crypto_cipher_type type);
void crypto_cipher_del(struct crypto_cipher* self);
bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst,
const uint8_t* src, size_t len);
bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst,
const uint8_t* src, size_t len);
ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
size_t dst_size, const uint8_t* src, size_t len);
void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad,
size_t len);
@ -74,6 +75,9 @@ void crypto_hash_digest(struct crypto_hash* self, uint8_t* dst,
struct crypto_rsa_pub_key* crypto_rsa_pub_key_new(void);
void crypto_rsa_pub_key_del(struct crypto_rsa_pub_key*);
// Returns length in bytes
size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key);
struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus,
const uint8_t* exponent, size_t size);

View File

@ -1,5 +1,6 @@
#include "crypto.h"
#include "neatvnc.h"
#include "vec.h"
#include <gmp.h>
#include <nettle/base64.h>
@ -15,10 +16,15 @@
#include <stdbool.h>
#include <string.h>
#include <sys/param.h>
#include <arpa/inet.h>
// TODO: This is linux specific
#include <sys/random.h>
#define UDIV_UP(a, b) (((a) + (b) - 1) / (b))
struct vec;
struct crypto_key {
int g;
mpz_t p;
@ -46,10 +52,13 @@ struct crypto_cipher {
uint8_t mac[16];
bool (*encrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src,
size_t len);
bool (*decrypt)(struct crypto_cipher*, uint8_t* dst, const uint8_t* src,
size_t len);
uint8_t read_buffer[65536];
uint8_t read_buffer_len;
bool (*encrypt)(struct crypto_cipher*, struct vec* dst,
const uint8_t* src, size_t len);
ssize_t (*decrypt)(struct crypto_cipher*, uint8_t* dst,
size_t dst_size, const uint8_t* src, size_t len);
};
struct crypto_hash {
@ -153,6 +162,7 @@ static size_t crypto_export(uint8_t* dst, size_t dst_size, const mpz_t n)
assert(bytesize <= dst_size);
memset(dst, 0, dst_size);
mpz_export(dst + dst_size - bytesize, &bytesize, order, unit_size,
endian, skip_bits, n);
@ -273,17 +283,20 @@ struct crypto_key* crypto_derive_shared_secret(
}
static bool crypto_cipher_aes128_ecb_encrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len)
struct vec* dst, const uint8_t* src, size_t len)
{
aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst, src);
vec_reserve(dst, len);
aes128_encrypt(&self->enc_ctx.aes128_ecb, len, dst->data, src);
dst->len = len;
return true;
}
static bool crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len)
static ssize_t crypto_cipher_aes128_ecb_decrypt(struct crypto_cipher* self,
uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len)
{
assert(dst_size <= len);
aes128_decrypt(&self->dec_ctx.aes128_ecb, len, dst, src);
return true;
return len;
}
static struct crypto_cipher* crypto_cipher_new_aes128_ecb(
@ -321,36 +334,96 @@ static void crypto_aes_eax_update_nonce(struct crypto_aes_eax* self)
}
static bool crypto_cipher_aes_eax_encrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len)
struct vec* dst, const uint8_t* src, size_t len)
{
crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax);
// size_t msg_max_size = 65535;
size_t msg_max_size = 8192;
size_t n_msg = UDIV_UP(len, msg_max_size);
if (self->ad) {
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx,
self->ad_len, self->ad);
} else {
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 0, NULL);
vec_clear(dst);
vec_reserve(dst, len + n_msg * (2 + 16));
for (size_t i = 0; i < n_msg; ++i) {
size_t msglen = MIN(len - i * msg_max_size, msg_max_size);
uint16_t msglen_be = htons(msglen);
nvnc_trace("msglen %zu", msglen);
vec_append(dst, &msglen_be, sizeof(msglen_be));
crypto_aes_eax_update_nonce(&self->enc_ctx.aes_eax);
nettle_eax_aes128_update(&self->enc_ctx.aes_eax.ctx, 2,
(uint8_t*)&msglen_be);
nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, msglen,
(uint8_t*)dst->data + dst->len, src + i * msg_max_size);
dst->len += msglen;
uint8_t mac[16];
nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, sizeof(mac), mac);
vec_append(dst, &mac, sizeof(mac));
}
nettle_eax_aes128_encrypt(&self->enc_ctx.aes_eax.ctx, len, dst, src);
nettle_eax_aes128_digest(&self->enc_ctx.aes_eax.ctx, 16, self->mac);
nvnc_trace("Encrypted buffer of size %zu", dst->len);
return true;
}
static bool crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self,
uint8_t* dst, const uint8_t* src, size_t len)
// TODO: Clean up this mess
static ssize_t crypto_cipher_aes_eax_decrypt(struct crypto_cipher* self,
uint8_t* dst, size_t dst_size, const uint8_t* src, size_t len)
{
crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax);
if (self->ad) {
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx,
self->ad_len, self->ad);
} else {
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx, 0, NULL);
size_t dst_index = 0;
size_t rem = len;
while (rem) {
size_t space = sizeof(self->read_buffer) - self->read_buffer_len;
memcpy(self->read_buffer, src + len - rem, MIN(space, rem));
self->read_buffer_len += len;
rem -= MIN(space, rem);
size_t index = 0;
for (;;) {
uint8_t* msg = &self->read_buffer[index];
uint16_t msglen_be;
memcpy(&msglen_be, msg, 2);
size_t msglen = ntohs(msglen_be);
if (self->read_buffer_len - index < msglen + 2) {
break;
}
if (msglen > dst_size - dst_index) {
break;
}
nvnc_trace("Got message of length: %zu", msglen);
crypto_aes_eax_update_nonce(&self->dec_ctx.aes_eax);
nettle_eax_aes128_update(&self->dec_ctx.aes_eax.ctx,
2, (uint8_t*)&msglen_be);
nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx,
msglen, dst + dst_index, msg + 2);
dst_index += msglen;
assert(dst_index <= len);
uint8_t expected_mac[16];
nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16,
expected_mac);
uint8_t *mac = msg + 2 + msglen;
if (memcmp(expected_mac, mac, 16) != 0)
return -1; // Authentication failure
index += msglen + 2 + 16;
}
self->read_buffer_len -= index;
memmove(self->read_buffer, self->read_buffer + index,
self->read_buffer_len);
}
nettle_eax_aes128_decrypt(&self->dec_ctx.aes_eax.ctx, len, dst, src);
nettle_eax_aes128_digest(&self->dec_ctx.aes_eax.ctx, 16, self->mac);
return true;
return dst_index;
}
static struct crypto_cipher* crypto_cipher_new_aes_eax(const uint8_t* enc_key,
@ -390,16 +463,16 @@ void crypto_cipher_del(struct crypto_cipher* self)
free(self);
}
bool crypto_cipher_encrypt(struct crypto_cipher* self, uint8_t* dst,
bool crypto_cipher_encrypt(struct crypto_cipher* self, struct vec* dst,
const uint8_t* src, size_t len)
{
return self->encrypt(self, dst, src, len);
}
bool crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
const uint8_t* src, size_t len)
ssize_t crypto_cipher_decrypt(struct crypto_cipher* self, uint8_t* dst,
size_t dst_size, const uint8_t* src, size_t len)
{
return self->decrypt(self, dst, src, len);
return self->decrypt(self, dst, dst_size, src, len);
}
void crypto_cipher_set_ad(struct crypto_cipher* self, const uint8_t* ad,
@ -480,9 +553,12 @@ struct crypto_rsa_pub_key* crypto_rsa_pub_key_import(const uint8_t* modulus,
if (!self)
return NULL;
rsa_public_key_init(&self->key);
mpz_init(self->key.n);
crypto_import(self->key.n, modulus, size);
mpz_init(self->key.e);
crypto_import(self->key.e, exponent, size);
self->key.size = rsa_public_key_prepare(&self->key);
rsa_public_key_prepare(&self->key);
return self;
}
@ -519,6 +595,11 @@ void crypto_rsa_priv_key_del(struct crypto_rsa_priv_key* self)
free(self);
}
size_t crypto_rsa_pub_key_length(const struct crypto_rsa_pub_key* key)
{
return key->key.size;
}
static void generate_random_for_rsa(void* random_ctx, size_t len, uint8_t* dst)
{
getrandom(dst, len, 0);

View File

@ -462,7 +462,7 @@ static int on_apple_dh_response(struct nvnc_client* client)
char username[128] = {};
char* password = username + 64;
crypto_cipher_decrypt(cipher, (uint8_t*)username,
crypto_cipher_decrypt(cipher, (uint8_t*)username, sizeof(username),
msg->encrypted_credentials, sizeof(username));
username[63] = '\0';
username[127] = '\0';

View File

@ -39,11 +39,10 @@ static_assert(sizeof(struct stream) <= STREAM_ALLOC_SIZE,
static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload)
{
uint8_t* ciphertext = malloc(payload->size);
assert(ciphertext);
crypto_cipher_encrypt(self->cipher, ciphertext, payload->payload,
struct vec ciphertext = {};
crypto_cipher_encrypt(self->cipher, &ciphertext, payload->payload,
payload->size);
struct rcbuf* result = rcbuf_new(ciphertext, payload->size);
struct rcbuf* result = rcbuf_new(ciphertext.data, ciphertext.len);
rcbuf_unref(payload);
return result;
}
@ -217,12 +216,17 @@ static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size)
if (rc > 0)
self->bytes_received += rc;
if (rc > 0 && self->cipher && !crypto_cipher_decrypt(self->cipher, dst,
read_buffer, rc)) {
nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!");
stream__remote_closed(self);
errno = EPROTO;
return -1;
if (rc > 0 && self->cipher) {
nvnc_trace("Got cipher text of length %zd", rc);
ssize_t len = crypto_cipher_decrypt(self->cipher, dst, size,
read_buffer, rc);
if (len < 0) {
nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!");
stream__remote_closed(self);
errno = EPROTO;
return -1;
}
rc = len;
}
return rc;