Implement websocket

websockets
Andri Yngvason 2023-04-06 21:02:01 +00:00
parent e385a98238
commit 8847511596
12 changed files with 1210 additions and 13 deletions

View File

@ -97,9 +97,16 @@ struct nvnc_client {
LIST_HEAD(nvnc_client_list, nvnc_client);
enum nvnc__socket_type {
NVNC__SOCKET_TCP,
NVNC__SOCKET_UNIX,
NVNC__SOCKET_WEBSOCKET,
};
struct nvnc {
struct nvnc_common common;
int fd;
enum nvnc__socket_type socket_type;
struct aml_handler* poll_handle;
struct nvnc_client_list clients;
char name[256];

52
include/http.h 100644
View File

@ -0,0 +1,52 @@
/* Copyright (c) 2014-2016, Marel
* Copyright (c) 2023, Andri Yngvason
*
* Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
* REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
* AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
* INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
* LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
* PERFORMANCE OF THIS SOFTWARE.
*/
#pragma once
#define URL_INDEX_MAX 32
#define URL_QUERY_INDEX_MAX 32
#define HTTP_FIELD_INDEX_MAX 32
#include <stddef.h>
enum http_method {
HTTP_GET = 1,
HTTP_PUT = 2,
HTTP_OPTIONS = 4,
};
struct http_kv {
char* key;
char* value;
};
struct http_req {
enum http_method method;
size_t header_length;
size_t content_length;
char* content_type;
size_t url_index;
char* url[URL_INDEX_MAX];
size_t url_query_index;
struct http_kv url_query[URL_QUERY_INDEX_MAX];
size_t field_index;
struct http_kv field[HTTP_FIELD_INDEX_MAX];
};
int http_req_parse(struct http_req* req, const char* head);
void http_req_free(struct http_req* req);
const char* http_req_query(struct http_req* req, const char* key);

View File

@ -125,6 +125,7 @@ extern const char nvnc_version[];
struct nvnc* nvnc_open(const char* addr, uint16_t port);
struct nvnc* nvnc_open_unix(const char *addr);
struct nvnc* nvnc_open_websocket(const char* addr, uint16_t port);
void nvnc_close(struct nvnc* self);
void nvnc_add_display(struct nvnc*, struct nvnc_display*);

View File

@ -91,6 +91,10 @@ struct stream {
bool cork;
};
#ifdef ENABLE_WEBSOCKET
struct stream* stream_ws_new(int fd, stream_event_fn on_event, void* userdata);
#endif
struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata);
int stream_close(struct stream* self);
void stream_destroy(struct stream* self);

View File

@ -0,0 +1,37 @@
#pragma once
#include <stdint.h>
#include <stdbool.h>
#include <unistd.h>
#define WS_HEADER_MIN_SIZE 14
enum ws_opcode {
WS_OPCODE_CONT = 0,
WS_OPCODE_TEXT,
WS_OPCODE_BIN,
WS_OPCODE_CLOSE = 8,
WS_OPCODE_PING,
WS_OPCODE_PONG,
};
struct ws_frame_header {
bool fin;
enum ws_opcode opcode;
bool mask;
uint64_t payload_length;
uint8_t masking_key[4];
size_t header_length;
};
ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input);
const char *ws_opcode_name(enum ws_opcode op);
bool ws_parse_frame_header(struct ws_frame_header* header,
const uint8_t* payload, size_t length);
void ws_apply_mask(const struct ws_frame_header* header,
uint8_t* restrict payload);
void ws_copy_payload(const struct ws_frame_header* header,
uint8_t* restrict dst, const uint8_t* restrict src, size_t len);
int ws_write_frame_header(uint8_t* dst, const struct ws_frame_header* header);

View File

