ws-handshake: Handle protocol & version fields

pull/68/head
Andri Yngvason 2023-04-30 13:44:12 +00:00
parent 58df7dfc5c
commit e5e6767c1e
1 changed files with 36 additions and 2 deletions

View File

@ -8,9 +8,18 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <stdbool.h> #include <stdbool.h>
#include <ctype.h>
static const char magic_uuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static const char magic_uuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
static void tolower_and_remove_ws(char* dst, const char* src)
{
while (*src)
if (!isspace(*src))
*dst++ = tolower(*src++);
*dst = '\0';
}
// TODO: Do some more sanity checks on the input // TODO: Do some more sanity checks on the input
ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input) ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input)
{ {
@ -19,16 +28,39 @@ ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input)
if (http_req_parse(&req, input) < 0) if (http_req_parse(&req, input) < 0)
return -1; return -1;
char protocols[256] = ",";
char versions[256] = ",";
char tmpstring[256];
const char *challenge = NULL; const char *challenge = NULL;
for (size_t i = 0; i < req.field_index; ++i) { for (size_t i = 0; i < req.field_index; ++i) {
if (strcasecmp(req.field[i].key, "Sec-WebSocket-Key") == 0) { if (strcasecmp(req.field[i].key, "Sec-WebSocket-Key") == 0) {
challenge = req.field[i].value; challenge = req.field[i].value;
} }
if (strcasecmp(req.field[i].key, "Sec-WebSocket-Protocol") == 0) {
snprintf(tmpstring, sizeof(tmpstring), "%s%s,",
protocols, req.field[i].value);
tolower_and_remove_ws(protocols, tmpstring);
}
if (strcasecmp(req.field[i].key, "Sec-WebSocket-Version") == 0) {
snprintf(tmpstring, sizeof(tmpstring), "%s%s,",
versions, req.field[i].value);
tolower_and_remove_ws(versions, tmpstring);
}
} }
if (!challenge) if (!challenge)
goto failure; goto failure;
bool have_protocols = strlen(protocols) != 1;
bool have_versions = strlen(versions) != 1;
if (have_protocols && !strstr(protocols, ",chat,"))
goto failure;
if (have_versions && !strstr(versions, ",13,"))
goto failure;
struct sha1_ctx ctx; struct sha1_ctx ctx;
sha1_init(&ctx); sha1_init(&ctx);
sha1_update(&ctx, strlen(challenge), (const uint8_t*)challenge); sha1_update(&ctx, strlen(challenge), (const uint8_t*)challenge);
@ -45,9 +77,11 @@ ssize_t ws_handshake(char* output, size_t output_maxlen, const char* input)
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
"Connection: Upgrade\r\n" "Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n" "Sec-WebSocket-Accept: %s\r\n"
"Sec-WebSocket-Protocol: chat\r\n" "%s%s"
"\r\n", "\r\n",
response); response,
have_protocols ? "Sec-WebSocket-Protocol: char\r\n" : "",
have_versions ? "Sec-WebSocket-Version: 13\r\n" : "");
ssize_t header_len = req.header_length; ssize_t header_len = req.header_length;
ok = len < output_maxlen; ok = len < output_maxlen;