stream: Move tls specific member into tls impl

pull/93/head
Andri Yngvason 2023-05-28 15:50:36 +00:00
parent c006936fd0
commit b5f37d0227
2 changed files with 65 additions and 61 deletions

View File

@ -23,10 +23,6 @@
#include <stdint.h>
#include <stdbool.h>
#ifdef ENABLE_TLS
#include <gnutls/gnutls.h>
#endif
enum stream_state {
STREAM_STATE_NORMAL = 0,
STREAM_STATE_CLOSED,
@ -86,10 +82,6 @@ struct stream {
struct stream_send_queue send_queue;
#ifdef ENABLE_TLS
gnutls_session_t tls_session;
#endif
uint32_t bytes_sent;
uint32_t bytes_received;

View File

@ -34,28 +34,36 @@
#include "stream-common.h"
#include "sys/queue.h"
struct stream_gnutls {
struct stream base;
gnutls_session_t session;
};
static int stream__try_tls_accept(struct stream* self);
static int stream_gnutls_close(struct stream* self)
static int stream_gnutls_close(struct stream* base)
{
if (self->state == STREAM_STATE_CLOSED)
struct stream_gnutls* self = (struct stream_gnutls*)base;
if (self->base.state == STREAM_STATE_CLOSED)
return -1;
self->state = STREAM_STATE_CLOSED;
self->base.state = STREAM_STATE_CLOSED;
while (!TAILQ_EMPTY(&self->send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->send_queue);
TAILQ_REMOVE(&self->send_queue, req, link);
while (!TAILQ_EMPTY(&self->base.send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
TAILQ_REMOVE(&self->base.send_queue, req, link);
stream_req__finish(req, STREAM_REQ_FAILED);
}
if (self->tls_session)
gnutls_deinit(self->tls_session);
self->tls_session = NULL;
if (self->session)
gnutls_deinit(self->session);
self->session = NULL;
aml_stop(aml_get_default(), self->handler);
close(self->fd);
self->fd = -1;
aml_stop(aml_get_default(), self->base.handler);
close(self->base.fd);
self->base.fd = -1;
return 0;
}
@ -67,25 +75,25 @@ static void stream_gnutls_destroy(struct stream* self)
free(self);
}
static int stream_gnutls__flush(struct stream* self)
static int stream_gnutls__flush(struct stream* base)
{
while (!TAILQ_EMPTY(&self->send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->send_queue);
struct stream_gnutls* self = (struct stream_gnutls*)base;
while (!TAILQ_EMPTY(&self->base.send_queue)) {
struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
ssize_t rc = gnutls_record_send(
self->tls_session, req->payload->payload,
req->payload->size);
ssize_t rc = gnutls_record_send(self->session,
req->payload->payload, req->payload->size);
if (rc < 0) {
if (gnutls_error_is_fatal(rc)) {
stream_close(self);
stream_close(base);
return -1;
}
stream__poll_rw(self);
stream__poll_rw(base);
return 0;
}
self->bytes_sent += rc;
self->base.bytes_sent += rc;
ssize_t remaining = req->payload->size - rc;
@ -94,18 +102,18 @@ static int stream_gnutls__flush(struct stream* self)
size_t s = req->payload->size;
memmove(p, p + s - remaining, remaining);
req->payload->size = remaining;
stream__poll_rw(self);
stream__poll_rw(base);
return 1;
}
assert(remaining == 0);
TAILQ_REMOVE(&self->send_queue, req, link);
TAILQ_REMOVE(&self->base.send_queue, req, link);
stream_req__finish(req, STREAM_REQ_DONE);
}
if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED)
stream__poll_r(self);
if (TAILQ_EMPTY(&base->send_queue) && base->state != STREAM_STATE_CLOSED)
stream__poll_r(base);
return 1;
}
@ -174,15 +182,17 @@ static int stream_gnutls_send(struct stream* self, struct rcbuf* payload,
return stream_gnutls__flush(self);
}
static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size)
static ssize_t stream_gnutls_read(struct stream* base, void* dst, size_t size)
{
ssize_t rc = gnutls_record_recv(self->tls_session, dst, size);
struct stream_gnutls* self = (struct stream_gnutls*)base;
ssize_t rc = gnutls_record_recv(self->session, dst, size);
if (rc == 0) {
stream__remote_closed(self);
stream__remote_closed(base);
return rc;
}
if (rc > 0) {
self->bytes_received += rc;
self->base.bytes_received += rc;
return rc;
}
@ -199,33 +209,34 @@ static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size)
}
// Make sure data wasn't being written.
assert(gnutls_record_get_direction(self->tls_session) == 0);
assert(gnutls_record_get_direction(self->session) == 0);
return -1;
}
static int stream__try_tls_accept(struct stream* self)
static int stream__try_tls_accept(struct stream* base)
{
struct stream_gnutls* self = (struct stream_gnutls*)base;
int rc;
rc = gnutls_handshake(self->tls_session);
rc = gnutls_handshake(self->session);
if (rc == GNUTLS_E_SUCCESS) {
self->state = STREAM_STATE_TLS_READY;
stream__poll_r(self);
self->base.state = STREAM_STATE_TLS_READY;
stream__poll_r(base);
return 0;
}
if (gnutls_error_is_fatal(rc)) {
aml_stop(aml_get_default(), self->handler);
aml_stop(aml_get_default(), self->base.handler);
return -1;
}
int was_writing = gnutls_record_get_direction(self->tls_session);
int was_writing = gnutls_record_get_direction(self->session);
if (was_writing)
stream__poll_w(self);
stream__poll_w(base);
else
stream__poll_r(self);
stream__poll_r(base);
self->state = STREAM_STATE_TLS_HANDSHAKE;
self->base.state = STREAM_STATE_TLS_HANDSHAKE;
return 0;
}
@ -236,40 +247,41 @@ static struct stream_impl impl = {
.send = stream_gnutls_send,
};
int stream_upgrade_to_tls(struct stream* self, void* context)
int stream_upgrade_to_tls(struct stream* base, void* context)
{
struct stream_gnutls* self = (struct stream_gnutls*)base;
int rc;
rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
rc = gnutls_init(&self->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
if (rc != GNUTLS_E_SUCCESS)
return -1;
rc = gnutls_set_default_priority(self->tls_session);
rc = gnutls_set_default_priority(self->session);
if (rc != GNUTLS_E_SUCCESS)
goto failure;
rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE,
rc = gnutls_credentials_set(self->session, GNUTLS_CRD_CERTIFICATE,
context);
if (rc != GNUTLS_E_SUCCESS)
goto failure;
aml_stop(aml_get_default(), self->handler);
aml_unref(self->handler);
aml_stop(aml_get_default(), self->base.handler);
aml_unref(self->base.handler);
self->handler = aml_handler_new(self->fd, stream_gnutls__on_event, self,
NULL);
assert(self->handler);
self->base.handler = aml_handler_new(self->base.fd,
stream_gnutls__on_event, self, NULL);
assert(self->base.handler);
rc = aml_start(aml_get_default(), self->handler);
rc = aml_start(aml_get_default(), self->base.handler);
assert(rc >= 0);
gnutls_transport_set_int(self->tls_session, self->fd);
gnutls_transport_set_int(self->session, self->base.fd);
self->impl = &impl;
self->base.impl = &impl;
return stream__try_tls_accept(self);
return stream__try_tls_accept(base);
failure:
gnutls_deinit(self->tls_session);
gnutls_deinit(self->session);
return -1;
}