diff --git a/src/ws-handshake.c b/src/ws-handshake.c index 3f43aa3..700b042 100644 --- a/src/ws-handshake.c +++ b/src/ws-handshake.c @@ -8,9 +8,18 @@ #include #include #include +#include static const char magic_uuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +static void tolower_and_remove_ws(char* dst, const char* src) +{ + while (*src) + if (!isspace(*src)) + *dst++ = tolower(*src++); + *dst = '\0'; +} + // TODO: Do some more sanity checks on the input ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input) { @@ -19,16 +28,39 @@ ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input) if (http_req_parse(&req, input) < 0) return -1; + char protocols[256] = ","; + char versions[256] = ","; + char tmpstring[256]; + const char *challenge = NULL; for (size_t i = 0; i < req.field_index; ++i) { if (strcasecmp(req.field[i].key, "Sec-WebSocket-Key") == 0) { challenge = req.field[i].value; } + if (strcasecmp(req.field[i].key, "Sec-WebSocket-Protocol") == 0) { + snprintf(tmpstring, sizeof(tmpstring), "%s%s,", + protocols, req.field[i].value); + tolower_and_remove_ws(protocols, tmpstring); + } + if (strcasecmp(req.field[i].key, "Sec-WebSocket-Version") == 0) { + snprintf(tmpstring, sizeof(tmpstring), "%s%s,", + versions, req.field[i].value); + tolower_and_remove_ws(versions, tmpstring); + } } if (!challenge) goto failure; + bool have_protocols = strlen(protocols) != 1; + bool have_versions = strlen(versions) != 1; + + if (have_protocols && !strstr(protocols, ",chat,")) + goto failure; + + if (have_versions && !strstr(versions, ",13,")) + goto failure; + struct sha1_ctx ctx; sha1_init(&ctx); sha1_update(&ctx, strlen(challenge), (const uint8_t*)challenge); @@ -45,9 +77,11 @@ ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input) "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" - "Sec-WebSocket-Protocol: chat\r\n" + "%s%s" "\r\n", - response); + response, + have_protocols ? "Sec-WebSocket-Protocol: char\r\n" : "", + have_versions ? "Sec-WebSocket-Version: 13\r\n" : ""); ssize_t header_len = req.header_length; ok = len < output_maxlen;