@ -49,6 +49,7 @@ libm = cc.find_library('m', required: false)
pixman = dependency('pixman-1')
libturbojpeg = dependency('libturbojpeg', required: get_option('jpeg'))
gnutls = dependency('gnutls', required: get_option('tls'))
nettle = dependency('nettle', required: get_option('nettle'))
zlib = dependency('zlib')
gbm = dependency('gbm', required: get_option('gbm'))
libdrm = dependency('libdrm', required: get_option('h264'))
@ -101,6 +102,8 @@ dependencies = [
libdrm_inc,
]
enable_websocket = false
config = configuration_data()
if libturbojpeg.found()
@ -114,6 +117,11 @@ if gnutls.found()
config.set('ENABLE_TLS', true)
endif
if nettle.found()
dependencies += nettle
enable_websocket = true
endif
if host_system == 'linux' and get_option('systemtap') and cc.has_header('sys/sdt.h')
config.set('HAVE_USDT', true)
endif
@ -130,6 +138,16 @@ if gbm.found() and libdrm.found() and libavcodec.found() and libavfilter.found()
config.set('HAVE_LIBAVUTIL', true)
endif
if enable_websocket
sources += [
'src/ws-handshake.c',
'src/ws-framing.c',
'src/http.c',
'src/stream-ws.c',
]
config.set('ENABLE_WEBSOCKET', true)
endif
configure_file(
output: 'config.h',
configuration: config,

View File

@ -3,6 +3,7 @@ option('examples', type: 'boolean', value: false, description: 'Build examples')
option('tests', type: 'boolean', value: false, description: 'Build unit tests')
option('jpeg', type: 'feature', value: 'auto', description: 'Enable JPEG compression')
option('tls', type: 'feature', value: 'auto', description: 'Enable encryption & authentication')
option('nettle', type: 'feature', value: 'auto', description: 'Enable nettle low level encryption library')
option('systemtap', type: 'boolean', value: false, description: 'Enable tracing using sdt')
option('gbm', type: 'feature', value: 'auto', description: 'Enable GBM integration')
option('h264', type: 'feature', value: 'auto', description: 'Enable open h264 encoding')

572
src/http.c 100644
View File

@ -0,0 +1,572 @@
/* Copyright (c) 2014-2016, Marel
* Copyright (c) 2023, Andri Yngvason
*
* Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
* REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
* AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
* INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
* LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
* PERFORMANCE OF THIS SOFTWARE.
*/
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include "vec.h"
#include "http.h"
enum httplex_token_type {
HTTPLEX_SOLIDUS,
HTTPLEX_CR,
HTTPLEX_LF,
HTTPLEX_WS,
HTTPLEX_LITERAL,
HTTPLEX_KEY,
HTTPLEX_VALUE,
HTTPLEX_QUERY,
HTTPLEX_AMPERSAND,
HTTPLEX_EQ,
HTTPLEX_END,
};
struct httplex_token {
enum httplex_token_type type;
const char* value;
};
enum httplex_state {
HTTPLEX_STATE_REQUEST = 0,
HTTPLEX_STATE_KEY,
HTTPLEX_STATE_VALUE,
};
struct httplex {
enum httplex_state state;
struct httplex_token current_token;
const char* input;
const char* pos;
const char* next_pos;
struct vec buffer;
int accepted;
int errno_;
};
static int httplex_init(struct httplex* self, const char* input)
{
memset(self, 0, sizeof(*self));
self->input = input;
self->pos = input;
self->accepted = 1;
if (vec_reserve(&self->buffer, 256) < 0)
return -1;
return 0;
}
static void httplex_destroy(struct httplex* self)
{
vec_destroy(&self->buffer);
}
static inline int httplex__is_literal(char c)
{
switch (c) {
case '/': case '\r': case '\n': case ' ': case '\t':
case '?': case '&': case '=':
return 0;
}
return isprint(c);
}
static inline size_t httplex__literal_length(const char* str)
{
size_t len = 0;
while (httplex__is_literal(*str++))
++len;
return len;
}
static int httplex__classify_request_token(struct httplex* self)
{
switch (*self->pos) {
case '/':
self->current_token.type = HTTPLEX_SOLIDUS;
self->next_pos = self->pos + strspn(self->pos, "/");
return 0;
case '\r':
self->current_token.type = HTTPLEX_CR;
self->next_pos = self->pos + 1;
return 0;
case '\n':
self->current_token.type = HTTPLEX_LF;
self->next_pos = self->pos + 1;
return 0;
case '?':
self->current_token.type = HTTPLEX_QUERY;
self->next_pos = self->pos + 1;
return 0;
case '&':
self->current_token.type = HTTPLEX_AMPERSAND;
self->next_pos = self->pos + 1;
return 0;
case '=':
self->current_token.type = HTTPLEX_EQ;
self->next_pos = self->pos + 1;
return 0;
case ' ':
case '\t':
self->current_token.type = HTTPLEX_WS;
self->next_pos = self->pos + strspn(self->pos, " \t");
return 0;
}
if (httplex__is_literal(*self->pos)) {
self->current_token.type = HTTPLEX_LITERAL;
size_t len = httplex__literal_length(self->pos);
self->next_pos = self->pos + len;
vec_assign(&self->buffer, self->pos, len);
vec_append(&self->buffer, "", 1);
self->current_token.value = self->buffer.data;
return 0;
}
return -1;
}
static inline int httplex__is_key_char(char c)
{
return isalnum(c) || c == '-';
}
static inline size_t httplex__key_length(const char* str)
{
size_t len = 0;
while (httplex__is_key_char(*str++))
++len;
return len;
}
static int httplex__classify_key_token(struct httplex* self)
{
switch (*self->pos) {
case '\r':
self->current_token.type = HTTPLEX_CR;
self->next_pos = self->pos + 1;
return 0;
case '\n':
self->current_token.type = HTTPLEX_LF;
self->next_pos = self->pos + 1;
return 0;
}
if (!httplex__is_key_char(*self->pos))
return -1;
size_t len = httplex__key_length(self->pos);
if (self->pos[len] != ':')
return -1;
len += 1;
self->next_pos = self->pos + len;
self->next_pos += strspn(self->next_pos, " \t");
vec_assign(&self->buffer, self->pos, len - 1);
vec_append(&self->buffer, "", 1);
self->current_token.type = HTTPLEX_KEY;
self->current_token.value = self->buffer.data;
return 0;
}
static int httplex__classify_value_token(struct httplex* self)
{
size_t len = strcspn(self->pos, "\r");
if (strncmp(&self->pos[len], "\r\n", 2) != 0)
return -1;
self->next_pos = self->pos + len + 2;
vec_assign(&self->buffer, self->pos, len);
vec_append(&self->buffer, "", 1);
self->current_token.type = HTTPLEX_VALUE;
self->current_token.value = self->buffer.data;
return 0;
}
static int httplex__classify_token(struct httplex* self)
{
switch (self->state) {
case HTTPLEX_STATE_REQUEST:
return httplex__classify_request_token(self);
case HTTPLEX_STATE_KEY:
return httplex__classify_key_token(self);
case HTTPLEX_STATE_VALUE:
return httplex__classify_value_token(self);
};
abort();
return -1;
}
static struct httplex_token* httplex_next_token(struct httplex* self)
{
if (self->current_token.type == HTTPLEX_END)
return &self->current_token;
if (!self->accepted)
return &self->current_token;
if (self->next_pos)
self->pos = self->next_pos;
if (httplex__classify_token(self) < 0)
return NULL;
self->accepted = 0;
return &self->current_token;
}
static inline int httplex_accept_token(struct httplex* self)
{
self->accepted = 1;
return 1;
}
static int http__literal(struct httplex* lex, const char* str)
{
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_LITERAL)
return 0;
if (strcasecmp(str, tok->value) != 0)
return 0;
return httplex_accept_token(lex);
}
static int http__get(struct http_req* req, struct httplex* lex)
{
if (!http__literal(lex, "GET"))
return 0;
req->method = HTTP_GET;
return 1;
}
static int http__put(struct http_req* req, struct httplex* lex)
{
if (!http__literal(lex, "PUT"))
return 0;
req->method = HTTP_PUT;
return 1;
}
static int http__options(struct http_req* req, struct httplex* lex)
{
if (!http__literal(lex, "OPTIONS"))
return 0;
req->method = HTTP_OPTIONS;
return 1;
}
static int http__method(struct http_req* req, struct httplex* lex)
{
return http__get(req, lex)
|| http__put(req, lex)
|| http__options(req, lex);
}
static int http__peek(struct httplex* lex, enum httplex_token_type type)
{
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != type)
return 0;
return 1;
}
static int http__expect(struct httplex* lex, enum httplex_token_type type)
{
return http__peek(lex, type) && httplex_accept_token(lex);
}
static int http__version(struct httplex* lex)
{
return http__literal(lex, "HTTP")
&& http__expect(lex, HTTPLEX_SOLIDUS)
&& http__literal(lex, "1.1");
}
static int http__url_path(struct http_req* req, struct httplex* lex)
{
if (!http__expect(lex, HTTPLEX_SOLIDUS))
return 0;
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_LITERAL)
return tok->type == HTTPLEX_WS;
if (req->url_index >= URL_INDEX_MAX)
return 0;
char* elem = strdup(tok->value);
if (!elem)
return 0;
req->url[req->url_index++] = elem;
httplex_accept_token(lex);
return http__peek(lex, HTTPLEX_SOLIDUS)
? http__url_path(req, lex) : 1;
}
static int http__url_query_key(struct http_req* req, struct httplex* lex)
{
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_LITERAL)
return 0;
if (req->url_index >= URL_INDEX_MAX)
return 0;
char* elem = strdup(tok->value);
if (!elem)
return 0;
req->url_query[req->url_query_index].key = elem;
return httplex_accept_token(lex);
}
static int http__url_query_value(struct http_req* req, struct httplex* lex)
{
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_LITERAL)
return 0;
if (req->url_index >= URL_INDEX_MAX)
return 0;
char* elem = strdup(tok->value);
if (!elem)
return 0;
req->url_query[req->url_query_index++].value = elem;
return httplex_accept_token(lex);
}
static int http__url_query(struct http_req* req, struct httplex* lex)
{
return http__url_query_key(req, lex)
&& http__expect(lex, HTTPLEX_EQ)
&& http__url_query_value(req, lex)
&& http__expect(lex, HTTPLEX_AMPERSAND)
? http__url_query(req, lex) : 1;
}
static int http__url(struct http_req* req, struct httplex* lex)
{
return http__url_path(req, lex)
&& http__expect(lex, HTTPLEX_QUERY) ? http__url_query(req, lex) : 1;
}
static int http__request(struct http_req* req, struct httplex* lex)
{
return http__method(req, lex)
&& http__expect(lex, HTTPLEX_WS)
&& http__url(req, lex)
&& http__expect(lex, HTTPLEX_WS)
&& http__version(lex)
&& http__expect(lex, HTTPLEX_CR)
&& http__expect(lex, HTTPLEX_LF);
}
static int http__expect_key(struct httplex* lex, const char* key)
{
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_KEY)
return 0;
if (key && strcasecmp(tok->value, key) != 0)
return 0;
return httplex_accept_token(lex);
}
static int http__content_length(struct http_req* req, struct httplex* lex)
{
lex->state = HTTPLEX_STATE_KEY;
if (!http__expect_key(lex, "Content-Length"))
return 0;
lex->state = HTTPLEX_STATE_VALUE;
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_VALUE)
return 0;
req->content_length = atoi(tok->value);
return httplex_accept_token(lex);
}
static int http__content_type(struct http_req* req, struct httplex* lex)
{
lex->state = HTTPLEX_STATE_KEY;
if (!http__expect_key(lex, "Content-Type"))
return 0;
lex->state = HTTPLEX_STATE_VALUE;
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_VALUE)
return 0;
req->content_type = strdup(tok->value);
return httplex_accept_token(lex);
}
static int http__field_key(struct http_req* req, struct httplex* lex)
{
lex->state = HTTPLEX_STATE_KEY;
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_KEY)
return 0;
req->field[req->field_index].key = strdup(tok->value);
return httplex_accept_token(lex);
}
static int http__field_value(struct http_req* req, struct httplex* lex)
{
lex->state = HTTPLEX_STATE_VALUE;
struct httplex_token* tok = httplex_next_token(lex);
if (!tok)
return 0;
if (tok->type != HTTPLEX_VALUE)
return 0;
req->field[req->field_index++].value = strdup(tok->value);
return httplex_accept_token(lex);
}
static int http__field_kv(struct http_req* req, struct httplex* lex)
{
return http__field_key(req, lex)
&& http__field_value(req, lex);
}
static int http__header_kv(struct http_req* req, struct httplex* lex)
{
return http__content_length(req, lex)
|| http__content_type(req, lex)
|| http__field_kv(req, lex);
}
static int http__header(struct http_req* req, struct httplex* lex)
{
while (http__header_kv(req, lex));
lex->state = HTTPLEX_STATE_KEY;
if (http__expect(lex, HTTPLEX_CR))
return http__expect(lex, HTTPLEX_LF);
return 1;
}
int http_req_parse(struct http_req* req, const char* input)
{
int rc = -1;
memset(req, 0, sizeof(*req));
struct httplex lex;
if (httplex_init(&lex, input) < 0)
return -1;
if (!http__request(req, &lex))
goto failure;
if (!http__header(req, &lex))
goto failure;
req->header_length = lex.next_pos - input;
rc = 0;
failure:
httplex_destroy(&lex);
return rc;
}
void http_req_free(struct http_req* req)
{
free(req->content_type);
for (size_t i = 0; i < req->url_index; ++i)
free(req->url[i]);
for (size_t i = 0; i < req->url_query_index; ++i) {
free(req->url_query[i].key);
free(req->url_query[i].value);
}
for (size_t i = 0; i < req->field_index; ++i) {
free(req->field[i].key);
free(req->field[i].value);
}
}
const char* http_req_query(struct http_req* req, const char* key)
{
for (size_t i = 0; i < req->url_query_index; ++i)
if (strcmp(key, req->url_query[i].key) == 0)
return req->url_query[i].value;
return NULL;
}

