diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f931016..43ea2ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [ main, master ] + branches: [ main, master, 'feature/**' ] tags: ['v*'] pull_request: workflow_dispatch: @@ -49,13 +49,62 @@ jobs: --project-option "platform=https://github.com/pioarduino/platform-espressif32.git" \ --project-option "build_unflags=-std=gnu++11" \ --project-option "build_flags=-std=gnu++17" \ - --project-option "lib_deps=ArduinoJson@>=7.0.0, https://github.com/ESPToolKit/esp-worker.git" + --project-option "lib_deps=ArduinoJson@>=7.0.0" fi done - arduino-cli: + espidf-smoke: runs-on: ubuntu-latest needs: build-examples + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Cache PlatformIO + uses: actions/cache@v4 + with: + path: ~/.platformio + key: ${{ runner.os }}-platformio-${{ hashFiles('**/library.json') }} + restore-keys: | + ${{ runner.os }}-platformio- + + - name: Install PIOArduino Core + run: python -m pip install --upgrade https://github.com/pioarduino/platformio-core/archive/refs/tags/v6.1.18.zip + + - name: Install PIOArduino ESP32 Platform + run: pio platform install https://github.com/pioarduino/platform-espressif32.git + + - name: Build ESP-IDF smoke target + run: | + set -euo pipefail + tmpdir="$(mktemp -d)" + cat > "${tmpdir}/main.cpp" <<'EOF' + #include + + extern "C" void app_main(void) { + ESPWebPush webPush; + WebPushVapidConfig vapid{}; + vapid.subject = "mailto:test@example.com"; + } + EOF + + pio ci "${tmpdir}/main.cpp" \ + --board esp32dev \ + --lib="." \ + --project-option "platform=https://github.com/pioarduino/platform-espressif32.git" \ + --project-option "framework=espidf" \ + --project-option "build_unflags=-std=gnu++11" \ + --project-option "build_flags=-std=gnu++17" \ + --project-option "lib_deps=ArduinoJson@>=7.0.0" + + arduino-cli: + runs-on: ubuntu-latest + needs: [build-examples, espidf-smoke] env: ESP32_CORE_VERSION: 3.3.3 ESP32_PACKAGE_URL: https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_index.json @@ -104,7 +153,6 @@ jobs: run: | arduino-cli lib update-index arduino-cli lib install "ArduinoJson" - arduino-cli lib install --git-url "https://github.com/ESPToolKit/esp-worker.git" - name: Add local library to sketchbook run: | diff --git a/.gitignore b/.gitignore index b694934..5696bb7 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -.venv \ No newline at end of file +.pio/ +.venv diff --git a/.pio/build/project.checksum b/.pio/build/project.checksum deleted file mode 100644 index 69fe1aa..0000000 --- a/.pio/build/project.checksum +++ /dev/null @@ -1 +0,0 @@ -cee741bb01450f5a3833f712661d3a10cc362d89 \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 67009c9..b4f39e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,23 +4,35 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed +- Breaking: renamed the public transport struct from `Subscription` to `WebPushSubscription` everywhere with no compatibility alias. +- Breaking: renamed `PushMessage.sub` to `PushMessage.subscription` for API consistency. +- Breaking: removed app-level metadata fields `deviceId`, `disabledTags`, and `deleted` from the transport struct. +- `validateSubscription()` now validates only the required Web Push transport fields: `endpoint`, `p256dh`, and `auth`. + +## [2.0.0] - 2026-03-28 + ### Added -- Core ESPWebPush implementation: VAPID JWT signing, AES-GCM payload encryption, and HTTP delivery. -- Async queue + worker task with configurable stack, priority, queue length, and memory caps. -- Sync `send()` API returning structured `WebPushResult`. -- Retry/backoff handling for network/transport failures. -- Basic example sketch and CI workflows. -- Teardown lifecycle tests for pre-init `deinit()`, idempotent `deinit()`, re-init, and destructor teardown. -- Strict `PushPayload` API with typed notification fields and ArduinoJson v7+ overloads. -- User-provided network validator callback support. +- `WebPushVapidConfig` with standards-based `subject`, public key, and private key inputs. +- `WebPushEnqueueResult` for async preflight / queue outcomes. +- `WebPushJoinStatus` plus `requestStop()` / `join(timeoutMs)` for bounded worker shutdown. +- RFC 8291 Appendix A key-derivation and encrypted-body test coverage. +- Payload-size guard with the RFC-safe default limit of 3993 bytes. +- Small per-origin JWT cache for VAPID header reuse. ### Changed -- Teardown contract now uses `isInitialized()` and removes the old `initialized()` naming. -- `deinit()` now always converges teardown, including worker/queue/crypto cleanup and runtime config/key release. -- Structured payload inputs now reject unknown fields, missing required fields, and invalid types before enqueue/send. -- ArduinoJson v7+ is now an explicit dependency. +- Reworked encryption and transport to use RFC 8188 / RFC 8291 `aes128gcm` only. +- `init()` now validates `mailto:` / `https://` VAPID subjects and verifies that the configured public key matches the private key. +- Async `send()` overloads now return `WebPushEnqueueResult` and only invoke callbacks for queued work. +- `deinit()` now returns `WebPushJoinStatus` and uses a bounded stop/join flow instead of waiting forever. +- JWT payload assembly now uses dynamic `std::string` construction instead of a fixed stack buffer. +- Structured and raw payload sends now enforce the payload-size guard before transport. +- README, example sketch, package metadata, and CI now describe the v2 API and drop stale `esp-worker` references. +- CI push triggers now include `feature/**` branches so v2 work runs workflows before merge. +- `library.json` now advertises both Arduino and ESP-IDF compatibility. +- Package metadata now reports the breaking release as `2.0.0`. ### Notes -- JWT signing requires a valid system clock (SNTP). -- Content encoding uses `aesgcm` with VAPID headers (`Authorization`, `Crypto-Key`, `Encryption`). -- Worker configuration now uses `WebPushWorkerConfig` with native FreeRTOS task creation. +- JWT signing still requires a valid system clock (SNTP). +- Push sends use `Content-Encoding: aes128gcm` with VAPID `Authorization`. +- Breaking changes in v2 include the new `init()` signature, async enqueue return type, bounded shutdown API, and RFC 8291-only protocol behavior. diff --git a/README.md b/README.md index 7b85004..9feab04 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,24 @@ # ESPWebPush -ESPWebPush is an **async-first** Web Push sender for ESP32 firmware. It handles VAPID JWT signing, Web Push AES-GCM payload encryption, and HTTP delivery so your devices can notify browsers without extra glue code. +ESPWebPush is an async-first Web Push sender for ESP32 firmware. It handles VAPID JWT signing, RFC 8291 `aes128gcm` payload encryption, and HTTP delivery so devices can notify browsers without custom glue code. -ArduinoJson v7+ is a required dependency for the structured payload API. +ArduinoJson v7+ is required for the structured payload API. ## CI / Release / License [![CI](https://github.com/ESPToolKit/esp-webPush/actions/workflows/ci.yml/badge.svg)](https://github.com/ESPToolKit/esp-webPush/actions/workflows/ci.yml) [![Release](https://img.shields.io/github/v/release/ESPToolKit/esp-webPush?sort=semver)](https://github.com/ESPToolKit/esp-webPush/releases) -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE.md) +[![License: MIT](https://img.shields.io/github/license/ESPToolKit/esp-webPush)](LICENSE.md) ## Features -- VAPID JWT signing (ES256) from base64url private key. -- Web Push AES-GCM payload encryption. +- RFC 8292 VAPID JWT signing with `mailto:` or `https://` subjects. +- RFC 8291 / RFC 8188 `aes128gcm` Web Push encryption. - Async queue + worker task via native FreeRTOS APIs. -- Optional synchronous `send()` API. +- Bounded shutdown via `requestStop()`, `join(timeoutMs)`, and `deinit(timeoutMs)`. +- Sync `send()` API plus async `send()` overloads that return `WebPushEnqueueResult`. - Strict `PushPayload` validation for browser notification fields. -- ArduinoJson v7+ overloads for validated `JsonDocument` / `JsonVariantConst` payloads. -- Configurable queue length, memory caps (internal vs PSRAM), stack, priority, retries, and timeouts. -- Optional application-provided network validator callback. -- Uses the standard Web Push headers (`Authorization`, `Crypto-Key`, `Encryption`, `TTL`). +- Payload-size guard with the RFC-safe default limit of 3993 bytes. +- Small per-origin JWT cache to avoid re-signing every message. +- Configurable queue length, memory caps, retries, timeouts, and worker task settings. ## Quick Start @@ -31,19 +31,21 @@ ESPWebPush webPush; void setup() { Serial.begin(115200); + WebPushVapidConfig vapid; + vapid.subject = "mailto:notify@example.com"; + vapid.publicKeyBase64 = "BAvapidPublicKeyBase64Url..."; + vapid.privateKeyBase64 = "vapidPrivateKeyBase64Url..."; + WebPushConfig cfg; cfg.queueLength = 16; cfg.queueMemory = WebPushQueueMemory::Psram; cfg.worker.stackSizeBytes = 16 * 1024; cfg.worker.priority = 3; cfg.worker.name = "webpush"; + cfg.maxPayloadBytes = 3993; cfg.networkValidator = []() { return true; }; - webPush.init( - "notify@example.com", - "BAvapidPublicKeyBase64Url...", - "vapidPrivateKeyBase64Url...", - cfg); + webPush.init(vapid, cfg); } void loop() {} @@ -51,13 +53,13 @@ void loop() {} ## Usage -### Subscription / Structured Payload +### WebPushSubscription / Structured Payload ```cpp -Subscription sub; -sub.endpoint = "https://fcm.googleapis.com/fcm/send/..."; -sub.p256dh = "BME..."; // base64url from browser subscription -sub.auth = "nsa..."; // base64url from browser subscription +WebPushSubscription subscription; +subscription.endpoint = "https://fcm.googleapis.com/fcm/send/..."; +subscription.p256dh = "BME..."; +subscription.auth = "nsa..."; PushPayload payload; payload.title = "Hello"; @@ -69,7 +71,7 @@ payload.icon = "https://example.com/icon.png"; ### Async Send ```cpp -bool started = webPush.send(sub, payload, [](WebPushResult result) { +WebPushEnqueueResult enqueue = webPush.send(subscription, payload, [](WebPushResult result) { if (!result.ok()) { ESP_LOGE("WEBPUSH", "Push failed: %s (status %d)", result.message, result.statusCode); @@ -78,11 +80,13 @@ bool started = webPush.send(sub, payload, [](WebPushResult result) { ESP_LOGI("WEBPUSH", "Push OK (status %d)", result.statusCode); }); -if (!started) { - ESP_LOGW("WEBPUSH", "Queue full or not initialized"); +if (!enqueue.queued()) { + ESP_LOGW("WEBPUSH", "Enqueue failed: %s", enqueue.message); } ``` +Async preflight failures are returned through `WebPushEnqueueResult`. The callback only runs for messages that were actually queued. + ### ArduinoJson v7+ Send ```cpp @@ -91,13 +95,13 @@ doc["title"] = "Hello"; doc["body"] = "ESP32"; doc["tag"] = "demo"; -WebPushResult result = webPush.send(sub, doc); +WebPushResult result = webPush.send(subscription, doc); ``` ### Sync Send ```cpp -WebPushResult result = webPush.send(sub, payload); +WebPushResult result = webPush.send(subscription, payload); if (!result.ok()) { ESP_LOGW("WEBPUSH", "Sync push failed: %s", result.message); } @@ -107,7 +111,7 @@ if (!result.ok()) { ```cpp PushMessage msg; -msg.sub = sub; +msg.subscription = subscription; msg.payload = "{\"title\":\"Hello\",\"body\":\"ESP32\"}"; // Raw payload strings remain supported, but they are not schema-validated. @@ -118,60 +122,73 @@ WebPushResult result = webPush.send(msg); ```cpp if (webPush.isInitialized()) { - webPush.deinit(); + WebPushJoinStatus stopStatus = webPush.deinit(); + if (stopStatus == WebPushJoinStatus::Timeout) { + ESP_LOGW("WEBPUSH", "Worker did not stop within the timeout"); + } } ``` +`requestStop()` marks shutdown and wakes the worker without blocking. `join(timeoutMs)` waits for the worker to exit and finalizes shutdown when the stop completes in time. `deinit(timeoutMs)` is the convenience wrapper that performs both in one call. + ## Configuration `WebPushConfig` lets you tune the worker and queue: -- `queueLength` – number of queued messages. -- `queueMemory` – `Internal`, `Psram`, or `Any`. -- `worker` – stack size, priority, core id, PSRAM stack usage. -- `requestTimeoutMs` – HTTP timeout. -- `ttlSeconds` – Web Push TTL header. -- `maxRetries`, `retryBaseDelayMs`, `retryMaxDelayMs` – retry/backoff controls. -- `networkValidator` – optional callback for application-defined network readiness checks. +- `queueLength` - number of queued messages. +- `queueMemory` - `Internal`, `Psram`, or `Any`. +- `worker` - stack size, priority, core id, and task name. +- `requestTimeoutMs` - HTTP timeout. +- `ttlSeconds` - Web Push TTL header. +- `maxRetries`, `retryBaseDelayMs`, `retryMaxDelayMs` - retry/backoff controls. +- `maxPayloadBytes` - plaintext payload size guard. The default is 3993 bytes; use `0` to disable. +- `networkValidator` - optional callback for application-defined network readiness checks. ## Gotchas -- **System time is required** for VAPID JWT expiration. Ensure SNTP is synced. -- Web Push endpoints require TLS; `esp_http_client` must be built with TLS support. -- `aesgcm` content encoding is used to match existing Web Push payloads. -- Structured payload inputs reject unknown top-level keys and invalid field types. - -## API Reference (Core) - -- `bool init(contactEmail, publicKeyBase64, privateKeyBase64, config)` -- `bool send(const PushMessage&, WebPushResultCB cb)` (async) -- `WebPushResult send(const PushMessage&)` (sync) -- `bool send(const Subscription&, const PushPayload&, WebPushResultCB cb)` / `WebPushResult send(const Subscription&, const PushPayload&)` -- `bool send(const Subscription&, const JsonDocument&, WebPushResultCB cb)` / `WebPushResult send(const Subscription&, const JsonDocument&)` -- `bool send(const Subscription&, JsonVariantConst, WebPushResultCB cb)` / `WebPushResult send(const Subscription&, JsonVariantConst)` +- System time is required for VAPID JWT expiration. +- Web Push endpoints require TLS-capable `esp_http_client`. +- Only `aes128gcm` is generated. Legacy `aesgcm` is intentionally not supported in v2. +- `subject` must start with `mailto:` or `https://`. +- The configured VAPID public key must match the private key. + +## API Reference + +- `bool init(const WebPushVapidConfig&, const WebPushConfig& = {})` +- `WebPushEnqueueResult send(const PushMessage&, WebPushResultCB cb)` +- `WebPushResult send(const PushMessage&)` +- `WebPushEnqueueResult send(const WebPushSubscription&, const PushPayload&, WebPushResultCB cb)` +- `WebPushResult send(const WebPushSubscription&, const PushPayload&)` +- `WebPushEnqueueResult send(const WebPushSubscription&, const JsonDocument&, WebPushResultCB cb)` +- `WebPushResult send(const WebPushSubscription&, const JsonDocument&)` +- `WebPushEnqueueResult send(const WebPushSubscription&, JsonVariantConst, WebPushResultCB cb)` +- `WebPushResult send(const WebPushSubscription&, JsonVariantConst)` +- `void requestStop()` +- `WebPushJoinStatus join(uint32_t timeoutMs)` - `void setNetworkValidator(WebPushNetworkValidator)` -- `void deinit()` / `bool isInitialized() const` -- `const char* errorToString(WebPushError)` +- `WebPushJoinStatus deinit(uint32_t timeoutMs = 10000)` / `bool isInitialized() const` +- `const char *errorToString(WebPushError)` -## Restrictions -- ESP32-class targets only (Arduino + ESP-IDF). +## Compatibility +- ESP32-class targets only. +- Arduino and ESP-IDF frameworks are supported. - Requires C++17, ArduinoJson v7+, and mbedTLS. - Do not call from ISR context. ## Tests -Host-side tests are disabled. Use the `examples/` sketches with PlatformIO or Arduino CLI. +- On-device Unity tests live in `test/test_esp_webPush`. +- CI builds Arduino examples and includes an ESP-IDF compile smoke build. ## Formatting Baseline This repository follows the firmware formatting baseline from `esptoolkit-template`: - `.clang-format` is the source of truth for C/C++/INO layout. - `.editorconfig` enforces tabs (`tab_width = 4`), LF endings, and final newline. -- Format all tracked firmware sources with `bash scripts/format_cpp.sh`. +- Format tracked firmware sources with `bash scripts/format_cpp.sh`. ## License -MIT — see [LICENSE.md](LICENSE.md). +MIT - see [LICENSE.md](LICENSE.md). ## ESPToolKit - Check out other libraries: -- Hang out on Discord: - Support the project: - Visit the website: diff --git a/examples/basic_web_push/basic_web_push.ino b/examples/basic_web_push/basic_web_push.ino index 32bf93c..7c80fbb 100644 --- a/examples/basic_web_push/basic_web_push.ino +++ b/examples/basic_web_push/basic_web_push.ino @@ -15,22 +15,23 @@ void setup() { cfg.worker.stackSizeBytes = 16 * 1024; cfg.worker.priority = 3; cfg.worker.name = "webpush"; + cfg.maxPayloadBytes = 3993; cfg.networkValidator = []() { // Replace with your own Wi-Fi/Ethernet readiness check. return true; }; - webPush.init( - "notify@example.com", - "BAvapidPublicKeyBase64Url...", - "vapidPrivateKeyBase64Url...", - cfg - ); + WebPushVapidConfig vapid; + vapid.subject = "mailto:notify@example.com"; + vapid.publicKeyBase64 = "BAvapidPublicKeyBase64Url..."; + vapid.privateKeyBase64 = "vapidPrivateKeyBase64Url..."; - Subscription sub; - sub.endpoint = "https://fcm.googleapis.com/fcm/send/..."; - sub.p256dh = "BMEp256dhBase64Url..."; - sub.auth = "authSecretBase64Url..."; + webPush.init(vapid, cfg); + + WebPushSubscription subscription; + subscription.endpoint = "https://fcm.googleapis.com/fcm/send/..."; + subscription.p256dh = "BMEp256dhBase64Url..."; + subscription.auth = "authSecretBase64Url..."; PushPayload payload; payload.title = "Hello"; @@ -38,7 +39,7 @@ void setup() { payload.tag = "basic-demo"; payload.icon = "https://www.esptoolkit.hu/icon.png"; - webPush.send(sub, payload, [](WebPushResult result) { + WebPushEnqueueResult enqueue = webPush.send(subscription, payload, [](WebPushResult result) { if (!result.ok()) { Serial.printf( "[webpush] async failed: %s (status %d)\n", @@ -50,12 +51,19 @@ void setup() { Serial.printf("[webpush] async ok: %d\n", result.statusCode); }); + if (!enqueue.queued()) { + Serial.printf( + "[webpush] enqueue failed: %s\n", + enqueue.message ? enqueue.message : "unknown" + ); + } + JsonDocument jsonPayload; jsonPayload["title"] = "Hello"; jsonPayload["body"] = "ESP32"; jsonPayload["tag"] = "basic-demo"; - WebPushResult syncResult = webPush.send(sub, jsonPayload); + WebPushResult syncResult = webPush.send(subscription, jsonPayload); if (!syncResult.ok()) { Serial.printf( "[webpush] sync failed: %s\n", @@ -70,7 +78,8 @@ void setup() { void loop() { if (!tornDown && webPush.isInitialized() && teardownAtMs != 0 && millis() >= teardownAtMs) { - webPush.deinit(); + WebPushJoinStatus stopStatus = webPush.deinit(); + Serial.printf("[webpush] deinit status: %d\n", static_cast(stopStatus)); tornDown = true; } vTaskDelay(pdMS_TO_TICKS(1000)); diff --git a/library.json b/library.json index 32d4c5f..2dcaeda 100644 --- a/library.json +++ b/library.json @@ -1,6 +1,6 @@ { "name": "ESPWebPush", - "version": "1.0.1", + "version": "2.0.0", "description": "Web Push payload encryption and delivery for ESP32 firmware", "keywords": [ "esp32", @@ -8,7 +8,7 @@ "web-push", "push-notification", "vapid", - "aesgcm" + "aes128gcm" ], "homepage": "https://github.com/ESPToolKit/esp-webPush", "repository": { @@ -21,7 +21,10 @@ "name": "zekageri" } ], - "frameworks": ["arduino"], + "frameworks": [ + "arduino", + "espidf" + ], "platforms": ["espressif32"], "dependencies": [ { diff --git a/library.properties b/library.properties index 630ca4a..b1fcfda 100644 --- a/library.properties +++ b/library.properties @@ -1,9 +1,9 @@ name=ESPWebPush -version=1.0.1 +version=2.0.0 author=zekageri maintainer=zekageri -sentence=Web Push payload encryption and delivery for ESP32 firmware. -paragraph=Provides VAPID JWT signing, AES-GCM payload encryption, and async delivery using a dedicated FreeRTOS worker task. +sentence=RFC 8291 Web Push encryption and delivery for ESP32 firmware. +paragraph=Provides VAPID JWT signing, aes128gcm payload encryption, and async delivery using a dedicated FreeRTOS worker task. category=HTTP url=https://github.com/ESPToolKit/esp-webPush repository=https://github.com/ESPToolKit/esp-webPush.git diff --git a/src/esp_webPush/webPush.cpp b/src/esp_webPush/webPush.cpp index a88f48b..996a0f1 100644 --- a/src/esp_webPush/webPush.cpp +++ b/src/esp_webPush/webPush.cpp @@ -13,6 +13,7 @@ extern "C" { namespace { constexpr const char *kTag = "ESPWebPush"; constexpr TickType_t kWorkerPollTicks = pdMS_TO_TICKS(250); +constexpr TickType_t kStopDelaySliceTicks = pdMS_TO_TICKS(50); uint32_t capsForMemory(WebPushQueueMemory memory) { switch (memory) { @@ -31,16 +32,16 @@ ESPWebPush::~ESPWebPush() { deinit(); } -bool ESPWebPush::init( - const std::string &contactEmail, - const std::string &publicKeyBase64, - const std::string &privateKeyBase64, - const WebPushConfig &config -) { - deinit(); +bool ESPWebPush::init(const WebPushVapidConfig &vapidConfig, const WebPushConfig &config) { + WebPushJoinStatus stopStatus = deinit(); + if (stopStatus == WebPushJoinStatus::Timeout) { + ESP_LOGE(kTag, "init: previous worker did not stop in time"); + return false; + } - if (contactEmail.empty() || publicKeyBase64.empty() || privateKeyBase64.empty()) { - ESP_LOGE(kTag, "init: missing VAPID keys or contact email"); + if (vapidConfig.subject.empty() || vapidConfig.publicKeyBase64.empty() || + vapidConfig.privateKeyBase64.empty()) { + ESP_LOGE(kTag, "init: missing VAPID subject or keys"); return false; } @@ -50,21 +51,19 @@ bool ESPWebPush::init( } _config = config; - _vapidEmail = contactEmail; - _vapidPublicKey = publicKeyBase64; - _vapidPrivateKey = privateKeyBase64; + _vapidConfig = vapidConfig; setNetworkValidator(config.networkValidator); if (_config.worker.name.empty()) { _config.worker.name = "webpush"; } - WebPushResult keyCheck{}; - if (!validateVapidKeys(keyCheck)) { + WebPushResult configCheck{}; + if (!validateVapidSubject(configCheck) || !validateVapidKeys(configCheck)) { ESP_LOGE( kTag, - "init: invalid VAPID keys (%s)", - keyCheck.message ? keyCheck.message : "unknown" + "init: invalid VAPID config (%s)", + configCheck.message ? configCheck.message : "unknown" ); return false; } @@ -75,7 +74,9 @@ bool ESPWebPush::init( return false; } + _deinitRequested.store(false, std::memory_order_release); _stopRequested.store(false, std::memory_order_release); + TaskHandle_t workerTask = nullptr; const char *taskName = _config.worker.name.empty() ? "webpush" : _config.worker.name.c_str(); const BaseType_t created = xTaskCreatePinnedToCore( @@ -101,59 +102,111 @@ bool ESPWebPush::init( return true; } -void ESPWebPush::deinit() { - _initialized.store(false, std::memory_order_release); +void ESPWebPush::requestStop() { _stopRequested.store(true, std::memory_order_release); + _deinitRequested.store(true, std::memory_order_release); + _initialized.store(false, std::memory_order_release); TaskHandle_t workerTask = _workerTask.load(std::memory_order_acquire); - if (workerTask != nullptr) { + if (workerTask != nullptr && _queue != nullptr) { + QueueItem *wake = nullptr; + (void)xQueueSend(_queue, &wake, 0); + } +} + +WebPushJoinStatus ESPWebPush::join(uint32_t timeoutMs) { + TaskHandle_t workerTask = _workerTask.load(std::memory_order_acquire); + if (workerTask == nullptr) { if (_queue != nullptr) { - QueueItem *wake = nullptr; - (void)xQueueSend(_queue, &wake, 0); + (void)cleanupAfterWorkerStop(); + return WebPushJoinStatus::Completed; } - TickType_t start = xTaskGetTickCount(); - while (_workerTask.load(std::memory_order_acquire) != nullptr && - (xTaskGetTickCount() - start) <= pdMS_TO_TICKS(2000)) { - vTaskDelay(pdMS_TO_TICKS(10)); - } - workerTask = _workerTask.load(std::memory_order_acquire); - if (workerTask != nullptr) { - vTaskDelete(workerTask); - _workerTask.store(nullptr, std::memory_order_release); + return WebPushJoinStatus::NotRunning; + } + + if (xTaskGetCurrentTaskHandle() == workerTask) { + return WebPushJoinStatus::Timeout; + } + + const TickType_t start = xTaskGetTickCount(); + const TickType_t timeoutTicks = pdMS_TO_TICKS(timeoutMs); + while (_workerTask.load(std::memory_order_acquire) != nullptr) { + if ((xTaskGetTickCount() - start) >= timeoutTicks) { + return WebPushJoinStatus::Timeout; } + vTaskDelay(pdMS_TO_TICKS(10)); + } + + (void)cleanupAfterWorkerStop(); + return WebPushJoinStatus::Completed; +} + +WebPushJoinStatus ESPWebPush::deinit(uint32_t timeoutMs) { + requestStop(); + + WebPushJoinStatus status = join(timeoutMs); + if (status == WebPushJoinStatus::Timeout) { + return status; + } + + (void)cleanupAfterWorkerStop(); + return status; +} + +bool ESPWebPush::cleanupAfterWorkerStop() { + if (_workerTask.load(std::memory_order_acquire) != nullptr) { + return false; + } + + if (_deinitRequested.load(std::memory_order_acquire)) { + failPendingQueueItems(WebPushError::ShuttingDown); } if (_queue) { - QueueItem *item = nullptr; - while (xQueueReceive(_queue, &item, 0) == pdTRUE) { - if (item) { - freeItem(item); - } - } vQueueDelete(_queue); _queue = nullptr; } deinitCrypto(); - std::string().swap(_vapidPublicKey); - std::string().swap(_vapidPrivateKey); - std::string().swap(_vapidEmail); + { + std::lock_guard guard(_jwtCacheMutex); + _jwtCache = {}; + } + + _vapidConfig = WebPushVapidConfig{}; setNetworkValidator(WebPushNetworkValidator{}); _config = WebPushConfig{}; _stopRequested.store(false, std::memory_order_release); + _deinitRequested.store(false, std::memory_order_release); + return true; } -bool ESPWebPush::send(const PushMessage &msg, WebPushResultCB callback) { +WebPushEnqueueResult ESPWebPush::send(const PushMessage &msg, WebPushResultCB callback) { + if (_stopRequested.load(std::memory_order_acquire)) { + ESP_LOGW(kTag, "send: shutting down"); + return enqueueResultForError(WebPushError::ShuttingDown); + } + if (!isInitialized() || !_queue) { ESP_LOGW(kTag, "send: not initialized"); - return false; + return enqueueResultForError(WebPushError::NotInitialized); + } + + WebPushResult validation{}; + if (!validateMessage(msg, validation)) { + ESP_LOGW( + kTag, + "send: preflight validation failed (%s)", + validation.message ? validation.message : "unknown" + ); + return enqueueResultForError(validation.error); } QueueItem *item = allocateItem(); if (!item) { ESP_LOGW(kTag, "send: out of memory"); - return false; + return enqueueResultForError(WebPushError::OutOfMemory); } item->msg = msg; @@ -164,95 +217,90 @@ bool ESPWebPush::send(const PushMessage &msg, WebPushResultCB callback) { if (xQueueSend(_queue, &payload, waitTicks) != pdTRUE) { ESP_LOGW(kTag, "send: queue full"); freeItem(item); - return false; + if (_stopRequested.load(std::memory_order_acquire)) { + return enqueueResultForError(WebPushError::ShuttingDown); + } + return enqueueResultForError(WebPushError::QueueFull); } - return true; + return enqueueResultForError(WebPushError::None); } WebPushResult ESPWebPush::send(const PushMessage &msg) { + if (_stopRequested.load(std::memory_order_acquire)) { + return resultForError(WebPushError::ShuttingDown); + } if (!isInitialized()) { - WebPushResult result{}; - result.error = WebPushError::NotInitialized; - result.message = errorToString(result.error); - return result; + return resultForError(WebPushError::NotInitialized); } return handleMessage(msg); } -bool ESPWebPush::send( - const Subscription &sub, const PushPayload &payload, WebPushResultCB callback +WebPushEnqueueResult ESPWebPush::send( + const WebPushSubscription &subscription, const PushPayload &payload, WebPushResultCB callback ) { - if (!isInitialized() || !_queue) { - ESP_LOGW(kTag, "send: not initialized"); - return false; - } - PushMessage message; WebPushResult result{}; - if (!buildMessage(sub, payload, message, result)) { - if (callback) { - callback(result); - } - return false; + if (!buildMessage(subscription, payload, message, result)) { + return enqueueResultForError(result.error); } return send(message, std::move(callback)); } -WebPushResult ESPWebPush::send(const Subscription &sub, const PushPayload &payload) { +WebPushResult ESPWebPush::send( + const WebPushSubscription &subscription, const PushPayload &payload +) { + if (_stopRequested.load(std::memory_order_acquire)) { + return resultForError(WebPushError::ShuttingDown); + } if (!isInitialized()) { - WebPushResult result{}; - result.error = WebPushError::NotInitialized; - result.message = errorToString(result.error); - return result; + return resultForError(WebPushError::NotInitialized); } PushMessage message; WebPushResult result{}; - if (!buildMessage(sub, payload, message, result)) { + if (!buildMessage(subscription, payload, message, result)) { return result; } return send(message); } -bool ESPWebPush::send( - const Subscription &sub, const JsonDocument &payload, WebPushResultCB callback +WebPushEnqueueResult ESPWebPush::send( + const WebPushSubscription &subscription, const JsonDocument &payload, WebPushResultCB callback ) { - return send(sub, payload.as(), std::move(callback)); + return send(subscription, payload.as(), std::move(callback)); } -WebPushResult ESPWebPush::send(const Subscription &sub, const JsonDocument &payload) { - return send(sub, payload.as()); +WebPushResult ESPWebPush::send( + const WebPushSubscription &subscription, const JsonDocument &payload +) { + return send(subscription, payload.as()); } -bool ESPWebPush::send(const Subscription &sub, JsonVariantConst payload, WebPushResultCB callback) { - if (!isInitialized() || !_queue) { - ESP_LOGW(kTag, "send: not initialized"); - return false; - } - +WebPushEnqueueResult ESPWebPush::send( + const WebPushSubscription &subscription, JsonVariantConst payload, WebPushResultCB callback +) { PushMessage message; WebPushResult result{}; - if (!buildMessage(sub, payload, message, result)) { - if (callback) { - callback(result); - } - return false; + if (!buildMessage(subscription, payload, message, result)) { + return enqueueResultForError(result.error); } return send(message, std::move(callback)); } -WebPushResult ESPWebPush::send(const Subscription &sub, JsonVariantConst payload) { +WebPushResult ESPWebPush::send( + const WebPushSubscription &subscription, JsonVariantConst payload +) { + if (_stopRequested.load(std::memory_order_acquire)) { + return resultForError(WebPushError::ShuttingDown); + } if (!isInitialized()) { - WebPushResult result{}; - result.error = WebPushError::NotInitialized; - result.message = errorToString(result.error); - return result; + return resultForError(WebPushError::NotInitialized); } PushMessage message; WebPushResult result{}; - if (!buildMessage(sub, payload, message, result)) { + if (!buildMessage(subscription, payload, message, result)) { return result; } return send(message); @@ -264,6 +312,20 @@ void ESPWebPush::setNetworkValidator(WebPushNetworkValidator validator) { _config.networkValidator = _networkValidator; } +WebPushEnqueueResult ESPWebPush::enqueueResultForError(WebPushError error) const { + WebPushEnqueueResult result{}; + result.error = error; + result.message = errorToString(error); + return result; +} + +WebPushResult ESPWebPush::resultForError(WebPushError error) const { + WebPushResult result{}; + result.error = error; + result.message = errorToString(error); + return result; +} + const char *ESPWebPush::errorToString(WebPushError error) const { switch (error) { case WebPushError::None: @@ -282,6 +344,10 @@ const char *ESPWebPush::errorToString(WebPushError error) const { return "queue full"; case WebPushError::OutOfMemory: return "out of memory"; + case WebPushError::PayloadTooLarge: + return "payload too large"; + case WebPushError::ShuttingDown: + return "shutting down"; case WebPushError::CryptoInitFailed: return "crypto init failed"; case WebPushError::EncryptFailed: @@ -309,7 +375,10 @@ WebPushResult ESPWebPush::invalidPayloadResult() const { } bool ESPWebPush::buildMessage( - const Subscription &sub, const PushPayload &payload, PushMessage &message, WebPushResult &result + const WebPushSubscription &subscription, + const PushPayload &payload, + PushMessage &message, + WebPushResult &result ) const { std::string serializedPayload; const char *payloadError = serializePushPayload(payload, serializedPayload); @@ -319,13 +388,16 @@ bool ESPWebPush::buildMessage( return false; } - message.sub = sub; + message.subscription = subscription; message.payload = std::move(serializedPayload); return true; } bool ESPWebPush::buildMessage( - const Subscription &sub, JsonVariantConst payload, PushMessage &message, WebPushResult &result + const WebPushSubscription &subscription, + JsonVariantConst payload, + PushMessage &message, + WebPushResult &result ) const { std::string serializedPayload; const char *payloadError = validateAndSerializePushPayload(payload, serializedPayload); @@ -335,46 +407,50 @@ bool ESPWebPush::buildMessage( return false; } - message.sub = sub; + message.subscription = subscription; message.payload = std::move(serializedPayload); return true; } WebPushResult ESPWebPush::handleMessage(const PushMessage &msg) { WebPushResult result{}; - if (!validateSubscription(msg.sub, result)) { + if (!validateMessage(msg, result)) { return result; } for (uint8_t attempt = 0; attempt <= _config.maxRetries; ++attempt) { + if (_stopRequested.load(std::memory_order_acquire)) { + return resultForError(WebPushError::ShuttingDown); + } + if (!isNetworkReadyForPush()) { result.error = WebPushError::NetworkUnavailable; result.message = errorToString(result.error); if (attempt >= _config.maxRetries) { return result; } - vTaskDelay(pdMS_TO_TICKS(calcRetryDelayMs(attempt))); + if (!waitForStopAwareDelay(calcRetryDelayMs(attempt))) { + return resultForError(WebPushError::ShuttingDown); + } continue; } - std::string salt; - std::string serverKey; - std::vector ciphertext = encryptPayload(msg.payload, msg.sub, salt, serverKey); - if (ciphertext.empty()) { + std::vector body = encryptPayload(msg.payload, msg.subscription); + if (body.empty()) { result.error = WebPushError::EncryptFailed; result.message = errorToString(result.error); return result; } - std::string aud = endpointOrigin(msg.sub.endpoint); - std::string jwt = generateVapidJWT(aud, "mailto:" + _vapidEmail, _vapidPrivateKey); + const std::string aud = endpointOrigin(msg.subscription.endpoint); + const std::string jwt = jwtForAudience(aud); if (jwt.empty()) { result.error = WebPushError::JwtFailed; result.message = errorToString(result.error); return result; } - WebPushResult request = sendPushRequest(msg.sub.endpoint, jwt, salt, serverKey, ciphertext); + WebPushResult request = sendPushRequest(msg.subscription.endpoint, jwt, body); if (request.ok()) { return request; } @@ -384,7 +460,9 @@ WebPushResult ESPWebPush::handleMessage(const PushMessage &msg) { return result; } - vTaskDelay(pdMS_TO_TICKS(calcRetryDelayMs(attempt))); + if (!waitForStopAwareDelay(calcRetryDelayMs(attempt))) { + return resultForError(WebPushError::ShuttingDown); + } } result.error = WebPushError::InternalError; @@ -425,6 +503,24 @@ bool ESPWebPush::shouldRetry(const WebPushResult &result) const { return false; } +bool ESPWebPush::waitForStopAwareDelay(uint32_t delayMs) const { + if (delayMs == 0) { + return !_stopRequested.load(std::memory_order_acquire); + } + + TickType_t remaining = pdMS_TO_TICKS(delayMs); + while (remaining > 0) { + if (_stopRequested.load(std::memory_order_acquire)) { + return false; + } + TickType_t slice = remaining > kStopDelaySliceTicks ? kStopDelaySliceTicks : remaining; + vTaskDelay(slice); + remaining -= slice; + } + + return !_stopRequested.load(std::memory_order_acquire); +} + uint32_t ESPWebPush::calcRetryDelayMs(uint8_t attempt) const { if (_config.retryBaseDelayMs == 0) { return 0; @@ -436,8 +532,10 @@ uint32_t ESPWebPush::calcRetryDelayMs(uint8_t attempt) const { return delay; } -bool ESPWebPush::validateSubscription(const Subscription &sub, WebPushResult &result) const { - if (sub.deleted || sub.endpoint.empty() || sub.p256dh.empty() || sub.auth.empty()) { +bool ESPWebPush::validateSubscription( + const WebPushSubscription &subscription, WebPushResult &result +) const { + if (subscription.endpoint.empty() || subscription.p256dh.empty() || subscription.auth.empty()) { result.error = WebPushError::InvalidSubscription; result.message = errorToString(result.error); return false; @@ -445,24 +543,49 @@ bool ESPWebPush::validateSubscription(const Subscription &sub, WebPushResult &re return true; } -bool ESPWebPush::validateVapidKeys(WebPushResult &result) { - std::vector pubKey; - std::vector privKey; - if (!base64UrlDecode(_vapidPublicKey, pubKey) || !base64UrlDecode(_vapidPrivateKey, privKey)) { - result.error = WebPushError::InvalidVapidKeys; +bool ESPWebPush::validatePayloadSize(const std::string &payload, WebPushResult &result) const { + if (_config.maxPayloadBytes != 0 && payload.size() > _config.maxPayloadBytes) { + result.error = WebPushError::PayloadTooLarge; result.message = errorToString(result.error); return false; } - if (pubKey.size() != 65 || pubKey[0] != 0x04) { + return true; +} + +bool ESPWebPush::validateMessage(const PushMessage &msg, WebPushResult &result) const { + if (!validateSubscription(msg.subscription, result)) { + return false; + } + return validatePayloadSize(msg.payload, result); +} + +bool ESPWebPush::validateVapidSubject(WebPushResult &result) const { + if (_vapidConfig.subject.rfind("mailto:", 0) == 0 || + _vapidConfig.subject.rfind("https://", 0) == 0) { + return true; + } + result.error = WebPushError::InvalidConfig; + result.message = errorToString(result.error); + return false; +} + +bool ESPWebPush::validateVapidKeys(WebPushResult &result) { + std::vector pubKey; + std::vector privKey; + if (!decodeP256PublicKey(_vapidConfig.publicKeyBase64, pubKey) || + !decodeP256PrivateKey(_vapidConfig.privateKeyBase64, privKey)) { result.error = WebPushError::InvalidVapidKeys; result.message = errorToString(result.error); return false; } - if (privKey.size() != 32) { + + std::vector derivedPublicKey; + if (!deriveP256PublicKey(privKey, derivedPublicKey) || derivedPublicKey != pubKey) { result.error = WebPushError::InvalidVapidKeys; result.message = errorToString(result.error); return false; } + return true; } @@ -503,21 +626,48 @@ void ESPWebPush::freeItem(QueueItem *item) { heap_caps_free(item); } +void ESPWebPush::failPendingQueueItems(WebPushError error) { + if (!_queue) { + return; + } + + WebPushResult result = resultForError(error); + QueueItem *item = nullptr; + while (xQueueReceive(_queue, &item, 0) == pdTRUE) { + if (!item) { + continue; + } + if (item->callback) { + item->callback(result); + } + freeItem(item); + } +} + void ESPWebPush::workerLoop() { - while (!_stopRequested.load(std::memory_order_acquire)) { + while (true) { + if (_stopRequested.load(std::memory_order_acquire)) { + break; + } + QueueItem *item = nullptr; if (xQueueReceive(_queue, &item, kWorkerPollTicks) != pdTRUE) { continue; } if (!item) { + if (_stopRequested.load(std::memory_order_acquire)) { + break; + } continue; } + WebPushResult result = handleMessage(item->msg); if (item->callback) { item->callback(result); } freeItem(item); } + _workerTask.store(nullptr, std::memory_order_release); vTaskDelete(nullptr); } diff --git a/src/esp_webPush/webPush.h b/src/esp_webPush/webPush.h index 4d9076b..7198ba2 100644 --- a/src/esp_webPush/webPush.h +++ b/src/esp_webPush/webPush.h @@ -1,9 +1,15 @@ #pragma once +#if defined(ARDUINO) #include +#else +#include +#endif #include +#include #include +#include #include #include #include @@ -25,17 +31,20 @@ struct WebPushWorkerConfig { BaseType_t coreId = tskNO_AFFINITY; }; -struct Subscription { +struct WebPushVapidConfig { + std::string subject; + std::string publicKeyBase64; + std::string privateKeyBase64; +}; + +struct WebPushSubscription { std::string endpoint; std::string p256dh; std::string auth; - std::string deviceId; - std::vector disabledTags; - bool deleted = false; }; struct PushMessage { - Subscription sub; + WebPushSubscription subscription; std::string payload; }; @@ -64,6 +73,8 @@ struct PushPayload { enum class WebPushQueueMemory : uint8_t { Any = 0, Internal, Psram }; +enum class WebPushJoinStatus : uint8_t { Completed = 0, Timeout, NotRunning }; + enum class WebPushError : uint8_t { None = 0, NotInitialized, @@ -73,6 +84,8 @@ enum class WebPushError : uint8_t { InvalidVapidKeys, QueueFull, OutOfMemory, + PayloadTooLarge, + ShuttingDown, CryptoInitFailed, EncryptFailed, JwtFailed, @@ -98,6 +111,19 @@ struct WebPushResult { } }; +struct WebPushEnqueueResult { + WebPushError error = WebPushError::None; + const char *message = nullptr; + + bool queued() const { + return error == WebPushError::None; + } + + explicit operator bool() const { + return queued(); + } +}; + using WebPushResultCB = std::function; using WebPushNetworkValidator = std::function; @@ -111,6 +137,7 @@ struct WebPushConfig { uint8_t maxRetries = 5; uint32_t retryBaseDelayMs = 1500; uint32_t retryMaxDelayMs = 15000; + size_t maxPayloadBytes = 3993; WebPushNetworkValidator networkValidator; }; @@ -119,26 +146,35 @@ class ESPWebPush { ESPWebPush() = default; ~ESPWebPush(); - bool init( - const std::string &contactEmail, - const std::string &publicKeyBase64, - const std::string &privateKeyBase64, - const WebPushConfig &config = WebPushConfig{} - ); + bool init(const WebPushVapidConfig &vapidConfig, const WebPushConfig &config = WebPushConfig{}); - void deinit(); + void requestStop(); + WebPushJoinStatus join(uint32_t timeoutMs); + WebPushJoinStatus deinit(uint32_t timeoutMs = 10000); bool isInitialized() const { return _initialized.load(std::memory_order_acquire); } - bool send(const PushMessage &msg, WebPushResultCB callback); + WebPushEnqueueResult send(const PushMessage &msg, WebPushResultCB callback); WebPushResult send(const PushMessage &msg); - bool send(const Subscription &sub, const PushPayload &payload, WebPushResultCB callback); - WebPushResult send(const Subscription &sub, const PushPayload &payload); - bool send(const Subscription &sub, const JsonDocument &payload, WebPushResultCB callback); - WebPushResult send(const Subscription &sub, const JsonDocument &payload); - bool send(const Subscription &sub, JsonVariantConst payload, WebPushResultCB callback); - WebPushResult send(const Subscription &sub, JsonVariantConst payload); + WebPushEnqueueResult send( + const WebPushSubscription &subscription, + const PushPayload &payload, + WebPushResultCB callback + ); + WebPushResult send(const WebPushSubscription &subscription, const PushPayload &payload); + WebPushEnqueueResult send( + const WebPushSubscription &subscription, + const JsonDocument &payload, + WebPushResultCB callback + ); + WebPushResult send(const WebPushSubscription &subscription, const JsonDocument &payload); + WebPushEnqueueResult send( + const WebPushSubscription &subscription, + JsonVariantConst payload, + WebPushResultCB callback + ); + WebPushResult send(const WebPushSubscription &subscription, JsonVariantConst payload); void setNetworkValidator(WebPushNetworkValidator validator); @@ -150,25 +186,39 @@ class ESPWebPush { WebPushResultCB callback; }; + struct JwtCacheEntry { + std::string aud; + std::string token; + time_t exp = 0; + uint32_t lastUsedTick = 0; + }; + struct CryptoState; struct CryptoDeleter { void operator()(CryptoState *state) const; }; WebPushResult handleMessage(const PushMessage &msg); + WebPushEnqueueResult enqueueResultForError(WebPushError error) const; + WebPushResult resultForError(WebPushError error) const; + bool cleanupAfterWorkerStop(); bool shouldRetry(const WebPushResult &result) const; uint32_t calcRetryDelayMs(uint8_t attempt) const; + bool waitForStopAwareDelay(uint32_t delayMs) const; - bool validateSubscription(const Subscription &sub, WebPushResult &result) const; + bool validateSubscription(const WebPushSubscription &subscription, WebPushResult &result) const; + bool validatePayloadSize(const std::string &payload, WebPushResult &result) const; + bool validateMessage(const PushMessage &msg, WebPushResult &result) const; + bool validateVapidSubject(WebPushResult &result) const; bool validateVapidKeys(WebPushResult &result); bool buildMessage( - const Subscription &sub, + const WebPushSubscription &subscription, const PushPayload &payload, PushMessage &message, WebPushResult &result ) const; bool buildMessage( - const Subscription &sub, + const WebPushSubscription &subscription, JsonVariantConst payload, PushMessage &message, WebPushResult &result @@ -181,46 +231,62 @@ class ESPWebPush { static void workerLoopThunk(void *arg); void workerLoop(); + void failPendingQueueItems(WebPushError error); bool initCrypto(); void deinitCrypto(); std::string base64UrlEncode(const uint8_t *data, size_t len); std::string base64UrlEncode(const std::string &input); - bool base64UrlDecode(const std::string &input, std::vector &output); - - std::string generateVapidJWT( - const std::string &aud, const std::string &sub, const std::string &vapidPrivateKeyBase64 - ); - - std::vector encryptPayload( - const std::string &plaintext, - const Subscription &sub, - std::string &salt, - std::string &publicServerKey - ); + bool base64UrlDecode(const std::string &input, std::vector &output) const; - bool generateSalt(uint8_t *saltBin, std::string &saltOut); + std::string generateVapidJWT(const std::string &aud, time_t &expOut); + std::string jwtForAudience(const std::string &aud); - bool generateECDHContext( - const std::vector &userPubKey, - uint8_t *sharedSecret, - uint8_t *serverPubKey, - size_t &pubLen, - std::string &publicServerKey + bool decodeP256PublicKey(const std::string &keyBase64, std::vector &output) const; + bool decodeP256PrivateKey(const std::string &keyBase64, std::vector &output) const; + bool deriveP256PublicKey( + const std::vector &privateKey, std::vector &publicKeyOut + ) const; + bool deriveSharedSecret( + const std::vector &peerPublicKey, + const std::vector &privateKey, + uint8_t *sharedSecret ); - - bool deriveKeys( + bool deriveInputKeyingMaterial( const uint8_t *authSecret, size_t authSecretLen, - const uint8_t *salt, const uint8_t *sharedSecret, - uint8_t *cek, - uint8_t *nonce, const uint8_t *clientPubKey, size_t clientPubKeyLen, const uint8_t *serverPubKey, - size_t serverPubKeyLen + size_t serverPubKeyLen, + uint8_t *ikm + ) const; + bool deriveContentEncryptionKeyAndNonce( + const uint8_t *salt, const uint8_t *ikm, uint8_t *cek, uint8_t *nonce + ) const; + bool buildRecordBody( + const uint8_t *salt, + uint32_t recordSize, + const uint8_t *serverPubKey, + size_t serverPubKeyLen, + const std::string &plaintext, + const uint8_t *cek, + const uint8_t *nonce, + std::vector &bodyOut + ); + + std::vector encryptPayload( + const std::string &plaintext, + const WebPushSubscription &subscription + ); + + bool generateSalt(uint8_t *saltBin); + + bool generateECDHContext( + const std::vector &privateKey, + std::vector &publicKeyOut ); bool encryptWithAESGCM( @@ -232,28 +298,29 @@ class ESPWebPush { bool isNetworkReadyForPush() const; WebPushResult sendPushRequest( - const std::string &endpoint, - const std::string &jwt, - const std::string &salt, - const std::string &serverPublicKey, - const std::vector &ciphertext + const std::string &endpoint, const std::string &jwt, const std::vector &body ); void printHeaderErr(esp_err_t headErr, const char *headKey) const; std::string endpointOrigin(const std::string &endpoint) const; - std::string _vapidPublicKey{}; - std::string _vapidPrivateKey{}; - std::string _vapidEmail{}; + static constexpr uint32_t kDefaultRecordSize = 4010; + static constexpr uint32_t kDefaultDeinitTimeoutMs = 10000; + static constexpr size_t kJwtCacheSize = 4; + + WebPushVapidConfig _vapidConfig{}; WebPushConfig _config{}; std::atomic _workerTask{nullptr}; QueueHandle_t _queue = nullptr; std::atomic _initialized{false}; std::atomic _stopRequested{false}; + std::atomic _deinitRequested{false}; std::unique_ptr _crypto{}; std::mutex _cryptoMutex; + mutable std::mutex _jwtCacheMutex; mutable std::mutex _networkValidatorMutex; WebPushNetworkValidator _networkValidator{}; + std::array _jwtCache{}; }; diff --git a/src/esp_webPush/webPush_crypto.cpp b/src/esp_webPush/webPush_crypto.cpp index 3855136..5b7d741 100644 --- a/src/esp_webPush/webPush_crypto.cpp +++ b/src/esp_webPush/webPush_crypto.cpp @@ -8,12 +8,20 @@ extern "C" { #include "mbedtls/cipher.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/ecdh.h" +#include "mbedtls/ecp.h" #include "mbedtls/entropy.h" #include "mbedtls/md.h" } namespace { constexpr const char *kTag = "ESPWebPush"; + +void appendUint32(std::vector &buffer, uint32_t value) { + buffer.push_back(static_cast((value >> 24) & 0xFF)); + buffer.push_back(static_cast((value >> 16) & 0xFF)); + buffer.push_back(static_cast((value >> 8) & 0xFF)); + buffer.push_back(static_cast(value & 0xFF)); +} } // namespace struct ESPWebPush::CryptoState { @@ -48,6 +56,7 @@ bool ESPWebPush::initCrypto() { if (_crypto->initialized) { return true; } + const char *pers = "espwebpush_drbg"; int ret = mbedtls_ctr_drbg_seed( &(_crypto->ctrDrbg), @@ -60,6 +69,7 @@ bool ESPWebPush::initCrypto() { ESP_LOGE(kTag, "initCrypto: failed to seed DRBG: -0x%04x", -ret); return false; } + _crypto->initialized = true; return true; } @@ -69,7 +79,7 @@ void ESPWebPush::deinitCrypto() { _crypto.reset(); } -bool ESPWebPush::generateSalt(uint8_t *saltBin, std::string &saltOut) { +bool ESPWebPush::generateSalt(uint8_t *saltBin) { if (!_crypto || !_crypto->initialized) { ESP_LOGE(kTag, "generateSalt: crypto not initialized"); return false; @@ -78,198 +88,229 @@ bool ESPWebPush::generateSalt(uint8_t *saltBin, std::string &saltOut) { ESP_LOGE(kTag, "generateSalt: failed to generate salt"); return false; } - saltOut = base64UrlEncode(saltBin, 16); - return !saltOut.empty(); + return true; +} + +bool ESPWebPush::decodeP256PublicKey(const std::string &keyBase64, std::vector &output) + const { + if (!base64UrlDecode(keyBase64, output) || output.size() != 65 || output[0] != 0x04) { + output.clear(); + return false; + } + return true; +} + +bool ESPWebPush::decodeP256PrivateKey(const std::string &keyBase64, std::vector &output) + const { + if (!base64UrlDecode(keyBase64, output) || output.size() != 32) { + output.clear(); + return false; + } + return true; +} + +bool ESPWebPush::deriveP256PublicKey( + const std::vector &privateKey, std::vector &publicKeyOut +) const { + if (privateKey.size() != 32) { + return false; + } + + bool success = false; + mbedtls_ecp_group group; + mbedtls_mpi d; + mbedtls_ecp_point q; + + mbedtls_ecp_group_init(&group); + mbedtls_mpi_init(&d); + mbedtls_ecp_point_init(&q); + + do { + if (mbedtls_ecp_group_load(&group, MBEDTLS_ECP_DP_SECP256R1) != 0) { + break; + } + if (mbedtls_mpi_read_binary(&d, privateKey.data(), privateKey.size()) != 0) { + break; + } + if (mbedtls_ecp_check_privkey(&group, &d) != 0) { + break; + } + if (mbedtls_ecp_mul(&group, &q, &d, &group.G, nullptr, nullptr) != 0) { + break; + } + if (mbedtls_ecp_check_pubkey(&group, &q) != 0) { + break; + } + + publicKeyOut.assign(65, 0); + size_t actualLen = 0; + if (mbedtls_ecp_point_write_binary( + &group, + &q, + MBEDTLS_ECP_PF_UNCOMPRESSED, + &actualLen, + publicKeyOut.data(), + publicKeyOut.size() + ) != 0) { + publicKeyOut.clear(); + break; + } + if (actualLen != 65 || publicKeyOut[0] != 0x04) { + publicKeyOut.clear(); + break; + } + success = true; + } while (false); + + mbedtls_ecp_point_free(&q); + mbedtls_mpi_free(&d); + mbedtls_ecp_group_free(&group); + + return success; } bool ESPWebPush::generateECDHContext( - const std::vector &userPubKey, - uint8_t *sharedSecret, - uint8_t *serverPubKey, - size_t &pubLen, - std::string &publicServerKey + const std::vector &privateKey, std::vector &publicKeyOut ) { - if (!_crypto || !_crypto->initialized) { - ESP_LOGE(kTag, "generateECDHContext: crypto not initialized"); + if (privateKey.size() != 32) { + ESP_LOGE(kTag, "generateECDHContext: private key length invalid"); return false; } + return deriveP256PublicKey(privateKey, publicKeyOut); +} - bool success = false; - mbedtls_ecdh_context ecdh; - mbedtls_ecdh_init(&ecdh); +bool ESPWebPush::deriveSharedSecret( + const std::vector &peerPublicKey, + const std::vector &privateKey, + uint8_t *sharedSecret +) { + if (peerPublicKey.size() != 65 || peerPublicKey[0] != 0x04 || privateKey.size() != 32 || + sharedSecret == nullptr) { + return false; + } - mbedtls_ecp_group grp; + bool success = false; + mbedtls_ecp_group group; mbedtls_mpi d; mbedtls_mpi z; - mbedtls_ecp_point Q; - mbedtls_ecp_point Qp; + mbedtls_ecp_point q; - mbedtls_ecp_group_init(&grp); + mbedtls_ecp_group_init(&group); mbedtls_mpi_init(&d); mbedtls_mpi_init(&z); - mbedtls_ecp_point_init(&Q); - mbedtls_ecp_point_init(&Qp); + mbedtls_ecp_point_init(&q); do { - if (mbedtls_ecp_group_load(&grp, MBEDTLS_ECP_DP_SECP256R1) != 0) { - ESP_LOGE(kTag, "ECDH: failed to load curve"); + if (mbedtls_ecp_group_load(&group, MBEDTLS_ECP_DP_SECP256R1) != 0) { break; } - if (mbedtls_ecdh_setup(&ecdh, MBEDTLS_ECP_DP_SECP256R1) != 0) { - ESP_LOGE(kTag, "ECDH: failed to setup context"); + if (mbedtls_mpi_read_binary(&d, privateKey.data(), privateKey.size()) != 0) { break; } - if (mbedtls_ecdh_gen_public(&grp, &d, &Q, mbedtls_ctr_drbg_random, &(_crypto->ctrDrbg)) != - 0) { - ESP_LOGE(kTag, "ECDH: failed to generate public key"); + if (mbedtls_ecp_check_privkey(&group, &d) != 0) { break; } - if (mbedtls_ecp_point_read_binary(&grp, &Qp, userPubKey.data(), userPubKey.size()) != 0) { - ESP_LOGE(kTag, "ECDH: failed to read user public key"); + if (mbedtls_ecp_point_read_binary(&group, &q, peerPublicKey.data(), peerPublicKey.size()) != + 0) { break; } - if (mbedtls_ecp_check_pubkey(&grp, &Qp) != 0) { - ESP_LOGE(kTag, "ECDH: user public key invalid"); + if (mbedtls_ecp_check_pubkey(&group, &q) != 0) { break; } if (mbedtls_ecdh_compute_shared( - &grp, + &group, &z, - &Qp, + &q, &d, mbedtls_ctr_drbg_random, - &(_crypto->ctrDrbg) + _crypto ? &(_crypto->ctrDrbg) : nullptr ) != 0) { - ESP_LOGE(kTag, "ECDH: failed to compute shared secret"); - break; - } - size_t zLen = mbedtls_mpi_size(&z); - if (zLen == 0 || zLen > 32) { - ESP_LOGE(kTag, "ECDH: shared secret length invalid (%u)", static_cast(zLen)); break; } memset(sharedSecret, 0, 32); if (mbedtls_mpi_write_binary(&z, sharedSecret, 32) != 0) { - ESP_LOGE(kTag, "ECDH: failed to write shared secret"); break; } - if (mbedtls_ecp_point_write_binary( - &grp, - &Q, - MBEDTLS_ECP_PF_UNCOMPRESSED, - &pubLen, - serverPubKey, - 65 - ) != 0) { - ESP_LOGE(kTag, "ECDH: failed to write ephemeral public key"); - break; - } - if (pubLen != 65 || serverPubKey[0] != 0x04) { - ESP_LOGE(kTag, "ECDH: invalid ephemeral public key"); - break; - } - publicServerKey = base64UrlEncode(serverPubKey, 65); - success = !publicServerKey.empty(); + success = true; } while (false); - mbedtls_ecdh_free(&ecdh); - mbedtls_mpi_free(&d); + mbedtls_ecp_point_free(&q); mbedtls_mpi_free(&z); - mbedtls_ecp_point_free(&Q); - mbedtls_ecp_point_free(&Qp); - mbedtls_ecp_group_free(&grp); + mbedtls_mpi_free(&d); + mbedtls_ecp_group_free(&group); return success; } -bool ESPWebPush::deriveKeys( +bool ESPWebPush::deriveInputKeyingMaterial( const uint8_t *authSecret, size_t authSecretLen, - const uint8_t *salt, const uint8_t *sharedSecret, - uint8_t *cek, - uint8_t *nonce, const uint8_t *clientPubKey, size_t clientPubKeyLen, const uint8_t *serverPubKey, - size_t serverPubKeyLen -) { + size_t serverPubKeyLen, + uint8_t *ikm +) const { const mbedtls_md_info_t *md = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); - if (!md) { - ESP_LOGE(kTag, "deriveKeys: SHA-256 not supported"); + if (!md || !authSecret || !sharedSecret || !clientPubKey || !serverPubKey || !ikm) { return false; } - uint8_t prk[32]; - uint8_t ms[32]; - std::vector authInfo; - authInfo.reserve(23); - const char *authLabel = "Content-Encoding: auth"; - authInfo.insert(authInfo.end(), authLabel, authLabel + strlen(authLabel)); - authInfo.push_back(0x00); - authInfo.push_back(0x01); - - if (mbedtls_md_hmac(md, authSecret, authSecretLen, sharedSecret, 32, prk) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to compute PRK"); + uint8_t prkKey[32]; + std::vector keyInfo; + const char *label = "WebPush: info"; + keyInfo.reserve(strlen(label) + 1 + clientPubKeyLen + serverPubKeyLen + 1); + keyInfo.insert(keyInfo.end(), label, label + strlen(label)); + keyInfo.push_back(0x00); + keyInfo.insert(keyInfo.end(), clientPubKey, clientPubKey + clientPubKeyLen); + keyInfo.insert(keyInfo.end(), serverPubKey, serverPubKey + serverPubKeyLen); + keyInfo.push_back(0x01); + + if (mbedtls_md_hmac(md, authSecret, authSecretLen, sharedSecret, 32, prkKey) != 0) { return false; } - if (mbedtls_md_hmac(md, prk, 32, authInfo.data(), authInfo.size(), ms) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to compute MS"); + return mbedtls_md_hmac(md, prkKey, sizeof(prkKey), keyInfo.data(), keyInfo.size(), ikm) == 0; +} + +bool ESPWebPush::deriveContentEncryptionKeyAndNonce( + const uint8_t *salt, const uint8_t *ikm, uint8_t *cek, uint8_t *nonce +) const { + const mbedtls_md_info_t *md = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + if (!md || !salt || !ikm || !cek || !nonce) { return false; } - std::vector context; - context.reserve(141); - const char *contextLabel = "P-256"; - context.insert(context.end(), contextLabel, contextLabel + strlen(contextLabel)); - context.push_back(0x00); - uint16_t clientLen = static_cast(clientPubKeyLen); - context.push_back((clientLen >> 8) & 0xFF); - context.push_back(clientLen & 0xFF); - context.insert(context.end(), clientPubKey, clientPubKey + clientPubKeyLen); - uint16_t serverLen = static_cast(serverPubKeyLen); - context.push_back((serverLen >> 8) & 0xFF); - context.push_back(serverLen & 0xFF); - context.insert(context.end(), serverPubKey, serverPubKey + serverPubKeyLen); + uint8_t prk[32]; + uint8_t cekFull[32]; + uint8_t nonceFull[32]; + const char *cekLabel = "Content-Encoding: aes128gcm"; std::vector cekInfo; - cekInfo.reserve(165); - const char *cekLabel = "Content-Encoding: aesgcm"; + cekInfo.reserve(strlen(cekLabel) + 2); cekInfo.insert(cekInfo.end(), cekLabel, cekLabel + strlen(cekLabel)); cekInfo.push_back(0x00); - cekInfo.insert(cekInfo.end(), context.begin(), context.end()); cekInfo.push_back(0x01); - uint8_t prkCek[32]; - uint8_t cekFull[32]; - if (mbedtls_md_hmac(md, salt, 16, ms, 32, prkCek) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to compute PRK_CEK"); - return false; - } - if (mbedtls_md_hmac(md, prkCek, 32, cekInfo.data(), cekInfo.size(), cekFull) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to derive CEK"); - return false; - } - memcpy(cek, cekFull, 16); - std::vector nonceInfo; - nonceInfo.reserve(164); const char *nonceLabel = "Content-Encoding: nonce"; + std::vector nonceInfo; + nonceInfo.reserve(strlen(nonceLabel) + 2); nonceInfo.insert(nonceInfo.end(), nonceLabel, nonceLabel + strlen(nonceLabel)); nonceInfo.push_back(0x00); - nonceInfo.insert(nonceInfo.end(), context.begin(), context.end()); nonceInfo.push_back(0x01); - uint8_t prkNonce[32]; - uint8_t nonceFull[32]; - if (mbedtls_md_hmac(md, salt, 16, ms, 32, prkNonce) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to compute PRK_nonce"); + + if (mbedtls_md_hmac(md, salt, 16, ikm, 32, prk) != 0) { return false; } - if (mbedtls_md_hmac(md, prkNonce, 32, nonceInfo.data(), nonceInfo.size(), nonceFull) != 0) { - ESP_LOGE(kTag, "deriveKeys: failed to derive nonce"); + if (mbedtls_md_hmac(md, prk, sizeof(prk), cekInfo.data(), cekInfo.size(), cekFull) != 0) { + return false; + } + if (mbedtls_md_hmac(md, prk, sizeof(prk), nonceInfo.data(), nonceInfo.size(), nonceFull) != 0) { return false; } - memcpy(nonce, nonceFull, 12); + memcpy(cek, cekFull, 16); + memcpy(nonce, nonceFull, 12); return true; } @@ -292,27 +333,22 @@ bool ESPWebPush::encryptWithAESGCM( mbedtls_cipher_init(&cipher); if (mbedtls_cipher_setup(&cipher, mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_GCM)) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: cipher setup failed"); mbedtls_cipher_free(&cipher); return false; } if (mbedtls_cipher_setkey(&cipher, cek, 128, MBEDTLS_ENCRYPT) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: set key failed"); mbedtls_cipher_free(&cipher); return false; } if (mbedtls_cipher_set_iv(&cipher, nonce, 12) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: set IV failed"); mbedtls_cipher_free(&cipher); return false; } if (mbedtls_cipher_reset(&cipher) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: reset failed"); mbedtls_cipher_free(&cipher); return false; } if (mbedtls_cipher_update_ad(&cipher, nullptr, 0) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: update AD failed"); mbedtls_cipher_free(&cipher); return false; } @@ -325,111 +361,142 @@ bool ESPWebPush::encryptWithAESGCM( output.data(), &olen ) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: encrypt failed"); mbedtls_cipher_free(&cipher); return false; } outputLen = olen; if (mbedtls_cipher_finish(&cipher, output.data() + outputLen, &olen) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: finish failed"); mbedtls_cipher_free(&cipher); return false; } outputLen += olen; - if (mbedtls_cipher_write_tag(&cipher, tag, 16) != 0) { - ESP_LOGE(kTag, "encryptWithAESGCM: write tag failed"); + if (mbedtls_cipher_write_tag(&cipher, tag, sizeof(tag)) != 0) { mbedtls_cipher_free(&cipher); return false; } - output.resize(outputLen + 16); - memcpy(output.data() + outputLen, tag, 16); - + output.resize(outputLen + sizeof(tag)); + memcpy(output.data() + outputLen, tag, sizeof(tag)); ciphertextOut = std::move(output); mbedtls_cipher_free(&cipher); return true; } -std::vector ESPWebPush::encryptPayload( +bool ESPWebPush::buildRecordBody( + const uint8_t *salt, + uint32_t recordSize, + const uint8_t *serverPubKey, + size_t serverPubKeyLen, const std::string &plaintext, - const Subscription &sub, - std::string &salt, - std::string &publicServerKey + const uint8_t *cek, + const uint8_t *nonce, + std::vector &bodyOut +) { + if (!salt || !serverPubKey || !cek || !nonce || serverPubKeyLen > 0xFF) { + return false; + } + + std::string recordPlaintext = plaintext; + recordPlaintext.push_back(static_cast(0x02)); + + std::vector ciphertext; + if (!encryptWithAESGCM(recordPlaintext, cek, nonce, ciphertext)) { + return false; + } + + bodyOut.clear(); + bodyOut.reserve(16 + 4 + 1 + serverPubKeyLen + ciphertext.size()); + bodyOut.insert(bodyOut.end(), salt, salt + 16); + appendUint32(bodyOut, recordSize); + bodyOut.push_back(static_cast(serverPubKeyLen)); + bodyOut.insert(bodyOut.end(), serverPubKey, serverPubKey + serverPubKeyLen); + bodyOut.insert(bodyOut.end(), ciphertext.begin(), ciphertext.end()); + return true; +} + +std::vector ESPWebPush::encryptPayload( + const std::string &plaintext, const WebPushSubscription &subscription ) { std::lock_guard guard(_cryptoMutex); if (!initCrypto()) { return {}; } - uint8_t saltBin[16]; - uint8_t sharedSecret[32]; - uint8_t serverPubKey[65]; - size_t pubLen = 0; - uint8_t cek[16]; - uint8_t nonce[12]; - std::vector userPubKey; - if (!base64UrlDecode(sub.p256dh, userPubKey) || userPubKey.empty()) { - ESP_LOGE(kTag, "encryptPayload: failed to decode client public key"); - return {}; - } - if (userPubKey.size() != 65 || userPubKey[0] != 0x04) { + if (!decodeP256PublicKey(subscription.p256dh, userPubKey)) { ESP_LOGE(kTag, "encryptPayload: client public key invalid"); return {}; } - std::vector authSecretBin; - if (!base64UrlDecode(sub.auth, authSecretBin) || authSecretBin.empty()) { - ESP_LOGE(kTag, "encryptPayload: failed to decode auth secret"); + std::vector authSecret; + if (!base64UrlDecode(subscription.auth, authSecret) || authSecret.size() != 16) { + ESP_LOGE(kTag, "encryptPayload: auth secret invalid"); return {}; } - if (authSecretBin.size() != 16) { - ESP_LOGE( - kTag, - "encryptPayload: auth secret length invalid (%u)", - static_cast(authSecretBin.size()) - ); + + std::vector serverPrivateKey(32, 0); + if (mbedtls_ctr_drbg_random(&(_crypto->ctrDrbg), serverPrivateKey.data(), serverPrivateKey.size()) != + 0) { + ESP_LOGE(kTag, "encryptPayload: failed to generate ephemeral private key"); return {}; } - if (!generateSalt(saltBin, salt)) { - ESP_LOGE(kTag, "encryptPayload: failed to generate salt"); + std::vector serverPublicKey; + if (!generateECDHContext(serverPrivateKey, serverPublicKey)) { + ESP_LOGE(kTag, "encryptPayload: failed to derive ephemeral public key"); return {}; } - if (!generateECDHContext(userPubKey, sharedSecret, serverPubKey, pubLen, publicServerKey)) { - ESP_LOGE(kTag, "encryptPayload: failed to generate ECDH context"); + uint8_t sharedSecret[32]; + if (!deriveSharedSecret(userPubKey, serverPrivateKey, sharedSecret)) { + ESP_LOGE(kTag, "encryptPayload: failed to derive shared secret"); return {}; } - if (!deriveKeys( - authSecretBin.data(), - authSecretBin.size(), - saltBin, + uint8_t ikm[32]; + if (!deriveInputKeyingMaterial( + authSecret.data(), + authSecret.size(), sharedSecret, - cek, - nonce, userPubKey.data(), userPubKey.size(), - serverPubKey, - pubLen + serverPublicKey.data(), + serverPublicKey.size(), + ikm )) { - ESP_LOGE(kTag, "encryptPayload: failed to derive keys"); + ESP_LOGE(kTag, "encryptPayload: failed to derive input keying material"); return {}; } - std::string paddedPlaintext; - paddedPlaintext.push_back(0x00); - paddedPlaintext.push_back(0x00); - paddedPlaintext.append(plaintext); + uint8_t salt[16]; + if (!generateSalt(salt)) { + ESP_LOGE(kTag, "encryptPayload: failed to generate salt"); + return {}; + } - std::vector ciphertext; - if (!encryptWithAESGCM(paddedPlaintext, cek, nonce, ciphertext)) { - ESP_LOGE(kTag, "encryptPayload: AES-GCM failed"); + uint8_t cek[16]; + uint8_t nonce[12]; + if (!deriveContentEncryptionKeyAndNonce(salt, ikm, cek, nonce)) { + ESP_LOGE(kTag, "encryptPayload: failed to derive content key and nonce"); + return {}; + } + + std::vector body; + if (!buildRecordBody( + salt, + kDefaultRecordSize, + serverPublicKey.data(), + serverPublicKey.size(), + plaintext, + cek, + nonce, + body + )) { + ESP_LOGE(kTag, "encryptPayload: failed to build encrypted body"); return {}; } - return ciphertext; + return body; } diff --git a/src/esp_webPush/webPush_http.cpp b/src/esp_webPush/webPush_http.cpp index bee7de3..4f26cd2 100644 --- a/src/esp_webPush/webPush_http.cpp +++ b/src/esp_webPush/webPush_http.cpp @@ -16,11 +16,7 @@ void ESPWebPush::printHeaderErr(esp_err_t headErr, const char *headKey) const { } WebPushResult ESPWebPush::sendPushRequest( - const std::string &endpoint, - const std::string &jwt, - const std::string &salt, - const std::string &serverPublicKey, - const std::vector &ciphertext + const std::string &endpoint, const std::string &jwt, const std::vector &body ) { WebPushResult result{}; if (endpoint.empty()) { @@ -33,7 +29,7 @@ WebPushResult ESPWebPush::sendPushRequest( config.url = endpoint.c_str(); config.method = HTTP_METHOD_POST; config.timeout_ms = static_cast(_config.requestTimeoutMs); - config.buffer_size_tx = 6048; + config.buffer_size_tx = 6144; esp_http_client_handle_t client = esp_http_client_init(&config); if (!client) { @@ -42,10 +38,9 @@ WebPushResult ESPWebPush::sendPushRequest( return result; } - std::string authHeader = "vapid t=" + jwt + ", k=" + _vapidPublicKey; - std::string cryptoKeyHeader = "dh=" + serverPublicKey + ";p256ecdsa=" + _vapidPublicKey; - std::string encryptionHeader = "salt=" + salt; - std::string ttlValue = std::to_string(_config.ttlSeconds); + const std::string authHeader = + "vapid t=" + jwt + ", k=" + _vapidConfig.publicKeyBase64; + const std::string ttlValue = std::to_string(_config.ttlSeconds); printHeaderErr( esp_http_client_set_header(client, "Authorization", authHeader.c_str()), @@ -53,30 +48,22 @@ WebPushResult ESPWebPush::sendPushRequest( ); printHeaderErr(esp_http_client_set_header(client, "TTL", ttlValue.c_str()), "TTL"); printHeaderErr( - esp_http_client_set_header(client, "Content-Encoding", "aesgcm"), + esp_http_client_set_header(client, "Content-Encoding", "aes128gcm"), "Content-Encoding" ); printHeaderErr( esp_http_client_set_header(client, "Content-Type", "application/octet-stream"), "Content-Type" ); - printHeaderErr( - esp_http_client_set_header(client, "Encryption", encryptionHeader.c_str()), - "Encryption" - ); - printHeaderErr( - esp_http_client_set_header(client, "Crypto-Key", cryptoKeyHeader.c_str()), - "Crypto-Key" - ); esp_http_client_set_post_field( client, - reinterpret_cast(ciphertext.data()), - ciphertext.size() + reinterpret_cast(body.data()), + static_cast(body.size()) ); - esp_err_t err = esp_http_client_perform(client); - int statusCode = esp_http_client_get_status_code(client); + const esp_err_t err = esp_http_client_perform(client); + const int statusCode = esp_http_client_get_status_code(client); result.transportError = err; result.statusCode = statusCode; diff --git a/src/esp_webPush/webPush_jwt.cpp b/src/esp_webPush/webPush_jwt.cpp index c6b0f8c..1e0eee7 100644 --- a/src/esp_webPush/webPush_jwt.cpp +++ b/src/esp_webPush/webPush_jwt.cpp @@ -1,13 +1,12 @@ #include "webPush.h" #include -#include +#include #include #include extern "C" { #include "esp_log.h" -#include "mbedtls/base64.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/ecdsa.h" #include "mbedtls/ecp.h" @@ -17,76 +16,44 @@ extern "C" { namespace { constexpr const char *kTag = "ESPWebPush"; - -std::string base64Normalize(const std::string &urlEncoded) { - std::string out = urlEncoded; - std::replace(out.begin(), out.end(), '-', '+'); - std::replace(out.begin(), out.end(), '_', '/'); - while (out.size() % 4 != 0) { - out.push_back('='); - } - return out; -} +constexpr time_t kJwtLifetimeSeconds = 12 * 60 * 60; +constexpr time_t kJwtRefreshMarginSeconds = 5 * 60; } // namespace -std::string ESPWebPush::generateVapidJWT( - const std::string &aud, const std::string &sub, const std::string &vapidPrivateKeyBase64 -) { - std::string header = R"({"alg":"ES256","typ":"JWT"})"; - unsigned long now = static_cast(std::time(nullptr)); - unsigned long exp = now + 12 * 60 * 60; - - char payloadBuf[256]; - snprintf( - payloadBuf, - sizeof(payloadBuf), - R"({"aud":"%s","exp":%lu,"sub":"%s"})", - aud.c_str(), - exp, - sub.c_str() - ); - - std::string encodedHeader = base64UrlEncode(header); - std::string encodedPayload = base64UrlEncode(payloadBuf); - if (encodedHeader.empty() || encodedPayload.empty()) { +std::string ESPWebPush::generateVapidJWT(const std::string &aud, time_t &expOut) { + std::vector privateKey; + if (!decodeP256PrivateKey(_vapidConfig.privateKeyBase64, privateKey)) { + ESP_LOGE(kTag, "generateVapidJWT: failed to decode private key"); return ""; } - std::string message = encodedHeader + "." + encodedPayload; - - std::string normalizedKey = base64Normalize(vapidPrivateKeyBase64); - normalizedKey.erase( - std::remove_if(normalizedKey.begin(), normalizedKey.end(), ::isspace), - normalizedKey.end() - ); - - std::vector privBytes(32); - size_t olen = 0; - int res = mbedtls_base64_decode( - privBytes.data(), - privBytes.size(), - &olen, - reinterpret_cast(normalizedKey.data()), - normalizedKey.size() - ); - if (res != 0 || olen != 32) { - ESP_LOGE(kTag, "generateVapidJWT: failed to decode private key"); + const time_t now = std::time(nullptr); + expOut = now + kJwtLifetimeSeconds; + + const std::string header = R"({"alg":"ES256","typ":"JWT"})"; + const std::string payload = std::string(R"({"aud":")") + aud + R"(","exp":)" + + std::to_string(static_cast(expOut)) + + R"(,"sub":")" + _vapidConfig.subject + R"("})"; + + const std::string encodedHeader = base64UrlEncode(header); + const std::string encodedPayload = base64UrlEncode(payload); + if (encodedHeader.empty() || encodedPayload.empty()) { return ""; } + const std::string message = encodedHeader + "." + encodedPayload; + mbedtls_mpi d; mbedtls_mpi r; mbedtls_mpi s; - mbedtls_ecp_group grp; - mbedtls_ecp_point Q; + mbedtls_ecp_group group; mbedtls_ctr_drbg_context ctrDrbg; mbedtls_entropy_context entropy; mbedtls_mpi_init(&d); mbedtls_mpi_init(&r); mbedtls_mpi_init(&s); - mbedtls_ecp_group_init(&grp); - mbedtls_ecp_point_init(&Q); + mbedtls_ecp_group_init(&group); mbedtls_ctr_drbg_init(&ctrDrbg); mbedtls_entropy_init(&entropy); @@ -105,29 +72,13 @@ std::string ESPWebPush::generateVapidJWT( ESP_LOGE(kTag, "generateVapidJWT: failed to seed DRBG"); break; } - - if (mbedtls_ecp_group_load(&grp, MBEDTLS_ECP_DP_SECP256R1) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: failed to load curve"); - break; - } - - if (mbedtls_mpi_read_binary(&d, privBytes.data(), 32) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: failed to load private key"); - break; - } - - if (mbedtls_ecp_check_privkey(&grp, &d) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: private key invalid"); + if (mbedtls_ecp_group_load(&group, MBEDTLS_ECP_DP_SECP256R1) != 0) { break; } - - if (mbedtls_ecp_mul(&grp, &Q, &d, &grp.G, mbedtls_ctr_drbg_random, &ctrDrbg) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: failed to derive public key"); + if (mbedtls_mpi_read_binary(&d, privateKey.data(), privateKey.size()) != 0) { break; } - - if (mbedtls_ecp_check_pubkey(&grp, &Q) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: derived public key invalid"); + if (mbedtls_ecp_check_privkey(&group, &d) != 0) { break; } @@ -136,25 +87,23 @@ std::string ESPWebPush::generateVapidJWT( mbedtls_md_init(&mdctx); const mbedtls_md_info_t *mdinfo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); if (!mdinfo || mbedtls_md_setup(&mdctx, mdinfo, 0) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: md setup failed"); mbedtls_md_free(&mdctx); break; } if (mbedtls_md_starts(&mdctx) != 0 || mbedtls_md_update( &mdctx, - reinterpret_cast(message.data()), + reinterpret_cast(message.data()), message.size() ) != 0 || mbedtls_md_finish(&mdctx, hash) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: sha256 failed"); mbedtls_md_free(&mdctx); break; } mbedtls_md_free(&mdctx); if (mbedtls_ecdsa_sign_det_ext( - &grp, + &group, &r, &s, &d, @@ -164,15 +113,16 @@ std::string ESPWebPush::generateVapidJWT( mbedtls_ctr_drbg_random, &ctrDrbg ) != 0) { - ESP_LOGE(kTag, "generateVapidJWT: signing failed"); break; } - uint8_t sig[64] = {}; - mbedtls_mpi_write_binary(&r, sig, 32); - mbedtls_mpi_write_binary(&s, sig + 32, 32); + std::array sig{}; + if (mbedtls_mpi_write_binary(&r, sig.data(), 32) != 0 || + mbedtls_mpi_write_binary(&s, sig.data() + 32, 32) != 0) { + break; + } - std::string encodedSig = base64UrlEncode(sig, sizeof(sig)); + const std::string encodedSig = base64UrlEncode(sig.data(), sig.size()); if (encodedSig.empty()) { break; } @@ -181,13 +131,62 @@ std::string ESPWebPush::generateVapidJWT( success = true; } while (false); - mbedtls_mpi_free(&d); - mbedtls_mpi_free(&r); - mbedtls_mpi_free(&s); - mbedtls_ecp_group_free(&grp); - mbedtls_ecp_point_free(&Q); - mbedtls_ctr_drbg_free(&ctrDrbg); mbedtls_entropy_free(&entropy); + mbedtls_ctr_drbg_free(&ctrDrbg); + mbedtls_ecp_group_free(&group); + mbedtls_mpi_free(&s); + mbedtls_mpi_free(&r); + mbedtls_mpi_free(&d); return success ? signedToken : ""; } + +std::string ESPWebPush::jwtForAudience(const std::string &aud) { + if (aud.empty()) { + return ""; + } + + const time_t now = std::time(nullptr); + { + std::lock_guard guard(_jwtCacheMutex); + for (JwtCacheEntry &entry : _jwtCache) { + if (entry.aud == aud && !entry.token.empty() && + entry.exp > (now + kJwtRefreshMarginSeconds)) { + entry.lastUsedTick = xTaskGetTickCount(); + return entry.token; + } + } + } + + time_t exp = 0; + const std::string jwt = generateVapidJWT(aud, exp); + if (jwt.empty()) { + return ""; + } + + std::lock_guard guard(_jwtCacheMutex); + JwtCacheEntry *target = nullptr; + for (JwtCacheEntry &entry : _jwtCache) { + if (entry.aud == aud) { + target = &entry; + break; + } + if (!target && entry.token.empty()) { + target = &entry; + } + } + if (!target) { + target = &_jwtCache[0]; + for (JwtCacheEntry &entry : _jwtCache) { + if (entry.lastUsedTick < target->lastUsedTick) { + target = &entry; + } + } + } + + target->aud = aud; + target->token = jwt; + target->exp = exp; + target->lastUsedTick = xTaskGetTickCount(); + return target->token; +} diff --git a/src/esp_webPush/webPush_utils.cpp b/src/esp_webPush/webPush_utils.cpp index de3cb17..e61d710 100644 --- a/src/esp_webPush/webPush_utils.cpp +++ b/src/esp_webPush/webPush_utils.cpp @@ -36,7 +36,7 @@ std::string ESPWebPush::base64UrlEncode(const std::string &input) { return base64UrlEncode(reinterpret_cast(input.data()), input.size()); } -bool ESPWebPush::base64UrlDecode(const std::string &input, std::vector &output) { +bool ESPWebPush::base64UrlDecode(const std::string &input, std::vector &output) const { if (input.empty()) { output.clear(); return false; diff --git a/test/test_esp_webPush/test_esp_webPush.cpp b/test/test_esp_webPush/test_esp_webPush.cpp index 2b20a71..90cf962 100644 --- a/test/test_esp_webPush/test_esp_webPush.cpp +++ b/test/test_esp_webPush/test_esp_webPush.cpp @@ -1,6 +1,19 @@ #include #include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define private public #include +#undef private + #include #include "freertos/FreeRTOS.h" @@ -8,10 +21,31 @@ namespace { -constexpr const char *kContact = "notify@example.com"; -constexpr const char *kPublicKey = - "BAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0-P0A"; -constexpr const char *kPrivateKey = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA"; +constexpr const char *kSubject = "mailto:notify@example.com"; +constexpr const char *kAltSubject = "https://www.esptoolkit.hu/contact"; +constexpr const char *kSenderPublicKey = + "BP4z9KsN6nGRTbVYI_c7VJSPQTBtkgcy27mlmlMoZIIgDll6e3vCYLocInmYWAmS6TlzAC8wEqKK6PBru3jl7A8"; +constexpr const char *kSenderPrivateKey = "yfWPiYE-n46HLnH0KqZOF1fJJU3MYrct3AELtAQ-oRw"; +constexpr const char *kReceiverPublicKey = + "BCVxsr7N_eNgVRqvHtD0zTZsEc6-VV-JvLexhqUzORcxaOzi6-AYWXvTBHm4bjyPjs7Vd8pZGH6SRpkNtoIAiw4"; +constexpr const char *kReceiverPrivateKey = "q1dXpw3UpT5VOmu_cf_v6ih07Aems3njxI-JWgLcM94"; +constexpr const char *kAuthSecret = "BTBZMqHH6r4Tts7J_aSIgg"; +constexpr const char *kSalt = "DGv6ra1nlYgDCS1FRnbzlw"; +constexpr const char *kExpectedSharedSecret = "kyrL1jIIOHEzg3sM2ZWRHDRB62YACZhhSlknJ672kSs"; +constexpr const char *kExpectedIkm = "S4lYMb_L0FxCeq0WhDx813KgSYqU26kOyzWUdsXYyrg"; +constexpr const char *kExpectedCek = "oIhVW04MRdy2XN9CiKLxTg"; +constexpr const char *kExpectedNonce = "4h_95klXJ5E_qnoN"; +constexpr const char *kExpectedBody = + "DGv6ra1nlYgDCS1FRnbzlwAAEABBBP4z9KsN6nGRTbVYI_c7VJSPQTBtkgcy27mlmlMoZIIgDll6e3vCYLocInmYWAmS6TlzAC8wEqKK6PBru3jl7A_yl95bQpu6cVPTpK4Mqgkf1CXztLVBSt2Ks3oZwbuwXPXLWyouBWLVWGNWQexSgSxsj_Qulcy4a-fN"; +constexpr const char *kPlaintext = "When I grow up, I want to be a watermelon"; + +WebPushVapidConfig testVapidConfig() { + WebPushVapidConfig cfg{}; + cfg.subject = kSubject; + cfg.publicKeyBase64 = kSenderPublicKey; + cfg.privateKeyBase64 = kSenderPrivateKey; + return cfg; +} WebPushConfig testConfig() { WebPushConfig cfg{}; @@ -20,74 +54,154 @@ WebPushConfig testConfig() { cfg.worker.stackSizeBytes = 4096; cfg.worker.priority = 2; cfg.worker.name = "wp-test"; + cfg.requestTimeoutMs = 200; return cfg; } -Subscription testSubscription() { - Subscription sub{}; - sub.endpoint = "https://example.com/push"; - sub.p256dh = "invalid-p256dh"; - sub.auth = "invalid-auth"; - return sub; +WebPushSubscription testSubscription() { + WebPushSubscription subscription{}; + subscription.endpoint = "https://example.com/push"; + subscription.p256dh = kReceiverPublicKey; + subscription.auth = kAuthSecret; + return subscription; } PushPayload testPayload() { - PushPayload payload{}; + PushPayload payload; payload.title = "Hello"; payload.body = "World"; return payload; } +bool waitForFlag(std::atomic &flag, uint32_t timeoutMs) { + const TickType_t deadline = xTaskGetTickCount() + pdMS_TO_TICKS(timeoutMs); + while (!flag.load() && xTaskGetTickCount() < deadline) { + vTaskDelay(pdMS_TO_TICKS(10)); + } + return flag.load(); +} + +std::string buildLongHttpsSubject() { + std::string subject = "https://www.esptoolkit.hu/contact?"; + for (int i = 0; i < 32; ++i) { + subject += "segment" + std::to_string(i) + "=abcdefghijklmnopqrstuvwxyz0123456789&"; + } + return subject; +} + void test_deinit_is_safe_before_init() { ESPWebPush webPush; TEST_ASSERT_FALSE(webPush.isInitialized()); - webPush.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::NotRunning), + static_cast(webPush.deinit()) + ); TEST_ASSERT_FALSE(webPush.isInitialized()); } void test_deinit_is_idempotent() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); TEST_ASSERT_TRUE(webPush.isInitialized()); - webPush.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.deinit()) + ); TEST_ASSERT_FALSE(webPush.isInitialized()); - webPush.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::NotRunning), + static_cast(webPush.deinit()) + ); TEST_ASSERT_FALSE(webPush.isInitialized()); } void test_reinit_after_deinit() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); TEST_ASSERT_TRUE(webPush.isInitialized()); - webPush.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.deinit()) + ); TEST_ASSERT_FALSE(webPush.isInitialized()); - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + WebPushVapidConfig alt = testVapidConfig(); + alt.subject = kAltSubject; + TEST_ASSERT_TRUE(webPush.init(alt, testConfig())); TEST_ASSERT_TRUE(webPush.isInitialized()); - webPush.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.deinit()) + ); } void test_destructor_deinits_active_instance() { { ESPWebPush first; - TEST_ASSERT_TRUE(first.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(first.init(testVapidConfig(), testConfig())); TEST_ASSERT_TRUE(first.isInitialized()); } ESPWebPush second; - TEST_ASSERT_TRUE(second.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(second.init(testVapidConfig(), testConfig())); TEST_ASSERT_TRUE(second.isInitialized()); - second.deinit(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(second.deinit()) + ); +} + +void test_request_stop_is_safe_before_init() { + ESPWebPush webPush; + webPush.requestStop(); + TEST_ASSERT_FALSE(webPush.isInitialized()); +} + +void test_join_returns_not_running_before_init() { + ESPWebPush webPush; + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::NotRunning), + static_cast(webPush.join(100)) + ); +} + +void test_request_stop_and_join_complete_for_idle_worker() { + ESPWebPush webPush; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); + + webPush.requestStop(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.join(1000)) + ); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::NotRunning), + static_cast(webPush.deinit()) + ); +} + +void test_invalid_subject_rejected() { + ESPWebPush webPush; + WebPushVapidConfig bad = testVapidConfig(); + bad.subject = "notify@example.com"; + TEST_ASSERT_FALSE(webPush.init(bad, testConfig())); +} + +void test_mismatched_vapid_keys_rejected() { + ESPWebPush webPush; + WebPushVapidConfig bad = testVapidConfig(); + bad.publicKeyBase64 = kReceiverPublicKey; + TEST_ASSERT_FALSE(webPush.init(bad, testConfig())); } void test_push_payload_rejects_missing_required_fields() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); - PushPayload payload{}; + PushPayload payload; payload.title = "Hello"; WebPushResult result = webPush.send(testSubscription(), payload); @@ -96,12 +210,12 @@ void test_push_payload_rejects_missing_required_fields() { static_cast(result.error) ); - webPush.deinit(); + (void)webPush.deinit(); } void test_json_document_rejects_unknown_top_level_keys() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); JsonDocument doc; doc["title"] = "Hello"; @@ -114,12 +228,12 @@ void test_json_document_rejects_unknown_top_level_keys() { static_cast(result.error) ); - webPush.deinit(); + (void)webPush.deinit(); } void test_json_variant_rejects_wrong_types() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); JsonDocument doc; doc["title"] = "Hello"; @@ -131,38 +245,108 @@ void test_json_variant_rejects_wrong_types() { static_cast(result.error) ); - webPush.deinit(); + (void)webPush.deinit(); +} + +void test_subscription_requires_only_transport_fields() { + ESPWebPush webPush; + WebPushConfig cfg = testConfig(); + cfg.networkValidator = []() { return false; }; + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); + + WebPushSubscription subscription = testSubscription(); + WebPushResult validResult = webPush.send(subscription, testPayload()); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::NetworkUnavailable), + static_cast(validResult.error) + ); + + subscription.endpoint.clear(); + WebPushResult invalidResult = webPush.send(subscription, testPayload()); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::InvalidSubscription), + static_cast(invalidResult.error) + ); + + (void)webPush.deinit(); } -void test_async_invalid_payload_reports_failure_without_enqueue() { +void test_async_invalid_payload_returns_enqueue_error_without_callback() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); bool callbackCalled = false; - WebPushError callbackError = WebPushError::None; JsonDocument doc; doc["title"] = "Hello"; - bool queued = webPush.send(testSubscription(), doc, [&](WebPushResult result) { + WebPushEnqueueResult enqueue = webPush.send(testSubscription(), doc, [&](WebPushResult) { callbackCalled = true; - callbackError = result.error; }); - TEST_ASSERT_FALSE(queued); - TEST_ASSERT_TRUE(callbackCalled); + TEST_ASSERT_FALSE(enqueue.queued()); TEST_ASSERT_EQUAL( static_cast(WebPushError::InvalidPayload), - static_cast(callbackError) + static_cast(enqueue.error) ); + TEST_ASSERT_FALSE(callbackCalled); - webPush.deinit(); + (void)webPush.deinit(); +} + +void test_payload_limit_is_enforced_for_raw_messages() { + ESPWebPush webPush; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); + + PushMessage fits{}; + fits.subscription = testSubscription(); + fits.payload.assign(3993, 'a'); + + PushMessage tooLarge = fits; + tooLarge.payload.push_back('b'); + + WebPushResult fitResult = webPush.send(fits); + TEST_ASSERT_NOT_EQUAL( + static_cast(WebPushError::PayloadTooLarge), + static_cast(fitResult.error) + ); + + WebPushResult largeResult = webPush.send(tooLarge); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::PayloadTooLarge), + static_cast(largeResult.error) + ); + + (void)webPush.deinit(); +} + +void test_payload_limit_can_be_disabled() { + ESPWebPush webPush; + WebPushConfig cfg = testConfig(); + cfg.maxPayloadBytes = 0; + cfg.networkValidator = []() { return false; }; + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); + + PushMessage msg{}; + msg.subscription = testSubscription(); + msg.payload.assign(5000, 'x'); + + WebPushResult result = webPush.send(msg); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::NetworkUnavailable), + static_cast(result.error) + ); + + (void)webPush.deinit(); } void test_network_validator_false_short_circuits_send() { ESPWebPush webPush; WebPushConfig cfg = testConfig(); cfg.networkValidator = []() { return false; }; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, cfg)); + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); WebPushResult result = webPush.send(testSubscription(), testPayload()); TEST_ASSERT_EQUAL( @@ -170,12 +354,12 @@ void test_network_validator_false_short_circuits_send() { static_cast(result.error) ); - webPush.deinit(); + (void)webPush.deinit(); } void test_missing_network_validator_does_not_force_network_unavailable() { ESPWebPush webPush; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, testConfig())); + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), testConfig())); WebPushResult result = webPush.send(testSubscription(), testPayload()); TEST_ASSERT_NOT_EQUAL( @@ -190,7 +374,8 @@ void test_network_validator_can_be_replaced_after_init() { ESPWebPush webPush; WebPushConfig cfg = testConfig(); cfg.networkValidator = []() { return false; }; - TEST_ASSERT_TRUE(webPush.init(kContact, kPublicKey, kPrivateKey, cfg)); + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); WebPushResult blocked = webPush.send(testSubscription(), testPayload()); TEST_ASSERT_EQUAL( @@ -208,6 +393,221 @@ void test_network_validator_can_be_replaced_after_init() { webPush.deinit(); } +void test_async_queued_message_invokes_callback_once() { + ESPWebPush webPush; + WebPushConfig cfg = testConfig(); + cfg.networkValidator = []() { return false; }; + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); + + std::atomic callbackDone{false}; + int callbackCount = 0; + WebPushError callbackError = WebPushError::None; + + WebPushEnqueueResult enqueue = webPush.send(testSubscription(), testPayload(), [&](WebPushResult result) { + ++callbackCount; + callbackError = result.error; + callbackDone.store(true); + }); + + TEST_ASSERT_TRUE(enqueue.queued()); + TEST_ASSERT_TRUE(waitForFlag(callbackDone, 1000)); + TEST_ASSERT_EQUAL(1, callbackCount); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::NetworkUnavailable), + static_cast(callbackError) + ); + + webPush.deinit(); +} + +void test_deinit_fails_pending_queue_items_with_shutting_down() { + ESPWebPush webPush; + WebPushConfig cfg = testConfig(); + cfg.queueLength = 4; + cfg.networkValidator = []() { return false; }; + cfg.maxRetries = 6; + cfg.retryBaseDelayMs = 200; + cfg.retryMaxDelayMs = 200; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); + + std::atomic firstCalled{false}; + std::atomic secondCalled{false}; + WebPushError firstError = WebPushError::None; + WebPushError secondError = WebPushError::None; + + WebPushEnqueueResult first = webPush.send(testSubscription(), testPayload(), [&](WebPushResult result) { + firstError = result.error; + firstCalled.store(true); + }); + WebPushEnqueueResult second = webPush.send(testSubscription(), testPayload(), [&](WebPushResult result) { + secondError = result.error; + secondCalled.store(true); + }); + + TEST_ASSERT_TRUE(first.queued()); + TEST_ASSERT_TRUE(second.queued()); + + vTaskDelay(pdMS_TO_TICKS(20)); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.deinit()) + ); + + TEST_ASSERT_TRUE(waitForFlag(firstCalled, 1000)); + TEST_ASSERT_TRUE(waitForFlag(secondCalled, 1000)); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::ShuttingDown), + static_cast(firstError) + ); + TEST_ASSERT_EQUAL( + static_cast(WebPushError::ShuttingDown), + static_cast(secondError) + ); +} + +void test_join_timeout_can_be_followed_by_later_success() { + ESPWebPush webPush; + WebPushConfig cfg = testConfig(); + cfg.networkValidator = []() { return false; }; + cfg.maxRetries = 0; + TEST_ASSERT_TRUE(webPush.init(testVapidConfig(), cfg)); + + std::atomic callbackEntered{false}; + std::atomic callbackDone{false}; + + WebPushEnqueueResult enqueue = + webPush.send(testSubscription(), testPayload(), [&](WebPushResult) { + callbackEntered.store(true); + vTaskDelay(pdMS_TO_TICKS(200)); + callbackDone.store(true); + }); + + TEST_ASSERT_TRUE(enqueue.queued()); + TEST_ASSERT_TRUE(waitForFlag(callbackEntered, 500)); + + webPush.requestStop(); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Timeout), + static_cast(webPush.join(10)) + ); + TEST_ASSERT_TRUE(waitForFlag(callbackDone, 1000)); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::Completed), + static_cast(webPush.join(1000)) + ); + TEST_ASSERT_EQUAL( + static_cast(WebPushJoinStatus::NotRunning), + static_cast(webPush.deinit()) + ); +} + +void test_rfc8291_key_derivation_matches_appendix_a() { + ESPWebPush webPush; + TEST_ASSERT_TRUE(webPush.initCrypto()); + + std::vector authSecret; + std::vector senderPrivate; + std::vector senderPublic; + std::vector receiverPublic; + std::vector salt; + uint8_t sharedSecret[32]; + uint8_t ikm[32]; + uint8_t cek[16]; + uint8_t nonce[12]; + + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kAuthSecret, authSecret)); + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kSenderPrivateKey, senderPrivate)); + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kSalt, salt)); + TEST_ASSERT_TRUE(webPush.decodeP256PublicKey(kSenderPublicKey, senderPublic)); + TEST_ASSERT_TRUE(webPush.decodeP256PublicKey(kReceiverPublicKey, receiverPublic)); + TEST_ASSERT_TRUE(webPush.deriveSharedSecret(receiverPublic, senderPrivate, sharedSecret)); + TEST_ASSERT_TRUE(webPush.deriveInputKeyingMaterial( + authSecret.data(), + authSecret.size(), + sharedSecret, + receiverPublic.data(), + receiverPublic.size(), + senderPublic.data(), + senderPublic.size(), + ikm)); + TEST_ASSERT_TRUE(webPush.deriveContentEncryptionKeyAndNonce(salt.data(), ikm, cek, nonce)); + + TEST_ASSERT_EQUAL_STRING( + kExpectedSharedSecret, + webPush.base64UrlEncode(sharedSecret, sizeof(sharedSecret)).c_str() + ); + TEST_ASSERT_EQUAL_STRING( + kExpectedIkm, + webPush.base64UrlEncode(ikm, sizeof(ikm)).c_str() + ); + TEST_ASSERT_EQUAL_STRING( + kExpectedCek, + webPush.base64UrlEncode(cek, sizeof(cek)).c_str() + ); + TEST_ASSERT_EQUAL_STRING( + kExpectedNonce, + webPush.base64UrlEncode(nonce, sizeof(nonce)).c_str() + ); + + webPush.deinitCrypto(); +} + +void test_rfc8291_body_matches_example() { + ESPWebPush webPush; + + std::vector salt; + std::vector senderPublic; + uint8_t cek[16]; + uint8_t nonce[12]; + std::vector cekBytes; + std::vector nonceBytes; + std::vector body; + + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kSalt, salt)); + TEST_ASSERT_TRUE(webPush.decodeP256PublicKey(kSenderPublicKey, senderPublic)); + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kExpectedCek, cekBytes)); + TEST_ASSERT_TRUE(webPush.base64UrlDecode(kExpectedNonce, nonceBytes)); + memcpy(cek, cekBytes.data(), sizeof(cek)); + memcpy(nonce, nonceBytes.data(), sizeof(nonce)); + + TEST_ASSERT_TRUE(webPush.buildRecordBody( + salt.data(), + 4096, + senderPublic.data(), + senderPublic.size(), + kPlaintext, + cek, + nonce, + body)); + + TEST_ASSERT_EQUAL_STRING(kExpectedBody, webPush.base64UrlEncode(body.data(), body.size()).c_str()); +} + +void test_generate_vapid_jwt_keeps_long_https_subject() { + ESPWebPush webPush; + webPush._vapidConfig = testVapidConfig(); + webPush._vapidConfig.subject = buildLongHttpsSubject(); + + time_t exp = 0; + std::string jwt = webPush.generateVapidJWT("https://push.example.com", exp); + TEST_ASSERT_FALSE(jwt.empty()); + + size_t firstDot = jwt.find('.'); + size_t secondDot = jwt.find('.', firstDot + 1); + TEST_ASSERT_NOT_EQUAL(std::string::npos, firstDot); + TEST_ASSERT_NOT_EQUAL(std::string::npos, secondDot); + + std::vector payloadBytes; + TEST_ASSERT_TRUE(webPush.base64UrlDecode( + jwt.substr(firstDot + 1, secondDot - firstDot - 1), + payloadBytes + )); + + std::string payload(payloadBytes.begin(), payloadBytes.end()); + TEST_ASSERT_NOT_EQUAL(std::string::npos, payload.find(webPush._vapidConfig.subject)); +} + } // namespace void setUp() { @@ -222,13 +622,27 @@ void setup() { RUN_TEST(test_deinit_is_idempotent); RUN_TEST(test_reinit_after_deinit); RUN_TEST(test_destructor_deinits_active_instance); + RUN_TEST(test_request_stop_is_safe_before_init); + RUN_TEST(test_join_returns_not_running_before_init); + RUN_TEST(test_request_stop_and_join_complete_for_idle_worker); + RUN_TEST(test_invalid_subject_rejected); + RUN_TEST(test_mismatched_vapid_keys_rejected); RUN_TEST(test_push_payload_rejects_missing_required_fields); RUN_TEST(test_json_document_rejects_unknown_top_level_keys); RUN_TEST(test_json_variant_rejects_wrong_types); - RUN_TEST(test_async_invalid_payload_reports_failure_without_enqueue); + RUN_TEST(test_subscription_requires_only_transport_fields); + RUN_TEST(test_async_invalid_payload_returns_enqueue_error_without_callback); + RUN_TEST(test_payload_limit_is_enforced_for_raw_messages); + RUN_TEST(test_payload_limit_can_be_disabled); RUN_TEST(test_network_validator_false_short_circuits_send); RUN_TEST(test_missing_network_validator_does_not_force_network_unavailable); RUN_TEST(test_network_validator_can_be_replaced_after_init); + RUN_TEST(test_async_queued_message_invokes_callback_once); + RUN_TEST(test_deinit_fails_pending_queue_items_with_shutting_down); + RUN_TEST(test_join_timeout_can_be_followed_by_later_success); + RUN_TEST(test_rfc8291_key_derivation_matches_appendix_a); + RUN_TEST(test_rfc8291_body_matches_example); + RUN_TEST(test_generate_vapid_jwt_keeps_long_https_subject); UNITY_END(); }