Implement VeNCrypt with x509 plain authentication

vencrypt
Andri Yngvason 2020-01-25 15:35:14 +00:00
parent 113f262115
commit 19e4e42036
4 changed files with 413 additions and 124 deletions

View File

@ -25,6 +25,11 @@
#include "neatvnc.h" #include "neatvnc.h"
#include "miniz.h" #include "miniz.h"
#include "config.h"
#ifdef ENABLE_TLS
#include <gnutls/gnutls.h>
#endif
#define MAX_ENCODINGS 32 #define MAX_ENCODINGS 32
#define MAX_OUTGOING_FRAMES 4 #define MAX_OUTGOING_FRAMES 4
@ -34,11 +39,17 @@ enum nvnc_client_state {
VNC_CLIENT_STATE_ERROR = -1, VNC_CLIENT_STATE_ERROR = -1,
VNC_CLIENT_STATE_WAITING_FOR_VERSION = 0, VNC_CLIENT_STATE_WAITING_FOR_VERSION = 0,
VNC_CLIENT_STATE_WAITING_FOR_SECURITY, VNC_CLIENT_STATE_WAITING_FOR_SECURITY,
#ifdef ENABLE_TLS
VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION,
VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE,
VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH,
#endif
VNC_CLIENT_STATE_WAITING_FOR_INIT, VNC_CLIENT_STATE_WAITING_FOR_INIT,
VNC_CLIENT_STATE_READY, VNC_CLIENT_STATE_READY,
}; };
struct nvnc; struct nvnc;
struct stream;
struct nvnc_common { struct nvnc_common {
void* userdata; void* userdata;
@ -47,7 +58,7 @@ struct nvnc_common {
struct nvnc_client { struct nvnc_client {
struct nvnc_common common; struct nvnc_common common;
int ref; int ref;
struct uv_tcp_s stream_handle; struct stream* net_stream;
struct nvnc* server; struct nvnc* server;
enum nvnc_client_state state; enum nvnc_client_state state;
bool has_pixfmt; bool has_pixfmt;
@ -76,7 +87,8 @@ struct vnc_display {
struct nvnc { struct nvnc {
struct nvnc_common common; struct nvnc_common common;
uv_tcp_t tcp_handle; int fd;
uv_poll_t poll_handle;
struct nvnc_client_list clients; struct nvnc_client_list clients;
struct vnc_display display; struct vnc_display display;
void* userdata; void* userdata;
@ -85,4 +97,10 @@ struct nvnc {
nvnc_fb_req_fn fb_req_fn; nvnc_fb_req_fn fb_req_fn;
nvnc_client_fn new_client_fn; nvnc_client_fn new_client_fn;
struct nvnc_fb* frame; struct nvnc_fb* frame;
#ifdef ENABLE_TLS
gnutls_certificate_credentials_t tls_creds;
nvnc_auth_fn auth_fn;
void* auth_ud;
#endif
}; };

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2019 Andri Yngvason * Copyright (c) 2019 - 2020 Andri Yngvason
* *
* Permission to use, copy, modify, and/or distribute this software for any * Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above * purpose with or without fee is hereby granted, provided that the above
@ -41,6 +41,8 @@ typedef void (*nvnc_fb_req_fn)(struct nvnc_client*, bool is_incremental,
uint16_t height); uint16_t height);
typedef void (*nvnc_client_fn)(struct nvnc_client*); typedef void (*nvnc_client_fn)(struct nvnc_client*);
typedef void (*nvnc_damage_fn)(struct pixman_region16* damage, void* userdata); typedef void (*nvnc_damage_fn)(struct pixman_region16* damage, void* userdata);
typedef bool (*nvnc_auth_fn)(const char* username, const char* password,
void* userdata);
struct nvnc* nvnc_open(const char* addr, uint16_t port); struct nvnc* nvnc_open(const char* addr, uint16_t port);
void nvnc_close(struct nvnc* self); void nvnc_close(struct nvnc* self);
@ -61,6 +63,10 @@ void nvnc_set_fb_req_fn(struct nvnc* self, nvnc_fb_req_fn);
void nvnc_set_new_client_fn(struct nvnc* self, nvnc_client_fn); void nvnc_set_new_client_fn(struct nvnc* self, nvnc_client_fn);
void nvnc_set_client_cleanup_fn(struct nvnc_client* self, nvnc_client_fn fn); void nvnc_set_client_cleanup_fn(struct nvnc_client* self, nvnc_client_fn fn);
bool nvnc_has_auth(void);
int nvnc_enable_auth(struct nvnc* self, const char* privkey_path,
const char* cert_path, nvnc_auth_fn, void* userdata);
struct nvnc_fb* nvnc_fb_new(uint16_t width, uint16_t height, struct nvnc_fb* nvnc_fb_new(uint16_t width, uint16_t height,
uint32_t fourcc_format); uint32_t fourcc_format);

View File

@ -29,6 +29,8 @@ enum rfb_security_type {
RFB_SECURITY_TYPE_INVALID = 0, RFB_SECURITY_TYPE_INVALID = 0,
RFB_SECURITY_TYPE_NONE = 1, RFB_SECURITY_TYPE_NONE = 1,
RFB_SECURITY_TYPE_VNC_AUTH = 2, RFB_SECURITY_TYPE_VNC_AUTH = 2,
RFB_SECURITY_TYPE_TIGHT = 16,
RFB_SECURITY_TYPE_VENCRYPT = 19,
}; };
enum rfb_security_handshake_result { enum rfb_security_handshake_result {
@ -64,6 +66,16 @@ enum rfb_server_to_client_msg_type {
RFB_SERVER_TO_CLIENT_SERVER_CUT_TEXT = 3, RFB_SERVER_TO_CLIENT_SERVER_CUT_TEXT = 3,
}; };
enum rfb_vencrypt_subtype {
RFB_VENCRYPT_PLAIN = 256,
RFB_VENCRYPT_TLS_NONE,
RFB_VENCRYPT_TLS_VNC,
RFB_VENCRYPT_TLS_PLAIN,
RFB_VENCRYPT_X509_NONE,
RFB_VENCRYPT_X509_VNC,
RFB_VENCRYPT_X509_PLAIN,
};
struct rfb_security_types_msg { struct rfb_security_types_msg {
uint8_t n; uint8_t n;
uint8_t types[1]; uint8_t types[1];
@ -146,3 +158,19 @@ struct rfb_server_fb_update_msg {
uint8_t padding; uint8_t padding;
uint16_t n_rects; uint16_t n_rects;
} RFB_PACKED; } RFB_PACKED;
struct rfb_vencrypt_version_msg {
uint8_t major;
uint8_t minor;
} RFB_PACKED;
struct rfb_vencrypt_subtypes_msg {
uint8_t n;
uint32_t types[1];
} RFB_PACKED;
struct rfb_vencrypt_plain_auth_msg {
uint32_t username_len;
uint32_t password_len;
char text[0];
} RFB_PACKED;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2019 Andri Yngvason * Copyright (c) 2019 - 2020 Andri Yngvason
* *
* Permission to use, copy, modify, and/or distribute this software for any * Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above * purpose with or without fee is hereby granted, provided that the above
@ -25,7 +25,9 @@
#include "neatvnc.h" #include "neatvnc.h"
#include "common.h" #include "common.h"
#include "pixels.h" #include "pixels.h"
#include "stream.h"
#include "config.h" #include "config.h"
#include "logging.h"
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h> #include <unistd.h>
@ -37,6 +39,10 @@
#include <pixman.h> #include <pixman.h>
#include <pthread.h> #include <pthread.h>
#ifdef ENABLE_TLS
#include <gnutls/gnutls.h>
#endif
#ifndef DRM_FORMAT_INVALID #ifndef DRM_FORMAT_INVALID
#define DRM_FORMAT_INVALID 0 #define DRM_FORMAT_INVALID 0
#endif #endif
@ -74,38 +80,25 @@ static const char* fourcc_to_string(uint32_t fourcc)
return buffer; return buffer;
} }
static void allocate_read_buffer(uv_handle_t* handle, size_t suggested_size, static void client_close(struct nvnc_client* client)
uv_buf_t* buf)
{ {
(void)suggested_size; log_debug("client_close(%p): ref %d\n", client, client->ref);
buf->base = malloc(READ_BUFFER_SIZE);
buf->len = buf->base ? READ_BUFFER_SIZE : 0;
}
static void cleanup_client(uv_handle_t* handle)
{
struct nvnc_client* client = container_of(
(uv_tcp_t*)handle, struct nvnc_client, stream_handle);
nvnc_client_fn fn = client->cleanup_fn; nvnc_client_fn fn = client->cleanup_fn;
if (fn) if (fn)
fn(client); fn(client);
deflateEnd(&client->z_stream);
LIST_REMOVE(client, link); LIST_REMOVE(client, link);
stream_destroy(client->net_stream);
deflateEnd(&client->z_stream);
pixman_region_fini(&client->damage); pixman_region_fini(&client->damage);
free(client); free(client);
} }
static inline void client_close(struct nvnc_client* client)
{
uv_close((uv_handle_t*)&client->stream_handle, cleanup_client);
}
static inline void client_unref(struct nvnc_client* client) static inline void client_unref(struct nvnc_client* client)
{ {
assert(client->ref > 0);
if (--client->ref == 0) if (--client->ref == 0)
client_close(client); client_close(client);
} }
@ -115,11 +108,11 @@ static inline void client_ref(struct nvnc_client* client)
++client->ref; ++client->ref;
} }
static void close_after_write(uv_write_t* req, int status) static void close_after_write(void* userdata, enum stream_req_status status)
{ {
struct nvnc_client* client = container_of( struct nvnc_client* client = userdata;
(uv_tcp_t*)req->handle, struct nvnc_client, stream_handle); log_debug("close_after_write(%p): ref %d\n", client, client->ref);
stream_close(client->net_stream);
client_unref(client); client_unref(client);
} }
@ -137,15 +130,18 @@ static int handle_unsupported_version(struct nvnc_client* client)
reason->length = htonl(strlen(reason_string)); reason->length = htonl(strlen(reason_string));
(void)strcmp(reason->message, reason_string); (void)strcmp(reason->message, reason_string);
vnc__write((uv_stream_t*)&client->stream_handle, buffer, struct rcbuf* payload =
1 + sizeof(*reason) + strlen(reason_string), rcbuf_from_mem(buffer,
close_after_write); 1 + sizeof(*reason) + strlen(reason_string));
stream_write(client->net_stream, payload, close_after_write, client);
return 0; return 0;
} }
static int on_version_message(struct nvnc_client* client) static int on_version_message(struct nvnc_client* client)
{ {
struct nvnc* server = client->server;
if (client->buffer_len - client->buffer_index < 12) if (client->buffer_len - client->buffer_index < 12)
return 0; return 0;
@ -156,17 +152,17 @@ static int on_version_message(struct nvnc_client* client)
if (strcmp(RFB_VERSION_MESSAGE, version_string) != 0) if (strcmp(RFB_VERSION_MESSAGE, version_string) != 0)
return handle_unsupported_version(client); return handle_unsupported_version(client);
/* clang-format off */ struct rfb_security_types_msg security = { 0 };
const static struct rfb_security_types_msg security = { security.n = 1;
.n = 1, security.types[0] = RFB_SECURITY_TYPE_NONE;
.types = {
RFB_SECURITY_TYPE_NONE,
},
};
/* clang-format on */
vnc__write((uv_stream_t*)&client->stream_handle, &security, #ifdef ENABLE_TLS
sizeof(security), NULL); if (server->auth_fn)
security.types[0] = RFB_SECURITY_TYPE_VENCRYPT;
#endif
struct rcbuf* payload = rcbuf_from_mem(&security, sizeof(security));
stream_write(client->net_stream, payload, NULL, NULL);
client->state = VNC_CLIENT_STATE_WAITING_FOR_SECURITY; client->state = VNC_CLIENT_STATE_WAITING_FOR_SECURITY;
return 12; return 12;
@ -189,13 +185,133 @@ static int handle_invalid_security_type(struct nvnc_client* client)
reason->length = htonl(strlen(reason_string)); reason->length = htonl(strlen(reason_string));
(void)strcmp(reason->message, reason_string); (void)strcmp(reason->message, reason_string);
vnc__write((uv_stream_t*)&client->stream_handle, buffer, struct rcbuf* payload =
sizeof(*result) + sizeof(*reason) + strlen(reason_string), rcbuf_from_mem(buffer, sizeof(*result) + sizeof(*reason) +
close_after_write); strlen(reason_string));
stream_write(client->net_stream, payload, close_after_write, client);
return 0; return 0;
} }
static int security_handshake_ok(struct nvnc_client* client)
{
uint32_t result = htonl(RFB_SECURITY_HANDSHAKE_OK);
struct rcbuf* payload = rcbuf_from_mem(&result, sizeof(result));
return stream_write(client->net_stream, payload, NULL, NULL);
}
static int send_byte(struct nvnc_client* client, uint8_t value)
{
struct rcbuf* payload = rcbuf_from_mem(&value, sizeof(value));
return stream_write(client->net_stream, payload, NULL, NULL);
}
#ifdef ENABLE_TLS
static int vencrypt_send_version(struct nvnc_client* client)
{
struct rfb_vencrypt_version_msg msg = {
.major = 0,
.minor = 2,
};
struct rcbuf* payload = rcbuf_from_mem(&msg, sizeof(msg));
return stream_write(client->net_stream, payload, NULL, NULL);
}
static int on_vencrypt_version_message(struct nvnc_client* client)
{
struct rfb_vencrypt_version_msg* msg =
(struct rfb_vencrypt_version_msg*)&client->msg_buffer[client->buffer_index];
if (client->buffer_len - client->buffer_index < sizeof(*msg))
return 0;
if (msg->major != 0 || msg->minor != 2) {
// TODO: Say unsupported vencrypt type in message
handle_invalid_security_type(client);
return sizeof(*msg);
}
send_byte(client, 0);
struct rfb_vencrypt_subtypes_msg result = { .n = 1, };
result.types[0] = htonl(RFB_VENCRYPT_X509_PLAIN);
struct rcbuf* payload = rcbuf_from_mem(&result, sizeof(result));
stream_write(client->net_stream, payload, NULL, NULL);
client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE;
return sizeof(*msg);
}
static int on_vencrypt_subtype_message(struct nvnc_client* client)
{
uint32_t* msg = (uint32_t*)&client->msg_buffer[client->buffer_index];
if (client->buffer_len - client->buffer_index < sizeof(*msg))
return 0;
enum rfb_vencrypt_subtype subtype = ntohl(*msg);
if (subtype != RFB_VENCRYPT_X509_PLAIN) {
send_byte(client, 0); // TODO Close after write
stream_close(client->net_stream);
client_unref(client);
return sizeof(*msg);
}
send_byte(client, 1);
if (stream_upgrade_to_tls(client->net_stream, client->server->tls_creds) < 0) {
stream_close(client->net_stream);
client_unref(client);
return sizeof(*msg);
}
client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH;
return sizeof(*msg);
}
static int on_vencrypt_plain_auth_message(struct nvnc_client* client)
{
struct nvnc* server = client->server;
struct rfb_vencrypt_plain_auth_msg* msg =
(void*)(client->msg_buffer + client->buffer_index);
if (client->buffer_len - client->buffer_index < sizeof(*msg))
return 0;
uint32_t ulen = ntohl(msg->username_len);
uint32_t plen = ntohl(msg->password_len);
if (client->buffer_len - client->buffer_index < sizeof(*msg) + ulen + plen)
return 0;
char username[256];
char password[256];
memcpy(username, msg->text, MIN(ulen, sizeof(username) - 1));
memcpy(password, msg->text + ulen, MIN(plen, sizeof(password) - 1));
username[MIN(ulen, sizeof(username) - 1)] = '\0';
password[MIN(plen, sizeof(password) - 1)] = '\0';
if (server->auth_fn(username, password, server->auth_ud)) {
log_debug("User \"%s\" authenticated\n", username);
security_handshake_ok(client);
client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT;
} else {
log_debug("User \"%s\" rejected\n", username);
handle_invalid_security_type(client); // TODO say wrong auth
}
return sizeof(*msg) + ulen + plen;
}
#endif
static int on_security_message(struct nvnc_client* client) static int on_security_message(struct nvnc_client* client)
{ {
if (client->buffer_len - client->buffer_index < 1) if (client->buffer_len - client->buffer_index < 1)
@ -203,25 +319,38 @@ static int on_security_message(struct nvnc_client* client)
uint8_t type = client->msg_buffer[client->buffer_index]; uint8_t type = client->msg_buffer[client->buffer_index];
if (type != RFB_SECURITY_TYPE_NONE) switch (type) {
return handle_invalid_security_type(client); case RFB_SECURITY_TYPE_NONE:
security_handshake_ok(client);
client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT;
break;
#ifdef ENABLE_TLS
case RFB_SECURITY_TYPE_VENCRYPT:
vencrypt_send_version(client);
client->state = VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION;
break;
#endif
default:
handle_invalid_security_type(client);
break;
}
enum rfb_security_handshake_result result =
htonl(RFB_SECURITY_HANDSHAKE_OK);
vnc__write((uv_stream_t*)&client->stream_handle, &result,
sizeof(result), NULL);
client->state = VNC_CLIENT_STATE_WAITING_FOR_INIT;
return sizeof(type); return sizeof(type);
} }
static void disconnect_all_other_clients(struct nvnc_client* client) static void disconnect_all_other_clients(struct nvnc_client* client)
{ {
struct nvnc_client* node; struct nvnc_client* node;
LIST_FOREACH (node, &client->server->clients, link) struct nvnc_client* tmp;
if (node != client)
client_unref(client); LIST_FOREACH_SAFE (node, &client->server->clients, link, tmp)
if (node != client) {
log_debug("disconnect other client %p (ref %d)\n",
node, node->ref);
stream_close(node->net_stream);
client_unref(node);
}
} }
static void send_server_init_message(struct nvnc_client* client) static void send_server_init_message(struct nvnc_client* client)
@ -234,6 +363,7 @@ static void send_server_init_message(struct nvnc_client* client)
struct rfb_server_init_msg* msg = calloc(1, size); struct rfb_server_init_msg* msg = calloc(1, size);
if (!msg) { if (!msg) {
stream_close(client->net_stream);
client_unref(client); client_unref(client);
return; return;
} }
@ -244,6 +374,7 @@ static void send_server_init_message(struct nvnc_client* client)
int rc = rfb_pixfmt_from_fourcc(&msg->pixel_format, display->pixfmt); int rc = rfb_pixfmt_from_fourcc(&msg->pixel_format, display->pixfmt);
if (rc < 0) { if (rc < 0) {
stream_close(client->net_stream);
client_unref(client); client_unref(client);
return; return;
} }
@ -252,9 +383,8 @@ static void send_server_init_message(struct nvnc_client* client)
msg->pixel_format.green_max = htons(msg->pixel_format.green_max); msg->pixel_format.green_max = htons(msg->pixel_format.green_max);
msg->pixel_format.blue_max = htons(msg->pixel_format.blue_max); msg->pixel_format.blue_max = htons(msg->pixel_format.blue_max);
vnc__write((uv_stream_t*)&client->stream_handle, msg, size, NULL); struct rcbuf* payload = rcbuf_new(msg, size);
stream_write(client->net_stream, payload, NULL, NULL);
free(msg);
} }
static int on_init_message(struct nvnc_client* client) static int on_init_message(struct nvnc_client* client)
@ -288,6 +418,7 @@ static int on_client_set_pixel_format(struct nvnc_client* client)
if (!fmt->true_colour_flag) { if (!fmt->true_colour_flag) {
/* We don't really know what to do with color maps right now */ /* We don't really know what to do with color maps right now */
stream_close(client->net_stream);
client_unref(client); client_unref(client);
return 0; return 0;
} }
@ -343,7 +474,7 @@ static void process_fb_update_requests(struct nvnc_client* client)
if (!client->server->frame) if (!client->server->frame)
return; return;
if (uv_is_closing((uv_handle_t*)&client->stream_handle)) if (client->net_stream->state == STREAM_STATE_CLOSED)
return; return;
if (!pixman_region_not_empty(&client->damage)) if (!pixman_region_not_empty(&client->damage))
@ -474,6 +605,9 @@ static int on_client_message(struct nvnc_client* client)
return on_client_cut_text(client); return on_client_cut_text(client);
} }
log_debug("Got uninterpretable message from client: %p (ref %d)\n",
client, client->ref);
stream_close(client->net_stream);
client_unref(client); client_unref(client);
return 0; return 0;
} }
@ -482,14 +616,21 @@ static int try_read_client_message(struct nvnc_client* client)
{ {
switch (client->state) { switch (client->state) {
case VNC_CLIENT_STATE_ERROR: case VNC_CLIENT_STATE_ERROR:
client_unref(client); return client->buffer_len - client->buffer_index;
return 0;
case VNC_CLIENT_STATE_WAITING_FOR_VERSION: case VNC_CLIENT_STATE_WAITING_FOR_VERSION:
return on_version_message(client); return on_version_message(client);
case VNC_CLIENT_STATE_WAITING_FOR_SECURITY: case VNC_CLIENT_STATE_WAITING_FOR_SECURITY:
return on_security_message(client); return on_security_message(client);
case VNC_CLIENT_STATE_WAITING_FOR_INIT: case VNC_CLIENT_STATE_WAITING_FOR_INIT:
return on_init_message(client); return on_init_message(client);
#ifdef ENABLE_TLS
case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_VERSION:
return on_vencrypt_version_message(client);
case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_SUBTYPE:
return on_vencrypt_subtype_message(client);
case VNC_CLIENT_STATE_WAITING_FOR_VENCRYPT_PLAIN_AUTH:
return on_vencrypt_plain_auth_message(client);
#endif
case VNC_CLIENT_STATE_READY: case VNC_CLIENT_STATE_READY:
return on_client_message(client); return on_client_message(client);
} }
@ -498,18 +639,35 @@ static int try_read_client_message(struct nvnc_client* client)
return 0; return 0;
} }
static void on_client_read(uv_stream_t* stream, ssize_t n_read, static void on_client_event(struct stream* stream, enum stream_event event)
const uv_buf_t* buf)
{ {
struct nvnc_client* client = container_of( struct nvnc_client* client = stream->userdata;
(uv_tcp_t*)stream, struct nvnc_client, stream_handle);
assert(client->net_stream == stream);
if (event == STREAM_EVENT_REMOTE_CLOSED) {
log_debug("Client %p (%d) hung up\n", client, client->ref);
stream_close(stream);
client_unref(client);
return;
}
#define BUF_SIZE 16384
char* buf = malloc(BUF_SIZE);
ssize_t n_read = stream_read(stream, buf, BUF_SIZE);
#undef BUF_SIZE
if (n_read == 0) if (n_read == 0)
goto done; goto done;
if (n_read < 0) { if (n_read < 0) {
uv_read_stop(stream); if (errno != EAGAIN) {
client_unref(client); log_debug("Client connection error: %p (ref %d)\n",
client, client->ref);
stream_close(stream);
client_unref(client);
}
goto done; goto done;
} }
@ -518,12 +676,13 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read,
if ((size_t)n_read > MSG_BUFFER_SIZE - client->buffer_len) { if ((size_t)n_read > MSG_BUFFER_SIZE - client->buffer_len) {
/* Can't handle this. Let's just give up */ /* Can't handle this. Let's just give up */
client->state = VNC_CLIENT_STATE_ERROR; client->state = VNC_CLIENT_STATE_ERROR;
uv_read_stop(stream); log_debug("Client whoops: %p (ref %d)\n", client, client->ref);
stream_close(client->net_stream);
client_unref(client); client_unref(client);
goto done; goto done;
} }
memcpy(client->msg_buffer + client->buffer_len, buf->base, n_read); memcpy(client->msg_buffer + client->buffer_len, buf, n_read);
client->buffer_len += n_read; client->buffer_len += n_read;
while (1) { while (1) {
@ -532,6 +691,7 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read,
break; break;
client->buffer_index += rc; client->buffer_index += rc;
} }
assert(client->buffer_index <= client->buffer_len); assert(client->buffer_index <= client->buffer_len);
@ -542,13 +702,13 @@ static void on_client_read(uv_stream_t* stream, ssize_t n_read,
client->buffer_index = 0; client->buffer_index = 0;
done: done:
free(buf->base); free(buf);
} }
static void on_connection(uv_stream_t* server_stream, int status) static void on_connection(uv_poll_t* poll_handle, int status, int events)
{ {
struct nvnc* server = struct nvnc* server =
container_of((uv_tcp_t*)server_stream, struct nvnc, tcp_handle); container_of(poll_handle, struct nvnc, poll_handle);
struct nvnc_client* client = calloc(1, sizeof(*client)); struct nvnc_client* client = calloc(1, sizeof(*client));
if (!client) if (!client)
@ -557,6 +717,14 @@ static void on_connection(uv_stream_t* server_stream, int status)
client->ref = 1; client->ref = 1;
client->server = server; client->server = server;
int fd = accept(server->fd, NULL, 0);
if (fd < 0)
goto accept_failure;
client->net_stream = stream_new(fd, on_client_event, client);
if (!client->net_stream)
goto stream_failure;
int rc = deflateInit2(&client->z_stream, int rc = deflateInit2(&client->z_stream,
/* compression level: */ 1, /* compression level: */ 1,
/* method: */ Z_DEFLATED, /* method: */ Z_DEFLATED,
@ -564,51 +732,35 @@ static void on_connection(uv_stream_t* server_stream, int status)
/* mem level: */ 9, /* mem level: */ 9,
/* strategy: */ Z_DEFAULT_STRATEGY); /* strategy: */ Z_DEFAULT_STRATEGY);
if (rc != Z_OK) { if (rc != Z_OK)
free(client); goto deflate_failure;
return;
}
pixman_region_init(&client->damage); pixman_region_init(&client->damage);
uv_tcp_init(uv_default_loop(), &client->stream_handle); struct rcbuf* payload = rcbuf_from_string(RFB_VERSION_MESSAGE);
if (!payload)
goto payload_failure;
uv_accept((uv_stream_t*)&server->tcp_handle, stream_write(client->net_stream, payload, NULL, NULL);
(uv_stream_t*)&client->stream_handle);
uv_read_start((uv_stream_t*)&client->stream_handle,
allocate_read_buffer, on_client_read);
vnc__write((uv_stream_t*)&client->stream_handle, RFB_VERSION_MESSAGE,
strlen(RFB_VERSION_MESSAGE), NULL);
LIST_INSERT_HEAD(&server->clients, client, link); LIST_INSERT_HEAD(&server->clients, client, link);
client->state = VNC_CLIENT_STATE_WAITING_FOR_VERSION; client->state = VNC_CLIENT_STATE_WAITING_FOR_VERSION;
}
int vnc_server_init(struct nvnc* self, const char* address, int port) log_debug("New client connection: %p (ref %d)\n", client, client->ref);
{
LIST_INIT(&self->clients);
uv_tcp_init(uv_default_loop(), &self->tcp_handle); return;
struct sockaddr_in addr = {0}; payload_failure:
addr.sin_family = AF_INET; deflateEnd(&client->z_stream);
addr.sin_addr.s_addr = inet_addr(address); deflate_failure:
addr.sin_port = htons(port); stream_destroy(client->net_stream);
stream_failure:
close(fd);
accept_failure:
free(client);
if (uv_tcp_bind(&self->tcp_handle, (const struct sockaddr*)&addr, 0) < 0) log_debug("Failed to accept a connection\n");
goto failure;
if (uv_listen((uv_stream_t*)&self->tcp_handle, 16, on_connection) < 0)
goto failure;
return 0;
failure:
uv_unref((uv_handle_t*)&self->tcp_handle);
return -1;
} }
EXPORT EXPORT
@ -622,22 +774,32 @@ struct nvnc* nvnc_open(const char* address, uint16_t port)
LIST_INIT(&self->clients); LIST_INIT(&self->clients);
uv_tcp_init(uv_default_loop(), &self->tcp_handle); self->fd = socket(AF_INET, SOCK_STREAM, 0);
if (self->fd < 0)
return NULL;
int one = 1;
if (setsockopt(self->fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(int)) < 0)
goto failure;
struct sockaddr_in addr = {0}; struct sockaddr_in addr = {0};
addr.sin_family = AF_INET; addr.sin_family = AF_INET;
addr.sin_addr.s_addr = inet_addr(address); addr.sin_addr.s_addr = inet_addr(address);
addr.sin_port = htons(port); addr.sin_port = htons(port);
if (uv_tcp_bind(&self->tcp_handle, (const struct sockaddr*)&addr, 0) < 0) if (bind(self->fd, (const struct sockaddr*)&addr, sizeof(addr)) < 0)
goto failure; goto failure;
if (uv_listen((uv_stream_t*)&self->tcp_handle, 16, on_connection) < 0) if (listen(self->fd, 16) < 0)
goto failure; goto failure;
uv_poll_init(uv_default_loop(), &self->poll_handle, self->fd);
uv_poll_start(&self->poll_handle, UV_READABLE, on_connection);
return self; return self;
failure: failure:
uv_unref((uv_handle_t*)&self->tcp_handle); close(self->fd);
return NULL; return NULL;
} }
@ -649,19 +811,28 @@ void nvnc_close(struct nvnc* self)
if (self->frame) if (self->frame)
nvnc_fb_unref(self->frame); nvnc_fb_unref(self->frame);
LIST_FOREACH (client, &self->clients, link) struct nvnc_client* tmp;
LIST_FOREACH_SAFE (client, &self->clients, link, tmp)
client_unref(client); client_unref(client);
uv_unref((uv_handle_t*)&self->tcp_handle); uv_poll_stop(&self->poll_handle);
close(self->fd);
#ifdef ENABLE_TLS
if (self->tls_creds) {
gnutls_certificate_free_credentials(self->tls_creds);
gnutls_global_deinit();
}
#endif
free(self); free(self);
} }
static void on_write_frame_done(uv_write_t* req, int status) static void on_write_frame_done(void* userdata, enum stream_req_status status)
{ {
struct vnc_write_request* rq = (struct vnc_write_request*)req; struct nvnc_client* client = userdata;
struct nvnc_client* client = rq->userdata;
client->is_updating = false; client->is_updating = false;
free(rq->buffer.base); client_unref(client);
} }
enum rfb_encodings choose_frame_encoding(struct nvnc_client* client) enum rfb_encodings choose_frame_encoding(struct nvnc_client* client)
@ -689,7 +860,7 @@ void do_client_update_fb(uv_work_t* work)
enum rfb_encodings encoding = choose_frame_encoding(client); enum rfb_encodings encoding = choose_frame_encoding(client);
if (encoding == -1) { if (encoding == -1) {
uv_read_stop((uv_stream_t*)&client->stream_handle); stream_close(client->net_stream);
client_unref(client); client_unref(client);
return; return;
} }
@ -725,21 +896,28 @@ void on_client_update_fb_done(uv_work_t* work, int status)
struct fb_update_work* update = (void*)work; struct fb_update_work* update = (void*)work;
struct nvnc_client* client = update->client; struct nvnc_client* client = update->client;
struct nvnc* server = client->server;
struct vec* frame = &update->frame; struct vec* frame = &update->frame;
if (!uv_is_closing((uv_handle_t*)&client->stream_handle)) client_ref(client);
vnc__write2((uv_stream_t*)&client->stream_handle, frame->data,
frame->len, on_write_frame_done, client);
else if (client->net_stream->state != STREAM_STATE_CLOSED) {
struct rcbuf* payload = rcbuf_new(frame->data, frame->len);
stream_write(client->net_stream, payload, on_write_frame_done,
client);
} else {
client->is_updating = false; client->is_updating = false;
vec_destroy(frame);
client_unref(client);
}
client->n_pending_requests--; client->n_pending_requests--;
process_fb_update_requests(client); process_fb_update_requests(client);
nvnc_fb_unref(update->fb); nvnc_fb_unref(update->fb);
client_unref(client);
pixman_region_fini(&update->region); pixman_region_fini(&update->region);
client_unref(client);
free(update); free(update);
} }
@ -798,8 +976,9 @@ int nvnc_feed_frame(struct nvnc* self, struct nvnc_fb* fb,
self->frame = fb; self->frame = fb;
nvnc_fb_ref(self->frame); nvnc_fb_ref(self->frame);
LIST_FOREACH (client, &self->clients, link) { struct nvnc_client* tmp;
if (uv_is_closing((uv_handle_t*)&client->stream_handle)) LIST_FOREACH_SAFE (client, &self->clients, link, tmp) {
if (client->net_stream->state == STREAM_STATE_CLOSED)
continue; continue;
pixman_region_union(&client->damage, &client->damage, pixman_region_union(&client->damage, &client->damage,
@ -878,3 +1057,61 @@ void nvnc_set_name(struct nvnc* self, const char* name)
strncpy(self->display.name, name, sizeof(self->display.name)); strncpy(self->display.name, name, sizeof(self->display.name));
self->display.name[sizeof(self->display.name) - 1] = '\0'; self->display.name[sizeof(self->display.name) - 1] = '\0';
} }
EXPORT
bool nvnc_has_auth(void)
{
#ifdef ENABLE_TLS
return true;
#else
return false;
#endif
}
EXPORT
int nvnc_enable_auth(struct nvnc* self, const char* privkey_path,
const char* cert_path, nvnc_auth_fn auth_fn,
void* userdata)
{
#ifdef ENABLE_TLS
if (self->tls_creds)
return -1;
/* Note: This is globally reference counted, so we don't need to worry
* about messing with other libraries.
*/
int rc = gnutls_global_init();
if (rc != GNUTLS_E_SUCCESS) {
log_error("GnuTLS: Failed to initialise: %s\n",
gnutls_strerror(rc));
return -1;
}
rc = gnutls_certificate_allocate_credentials(&self->tls_creds);
if (rc != GNUTLS_E_SUCCESS) {
log_error("GnuTLS: Failed to allocate credentials: %s\n",
gnutls_strerror(rc));
goto cert_alloc_failure;
}
rc = gnutls_certificate_set_x509_key_file(
self->tls_creds, cert_path, privkey_path, GNUTLS_X509_FMT_PEM);
if (rc != GNUTLS_E_SUCCESS) {
log_error("GnuTLS: Failed to load credentials: %s\n",
gnutls_strerror(rc));
goto cert_set_failure;
}
self->auth_fn = auth_fn;
self->auth_ud = userdata;
return 0;
cert_set_failure:
gnutls_certificate_free_credentials(self->tls_creds);
self->tls_creds = NULL;
cert_alloc_failure:
gnutls_global_deinit();
#endif
return -1;
}