diff --git a/include/stream.h b/include/stream.h index 706a842..8eaf8bf 100644 --- a/include/stream.h +++ b/include/stream.h @@ -23,10 +23,6 @@ #include #include -#ifdef ENABLE_TLS -#include -#endif - enum stream_state { STREAM_STATE_NORMAL = 0, STREAM_STATE_CLOSED, @@ -86,10 +82,6 @@ struct stream { struct stream_send_queue send_queue; -#ifdef ENABLE_TLS - gnutls_session_t tls_session; -#endif - uint32_t bytes_sent; uint32_t bytes_received; diff --git a/src/stream-gnutls.c b/src/stream-gnutls.c index ad1bc9c..416233d 100644 --- a/src/stream-gnutls.c +++ b/src/stream-gnutls.c @@ -34,28 +34,36 @@ #include "stream-common.h" #include "sys/queue.h" +struct stream_gnutls { + struct stream base; + + gnutls_session_t session; +}; + static int stream__try_tls_accept(struct stream* self); -static int stream_gnutls_close(struct stream* self) +static int stream_gnutls_close(struct stream* base) { - if (self->state == STREAM_STATE_CLOSED) + struct stream_gnutls* self = (struct stream_gnutls*)base; + + if (self->base.state == STREAM_STATE_CLOSED) return -1; - self->state = STREAM_STATE_CLOSED; + self->base.state = STREAM_STATE_CLOSED; - while (!TAILQ_EMPTY(&self->send_queue)) { - struct stream_req* req = TAILQ_FIRST(&self->send_queue); - TAILQ_REMOVE(&self->send_queue, req, link); + while (!TAILQ_EMPTY(&self->base.send_queue)) { + struct stream_req* req = TAILQ_FIRST(&self->base.send_queue); + TAILQ_REMOVE(&self->base.send_queue, req, link); stream_req__finish(req, STREAM_REQ_FAILED); } - if (self->tls_session) - gnutls_deinit(self->tls_session); - self->tls_session = NULL; + if (self->session) + gnutls_deinit(self->session); + self->session = NULL; - aml_stop(aml_get_default(), self->handler); - close(self->fd); - self->fd = -1; + aml_stop(aml_get_default(), self->base.handler); + close(self->base.fd); + self->base.fd = -1; return 0; } @@ -67,25 +75,25 @@ static void stream_gnutls_destroy(struct stream* self) free(self); } -static int stream_gnutls__flush(struct stream* self) +static int stream_gnutls__flush(struct stream* base) { - while (!TAILQ_EMPTY(&self->send_queue)) { - struct stream_req* req = TAILQ_FIRST(&self->send_queue); + struct stream_gnutls* self = (struct stream_gnutls*)base; + while (!TAILQ_EMPTY(&self->base.send_queue)) { + struct stream_req* req = TAILQ_FIRST(&self->base.send_queue); - ssize_t rc = gnutls_record_send( - self->tls_session, req->payload->payload, - req->payload->size); + ssize_t rc = gnutls_record_send(self->session, + req->payload->payload, req->payload->size); if (rc < 0) { if (gnutls_error_is_fatal(rc)) { - stream_close(self); + stream_close(base); return -1; } - stream__poll_rw(self); + stream__poll_rw(base); return 0; } - self->bytes_sent += rc; + self->base.bytes_sent += rc; ssize_t remaining = req->payload->size - rc; @@ -94,18 +102,18 @@ static int stream_gnutls__flush(struct stream* self) size_t s = req->payload->size; memmove(p, p + s - remaining, remaining); req->payload->size = remaining; - stream__poll_rw(self); + stream__poll_rw(base); return 1; } assert(remaining == 0); - TAILQ_REMOVE(&self->send_queue, req, link); + TAILQ_REMOVE(&self->base.send_queue, req, link); stream_req__finish(req, STREAM_REQ_DONE); } - if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED) - stream__poll_r(self); + if (TAILQ_EMPTY(&base->send_queue) && base->state != STREAM_STATE_CLOSED) + stream__poll_r(base); return 1; } @@ -174,15 +182,17 @@ static int stream_gnutls_send(struct stream* self, struct rcbuf* payload, return stream_gnutls__flush(self); } -static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size) +static ssize_t stream_gnutls_read(struct stream* base, void* dst, size_t size) { - ssize_t rc = gnutls_record_recv(self->tls_session, dst, size); + struct stream_gnutls* self = (struct stream_gnutls*)base; + + ssize_t rc = gnutls_record_recv(self->session, dst, size); if (rc == 0) { - stream__remote_closed(self); + stream__remote_closed(base); return rc; } if (rc > 0) { - self->bytes_received += rc; + self->base.bytes_received += rc; return rc; } @@ -199,33 +209,34 @@ static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size) } // Make sure data wasn't being written. - assert(gnutls_record_get_direction(self->tls_session) == 0); + assert(gnutls_record_get_direction(self->session) == 0); return -1; } -static int stream__try_tls_accept(struct stream* self) +static int stream__try_tls_accept(struct stream* base) { + struct stream_gnutls* self = (struct stream_gnutls*)base; int rc; - rc = gnutls_handshake(self->tls_session); + rc = gnutls_handshake(self->session); if (rc == GNUTLS_E_SUCCESS) { - self->state = STREAM_STATE_TLS_READY; - stream__poll_r(self); + self->base.state = STREAM_STATE_TLS_READY; + stream__poll_r(base); return 0; } if (gnutls_error_is_fatal(rc)) { - aml_stop(aml_get_default(), self->handler); + aml_stop(aml_get_default(), self->base.handler); return -1; } - int was_writing = gnutls_record_get_direction(self->tls_session); + int was_writing = gnutls_record_get_direction(self->session); if (was_writing) - stream__poll_w(self); + stream__poll_w(base); else - stream__poll_r(self); + stream__poll_r(base); - self->state = STREAM_STATE_TLS_HANDSHAKE; + self->base.state = STREAM_STATE_TLS_HANDSHAKE; return 0; } @@ -236,40 +247,41 @@ static struct stream_impl impl = { .send = stream_gnutls_send, }; -int stream_upgrade_to_tls(struct stream* self, void* context) +int stream_upgrade_to_tls(struct stream* base, void* context) { + struct stream_gnutls* self = (struct stream_gnutls*)base; int rc; - rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK); + rc = gnutls_init(&self->session, GNUTLS_SERVER | GNUTLS_NONBLOCK); if (rc != GNUTLS_E_SUCCESS) return -1; - rc = gnutls_set_default_priority(self->tls_session); + rc = gnutls_set_default_priority(self->session); if (rc != GNUTLS_E_SUCCESS) goto failure; - rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE, + rc = gnutls_credentials_set(self->session, GNUTLS_CRD_CERTIFICATE, context); if (rc != GNUTLS_E_SUCCESS) goto failure; - aml_stop(aml_get_default(), self->handler); - aml_unref(self->handler); + aml_stop(aml_get_default(), self->base.handler); + aml_unref(self->base.handler); - self->handler = aml_handler_new(self->fd, stream_gnutls__on_event, self, - NULL); - assert(self->handler); + self->base.handler = aml_handler_new(self->base.fd, + stream_gnutls__on_event, self, NULL); + assert(self->base.handler); - rc = aml_start(aml_get_default(), self->handler); + rc = aml_start(aml_get_default(), self->base.handler); assert(rc >= 0); - gnutls_transport_set_int(self->tls_session, self->fd); + gnutls_transport_set_int(self->session, self->base.fd); - self->impl = &impl; + self->base.impl = &impl; - return stream__try_tls_accept(self); + return stream__try_tls_accept(base); failure: - gnutls_deinit(self->tls_session); + gnutls_deinit(self->session); return -1; }