server: Allow arbitrary RSA key length

rsa-aes
Andri Yngvason 2023-09-10 17:39:12 +00:00
parent 3c3de5f323
commit 89b759c838
1 changed files with 36 additions and 31 deletions

View File

@ -68,7 +68,7 @@
#define DEFAULT_NAME "Neat VNC" #define DEFAULT_NAME "Neat VNC"
#define SECURITY_TYPES_MAX 3 #define SECURITY_TYPES_MAX 3
#define RSA_AES_SERVER_KEY_LENGTH 256 #define APPLE_DH_SERVER_KEY_LENGTH 256
#define UDIV_UP(a, b) (((a) + (b) - 1) / (b)) #define UDIV_UP(a, b) (((a) + (b) - 1) / (b))
@ -409,11 +409,11 @@ static int apple_dh_send_public_key(struct nvnc_client* client)
crypto_derive_public_key(client->apple_dh_secret); crypto_derive_public_key(client->apple_dh_secret);
assert(pub); assert(pub);
uint8_t mod[RSA_AES_SERVER_KEY_LENGTH] = {}; uint8_t mod[APPLE_DH_SERVER_KEY_LENGTH] = {};
int mod_len = crypto_key_p(pub, mod, sizeof(mod)); int mod_len = crypto_key_p(pub, mod, sizeof(mod));
assert(mod_len == sizeof(mod)); assert(mod_len == sizeof(mod));
uint8_t q[RSA_AES_SERVER_KEY_LENGTH] = {}; uint8_t q[APPLE_DH_SERVER_KEY_LENGTH] = {};
int q_len = crypto_key_q(pub, q, sizeof(q)); int q_len = crypto_key_q(pub, q, sizeof(q));
assert(q_len == sizeof(q)); assert(q_len == sizeof(q));
@ -437,7 +437,7 @@ static int on_apple_dh_response(struct nvnc_client* client)
struct rfb_apple_dh_client_msg* msg = struct rfb_apple_dh_client_msg* msg =
(void*)(client->msg_buffer + client->buffer_index); (void*)(client->msg_buffer + client->buffer_index);
uint8_t p[RSA_AES_SERVER_KEY_LENGTH]; uint8_t p[APPLE_DH_SERVER_KEY_LENGTH];
int key_len = crypto_key_p(client->apple_dh_secret, p, sizeof(p)); int key_len = crypto_key_p(client->apple_dh_secret, p, sizeof(p));
assert(key_len == sizeof(p)); assert(key_len == sizeof(p));
@ -454,7 +454,7 @@ static int on_apple_dh_response(struct nvnc_client* client)
crypto_derive_shared_secret(client->apple_dh_secret, remote_key); crypto_derive_shared_secret(client->apple_dh_secret, remote_key);
assert(shared_secret); assert(shared_secret);
uint8_t shared_buf[RSA_AES_SERVER_KEY_LENGTH]; uint8_t shared_buf[APPLE_DH_SERVER_KEY_LENGTH];
crypto_key_q(shared_secret, shared_buf, sizeof(shared_buf)); crypto_key_q(shared_secret, shared_buf, sizeof(shared_buf));
crypto_key_del(shared_secret); crypto_key_del(shared_secret);
@ -504,22 +504,22 @@ static int rsa_aes_send_public_key(struct nvnc_client* client)
} }
assert(server->rsa_pub && server->rsa_priv); assert(server->rsa_pub && server->rsa_priv);
char buffer[sizeof(struct rfb_rsa_aes_pub_key_msg) + size_t key_len = crypto_rsa_pub_key_length(server->rsa_pub);
RSA_AES_SERVER_KEY_LENGTH * 2] = {}; size_t buf_len = sizeof(struct rfb_rsa_aes_pub_key_msg) + key_len * 2;
char* buffer = calloc(1, buf_len);
assert(buffer);
struct rfb_rsa_aes_pub_key_msg* msg = struct rfb_rsa_aes_pub_key_msg* msg =
(struct rfb_rsa_aes_pub_key_msg*)buffer; (struct rfb_rsa_aes_pub_key_msg*)buffer;
uint8_t* modulus = msg->modulus_and_exponent; uint8_t* modulus = msg->modulus_and_exponent;
uint8_t* exponent = msg->modulus_and_exponent + uint8_t* exponent = msg->modulus_and_exponent + key_len;
RSA_AES_SERVER_KEY_LENGTH;
msg->length = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); msg->length = htonl(key_len * 8);
crypto_rsa_pub_key_modulus(server->rsa_pub, modulus, crypto_rsa_pub_key_modulus(server->rsa_pub, modulus, key_len);
RSA_AES_SERVER_KEY_LENGTH); crypto_rsa_pub_key_exponent(server->rsa_pub, exponent, key_len);
crypto_rsa_pub_key_exponent(server->rsa_pub, exponent,
RSA_AES_SERVER_KEY_LENGTH);
stream_write(client->net_stream, buffer, sizeof(buffer), NULL, NULL); stream_send(client->net_stream, rcbuf_new(buffer, buf_len), NULL, NULL);
return 0; return 0;
} }
@ -533,8 +533,8 @@ static int rsa_aes_send_challenge(struct nvnc_client* client,
(struct rfb_rsa_aes_challenge_msg*)buffer; (struct rfb_rsa_aes_challenge_msg*)buffer;
ssize_t len = crypto_rsa_encrypt(pub, msg->challenge, ssize_t len = crypto_rsa_encrypt(pub, msg->challenge,
RSA_AES_SERVER_KEY_LENGTH, client->rsa.challenge, crypto_rsa_pub_key_length(client->rsa.pub),
client->rsa.challenge_len); client->rsa.challenge, client->rsa.challenge_len);
msg->length = htons(len); msg->length = htons(len);
nvnc_trace("Challenge length is %zd", len); nvnc_trace("Challenge length is %zd", len);
@ -638,12 +638,14 @@ static int on_rsa_aes_challenge(struct nvnc_client* client)
stream_upgrade_to_rsa_eas(client->net_stream, client->rsa.cipher_type, stream_upgrade_to_rsa_eas(client->net_stream, client->rsa.cipher_type,
server_session_key, client_session_key); server_session_key, client_session_key);
uint8_t server_modulus[RSA_AES_SERVER_KEY_LENGTH]; size_t server_key_len = crypto_rsa_pub_key_length(server->rsa_pub);
uint8_t server_exponent[RSA_AES_SERVER_KEY_LENGTH]; uint8_t* server_modulus = malloc(server_key_len * 2);
uint8_t* server_exponent = server_modulus + server_key_len;
crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus, crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus,
RSA_AES_SERVER_KEY_LENGTH); server_key_len);
crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent, crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent,
RSA_AES_SERVER_KEY_LENGTH); server_key_len);
size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub); size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub);
uint8_t* client_modulus = malloc(client_key_len * 2); uint8_t* client_modulus = malloc(client_key_len * 2);
@ -654,21 +656,22 @@ static int on_rsa_aes_challenge(struct nvnc_client* client)
crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent, crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent,
client_key_len); client_key_len);
uint32_t server_key_len_be = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); uint32_t server_key_len_be = htonl(server_key_len * 8);
uint32_t client_key_len_be = htonl(client_key_len * 8); uint32_t client_key_len_be = htonl(client_key_len * 8);
uint8_t server_hash[32] = {}; uint8_t server_hash[32] = {};
crypto_hash_many(server_hash, client_rsa_aes_hash_len(client), crypto_hash_many(server_hash, client_rsa_aes_hash_len(client),
client->rsa.hash_type, (const struct crypto_data_entry[]) { client->rsa.hash_type, (const struct crypto_data_entry[]) {
{ (uint8_t*)&server_key_len_be, 4 }, { (uint8_t*)&server_key_len_be, 4 },
{ server_modulus, RSA_AES_SERVER_KEY_LENGTH }, { server_modulus, server_key_len },
{ server_exponent, RSA_AES_SERVER_KEY_LENGTH }, { server_exponent, server_key_len },
{ (uint8_t*)&client_key_len_be, 4 }, { (uint8_t*)&client_key_len_be, 4 },
{ client_modulus, client_key_len }, { client_modulus, client_key_len },
{ client_exponent, client_key_len }, { client_exponent, client_key_len },
{} {}
}); });
free(server_modulus);
free(client_modulus); free(client_modulus);
stream_write(client->net_stream, server_hash, stream_write(client->net_stream, server_hash,
@ -688,12 +691,13 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client)
struct nvnc* server = client->server; struct nvnc* server = client->server;
uint8_t server_modulus[RSA_AES_SERVER_KEY_LENGTH]; size_t server_key_len = crypto_rsa_pub_key_length(server->rsa_pub);
uint8_t server_exponent[RSA_AES_SERVER_KEY_LENGTH]; uint8_t* server_modulus = malloc(server_key_len * 2);
uint8_t* server_exponent = server_modulus + server_key_len;
crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus, crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus,
RSA_AES_SERVER_KEY_LENGTH); server_key_len);
crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent, crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent,
RSA_AES_SERVER_KEY_LENGTH); server_key_len);
size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub); size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub);
uint8_t* client_modulus = malloc(client_key_len * 2); uint8_t* client_modulus = malloc(client_key_len * 2);
@ -704,7 +708,7 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client)
crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent, crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent,
client_key_len); client_key_len);
uint32_t server_key_len_be = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); uint32_t server_key_len_be = htonl(server_key_len * 8);
uint32_t client_key_len_be = htonl(client_key_len * 8); uint32_t client_key_len_be = htonl(client_key_len * 8);
uint8_t client_hash[32] = {}; uint8_t client_hash[32] = {};
@ -714,12 +718,13 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client)
{ client_modulus, client_key_len }, { client_modulus, client_key_len },
{ client_exponent, client_key_len }, { client_exponent, client_key_len },
{ (uint8_t*)&server_key_len_be, 4 }, { (uint8_t*)&server_key_len_be, 4 },
{ server_modulus, RSA_AES_SERVER_KEY_LENGTH }, { server_modulus, server_key_len },
{ server_exponent, RSA_AES_SERVER_KEY_LENGTH }, { server_exponent, server_key_len },
{} {}
}); });
free(client_modulus); free(client_modulus);
free(server_modulus);
if (memcmp(msg, client_hash, client_rsa_aes_hash_len(client)) != 0) { if (memcmp(msg, client_hash, client_rsa_aes_hash_len(client)) != 0) {
nvnc_log(NVNC_LOG_INFO, "Client hash mismatch"); nvnc_log(NVNC_LOG_INFO, "Client hash mismatch");