stream: Move tls specific member into tls impl
parent
c006936fd0
commit
b5f37d0227
|
@ -23,10 +23,6 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
#ifdef ENABLE_TLS
|
|
||||||
#include <gnutls/gnutls.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
enum stream_state {
|
enum stream_state {
|
||||||
STREAM_STATE_NORMAL = 0,
|
STREAM_STATE_NORMAL = 0,
|
||||||
STREAM_STATE_CLOSED,
|
STREAM_STATE_CLOSED,
|
||||||
|
@ -86,10 +82,6 @@ struct stream {
|
||||||
|
|
||||||
struct stream_send_queue send_queue;
|
struct stream_send_queue send_queue;
|
||||||
|
|
||||||
#ifdef ENABLE_TLS
|
|
||||||
gnutls_session_t tls_session;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
uint32_t bytes_sent;
|
uint32_t bytes_sent;
|
||||||
uint32_t bytes_received;
|
uint32_t bytes_received;
|
||||||
|
|
||||||
|
|
|
@ -34,28 +34,36 @@
|
||||||
#include "stream-common.h"
|
#include "stream-common.h"
|
||||||
#include "sys/queue.h"
|
#include "sys/queue.h"
|
||||||
|
|
||||||
|
struct stream_gnutls {
|
||||||
|
struct stream base;
|
||||||
|
|
||||||
|
gnutls_session_t session;
|
||||||
|
};
|
||||||
|
|
||||||
static int stream__try_tls_accept(struct stream* self);
|
static int stream__try_tls_accept(struct stream* self);
|
||||||
|
|
||||||
static int stream_gnutls_close(struct stream* self)
|
static int stream_gnutls_close(struct stream* base)
|
||||||
{
|
{
|
||||||
if (self->state == STREAM_STATE_CLOSED)
|
struct stream_gnutls* self = (struct stream_gnutls*)base;
|
||||||
|
|
||||||
|
if (self->base.state == STREAM_STATE_CLOSED)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
self->state = STREAM_STATE_CLOSED;
|
self->base.state = STREAM_STATE_CLOSED;
|
||||||
|
|
||||||
while (!TAILQ_EMPTY(&self->send_queue)) {
|
while (!TAILQ_EMPTY(&self->base.send_queue)) {
|
||||||
struct stream_req* req = TAILQ_FIRST(&self->send_queue);
|
struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
|
||||||
TAILQ_REMOVE(&self->send_queue, req, link);
|
TAILQ_REMOVE(&self->base.send_queue, req, link);
|
||||||
stream_req__finish(req, STREAM_REQ_FAILED);
|
stream_req__finish(req, STREAM_REQ_FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self->tls_session)
|
if (self->session)
|
||||||
gnutls_deinit(self->tls_session);
|
gnutls_deinit(self->session);
|
||||||
self->tls_session = NULL;
|
self->session = NULL;
|
||||||
|
|
||||||
aml_stop(aml_get_default(), self->handler);
|
aml_stop(aml_get_default(), self->base.handler);
|
||||||
close(self->fd);
|
close(self->base.fd);
|
||||||
self->fd = -1;
|
self->base.fd = -1;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -67,25 +75,25 @@ static void stream_gnutls_destroy(struct stream* self)
|
||||||
free(self);
|
free(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int stream_gnutls__flush(struct stream* self)
|
static int stream_gnutls__flush(struct stream* base)
|
||||||
{
|
{
|
||||||
while (!TAILQ_EMPTY(&self->send_queue)) {
|
struct stream_gnutls* self = (struct stream_gnutls*)base;
|
||||||
struct stream_req* req = TAILQ_FIRST(&self->send_queue);
|
while (!TAILQ_EMPTY(&self->base.send_queue)) {
|
||||||
|
struct stream_req* req = TAILQ_FIRST(&self->base.send_queue);
|
||||||
|
|
||||||
ssize_t rc = gnutls_record_send(
|
ssize_t rc = gnutls_record_send(self->session,
|
||||||
self->tls_session, req->payload->payload,
|
req->payload->payload, req->payload->size);
|
||||||
req->payload->size);
|
|
||||||
if (rc < 0) {
|
if (rc < 0) {
|
||||||
if (gnutls_error_is_fatal(rc)) {
|
if (gnutls_error_is_fatal(rc)) {
|
||||||
stream_close(self);
|
stream_close(base);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
stream__poll_rw(self);
|
stream__poll_rw(base);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
self->bytes_sent += rc;
|
self->base.bytes_sent += rc;
|
||||||
|
|
||||||
ssize_t remaining = req->payload->size - rc;
|
ssize_t remaining = req->payload->size - rc;
|
||||||
|
|
||||||
|
@ -94,18 +102,18 @@ static int stream_gnutls__flush(struct stream* self)
|
||||||
size_t s = req->payload->size;
|
size_t s = req->payload->size;
|
||||||
memmove(p, p + s - remaining, remaining);
|
memmove(p, p + s - remaining, remaining);
|
||||||
req->payload->size = remaining;
|
req->payload->size = remaining;
|
||||||
stream__poll_rw(self);
|
stream__poll_rw(base);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(remaining == 0);
|
assert(remaining == 0);
|
||||||
|
|
||||||
TAILQ_REMOVE(&self->send_queue, req, link);
|
TAILQ_REMOVE(&self->base.send_queue, req, link);
|
||||||
stream_req__finish(req, STREAM_REQ_DONE);
|
stream_req__finish(req, STREAM_REQ_DONE);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED)
|
if (TAILQ_EMPTY(&base->send_queue) && base->state != STREAM_STATE_CLOSED)
|
||||||
stream__poll_r(self);
|
stream__poll_r(base);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -174,15 +182,17 @@ static int stream_gnutls_send(struct stream* self, struct rcbuf* payload,
|
||||||
return stream_gnutls__flush(self);
|
return stream_gnutls__flush(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size)
|
static ssize_t stream_gnutls_read(struct stream* base, void* dst, size_t size)
|
||||||
{
|
{
|
||||||
ssize_t rc = gnutls_record_recv(self->tls_session, dst, size);
|
struct stream_gnutls* self = (struct stream_gnutls*)base;
|
||||||
|
|
||||||
|
ssize_t rc = gnutls_record_recv(self->session, dst, size);
|
||||||
if (rc == 0) {
|
if (rc == 0) {
|
||||||
stream__remote_closed(self);
|
stream__remote_closed(base);
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
if (rc > 0) {
|
if (rc > 0) {
|
||||||
self->bytes_received += rc;
|
self->base.bytes_received += rc;
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,33 +209,34 @@ static ssize_t stream_gnutls_read(struct stream* self, void* dst, size_t size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure data wasn't being written.
|
// Make sure data wasn't being written.
|
||||||
assert(gnutls_record_get_direction(self->tls_session) == 0);
|
assert(gnutls_record_get_direction(self->session) == 0);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int stream__try_tls_accept(struct stream* self)
|
static int stream__try_tls_accept(struct stream* base)
|
||||||
{
|
{
|
||||||
|
struct stream_gnutls* self = (struct stream_gnutls*)base;
|
||||||
int rc;
|
int rc;
|
||||||
|
|
||||||
rc = gnutls_handshake(self->tls_session);
|
rc = gnutls_handshake(self->session);
|
||||||
if (rc == GNUTLS_E_SUCCESS) {
|
if (rc == GNUTLS_E_SUCCESS) {
|
||||||
self->state = STREAM_STATE_TLS_READY;
|
self->base.state = STREAM_STATE_TLS_READY;
|
||||||
stream__poll_r(self);
|
stream__poll_r(base);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gnutls_error_is_fatal(rc)) {
|
if (gnutls_error_is_fatal(rc)) {
|
||||||
aml_stop(aml_get_default(), self->handler);
|
aml_stop(aml_get_default(), self->base.handler);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int was_writing = gnutls_record_get_direction(self->tls_session);
|
int was_writing = gnutls_record_get_direction(self->session);
|
||||||
if (was_writing)
|
if (was_writing)
|
||||||
stream__poll_w(self);
|
stream__poll_w(base);
|
||||||
else
|
else
|
||||||
stream__poll_r(self);
|
stream__poll_r(base);
|
||||||
|
|
||||||
self->state = STREAM_STATE_TLS_HANDSHAKE;
|
self->base.state = STREAM_STATE_TLS_HANDSHAKE;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,40 +247,41 @@ static struct stream_impl impl = {
|
||||||
.send = stream_gnutls_send,
|
.send = stream_gnutls_send,
|
||||||
};
|
};
|
||||||
|
|
||||||
int stream_upgrade_to_tls(struct stream* self, void* context)
|
int stream_upgrade_to_tls(struct stream* base, void* context)
|
||||||
{
|
{
|
||||||
|
struct stream_gnutls* self = (struct stream_gnutls*)base;
|
||||||
int rc;
|
int rc;
|
||||||
|
|
||||||
rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
|
rc = gnutls_init(&self->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
|
||||||
if (rc != GNUTLS_E_SUCCESS)
|
if (rc != GNUTLS_E_SUCCESS)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
rc = gnutls_set_default_priority(self->tls_session);
|
rc = gnutls_set_default_priority(self->session);
|
||||||
if (rc != GNUTLS_E_SUCCESS)
|
if (rc != GNUTLS_E_SUCCESS)
|
||||||
goto failure;
|
goto failure;
|
||||||
|
|
||||||
rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE,
|
rc = gnutls_credentials_set(self->session, GNUTLS_CRD_CERTIFICATE,
|
||||||
context);
|
context);
|
||||||
if (rc != GNUTLS_E_SUCCESS)
|
if (rc != GNUTLS_E_SUCCESS)
|
||||||
goto failure;
|
goto failure;
|
||||||
|
|
||||||
aml_stop(aml_get_default(), self->handler);
|
aml_stop(aml_get_default(), self->base.handler);
|
||||||
aml_unref(self->handler);
|
aml_unref(self->base.handler);
|
||||||
|
|
||||||
self->handler = aml_handler_new(self->fd, stream_gnutls__on_event, self,
|
self->base.handler = aml_handler_new(self->base.fd,
|
||||||
NULL);
|
stream_gnutls__on_event, self, NULL);
|
||||||
assert(self->handler);
|
assert(self->base.handler);
|
||||||
|
|
||||||
rc = aml_start(aml_get_default(), self->handler);
|
rc = aml_start(aml_get_default(), self->base.handler);
|
||||||
assert(rc >= 0);
|
assert(rc >= 0);
|
||||||
|
|
||||||
gnutls_transport_set_int(self->tls_session, self->fd);
|
gnutls_transport_set_int(self->session, self->base.fd);
|
||||||
|
|
||||||
self->impl = &impl;
|
self->base.impl = &impl;
|
||||||
|
|
||||||
return stream__try_tls_accept(self);
|
return stream__try_tls_accept(base);
|
||||||
|
|
||||||
failure:
|
failure:
|
||||||
gnutls_deinit(self->tls_session);
|
gnutls_deinit(self->session);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue