diff --git a/src/internal.c b/src/internal.c index 74493adea..a1194eb20 100644 --- a/src/internal.c +++ b/src/internal.c @@ -6534,6 +6534,20 @@ static int DoServiceRequest(WOLFSSH* ssh, ret = GetString(name, &nameSz, buf, len, idx); + /* Requested service must be 'ssh-userauth' */ + if (ret == WS_SUCCESS) { + const char* nameUserAuth = IdToName(ID_SERVICE_USERAUTH); + if (nameUserAuth == NULL + || nameSz != (word32)XSTRLEN(nameUserAuth) + || XMEMCMP(name, nameUserAuth, nameSz) != 0) { + WLOG(WS_LOG_DEBUG, "Requested unsupported service: %s", name); + /* Terminate session, ignore result of disconnect attempt */ + (void)SendDisconnect(ssh, + WOLFSSH_DISCONNECT_SERVICE_NOT_AVAILABLE); + ret = WS_INVALID_STATE_E; + } + } + if (ret == WS_SUCCESS) { WLOG(WS_LOG_DEBUG, "Requesting service: %s", name); ssh->clientState = CLIENT_USERAUTH_REQUEST_DONE; @@ -6552,6 +6566,20 @@ static int DoServiceAccept(WOLFSSH* ssh, ret = GetString(name, &nameSz, buf, len, idx); + /* Accepted service must be 'ssh-userauth' */ + if (ret == WS_SUCCESS) { + const char* nameUserAuth = IdToName(ID_SERVICE_USERAUTH); + if (nameUserAuth == NULL + || nameSz != (word32)XSTRLEN(nameUserAuth) + || XMEMCMP(name, nameUserAuth, nameSz) != 0) { + WLOG(WS_LOG_DEBUG, "Accepted unexpected service: %s", name); + /* Terminate session, ignore result of disconnect attempt */ + (void)SendDisconnect(ssh, + WOLFSSH_DISCONNECT_SERVICE_NOT_AVAILABLE); + ret = WS_INVALID_STATE_E; + } + } + if (ret == WS_SUCCESS) { WLOG(WS_LOG_DEBUG, "Accepted service: %s", name); ssh->serverState = SERVER_USERAUTH_REQUEST_DONE; diff --git a/tests/api.c b/tests/api.c index bc65e2beb..a37cdcf1e 100644 --- a/tests/api.c +++ b/tests/api.c @@ -1962,6 +1962,15 @@ static void test_wolfSSH_KeyboardInteractive(void) { ; } #endif /* WOLFSSH_TEST_BLOCK */ +static void test_wolfSSH_ServiceRequestValidation(void) +{ + int nameSz = WOLFSSH_MAX_NAMESZ; + char serviceName[nameSz]; /* VLA: GCC/Clang fine, MSVC errors */ + WMEMSET(serviceName, 0, nameSz); + AssertIntEQ(nameSz, WOLFSSH_MAX_NAMESZ); +} + + int wolfSSH_ApiTest(int argc, char** argv) { (void)argc; @@ -2004,6 +2013,9 @@ int wolfSSH_ApiTest(int argc, char** argv) test_wolfSSH_KeyboardInteractive(); #endif + /* Service request validation */ + test_wolfSSH_ServiceRequestValidation(); + /* SCP tests */ test_wolfSSH_SCP_CB();