diff --git a/src/server.c b/src/server.c index 5a8c1cc..da40ca3 100644 --- a/src/server.c +++ b/src/server.c @@ -68,7 +68,7 @@ #define DEFAULT_NAME "Neat VNC" #define SECURITY_TYPES_MAX 3 -#define RSA_AES_SERVER_KEY_LENGTH 256 +#define APPLE_DH_SERVER_KEY_LENGTH 256 #define UDIV_UP(a, b) (((a) + (b) - 1) / (b)) @@ -409,11 +409,11 @@ static int apple_dh_send_public_key(struct nvnc_client* client) crypto_derive_public_key(client->apple_dh_secret); assert(pub); - uint8_t mod[RSA_AES_SERVER_KEY_LENGTH] = {}; + uint8_t mod[APPLE_DH_SERVER_KEY_LENGTH] = {}; int mod_len = crypto_key_p(pub, mod, sizeof(mod)); assert(mod_len == sizeof(mod)); - uint8_t q[RSA_AES_SERVER_KEY_LENGTH] = {}; + uint8_t q[APPLE_DH_SERVER_KEY_LENGTH] = {}; int q_len = crypto_key_q(pub, q, sizeof(q)); assert(q_len == sizeof(q)); @@ -437,7 +437,7 @@ static int on_apple_dh_response(struct nvnc_client* client) struct rfb_apple_dh_client_msg* msg = (void*)(client->msg_buffer + client->buffer_index); - uint8_t p[RSA_AES_SERVER_KEY_LENGTH]; + uint8_t p[APPLE_DH_SERVER_KEY_LENGTH]; int key_len = crypto_key_p(client->apple_dh_secret, p, sizeof(p)); assert(key_len == sizeof(p)); @@ -454,7 +454,7 @@ static int on_apple_dh_response(struct nvnc_client* client) crypto_derive_shared_secret(client->apple_dh_secret, remote_key); assert(shared_secret); - uint8_t shared_buf[RSA_AES_SERVER_KEY_LENGTH]; + uint8_t shared_buf[APPLE_DH_SERVER_KEY_LENGTH]; crypto_key_q(shared_secret, shared_buf, sizeof(shared_buf)); crypto_key_del(shared_secret); @@ -504,22 +504,22 @@ static int rsa_aes_send_public_key(struct nvnc_client* client) } assert(server->rsa_pub && server->rsa_priv); - char buffer[sizeof(struct rfb_rsa_aes_pub_key_msg) + - RSA_AES_SERVER_KEY_LENGTH * 2] = {}; + size_t key_len = crypto_rsa_pub_key_length(server->rsa_pub); + size_t buf_len = sizeof(struct rfb_rsa_aes_pub_key_msg) + key_len * 2; + + char* buffer = calloc(1, buf_len); + assert(buffer); struct rfb_rsa_aes_pub_key_msg* msg = (struct rfb_rsa_aes_pub_key_msg*)buffer; uint8_t* modulus = msg->modulus_and_exponent; - uint8_t* exponent = msg->modulus_and_exponent + - RSA_AES_SERVER_KEY_LENGTH; + uint8_t* exponent = msg->modulus_and_exponent + key_len; - msg->length = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); - crypto_rsa_pub_key_modulus(server->rsa_pub, modulus, - RSA_AES_SERVER_KEY_LENGTH); - crypto_rsa_pub_key_exponent(server->rsa_pub, exponent, - RSA_AES_SERVER_KEY_LENGTH); + msg->length = htonl(key_len * 8); + crypto_rsa_pub_key_modulus(server->rsa_pub, modulus, key_len); + crypto_rsa_pub_key_exponent(server->rsa_pub, exponent, key_len); - stream_write(client->net_stream, buffer, sizeof(buffer), NULL, NULL); + stream_send(client->net_stream, rcbuf_new(buffer, buf_len), NULL, NULL); return 0; } @@ -533,8 +533,8 @@ static int rsa_aes_send_challenge(struct nvnc_client* client, (struct rfb_rsa_aes_challenge_msg*)buffer; ssize_t len = crypto_rsa_encrypt(pub, msg->challenge, - RSA_AES_SERVER_KEY_LENGTH, client->rsa.challenge, - client->rsa.challenge_len); + crypto_rsa_pub_key_length(client->rsa.pub), + client->rsa.challenge, client->rsa.challenge_len); msg->length = htons(len); nvnc_trace("Challenge length is %zd", len); @@ -638,12 +638,14 @@ static int on_rsa_aes_challenge(struct nvnc_client* client) stream_upgrade_to_rsa_eas(client->net_stream, client->rsa.cipher_type, server_session_key, client_session_key); - uint8_t server_modulus[RSA_AES_SERVER_KEY_LENGTH]; - uint8_t server_exponent[RSA_AES_SERVER_KEY_LENGTH]; + size_t server_key_len = crypto_rsa_pub_key_length(server->rsa_pub); + uint8_t* server_modulus = malloc(server_key_len * 2); + uint8_t* server_exponent = server_modulus + server_key_len; + crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus, - RSA_AES_SERVER_KEY_LENGTH); + server_key_len); crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent, - RSA_AES_SERVER_KEY_LENGTH); + server_key_len); size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub); uint8_t* client_modulus = malloc(client_key_len * 2); @@ -654,21 +656,22 @@ static int on_rsa_aes_challenge(struct nvnc_client* client) crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent, client_key_len); - uint32_t server_key_len_be = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); + uint32_t server_key_len_be = htonl(server_key_len * 8); uint32_t client_key_len_be = htonl(client_key_len * 8); uint8_t server_hash[32] = {}; crypto_hash_many(server_hash, client_rsa_aes_hash_len(client), client->rsa.hash_type, (const struct crypto_data_entry[]) { { (uint8_t*)&server_key_len_be, 4 }, - { server_modulus, RSA_AES_SERVER_KEY_LENGTH }, - { server_exponent, RSA_AES_SERVER_KEY_LENGTH }, + { server_modulus, server_key_len }, + { server_exponent, server_key_len }, { (uint8_t*)&client_key_len_be, 4 }, { client_modulus, client_key_len }, { client_exponent, client_key_len }, {} }); + free(server_modulus); free(client_modulus); stream_write(client->net_stream, server_hash, @@ -688,12 +691,13 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client) struct nvnc* server = client->server; - uint8_t server_modulus[RSA_AES_SERVER_KEY_LENGTH]; - uint8_t server_exponent[RSA_AES_SERVER_KEY_LENGTH]; + size_t server_key_len = crypto_rsa_pub_key_length(server->rsa_pub); + uint8_t* server_modulus = malloc(server_key_len * 2); + uint8_t* server_exponent = server_modulus + server_key_len; crypto_rsa_pub_key_modulus(server->rsa_pub, server_modulus, - RSA_AES_SERVER_KEY_LENGTH); + server_key_len); crypto_rsa_pub_key_exponent(server->rsa_pub, server_exponent, - RSA_AES_SERVER_KEY_LENGTH); + server_key_len); size_t client_key_len = crypto_rsa_pub_key_length(client->rsa.pub); uint8_t* client_modulus = malloc(client_key_len * 2); @@ -704,7 +708,7 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client) crypto_rsa_pub_key_exponent(client->rsa.pub, client_exponent, client_key_len); - uint32_t server_key_len_be = htonl(RSA_AES_SERVER_KEY_LENGTH * 8); + uint32_t server_key_len_be = htonl(server_key_len * 8); uint32_t client_key_len_be = htonl(client_key_len * 8); uint8_t client_hash[32] = {}; @@ -714,12 +718,13 @@ static int on_rsa_aes_client_hash(struct nvnc_client* client) { client_modulus, client_key_len }, { client_exponent, client_key_len }, { (uint8_t*)&server_key_len_be, 4 }, - { server_modulus, RSA_AES_SERVER_KEY_LENGTH }, - { server_exponent, RSA_AES_SERVER_KEY_LENGTH }, + { server_modulus, server_key_len }, + { server_exponent, server_key_len }, {} }); free(client_modulus); + free(server_modulus); if (memcmp(msg, client_hash, client_rsa_aes_hash_len(client)) != 0) { nvnc_log(NVNC_LOG_INFO, "Client hash mismatch");