View File

@ -65,11 +65,6 @@
#define EXPORT __attribute__((visibility("default")))
enum addrtype {
ADDRTYPE_TCP,
ADDRTYPE_UNIX,
};
static int send_desktop_resize(struct nvnc_client* client, struct nvnc_fb* fb);
static int send_qemu_key_ext_frame(struct nvnc_client* client);
static enum rfb_encodings choose_frame_encoding(struct nvnc_client* client,
@ -1195,7 +1190,16 @@ static void on_connection(void* obj)
record_peer_hostname(fd, client);
client->net_stream = stream_new(fd, on_client_event, client);
#ifdef ENABLE_WEBSOCKET
if (server->socket_type == NVNC__SOCKET_WEBSOCKET)
{
client->net_stream = stream_ws_new(fd, on_client_event, client);
}
else
#endif
{
client->net_stream = stream_new(fd, on_client_event, client);
}
if (!client->net_stream) {
nvnc_log(NVNC_LOG_WARNING, "OOM");
goto stream_failure;
@ -1326,12 +1330,14 @@ static int bind_address_unix(const char* name)
return fd;
}
static int bind_address(const char* name, uint16_t port, enum addrtype type)
static int bind_address(const char* name, uint16_t port,
enum nvnc__socket_type type)
{
switch (type) {
case ADDRTYPE_TCP:
case NVNC__SOCKET_TCP:
case NVNC__SOCKET_WEBSOCKET:
return bind_address_tcp(name, port);
case ADDRTYPE_UNIX:
case NVNC__SOCKET_UNIX:
return bind_address_unix(name);
}
@ -1339,7 +1345,8 @@ static int bind_address(const char* name, uint16_t port, enum addrtype type)
return -1;
}
static struct nvnc* open_common(const char* address, uint16_t port, enum addrtype type)
static struct nvnc* open_common(const char* address, uint16_t port,
enum nvnc__socket_type type)
{
nvnc__log_init();
@ -1349,6 +1356,8 @@ static struct nvnc* open_common(const char* address, uint16_t port, enum addrtyp
if (!self)
return NULL;
self->socket_type = type;
strcpy(self->name, DEFAULT_NAME);
LIST_INIT(&self->clients);
@ -1374,7 +1383,7 @@ poll_start_failure:
handle_failure:
listen_failure:
close(self->fd);
if (type == ADDRTYPE_UNIX) {
if (type == NVNC__SOCKET_UNIX) {
unlink(address);
}
bind_failure:
@ -1386,13 +1395,23 @@ bind_failure:
EXPORT
struct nvnc* nvnc_open(const char* address, uint16_t port)
{
return open_common(address, port, ADDRTYPE_TCP);
return open_common(address, port, NVNC__SOCKET_TCP);
}
EXPORT
struct nvnc* nvnc_open_websocket(const char *address, uint16_t port)
{
#ifdef ENABLE_WEBSOCKET
return open_common(address, port, NVNC__SOCKET_WEBSOCKET);
#else
return NULL;
#endif
}
EXPORT
struct nvnc* nvnc_open_unix(const char* address)
{
return open_common(address, 0, ADDRTYPE_UNIX);
return open_common(address, 0, NVNC__SOCKET_UNIX);
}
static void unlink_fd_path(int fd)

292
src/stream-ws.c 100644
View File

@ -0,0 +1,292 @@
/*
* Copyright (c) 2023 Andri Yngvason
*
* Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
* REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
* AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
* INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
* LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
* PERFORMANCE OF THIS SOFTWARE.
*/
#include "stream.h"
#include "stream-common.h"
#include "websocket.h"
#include "neatvnc.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <sys/param.h>
enum stream_ws_state {
STREAM_WS_STATE_HANDSHAKE = 0,
STREAM_WS_STATE_READY,
};
struct stream_ws {
struct stream base;
enum stream_ws_state ws_state;
struct ws_frame_header header;
enum ws_opcode current_opcode;
uint8_t read_buffer[4096]; // TODO: Is this a reasonable size?
size_t read_index;
struct stream* tcp_stream;
};
static int stream_ws_close(struct stream* self)
{
struct stream_ws* ws = (struct stream_ws*)self;
self->state = STREAM_STATE_CLOSED;
return stream_close(ws->tcp_stream);
}
static void stream_ws_destroy(struct stream* self)
{
struct stream_ws* ws = (struct stream_ws*)self;
stream_destroy(ws->tcp_stream);
free(self);
}
static void stream_ws_read_into_buffer(struct stream_ws* ws)
{
ssize_t n_read = stream_read(ws->tcp_stream,
ws->read_buffer + ws->read_index,
sizeof(ws->read_buffer) - ws->read_index);
if (n_read > 0)
ws->read_index += n_read;
}
static void stream_ws_advance_read_buffer(struct stream_ws* ws, size_t size,
size_t offset)
{
size_t payload_len = MIN(size, ws->read_index - offset);
payload_len = MIN(payload_len, ws->header.payload_length);
ws->read_index -= offset + payload_len;
memmove(ws->read_buffer, ws->read_buffer + offset + payload_len,
ws->read_index);
ws->header.payload_length -= payload_len;
}
static ssize_t stream_ws_copy_payload(struct stream_ws* ws, void* dst,
size_t size, size_t offset)
{
size_t payload_len = MIN(size, ws->read_index - offset);
payload_len = MIN(payload_len, ws->header.payload_length);
ws_copy_payload(&ws->header, dst, ws->read_buffer + offset, payload_len);
stream_ws_advance_read_buffer(ws, size, offset);
return payload_len;
}
static ssize_t stream_ws_process_ping(struct stream_ws* ws, size_t offset)
{
if (offset > 0) {
// This means we're at the start, so send a header
struct ws_frame_header reply = {
.fin = true,
.opcode = WS_OPCODE_PONG,
.payload_length = ws->header.payload_length,
};
uint8_t buf[WS_HEADER_MIN_SIZE];
int reply_len = ws_write_frame_header(buf, &reply);
stream_write(ws->tcp_stream, buf, reply_len, NULL, NULL);
}
int payload_len = MIN(ws->read_index, ws->header.payload_length);
// Feed back the payload:
stream_write(ws->tcp_stream, ws->read_buffer + offset,
payload_len, NULL, NULL);
stream_ws_advance_read_buffer(ws, payload_len, offset);
return 0;
}
static ssize_t stream_ws_process_payload(struct stream_ws* ws, void* dst,
size_t size, size_t offset)
{
switch (ws->current_opcode) {
case WS_OPCODE_CONT:
// Remote end started with a continuation frame. This is
// unexpected, so we'll just close.
stream__remote_closed(ws->tcp_stream);
return 0;
case WS_OPCODE_TEXT:
// This is unexpected, but let's just ignore it...
stream_ws_advance_read_buffer(ws, SIZE_MAX, offset);
return 0;
case WS_OPCODE_BIN:
return stream_ws_copy_payload(ws, dst, size, offset);
case WS_OPCODE_CLOSE:
stream__remote_closed(ws->tcp_stream);
return 0;
case WS_OPCODE_PING:
return stream_ws_process_ping(ws, offset);
case WS_OPCODE_PONG:
// Don't care
stream_ws_advance_read_buffer(ws, SIZE_MAX, offset);
return 0;
}
return -1;
}
/* We don't really care about framing. The binary data is just passed on as it
* arrives and it's not gathered into individual frames.
*/
static ssize_t stream_ws_read_frame(struct stream_ws* ws, void* dst,
size_t size)
{
if (ws->header.payload_length > 0) {
nvnc_trace("Processing left-over payload chunk");
return stream_ws_process_payload(ws, dst, size, 0);
}
if (!ws_parse_frame_header(&ws->header, ws->read_buffer,
ws->read_index)) {
return 0;
}
nvnc_trace("Got frame header: opcode=%s, header-len: %zu, payload-len: %zu, read-buffer-len: %zu",
ws_opcode_name(ws->header.opcode),
ws->header.header_length, ws->header.payload_length,
ws->read_index);
if (ws->header.opcode != WS_OPCODE_CONT) {
ws->current_opcode = ws->header.opcode;
}
// The header is located at the start of the buffer, so an offset is
// needed.
return stream_ws_process_payload(ws, dst, size,
ws->header.header_length);
}
static ssize_t stream_ws_read_ready(struct stream_ws* ws, void* dst,
size_t size)
{
size_t total_read = 0;
while (true) {
ssize_t n_read = stream_ws_read_frame(ws, dst, size);
if (n_read == 0)
break;
if (n_read < 0) {
if (errno == EAGAIN) {
break;
}
return -1;
}
total_read += n_read;
dst += n_read;
size -= n_read;
}
return total_read;
}
static ssize_t stream_ws_read_handshake(struct stream_ws* ws, void* dst,
size_t size)
{
char reply[512];
ssize_t header_len = ws_handshake(reply, sizeof(reply),
(const char*)ws->read_buffer);
if (header_len < 0)
return 0;
ws->tcp_stream->cork = false;
stream_send_first(ws->tcp_stream, rcbuf_from_mem(reply, strlen(reply)));
ws->read_index -= header_len;
memmove(ws->read_buffer, ws->read_buffer + header_len, ws->read_index);
ws->ws_state = STREAM_WS_STATE_READY;
return stream_ws_read_ready(ws, dst, size);
}
static ssize_t stream_ws_read(struct stream* self, void* dst, size_t size)
{
struct stream_ws* ws = (struct stream_ws*)self;
stream_ws_read_into_buffer(ws);
if (self->state == STREAM_STATE_CLOSED)
return 0;
switch (ws->ws_state) {
case STREAM_WS_STATE_HANDSHAKE:
return stream_ws_read_handshake(ws, dst, size);
case STREAM_WS_STATE_READY:
return stream_ws_read_ready(ws, dst, size);
}
abort();
return -1;
}
static int stream_ws_send(struct stream* self, struct rcbuf* payload,
stream_req_fn on_done, void* userdata)
{
struct stream_ws* ws = (struct stream_ws*)self;
struct ws_frame_header head = {
.fin = true,
.opcode = WS_OPCODE_BIN,
.payload_length = payload->size,
};
uint8_t raw_head[WS_HEADER_MIN_SIZE];
int head_len = ws_write_frame_header(raw_head, &head);
stream_write(ws->tcp_stream, &raw_head, head_len, NULL, NULL);
return stream_send(ws->tcp_stream, payload, on_done, userdata);
}
static void stream_ws_event(struct stream* self, enum stream_event event)
{
struct stream_ws* ws = self->userdata;
if (event == STREAM_EVENT_REMOTE_CLOSED) {
ws->base.state = STREAM_STATE_CLOSED;
}
ws->base.on_event(&ws->base, event);
}
static struct stream_impl impl = {
.close = stream_ws_close,
.destroy = stream_ws_destroy,
.read = stream_ws_read,
.send = stream_ws_send,
};
struct stream* stream_ws_new(int fd, stream_event_fn on_event, void* userdata)
{
struct stream_ws *self = calloc(1, sizeof(*self));
if (!self)
return NULL;
self->base.state = STREAM_STATE_NORMAL;
self->base.impl = &impl;
self->base.on_event = on_event;
self->base.userdata = userdata;
self->tcp_stream = stream_new(fd, stream_ws_event, self);
if (!self->tcp_stream) {
free(self);
return NULL;
}
// Don't send anything until handshake is done:
self->tcp_stream->cork = true;
return &self->base;
}

137
src/ws-framing.c 100644
View File

@ -0,0 +1,137 @@
#include "websocket.h"
#include <stdint.h>
#include <stdbool.h>
#include <unistd.h>
#include <string.h>
#include <assert.h>
#include <arpa/inet.h>
static inline uint64_t u64_from_network_order(uint64_t x)
{
#if __BYTE_ORDER__ == __BIG_ENDIAN__
return x;
#else
return __builtin_bswap64(x);
#endif
}
static inline uint64_t u64_to_network_order(uint64_t x)
{
#if __BYTE_ORDER__ == __BIG_ENDIAN__
return x;
#else
return __builtin_bswap64(x);
#endif
}
const char *ws_opcode_name(enum ws_opcode op)
{
switch (op) {
case WS_OPCODE_CONT: return "cont";
case WS_OPCODE_TEXT: return "text";
case WS_OPCODE_BIN: return "bin";
case WS_OPCODE_CLOSE: return "close";
case WS_OPCODE_PING: return "ping";
case WS_OPCODE_PONG: return "pong";
}
return "INVALID";
}
bool ws_parse_frame_header(struct ws_frame_header* header,
const uint8_t* payload, size_t length)
{
if (length < 2)
return false;
int i = 0;
header->fin = !!(payload[i] & 0x80);
header->opcode = (payload[i++] & 0x0f);
header->mask = !!(payload[i] & 0x80);
header->payload_length = payload[i++] & 0x7f;
if (header->payload_length == 126) {
if (length - i < 2)
return false;
uint16_t value = 0;
memcpy(&value, &payload[i], 2);
header->payload_length = ntohs(value);
i += 2;
} else if (header->payload_length == 127) {
if (length - i < 8)
return false;
uint64_t value = 0;
memcpy(&value, &payload[i], 8);
header->payload_length = u64_from_network_order(value);
i += 8;
}
if (header->mask) {
if (length - i < 4)
return false;
memcpy(header->masking_key, &payload[i], 4);
i += 4;
}
header->header_length = i;
return true;
}
void ws_apply_mask(const struct ws_frame_header* header,
uint8_t* restrict payload)
{
assert(header->mask);
uint64_t len = header->payload_length;
const uint8_t* restrict key = header->masking_key;
for (uint64_t i = 0; i < len; ++i) {
payload[i] ^= key[i % 4];
}
}
void ws_copy_payload(const struct ws_frame_header* header,
uint8_t* restrict dst, const uint8_t* restrict src, size_t len)
{
if (!header->mask) {
memcpy(dst, src, len);
return;
}
const uint8_t* restrict key = header->masking_key;
for (uint64_t i = 0; i < len; ++i) {
dst[i] = src[i] ^ key[i % 4];
}
}
int ws_write_frame_header(uint8_t* dst, const struct ws_frame_header* header)
{
int i = 0;
dst[i++] = ((uint8_t)header->fin << 7) | (header->opcode);
if (header->payload_length <= 125) {
dst[i++] = ((uint8_t)header->mask << 7) | header->payload_length;
} else if (header->payload_length <= UINT16_MAX) {
dst[i++] = ((uint8_t)header->mask << 7) | 126;
uint16_t be = htons(header->payload_length);
memcpy(&dst[i], &be, 2);
i += 2;
} else {
dst[i++] = ((uint8_t)header->mask << 7) | 127;
uint64_t be = u64_to_network_order(header->payload_length);
memcpy(&dst[i], &be, 8);
i += 8;
}
if (header->mask) {
memcpy(dst, header->masking_key, 4);
i += 4;
}
return i;
}

57
src/ws-handshake.c 100644
View File

@ -0,0 +1,57 @@
#include "websocket.h"
#include "http.h"
#include <nettle/sha1.h>
#include <nettle/base64.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
static const char magic_uuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// TODO: Do some more sanity checks on the input
ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input)
{
bool ok = false;
struct http_req req = {};
if (http_req_parse(&req, input) < 0)
return -1;
const char *challenge = NULL;
for (size_t i = 0; i < req.field_index; ++i) {
if (strcasecmp(req.field[i].key, "Sec-WebSocket-Key") == 0) {
challenge = req.field[i].value;
}
}
if (!challenge)
goto failure;
struct sha1_ctx ctx;
sha1_init(&ctx);
sha1_update(&ctx, strlen(challenge), (const uint8_t*)challenge);
sha1_update(&ctx, strlen(magic_uuid), (const uint8_t*)magic_uuid);
uint8_t hash[SHA1_DIGEST_SIZE];
sha1_digest(&ctx, sizeof(hash), hash);
char response[BASE64_ENCODE_RAW_LENGTH(SHA1_DIGEST_SIZE) + 1] = {};
base64_encode_raw(response, SHA1_DIGEST_SIZE, hash);
size_t len = snprintf(output, output_maxlen,
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"Sec-WebSocket-Protocol: chat\r\n"
"\r\n",
response);
ssize_t header_len = req.header_length;
ok = len < output_maxlen;
failure:
http_req_free(&req);
return ok ? header_len : -1;
}