diff --git a/include/common.h b/include/common.h index 9a3b960..0afcb14 100644 --- a/include/common.h +++ b/include/common.h @@ -25,6 +25,11 @@ #include "neatvnc.h" #include "miniz.h" +#include "config.h" + +#ifdef ENABLE_TLS +#include +#endif #define MAX_ENCODINGS 32 #define MAX_OUTGOING_FRAMES 4 @@ -34,11 +39,17 @@ enum nvnc_client_state { VNC_CLIENT_STATE_ERROR = -1, VNC_CLIENT_STATE_WAITING_FOR_VERSION = 0, VNC_CLIENT_STATE_WAITING_FOR_SECURITY, +#ifdef ENABLE_TLS + VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION, + VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE, + VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH, +#endif VNC_CLIENT_STATE_WAITING_FOR_INIT, VNC_CLIENT_STATE_READY, }; struct nvnc; +struct stream; struct nvnc_common { void* userdata; @@ -47,7 +58,7 @@ struct nvnc_common { struct nvnc_client { struct nvnc_common common; int ref; - struct uv_tcp_s stream_handle; + struct stream* net_stream; struct nvnc* server; enum nvnc_client_state state; bool has_pixfmt; @@ -76,7 +87,8 @@ struct vnc_display { struct nvnc { struct nvnc_common common; - uv_tcp_t tcp_handle; + int fd; + uv_poll_t poll_handle; struct nvnc_client_list clients; struct vnc_display display; void* userdata; @@ -85,4 +97,10 @@ struct nvnc { nvnc_fb_req_fn fb_req_fn; nvnc_client_fn new_client_fn; struct nvnc_fb* frame; + +#ifdef ENABLE_TLS + gnutls_certificate_credentials_t tls_creds; + nvnc_auth_fn auth_fn; + void* auth_ud; +#endif }; diff --git a/include/neatvnc.h b/include/neatvnc.h index b8d55c7..36a3881 100644 --- a/include/neatvnc.h +++ b/include/neatvnc.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Andri Yngvason + * Copyright (c) 2019 - 2020 Andri Yngvason * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -41,6 +41,8 @@ typedef void (*nvnc_fb_req_fn)(struct nvnc_client*, bool is_incremental, uint16_t height); typedef void (*nvnc_client_fn)(struct nvnc_client*); typedef void (*nvnc_damage_fn)(struct pixman_region16* damage, void* userdata); +typedef bool (*nvnc_auth_fn)(const char* username, const char* password, + void* userdata); struct nvnc* nvnc_open(const char* addr, uint16_t port); void nvnc_close(struct nvnc* self); @@ -61,6 +63,10 @@ void nvnc_set_fb_req_fn(struct nvnc* self, nvnc_fb_req_fn); void nvnc_set_new_client_fn(struct nvnc* self, nvnc_client_fn); void nvnc_set_client_cleanup_fn(struct nvnc_client* self, nvnc_client_fn fn); +bool nvnc_has_auth(void); +int nvnc_enable_auth(struct nvnc* self, const char* privkey_path, + const char* cert_path, nvnc_auth_fn, void* userdata); + struct nvnc_fb* nvnc_fb_new(uint16_t width, uint16_t height, uint32_t fourcc_format); diff --git a/include/rfb-proto.h b/include/rfb-proto.h index 20fd5d4..23cfcda 100644 --- a/include/rfb-proto.h +++ b/include/rfb-proto.h @@ -29,6 +29,8 @@ enum rfb_security_type { RFB_SECURITY_TYPE_INVALID = 0, RFB_SECURITY_TYPE_NONE = 1, RFB_SECURITY_TYPE_VNC_AUTH = 2, + RFB_SECURITY_TYPE_TIGHT = 16, + RFB_SECURITY_TYPE_VENCRYPT = 19, }; enum rfb_security_handshake_result { @@ -64,6 +66,16 @@ enum rfb_server_to_client_msg_type { RFB_SERVER_TO_CLIENT_SERVER_CUT_TEXT = 3, }; +enum rfb_vencrypt_subtype { + RFB_VENCRYPT_PLAIN = 256, + RFB_VENCRYPT_TLS_NONE, + RFB_VENCRYPT_TLS_VNC, + RFB_VENCRYPT_TLS_PLAIN, + RFB_VENCRYPT_X509_NONE, + RFB_VENCRYPT_X509_VNC, + RFB_VENCRYPT_X509_PLAIN, +}; + struct rfb_security_types_msg { uint8_t n; uint8_t types[1]; @@ -146,3 +158,19 @@ struct rfb_server_fb_update_msg { uint8_t padding; uint16_t n_rects; } RFB_PACKED; + +struct rfb_vencrypt_version_msg { + uint8_t major; + uint8_t minor; +} RFB_PACKED; + +struct rfb_vencrypt_subtypes_msg { + uint8_t n; + uint32_t types[1]; +} RFB_PACKED; + +struct rfb_vencrypt_plain_auth_msg { + uint32_t username_len; + uint32_t password_len; + char text[0]; +} RFB_PACKED; diff --git a/src/server.c b/src/server.c index 413657c..72a3e7b 100644 --- a/src/server.c +++ b/src/server.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Andri Yngvason + * Copyright (c) 2019 - 2020 Andri Yngvason * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -25,7 +25,9 @@ #include "neatvnc.h" #include "common.h" #include "pixels.h" +#include "stream.h" #include "config.h" +#include "logging.h" #include #include @@ -37,6 +39,10 @@ #include #include +#ifdef ENABLE_TLS +#include +#endif + #ifndef DRM_FORMAT_INVALID #define DRM_FORMAT_INVALID 0 #endif @@ -74,38 +80,25 @@ static const char* fourcc_to_string(uint32_t fourcc) return buffer; } -static void allocate_read_buffer(uv_handle_t* handle, size_t suggested_size, - uv_buf_t* buf) +static void client_close(struct nvnc_client* client) { - (void)suggested_size; - - buf->base = malloc(READ_BUFFER_SIZE); - buf->len = buf->base ? READ_BUFFER_SIZE : 0; -} - -static void cleanup_client(uv_handle_t* handle) -{ - struct nvnc_client* client = container_of( - (uv_tcp_t*)handle, struct nvnc_client, stream_handle); + log_debug("client_close(%p): ref %d\n", client, client->ref); nvnc_client_fn fn = client->cleanup_fn; if (fn) fn(client); - deflateEnd(&client->z_stream); - LIST_REMOVE(client, link); + stream_destroy(client->net_stream); + deflateEnd(&client->z_stream); pixman_region_fini(&client->damage); free(client); } -static inline void client_close(struct nvnc_client* client) -{ - uv_close((uv_handle_t*)&client->stream_handle, cleanup_client); -} - static inline void client_unref(struct nvnc_client* client) { + assert(client->ref > 0); + if (--client->ref == 0) client_close(client); } @@ -115,11 +108,11 @@ static inline void client_ref(struct nvnc_client* client) ++client->ref; } -static void close_after_write(uv_write_t* req, int status) +static void close_after_write(void* userdata, enum stream_req_status status) { - struct nvnc_client* client = container_of( - (uv_tcp_t*)req->handle, struct nvnc_client, stream_handle); - + struct nvnc_client* client = userdata; + log_debug("close_after_write(%p): ref %d\n", client, client->ref); + stream_close(client->net_stream); client_unref(client); } @@ -137,15 +130,18 @@ static int handle_unsupported_version(struct nvnc_client* client) reason->length = htonl(strlen(reason_string)); (void)strcmp(reason->message, reason_string); - vnc__write((uv_stream_t*)&client->stream_handle, buffer, - 1 + sizeof(*reason) + strlen(reason_string), - close_after_write); + struct rcbuf* payload = + rcbuf_from_mem(buffer, + 1 + sizeof(*reason) + strlen(reason_string)); + stream_write(client->net_stream, payload, close_after_write, client); return 0; } static int on_version_message(struct nvnc_client* client) { + struct nvnc* server = client->server; + if (client->buffer_len - client->buffer_index < 12) return 0; @@ -156,17 +152,17 @@ static int on_version_message(struct nvnc_client* client) if (strcmp(RFB_VERSION_MESSAGE, version_string) != 0) return handle_unsupported_version(client); - /* clang-format off */ - const static struct rfb_security_types_msg security = { - .n = 1, - .types = { - RFB_SECURITY_TYPE_NONE, - }, - }; - /* clang-format on */ + struct rfb_security_types_msg security = { 0 }; + security.n = 1; + security.types[0] = RFB_SECURITY_TYPE_NONE; - vnc__write((uv_stream_t*)&client->stream_handle, &security, - sizeof(security), NULL); +#ifdef ENABLE_TLS + if (server->auth_fn) + security.types[0] = RFB_SECURITY_TYPE_VENCRYPT; +#endif + + struct rcbuf* payload = rcbuf_from_mem(&security, sizeof(security)); + stream_write(client->net_stream, payload, NULL, NULL); client->state = VNC_CLIENT_STATE_WAITING_FOR_SECURITY; return 12; @@ -189,13 +185,133 @@ static int handle_invalid_security_type(struct nvnc_client* client) reason->length = htonl(strlen(reason_string)); (void)strcmp(reason->message, reason_string); - vnc__write((uv_stream_t*)&client->stream_handle, buffer, - sizeof(*result) + sizeof(*reason) + strlen(reason_string), - close_after_write); + struct rcbuf* payload = + rcbuf_from_mem(buffer, sizeof(*result) + sizeof(*reason) + + strlen(reason_string)); + stream_write(client->net_stream, payload, close_after_write, client); return 0; } +static int security_handshake_ok(struct nvnc_client* client) +{ + uint32_t result = htonl(RFB_SECURITY_HANDSHAKE_OK); + struct rcbuf* payload = rcbuf_from_mem(&result, sizeof(result)); + return stream_write(client->net_stream, payload, NULL, NULL); +} + +static int send_byte(struct nvnc_client* client, uint8_t value) +{ + struct rcbuf* payload = rcbuf_from_mem(&value, sizeof(value)); + return stream_write(client->net_stream, payload, NULL, NULL); +} + +#ifdef ENABLE_TLS +static int vencrypt_send_version(struct nvnc_client* client) +{ + struct rfb_vencrypt_version_msg msg = { + .major = 0, + .minor = 2, + }; + + struct rcbuf* payload = rcbuf_from_mem(&msg, sizeof(msg)); + return stream_write(client->net_stream, payload, NULL, NULL); +} + +static int on_vencrypt_version_message(struct nvnc_client* client) +{ + struct rfb_vencrypt_version_msg* msg = + (struct rfb_vencrypt_version_msg*)&client->msg_buffer[client->buffer_index]; + + if (client->buffer_len - client->buffer_index < sizeof(*msg)) + return 0; + + if (msg->major != 0 || msg->minor != 2) { + // TODO: Say unsupported vencrypt type in message + handle_invalid_security_type(client); + return sizeof(*msg); + } + + send_byte(client, 0); + + struct rfb_vencrypt_subtypes_msg result = { .n = 1, }; + result.types[0] = htonl(RFB_VENCRYPT_X509_PLAIN); + + struct rcbuf* payload = rcbuf_from_mem(&result, sizeof(result)); + stream_write(client->net_stream, payload, NULL, NULL); + + client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE; + + return sizeof(*msg); +} + +static int on_vencrypt_subtype_message(struct nvnc_client* client) +{ + uint32_t* msg = (uint32_t*)&client->msg_buffer[client->buffer_index]; + + if (client->buffer_len - client->buffer_index < sizeof(*msg)) + return 0; + + enum rfb_vencrypt_subtype subtype = ntohl(*msg); + + if (subtype != RFB_VENCRYPT_X509_PLAIN) { + send_byte(client, 0); // TODO Close after write + stream_close(client->net_stream); + client_unref(client); + return sizeof(*msg); + } + + send_byte(client, 1); + + if (stream_upgrade_to_tls(client->net_stream, client->server->tls_creds) < 0) { + stream_close(client->net_stream); + client_unref(client); + return sizeof(*msg); + } + + client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH; + + return sizeof(*msg); +} + +static int on_vencrypt_plain_auth_message(struct nvnc_client* client) +{ + struct nvnc* server = client->server; + + struct rfb_vencrypt_plain_auth_msg* msg = + (void*)(client->msg_buffer + client->buffer_index); + + if (client->buffer_len - client->buffer_index < sizeof(*msg)) + return 0; + + uint32_t ulen = ntohl(msg->username_len); + uint32_t plen = ntohl(msg->password_len); + + if (client->buffer_len - client->buffer_index < sizeof(*msg) + ulen + plen) + return 0; + + char username[256]; + char password[256]; + + memcpy(username, msg->text, MIN(ulen, sizeof(username) - 1)); + memcpy(password, msg->text + ulen, MIN(plen, sizeof(password) - 1)); + + username[MIN(ulen, sizeof(username) - 1)] = '\0'; + password[MIN(plen, sizeof(password) - 1)] = '\0'; + + if (server->auth_fn(username, password, server->auth_ud)) { + log_debug("User \"%s\" authenticated\n", username); + security_handshake_ok(client); + client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT; + } else { + log_debug("User \"%s\" rejected\n", username); + handle_invalid_security_type(client); // TODO say wrong auth + } + + return sizeof(*msg) + ulen + plen; +} +#endif + static int on_security_message(struct nvnc_client* client) { if (client->buffer_len - client->buffer_index < 1) @@ -203,25 +319,38 @@ static int on_security_message(struct nvnc_client* client) uint8_t type = client->msg_buffer[client->buffer_index]; - if (type != RFB_SECURITY_TYPE_NONE) - return handle_invalid_security_type(client); + switch (type) { + case RFB_SECURITY_TYPE_NONE: + security_handshake_ok(client); + client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT; + break; +#ifdef ENABLE_TLS + case RFB_SECURITY_TYPE_VENCRYPT: + vencrypt_send_version(client); + client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION; + break; +#endif + default: + handle_invalid_security_type(client); + break; + } - enum rfb_security_handshake_result result = - htonl(RFB_SECURITY_HANDSHAKE_OK); - - vnc__write((uv_stream_t*)&client->stream_handle, &result, - sizeof(result), NULL); - - client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT; return sizeof(type); } static void disconnect_all_other_clients(struct nvnc_client* client) { struct nvnc_client* node; - LIST_FOREACH (node, &client->server->clients, link) - if (node != client) - client_unref(client); + struct nvnc_client* tmp; + + LIST_FOREACH_SAFE (node, &client->server->clients, link, tmp) + if (node != client) { + log_debug("disconnect other client %p (ref %d)\n", + node, node->ref); + stream_close(node->net_stream); + client_unref(node); + } + } static void send_server_init_message(struct nvnc_client* client) @@ -234,6 +363,7 @@ static void send_server_init_message(struct nvnc_client* client) struct rfb_server_init_msg* msg = calloc(1, size); if (!msg) { + stream_close(client->net_stream); client_unref(client); return; } @@ -244,6 +374,7 @@ static void send_server_init_message(struct nvnc_client* client) int rc = rfb_pixfmt_from_fourcc(&msg->pixel_format, display->pixfmt); if (rc < 0) { + stream_close(client->net_stream); client_unref(client); return; } @@ -252,9 +383,8 @@ static void send_server_init_message(struct nvnc_client* client) msg->pixel_format.green_max = htons(msg->pixel_format.green_max); msg->pixel_format.blue_max = htons(msg->pixel_format.blue_max); - vnc__write((uv_stream_t*)&client->stream_handle, msg, size, NULL); - - free(msg); + struct rcbuf* payload = rcbuf_new(msg, size); + stream_write(client->net_stream, payload, NULL, NULL); } static int on_init_message(struct nvnc_client* client) @@ -288,6 +418,7 @@ static int on_client_set_pixel_format(struct nvnc_client* client) if (!fmt->true_colour_flag) { /* We don't really know what to do with color maps right now */ + stream_close(client->net_stream); client_unref(client); return 0; } @@ -343,7 +474,7 @@ static void process_fb_update_requests(struct nvnc_client* client) if (!client->server->frame) return; - if (uv_is_closing((uv_handle_t*)&client->stream_handle)) + if (client->net_stream->state == STREAM_STATE_CLOSED) return; if (!pixman_region_not_empty(&client->damage)) @@ -474,6 +605,9 @@ static int on_client_message(struct nvnc_client* client) return on_client_cut_text(client); } + log_debug("Got uninterpretable message from client: %p (ref %d)\n", + client, client->ref); + stream_close(client->net_stream); client_unref(client); return 0; } @@ -482,14 +616,21 @@ static int try_read_client_message(struct nvnc_client* client) { switch (client->state) { case VNC_CLIENT_STATE_ERROR: - client_unref(client); - return 0; + return client->buffer_len - client->buffer_index; case VNC_CLIENT_STATE_WAITING_FOR_VERSION: return on_version_message(client); case VNC_CLIENT_STATE_WAITING_FOR_SECURITY: return on_security_message(client); case VNC_CLIENT_STATE_WAITING_FOR_INIT: return on_init_message(client); +#ifdef ENABLE_TLS + case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION: + return on_vencrypt_version_message(client); + case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE: + return on_vencrypt_subtype_message(client); + case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH: + return on_vencrypt_plain_auth_message(client); +#endif case VNC_CLIENT_STATE_READY: return on_client_message(client); } @@ -498,18 +639,35 @@ static int try_read_client_message(struct nvnc_client* client) return 0; } -static void on_client_read(uv_stream_t* stream, ssize_t n_read, - const uv_buf_t* buf) +static void on_client_event(struct stream* stream, enum stream_event event) { - struct nvnc_client* client = container_of( - (uv_tcp_t*)stream, struct nvnc_client, stream_handle); + struct nvnc_client* client = stream->userdata; + + assert(client->net_stream == stream); + + if (event == STREAM_EVENT_REMOTE_CLOSED) { + log_debug("Client %p (%d) hung up\n", client, client->ref); + stream_close(stream); + client_unref(client); + return; + } + +#define BUF_SIZE 16384 + char* buf = malloc(BUF_SIZE); + ssize_t n_read = stream_read(stream, buf, BUF_SIZE); +#undef BUF_SIZE if (n_read == 0) goto done; if (n_read < 0) { - uv_read_stop(stream); - client_unref(client); + if (errno != EAGAIN) { + log_debug("Client connection error: %p (ref %d)\n", + client, client->ref); + stream_close(stream); + client_unref(client); + } + goto done; } @@ -518,12 +676,13 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read, if ((size_t)n_read > MSG_BUFFER_SIZE - client->buffer_len) { /* Can't handle this. Let's just give up */ client->state = VNC_CLIENT_STATE_ERROR; - uv_read_stop(stream); + log_debug("Client whoops: %p (ref %d)\n", client, client->ref); + stream_close(client->net_stream); client_unref(client); goto done; } - memcpy(client->msg_buffer + client->buffer_len, buf->base, n_read); + memcpy(client->msg_buffer + client->buffer_len, buf, n_read); client->buffer_len += n_read; while (1) { @@ -532,6 +691,7 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read, break; client->buffer_index += rc; + } assert(client->buffer_index <= client->buffer_len); @@ -542,13 +702,13 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read, client->buffer_index = 0; done: - free(buf->base); + free(buf); } -static void on_connection(uv_stream_t* server_stream, int status) +static void on_connection(uv_poll_t* poll_handle, int status, int events) { struct nvnc* server = - container_of((uv_tcp_t*)server_stream, struct nvnc, tcp_handle); + container_of(poll_handle, struct nvnc, poll_handle); struct nvnc_client* client = calloc(1, sizeof(*client)); if (!client) @@ -557,6 +717,14 @@ static void on_connection(uv_stream_t* server_stream, int status) client->ref = 1; client->server = server; + int fd = accept(server->fd, NULL, 0); + if (fd < 0) + goto accept_failure; + + client->net_stream = stream_new(fd, on_client_event, client); + if (!client->net_stream) + goto stream_failure; + int rc = deflateInit2(&client->z_stream, /* compression level: */ 1, /* method: */ Z_DEFLATED, @@ -564,51 +732,35 @@ static void on_connection(uv_stream_t* server_stream, int status) /* mem level: */ 9, /* strategy: */ Z_DEFAULT_STRATEGY); - if (rc != Z_OK) { - free(client); - return; - } + if (rc != Z_OK) + goto deflate_failure; pixman_region_init(&client->damage); - uv_tcp_init(uv_default_loop(), &client->stream_handle); + struct rcbuf* payload = rcbuf_from_string(RFB_VERSION_MESSAGE); + if (!payload) + goto payload_failure; - uv_accept((uv_stream_t*)&server->tcp_handle, - (uv_stream_t*)&client->stream_handle); - - uv_read_start((uv_stream_t*)&client->stream_handle, - allocate_read_buffer, on_client_read); - - vnc__write((uv_stream_t*)&client->stream_handle, RFB_VERSION_MESSAGE, - strlen(RFB_VERSION_MESSAGE), NULL); + stream_write(client->net_stream, payload, NULL, NULL); LIST_INSERT_HEAD(&server->clients, client, link); client->state = VNC_CLIENT_STATE_WAITING_FOR_VERSION; -} -int vnc_server_init(struct nvnc* self, const char* address, int port) -{ - LIST_INIT(&self->clients); + log_debug("New client connection: %p (ref %d)\n", client, client->ref); - uv_tcp_init(uv_default_loop(), &self->tcp_handle); + return; - struct sockaddr_in addr = {0}; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = inet_addr(address); - addr.sin_port = htons(port); +payload_failure: + deflateEnd(&client->z_stream); +deflate_failure: + stream_destroy(client->net_stream); +stream_failure: + close(fd); +accept_failure: + free(client); - if (uv_tcp_bind(&self->tcp_handle, (const struct sockaddr*)&addr, 0) < 0) - goto failure; - - if (uv_listen((uv_stream_t*)&self->tcp_handle, 16, on_connection) < 0) - goto failure; - - return 0; - -failure: - uv_unref((uv_handle_t*)&self->tcp_handle); - return -1; + log_debug("Failed to accept a connection\n"); } EXPORT @@ -622,22 +774,32 @@ struct nvnc* nvnc_open(const char* address, uint16_t port) LIST_INIT(&self->clients); - uv_tcp_init(uv_default_loop(), &self->tcp_handle); + self->fd = socket(AF_INET, SOCK_STREAM, 0); + if (self->fd < 0) + return NULL; + + int one = 1; + if (setsockopt(self->fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(int)) < 0) + goto failure; struct sockaddr_in addr = {0}; addr.sin_family = AF_INET; addr.sin_addr.s_addr = inet_addr(address); addr.sin_port = htons(port); - if (uv_tcp_bind(&self->tcp_handle, (const struct sockaddr*)&addr, 0) < 0) + if (bind(self->fd, (const struct sockaddr*)&addr, sizeof(addr)) < 0) goto failure; - if (uv_listen((uv_stream_t*)&self->tcp_handle, 16, on_connection) < 0) + if (listen(self->fd, 16) < 0) goto failure; + uv_poll_init(uv_default_loop(), &self->poll_handle, self->fd); + uv_poll_start(&self->poll_handle, UV_READABLE, on_connection); + return self; + failure: - uv_unref((uv_handle_t*)&self->tcp_handle); + close(self->fd); return NULL; } @@ -649,19 +811,28 @@ void nvnc_close(struct nvnc* self) if (self->frame) nvnc_fb_unref(self->frame); - LIST_FOREACH (client, &self->clients, link) + struct nvnc_client* tmp; + LIST_FOREACH_SAFE (client, &self->clients, link, tmp) client_unref(client); - uv_unref((uv_handle_t*)&self->tcp_handle); + uv_poll_stop(&self->poll_handle); + close(self->fd); + +#ifdef ENABLE_TLS + if (self->tls_creds) { + gnutls_certificate_free_credentials(self->tls_creds); + gnutls_global_deinit(); + } +#endif + free(self); } -static void on_write_frame_done(uv_write_t* req, int status) +static void on_write_frame_done(void* userdata, enum stream_req_status status) { - struct vnc_write_request* rq = (struct vnc_write_request*)req; - struct nvnc_client* client = rq->userdata; + struct nvnc_client* client = userdata; client->is_updating = false; - free(rq->buffer.base); + client_unref(client); } enum rfb_encodings choose_frame_encoding(struct nvnc_client* client) @@ -689,7 +860,7 @@ void do_client_update_fb(uv_work_t* work) enum rfb_encodings encoding = choose_frame_encoding(client); if (encoding == -1) { - uv_read_stop((uv_stream_t*)&client->stream_handle); + stream_close(client->net_stream); client_unref(client); return; } @@ -725,21 +896,28 @@ void on_client_update_fb_done(uv_work_t* work, int status) struct fb_update_work* update = (void*)work; struct nvnc_client* client = update->client; - struct nvnc* server = client->server; struct vec* frame = &update->frame; - if (!uv_is_closing((uv_handle_t*)&client->stream_handle)) - vnc__write2((uv_stream_t*)&client->stream_handle, frame->data, - frame->len, on_write_frame_done, client); - else + client_ref(client); + + + if (client->net_stream->state != STREAM_STATE_CLOSED) { + struct rcbuf* payload = rcbuf_new(frame->data, frame->len); + stream_write(client->net_stream, payload, on_write_frame_done, + client); + } else { client->is_updating = false; + vec_destroy(frame); + client_unref(client); + } client->n_pending_requests--; process_fb_update_requests(client); nvnc_fb_unref(update->fb); - client_unref(client); pixman_region_fini(&update->region); + + client_unref(client); free(update); } @@ -798,8 +976,9 @@ int nvnc_feed_frame(struct nvnc* self, struct nvnc_fb* fb, self->frame = fb; nvnc_fb_ref(self->frame); - LIST_FOREACH (client, &self->clients, link) { - if (uv_is_closing((uv_handle_t*)&client->stream_handle)) + struct nvnc_client* tmp; + LIST_FOREACH_SAFE (client, &self->clients, link, tmp) { + if (client->net_stream->state == STREAM_STATE_CLOSED) continue; pixman_region_union(&client->damage, &client->damage, @@ -878,3 +1057,61 @@ void nvnc_set_name(struct nvnc* self, const char* name) strncpy(self->display.name, name, sizeof(self->display.name)); self->display.name[sizeof(self->display.name) - 1] = '\0'; } + +EXPORT +bool nvnc_has_auth(void) +{ +#ifdef ENABLE_TLS + return true; +#else + return false; +#endif +} + +EXPORT +int nvnc_enable_auth(struct nvnc* self, const char* privkey_path, + const char* cert_path, nvnc_auth_fn auth_fn, + void* userdata) +{ +#ifdef ENABLE_TLS + if (self->tls_creds) + return -1; + + /* Note: This is globally reference counted, so we don't need to worry + * about messing with other libraries. + */ + int rc = gnutls_global_init(); + if (rc != GNUTLS_E_SUCCESS) { + log_error("GnuTLS: Failed to initialise: %s\n", + gnutls_strerror(rc)); + return -1; + } + + rc = gnutls_certificate_allocate_credentials(&self->tls_creds); + if (rc != GNUTLS_E_SUCCESS) { + log_error("GnuTLS: Failed to allocate credentials: %s\n", + gnutls_strerror(rc)); + goto cert_alloc_failure; + } + + rc = gnutls_certificate_set_x509_key_file( + self->tls_creds, cert_path, privkey_path, GNUTLS_X509_FMT_PEM); + if (rc != GNUTLS_E_SUCCESS) { + log_error("GnuTLS: Failed to load credentials: %s\n", + gnutls_strerror(rc)); + goto cert_set_failure; + } + + self->auth_fn = auth_fn; + self->auth_ud = userdata; + + return 0; + +cert_set_failure: + gnutls_certificate_free_credentials(self->tls_creds); + self->tls_creds = NULL; +cert_alloc_failure: + gnutls_global_deinit(); +#endif + return -1; +}