diff --git a/include/common.h b/include/common.h index f54ca45..431bfbe 100644 --- a/include/common.h +++ b/include/common.h @@ -97,9 +97,16 @@ struct nvnc_client { LIST_HEAD(nvnc_client_list, nvnc_client); +enum nvnc__socket_type { + NVNC__SOCKET_TCP, + NVNC__SOCKET_UNIX, + NVNC__SOCKET_WEBSOCKET, +}; + struct nvnc { struct nvnc_common common; int fd; + enum nvnc__socket_type socket_type; struct aml_handler* poll_handle; struct nvnc_client_list clients; char name[256]; diff --git a/include/http.h b/include/http.h new file mode 100644 index 0000000..84b889a --- /dev/null +++ b/include/http.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2014-2016, Marel + * Copyright (c) 2023, Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#define URL_INDEX_MAX 32 +#define URL_QUERY_INDEX_MAX 32 +#define HTTP_FIELD_INDEX_MAX 32 + +#include + +enum http_method { + HTTP_GET = 1, + HTTP_PUT = 2, + HTTP_OPTIONS = 4, +}; + +struct http_kv { + char* key; + char* value; +}; + +struct http_req { + enum http_method method; + size_t header_length; + size_t content_length; + char* content_type; + size_t url_index; + char* url[URL_INDEX_MAX]; + size_t url_query_index; + struct http_kv url_query[URL_QUERY_INDEX_MAX]; + size_t field_index; + struct http_kv field[HTTP_FIELD_INDEX_MAX]; +}; + +int http_req_parse(struct http_req* req, const char* head); +void http_req_free(struct http_req* req); + +const char* http_req_query(struct http_req* req, const char* key); diff --git a/include/neatvnc.h b/include/neatvnc.h index 402ec1c..9e68085 100644 --- a/include/neatvnc.h +++ b/include/neatvnc.h @@ -125,6 +125,7 @@ extern const char nvnc_version[]; struct nvnc* nvnc_open(const char* addr, uint16_t port); struct nvnc* nvnc_open_unix(const char *addr); +struct nvnc* nvnc_open_websocket(const char* addr, uint16_t port); void nvnc_close(struct nvnc* self); void nvnc_add_display(struct nvnc*, struct nvnc_display*); diff --git a/include/stream.h b/include/stream.h index f8ee239..c5e0dbf 100644 --- a/include/stream.h +++ b/include/stream.h @@ -91,6 +91,10 @@ struct stream { bool cork; }; +#ifdef ENABLE_WEBSOCKET +struct stream* stream_ws_new(int fd, stream_event_fn on_event, void* userdata); +#endif + struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata); int stream_close(struct stream* self); void stream_destroy(struct stream* self); diff --git a/include/websocket.h b/include/websocket.h new file mode 100644 index 0000000..e1b831a --- /dev/null +++ b/include/websocket.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +#define WS_HEADER_MIN_SIZE 14 + +enum ws_opcode { + WS_OPCODE_CONT = 0, + WS_OPCODE_TEXT, + WS_OPCODE_BIN, + WS_OPCODE_CLOSE = 8, + WS_OPCODE_PING, + WS_OPCODE_PONG, +}; + +struct ws_frame_header { + bool fin; + enum ws_opcode opcode; + bool mask; + uint64_t payload_length; + uint8_t masking_key[4]; + size_t header_length; +}; + +ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input); + +const char *ws_opcode_name(enum ws_opcode op); + +bool ws_parse_frame_header(struct ws_frame_header* header, + const uint8_t* payload, size_t length); +void ws_apply_mask(const struct ws_frame_header* header, + uint8_t* restrict payload); +void ws_copy_payload(const struct ws_frame_header* header, + uint8_t* restrict dst, const uint8_t* restrict src, size_t len); +int ws_write_frame_header(uint8_t* dst, const struct ws_frame_header* header); diff --git a/meson.build b/meson.build index 1519f1f..03b280a 100644 --- a/meson.build +++ b/meson.build @@ -49,6 +49,7 @@ libm = cc.find_library('m', required: false) pixman = dependency('pixman-1') libturbojpeg = dependency('libturbojpeg', required: get_option('jpeg')) gnutls = dependency('gnutls', required: get_option('tls')) +nettle = dependency('nettle', required: get_option('nettle')) zlib = dependency('zlib') gbm = dependency('gbm', required: get_option('gbm')) libdrm = dependency('libdrm', required: get_option('h264')) @@ -101,6 +102,8 @@ dependencies = [ libdrm_inc, ] +enable_websocket = false + config = configuration_data() if libturbojpeg.found() @@ -114,6 +117,11 @@ if gnutls.found() config.set('ENABLE_TLS', true) endif +if nettle.found() + dependencies += nettle + enable_websocket = true +endif + if host_system == 'linux' and get_option('systemtap') and cc.has_header('sys/sdt.h') config.set('HAVE_USDT', true) endif @@ -130,6 +138,16 @@ if gbm.found() and libdrm.found() and libavcodec.found() and libavfilter.found() config.set('HAVE_LIBAVUTIL', true) endif +if enable_websocket + sources += [ + 'src/ws-handshake.c', + 'src/ws-framing.c', + 'src/http.c', + 'src/stream-ws.c', + ] + config.set('ENABLE_WEBSOCKET', true) +endif + configure_file( output: 'config.h', configuration: config, diff --git a/meson_options.txt b/meson_options.txt index 0db3eb3..ce6f81c 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -3,6 +3,7 @@ option('examples', type: 'boolean', value: false, description: 'Build examples') option('tests', type: 'boolean', value: false, description: 'Build unit tests') option('jpeg', type: 'feature', value: 'auto', description: 'Enable JPEG compression') option('tls', type: 'feature', value: 'auto', description: 'Enable encryption & authentication') +option('nettle', type: 'feature', value: 'auto', description: 'Enable nettle low level encryption library') option('systemtap', type: 'boolean', value: false, description: 'Enable tracing using sdt') option('gbm', type: 'feature', value: 'auto', description: 'Enable GBM integration') option('h264', type: 'feature', value: 'auto', description: 'Enable open h264 encoding') diff --git a/src/http.c b/src/http.c new file mode 100644 index 0000000..3d941a2 --- /dev/null +++ b/src/http.c @@ -0,0 +1,572 @@ +/* Copyright (c) 2014-2016, Marel + * Copyright (c) 2023, Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include +#include "vec.h" +#include "http.h" + +enum httplex_token_type { + HTTPLEX_SOLIDUS, + HTTPLEX_CR, + HTTPLEX_LF, + HTTPLEX_WS, + HTTPLEX_LITERAL, + HTTPLEX_KEY, + HTTPLEX_VALUE, + HTTPLEX_QUERY, + HTTPLEX_AMPERSAND, + HTTPLEX_EQ, + HTTPLEX_END, +}; + +struct httplex_token { + enum httplex_token_type type; + const char* value; +}; + +enum httplex_state { + HTTPLEX_STATE_REQUEST = 0, + HTTPLEX_STATE_KEY, + HTTPLEX_STATE_VALUE, +}; + +struct httplex { + enum httplex_state state; + struct httplex_token current_token; + const char* input; + const char* pos; + const char* next_pos; + struct vec buffer; + int accepted; + int errno_; +}; + +static int httplex_init(struct httplex* self, const char* input) +{ + memset(self, 0, sizeof(*self)); + + self->input = input; + self->pos = input; + self->accepted = 1; + + if (vec_reserve(&self->buffer, 256) < 0) + return -1; + + return 0; +} + +static void httplex_destroy(struct httplex* self) +{ + vec_destroy(&self->buffer); +} + +static inline int httplex__is_literal(char c) +{ + switch (c) { + case '/': case '\r': case '\n': case ' ': case '\t': + case '?': case '&': case '=': + return 0; + } + + return isprint(c); +} + +static inline size_t httplex__literal_length(const char* str) +{ + size_t len = 0; + while (httplex__is_literal(*str++)) + ++len; + return len; +} + +static int httplex__classify_request_token(struct httplex* self) +{ + switch (*self->pos) { + case '/': + self->current_token.type = HTTPLEX_SOLIDUS; + self->next_pos = self->pos + strspn(self->pos, "/"); + return 0; + case '\r': + self->current_token.type = HTTPLEX_CR; + self->next_pos = self->pos + 1; + return 0; + case '\n': + self->current_token.type = HTTPLEX_LF; + self->next_pos = self->pos + 1; + return 0; + case '?': + self->current_token.type = HTTPLEX_QUERY; + self->next_pos = self->pos + 1; + return 0; + case '&': + self->current_token.type = HTTPLEX_AMPERSAND; + self->next_pos = self->pos + 1; + return 0; + case '=': + self->current_token.type = HTTPLEX_EQ; + self->next_pos = self->pos + 1; + return 0; + case ' ': + case '\t': + self->current_token.type = HTTPLEX_WS; + self->next_pos = self->pos + strspn(self->pos, " \t"); + return 0; + } + + if (httplex__is_literal(*self->pos)) { + self->current_token.type = HTTPLEX_LITERAL; + size_t len = httplex__literal_length(self->pos); + self->next_pos = self->pos + len; + vec_assign(&self->buffer, self->pos, len); + vec_append(&self->buffer, "", 1); + self->current_token.value = self->buffer.data; + return 0; + } + + return -1; +} + +static inline int httplex__is_key_char(char c) +{ + return isalnum(c) || c == '-'; +} + +static inline size_t httplex__key_length(const char* str) +{ + size_t len = 0; + while (httplex__is_key_char(*str++)) + ++len; + return len; +} + +static int httplex__classify_key_token(struct httplex* self) +{ + switch (*self->pos) { + case '\r': + self->current_token.type = HTTPLEX_CR; + self->next_pos = self->pos + 1; + return 0; + case '\n': + self->current_token.type = HTTPLEX_LF; + self->next_pos = self->pos + 1; + return 0; + } + + if (!httplex__is_key_char(*self->pos)) + return -1; + + size_t len = httplex__key_length(self->pos); + + if (self->pos[len] != ':') + return -1; + + len += 1; + + self->next_pos = self->pos + len; + self->next_pos += strspn(self->next_pos, " \t"); + + vec_assign(&self->buffer, self->pos, len - 1); + vec_append(&self->buffer, "", 1); + + self->current_token.type = HTTPLEX_KEY; + self->current_token.value = self->buffer.data; + return 0; +} + +static int httplex__classify_value_token(struct httplex* self) +{ + size_t len = strcspn(self->pos, "\r"); + if (strncmp(&self->pos[len], "\r\n", 2) != 0) + return -1; + + self->next_pos = self->pos + len + 2; + + vec_assign(&self->buffer, self->pos, len); + vec_append(&self->buffer, "", 1); + + self->current_token.type = HTTPLEX_VALUE; + self->current_token.value = self->buffer.data; + return 0; +} + +static int httplex__classify_token(struct httplex* self) +{ + switch (self->state) { + case HTTPLEX_STATE_REQUEST: + return httplex__classify_request_token(self); + case HTTPLEX_STATE_KEY: + return httplex__classify_key_token(self); + case HTTPLEX_STATE_VALUE: + return httplex__classify_value_token(self); + }; + + abort(); + return -1; +} + +static struct httplex_token* httplex_next_token(struct httplex* self) +{ + if (self->current_token.type == HTTPLEX_END) + return &self->current_token; + + if (!self->accepted) + return &self->current_token; + + if (self->next_pos) + self->pos = self->next_pos; + + if (httplex__classify_token(self) < 0) + return NULL; + + self->accepted = 0; + + return &self->current_token; +} + +static inline int httplex_accept_token(struct httplex* self) +{ + self->accepted = 1; + return 1; +} + +static int http__literal(struct httplex* lex, const char* str) +{ + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_LITERAL) + return 0; + + if (strcasecmp(str, tok->value) != 0) + return 0; + + return httplex_accept_token(lex); +} + +static int http__get(struct http_req* req, struct httplex* lex) +{ + if (!http__literal(lex, "GET")) + return 0; + + req->method = HTTP_GET; + return 1; +} + +static int http__put(struct http_req* req, struct httplex* lex) +{ + if (!http__literal(lex, "PUT")) + return 0; + + req->method = HTTP_PUT; + return 1; +} + +static int http__options(struct http_req* req, struct httplex* lex) +{ + if (!http__literal(lex, "OPTIONS")) + return 0; + + req->method = HTTP_OPTIONS; + return 1; +} + +static int http__method(struct http_req* req, struct httplex* lex) +{ + return http__get(req, lex) + || http__put(req, lex) + || http__options(req, lex); +} + +static int http__peek(struct httplex* lex, enum httplex_token_type type) +{ + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != type) + return 0; + + return 1; +} + +static int http__expect(struct httplex* lex, enum httplex_token_type type) +{ + return http__peek(lex, type) && httplex_accept_token(lex); +} + +static int http__version(struct httplex* lex) +{ + return http__literal(lex, "HTTP") + && http__expect(lex, HTTPLEX_SOLIDUS) + && http__literal(lex, "1.1"); +} + +static int http__url_path(struct http_req* req, struct httplex* lex) +{ + if (!http__expect(lex, HTTPLEX_SOLIDUS)) + return 0; + + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_LITERAL) + return tok->type == HTTPLEX_WS; + + if (req->url_index >= URL_INDEX_MAX) + return 0; + + char* elem = strdup(tok->value); + if (!elem) + return 0; + + req->url[req->url_index++] = elem; + + httplex_accept_token(lex); + + return http__peek(lex, HTTPLEX_SOLIDUS) + ? http__url_path(req, lex) : 1; +} + +static int http__url_query_key(struct http_req* req, struct httplex* lex) +{ + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_LITERAL) + return 0; + + if (req->url_index >= URL_INDEX_MAX) + return 0; + + char* elem = strdup(tok->value); + if (!elem) + return 0; + + req->url_query[req->url_query_index].key = elem; + + return httplex_accept_token(lex); +} + +static int http__url_query_value(struct http_req* req, struct httplex* lex) +{ + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_LITERAL) + return 0; + + if (req->url_index >= URL_INDEX_MAX) + return 0; + + char* elem = strdup(tok->value); + if (!elem) + return 0; + + req->url_query[req->url_query_index++].value = elem; + + return httplex_accept_token(lex); +} + +static int http__url_query(struct http_req* req, struct httplex* lex) +{ + return http__url_query_key(req, lex) + && http__expect(lex, HTTPLEX_EQ) + && http__url_query_value(req, lex) + && http__expect(lex, HTTPLEX_AMPERSAND) + ? http__url_query(req, lex) : 1; +} + +static int http__url(struct http_req* req, struct httplex* lex) +{ + return http__url_path(req, lex) + && http__expect(lex, HTTPLEX_QUERY) ? http__url_query(req, lex) : 1; +} + +static int http__request(struct http_req* req, struct httplex* lex) +{ + return http__method(req, lex) + && http__expect(lex, HTTPLEX_WS) + && http__url(req, lex) + && http__expect(lex, HTTPLEX_WS) + && http__version(lex) + && http__expect(lex, HTTPLEX_CR) + && http__expect(lex, HTTPLEX_LF); +} + +static int http__expect_key(struct httplex* lex, const char* key) +{ + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_KEY) + return 0; + + if (key && strcasecmp(tok->value, key) != 0) + return 0; + + return httplex_accept_token(lex); +} + +static int http__content_length(struct http_req* req, struct httplex* lex) +{ + lex->state = HTTPLEX_STATE_KEY; + if (!http__expect_key(lex, "Content-Length")) + return 0; + + lex->state = HTTPLEX_STATE_VALUE; + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_VALUE) + return 0; + + req->content_length = atoi(tok->value); + + return httplex_accept_token(lex); +} + +static int http__content_type(struct http_req* req, struct httplex* lex) +{ + lex->state = HTTPLEX_STATE_KEY; + if (!http__expect_key(lex, "Content-Type")) + return 0; + + lex->state = HTTPLEX_STATE_VALUE; + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_VALUE) + return 0; + + req->content_type = strdup(tok->value); + + return httplex_accept_token(lex); +} + +static int http__field_key(struct http_req* req, struct httplex* lex) +{ + lex->state = HTTPLEX_STATE_KEY; + + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_KEY) + return 0; + + req->field[req->field_index].key = strdup(tok->value); + + return httplex_accept_token(lex); +} + +static int http__field_value(struct http_req* req, struct httplex* lex) +{ + lex->state = HTTPLEX_STATE_VALUE; + + struct httplex_token* tok = httplex_next_token(lex); + if (!tok) + return 0; + + if (tok->type != HTTPLEX_VALUE) + return 0; + + req->field[req->field_index++].value = strdup(tok->value); + + return httplex_accept_token(lex); +} + +static int http__field_kv(struct http_req* req, struct httplex* lex) +{ + return http__field_key(req, lex) + && http__field_value(req, lex); +} + +static int http__header_kv(struct http_req* req, struct httplex* lex) +{ + return http__content_length(req, lex) + || http__content_type(req, lex) + || http__field_kv(req, lex); +} + +static int http__header(struct http_req* req, struct httplex* lex) +{ + while (http__header_kv(req, lex)); + + lex->state = HTTPLEX_STATE_KEY; + if (http__expect(lex, HTTPLEX_CR)) + return http__expect(lex, HTTPLEX_LF); + + return 1; +} + +int http_req_parse(struct http_req* req, const char* input) +{ + int rc = -1; + memset(req, 0, sizeof(*req)); + + struct httplex lex; + if (httplex_init(&lex, input) < 0) + return -1; + + if (!http__request(req, &lex)) + goto failure; + + if (!http__header(req, &lex)) + goto failure; + + req->header_length = lex.next_pos - input; + + rc = 0; +failure: + httplex_destroy(&lex); + return rc; +} + +void http_req_free(struct http_req* req) +{ + free(req->content_type); + + for (size_t i = 0; i < req->url_index; ++i) + free(req->url[i]); + + for (size_t i = 0; i < req->url_query_index; ++i) { + free(req->url_query[i].key); + free(req->url_query[i].value); + } + + for (size_t i = 0; i < req->field_index; ++i) { + free(req->field[i].key); + free(req->field[i].value); + } +} + +const char* http_req_query(struct http_req* req, const char* key) +{ + for (size_t i = 0; i < req->url_query_index; ++i) + if (strcmp(key, req->url_query[i].key) == 0) + return req->url_query[i].value; + + return NULL; +} diff --git a/src/server.c b/src/server.c index 8f24bf8..762d011 100644 --- a/src/server.c +++ b/src/server.c @@ -65,11 +65,6 @@ #define EXPORT __attribute__((visibility("default"))) -enum addrtype { - ADDRTYPE_TCP, - ADDRTYPE_UNIX, -}; - static int send_desktop_resize(struct nvnc_client* client, struct nvnc_fb* fb); static int send_qemu_key_ext_frame(struct nvnc_client* client); static enum rfb_encodings choose_frame_encoding(struct nvnc_client* client, @@ -1195,7 +1190,16 @@ static void on_connection(void* obj) record_peer_hostname(fd, client); - client->net_stream = stream_new(fd, on_client_event, client); +#ifdef ENABLE_WEBSOCKET + if (server->socket_type == NVNC__SOCKET_WEBSOCKET) + { + client->net_stream = stream_ws_new(fd, on_client_event, client); + } + else +#endif + { + client->net_stream = stream_new(fd, on_client_event, client); + } if (!client->net_stream) { nvnc_log(NVNC_LOG_WARNING, "OOM"); goto stream_failure; @@ -1326,12 +1330,14 @@ static int bind_address_unix(const char* name) return fd; } -static int bind_address(const char* name, uint16_t port, enum addrtype type) +static int bind_address(const char* name, uint16_t port, + enum nvnc__socket_type type) { switch (type) { - case ADDRTYPE_TCP: + case NVNC__SOCKET_TCP: + case NVNC__SOCKET_WEBSOCKET: return bind_address_tcp(name, port); - case ADDRTYPE_UNIX: + case NVNC__SOCKET_UNIX: return bind_address_unix(name); } @@ -1339,7 +1345,8 @@ static int bind_address(const char* name, uint16_t port, enum addrtype type) return -1; } -static struct nvnc* open_common(const char* address, uint16_t port, enum addrtype type) +static struct nvnc* open_common(const char* address, uint16_t port, + enum nvnc__socket_type type) { nvnc__log_init(); @@ -1349,6 +1356,8 @@ static struct nvnc* open_common(const char* address, uint16_t port, enum addrtyp if (!self) return NULL; + self->socket_type = type; + strcpy(self->name, DEFAULT_NAME); LIST_INIT(&self->clients); @@ -1374,7 +1383,7 @@ poll_start_failure: handle_failure: listen_failure: close(self->fd); - if (type == ADDRTYPE_UNIX) { + if (type == NVNC__SOCKET_UNIX) { unlink(address); } bind_failure: @@ -1386,13 +1395,23 @@ bind_failure: EXPORT struct nvnc* nvnc_open(const char* address, uint16_t port) { - return open_common(address, port, ADDRTYPE_TCP); + return open_common(address, port, NVNC__SOCKET_TCP); +} + +EXPORT +struct nvnc* nvnc_open_websocket(const char *address, uint16_t port) +{ +#ifdef ENABLE_WEBSOCKET + return open_common(address, port, NVNC__SOCKET_WEBSOCKET); +#else + return NULL; +#endif } EXPORT struct nvnc* nvnc_open_unix(const char* address) { - return open_common(address, 0, ADDRTYPE_UNIX); + return open_common(address, 0, NVNC__SOCKET_UNIX); } static void unlink_fd_path(int fd) diff --git a/src/stream-ws.c b/src/stream-ws.c new file mode 100644 index 0000000..2944354 --- /dev/null +++ b/src/stream-ws.c @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include "stream.h" +#include "stream-common.h" +#include "websocket.h" +#include "neatvnc.h" + +#include +#include +#include +#include +#include + +enum stream_ws_state { + STREAM_WS_STATE_HANDSHAKE = 0, + STREAM_WS_STATE_READY, +}; + +struct stream_ws { + struct stream base; + enum stream_ws_state ws_state; + struct ws_frame_header header; + enum ws_opcode current_opcode; + uint8_t read_buffer[4096]; // TODO: Is this a reasonable size? + size_t read_index; + struct stream* tcp_stream; +}; + +static int stream_ws_close(struct stream* self) +{ + struct stream_ws* ws = (struct stream_ws*)self; + self->state = STREAM_STATE_CLOSED; + return stream_close(ws->tcp_stream); +} + +static void stream_ws_destroy(struct stream* self) +{ + struct stream_ws* ws = (struct stream_ws*)self; + stream_destroy(ws->tcp_stream); + free(self); +} + +static void stream_ws_read_into_buffer(struct stream_ws* ws) +{ + ssize_t n_read = stream_read(ws->tcp_stream, + ws->read_buffer + ws->read_index, + sizeof(ws->read_buffer) - ws->read_index); + if (n_read > 0) + ws->read_index += n_read; +} + +static void stream_ws_advance_read_buffer(struct stream_ws* ws, size_t size, + size_t offset) +{ + size_t payload_len = MIN(size, ws->read_index - offset); + payload_len = MIN(payload_len, ws->header.payload_length); + + ws->read_index -= offset + payload_len; + memmove(ws->read_buffer, ws->read_buffer + offset + payload_len, + ws->read_index); + ws->header.payload_length -= payload_len; +} + +static ssize_t stream_ws_copy_payload(struct stream_ws* ws, void* dst, + size_t size, size_t offset) +{ + size_t payload_len = MIN(size, ws->read_index - offset); + payload_len = MIN(payload_len, ws->header.payload_length); + + ws_copy_payload(&ws->header, dst, ws->read_buffer + offset, payload_len); + stream_ws_advance_read_buffer(ws, size, offset); + return payload_len; +} + +static ssize_t stream_ws_process_ping(struct stream_ws* ws, size_t offset) +{ + if (offset > 0) { + // This means we're at the start, so send a header + struct ws_frame_header reply = { + .fin = true, + .opcode = WS_OPCODE_PONG, + .payload_length = ws->header.payload_length, + }; + + uint8_t buf[WS_HEADER_MIN_SIZE]; + int reply_len = ws_write_frame_header(buf, &reply); + stream_write(ws->tcp_stream, buf, reply_len, NULL, NULL); + } + + int payload_len = MIN(ws->read_index, ws->header.payload_length); + + // Feed back the payload: + stream_write(ws->tcp_stream, ws->read_buffer + offset, + payload_len, NULL, NULL); + + stream_ws_advance_read_buffer(ws, payload_len, offset); + return 0; +} + +static ssize_t stream_ws_process_payload(struct stream_ws* ws, void* dst, + size_t size, size_t offset) +{ + switch (ws->current_opcode) { + case WS_OPCODE_CONT: + // Remote end started with a continuation frame. This is + // unexpected, so we'll just close. + stream__remote_closed(ws->tcp_stream); + return 0; + case WS_OPCODE_TEXT: + // This is unexpected, but let's just ignore it... + stream_ws_advance_read_buffer(ws, SIZE_MAX, offset); + return 0; + case WS_OPCODE_BIN: + return stream_ws_copy_payload(ws, dst, size, offset); + case WS_OPCODE_CLOSE: + stream__remote_closed(ws->tcp_stream); + return 0; + case WS_OPCODE_PING: + return stream_ws_process_ping(ws, offset); + case WS_OPCODE_PONG: + // Don't care + stream_ws_advance_read_buffer(ws, SIZE_MAX, offset); + return 0; + } + return -1; +} + +/* We don't really care about framing. The binary data is just passed on as it + * arrives and it's not gathered into individual frames. + */ +static ssize_t stream_ws_read_frame(struct stream_ws* ws, void* dst, + size_t size) +{ + if (ws->header.payload_length > 0) { + nvnc_trace("Processing left-over payload chunk"); + return stream_ws_process_payload(ws, dst, size, 0); + } + + if (!ws_parse_frame_header(&ws->header, ws->read_buffer, + ws->read_index)) { + return 0; + } + + nvnc_trace("Got frame header: opcode=%s, header-len: %zu, payload-len: %zu, read-buffer-len: %zu", + ws_opcode_name(ws->header.opcode), + ws->header.header_length, ws->header.payload_length, + ws->read_index); + + if (ws->header.opcode != WS_OPCODE_CONT) { + ws->current_opcode = ws->header.opcode; + } + + // The header is located at the start of the buffer, so an offset is + // needed. + return stream_ws_process_payload(ws, dst, size, + ws->header.header_length); +} + +static ssize_t stream_ws_read_ready(struct stream_ws* ws, void* dst, + size_t size) +{ + size_t total_read = 0; + + while (true) { + ssize_t n_read = stream_ws_read_frame(ws, dst, size); + if (n_read == 0) + break; + + if (n_read < 0) { + if (errno == EAGAIN) { + break; + } + return -1; + } + + total_read += n_read; + dst += n_read; + size -= n_read; + } + + return total_read; +} + +static ssize_t stream_ws_read_handshake(struct stream_ws* ws, void* dst, + size_t size) +{ + char reply[512]; + ssize_t header_len = ws_handshake(reply, sizeof(reply), + (const char*)ws->read_buffer); + if (header_len < 0) + return 0; + + ws->tcp_stream->cork = false; + stream_send_first(ws->tcp_stream, rcbuf_from_mem(reply, strlen(reply))); + + ws->read_index -= header_len; + memmove(ws->read_buffer, ws->read_buffer + header_len, ws->read_index); + + ws->ws_state = STREAM_WS_STATE_READY; + return stream_ws_read_ready(ws, dst, size); +} + +static ssize_t stream_ws_read(struct stream* self, void* dst, size_t size) +{ + struct stream_ws* ws = (struct stream_ws*)self; + + stream_ws_read_into_buffer(ws); + if (self->state == STREAM_STATE_CLOSED) + return 0; + + switch (ws->ws_state) { + case STREAM_WS_STATE_HANDSHAKE: + return stream_ws_read_handshake(ws, dst, size); + case STREAM_WS_STATE_READY: + return stream_ws_read_ready(ws, dst, size); + } + abort(); + return -1; +} + +static int stream_ws_send(struct stream* self, struct rcbuf* payload, + stream_req_fn on_done, void* userdata) +{ + struct stream_ws* ws = (struct stream_ws*)self; + + struct ws_frame_header head = { + .fin = true, + .opcode = WS_OPCODE_BIN, + .payload_length = payload->size, + }; + + uint8_t raw_head[WS_HEADER_MIN_SIZE]; + int head_len = ws_write_frame_header(raw_head, &head); + + stream_write(ws->tcp_stream, &raw_head, head_len, NULL, NULL); + return stream_send(ws->tcp_stream, payload, on_done, userdata); +} + +static void stream_ws_event(struct stream* self, enum stream_event event) +{ + struct stream_ws* ws = self->userdata; + + if (event == STREAM_EVENT_REMOTE_CLOSED) { + ws->base.state = STREAM_STATE_CLOSED; + } + + ws->base.on_event(&ws->base, event); +} + +static struct stream_impl impl = { + .close = stream_ws_close, + .destroy = stream_ws_destroy, + .read = stream_ws_read, + .send = stream_ws_send, +}; + +struct stream* stream_ws_new(int fd, stream_event_fn on_event, void* userdata) +{ + struct stream_ws *self = calloc(1, sizeof(*self)); + if (!self) + return NULL; + + self->base.state = STREAM_STATE_NORMAL; + self->base.impl = &impl; + self->base.on_event = on_event; + self->base.userdata = userdata; + + self->tcp_stream = stream_new(fd, stream_ws_event, self); + if (!self->tcp_stream) { + free(self); + return NULL; + } + + // Don't send anything until handshake is done: + self->tcp_stream->cork = true; + + return &self->base; +} diff --git a/src/ws-framing.c b/src/ws-framing.c new file mode 100644 index 0000000..1cf36dd --- /dev/null +++ b/src/ws-framing.c @@ -0,0 +1,137 @@ +#include "websocket.h" + +#include +#include +#include +#include +#include +#include + +static inline uint64_t u64_from_network_order(uint64_t x) +{ +#if __BYTE_ORDER__ == __BIG_ENDIAN__ + return x; +#else + return __builtin_bswap64(x); +#endif +} + +static inline uint64_t u64_to_network_order(uint64_t x) +{ +#if __BYTE_ORDER__ == __BIG_ENDIAN__ + return x; +#else + return __builtin_bswap64(x); +#endif +} + +const char *ws_opcode_name(enum ws_opcode op) +{ + switch (op) { + case WS_OPCODE_CONT: return "cont"; + case WS_OPCODE_TEXT: return "text"; + case WS_OPCODE_BIN: return "bin"; + case WS_OPCODE_CLOSE: return "close"; + case WS_OPCODE_PING: return "ping"; + case WS_OPCODE_PONG: return "pong"; + } + return "INVALID"; +} + +bool ws_parse_frame_header(struct ws_frame_header* header, + const uint8_t* payload, size_t length) +{ + if (length < 2) + return false; + + int i = 0; + + header->fin = !!(payload[i] & 0x80); + header->opcode = (payload[i++] & 0x0f); + header->mask = !!(payload[i] & 0x80); + header->payload_length = payload[i++] & 0x7f; + + if (header->payload_length == 126) { + if (length - i < 2) + return false; + + uint16_t value = 0; + memcpy(&value, &payload[i], 2); + header->payload_length = ntohs(value); + i += 2; + } else if (header->payload_length == 127) { + if (length - i < 8) + return false; + + uint64_t value = 0; + memcpy(&value, &payload[i], 8); + header->payload_length = u64_from_network_order(value); + i += 8; + } + + if (header->mask) { + if (length - i < 4) + return false; + + memcpy(header->masking_key, &payload[i], 4); + i += 4; + } + + header->header_length = i; + + return true; +} + +void ws_apply_mask(const struct ws_frame_header* header, + uint8_t* restrict payload) +{ + assert(header->mask); + + uint64_t len = header->payload_length; + const uint8_t* restrict key = header->masking_key; + + for (uint64_t i = 0; i < len; ++i) { + payload[i] ^= key[i % 4]; + } +} + +void ws_copy_payload(const struct ws_frame_header* header, + uint8_t* restrict dst, const uint8_t* restrict src, size_t len) +{ + if (!header->mask) { + memcpy(dst, src, len); + return; + } + + const uint8_t* restrict key = header->masking_key; + for (uint64_t i = 0; i < len; ++i) { + dst[i] = src[i] ^ key[i % 4]; + } +} + +int ws_write_frame_header(uint8_t* dst, const struct ws_frame_header* header) +{ + int i = 0; + dst[i++] = ((uint8_t)header->fin << 7) | (header->opcode); + + if (header->payload_length <= 125) { + dst[i++] = ((uint8_t)header->mask << 7) | header->payload_length; + } else if (header->payload_length <= UINT16_MAX) { + dst[i++] = ((uint8_t)header->mask << 7) | 126; + uint16_t be = htons(header->payload_length); + memcpy(&dst[i], &be, 2); + i += 2; + } else { + dst[i++] = ((uint8_t)header->mask << 7) | 127; + uint64_t be = u64_to_network_order(header->payload_length); + memcpy(&dst[i], &be, 8); + i += 8; + } + + if (header->mask) { + memcpy(dst, header->masking_key, 4); + i += 4; + } + + return i; +} diff --git a/src/ws-handshake.c b/src/ws-handshake.c new file mode 100644 index 0000000..3f43aa3 --- /dev/null +++ b/src/ws-handshake.c @@ -0,0 +1,57 @@ +#include "websocket.h" +#include "http.h" + +#include +#include + +#include +#include +#include +#include + +static const char magic_uuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +// TODO: Do some more sanity checks on the input +ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input) +{ + bool ok = false; + struct http_req req = {}; + if (http_req_parse(&req, input) < 0) + return -1; + + const char *challenge = NULL; + for (size_t i = 0; i < req.field_index; ++i) { + if (strcasecmp(req.field[i].key, "Sec-WebSocket-Key") == 0) { + challenge = req.field[i].value; + } + } + + if (!challenge) + goto failure; + + struct sha1_ctx ctx; + sha1_init(&ctx); + sha1_update(&ctx, strlen(challenge), (const uint8_t*)challenge); + sha1_update(&ctx, strlen(magic_uuid), (const uint8_t*)magic_uuid); + + uint8_t hash[SHA1_DIGEST_SIZE]; + sha1_digest(&ctx, sizeof(hash), hash); + + char response[BASE64_ENCODE_RAW_LENGTH(SHA1_DIGEST_SIZE) + 1] = {}; + base64_encode_raw(response, SHA1_DIGEST_SIZE, hash); + + size_t len = snprintf(output, output_maxlen, + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "Sec-WebSocket-Protocol: chat\r\n" + "\r\n", + response); + + ssize_t header_len = req.header_length; + ok = len < output_maxlen; +failure: + http_req_free(&req); + return ok ? header_len : -1; +}