From 979d10ce62c78ba015143831f613b9efcf97a100 Mon Sep 17 00:00:00 2001 From: Andri Yngvason Date: Thu, 6 Apr 2023 11:50:44 +0000 Subject: [PATCH] Turn stream into abstract interface class --- include/stream-common.h | 39 ++++ include/stream.h | 19 +- meson.build | 3 + src/stream-common.c | 37 ++++ src/stream-gnutls.c | 271 ++++++++++++++++++++++++++ src/stream-tcp.c | 237 +++++++++++++++++++++++ src/stream.c | 419 +--------------------------------------- 7 files changed, 610 insertions(+), 415 deletions(-) create mode 100644 include/stream-common.h create mode 100644 src/stream-common.c create mode 100644 src/stream-gnutls.c create mode 100644 src/stream-tcp.c diff --git a/include/stream-common.h b/include/stream-common.h new file mode 100644 index 0000000..2c92552 --- /dev/null +++ b/include/stream-common.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020 - 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include "stream.h" + +#include + +static inline void stream__poll_r(struct stream* self) +{ + aml_set_event_mask(self->handler, AML_EVENT_READ); +} + +static inline void stream__poll_w(struct stream* self) +{ + aml_set_event_mask(self->handler, AML_EVENT_WRITE); +} + +static inline void stream__poll_rw(struct stream* self) +{ + aml_set_event_mask(self->handler, AML_EVENT_READ | AML_EVENT_WRITE); +} + +void stream_req__finish(struct stream_req* req, enum stream_req_status status); +void stream__remote_closed(struct stream* self); diff --git a/include/stream.h b/include/stream.h index 734391b..f166745 100644 --- a/include/stream.h +++ b/include/stream.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Andri Yngvason + * Copyright (c) 2020 - 2023 Andri Yngvason * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -14,6 +14,8 @@ * PERFORMANCE OF THIS SOFTWARE. */ +#pragma once + #include "config.h" #include "sys/queue.h" #include "rcbuf.h" @@ -33,11 +35,6 @@ enum stream_state { #endif }; -enum stream_status { - STREAM_READY = 0, - STREAM_CLOSED, -}; - enum stream_req_status { STREAM_REQ_DONE = 0, STREAM_REQ_FAILED, @@ -62,7 +59,17 @@ struct stream_req { TAILQ_HEAD(stream_send_queue, stream_req); +struct stream_impl { + int (*close)(struct stream*); + void (*destroy)(struct stream*); + ssize_t (*read)(struct stream*, void* dst, size_t size); + int (*send)(struct stream*, struct rcbuf* payload, + stream_req_fn on_done, void* userdata); +}; + struct stream { + struct stream_impl *impl; + enum stream_state state; int fd; diff --git a/meson.build b/meson.build index 270152c..1519f1f 100644 --- a/meson.build +++ b/meson.build @@ -77,6 +77,8 @@ sources = [ 'src/fb_pool.c', 'src/rcbuf.c', 'src/stream.c', + 'src/stream-common.c', + 'src/stream-tcp.c', 'src/desktop-layout.c', 'src/display.c', 'src/tight.c', @@ -107,6 +109,7 @@ if libturbojpeg.found() endif if gnutls.found() + sources += 'src/stream-gnutls.c' dependencies += gnutls config.set('ENABLE_TLS', true) endif diff --git a/src/stream-common.c b/src/stream-common.c new file mode 100644 index 0000000..63f17c0 --- /dev/null +++ b/src/stream-common.c @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 - 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include "stream.h" +#include "stream-common.h" + +#include + +void stream_req__finish(struct stream_req* req, enum stream_req_status status) +{ + if (req->on_done) + req->on_done(req->userdata, status); + + rcbuf_unref(req->payload); + free(req); +} + +void stream__remote_closed(struct stream* self) +{ + stream_close(self); + + if (self->on_event) + self->on_event(self, STREAM_EVENT_REMOTE_CLOSED); +} diff --git a/src/stream-gnutls.c b/src/stream-gnutls.c new file mode 100644 index 0000000..27cd24e --- /dev/null +++ b/src/stream-gnutls.c @@ -0,0 +1,271 @@ +/* + * Copyright (c) 2020 - 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "rcbuf.h" +#include "stream.h" +#include "stream-common.h" +#include "sys/queue.h" + +static int stream__try_tls_accept(struct stream* self); + +static int stream_gnutls_close(struct stream* self) +{ + if (self->state == STREAM_STATE_CLOSED) + return -1; + + self->state = STREAM_STATE_CLOSED; + + while (!TAILQ_EMPTY(&self->send_queue)) { + struct stream_req* req = TAILQ_FIRST(&self->send_queue); + TAILQ_REMOVE(&self->send_queue, req, link); + stream_req__finish(req, STREAM_REQ_FAILED); + } + + if (self->tls_session) + gnutls_deinit(self->tls_session); + self->tls_session = NULL; + + aml_stop(aml_get_default(), self->handler); + close(self->fd); + self->fd = -1; + + return 0; +} + +static void stream_gnutls_destroy(struct stream* self) +{ + stream_close(self); + aml_unref(self->handler); +} + +static int stream_gnutls__flush(struct stream* self) +{ + while (!TAILQ_EMPTY(&self->send_queue)) { + struct stream_req* req = TAILQ_FIRST(&self->send_queue); + + ssize_t rc = gnutls_record_send( + self->tls_session, req->payload->payload, + req->payload->size); + if (rc < 0) { + gnutls_record_discard_queued(self->tls_session); + if (gnutls_error_is_fatal(rc)) + stream_close(self); + return -1; + } + + self->bytes_sent += rc; + + ssize_t remaining = req->payload->size - rc; + + if (remaining > 0) { + char* p = req->payload->payload; + size_t s = req->payload->size; + memmove(p, p + s - remaining, remaining); + req->payload->size = remaining; + stream__poll_rw(self); + return 1; + } + + assert(remaining == 0); + + TAILQ_REMOVE(&self->send_queue, req, link); + stream_req__finish(req, STREAM_REQ_DONE); + } + + if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED) + stream__poll_r(self); + + return 1; +} + +static void stream_gnutls__on_readable(struct stream* self) +{ + switch (self->state) { + case STREAM_STATE_NORMAL: + /* fallthrough */ + case STREAM_STATE_TLS_READY: + if (self->on_event) + self->on_event(self, STREAM_EVENT_READ); + break; + case STREAM_STATE_TLS_HANDSHAKE: + stream__try_tls_accept(self); + break; + case STREAM_STATE_CLOSED: + break; + } +} + +static void stream_gnutls__on_writable(struct stream* self) +{ + switch (self->state) { + case STREAM_STATE_NORMAL: + /* fallthrough */ + case STREAM_STATE_TLS_READY: + stream_gnutls__flush(self); + break; + case STREAM_STATE_TLS_HANDSHAKE: + stream__try_tls_accept(self); + break; + case STREAM_STATE_CLOSED: + break; + } +} + +static void stream_gnutls__on_event(void* obj) +{ + struct stream* self = aml_get_userdata(obj); + uint32_t events = aml_get_revents(obj); + + if (events & AML_EVENT_READ) + stream_gnutls__on_readable(self); + + if (events & AML_EVENT_WRITE) + stream_gnutls__on_writable(self); +} + +static int stream_gnutls_send(struct stream* self, struct rcbuf* payload, + stream_req_fn on_done, void* userdata) +{ + if (self->state == STREAM_STATE_CLOSED) + return -1; + + struct stream_req* req = calloc(1, sizeof(*req)); + if (!req) + return -1; + + req->payload = payload; + req->on_done = on_done; + req->userdata = userdata; + + TAILQ_INSERT_TAIL(&self->send_queue, req, link); + + return stream_gnutls__flush(self); +} + +static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size) +{ + ssize_t rc = gnutls_record_recv(self->tls_session, dst, size); + if (rc == 0) { + stream__remote_closed(self); + return rc; + } + if (rc > 0) { + self->bytes_received += rc; + return rc; + } + + switch (rc) { + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + default: + errno = 0; + break; + } + + // Make sure data wasn't being written. + assert(gnutls_record_get_direction(self->tls_session) == 0); + return -1; +} + +static int stream__try_tls_accept(struct stream* self) +{ + int rc; + + rc = gnutls_handshake(self->tls_session); + if (rc == GNUTLS_E_SUCCESS) { + self->state = STREAM_STATE_TLS_READY; + stream__poll_r(self); + return 0; + } + + if (gnutls_error_is_fatal(rc)) { + aml_stop(aml_get_default(), self->handler); + return -1; + } + + int was_writing = gnutls_record_get_direction(self->tls_session); + if (was_writing) + stream__poll_w(self); + else + stream__poll_r(self); + + self->state = STREAM_STATE_TLS_HANDSHAKE; + return 0; +} + +static struct stream_impl impl = { + .close = stream_gnutls_close, + .destroy = stream_gnutls_destroy, + .read = stream_gnutls_read, + .send = stream_gnutls_send, +}; + +int stream_upgrade_to_tls(struct stream* self, void* context) +{ + int rc; + + rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK); + if (rc != GNUTLS_E_SUCCESS) + return -1; + + rc = gnutls_set_default_priority(self->tls_session); + if (rc != GNUTLS_E_SUCCESS) + goto failure; + + rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE, + context); + if (rc != GNUTLS_E_SUCCESS) + goto failure; + + aml_stop(aml_get_default(), self->handler); + aml_unref(self->handler); + + self->handler = aml_handler_new(self->fd, stream_gnutls__on_event, self, + free); + assert(self->handler); + + rc = aml_start(aml_get_default(), self->handler); + assert(rc >= 0); + + gnutls_transport_set_int(self->tls_session, self->fd); + + self->impl = &impl; + + return stream__try_tls_accept(self); + +failure: + gnutls_deinit(self->tls_session); + return -1; +} diff --git a/src/stream-tcp.c b/src/stream-tcp.c new file mode 100644 index 0000000..7cff31f --- /dev/null +++ b/src/stream-tcp.c @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2020 - 2023 Andri Yngvason + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rcbuf.h" +#include "stream.h" +#include "stream-common.h" +#include "sys/queue.h" + +static int stream_tcp_close(struct stream* self) +{ + if (self->state == STREAM_STATE_CLOSED) + return -1; + + self->state = STREAM_STATE_CLOSED; + + while (!TAILQ_EMPTY(&self->send_queue)) { + struct stream_req* req = TAILQ_FIRST(&self->send_queue); + TAILQ_REMOVE(&self->send_queue, req, link); + stream_req__finish(req, STREAM_REQ_FAILED); + } + + aml_stop(aml_get_default(), self->handler); + close(self->fd); + self->fd = -1; + + return 0; +} + +static void stream_tcp_destroy(struct stream* self) +{ + stream_close(self); + aml_unref(self->handler); +} + +static int stream_tcp__flush(struct stream* self) +{ + static struct iovec iov[IOV_MAX]; + size_t n_msgs = 0; + ssize_t bytes_sent; + + struct stream_req* req; + TAILQ_FOREACH(req, &self->send_queue, link) { + iov[n_msgs].iov_base = req->payload->payload; + iov[n_msgs].iov_len = req->payload->size; + + if (++n_msgs >= IOV_MAX) + break; + } + + if (n_msgs == 0) + return 0; + + struct msghdr msghdr = { + .msg_iov = iov, + .msg_iovlen = n_msgs, + }; + bytes_sent = sendmsg(self->fd, &msghdr, MSG_NOSIGNAL); + if (bytes_sent < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + stream__poll_rw(self); + errno = EAGAIN; + } else if (errno == EPIPE) { + stream__remote_closed(self); + errno = EPIPE; + } + + return bytes_sent; + } + + self->bytes_sent += bytes_sent; + + ssize_t bytes_left = bytes_sent; + + struct stream_req* tmp; + TAILQ_FOREACH_SAFE(req, &self->send_queue, link, tmp) { + bytes_left -= req->payload->size; + + if (bytes_left >= 0) { + TAILQ_REMOVE(&self->send_queue, req, link); + stream_req__finish(req, STREAM_REQ_DONE); + } else { + char* p = req->payload->payload; + size_t s = req->payload->size; + memmove(p, p + s + bytes_left, -bytes_left); + req->payload->size = -bytes_left; + stream__poll_rw(self); + } + + if (bytes_left <= 0) + break; + } + + if (bytes_left == 0 && self->state != STREAM_STATE_CLOSED) + stream__poll_r(self); + + assert(bytes_left <= 0); + + return bytes_sent; +} + +static void stream_tcp__on_readable(struct stream* self) +{ + switch (self->state) { + case STREAM_STATE_NORMAL: + /* fallthrough */ + if (self->on_event) + self->on_event(self, STREAM_EVENT_READ); + break; + case STREAM_STATE_CLOSED: + break; + default:; + } +} + +static void stream_tcp__on_writable(struct stream* self) +{ + switch (self->state) { + case STREAM_STATE_NORMAL: + /* fallthrough */ + stream_tcp__flush(self); + break; + case STREAM_STATE_CLOSED: + break; + default:; + } +} + +static void stream_tcp__on_event(void* obj) +{ + struct stream* self = aml_get_userdata(obj); + uint32_t events = aml_get_revents(obj); + + if (events & AML_EVENT_READ) + stream_tcp__on_readable(self); + + if (events & AML_EVENT_WRITE) + stream_tcp__on_writable(self); +} + +static ssize_t stream_tcp_read(struct stream* self, void* dst, size_t size) +{ + if (self->state != STREAM_STATE_NORMAL) + return -1; + + ssize_t rc = read(self->fd, dst, size); + if (rc == 0) + stream__remote_closed(self); + if (rc > 0) + self->bytes_received += rc; + return rc; +} + +static int stream_tcp_send(struct stream* self, struct rcbuf* payload, + stream_req_fn on_done, void* userdata) +{ + if (self->state == STREAM_STATE_CLOSED) + return -1; + + struct stream_req* req = calloc(1, sizeof(*req)); + if (!req) + return -1; + + req->payload = payload; + req->on_done = on_done; + req->userdata = userdata; + + TAILQ_INSERT_TAIL(&self->send_queue, req, link); + + return stream_tcp__flush(self); +} + +static struct stream_impl impl = { + .close = stream_tcp_close, + .destroy = stream_tcp_destroy, + .read = stream_tcp_read, + .send = stream_tcp_send, +}; + +struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) +{ + struct stream* self = calloc(1, sizeof(*self)); + if (!self) + return NULL; + + self->impl = &impl, + self->fd = fd; + self->on_event = on_event; + self->userdata = userdata; + + TAILQ_INIT(&self->send_queue); + + fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); + + self->handler = aml_handler_new(fd, stream_tcp__on_event, self, free); + if (!self->handler) + goto failure; + + if (aml_start(aml_get_default(), self->handler) < 0) + goto start_failure; + + stream__poll_r(self); + + return self; + +start_failure: + aml_unref(self->handler); + self = NULL; /* Handled in unref */ +failure: + free(self); + return NULL; +} diff --git a/src/stream.c b/src/stream.c index 7b54be6..7f762e1 100644 --- a/src/stream.c +++ b/src/stream.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Andri Yngvason + * Copyright (c) 2023 Andri Yngvason * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -14,323 +14,27 @@ * PERFORMANCE OF THIS SOFTWARE. */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef ENABLE_TLS -#include -#endif - -#include "rcbuf.h" #include "stream.h" -#include "sys/queue.h" -static void stream__on_event(void* obj); -#ifdef ENABLE_TLS -static int stream__try_tls_accept(struct stream* self); -#endif - -static inline void stream__poll_r(struct stream* self) -{ - aml_set_event_mask(self->handler, AML_EVENT_READ); -} - -static inline void stream__poll_w(struct stream* self) -{ - aml_set_event_mask(self->handler, AML_EVENT_WRITE); -} - -static inline void stream__poll_rw(struct stream* self) -{ - aml_set_event_mask(self->handler, AML_EVENT_READ | AML_EVENT_WRITE); -} - -static void stream_req__finish(struct stream_req* req, - enum stream_req_status status) -{ - if (req->on_done) - req->on_done(req->userdata, status); - - rcbuf_unref(req->payload); - free(req); -} +#include int stream_close(struct stream* self) { - if (self->state == STREAM_STATE_CLOSED) - return -1; - - self->state = STREAM_STATE_CLOSED; - - while (!TAILQ_EMPTY(&self->send_queue)) { - struct stream_req* req = TAILQ_FIRST(&self->send_queue); - TAILQ_REMOVE(&self->send_queue, req, link); - stream_req__finish(req, STREAM_REQ_FAILED); - } - -#ifdef ENABLE_TLS - if (self->tls_session) - gnutls_deinit(self->tls_session); - self->tls_session = NULL; -#endif - - // TODO: Maybe use explicit loop object instead of the default one? - aml_stop(aml_get_default(), self->handler); - close(self->fd); - self->fd = -1; - - return 0; + assert(self->impl && self->impl->close); + return self->impl->close(self); } void stream_destroy(struct stream* self) { - stream_close(self); - aml_unref(self->handler); -} - -static void stream__remote_closed(struct stream* self) -{ - stream_close(self); - - if (self->on_event) - self->on_event(self, STREAM_EVENT_REMOTE_CLOSED); -} - -static int stream__flush_plain(struct stream* self) -{ - static struct iovec iov[IOV_MAX]; - size_t n_msgs = 0; - ssize_t bytes_sent; - - struct stream_req* req; - TAILQ_FOREACH(req, &self->send_queue, link) { - iov[n_msgs].iov_base = req->payload->payload; - iov[n_msgs].iov_len = req->payload->size; - - if (++n_msgs >= IOV_MAX) - break; - } - - if (n_msgs == 0) - return 0; - - struct msghdr msghdr = { - .msg_iov = iov, - .msg_iovlen = n_msgs, - }; - bytes_sent = sendmsg(self->fd, &msghdr, MSG_NOSIGNAL); - if (bytes_sent < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - stream__poll_rw(self); - errno = EAGAIN; - } else if (errno == EPIPE) { - stream__remote_closed(self); - errno = EPIPE; - } - - return bytes_sent; - } - - self->bytes_sent += bytes_sent; - - ssize_t bytes_left = bytes_sent; - - struct stream_req* tmp; - TAILQ_FOREACH_SAFE(req, &self->send_queue, link, tmp) { - bytes_left -= req->payload->size; - - if (bytes_left >= 0) { - TAILQ_REMOVE(&self->send_queue, req, link); - stream_req__finish(req, STREAM_REQ_DONE); - } else { - char* p = req->payload->payload; - size_t s = req->payload->size; - memmove(p, p + s + bytes_left, -bytes_left); - req->payload->size = -bytes_left; - stream__poll_rw(self); - } - - if (bytes_left <= 0) - break; - } - - if (bytes_left == 0 && self->state != STREAM_STATE_CLOSED) - stream__poll_r(self); - - assert(bytes_left <= 0); - - return bytes_sent; -} - -#ifdef ENABLE_TLS -static int stream__flush_tls(struct stream* self) -{ - while (!TAILQ_EMPTY(&self->send_queue)) { - struct stream_req* req = TAILQ_FIRST(&self->send_queue); - - ssize_t rc = gnutls_record_send( - self->tls_session, req->payload->payload, - req->payload->size); - if (rc < 0) { - gnutls_record_discard_queued(self->tls_session); - if (gnutls_error_is_fatal(rc)) - stream_close(self); - return -1; - } - - self->bytes_sent += rc; - - ssize_t remaining = req->payload->size - rc; - - if (remaining > 0) { - char* p = req->payload->payload; - size_t s = req->payload->size; - memmove(p, p + s - remaining, remaining); - req->payload->size = remaining; - stream__poll_rw(self); - return 1; - } - - assert(remaining == 0); - - TAILQ_REMOVE(&self->send_queue, req, link); - stream_req__finish(req, STREAM_REQ_DONE); - } - - if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED) - stream__poll_r(self); - - return 1; -} -#endif - -static int stream__flush(struct stream* self) -{ - switch (self->state) { - case STREAM_STATE_NORMAL: return stream__flush_plain(self); -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_READY: return stream__flush_tls(self); -#endif - default: - break; - } - abort(); - return -1; -} - -static void stream__on_readable(struct stream* self) -{ - switch (self->state) { - case STREAM_STATE_NORMAL: - /* fallthrough */ -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_READY: -#endif - if (self->on_event) - self->on_event(self, STREAM_EVENT_READ); - break; -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_HANDSHAKE: - stream__try_tls_accept(self); - break; -#endif - case STREAM_STATE_CLOSED: - break; - } -} - -static void stream__on_writable(struct stream* self) -{ - switch (self->state) { - case STREAM_STATE_NORMAL: - /* fallthrough */ -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_READY: -#endif - stream__flush(self); - break; -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_HANDSHAKE: - stream__try_tls_accept(self); - break; -#endif - case STREAM_STATE_CLOSED: - break; - } -} - -static void stream__on_event(void* obj) -{ - struct stream* self = aml_get_userdata(obj); - uint32_t events = aml_get_revents(obj); - - if (events & AML_EVENT_READ) - stream__on_readable(self); - - if (events & AML_EVENT_WRITE) - stream__on_writable(self); -} - -struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata) -{ - struct stream* self = calloc(1, sizeof(*self)); - if (!self) - return NULL; - - self->fd = fd; - self->on_event = on_event; - self->userdata = userdata; - - TAILQ_INIT(&self->send_queue); - - fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); - - self->handler = aml_handler_new(fd, stream__on_event, self, free); - if (!self->handler) - goto failure; - - if (aml_start(aml_get_default(), self->handler) < 0) - goto start_failure; - - stream__poll_r(self); - - return self; - -start_failure: - aml_unref(self->handler); - self = NULL; /* Handled in unref */ -failure: - free(self); - return NULL; + assert(self->impl && self->impl->destroy); + return self->impl->destroy(self); } int stream_send(struct stream* self, struct rcbuf* payload, stream_req_fn on_done, void* userdata) { - if (self->state == STREAM_STATE_CLOSED) - return -1; - - struct stream_req* req = calloc(1, sizeof(*req)); - if (!req) - return -1; - - req->payload = payload; - req->on_done = on_done; - req->userdata = userdata; - - TAILQ_INSERT_TAIL(&self->send_queue, req, link); - - return stream__flush(self); + assert(self->impl && self->impl->send); + return self->impl->send(self, payload, on_done, userdata); } int stream_write(struct stream* self, const void* payload, size_t len, @@ -340,111 +44,8 @@ int stream_write(struct stream* self, const void* payload, size_t len, return buf ? stream_send(self, buf, on_done, userdata) : -1; } -static ssize_t stream__read_plain(struct stream* self, void* dst, size_t size) -{ - ssize_t rc = read(self->fd, dst, size); - if (rc == 0) - stream__remote_closed(self); - if (rc > 0) - self->bytes_received += rc; - return rc; -} - -#ifdef ENABLE_TLS -static ssize_t stream__read_tls(struct stream* self, void* dst, size_t size) -{ - ssize_t rc = gnutls_record_recv(self->tls_session, dst, size); - if (rc == 0) { - stream__remote_closed(self); - return rc; - } - if (rc > 0) { - self->bytes_received += rc; - return rc; - } - - switch (rc) { - case GNUTLS_E_INTERRUPTED: - errno = EINTR; - break; - case GNUTLS_E_AGAIN: - errno = EAGAIN; - break; - default: - errno = 0; - break; - } - - // Make sure data wasn't being written. - assert(gnutls_record_get_direction(self->tls_session) == 0); - return -1; -} -#endif - ssize_t stream_read(struct stream* self, void* dst, size_t size) { - switch (self->state) { - case STREAM_STATE_NORMAL: return stream__read_plain(self, dst, size); -#ifdef ENABLE_TLS - case STREAM_STATE_TLS_READY: return stream__read_tls(self, dst, size); -#endif - default: break; - } - - abort(); - return -1; + assert(self->impl && self->impl->read); + return self->impl->read(self, dst, size); } - -#ifdef ENABLE_TLS -static int stream__try_tls_accept(struct stream* self) -{ - int rc; - - rc = gnutls_handshake(self->tls_session); - if (rc == GNUTLS_E_SUCCESS) { - self->state = STREAM_STATE_TLS_READY; - stream__poll_r(self); - return 0; - } - - if (gnutls_error_is_fatal(rc)) { - aml_stop(aml_get_default(), self->handler); - return -1; - } - - int was_writing = gnutls_record_get_direction(self->tls_session); - if (was_writing) - stream__poll_w(self); - else - stream__poll_r(self); - - self->state = STREAM_STATE_TLS_HANDSHAKE; - return 0; -} - -int stream_upgrade_to_tls(struct stream* self, void* context) -{ - int rc; - - rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK); - if (rc != GNUTLS_E_SUCCESS) - return -1; - - rc = gnutls_set_default_priority(self->tls_session); - if (rc != GNUTLS_E_SUCCESS) - goto failure; - - rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE, - context); - if (rc != GNUTLS_E_SUCCESS) - goto failure; - - gnutls_transport_set_int(self->tls_session, self->fd); - - return stream__try_tls_accept(self); - -failure: - gnutls_deinit(self->tls_session); - return -1; -} -#endif