diff --git a/include/stream-tcp.h b/include/stream-tcp.h new file mode 100644 index 0000000..dc27d82 --- /dev/null +++ b/include/stream-tcp.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include "stream.h" + +#include + +struct stream; + +int stream_tcp_init(struct stream* self, int fd, stream_event_fn on_event, + void* userdata); +int stream_tcp_close(struct stream* self); +void stream_tcp_destroy(struct stream* self); +ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size); +int stream_tcp_send(struct stream* self, struct rcbuf* payload, + stream_req_fn on_done, void* userdata); +int stream_tcp_send_first(struct stream* self, struct rcbuf* payload); +void stream_tcp_exec_and_send(struct stream* self, + stream_exec_fn exec_fn, void* userdata); +int stream_tcp_install_cipher(struct stream* self, + struct crypto_cipher* cipher); diff --git a/src/stream-tcp.c b/src/stream-tcp.c index 9b06703..0441fab 100644 --- a/src/stream-tcp.c +++ b/src/stream-tcp.c @@ -30,6 +30,7 @@ #include "rcbuf.h" #include "stream.h" #include "stream-common.h" +#include "stream-tcp.h" #include "sys/queue.h" #include "crypto.h" #include "neatvnc.h" @@ -47,7 +48,7 @@ static struct rcbuf* encrypt_rcbuf(struct stream* self, struct rcbuf* payload) return result; } -static int stream_tcp_close(struct stream* self) +int stream_tcp_close(struct stream* self) { if (self->state == STREAM_STATE_CLOSED) return -1; @@ -67,7 +68,7 @@ static int stream_tcp_close(struct stream* self) return 0; } -static void stream_tcp_destroy(struct stream* self) +void stream_tcp_destroy(struct stream* self) { vec_destroy(&self->tmp_buf); crypto_cipher_del(self->cipher); @@ -198,7 +199,7 @@ static void stream_tcp__on_event(void* obj) stream_tcp__on_writable(self); } -static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) +ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) { if (self->state != STREAM_STATE_NORMAL) return -1; @@ -232,7 +233,7 @@ static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) return rc; } -static int stream_tcp_send(struct stream* self, struct rcbuf* payload, +int stream_tcp_send(struct stream* self, struct rcbuf* payload, stream_req_fn on_done, void* userdata) { if (self->state == STREAM_STATE_CLOSED) @@ -251,7 +252,7 @@ static int stream_tcp_send(struct stream* self, struct rcbuf* payload, return stream_tcp__flush(self); } -static int stream_tcp_send_first(struct stream* self, struct rcbuf* payload) +int stream_tcp_send_first(struct stream* self, struct rcbuf* payload) { if (self->state == STREAM_STATE_CLOSED) return -1; @@ -266,7 +267,7 @@ static int stream_tcp_send_first(struct stream* self, struct rcbuf* payload) return stream_tcp__flush(self); } -static void stream_tcp_exec_and_send(struct stream* self, +void stream_tcp_exec_and_send(struct stream* self, stream_exec_fn exec_fn, void* userdata) { if (self->state == STREAM_STATE_CLOSED) @@ -284,7 +285,7 @@ static void stream_tcp_exec_and_send(struct stream* self, stream_tcp__flush(self); } -static int stream_tcp_install_cipher(struct stream* self, +int stream_tcp_install_cipher(struct stream* self, struct crypto_cipher* cipher) { assert(!self->cipher); @@ -302,12 +303,9 @@ static struct stream_impl impl = { .install_cipher = stream_tcp_install_cipher, }; -struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) +int stream_tcp_init(struct stream* self, int fd, stream_event_fn on_event, + void* userdata) { - struct stream* self = calloc(1, STREAM_ALLOC_SIZE); - if (!self) - return NULL; - self->impl = &impl, self->fd = fd; self->on_event = on_event; @@ -319,19 +317,31 @@ struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) self->handler = aml_handler_new(fd, stream_tcp__on_event, self, NULL); if (!self->handler) - goto failure; + return -1; if (aml_start(aml_get_default(), self->handler) < 0) goto start_failure; stream__poll_r(self); - return self; + return 0; start_failure: aml_unref(self->handler); - self = NULL; /* Handled in unref */ -failure: - free(self); - return NULL; + return -1; +} + +struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) +{ + struct stream* self = calloc(1, STREAM_ALLOC_SIZE); + if (!self) + return NULL; + + if (stream_tcp_init(self, fd, on_event, userdata) < 0) { + free(self); + return NULL; + } + + return self; + } diff --git a/src/stream-ws.c b/src/stream-ws.c index e7038af..df7d11a 100644 --- a/src/stream-ws.c +++ b/src/stream-ws.c @@ -16,6 +16,7 @@ #include "stream.h" #include "stream-common.h" +#include "stream-tcp.h" #include "websocket.h" #include "vec.h" #include "neatvnc.h" @@ -38,31 +39,17 @@ struct stream_ws_exec_ctx { struct stream_ws { struct stream base; + stream_event_fn on_event; enum stream_ws_state ws_state; struct ws_frame_header header; enum ws_opcode current_opcode; - uint8_t read_buffer[4096]; // TODO: Is this a reasonable size? size_t read_index; - struct stream* tcp_stream; + uint8_t read_buffer[4096]; // TODO: Is this a reasonable size? }; -static int stream_ws_close(struct stream* self) -{ - struct stream_ws* ws = (struct stream_ws*)self; - self->state = STREAM_STATE_CLOSED; - return stream_close(ws->tcp_stream); -} - -static void stream_ws_destroy(struct stream* self) -{ - struct stream_ws* ws = (struct stream_ws*)self; - stream_destroy(ws->tcp_stream); - free(self); -} - static void stream_ws_read_into_buffer(struct stream_ws* ws) { - ssize_t n_read = stream_read(ws->tcp_stream, + ssize_t n_read = stream_tcp_read(&ws->base, ws->read_buffer + ws->read_index, sizeof(ws->read_buffer) - ws->read_index); if (n_read > 0) @@ -104,14 +91,15 @@ static ssize_t stream_ws_process_ping(struct stream_ws* ws, size_t offset) uint8_t buf[WS_HEADER_MIN_SIZE]; int reply_len = ws_write_frame_header(buf, &reply); - stream_write(ws->tcp_stream, buf, reply_len, NULL, NULL); + stream_tcp_send(&ws->base, rcbuf_from_mem(buf, reply_len), + NULL, NULL); } int payload_len = MIN(ws->read_index, ws->header.payload_length); // Feed back the payload: - stream_write(ws->tcp_stream, ws->read_buffer + offset, - payload_len, NULL, NULL); + stream_tcp_send(&ws->base, rcbuf_from_mem(ws->read_buffer + offset, + payload_len), NULL, NULL); stream_ws_advance_read_buffer(ws, payload_len, offset); return 0; @@ -124,7 +112,7 @@ static ssize_t stream_ws_process_payload(struct stream_ws* ws, void* dst, case WS_OPCODE_CONT: // Remote end started with a continuation frame. This is // unexpected, so we'll just close. - stream__remote_closed(ws->tcp_stream); + stream__remote_closed(&ws->base); return 0; case WS_OPCODE_TEXT: // This is unexpected, but let's just ignore it... @@ -133,7 +121,7 @@ static ssize_t stream_ws_process_payload(struct stream_ws* ws, void* dst, case WS_OPCODE_BIN: return stream_ws_copy_payload(ws, dst, size, offset); case WS_OPCODE_CLOSE: - stream__remote_closed(ws->tcp_stream); + stream__remote_closed(&ws->base); return 0; case WS_OPCODE_PING: return stream_ws_process_ping(ws, offset); @@ -213,8 +201,8 @@ static ssize_t stream_ws_read_handshake(struct stream_ws* ws, void* dst, if (header_len < 0) return 0; - ws->tcp_stream->cork = false; - stream_send_first(ws->tcp_stream, rcbuf_from_mem(reply, strlen(reply))); + ws->base.cork = false; + stream_tcp_send_first(&ws->base, rcbuf_from_mem(reply, strlen(reply))); ws->read_index -= header_len; memmove(ws->read_buffer, ws->read_buffer + header_len, ws->read_index); @@ -255,15 +243,16 @@ static int stream_ws_send(struct stream* self, struct rcbuf* payload, uint8_t raw_head[WS_HEADER_MIN_SIZE]; int head_len = ws_write_frame_header(raw_head, &head); - stream_write(ws->tcp_stream, &raw_head, head_len, NULL, NULL); - return stream_send(ws->tcp_stream, payload, on_done, userdata); + stream_tcp_send(&ws->base, rcbuf_from_mem(&raw_head, head_len), + NULL, NULL); + return stream_tcp_send(&ws->base, payload, on_done, userdata); } static struct rcbuf* stream_ws_chained_exec(struct stream* tcp_stream, void* userdata) { + struct stream_ws* ws = (struct stream_ws*)tcp_stream; struct stream_ws_exec_ctx* ctx = userdata; - struct stream_ws* ws = tcp_stream->userdata; struct rcbuf* buf = ctx->exec(&ws->base, ctx->userdata); @@ -293,23 +282,12 @@ static void stream_ws_exec_and_send(struct stream* self, stream_exec_fn exec, ctx->exec = exec; ctx->userdata = userdata; - stream_exec_and_send(ws->tcp_stream, stream_ws_chained_exec, ctx); -} - -static void stream_ws_event(struct stream* self, enum stream_event event) -{ - struct stream_ws* ws = self->userdata; - - if (event == STREAM_EVENT_REMOTE_CLOSED) { - ws->base.state = STREAM_STATE_CLOSED; - } - - ws->base.on_event(&ws->base, event); + stream_tcp_exec_and_send(&ws->base, stream_ws_chained_exec, ctx); } static struct stream_impl impl = { - .close = stream_ws_close, - .destroy = stream_ws_destroy, + .close = stream_tcp_close, + .destroy = stream_tcp_destroy, .read = stream_ws_read, .send = stream_ws_send, .exec_and_send = stream_ws_exec_and_send, @@ -321,19 +299,11 @@ struct stream* stream_ws_new(int fd, stream_event_fn on_event, void* userdata) if (!self) return NULL; - self->base.state = STREAM_STATE_NORMAL; - self->base.impl = &impl; - self->base.on_event = on_event; - self->base.userdata = userdata; - - self->tcp_stream = stream_new(fd, stream_ws_event, self); - if (!self->tcp_stream) { - free(self); - return NULL; - } + stream_tcp_init(&self->base, fd, on_event, userdata); + self->base.impl = &impl; // Don't send anything until handshake is done: - self->tcp_stream->cork = true; + self->base.cork = true; return &self->base; }