diff --git a/include/stream.h b/include/stream.h index 68c7e71..24b5713 100644 --- a/include/stream.h +++ b/include/stream.h @@ -19,6 +19,7 @@ #include "config.h" #include "sys/queue.h" #include "rcbuf.h" +#include "vec.h" #include #include @@ -45,6 +46,7 @@ enum stream_event { }; struct stream; +struct crypto_cipher; typedef void (*stream_event_fn)(struct stream*, enum stream_event); typedef void (*stream_req_fn)(void*, enum stream_req_status); @@ -68,6 +70,7 @@ struct stream_impl { stream_req_fn on_done, void* userdata); int (*send_first)(struct stream*, struct rcbuf* payload); void (*exec_and_send)(struct stream*, stream_exec_fn, void* userdata); + int (*install_cipher)(struct stream*, struct crypto_cipher*); }; struct stream { @@ -86,6 +89,9 @@ struct stream { uint32_t bytes_received; bool cork; + + struct crypto_cipher* cipher; + struct vec tmp_buf; }; #ifdef ENABLE_WEBSOCKET @@ -108,3 +114,5 @@ void stream_exec_and_send(struct stream* self, stream_exec_fn, void* userdata); #ifdef ENABLE_TLS int stream_upgrade_to_tls(struct stream* self, void* context); #endif + +int stream_install_cipher(struct stream* self, struct crypto_cipher* cipher); diff --git a/src/stream-gnutls.c b/src/stream-gnutls.c index 2b7afb1..2b4b452 100644 --- a/src/stream-gnutls.c +++ b/src/stream-gnutls.c @@ -82,6 +82,8 @@ static int stream_gnutls__flush(struct stream* base) { struct stream_gnutls* self = (struct stream_gnutls*)base; while (!TAILQ_EMPTY(&self->base.send_queue)) { + assert(self->base.state != STREAM_STATE_CLOSED); + struct stream_req* req = TAILQ_FIRST(&self->base.send_queue); ssize_t rc = gnutls_record_send(self->session, diff --git a/src/stream-tcp.c b/src/stream-tcp.c index eb94c1d..07af480 100644 --- a/src/stream-tcp.c +++ b/src/stream-tcp.c @@ -31,10 +31,23 @@ #include "stream.h" #include "stream-common.h" #include "sys/queue.h" +#include "crypto.h" +#include "neatvnc.h" static_assert(sizeof(struct stream) <= STREAM_ALLOC_SIZE, "struct stream has grown too large, increase STREAM_ALLOC_SIZE"); +static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload) +{ + uint8_t* ciphertext = malloc(payload->size); + assert(ciphertext); + crypto_cipher_encrypt(self->cipher, ciphertext, payload->payload, + payload->size); + struct rcbuf* result = rcbuf_new(ciphertext, payload->size); + rcbuf_unref(payload); + return result; +} + static int stream_tcp_close(struct stream* self) { if (self->state == STREAM_STATE_CLOSED) @@ -57,6 +70,8 @@ static int stream_tcp_close(struct stream* self) static void stream_tcp_destroy(struct stream* self) { + vec_destroy(&self->tmp_buf); + crypto_cipher_del(self->cipher); stream_close(self); aml_unref(self->handler); free(self); @@ -76,7 +91,9 @@ static int stream_tcp__flush(struct stream* self) if (req->exec) { if (req->payload) rcbuf_unref(req->payload); - req->payload = req->exec(self, req->userdata); + struct rcbuf* payload = req->exec(self, req->userdata); + req->payload = self->cipher ? + encrypt_rcbuf(self, payload) : payload; } iov[n_msgs].iov_base = req->payload->payload; @@ -187,11 +204,27 @@ static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) if (self->state != STREAM_STATE_NORMAL) return -1; - ssize_t rc = read(self->fd, dst, size); + uint8_t* read_buffer = dst; + + if (self->cipher) { + vec_reserve(&self->tmp_buf, size); + read_buffer = self->tmp_buf.data; + } + + ssize_t rc = read(self->fd, read_buffer, size); if (rc == 0) stream__remote_closed(self); if (rc > 0) self->bytes_received += rc; + + if (rc > 0 && self->cipher && !crypto_cipher_decrypt(self->cipher, dst, + read_buffer, rc)) { + nvnc_log(NVNC_LOG_ERROR, "Message authentication failed!"); + stream__remote_closed(self); + errno = EPROTO; + return -1; + } + return rc; } @@ -205,7 +238,7 @@ static int stream_tcp_send(struct stream* self, struct rcbuf* payload, if (!req) return -1; - req->payload = payload; + req->payload = self->cipher ? encrypt_rcbuf(self, payload) : payload; req->on_done = on_done; req->userdata = userdata; @@ -247,6 +280,14 @@ static void stream_tcp_exec_and_send(struct stream* self, stream_tcp__flush(self); } +static int stream_tcp_install_cipher(struct stream* self, + struct crypto_cipher* cipher) +{ + assert(!self->cipher); + self->cipher = cipher; + return 0; +} + static struct stream_impl impl = { .close = stream_tcp_close, .destroy = stream_tcp_destroy, @@ -254,6 +295,7 @@ static struct stream_impl impl = { .send = stream_tcp_send, .send_first = stream_tcp_send_first, .exec_and_send = stream_tcp_exec_and_send, + .install_cipher = stream_tcp_install_cipher, }; struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) diff --git a/src/stream.c b/src/stream.c index 457035c..01a51dd 100644 --- a/src/stream.c +++ b/src/stream.c @@ -65,3 +65,11 @@ void stream_exec_and_send(struct stream* self, stream_exec_fn exec_fn, else stream_send(self, exec_fn(self, userdata), NULL, NULL); } + +int stream_install_cipher(struct stream* self, struct crypto_cipher* cipher) +{ + if (!self->impl->install_cipher) { + return -1; + } + return self->impl->install_cipher(self, cipher); +}