diff --git a/include/common.h b/include/common.h index 6c72f46..c7ed237 100644 --- a/include/common.h +++ b/include/common.h @@ -34,6 +34,7 @@ #define MAX_ENCODINGS 32 #define MAX_OUTGOING_FRAMES 4 #define MSG_BUFFER_SIZE 4096 +#define MAX_CUT_TEXT_SIZE 10000000 enum nvnc_client_state { VNC_CLIENT_STATE_ERROR = -1, @@ -58,6 +59,12 @@ struct nvnc_common { void* userdata; }; +struct cut_text { + char* buffer; + size_t length; + size_t index; +}; + struct nvnc_client { struct nvnc_common common; int ref; @@ -80,6 +87,7 @@ struct nvnc_client { uint8_t msg_buffer[MSG_BUFFER_SIZE]; uint32_t known_width; uint32_t known_height; + struct cut_text cut_text; }; LIST_HEAD(nvnc_client_list, nvnc_client); diff --git a/src/server.c b/src/server.c index d3f0ab9..bd5bf44 100644 --- a/src/server.c +++ b/src/server.c @@ -92,6 +92,7 @@ static void client_close(struct nvnc_client* client) tight_encoder_destroy(&client->tight_encoder); deflateEnd(&client->z_stream); pixman_region_fini(&client->damage); + free(client->cut_text.buffer); free(client); } @@ -602,31 +603,95 @@ void nvnc_send_cut_text(struct nvnc* server, const char* text, uint32_t len) static int on_client_cut_text(struct nvnc_client* client) { struct nvnc* server = client->server; + nvnc_cut_text_fn fn = server->cut_text_fn; struct rfb_cut_text_msg* msg = (struct rfb_cut_text_msg*)(client->msg_buffer + client->buffer_index); - if (client->buffer_len - client->buffer_index < sizeof(*msg)) + size_t left_to_process = client->buffer_len - client->buffer_index; + + if (left_to_process < sizeof(*msg)) return 0; uint32_t length = ntohl(msg->length); + uint32_t max_length = MAX_CUT_TEXT_SIZE; /* Messages greater than this size are unsupported */ - if (length > MSG_BUFFER_SIZE - sizeof(*msg)) { + if (length > max_length) { + log_error("Copied text length (%d) is greater than max supported length (%d)\n", + length, max_length); stream_close(client->net_stream); client_unref(client); return 0; } - if (client->buffer_len - client->buffer_index < sizeof(*msg) + length) - return 0; + size_t msg_size = sizeof(*msg) + length; - nvnc_cut_text_fn fn = server->cut_text_fn; - if (fn) { - fn(server, msg->text, length); + if (msg_size <= left_to_process) { + if (fn) + fn(server, msg->text, length); + + return msg_size; } - return sizeof(*msg) + length; + assert(!client->cut_text.buffer); + + client->cut_text.buffer = malloc(length); + if (!client->cut_text.buffer) { + log_error("OOM: %m\n"); + stream_close(client->net_stream); + client_unref(client); + return 0; + } + + size_t partial_size = left_to_process - sizeof(*msg); + + memcpy(client->cut_text.buffer, msg->text, partial_size); + + client->cut_text.length = length; + client->cut_text.index = partial_size; + + return left_to_process; +} + +static void process_big_cut_text(struct nvnc_client* client) +{ + struct nvnc* server = client->server; + nvnc_cut_text_fn fn = server->cut_text_fn; + + assert(client->cut_text.length > client->cut_text.index); + + void* start = client->cut_text.buffer + client->cut_text.index; + size_t space = client->cut_text.length - client->cut_text.index; + + space = MIN(space, MSG_BUFFER_SIZE); + + ssize_t n_read = stream_read(client->net_stream, start, space); + + if (n_read == 0) + return; + + if (n_read < 0) { + if (errno != EAGAIN) { + log_debug("Client connection error: %p (ref %d)\n", + client, client->ref); + stream_close(client->net_stream); + client_unref(client); + } + + return; + } + + client->cut_text.index += n_read; + + if (client->cut_text.index != client->cut_text.length) + return; + + if (fn) + fn(server, client->cut_text.buffer, client->cut_text.length); + + free(client->cut_text.buffer); + client->cut_text.buffer = NULL; } static int on_client_message(struct nvnc_client* client) @@ -699,6 +764,11 @@ static void on_client_event(struct stream* stream, enum stream_event event) return; } + if (client->cut_text.buffer) { + process_big_cut_text(client); + return; + } + assert(client->buffer_index == 0); void* start = client->msg_buffer + client->buffer_len;