stream: Move tls specific member into tls impl

pull/93/head
Andri Yngvason 2023-05-28 15:50:36 +00:00
parent c006936fd0
commit b5f37d0227
2 changed files with 65 additions and 61 deletions

View File

@ -23,10 +23,6 @@
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
#ifdef ENABLE_TLS
#include <gnutls/gnutls.h>
#endif
enum stream_state { enum stream_state {
STREAM_STATE_NORMAL = 0, STREAM_STATE_NORMAL = 0,
STREAM_STATE_CLOSED, STREAM_STATE_CLOSED,
@ -86,10 +82,6 @@ struct stream {
struct stream_send_queue send_queue; struct stream_send_queue send_queue;
#ifdef ENABLE_TLS
gnutls_session_t tls_session;
#endif
uint32_t bytes_sent; uint32_t bytes_sent;
uint32_t bytes_received; uint32_t bytes_received;

View File

@ -34,28 +34,36 @@
#include "stream-common.h" #include "stream-common.h"
#include "sys/queue.h" #include "sys/queue.h"
struct stream_gnutls {
struct stream base;
gnutls_session_t session;
};
static int stream__try_tls_accept(struct stream* self); static int stream__try_tls_accept(struct stream* self);
static int stream_gnutls_close(struct stream* self) static int stream_gnutls_close(struct stream* base)
{ {
if (self->state == STREAM_STATE_CLOSED) struct stream_gnutls* self = (struct stream_gnutls*)base;
if (self->base.state == STREAM_STATE_CLOSED)
return -1; return -1;
self->state = STREAM_STATE_CLOSED; self->base.state = STREAM_STATE_CLOSED;
while (!TAILQ_EMPTY(&self->send_queue)) { while (!TAILQ_EMPTY(&self->base.send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->send_queue); struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
TAILQ_REMOVE(&self->send_queue, req, link); TAILQ_REMOVE(&self->base.send_queue, req, link);
stream_req__finish(req, STREAM_REQ_FAILED); stream_req__finish(req, STREAM_REQ_FAILED);
} }
if (self->tls_session) if (self->session)
gnutls_deinit(self->tls_session); gnutls_deinit(self->session);
self->tls_session = NULL; self->session = NULL;
aml_stop(aml_get_default(), self->handler); aml_stop(aml_get_default(), self->base.handler);
close(self->fd); close(self->base.fd);
self->fd = -1; self->base.fd = -1;
return 0; return 0;
} }
@ -67,25 +75,25 @@ static void stream_gnutls_destroy(struct stream* self)
free(self); free(self);
} }
static int stream_gnutls__flush(struct stream* self) static int stream_gnutls__flush(struct stream* base)
{ {
while (!TAILQ_EMPTY(&self->send_queue)) { struct stream_gnutls* self = (struct stream_gnutls*)base;
struct stream_req* req = TAILQ_FIRST(&self->send_queue); while (!TAILQ_EMPTY(&self->base.send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
ssize_t rc = gnutls_record_send( ssize_t rc = gnutls_record_send(self->session,
self->tls_session, req->payload->payload, req->payload->payload, req->payload->size);
req->payload->size);
if (rc < 0) { if (rc < 0) {
if (gnutls_error_is_fatal(rc)) { if (gnutls_error_is_fatal(rc)) {
stream_close(self); stream_close(base);
return -1; return -1;
} }
stream__poll_rw(self); stream__poll_rw(base);
return 0; return 0;
} }
self->bytes_sent += rc; self->base.bytes_sent += rc;
ssize_t remaining = req->payload->size - rc; ssize_t remaining = req->payload->size - rc;
@ -94,18 +102,18 @@ static int stream_gnutls__flush(struct stream* self)
size_t s = req->payload->size; size_t s = req->payload->size;
memmove(p, p + s - remaining, remaining); memmove(p, p + s - remaining, remaining);
req->payload->size = remaining; req->payload->size = remaining;
stream__poll_rw(self); stream__poll_rw(base);
return 1; return 1;
} }
assert(remaining == 0); assert(remaining == 0);
TAILQ_REMOVE(&self->send_queue, req, link); TAILQ_REMOVE(&self->base.send_queue, req, link);
stream_req__finish(req, STREAM_REQ_DONE); stream_req__finish(req, STREAM_REQ_DONE);
} }
if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED) if (TAILQ_EMPTY(&base->send_queue) && base->state != STREAM_STATE_CLOSED)
stream__poll_r(self); stream__poll_r(base);
return 1; return 1;
} }
@ -174,15 +182,17 @@ static int stream_gnutls_send(struct stream* self, struct rcbuf* payload,
return stream_gnutls__flush(self); return stream_gnutls__flush(self);
} }
static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size) static ssize_t stream_gnutls_read(struct stream* base, void* dst, size_t size)
{ {
ssize_t rc = gnutls_record_recv(self->tls_session, dst, size); struct stream_gnutls* self = (struct stream_gnutls*)base;
ssize_t rc = gnutls_record_recv(self->session, dst, size);
if (rc == 0) { if (rc == 0) {
stream__remote_closed(self); stream__remote_closed(base);
return rc; return rc;
} }
if (rc > 0) { if (rc > 0) {
self->bytes_received += rc; self->base.bytes_received += rc;
return rc; return rc;
} }
@ -199,33 +209,34 @@ static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size)
} }
// Make sure data wasn't being written. // Make sure data wasn't being written.
assert(gnutls_record_get_direction(self->tls_session) == 0); assert(gnutls_record_get_direction(self->session) == 0);
return -1; return -1;
} }
static int stream__try_tls_accept(struct stream* self) static int stream__try_tls_accept(struct stream* base)
{ {
struct stream_gnutls* self = (struct stream_gnutls*)base;
int rc; int rc;
rc = gnutls_handshake(self->tls_session); rc = gnutls_handshake(self->session);
if (rc == GNUTLS_E_SUCCESS) { if (rc == GNUTLS_E_SUCCESS) {
self->state = STREAM_STATE_TLS_READY; self->base.state = STREAM_STATE_TLS_READY;
stream__poll_r(self); stream__poll_r(base);
return 0; return 0;
} }
if (gnutls_error_is_fatal(rc)) { if (gnutls_error_is_fatal(rc)) {
aml_stop(aml_get_default(), self->handler); aml_stop(aml_get_default(), self->base.handler);
return -1; return -1;
} }
int was_writing = gnutls_record_get_direction(self->tls_session); int was_writing = gnutls_record_get_direction(self->session);
if (was_writing) if (was_writing)
stream__poll_w(self); stream__poll_w(base);
else else
stream__poll_r(self); stream__poll_r(base);
self->state = STREAM_STATE_TLS_HANDSHAKE; self->base.state = STREAM_STATE_TLS_HANDSHAKE;
return 0; return 0;
} }
@ -236,40 +247,41 @@ static struct stream_impl impl = {
.send = stream_gnutls_send, .send = stream_gnutls_send,
}; };
int stream_upgrade_to_tls(struct stream* self, void* context) int stream_upgrade_to_tls(struct stream* base, void* context)
{ {
struct stream_gnutls* self = (struct stream_gnutls*)base;
int rc; int rc;
rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK); rc = gnutls_init(&self->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
if (rc != GNUTLS_E_SUCCESS) if (rc != GNUTLS_E_SUCCESS)
return -1; return -1;
rc = gnutls_set_default_priority(self->tls_session); rc = gnutls_set_default_priority(self->session);
if (rc != GNUTLS_E_SUCCESS) if (rc != GNUTLS_E_SUCCESS)
goto failure; goto failure;
rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE, rc = gnutls_credentials_set(self->session, GNUTLS_CRD_CERTIFICATE,
context); context);
if (rc != GNUTLS_E_SUCCESS) if (rc != GNUTLS_E_SUCCESS)
goto failure; goto failure;
aml_stop(aml_get_default(), self->handler); aml_stop(aml_get_default(), self->base.handler);
aml_unref(self->handler); aml_unref(self->base.handler);
self->handler = aml_handler_new(self->fd, stream_gnutls__on_event, self, self->base.handler = aml_handler_new(self->base.fd,
NULL); stream_gnutls__on_event, self, NULL);
assert(self->handler); assert(self->base.handler);
rc = aml_start(aml_get_default(), self->handler); rc = aml_start(aml_get_default(), self->base.handler);
assert(rc >= 0); assert(rc >= 0);
gnutls_transport_set_int(self->tls_session, self->fd); gnutls_transport_set_int(self->session, self->base.fd);
self->impl = &impl; self->base.impl = &impl;
return stream__try_tls_accept(self); return stream__try_tls_accept(base);
failure: failure:
gnutls_deinit(self->tls_session); gnutls_deinit(self->session);
return -1; return -1;
} }