diff --git a/.github/workflows/autoport-41.yml b/.github/workflows/autoport-41.yml new file mode 100644 index 00000000000..682ac53c370 --- /dev/null +++ b/.github/workflows/autoport-41.yml @@ -0,0 +1,130 @@ +name: Auto-port to 4.1 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.2' + - '5.0' + +jobs: + autoport: + name: "Auto-porting to 4.1" + concurrency: + group: port-41-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-4.1') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-4.1" + if [[ $(git branch --show-current) != '4.1' ]]; then + git fetch origin 4.1:4.1 + fi + git checkout -b "$PORT_BRANCH" 4.1 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 4.1: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '4.1', + body: `Auto-port of #${context.payload.pull_request.number} to 4.1\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 4.1: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-4.1' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 4.1.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 4.1 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/autoport-42.yml b/.github/workflows/autoport-42.yml new file mode 100644 index 00000000000..15b27eafe67 --- /dev/null +++ b/.github/workflows/autoport-42.yml @@ -0,0 +1,130 @@ +name: Auto-port to 4.2 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.1' + - '5.0' + +jobs: + autoport: + name: "Auto-porting to 4.2" + concurrency: + group: port-42-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-4.2') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-4.2" + if [[ $(git branch --show-current) != '4.2' ]]; then + git fetch origin 4.2:4.2 + fi + git checkout -b "$PORT_BRANCH" 4.2 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 4.2: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '4.2', + body: `Auto-port of #${context.payload.pull_request.number} to 4.2\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 4.2: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-4.2' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 4.2.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 4.2 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/autoport-50.yml b/.github/workflows/autoport-50.yml new file mode 100644 index 00000000000..2899d56e209 --- /dev/null +++ b/.github/workflows/autoport-50.yml @@ -0,0 +1,130 @@ +name: Auto-port to 5.0 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.1' + - '4.2' + +jobs: + autoport: + name: "Auto-porting to 5.0" + concurrency: + group: port-50-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-5.0') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-5.0" + if [[ $(git branch --show-current) != '5.0' ]]; then + git fetch origin 5.0:5.0 + fi + git checkout -b "$PORT_BRANCH" 5.0 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 5.0: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '5.0', + body: `Auto-port of #${context.payload.pull_request.number} to 5.0\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 5.0: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-5.0' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 5.0.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 5.0 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/ci-deploy.yml b/.github/workflows/ci-deploy.yml index 3f41bd26508..7a8e2abf467 100644 --- a/.github/workflows/ci-deploy.yml +++ b/.github/workflows/ci-deploy.yml @@ -90,7 +90,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8 os: macos-15 diff --git a/.github/workflows/ci-pr.yml b/.github/workflows/ci-pr.yml index 506ac787e17..f67ed7ed0fd 100644 --- a/.github/workflows/ci-pr.yml +++ b/.github/workflows/ci-pr.yml @@ -201,16 +201,35 @@ jobs: - setup: linux-x86_64-java11-adaptive docker-compose-build: "-f docker/docker-compose.yaml -f docker/docker-compose.centos-6.111.yaml build" docker-compose-run: "-f docker/docker-compose.yaml -f docker/docker-compose.centos-6.111.yaml run build-leak-adaptive" + - setup: linux-x86_64-java11-awslc + docker-compose-build: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml build" + docker-compose-install-tcnative: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run install-tcnative" + docker-compose-update-tcnative-version: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run update-tcnative-version" + docker-compose-run: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run build" name: ${{ matrix.setup }} build needs: verify-pr + defaults: + run: + working-directory: netty steps: - uses: actions/checkout@v4 + with: + path: netty + + - uses: actions/checkout@v4 + if: ${{ endsWith(matrix.setup, '-awslc') }} + with: + repository: netty/netty-tcnative + ref: main + path: netty-tcnative + fetch-depth: 0 # Cache .m2/repository - name: Cache local Maven repository uses: actions/cache@v4 continue-on-error: true + if: ${{ !endsWith(matrix.setup, '-awslc') }} with: path: ~/.m2/repository key: cache-maven-${{ hashFiles('**/pom.xml') }} @@ -218,9 +237,28 @@ jobs: cache-maven-${{ hashFiles('**/pom.xml') }} cache-maven- + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + if: ${{ endsWith(matrix.setup, '-awslc') }} + with: + path: ~/.m2-al2023/repository + key: cache-maven-al2023-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-al2023-${{ hashFiles('**/pom.xml') }} + cache-maven-al2023- + - name: Build docker image run: docker compose ${{ matrix.docker-compose-build }} + - name: Install custom netty-tcnative + if: ${{ endsWith(matrix.setup, '-awslc') }} + run: docker compose ${{ matrix.docker-compose-install-tcnative }} + + - name: Update netty-tcnative version + if: ${{ endsWith(matrix.setup, '-awslc') }} + run: docker compose ${{ matrix.docker-compose-update-tcnative-version }} + - name: Build project with leak detection run: docker compose ${{ matrix.docker-compose-run }} | tee build-leak.output @@ -231,7 +269,7 @@ jobs: run: ./.github/scripts/check_leak.sh build-leak.output - name: print JVM thread dumps when cancelled - uses: ./.github/actions/thread-dump-jvms + uses: ./netty/.github/actions/thread-dump-jvms if: ${{ cancelled() }} - name: Upload Test Results @@ -239,17 +277,17 @@ jobs: uses: actions/upload-artifact@v4 with: name: test-results-${{ matrix.setup }} - path: '**/target/surefire-reports/TEST-*.xml' + path: 'netty/**/target/surefire-reports/TEST-*.xml' - uses: actions/upload-artifact@v4 if: ${{ failure() || cancelled() }} with: name: build-${{ matrix.setup }}-target path: | - **/target/surefire-reports/ - **/target/autobahntestsuite-reports/ - **/hs_err*.log - **/core.* + netty/**/target/surefire-reports/ + netty/**/target/autobahntestsuite-reports/ + netty/**/hs_err*.log + netty/**/core.* build-pr-macos: strategy: @@ -257,7 +295,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8-boringssl - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8-boringssl os: macos-15 diff --git a/.github/workflows/ci-release-4.2.yml b/.github/workflows/ci-release-4.2.yml index 619fb0ac819..51709846987 100644 --- a/.github/workflows/ci-release-4.2.yml +++ b/.github/workflows/ci-release-4.2.yml @@ -185,7 +185,7 @@ jobs: matrix: include: - setup: macos-x86_64-java11 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java11 os: macos-15 diff --git a/.github/workflows/ci-release.yml b/.github/workflows/ci-release.yml index 9f60f3b80e2..c797a64f127 100644 --- a/.github/workflows/ci-release.yml +++ b/.github/workflows/ci-release.yml @@ -185,7 +185,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8 os: macos-15 runs-on: ${{ matrix.os }} diff --git a/all/pom.xml b/all/pom.xml index d0a186dc417..2eb01ae4196 100644 --- a/all/pom.xml +++ b/all/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-all diff --git a/bom/pom.xml b/bom/pom.xml index f74949ba8aa..120a9e64113 100644 --- a/bom/pom.xml +++ b/bom/pom.xml @@ -25,7 +25,7 @@ io.netty netty-bom - 4.1.128.1.dse + 4.1.132.1.dse pom Netty/BOM @@ -49,7 +49,7 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.128.Final + netty-4.1.132.Final @@ -73,7 +73,7 @@ - 2.0.74.Final + 2.0.75.Final diff --git a/buffer/pom.xml b/buffer/pom.xml index 21dd16f77d9..bea8c59032e 100644 --- a/buffer/pom.xml +++ b/buffer/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-buffer diff --git a/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java b/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java index d4fba097831..de90de6f784 100644 --- a/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java +++ b/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java @@ -18,22 +18,24 @@ import io.netty.util.ByteProcessor; import io.netty.util.CharsetUtil; import io.netty.util.IllegalReferenceCountException; -import io.netty.util.IntSupplier; +import io.netty.util.IntConsumer; import io.netty.util.NettyRuntime; +import io.netty.util.Recycler; import io.netty.util.Recycler.EnhancedHandle; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap; +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap.IntEntry; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; import io.netty.util.concurrent.MpscAtomicIntegerArrayQueue; import io.netty.util.concurrent.MpscIntQueue; -import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.MathUtil; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ReferenceCountUpdater; import io.netty.util.internal.SuppressJava6Requirement; import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.ThreadExecutorMap; -import io.netty.util.internal.ThreadLocalRandom; import io.netty.util.internal.UnstableApi; import java.io.IOException; @@ -47,8 +49,10 @@ import java.nio.channels.ScatteringByteChannel; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Iterator; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.concurrent.atomic.LongAdder; @@ -83,6 +87,16 @@ @SuppressJava6Requirement(reason = "Guarded by version check") @UnstableApi final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.AdaptiveAllocatorApi { + private static final int LOW_MEM_THRESHOLD = 512 * 1024 * 1024; + private static final boolean IS_LOW_MEM = Runtime.getRuntime().maxMemory() <= LOW_MEM_THRESHOLD; + + /** + * Whether the IS_LOW_MEM setting should disable thread-local magazines. + * This can have fairly high performance overhead. + */ + private static final boolean DISABLE_THREAD_LOCAL_MAGAZINES_ON_LOW_MEM = SystemPropertyUtil.getBoolean( + "io.netty.allocator.disableThreadLocalMagazinesOnLowMemory", true); + /** * The 128 KiB minimum chunk size is chosen to encourage the system allocator to delegate to mmap for chunk * allocations. For instance, glibc will do this. @@ -90,11 +104,11 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv * which is a much, much larger space. Chunks are also allocated in whole multiples of the minimum * chunk size, which itself is a whole multiple of popular page sizes like 4 KiB, 16 KiB, and 64 KiB. */ - private static final int MIN_CHUNK_SIZE = 128 * 1024; + static final int MIN_CHUNK_SIZE = 128 * 1024; private static final int EXPANSION_ATTEMPTS = 3; private static final int INITIAL_MAGAZINES = 1; private static final int RETIRE_CAPACITY = 256; - private static final int MAX_STRIPES = NettyRuntime.availableProcessors() * 2; + private static final int MAX_STRIPES = IS_LOW_MEM ? 1 : NettyRuntime.availableProcessors() * 2; private static final int BUFS_PER_CHUNK = 8; // For large buffers, aim to have about this many buffers per chunk. /** @@ -102,7 +116,9 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv *

* This number is 8 MiB, and is derived from the limitations of internal histograms. */ - private static final int MAX_CHUNK_SIZE = 8 * 1024 * 1024; // 8 MiB. + private static final int MAX_CHUNK_SIZE = IS_LOW_MEM ? + 2 * 1024 * 1024 : // 2 MiB for systems with small heaps. + 8 * 1024 * 1024; // 8 MiB. private static final int MAX_POOLED_BUF_SIZE = MAX_CHUNK_SIZE / BUFS_PER_CHUNK; /** @@ -150,21 +166,9 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv 16384, 16896, // 16384 + 512 }; - private static final ChunkReleasePredicate CHUNK_RELEASE_ALWAYS = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return true; - } - }; - private static final ChunkReleasePredicate CHUNK_RELEASE_NEVER = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return false; - } - }; private static final int SIZE_CLASSES_COUNT = SIZE_CLASSES.length; - private static final byte[] SIZE_INDEXES = new byte[(SIZE_CLASSES[SIZE_CLASSES_COUNT - 1] / 32) + 1]; + private static final byte[] SIZE_INDEXES = new byte[SIZE_CLASSES[SIZE_CLASSES_COUNT - 1] / 32 + 1]; static { if (MAGAZINE_BUFFER_QUEUE_CAPACITY < 2) { @@ -175,7 +179,7 @@ public boolean shouldReleaseChunk(int chunkSize) { for (int i = 0; i < SIZE_CLASSES_COUNT; i++) { int sizeClass = SIZE_CLASSES[i]; //noinspection ConstantValue - assert (sizeClass & 5) == 0 : "Size class must be a multiple of 32"; + assert (sizeClass & 31) == 0 : "Size class must be a multiple of 32"; int sizeIndex = sizeIndexOf(sizeClass); Arrays.fill(SIZE_INDEXES, lastIndex + 1, sizeIndex + 1, (byte) i); lastIndex = sizeIndex; @@ -193,8 +197,10 @@ public boolean shouldReleaseChunk(int chunkSize) { chunkRegistry = new ChunkRegistry(); sizeClassedMagazineGroups = createMagazineGroupSizeClasses(this, false); largeBufferMagazineGroup = new MagazineGroup( - this, chunkAllocator, new HistogramChunkControllerFactory(true), false); - threadLocalGroup = new FastThreadLocal() { + this, chunkAllocator, new BuddyChunkManagementStrategy(), false); + + boolean disableThreadLocalGroups = IS_LOW_MEM && DISABLE_THREAD_LOCAL_MAGAZINES_ON_LOW_MEM; + threadLocalGroup = disableThreadLocalGroups ? null : new FastThreadLocal() { @Override protected MagazineGroup[] initialValue() { if (useCacheForNonEventLoopThreads || ThreadExecutorMap.currentExecutor() != null) { @@ -220,7 +226,7 @@ private static MagazineGroup[] createMagazineGroupSizeClasses( for (int i = 0; i < SIZE_CLASSES.length; i++) { int segmentSize = SIZE_CLASSES[i]; groups[i] = new MagazineGroup(allocator, allocator.chunkAllocator, - new SizeClassChunkControllerFactory(segmentSize), isThreadLocal); + new SizeClassChunkManagementStrategy(segmentSize), isThreadLocal); } return groups; } @@ -245,7 +251,7 @@ private static MagazineGroup[] createMagazineGroupSizeClasses( * * @return A new multi-producer, multi-consumer queue. */ - private static Queue createSharedChunkQueue() { + private static Queue createSharedChunkQueue() { return PlatformDependent.newFixedMpmcQueue(CHUNK_REUSE_QUEUE); } @@ -259,13 +265,14 @@ private AdaptiveByteBuf allocate(int size, int maxCapacity, Thread currentThread if (size <= MAX_POOLED_BUF_SIZE) { final int index = sizeClassIndexOf(size); MagazineGroup[] magazineGroups; - if (!FastThreadLocalThread.willCleanupFastThreadLocals(currentThread) || + if (!FastThreadLocalThread.willCleanupFastThreadLocals(Thread.currentThread()) || + IS_LOW_MEM || (magazineGroups = threadLocalGroup.get()) == null) { magazineGroups = sizeClassedMagazineGroups; } if (index < magazineGroups.length) { allocated = magazineGroups[index].allocate(size, maxCapacity, currentThread, buf); - } else { + } else if (!IS_LOW_MEM) { allocated = largeBufferMagazineGroup.allocate(size, maxCapacity, currentThread, buf); } } @@ -292,8 +299,7 @@ static int[] getSizeClasses() { return SIZE_CLASSES.clone(); } - private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread currentThread, - AdaptiveByteBuf buf) { + private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread currentThread, AdaptiveByteBuf buf) { // If we don't already have a buffer, obtain one from the most conveniently available magazine. Magazine magazine; if (buf != null) { @@ -307,10 +313,11 @@ private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread curre } // Create a one-off chunk for this allocation. AbstractByteBuf innerChunk = chunkAllocator.allocate(size, maxCapacity); - Chunk chunk = new Chunk(innerChunk, magazine, false, CHUNK_RELEASE_ALWAYS); + Chunk chunk = new Chunk(innerChunk, magazine); chunkRegistry.add(chunk); try { - chunk.readInitInto(buf, size, size, maxCapacity); + boolean success = chunk.readInitInto(buf, size, size, maxCapacity); + assert success: "Failed to initialize ByteBuf with dedicated chunk"; } finally { // As the chunk is an one-off we need to always call release explicitly as readInitInto(...) // will take care of retain once when successful. Once The AdaptiveByteBuf is released it will @@ -355,38 +362,37 @@ private void free() { largeBufferMagazineGroup.free(); } - static int sizeToBucket(int size) { - return HistogramChunkController.sizeToBucket(size); - } - @SuppressJava6Requirement(reason = "Guarded by version check") private static final class MagazineGroup { private final AdaptivePoolingAllocator allocator; private final ChunkAllocator chunkAllocator; - private final ChunkControllerFactory chunkControllerFactory; - private final Queue chunkReuseQueue; + private final ChunkManagementStrategy chunkManagementStrategy; + private final ChunkCache chunkCache; private final StampedLock magazineExpandLock; private final Magazine threadLocalMagazine; + private Thread ownerThread; private volatile Magazine[] magazines; private volatile boolean freed; MagazineGroup(AdaptivePoolingAllocator allocator, ChunkAllocator chunkAllocator, - ChunkControllerFactory chunkControllerFactory, + ChunkManagementStrategy chunkManagementStrategy, boolean isThreadLocal) { this.allocator = allocator; this.chunkAllocator = chunkAllocator; - this.chunkControllerFactory = chunkControllerFactory; - chunkReuseQueue = createSharedChunkQueue(); + this.chunkManagementStrategy = chunkManagementStrategy; + chunkCache = chunkManagementStrategy.createChunkCache(isThreadLocal); if (isThreadLocal) { + ownerThread = Thread.currentThread(); magazineExpandLock = null; - threadLocalMagazine = new Magazine(this, false, chunkReuseQueue, chunkControllerFactory.create(this)); + threadLocalMagazine = new Magazine(this, false, chunkManagementStrategy.createController(this)); } else { + ownerThread = null; magazineExpandLock = new StampedLock(); threadLocalMagazine = null; Magazine[] mags = new Magazine[INITIAL_MAGAZINES]; for (int i = 0; i < mags.length; i++) { - mags[i] = new Magazine(this, true, chunkReuseQueue, chunkControllerFactory.create(this)); + mags[i] = new Magazine(this, true, chunkManagementStrategy.createController(this)); } magazines = mags; } @@ -446,12 +452,9 @@ private boolean tryExpandMagazines(int currentLength) { if (mags.length >= MAX_STRIPES || mags.length > currentLength || freed) { return true; } - Magazine firstMagazine = mags[0]; Magazine[] expanded = new Magazine[mags.length * 2]; for (int i = 0, l = expanded.length; i < l; i++) { - Magazine m = new Magazine(this, true, chunkReuseQueue, chunkControllerFactory.create(this)); - firstMagazine.initializeSharedStateIn(m); - expanded[i] = m; + expanded[i] = new Magazine(this, true, chunkManagementStrategy.createController(this)); } magazines = expanded; } finally { @@ -464,22 +467,32 @@ private boolean tryExpandMagazines(int currentLength) { return true; } - boolean offerToQueue(Chunk buffer) { + Chunk pollChunk(int size) { + return chunkCache.pollChunk(size); + } + + boolean offerChunk(Chunk chunk) { if (freed) { return false; } - boolean isAdded = chunkReuseQueue.offer(buffer); + if (chunk.hasUnprocessedFreelistEntries()) { + chunk.processFreelistEntries(); + } + boolean isAdded = chunkCache.offerChunk(chunk); + if (freed && isAdded) { // Help to free the reuse queue. - freeChunkReuseQueue(); + freeChunkReuseQueue(ownerThread); } return isAdded; } private void free() { freed = true; + Thread ownerThread = this.ownerThread; if (threadLocalMagazine != null) { + this.ownerThread = null; threadLocalMagazine.free(); } else { long stamp = magazineExpandLock.writeLock(); @@ -492,22 +505,153 @@ private void free() { magazineExpandLock.unlockWrite(stamp); } } - freeChunkReuseQueue(); + freeChunkReuseQueue(ownerThread); } - private void freeChunkReuseQueue() { - for (;;) { - Chunk chunk = chunkReuseQueue.poll(); + private void freeChunkReuseQueue(Thread ownerThread) { + Chunk chunk; + while ((chunk = chunkCache.pollChunk(0)) != null) { + if (ownerThread != null && chunk instanceof SizeClassedChunk) { + SizeClassedChunk threadLocalChunk = (SizeClassedChunk) chunk; + assert ownerThread == threadLocalChunk.ownerThread; + // no release segment can ever happen from the owner Thread since it's not running anymore + // This is required to let the ownerThread to be GC'ed despite there are AdaptiveByteBuf + // that reference some thread local chunk + threadLocalChunk.ownerThread = null; + } + chunk.markToDeallocate(); + } + } + } + + private interface ChunkCache { + Chunk pollChunk(int size); + boolean offerChunk(Chunk chunk); + } + + private static final class ConcurrentQueueChunkCache implements ChunkCache { + private final Queue queue; + + private ConcurrentQueueChunkCache() { + queue = createSharedChunkQueue(); + } + + @Override + public SizeClassedChunk pollChunk(int size) { + // we really don't care about size here since the sized class chunk q + // just care about segments of fixed size! + Queue queue = this.queue; + for (int i = 0; i < CHUNK_REUSE_QUEUE; i++) { + SizeClassedChunk chunk = queue.poll(); if (chunk == null) { + return null; + } + if (chunk.hasRemainingCapacity()) { + return chunk; + } + queue.offer(chunk); + } + return null; + } + + @Override + public boolean offerChunk(Chunk chunk) { + return queue.offer((SizeClassedChunk) chunk); + } + } + + private static final class ConcurrentSkipListChunkCache implements ChunkCache { + private final ConcurrentSkipListIntObjMultimap chunks; + + private ConcurrentSkipListChunkCache() { + chunks = new ConcurrentSkipListIntObjMultimap(-1); + } + + @Override + public Chunk pollChunk(int size) { + if (chunks.isEmpty()) { + return null; + } + IntEntry entry = chunks.pollCeilingEntry(size); + if (entry != null) { + Chunk chunk = entry.getValue(); + if (chunk.hasUnprocessedFreelistEntries()) { + chunk.processFreelistEntries(); + } + return chunk; + } + + Chunk bestChunk = null; + int bestRemainingCapacity = 0; + Iterator> itr = chunks.iterator(); + while (itr.hasNext()) { + entry = itr.next(); + final Chunk chunk; + if (entry != null && (chunk = entry.getValue()).hasUnprocessedFreelistEntries()) { + if (!chunks.remove(entry.getKey(), entry.getValue())) { + continue; + } + chunk.processFreelistEntries(); + int remainingCapacity = chunk.remainingCapacity(); + if (remainingCapacity >= size && + (bestChunk == null || remainingCapacity > bestRemainingCapacity)) { + if (bestChunk != null) { + chunks.put(bestRemainingCapacity, bestChunk); + } + bestChunk = chunk; + bestRemainingCapacity = remainingCapacity; + } else { + chunks.put(remainingCapacity, chunk); + } + } + } + + return bestChunk; + } + + @Override + public boolean offerChunk(Chunk chunk) { + chunks.put(chunk.remainingCapacity(), chunk); + + int size = chunks.size(); + while (size > CHUNK_REUSE_QUEUE) { + // Deallocate the chunk with the fewest incoming references. + int key = -1; + Chunk toDeallocate = null; + for (IntEntry entry : chunks) { + Chunk candidate = entry.getValue(); + if (candidate != null) { + if (toDeallocate == null) { + toDeallocate = candidate; + key = entry.getKey(); + } else { + int candidateRefCnt = candidate.refCnt(); + int toDeallocateRefCnt = toDeallocate.refCnt(); + if (candidateRefCnt < toDeallocateRefCnt || + candidateRefCnt == toDeallocateRefCnt && + candidate.capacity() < toDeallocate.capacity()) { + toDeallocate = candidate; + key = entry.getKey(); + } + } + } + } + if (toDeallocate == null) { break; } - chunk.release(); + if (chunks.remove(key, toDeallocate)) { + toDeallocate.markToDeallocate(); + } + size = chunks.size(); } + return true; } } - private interface ChunkControllerFactory { - ChunkController create(MagazineGroup group); + private interface ChunkManagementStrategy { + ChunkController createController(MagazineGroup group); + + ChunkCache createChunkCache(boolean isThreadLocal); } private interface ChunkController { @@ -516,66 +660,75 @@ private interface ChunkController { */ int computeBufferCapacity(int requestedSize, int maxCapacity, boolean isReallocation); - /** - * Initialize the given chunk factory with shared statistics state (if any) from this factory. - */ - void initializeSharedStateIn(ChunkController chunkController); - /** * Allocate a new {@link Chunk} for the given {@link Magazine}. */ Chunk newChunkAllocation(int promptingSize, Magazine magazine); } - private interface ChunkReleasePredicate { - boolean shouldReleaseChunk(int chunkSize); - } - - private static final class SizeClassChunkControllerFactory implements ChunkControllerFactory { + private static final class SizeClassChunkManagementStrategy implements ChunkManagementStrategy { // To amortize activation/deactivation of chunks, we should have a minimum number of segments per chunk. // We choose 32 because it seems neither too small nor too big. // For segments of 16 KiB, the chunks will be half a megabyte. private static final int MIN_SEGMENTS_PER_CHUNK = 32; private final int segmentSize; private final int chunkSize; - private final int[] segmentOffsets; - private SizeClassChunkControllerFactory(int segmentSize) { + private SizeClassChunkManagementStrategy(int segmentSize) { this.segmentSize = ObjectUtil.checkPositive(segmentSize, "segmentSize"); chunkSize = Math.max(MIN_CHUNK_SIZE, segmentSize * MIN_SEGMENTS_PER_CHUNK); - int segmentsCount = chunkSize / segmentSize; - segmentOffsets = new int[segmentsCount]; - for (int i = 0; i < segmentsCount; i++) { - segmentOffsets[i] = i * segmentSize; - } } @Override - public ChunkController create(MagazineGroup group) { - return new SizeClassChunkController(group, segmentSize, chunkSize, segmentOffsets); + public ChunkController createController(MagazineGroup group) { + return new SizeClassChunkController(group, segmentSize, chunkSize); + } + + @Override + public ChunkCache createChunkCache(boolean isThreadLocal) { + return new ConcurrentQueueChunkCache(); } } private static final class SizeClassChunkController implements ChunkController { - private static final ChunkReleasePredicate FALSE_PREDICATE = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return false; - } - }; private final ChunkAllocator chunkAllocator; private final int segmentSize; private final int chunkSize; private final ChunkRegistry chunkRegistry; - private final int[] segmentOffsets; - private SizeClassChunkController(MagazineGroup group, int segmentSize, int chunkSize, int[] segmentOffsets) { + private SizeClassChunkController(MagazineGroup group, int segmentSize, int chunkSize) { chunkAllocator = group.chunkAllocator; this.segmentSize = segmentSize; this.chunkSize = chunkSize; chunkRegistry = group.allocator.chunkRegistry; - this.segmentOffsets = segmentOffsets; + } + + private MpscIntQueue createEmptyFreeList() { + return new MpscAtomicIntegerArrayQueue(chunkSize / segmentSize, SizeClassedChunk.FREE_LIST_EMPTY); + } + + private MpscIntQueue createFreeList() { + final int segmentsCount = chunkSize / segmentSize; + final MpscIntQueue freeList = new MpscAtomicIntegerArrayQueue( + segmentsCount, SizeClassedChunk.FREE_LIST_EMPTY); + int segmentOffset = 0; + for (int i = 0; i < segmentsCount; i++) { + freeList.offer(segmentOffset); + segmentOffset += segmentSize; + } + return freeList; + } + + private IntStack createLocalFreeList() { + final int segmentsCount = chunkSize / segmentSize; + int segmentOffset = chunkSize; + int[] offsets = new int[segmentsCount]; + for (int i = 0; i < segmentsCount; i++) { + segmentOffset -= segmentSize; + offsets[i] = segmentOffset; + } + return new IntStack(offsets); } @Override @@ -584,235 +737,59 @@ public int computeBufferCapacity( return Math.min(segmentSize, maxCapacity); } - @Override - public void initializeSharedStateIn(ChunkController chunkController) { - // NOOP - } - @Override public Chunk newChunkAllocation(int promptingSize, Magazine magazine) { AbstractByteBuf chunkBuffer = chunkAllocator.allocate(chunkSize, chunkSize); assert chunkBuffer.capacity() == chunkSize; - SizeClassedChunk chunk = new SizeClassedChunk(chunkBuffer, magazine, true, - segmentSize, segmentOffsets, FALSE_PREDICATE); + SizeClassedChunk chunk = new SizeClassedChunk(chunkBuffer, magazine, this); chunkRegistry.add(chunk); return chunk; } } - private static final class HistogramChunkControllerFactory implements ChunkControllerFactory { - private final boolean shareable; + private static final class BuddyChunkManagementStrategy implements ChunkManagementStrategy { + private final AtomicInteger maxChunkSize = new AtomicInteger(); - private HistogramChunkControllerFactory(boolean shareable) { - this.shareable = shareable; + @Override + public ChunkController createController(MagazineGroup group) { + return new BuddyChunkController(group, maxChunkSize); } @Override - public ChunkController create(MagazineGroup group) { - return new HistogramChunkController(group, shareable); + public ChunkCache createChunkCache(boolean isThreadLocal) { + return new ConcurrentSkipListChunkCache(); } } - private static final class HistogramChunkController implements ChunkController, ChunkReleasePredicate { - private static final int MIN_DATUM_TARGET = 1024; - private static final int MAX_DATUM_TARGET = 65534; - private static final int INIT_DATUM_TARGET = 9; - private static final int HISTO_BUCKET_COUNT = 16; - private static final int[] HISTO_BUCKETS = { - 16 * 1024, - 24 * 1024, - 32 * 1024, - 48 * 1024, - 64 * 1024, - 96 * 1024, - 128 * 1024, - 192 * 1024, - 256 * 1024, - 384 * 1024, - 512 * 1024, - 768 * 1024, - 1024 * 1024, - 1792 * 1024, - 2048 * 1024, - 3072 * 1024 - }; - - private final MagazineGroup group; - private final boolean shareable; - private final short[][] histos = { - new short[HISTO_BUCKET_COUNT], new short[HISTO_BUCKET_COUNT], - new short[HISTO_BUCKET_COUNT], new short[HISTO_BUCKET_COUNT], - }; + private static final class BuddyChunkController implements ChunkController { + private final ChunkAllocator chunkAllocator; private final ChunkRegistry chunkRegistry; - private short[] histo = histos[0]; - private final int[] sums = new int[HISTO_BUCKET_COUNT]; - - private int histoIndex; - private int datumCount; - private int datumTarget = INIT_DATUM_TARGET; - private boolean hasHadRotation; - private volatile int sharedPrefChunkSize = MIN_CHUNK_SIZE; - private volatile int localPrefChunkSize = MIN_CHUNK_SIZE; - private volatile int localUpperBufSize; - - private HistogramChunkController(MagazineGroup group, boolean shareable) { - this.group = group; - this.shareable = shareable; - chunkRegistry = group.allocator.chunkRegistry; - } - - @Override - public int computeBufferCapacity( - int requestedSize, int maxCapacity, boolean isReallocation) { - if (!isReallocation) { - // Only record allocation size if it's not caused by a reallocation that was triggered by capacity - // change of the buffer. - recordAllocationSize(requestedSize); - } + private final AtomicInteger maxChunkSize; - // Predict starting capacity from localUpperBufSize, but place limits on the max starting capacity - // based on the requested size, because localUpperBufSize can potentially be quite large. - int startCapLimits; - if (requestedSize <= 32768) { // Less than or equal to 32 KiB. - startCapLimits = 65536; // Use at most 64 KiB, which is also the AdaptiveRecvByteBufAllocator max. - } else { - startCapLimits = requestedSize * 2; // Otherwise use at most twice the requested memory. - } - int startingCapacity = Math.min(startCapLimits, localUpperBufSize); - startingCapacity = Math.max(requestedSize, Math.min(maxCapacity, startingCapacity)); - return startingCapacity; - } - - private void recordAllocationSize(int bufferSizeToRecord) { - // Use the preserved size from the reused AdaptiveByteBuf, if available. - // Otherwise, use the requested buffer size. - // This way, we better take into account - if (bufferSizeToRecord == 0) { - return; - } - int bucket = sizeToBucket(bufferSizeToRecord); - histo[bucket]++; - if (datumCount++ == datumTarget) { - rotateHistograms(); - } - } - - static int sizeToBucket(int size) { - int index = binarySearchInsertionPoint(Arrays.binarySearch(HISTO_BUCKETS, size)); - return index >= HISTO_BUCKETS.length ? HISTO_BUCKETS.length - 1 : index; - } - - private static int binarySearchInsertionPoint(int index) { - if (index < 0) { - index = -(index + 1); - } - return index; - } - - static int bucketToSize(int sizeBucket) { - return HISTO_BUCKETS[sizeBucket]; - } - - private void rotateHistograms() { - short[][] hs = histos; - for (int i = 0; i < HISTO_BUCKET_COUNT; i++) { - sums[i] = (hs[0][i] & 0xFFFF) + (hs[1][i] & 0xFFFF) + (hs[2][i] & 0xFFFF) + (hs[3][i] & 0xFFFF); - } - int sum = 0; - for (int count : sums) { - sum += count; - } - int targetPercentile = (int) (sum * 0.99); - int sizeBucket = 0; - for (; sizeBucket < sums.length; sizeBucket++) { - if (sums[sizeBucket] > targetPercentile) { - break; - } - targetPercentile -= sums[sizeBucket]; - } - hasHadRotation = true; - int percentileSize = bucketToSize(sizeBucket); - int prefChunkSize = Math.max(percentileSize * BUFS_PER_CHUNK, MIN_CHUNK_SIZE); - localUpperBufSize = percentileSize; - localPrefChunkSize = prefChunkSize; - if (shareable) { - for (Magazine mag : group.magazines) { - HistogramChunkController statistics = (HistogramChunkController) mag.chunkController; - prefChunkSize = Math.max(prefChunkSize, statistics.localPrefChunkSize); - } - } - if (sharedPrefChunkSize != prefChunkSize) { - // Preferred chunk size changed. Increase check frequency. - datumTarget = Math.max(datumTarget >> 1, MIN_DATUM_TARGET); - sharedPrefChunkSize = prefChunkSize; - } else { - // Preferred chunk size did not change. Check less often. - datumTarget = Math.min(datumTarget << 1, MAX_DATUM_TARGET); - } - - histoIndex = histoIndex + 1 & 3; - histo = histos[histoIndex]; - datumCount = 0; - Arrays.fill(histo, (short) 0); - } - - /** - * Get the preferred chunk size, based on statistics from the {@linkplain #recordAllocationSize(int) recorded} - * allocation sizes. - *

- * This method must be thread-safe. - * - * @return The currently preferred chunk allocation size. - */ - int preferredChunkSize() { - return sharedPrefChunkSize; + BuddyChunkController(MagazineGroup group, AtomicInteger maxChunkSize) { + chunkAllocator = group.chunkAllocator; + chunkRegistry = group.allocator.chunkRegistry; + this.maxChunkSize = maxChunkSize; } @Override - public void initializeSharedStateIn(ChunkController chunkController) { - HistogramChunkController statistics = (HistogramChunkController) chunkController; - int sharedPrefChunkSize = this.sharedPrefChunkSize; - statistics.localPrefChunkSize = sharedPrefChunkSize; - statistics.sharedPrefChunkSize = sharedPrefChunkSize; + public int computeBufferCapacity(int requestedSize, int maxCapacity, boolean isReallocation) { + return MathUtil.safeFindNextPositivePowerOfTwo(requestedSize); } @Override public Chunk newChunkAllocation(int promptingSize, Magazine magazine) { - int size = Math.max(promptingSize * BUFS_PER_CHUNK, preferredChunkSize()); - int minChunks = size / MIN_CHUNK_SIZE; - if (MIN_CHUNK_SIZE * minChunks < size) { - // Round up to nearest whole MIN_CHUNK_SIZE unit. The MIN_CHUNK_SIZE is an even multiple of many - // popular small page sizes, like 4k, 16k, and 64k, which makes it easier for the system allocator - // to manage the memory in terms of whole pages. This reduces memory fragmentation, - // but without the potentially high overhead that power-of-2 chunk sizes would bring. - size = MIN_CHUNK_SIZE * (1 + minChunks); - } - - // Limit chunks to the max size, even if the histogram suggests to go above it. - size = Math.min(size, MAX_CHUNK_SIZE); - - // If we haven't rotated the histogram yet, optimisticly record this chunk size as our preferred. - if (!hasHadRotation && sharedPrefChunkSize == MIN_CHUNK_SIZE) { - sharedPrefChunkSize = size; - } - - ChunkAllocator chunkAllocator = group.chunkAllocator; - Chunk chunk = new Chunk(chunkAllocator.allocate(size, size), magazine, true, this); + int maxChunkSize = this.maxChunkSize.get(); + int proposedChunkSize = MathUtil.safeFindNextPositivePowerOfTwo(BUFS_PER_CHUNK * promptingSize); + int chunkSize = Math.min(MAX_CHUNK_SIZE, Math.max(maxChunkSize, proposedChunkSize)); + if (chunkSize > maxChunkSize) { + // Update our stored max chunk size. It's fine that this is racy. + this.maxChunkSize.set(chunkSize); + } + BuddyChunk chunk = new BuddyChunk(chunkAllocator.allocate(chunkSize, chunkSize), magazine); chunkRegistry.add(chunk); return chunk; } - - @Override - public boolean shouldReleaseChunk(int chunkSize) { - int preferredSize = preferredChunkSize(); - int givenChunks = chunkSize / MIN_CHUNK_SIZE; - int preferredChunks = preferredSize / MIN_CHUNK_SIZE; - int deviation = Math.abs(givenChunks - preferredChunks); - - // Retire chunks with a 5% probability per unit of MIN_CHUNK_SIZE deviation from preference. - return deviation != 0 && - ThreadLocalRandom.current().nextDouble() * 20.0 < deviation; - } } @SuppressJava6Requirement(reason = "Guarded by version check") @@ -823,13 +800,31 @@ private static final class Magazine { } private static final Chunk MAGAZINE_FREED = new Chunk(); - private static final ObjectPool EVENT_LOOP_LOCAL_BUFFER_POOL = ObjectPool.newPool( - new ObjectPool.ObjectCreator() { - @Override - public AdaptiveByteBuf newObject(ObjectPool.Handle handle) { - return new AdaptiveByteBuf(handle); - } - }); + private static final class AdaptiveRecycler extends Recycler { + + private AdaptiveRecycler() { + } + + private AdaptiveRecycler(int maxCapacity) { + // doesn't use fast thread local, shared + super(maxCapacity); + } + + @Override + protected AdaptiveByteBuf newObject(final Handle handle) { + return new AdaptiveByteBuf((EnhancedHandle) handle); + } + + public static AdaptiveRecycler threadLocal() { + return new AdaptiveRecycler(); + } + + public static AdaptiveRecycler sharedWith(int maxCapacity) { + return new AdaptiveRecycler(maxCapacity); + } + } + + private static final AdaptiveRecycler EVENT_LOOP_LOCAL_BUFFER_POOL = AdaptiveRecycler.threadLocal(); private Chunk current; @SuppressWarnings("unused") // updated via NEXT_IN_LINE @@ -837,31 +832,20 @@ public AdaptiveByteBuf newObject(ObjectPool.Handle handle) { private final MagazineGroup group; private final ChunkController chunkController; private final StampedLock allocationLock; - private final Queue bufferQueue; - private final ObjectPool.Handle handle; - private final Queue sharedChunkQueue; + private final AdaptiveRecycler recycler; - Magazine(MagazineGroup group, boolean shareable, Queue sharedChunkQueue, - ChunkController chunkController) { + Magazine(MagazineGroup group, boolean shareable, ChunkController chunkController) { this.group = group; this.chunkController = chunkController; if (shareable) { // We only need the StampedLock if this Magazine will be shared across threads. allocationLock = new StampedLock(); - bufferQueue = PlatformDependent.newFixedMpmcQueue(MAGAZINE_BUFFER_QUEUE_CAPACITY); - handle = new ObjectPool.Handle() { - @Override - public void recycle(AdaptiveByteBuf self) { - bufferQueue.offer(self); - } - }; + recycler = AdaptiveRecycler.sharedWith(MAGAZINE_BUFFER_QUEUE_CAPACITY); } else { allocationLock = null; - bufferQueue = null; - handle = null; + recycler = null; } - this.sharedChunkQueue = sharedChunkQueue; } public boolean tryAllocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean reallocate) { @@ -890,7 +874,7 @@ private boolean allocateWithoutLock(int size, int maxCapacity, AdaptiveByteBuf b return false; } if (curr == null) { - curr = sharedChunkQueue.poll(); + curr = group.pollChunk(size); if (curr == null) { return false; } @@ -900,9 +884,10 @@ private boolean allocateWithoutLock(int size, int maxCapacity, AdaptiveByteBuf b int remainingCapacity = curr.remainingCapacity(); int startingCapacity = chunkController.computeBufferCapacity( size, maxCapacity, true /* never update stats as we don't hold the magazine lock */); - if (remainingCapacity >= size) { - curr.readInitInto(buf, size, Math.min(remainingCapacity, startingCapacity), maxCapacity); + if (remainingCapacity >= size && + curr.readInitInto(buf, size, Math.min(remainingCapacity, startingCapacity), maxCapacity)) { allocated = true; + remainingCapacity = curr.remainingCapacity(); } try { if (remainingCapacity >= RETIRE_CAPACITY) { @@ -921,33 +906,17 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean int startingCapacity = chunkController.computeBufferCapacity(size, maxCapacity, reallocate); Chunk curr = current; if (curr != null) { - // We have a Chunk that has some space left. + boolean success = curr.readInitInto(buf, size, startingCapacity, maxCapacity); int remainingCapacity = curr.remainingCapacity(); - if (remainingCapacity > startingCapacity) { - curr.readInitInto(buf, size, startingCapacity, maxCapacity); - // We still have some bytes left that we can use for the next allocation, just early return. - return true; - } - - // At this point we know that this will be the last time current will be used, so directly set it to - // null and release it once we are done. - current = null; - if (remainingCapacity >= size) { - try { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); - return true; - } finally { - curr.releaseFromMagazine(); - } - } - - // Check if we either retain the chunk in the nextInLine cache or releasing it. - if (remainingCapacity < RETIRE_CAPACITY) { - curr.releaseFromMagazine(); - } else { - // See if it makes sense to transfer the Chunk to the nextInLine cache for later usage. - // This method will release curr if this is not the case + if (!success && remainingCapacity > 0) { + current = null; transferToNextInLineOrRelease(curr); + } else if (remainingCapacity == 0) { + current = null; + curr.releaseFromMagazine(); + } + if (success) { + return true; } } @@ -969,32 +938,28 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean } int remainingCapacity = curr.remainingCapacity(); - if (remainingCapacity > startingCapacity) { + if (remainingCapacity > startingCapacity && + curr.readInitInto(buf, size, startingCapacity, maxCapacity)) { // We have a Chunk that has some space left. - curr.readInitInto(buf, size, startingCapacity, maxCapacity); current = curr; return true; } - if (remainingCapacity >= size) { - // At this point we know that this will be the last time curr will be used, so directly set it to - // null and release it once we are done. - try { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); - return true; - } finally { - // Release in a finally block so even if readInitInto(...) would throw we would still correctly - // release the current chunk before null it out. - curr.releaseFromMagazine(); + try { + if (remainingCapacity >= size) { + // At this point we know that this will be the last time curr will be used, so directly set it + // to null and release it once we are done. + return curr.readInitInto(buf, size, remainingCapacity, maxCapacity); } - } else { - // Release it as it's too small. + } finally { + // Release in a finally block so even if readInitInto(...) would throw we would still correctly + // release the current chunk before null it out. curr.releaseFromMagazine(); } } // Now try to poll from the central queue first - curr = sharedChunkQueue.poll(); + curr = group.pollChunk(size); if (curr == null) { curr = chunkController.newChunkAllocation(size, this); } else { @@ -1015,14 +980,15 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean } current = curr; + boolean success; try { int remainingCapacity = curr.remainingCapacity(); assert remainingCapacity >= size; if (remainingCapacity > startingCapacity) { - curr.readInitInto(buf, size, startingCapacity, maxCapacity); + success = curr.readInitInto(buf, size, startingCapacity, maxCapacity); curr = null; } else { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); + success = curr.readInitInto(buf, size, remainingCapacity, maxCapacity); } } finally { if (curr != null) { @@ -1032,7 +998,7 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean current = null; } } - return true; + return success; } private void restoreMagazineFreed() { @@ -1063,10 +1029,6 @@ private void transferToNextInLineOrRelease(Chunk chunk) { chunk.releaseFromMagazine(); } - boolean trySetNextInLine(Chunk chunk) { - return NEXT_IN_LINE.compareAndSet(this, null, chunk); - } - void free() { // Release the current Chunk and the next that was stored for later usage. restoreMagazineFreed(); @@ -1084,26 +1046,15 @@ void free() { } public AdaptiveByteBuf newBuffer() { - AdaptiveByteBuf buf; - if (handle == null) { - buf = EVENT_LOOP_LOCAL_BUFFER_POOL.get(); - } else { - buf = bufferQueue.poll(); - if (buf == null) { - buf = new AdaptiveByteBuf(handle); - } - } + AdaptiveRecycler recycler = this.recycler; + AdaptiveByteBuf buf = recycler == null? EVENT_LOOP_LOCAL_BUFFER_POOL.get() : recycler.get(); buf.resetRefCnt(); buf.discardMarks(); return buf; } boolean offerToQueue(Chunk chunk) { - return group.offerToQueue(chunk); - } - - public void initializeSharedStateIn(Magazine other) { - chunkController.initializeSharedStateIn(other.chunkController); + return group.offerChunk(chunk); } } @@ -1133,9 +1084,7 @@ private static class Chunk implements ReferenceCounted { protected final AbstractByteBuf delegate; protected Magazine magazine; private final AdaptivePoolingAllocator allocator; - private final ChunkReleasePredicate chunkReleasePredicate; private final int capacity; - private final boolean pooled; protected int allocatedBytes; private static final ReferenceCountUpdater updater = @@ -1161,23 +1110,17 @@ protected long unsafeOffset() { delegate = null; magazine = null; allocator = null; - chunkReleasePredicate = null; capacity = 0; - pooled = false; } - Chunk(AbstractByteBuf delegate, Magazine magazine, boolean pooled, - ChunkReleasePredicate chunkReleasePredicate) { + Chunk(AbstractByteBuf delegate, Magazine magazine) { this.delegate = delegate; - this.pooled = pooled; capacity = delegate.capacity(); updater.setInitialValue(this); attachToMagazine(magazine); // We need the top-level allocator so ByteBuf.capacity(int) can call reallocate() allocator = magazine.group.allocator; - - this.chunkReleasePredicate = chunkReleasePredicate; } Magazine currentMagazine() { @@ -1241,46 +1184,33 @@ public boolean release(int decrement) { /** * Called when a magazine is done using this chunk, probably because it was emptied. */ - boolean releaseFromMagazine() { - return release(); + void releaseFromMagazine() { + // Chunks can be reused before they become empty. + // We can therefor put them in the shared queue as soon as the magazine is done with this chunk. + Magazine mag = magazine; + detachFromMagazine(); + if (!mag.offerToQueue(this)) { + markToDeallocate(); + } } /** * Called when a ByteBuf is done using its allocation in this chunk. */ - boolean releaseSegment(int ignoredSegmentId) { - return release(); + void releaseSegment(int ignoredSegmentId, int size) { + release(); } - private void deallocate() { - Magazine mag = magazine; - int chunkSize = delegate.capacity(); - if (!pooled || chunkReleasePredicate.shouldReleaseChunk(chunkSize) || mag == null) { - // Drop the chunk if the parent allocator is closed, - // or if the chunk deviates too much from the preferred chunk size. - detachFromMagazine(); - allocator.chunkRegistry.remove(this); - delegate.release(); - } else { - updater.resetRefCnt(this); - delegate.setIndex(0, 0); - allocatedBytes = 0; - if (!mag.trySetNextInLine(this)) { - // As this Chunk does not belong to the mag anymore we need to decrease the used memory . - detachFromMagazine(); - if (!mag.offerToQueue(this)) { - // The central queue is full. Ensure we release again as we previously did use resetRefCnt() - // which did increase the reference count by 1. - boolean released = updater.release(this); - allocator.chunkRegistry.remove(this); - delegate.release(); - assert released; - } - } - } + void markToDeallocate() { + release(); } - public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + protected void deallocate() { + allocator.chunkRegistry.remove(this); + delegate.release(); + } + + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { int startIndex = allocatedBytes; allocatedBytes = startIndex + startingCapacity; Chunk chunk = this; @@ -1297,101 +1227,408 @@ public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, in chunk.release(); } } + return true; } public int remainingCapacity() { return capacity - allocatedBytes; } + public boolean hasUnprocessedFreelistEntries() { + return false; + } + + public void processFreelistEntries() { + } + public int capacity() { return capacity; } } + private static final class IntStack { + + private final int[] stack; + private int top; + + IntStack(int[] initialValues) { + stack = initialValues; + top = initialValues.length - 1; + } + + public boolean isEmpty() { + return top == -1; + } + + public int pop() { + final int last = stack[top]; + top--; + return last; + } + + public void push(int value) { + stack[top + 1] = value; + top++; + } + + public int size() { + return top + 1; + } + } + + /** + * Removes per-allocation retain()/release() atomic ops from the hot path by replacing ref counting + * with a segment-count state machine. Atomics are only needed on the cold deallocation path + * ({@link #markToDeallocate()}), which is rare for long-lived chunks that cycle segments many times. + * The tradeoff is a {@link MpscIntQueue#size()} call (volatile reads, no RMW) per remaining segment + * return after mark — acceptable since it avoids atomic RMWs entirely. + *

+ * State transitions: + *

+ *

+ * Ordering: external {@link #releaseSegment} pushes to the MPSC queue (which has an implicit + * StoreLoad barrier via its {@code offer()}), then reads {@code state} — this guarantees + * visibility of any preceding {@link #markToDeallocate()} write. + */ private static final class SizeClassedChunk extends Chunk { private static final int FREE_LIST_EMPTY = -1; + private static final int AVAILABLE = -1; + // Integer.MIN_VALUE so that `DEALLOCATED + externalFreeList.size()` can never equal `segments`, + // making late-arriving releaseSegment calls on external threads arithmetically harmless. + private static final int DEALLOCATED = Integer.MIN_VALUE; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(SizeClassedChunk.class, "state"); + private volatile int state; + private final int segments; private final int segmentSize; - private final MpscIntQueue freeList; - - SizeClassedChunk(AbstractByteBuf delegate, Magazine magazine, boolean pooled, int segmentSize, - final int[] segmentOffsets, ChunkReleasePredicate shouldReleaseChunk) { - super(delegate, magazine, pooled, shouldReleaseChunk); - this.segmentSize = segmentSize; - int segmentCount = segmentOffsets.length; - assert delegate.capacity() / segmentSize == segmentCount; - assert segmentCount > 0: "Chunk must have a positive number of segments"; - freeList = new MpscAtomicIntegerArrayQueue(segmentCount, FREE_LIST_EMPTY); - freeList.fill(segmentCount, new IntSupplier() { - int counter; - @Override - public int get() { - return segmentOffsets[counter++]; - } - }); + private final MpscIntQueue externalFreeList; + private final IntStack localFreeList; + private Thread ownerThread; + + SizeClassedChunk(AbstractByteBuf delegate, Magazine magazine, + SizeClassChunkController controller) { + super(delegate, magazine); + segmentSize = controller.segmentSize; + segments = controller.chunkSize / segmentSize; + STATE.lazySet(this, AVAILABLE); + ownerThread = magazine.group.ownerThread; + if (ownerThread == null) { + externalFreeList = controller.createFreeList(); + localFreeList = null; + } else { + externalFreeList = controller.createEmptyFreeList(); + localFreeList = controller.createLocalFreeList(); + } } @Override - public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { - int startIndex = freeList.poll(); + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + assert state == AVAILABLE; + final int startIndex = nextAvailableSegmentOffset(); if (startIndex == FREE_LIST_EMPTY) { - throw new IllegalStateException("Free list is empty"); + return false; } allocatedBytes += segmentSize; + try { + buf.init(delegate, this, 0, 0, startIndex, size, startingCapacity, maxCapacity); + } catch (Throwable t) { + allocatedBytes -= segmentSize; + releaseSegmentOffsetIntoFreeList(startIndex); + PlatformDependent.throwException(t); + } + return true; + } + + private int nextAvailableSegmentOffset() { + final int startIndex; + IntStack localFreeList = this.localFreeList; + if (localFreeList != null) { + assert Thread.currentThread() == ownerThread; + if (localFreeList.isEmpty()) { + startIndex = externalFreeList.poll(); + } else { + startIndex = localFreeList.pop(); + } + } else { + startIndex = externalFreeList.poll(); + } + return startIndex; + } + + // this can be used by the ConcurrentQueueChunkCache to find the first buffer to use: + // it doesn't update the remaining capacity and it's not consider a single segmentSize + // case as not suitable to be reused + public boolean hasRemainingCapacity() { + int remaining = super.remainingCapacity(); + if (remaining > 0) { + return true; + } + if (localFreeList != null) { + return !localFreeList.isEmpty(); + } + return !externalFreeList.isEmpty(); + } + + @Override + public int remainingCapacity() { + int remaining = super.remainingCapacity(); + return remaining > segmentSize ? remaining : updateRemainingCapacity(remaining); + } + + private int updateRemainingCapacity(int snapshotted) { + int freeSegments = externalFreeList.size(); + IntStack localFreeList = this.localFreeList; + if (localFreeList != null) { + freeSegments += localFreeList.size(); + } + int updated = freeSegments * segmentSize; + if (updated != snapshotted) { + allocatedBytes = capacity() - updated; + } + return updated; + } + + private void releaseSegmentOffsetIntoFreeList(int startIndex) { + IntStack localFreeList = this.localFreeList; + if (localFreeList != null && Thread.currentThread() == ownerThread) { + localFreeList.push(startIndex); + } else { + boolean segmentReturned = externalFreeList.offer(startIndex); + assert segmentReturned : "Unable to return segment " + startIndex + " to free list"; + } + } + + @Override + void releaseSegment(int startIndex, int size) { + IntStack localFreeList = this.localFreeList; + if (localFreeList != null && Thread.currentThread() == ownerThread) { + localFreeList.push(startIndex); + int state = this.state; + if (state != AVAILABLE) { + updateStateOnLocalReleaseSegment(state, localFreeList); + } + } else { + boolean segmentReturned = externalFreeList.offer(startIndex); + assert segmentReturned; + // implicit StoreLoad barrier from MPSC offer() + int state = this.state; + if (state != AVAILABLE) { + deallocateIfNeeded(state); + } + } + } + + private void updateStateOnLocalReleaseSegment(int previousLocalSize, IntStack localFreeList) { + int newLocalSize = localFreeList.size(); + boolean alwaysTrue = STATE.compareAndSet(this, previousLocalSize, newLocalSize); + assert alwaysTrue : "this shouldn't happen unless double release in the local free list"; + deallocateIfNeeded(newLocalSize); + } + + private void deallocateIfNeeded(int localSize) { + // Check if all segments have been returned. + int totalFreeSegments = localSize + externalFreeList.size(); + if (totalFreeSegments == segments && STATE.compareAndSet(this, localSize, DEALLOCATED)) { + deallocate(); + } + } + + @Override + void markToDeallocate() { + IntStack localFreeList = this.localFreeList; + int localSize = localFreeList != null ? localFreeList.size() : 0; + STATE.set(this, localSize); + deallocateIfNeeded(localSize); + } + } + + private static final class BuddyChunk extends Chunk implements IntConsumer { + private static final int MIN_BUDDY_SIZE = 32768; + private static final byte IS_CLAIMED = (byte) (1 << 7); + private static final byte HAS_CLAIMED_CHILDREN = 1 << 6; + private static final byte SHIFT_MASK = ~(IS_CLAIMED | HAS_CLAIMED_CHILDREN); + private static final int PACK_OFFSET_MASK = 0xFFFF; + private static final int PACK_SIZE_SHIFT = Integer.SIZE - Integer.numberOfLeadingZeros(PACK_OFFSET_MASK); + + private final MpscIntQueue freeList; + // The bits of each buddy: [1: is claimed][1: has claimed children][30: MIN_BUDDY_SIZE shift to get size] + private final byte[] buddies; + private final int freeListCapacity; + + BuddyChunk(AbstractByteBuf delegate, Magazine magazine) { + super(delegate, magazine); + freeListCapacity = delegate.capacity() / MIN_BUDDY_SIZE; + int maxShift = Integer.numberOfTrailingZeros(freeListCapacity); + assert maxShift <= 30; // The top 2 bits are used for marking. + // At most half of tree (all leaf nodes) can be freed. + freeList = new MpscAtomicIntegerArrayQueue(freeListCapacity, -1); + buddies = new byte[freeListCapacity << 1]; + + // Generate the buddies entries. + int index = 1; + int runLength = 1; + int currentRun = 0; + while (maxShift > 0) { + buddies[index++] = (byte) maxShift; + if (++currentRun == runLength) { + currentRun = 0; + runLength <<= 1; + maxShift--; + } + } + } + + @Override + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + if (!freeList.isEmpty()) { + freeList.drain(freeListCapacity, this); + } + int startIndex = chooseFirstFreeBuddy(1, startingCapacity, 0); + if (startIndex == -1) { + return false; + } Chunk chunk = this; chunk.retain(); try { - buf.init(delegate, chunk, 0, 0, startIndex, size, startingCapacity, maxCapacity); + buf.init(delegate, this, 0, 0, startIndex, size, startingCapacity, maxCapacity); + allocatedBytes += startingCapacity; chunk = null; } finally { if (chunk != null) { + unreserveMatchingBuddy(1, startingCapacity, startIndex, 0); // If chunk is not null we know that buf.init(...) failed and so we need to manually release - // the chunk again as we retained it before calling buf.init(...). Beside this we also need to - // restore the old allocatedBytes value. - allocatedBytes -= segmentSize; - chunk.releaseSegment(startIndex); + // the chunk again as we retained it before calling buf.init(...). + chunk.release(); } } + return true; + } + + @Override + public void accept(int packed) { + // Called by allocating thread when draining freeList. + int size = MIN_BUDDY_SIZE << (packed >> PACK_SIZE_SHIFT); + int offset = (packed & PACK_OFFSET_MASK) * MIN_BUDDY_SIZE; + unreserveMatchingBuddy(1, size, offset, 0); + allocatedBytes -= size; + } + + @Override + void releaseSegment(int startingIndex, int size) { + int packedOffset = startingIndex / MIN_BUDDY_SIZE; + int packedSize = Integer.numberOfTrailingZeros(size / MIN_BUDDY_SIZE) << PACK_SIZE_SHIFT; + int packed = packedOffset | packedSize; + freeList.offer(packed); + release(); } @Override public int remainingCapacity() { - int remainingCapacity = super.remainingCapacity(); - if (remainingCapacity > segmentSize) { - return remainingCapacity; + if (!freeList.isEmpty()) { + freeList.drain(freeListCapacity, this); } - int updatedRemainingCapacity = freeList.size() * segmentSize; - if (updatedRemainingCapacity == remainingCapacity) { - return remainingCapacity; - } - // update allocatedBytes based on what's available in the free list - allocatedBytes = capacity() - updatedRemainingCapacity; - return updatedRemainingCapacity; + return super.remainingCapacity(); } @Override - boolean releaseFromMagazine() { - // Size-classed chunks can be reused before they become empty. - // We can therefor put them in the shared queue as soon as the magazine is done with this chunk. - Magazine mag = magazine; - detachFromMagazine(); - if (!mag.offerToQueue(this)) { - return super.releaseFromMagazine(); + public boolean hasUnprocessedFreelistEntries() { + return !freeList.isEmpty(); + } + + @Override + public void processFreelistEntries() { + freeList.drain(freeListCapacity, this); + } + + /** + * Claim a suitable buddy and return its start offset into the delegate chunk, or return -1 if nothing claimed. + */ + private int chooseFirstFreeBuddy(int index, int size, int currOffset) { + byte[] buddies = this.buddies; + while (index < buddies.length) { + byte buddy = buddies[index]; + int currValue = MIN_BUDDY_SIZE << (buddy & SHIFT_MASK); + if (currValue < size || (buddy & IS_CLAIMED) == IS_CLAIMED) { + return -1; + } + if (currValue == size && (buddy & HAS_CLAIMED_CHILDREN) == 0) { + buddies[index] |= IS_CLAIMED; + return currOffset; + } + int found = chooseFirstFreeBuddy(index << 1, size, currOffset); + if (found != -1) { + buddies[index] |= HAS_CLAIMED_CHILDREN; + return found; + } + index = (index << 1) + 1; + currOffset += currValue >> 1; // Bump offset to skip first half of this layer. } - return false; + return -1; + } + + /** + * Un-reserve the matching buddy and return whether there are any other child or sibling reservations. + */ + private boolean unreserveMatchingBuddy(int index, int size, int offset, int currOffset) { + byte[] buddies = this.buddies; + if (buddies.length <= index) { + return false; + } + byte buddy = buddies[index]; + int currSize = MIN_BUDDY_SIZE << (buddy & SHIFT_MASK); + + if (currSize == size) { + // We're at the right size level. + if (currOffset == offset) { + buddies[index] &= SHIFT_MASK; + return false; + } + throw new IllegalStateException("The intended segment was not found at index " + + index + ", for size " + size + " and offset " + offset); + } + + // We're at a parent size level. Use the target offset to guide our drill-down path. + boolean claims; + int siblingIndex; + if (offset < currOffset + (currSize >> 1)) { + // Must be down the left path. + claims = unreserveMatchingBuddy(index << 1, size, offset, currOffset); + siblingIndex = (index << 1) + 1; + } else { + // Must be down the rigth path. + claims = unreserveMatchingBuddy((index << 1) + 1, size, offset, currOffset + (currSize >> 1)); + siblingIndex = index << 1; + } + if (!claims) { + // No other claims down the path we took. Check if the sibling has claims. + byte sibling = buddies[siblingIndex]; + if ((sibling & SHIFT_MASK) == sibling) { + // No claims in the sibling. We can clear this level as well. + buddies[index] &= SHIFT_MASK; + return false; + } + } + return true; } @Override - boolean releaseSegment(int startIndex) { - boolean released = release(); - boolean segmentReturned = freeList.offer(startIndex); - assert segmentReturned: "Unable to return segment " + startIndex + " to free list"; - return released; + public String toString() { + int capacity = delegate.capacity(); + int remaining = capacity - allocatedBytes; + return "BuddyChunk[capacity: " + capacity + + ", remaining: " + remaining + + ", free list: " + freeList.size() + ']'; } } static final class AdaptiveByteBuf extends AbstractReferenceCountedByteBuf { - private final ObjectPool.Handle handle; + private final EnhancedHandle handle; // this both act as adjustment and the start index for a free list segment allocation private int startIndex; @@ -1403,7 +1640,7 @@ static final class AdaptiveByteBuf extends AbstractReferenceCountedByteBuf { private boolean hasArray; private boolean hasMemoryAddress; - AdaptiveByteBuf(ObjectPool.Handle recyclerHandle) { + AdaptiveByteBuf(EnhancedHandle recyclerHandle) { super(0); handle = ObjectUtil.checkNotNull(recyclerHandle, "recyclerHandle"); } @@ -1442,12 +1679,11 @@ public int maxFastWritableBytes() { @Override public ByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); if (length <= newCapacity && newCapacity <= maxFastCapacity) { - ensureAccessible(); length = newCapacity; return this; } - checkNewCapacity(newCapacity); if (newCapacity < capacity()) { length = newCapacity; trimIndicesToCapacity(newCapacity); @@ -1460,11 +1696,14 @@ public ByteBuf capacity(int newCapacity) { int readerIndex = this.readerIndex; int writerIndex = this.writerIndex; int baseOldRootIndex = startIndex; - int oldCapacity = length; + int oldLength = length; + int oldCapacity = maxFastCapacity; AbstractByteBuf oldRoot = rootParent(); allocator.reallocate(newCapacity, maxCapacity(), this); - oldRoot.getBytes(baseOldRootIndex, this, 0, oldCapacity); - chunk.releaseSegment(baseOldRootIndex); + oldRoot.getBytes(baseOldRootIndex, this, 0, oldLength); + chunk.releaseSegment(baseOldRootIndex, oldCapacity); + assert oldCapacity < maxFastCapacity && newCapacity <= maxFastCapacity: + "Capacity increase failed"; this.readerIndex = readerIndex; this.writerIndex = writerIndex; return this; @@ -1475,6 +1714,7 @@ public ByteBufAllocator alloc() { return rootParent().alloc(); } + @SuppressWarnings("deprecation") @Override public ByteOrder order() { return rootParent().order(); @@ -1841,17 +2081,12 @@ private int idx(int index) { @Override protected void deallocate() { if (chunk != null) { - chunk.releaseSegment(startIndex); + chunk.releaseSegment(startIndex, maxFastCapacity); } tmpNioBuf = null; chunk = null; rootParent = null; - if (handle instanceof EnhancedHandle) { - EnhancedHandle enhancedHandle = (EnhancedHandle) handle; - enhancedHandle.unguardedRecycle(this); - } else { - handle.recycle(this); - } + handle.unguardedRecycle(this); } } diff --git a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java index 4ad86136888..4786724dc0b 100644 --- a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java @@ -2360,4 +2360,17 @@ private void shiftComps(int i, int count) { } componentCount = newSize; } + + /** + * Decreases the reference count by the specified {@code decrement} and deallocates this object if the reference + * count reaches at {@code 0}. At this point it will also decrement the reference count of each internal + * component by {@code 1}. + * + * @param decrement the number by which the reference count should be decreased + * @return {@code true} if and only if the reference count became {@code 0} and this object has been deallocated + */ + @Override + public boolean release(final int decrement) { + return super.release(decrement); + } } diff --git a/buffer/src/main/java/io/netty/buffer/SizeClasses.java b/buffer/src/main/java/io/netty/buffer/SizeClasses.java index b42d455d5e6..d1fa1389855 100644 --- a/buffer/src/main/java/io/netty/buffer/SizeClasses.java +++ b/buffer/src/main/java/io/netty/buffer/SizeClasses.java @@ -107,7 +107,7 @@ final class SizeClasses implements SizeClassesMetric { private final int[] pageIdx2sizeTab; - // lookup table for sizeIdx <= smallMaxSizeIdx + // lookup table for sizeIdx < nSizes private final int[] sizeIdx2sizeTab; // lookup table used for size <= lookupMaxClass diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java index c32183fa707..a5f3675ba66 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java @@ -17,6 +17,7 @@ import io.netty.util.internal.PlatformDependent; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; @@ -26,6 +27,7 @@ import static org.assertj.core.api.Assumptions.assumeThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.abort; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -196,6 +198,27 @@ public void shouldReuseChunks() throws Exception { .isLessThan(8 * 1024 * 1024); } + @Test + public void testCapacityNotGreaterThanMaxCapacity() { + testCapacityNotGreaterThanMaxCapacity(true); + testCapacityNotGreaterThanMaxCapacity(false); + } + + private void testCapacityNotGreaterThanMaxCapacity(boolean preferDirect) { + final int maxSize = 100000; + final ByteBuf buf = newAllocator(preferDirect).newDirectBuffer(maxSize, maxSize); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + buf.capacity(maxSize + 1); + } + }); + } finally { + buf.release(); + } + } + protected long expectedUsedMemory(T allocator, int capacity) { return capacity; } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java index 58a4ae82e75..d8ff780f517 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java @@ -57,6 +57,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -74,7 +75,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -2290,7 +2290,7 @@ public void testToString() { } @Test - @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testToStringMultipleThreads() throws Throwable { buffer.clear(); buffer.writeBytes("Hello, World!".getBytes(CharsetUtil.ISO_8859_1)); @@ -2300,7 +2300,7 @@ public void testToStringMultipleThreads() throws Throwable { static void testToStringMultipleThreads0(final ByteBuf buffer) throws Throwable { final String expected = buffer.toString(CharsetUtil.ISO_8859_1); - final AtomicInteger counter = new AtomicInteger(30000); + final CyclicBarrier startBarrier = new CyclicBarrier(10); final AtomicReference errorRef = new AtomicReference(); List threads = new ArrayList(); for (int i = 0; i < 10; i++) { @@ -2308,11 +2308,15 @@ static void testToStringMultipleThreads0(final ByteBuf buffer) throws Throwable @Override public void run() { try { - while (errorRef.get() == null && counter.decrementAndGet() > 0) { + startBarrier.await(10, TimeUnit.SECONDS); + int counter = 3000; + while (errorRef.get() == null && counter-- > 0) { assertEquals(expected, buffer.toString(CharsetUtil.ISO_8859_1)); } } catch (Throwable cause) { - errorRef.compareAndSet(null, cause); + if (!errorRef.compareAndSet(null, cause)) { + ThrowableUtil.addSuppressed(errorRef.get(), cause); + } } } }); @@ -2322,13 +2326,27 @@ public void run() { thread.start(); } - for (Thread thread : threads) { - thread.join(); - } + joinAllAndReportErrors(threads, errorRef); + } - Throwable error = errorRef.get(); - if (error != null) { - throw error; + private static void joinAllAndReportErrors(List threads, AtomicReference errorRef) + throws Throwable { + try { + for (Thread thread : threads) { + thread.join(); + } + + Throwable error = errorRef.get(); + if (error != null) { + throw error; + } + } catch (Throwable e) { + for (Thread thread : threads) { + if (thread.isAlive()) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(thread, e); + } + } + throw e; } } @@ -2345,7 +2363,7 @@ public void testCopyMultipleThreads0() throws Throwable { static void testCopyMultipleThreads0(final ByteBuf buffer) throws Throwable { final ByteBuf expected = buffer.copy(); try { - final AtomicInteger counter = new AtomicInteger(30000); + final CyclicBarrier startBarrier = new CyclicBarrier(10); final AtomicReference errorRef = new AtomicReference(); List threads = new ArrayList(); for (int i = 0; i < 10; i++) { @@ -2353,7 +2371,9 @@ static void testCopyMultipleThreads0(final ByteBuf buffer) throws Throwable { @Override public void run() { try { - while (errorRef.get() == null && counter.decrementAndGet() > 0) { + startBarrier.await(10, TimeUnit.SECONDS); + int counter = 3000; + while (errorRef.get() == null && counter-- > 0) { ByteBuf copy = buffer.copy(); try { assertEquals(expected, copy); @@ -2372,14 +2392,7 @@ public void run() { thread.start(); } - for (Thread thread : threads) { - thread.join(); - } - - Throwable error = errorRef.get(); - if (error != null) { - throw error; - } + joinAllAndReportErrors(threads, errorRef); } finally { expected.release(); } @@ -2879,43 +2892,54 @@ public void testSliceBytesInArrayMultipleThreads() throws Exception { static void testBytesInArrayMultipleThreads( final ByteBuf buffer, final byte[] expectedBytes, final boolean slice) throws Exception { - final AtomicReference cause = new AtomicReference(); - final CountDownLatch latch = new CountDownLatch(60000); - final CyclicBarrier barrier = new CyclicBarrier(11); - for (int i = 0; i < 10; i++) { - new Thread(new Runnable() { - @Override - public void run() { - while (cause.get() == null && latch.getCount() > 0) { - ByteBuf buf; - if (slice) { - buf = buffer.slice(); - } else { - buf = buffer.duplicate(); - } - - byte[] array = new byte[8]; - buf.readBytes(array); + final CyclicBarrier startBarrier = new CyclicBarrier(10); + final CyclicBarrier endBarrier = new CyclicBarrier(11); + Callable callable = new Callable() { + @Override + public Void call() throws Exception { + startBarrier.await(); + for (int i = 0; i < 6000; i++) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } - assertArrayEquals(expectedBytes, array); + byte[] array = new byte[8]; + buf.readBytes(array); - Arrays.fill(array, (byte) 0); - buf.getBytes(0, array); - assertArrayEquals(expectedBytes, array); + assertArrayEquals(expectedBytes, array); - latch.countDown(); - } - try { - barrier.await(); - } catch (Exception e) { - // ignore - } + Arrays.fill(array, (byte) 0); + buf.getBytes(0, array); + assertArrayEquals(expectedBytes, array); } - }).start(); + endBarrier.await(); + return null; + } + }; + List> tasks = new ArrayList>(); + for (int i = 0; i < 10; i++) { + FutureTask task = new FutureTask(callable); + new Thread(task).start(); + tasks.add(task); + } + try { + endBarrier.await(30, TimeUnit.SECONDS); + } catch (Exception e) { + for (FutureTask task : tasks) { + try { + task.get(100, TimeUnit.MILLISECONDS); + } catch (Exception ex) { + e.addSuppressed(ex); + } + } + throw e; + } + for (FutureTask task : tasks) { + task.get(1, TimeUnit.SECONDS); } - latch.await(10, TimeUnit.SECONDS); - barrier.await(5, TimeUnit.SECONDS); - assertNull(cause.get()); } public static Object[][] setCharSequenceCombinations() { diff --git a/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java index 448930a3189..4c212410d88 100644 --- a/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java @@ -16,10 +16,17 @@ package io.netty.buffer; import io.netty.util.NettyRuntime; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.RepetitionInfo; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + +import java.lang.reflect.Array; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.SplittableRandom; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; @@ -111,24 +118,29 @@ public void testUsedHeapMemory() { @Test void adaptiveChunkMustDeallocateOrReuseWthBufferRelease() throws Exception { AdaptiveByteBufAllocator allocator = newAllocator(false); - ByteBuf a = allocator.heapBuffer(28 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - ByteBuf b = allocator.heapBuffer(100 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - b.release(); - a.release(); - assertEquals(262144, allocator.usedHeapMemory()); - a = allocator.heapBuffer(28 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - b = allocator.heapBuffer(100 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - a.release(); - ByteBuf c = allocator.heapBuffer(28 * 1024); - assertEquals(2 * 262144, allocator.usedHeapMemory()); - c.release(); - assertEquals(2 * 262144, allocator.usedHeapMemory()); - b.release(); - assertEquals(2 * 262144, allocator.usedHeapMemory()); + Deque bufs = new ArrayDeque(); + assertEquals(0, allocator.usedHeapMemory()); + assertEquals(0, allocator.usedHeapMemory()); + bufs.add(allocator.heapBuffer(256)); + long usedHeapMemory = allocator.usedHeapMemory(); + int buffersPerChunk = Math.toIntExact(usedHeapMemory / 256); + for (int i = 0; i < buffersPerChunk; i++) { + bufs.add(allocator.heapBuffer(256)); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + bufs.pop().release(); + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + while (!bufs.isEmpty()) { + bufs.pop().release(); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + for (int i = 0; i < 2 * buffersPerChunk; i++) { + bufs.add(allocator.heapBuffer(256)); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + while (!bufs.isEmpty()) { + bufs.pop().release(); + } } @ParameterizedTest @@ -198,4 +210,71 @@ public void run() { fail("Expected no exception, but got", throwable); } } + + @RepeatedTest(100) + void buddyAllocationConsistency(RepetitionInfo info) { + SplittableRandom rng = new SplittableRandom(info.getCurrentRepetition()); + AdaptiveByteBufAllocator allocator = newAllocator(true); + int small = 32768; + int large = 2 * small; + int xlarge = 2 * large; + + int[] allocationSizes = { + small, small, small, small, small, small, small, small, + large, large, large, large, + xlarge, xlarge, + }; + + shuffle(rng, allocationSizes); + + ByteBuf[] bufs = new ByteBuf[allocationSizes.length]; + for (int i = 0; i < bufs.length; i++) { + bufs[i] = allocator.buffer(allocationSizes[i], allocationSizes[i]); + } + + shuffle(rng, bufs); + + int[] reallocations = new int[bufs.length / 2]; + for (int i = 0; i < reallocations.length; i++) { + reallocations[i] = bufs[i].capacity(); + bufs[i].release(); + bufs[i] = null; + } + for (int i = 0; i < reallocations.length; i++) { + assertNull(bufs[i]); + bufs[i] = allocator.buffer(reallocations[i], reallocations[i]); + } + + for (int i = 0; i < bufs.length; i++) { + while (bufs[i].isWritable()) { + bufs[i].writeByte(i + 1); + } + } + try { + for (int i = 0; i < bufs.length; i++) { + while (bufs[i].isReadable()) { + int b = Byte.toUnsignedInt(bufs[i].readByte()); + if (b != i + 1) { + fail("Expected byte " + (i + 1) + + " at index " + (bufs[i].readerIndex() - 1) + + " but got " + b); + } + } + } + } finally { + for (ByteBuf buf : bufs) { + buf.release(); + } + } + } + + private static void shuffle(SplittableRandom rng, Object array) { + int len = Array.getLength(array); + for (int i = 0; i < len; i++) { + int n = rng.nextInt(i, len); + Object value = Array.get(array, i); + Array.set(array, i, Array.get(array, n)); + Array.set(array, n, value); + } + } } diff --git a/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java index ab47050c641..4a4c28deebf 100644 --- a/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java @@ -15,52 +15,11 @@ */ package io.netty.buffer; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.function.Supplier; - import static org.junit.jupiter.api.Assertions.assertEquals; -class AdaptivePoolingAllocatorTest implements Supplier { - private int i; - - @BeforeEach - void setUp() { - i = 0; - } - - @Override - public String get() { - return "i = " + i; - } - - @Test - void sizeBucketComputations() throws Exception { - assertSizeBucket(0, 16 * 1024); - assertSizeBucket(1, 24 * 1024); - assertSizeBucket(2, 32 * 1024); - assertSizeBucket(3, 48 * 1024); - assertSizeBucket(4, 64 * 1024); - assertSizeBucket(5, 96 * 1024); - assertSizeBucket(6, 128 * 1024); - assertSizeBucket(7, 192 * 1024); - assertSizeBucket(8, 256 * 1024); - assertSizeBucket(9, 384 * 1024); - assertSizeBucket(10, 512 * 1024); - assertSizeBucket(11, 768 * 1024); - assertSizeBucket(12, 1024 * 1024); - assertSizeBucket(13, 1792 * 1024); - assertSizeBucket(14, 2048 * 1024); - assertSizeBucket(15, 3072 * 1024); - // The sizeBucket function will be used for sizes up to 8 MiB - assertSizeBucket(15, 4 * 1024 * 1024); - assertSizeBucket(15, 5 * 1024 * 1024); - assertSizeBucket(15, 6 * 1024 * 1024); - assertSizeBucket(15, 7 * 1024 * 1024); - assertSizeBucket(15, 8 * 1024 * 1024); - } - +class AdaptivePoolingAllocatorTest { @Test void sizeClassComputations() throws Exception { final int[] sizeClasses = AdaptivePoolingAllocator.getSizeClasses(); @@ -75,20 +34,7 @@ void sizeClassComputations() throws Exception { private static void assertSizeClassOf(int expectedSizeClass, int previousSizeIncluded, int maxSizeIncluded) { for (int size = previousSizeIncluded; size <= maxSizeIncluded; size++) { - final int sizeToTest = size; - Supplier messageSupplier = new Supplier() { - @Override - public String get() { - return "size = " + sizeToTest; - } - }; - assertEquals(expectedSizeClass, AdaptivePoolingAllocator.sizeClassIndexOf(size), messageSupplier); - } - } - - private void assertSizeBucket(int expectedSizeBucket, int maxSizeIncluded) { - for (; i <= maxSizeIncluded; i++) { - assertEquals(expectedSizeBucket, AdaptivePoolingAllocator.sizeToBucket(i), this); + assertEquals(expectedSizeClass, AdaptivePoolingAllocator.sizeClassIndexOf(size), "size = " + size); } } } diff --git a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java index ecc01065210..64638f8e1cb 100644 --- a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java @@ -20,6 +20,7 @@ import io.netty.util.concurrent.FastThreadLocalThread; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThrowableUtil; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -30,6 +31,7 @@ import java.util.Random; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.FutureTask; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -349,13 +351,13 @@ public void testAllocateSmallOffset() { } @Test - @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + @Timeout(value = 20, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public void testThreadCacheDestroyedByThreadCleaner() throws InterruptedException { testThreadCacheDestroyed(false); } @Test - @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + @Timeout(value = 20, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public void testThreadCacheDestroyedAfterExitRun() throws InterruptedException { testThreadCacheDestroyed(true); } @@ -408,7 +410,6 @@ public void run() { while (allocator.metric().numThreadLocalCaches() > 0) { // Signal we want to have a GC run to ensure we can process our ThreadCleanerReference System.gc(); - System.runFinalization(); LockSupport.parkNanos(MILLISECONDS.toNanos(100)); } @@ -416,8 +417,8 @@ public void run() { } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testNumThreadCachesWithNoDirectArenas() throws InterruptedException { + @Timeout(10) + public void testNumThreadCachesWithNoDirectArenas() throws Exception { int numHeapArenas = 1; final PooledByteBufAllocator allocator = new PooledByteBufAllocator(numHeapArenas, 0, 8192, 1); @@ -436,11 +437,11 @@ public void testNumThreadCachesWithNoDirectArenas() throws InterruptedException } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testNumThreadCachesAccountForDirectAndHeapArenas() throws InterruptedException { - int numHeapArenas = 1; + @Timeout(10) + public void testNumThreadCachesAccountForDirectAndHeapArenas() throws Exception { + int numArenas = 1; final PooledByteBufAllocator allocator = - new PooledByteBufAllocator(numHeapArenas, 0, 8192, 1); + new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); ThreadCache tcache0 = createNewThreadCache(allocator, false); assertEquals(1, allocator.metric().numThreadLocalCaches()); @@ -456,8 +457,8 @@ public void testNumThreadCachesAccountForDirectAndHeapArenas() throws Interrupte } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testThreadCacheToArenaMappings() throws InterruptedException { + @Timeout(10) + public void testThreadCacheToArenaMappings() throws Exception { int numArenas = 2; final PooledByteBufAllocator allocator = new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); @@ -500,8 +501,7 @@ private static ThreadCache createNewThreadCache(final PooledByteBufAllocator all throws InterruptedException { final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch cacheLatch = new CountDownLatch(1); - final Thread t = new FastThreadLocalThread(new Runnable() { - + final FutureTask task = new FutureTask(new Runnable() { @Override public void run() { final ByteBuf buf; @@ -527,23 +527,35 @@ public void run() { FastThreadLocal.removeAll(); } - }); + }, null); + final Thread t = new FastThreadLocalThread(task); t.start(); // Wait until we allocated a buffer and so be sure the thread was started and the cache exists. - cacheLatch.await(); + try { + cacheLatch.await(); + } catch (InterruptedException e) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(t, e); + throw e; + } return new ThreadCache() { @Override - public void destroy() throws InterruptedException { + public void destroy() throws Exception { latch.countDown(); - t.join(); + try { + task.get(); + t.join(); + } catch (InterruptedException e) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(t, e); + throw e; + } } }; } private interface ThreadCache { - void destroy() throws InterruptedException; + void destroy() throws Exception; } @Test diff --git a/buffer/src/test/java/io/netty/buffer/UnpooledTest.java b/buffer/src/test/java/io/netty/buffer/UnpooledTest.java index efc1dafd1ed..fd705f1bed0 100644 --- a/buffer/src/test/java/io/netty/buffer/UnpooledTest.java +++ b/buffer/src/test/java/io/netty/buffer/UnpooledTest.java @@ -476,7 +476,7 @@ public void testUnmodifiableBuffer() throws Exception { } catch (UnsupportedOperationException e) { // Expected } - Mockito.verifyZeroInteractions(inputStream); + Mockito.verifyNoInteractions(inputStream); ScatteringByteChannel scatteringByteChannel = Mockito.mock(ScatteringByteChannel.class); try { @@ -485,7 +485,7 @@ public void testUnmodifiableBuffer() throws Exception { } catch (UnsupportedOperationException e) { // Expected } - Mockito.verifyZeroInteractions(scatteringByteChannel); + Mockito.verifyNoInteractions(scatteringByteChannel); buf.release(); } diff --git a/codec-dns/pom.xml b/codec-dns/pom.xml index 39846136332..00ed2a59064 100644 --- a/codec-dns/pom.xml +++ b/codec-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-dns diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java index 2aea39159fe..80cf862ab6a 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java @@ -16,6 +16,8 @@ package io.netty.handler.codec.dns; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.CorruptedFrameException; /** * The default {@link DnsRecordDecoder} implementation. @@ -99,6 +101,30 @@ protected DnsRecord decodeRecord( DnsCodecUtil.decompressDomainName( in.duplicate().setIndex(offset, offset + length))); } + if (type == DnsRecordType.MX) { + // MX RDATA: 16-bit preference + exchange (domain name, possibly compressed) + if (length < 3) { + throw new CorruptedFrameException("MX record RDATA is too short: " + length); + } + final int pref = in.getUnsignedShort(offset); + ByteBuf exchange = null; + try { + exchange = DnsCodecUtil.decompressDomainName( + in.duplicate().setIndex(offset + 2, offset + length)); + + // Build decompressed RDATA = [preference][expanded exchange name] + final ByteBuf out = in.alloc().buffer(2 + exchange.readableBytes()); + out.writeShort(pref); + out.writeBytes(exchange); + + return new DefaultDnsRawRecord(name, type, dnsClass, timeToLive, out); + } finally { + if (exchange != null) { + exchange.release(); + } + } + } + return new DefaultDnsRawRecord( name, type, dnsClass, timeToLive, in.retainedDuplicate().setIndex(offset, offset + length)); } diff --git a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java index a8379f6d8d7..d66b994b604 100644 --- a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java +++ b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java @@ -166,6 +166,51 @@ public void testDecodeCompressionRDataPointer() throws Exception { } } + @Test + public void testDecodeCompressionRDataPointerMX() throws Exception { + DefaultDnsRecordDecoder decoder = new DefaultDnsRecordDecoder(); + byte[] compressionPointer = { + 5, 'n', 'e', 't', 't', 'y', 2, 'i', 'o', 0, + 0, 10, // preference = 10 + (byte) 0xC0, 0 // record is a pointer to netty.io + }; + + byte[] expected = { + 0, 10, // pref = 10 + 5, 'n', 'e', 't', 't', 'y', 2, 'i', 'o', 0 + }; + ByteBuf buffer = Unpooled.wrappedBuffer(compressionPointer); + DefaultDnsRawRecord mxRecord = null; + ByteBuf expectedBuf = null; + try { + mxRecord = (DefaultDnsRawRecord) decoder.decodeRecord( + "mail.example.com", + DnsRecordType.MX, + DnsRecord.CLASS_IN, + 60, + buffer, + 10, + 4); + + expectedBuf = Unpooled.wrappedBuffer(expected); + + assertEquals(0, ByteBufUtil.compare(expectedBuf, mxRecord.content()), + "The rdata of MX-type record should be decompressed in advance"); + assertEquals(10, mxRecord.content().getUnsignedShort(0)); + + ByteBuf exchangerName = mxRecord.content().duplicate().setIndex(2, mxRecord.content().writerIndex()); + assertEquals("netty.io.", DnsCodecUtil.decodeDomainName(exchangerName)); + } finally { + buffer.release(); + if (expectedBuf != null) { + expectedBuf.release(); + } + if (mxRecord != null) { + mxRecord.release(); + } + } + } + @Test public void testDecodeMessageCompression() throws Exception { // See https://www.ietf.org/rfc/rfc1035 [4.1.4. Message compression] diff --git a/codec-haproxy/pom.xml b/codec-haproxy/pom.xml index b7a0e7202a1..139ef90e84d 100644 --- a/codec-haproxy/pom.xml +++ b/codec-haproxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-haproxy diff --git a/codec-http/pom.xml b/codec-http/pom.xml index 9d37046e74e..0c6b82ac520 100644 --- a/codec-http/pom.xml +++ b/codec-http/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-http diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java index 3cd8d0c6985..a4762516846 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java @@ -92,7 +92,15 @@ public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String */ public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader) { - super(httpVersion, method, uri, headers); + this(httpVersion, method, uri, content, headers, trailingHeader, true); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, contents, and header and trailer objects. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader, boolean validateRequestLine) { + super(httpVersion, method, uri, headers, validateRequestLine); this.content = checkNotNull(content, "content"); this.trailingHeader = checkNotNull(trailingHeader, "trailingHeader"); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java index 271b6069a02..437598503e6 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java @@ -75,9 +75,25 @@ public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri * @param headers the Headers for this Request */ public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, HttpHeaders headers) { + this(httpVersion, method, uri, headers, true); + } + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + * @param headers the Headers for this Request + */ + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, HttpHeaders headers, + boolean validateRequestLine) { super(httpVersion, headers); this.method = checkNotNull(method, "method"); this.uri = checkNotNull(uri, "uri"); + if (validateRequestLine) { + HttpUtil.validateRequestLineTokens(method, uri); + } } @Override diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java new file mode 100644 index 00000000000..6839ce8d8db --- /dev/null +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java @@ -0,0 +1,170 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.ByteProcessor; + +import java.util.BitSet; + +/** + * Validates the chunk start line. That is, the chunk size and chunk extensions, until the CR LF pair. + * See RFC 9112 section 7.1. + * + *

{@code
+ *   chunked-body   = *chunk
+ *                    last-chunk
+ *                    trailer-section
+ *                    CRLF
+ *
+ *   chunk          = chunk-size [ chunk-ext ] CRLF
+ *                    chunk-data CRLF
+ *   chunk-size     = 1*HEXDIG
+ *   last-chunk     = 1*("0") [ chunk-ext ] CRLF
+ *
+ *   chunk-data     = 1*OCTET ; a sequence of chunk-size octets
+ *   chunk-ext      = *( BWS ";" BWS chunk-ext-name
+ *                       [ BWS "=" BWS chunk-ext-val ] )
+ *
+ *   chunk-ext-name = token
+ *   chunk-ext-val  = token / quoted-string
+ *   quoted-string  = DQUOTE *( qdtext / quoted-pair ) DQUOTE
+ *   qdtext         = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
+ *   quoted-pair    = "\" ( HTAB / SP / VCHAR / obs-text )
+ *   obs-text       = %x80-FF
+ *   OWS            = *( SP / HTAB )
+ *                  ; optional whitespace
+ *   BWS            = OWS
+ *                  ; "bad" whitespace
+ *   VCHAR          =  %x21-7E
+ *                  ; visible (printing) characters
+ * }
+ */ +final class HttpChunkLineValidatingByteProcessor implements ByteProcessor { + private static final int SIZE = 0; + private static final int CHUNK_EXT_NAME = 1; + private static final int CHUNK_EXT_VAL_START = 2; + private static final int CHUNK_EXT_VAL_QUOTED = 3; + private static final int CHUNK_EXT_VAL_QUOTED_ESCAPE = 4; + private static final int CHUNK_EXT_VAL_QUOTED_END = 5; + private static final int CHUNK_EXT_VAL_TOKEN = 6; + + static final class Match extends BitSet { + private static final long serialVersionUID = 49522994383099834L; + private final int then; + + Match(int then) { + super(256); + this.then = then; + } + + Match chars(String chars) { + return chars(chars, true); + } + + Match chars(String chars, boolean value) { + for (int i = 0, len = chars.length(); i < len; i++) { + set(chars.charAt(i), value); + } + return this; + } + + Match range(int from, int to) { + return range(from, to, true); + } + + Match range(int from, int to, boolean value) { + for (int i = from; i <= to; i++) { + set(i, value); + } + return this; + } + } + + private enum State { + Size( + new Match(SIZE).chars("0123456789abcdefABCDEF \t"), + new Match(CHUNK_EXT_NAME).chars(";")), + ChunkExtName( + new Match(CHUNK_EXT_NAME) + .range(0x21, 0x7E) + .chars(" \t") + .chars("(),/:<=>?@[\\]{}", false), + new Match(CHUNK_EXT_VAL_START).chars("=")), + ChunkExtValStart( + new Match(CHUNK_EXT_VAL_START).chars(" \t"), + new Match(CHUNK_EXT_VAL_QUOTED).chars("\""), + new Match(CHUNK_EXT_VAL_TOKEN) + .range(0x21, 0x7E) + .chars("(),/:<=>?@[\\]{}", false)), + ChunkExtValQuoted( + new Match(CHUNK_EXT_VAL_QUOTED_ESCAPE).chars("\\"), + new Match(CHUNK_EXT_VAL_QUOTED_END).chars("\""), + new Match(CHUNK_EXT_VAL_QUOTED) + .chars("\t !") + .range(0x23, 0x5B) + .range(0x5D, 0x7E) + .range(0x80, 0xFF)), + ChunkExtValQuotedEscape( + new Match(CHUNK_EXT_VAL_QUOTED) + .chars("\t ") + .range(0x21, 0x7E) + .range(0x80, 0xFF)), + ChunkExtValQuotedEnd( + new Match(CHUNK_EXT_VAL_QUOTED_END).chars("\t "), + new Match(CHUNK_EXT_NAME).chars(";")), + ChunkExtValToken( + new Match(CHUNK_EXT_VAL_TOKEN) + .range(0x21, 0x7E, true) + .chars("(),/:<=>?@[\\]{}", false), + new Match(CHUNK_EXT_NAME).chars(";")), + ; + + private final Match[] matches; + + State(Match... matches) { + this.matches = matches; + } + + State match(byte value) { + for (Match match : matches) { + if (match.get(value)) { + return STATES_BY_ORDINAL[match.then]; + } + } + if (this == Size) { + throw new NumberFormatException("Invalid chunk size"); + } else { + throw new InvalidChunkExtensionException("Invalid chunk extension"); + } + } + } + + private static final State[] STATES_BY_ORDINAL = State.values(); + + private State state = State.Size; + + @Override + public boolean process(byte value) { + state = state.match(value); + return true; + } + + public void finish() { + if (state != State.Size && state != State.ChunkExtName && state != State.ChunkExtValQuotedEnd) { + throw new InvalidChunkExtensionException("Invalid chunk extension"); + } + } +} diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java index 6b09c1614b3..70a52848c2d 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java @@ -27,14 +27,34 @@ public final class HttpHeaderValues { */ public static final AsciiString APPLICATION_JSON = AsciiString.cached("application/json"); /** - * {@code "application/x-www-form-urlencoded"} + * {@code "application/manifest+json"} */ - public static final AsciiString APPLICATION_X_WWW_FORM_URLENCODED = - AsciiString.cached("application/x-www-form-urlencoded"); + public static final AsciiString APPLICATION_MANIFEST_JSON = AsciiString.cached("application/manifest+json"); /** * {@code "application/octet-stream"} */ public static final AsciiString APPLICATION_OCTET_STREAM = AsciiString.cached("application/octet-stream"); + /** + * {@code "application/ogg"} + */ + public static final AsciiString APPLICATION_OGG = AsciiString.cached("application/ogg"); + /** + * {@code "application/pdf"} + */ + public static final AsciiString APPLICATION_PDF = AsciiString.cached("application/pdf"); + /** + * {@code "application/rtf"} + */ + public static final AsciiString APPLICATION_RTF = AsciiString.cached("application/rtf"); + /** + * {@code "application/wasm"} + */ + public static final AsciiString APPLICATION_WASM = AsciiString.cached("application/wasm"); + /** + * {@code "application/x-www-form-urlencoded"} + */ + public static final AsciiString APPLICATION_X_WWW_FORM_URLENCODED = + AsciiString.cached("application/x-www-form-urlencoded"); /** * {@code "application/xhtml+xml"} */ @@ -52,6 +72,34 @@ public final class HttpHeaderValues { * See {@link HttpHeaderNames#CONTENT_DISPOSITION} */ public static final AsciiString ATTACHMENT = AsciiString.cached("attachment"); + /** + * {@code "audio/aac"} + */ + public static final AsciiString AUDIO_AAC = AsciiString.cached("audio/aac"); + /** + * {@code "audio/midi"} + */ + public static final AsciiString AUDIO_MIDI = AsciiString.cached("audio/midi"); + /** + * {@code "audio/x-midi"} + */ + public static final AsciiString AUDIO_X_MIDI = AsciiString.cached("audio/x-midi"); + /** + * {@code "audio/mpeg"} + */ + public static final AsciiString AUDIO_MPEG = AsciiString.cached("audio/mpeg"); + /** + * {@code "audio/ogg"} + */ + public static final AsciiString AUDIO_OGG = AsciiString.cached("audio/ogg"); + /** + * {@code "audio/wav"} + */ + public static final AsciiString AUDIO_WAV = AsciiString.cached("audio/wav"); + /** + * {@code "audio/webm"} + */ + public static final AsciiString AUDIO_WEBM = AsciiString.cached("audio/webm"); /** * {@code "base64"} */ @@ -106,6 +154,22 @@ public final class HttpHeaderValues { * See {@link HttpHeaderNames#CONTENT_DISPOSITION} */ public static final AsciiString FILENAME = AsciiString.cached("filename"); + /** + * {@code "font/otf"} + */ + public static final AsciiString FONT_OTF = AsciiString.cached("font/otf"); + /** + * {@code "font/ttf"} + */ + public static final AsciiString FONT_TTF = AsciiString.cached("font/ttf"); + /** + * {@code "font/woff"} + */ + public static final AsciiString FONT_WOFF = AsciiString.cached("font/woff"); + /** + * {@code "font/woff2"} + */ + public static final AsciiString FONT_WOFF2 = AsciiString.cached("font/woff2"); /** * {@code "form-data"} * See {@link HttpHeaderNames#CONTENT_DISPOSITION} @@ -141,6 +205,34 @@ public final class HttpHeaderValues { * {@code "identity"} */ public static final AsciiString IDENTITY = AsciiString.cached("identity"); + /** + * {@code "image/avif"} + */ + public static final AsciiString IMAGE_AVIF = AsciiString.cached("image/avif"); + /** + * {@code "image/bmp"} + */ + public static final AsciiString IMAGE_BMP = AsciiString.cached("image/bmp"); + /** + * {@code "image/jpeg"} + */ + public static final AsciiString IMAGE_JPEG = AsciiString.cached("image/jpeg"); + /** + * {@code "image/png"} + */ + public static final AsciiString IMAGE_PNG = AsciiString.cached("image/png"); + /** + * {@code "image/svg+xml"} + */ + public static final AsciiString IMAGE_SVG_XML = AsciiString.cached("image/svg+xml"); + /** + * {@code "image/tiff"} + */ + public static final AsciiString IMAGE_TIFF = AsciiString.cached("image/tiff"); + /** + * {@code "image/webp"} + */ + public static final AsciiString IMAGE_WEBP = AsciiString.cached("image/webp"); /** * {@code "keep-alive"} */ @@ -222,10 +314,22 @@ public final class HttpHeaderValues { * {@code "text/css"} */ public static final AsciiString TEXT_CSS = AsciiString.cached("text/css"); + /** + * {@code "text/csv"} + */ + public static final AsciiString TEXT_CSV = AsciiString.cached("text/csv"); /** * {@code "text/html"} */ public static final AsciiString TEXT_HTML = AsciiString.cached("text/html"); + /** + * {@code "text/javascript"} + */ + public static final AsciiString TEXT_JAVASCRIPT = AsciiString.cached("text/javascript"); + /** + * {@code "text/markdown"} + */ + public static final AsciiString TEXT_MARKDOWN = AsciiString.cached("text/markdown"); /** * {@code "text/event-stream"} */ @@ -242,6 +346,22 @@ public final class HttpHeaderValues { * {@code "upgrade"} */ public static final AsciiString UPGRADE = AsciiString.cached("upgrade"); + /** + * {@code "video/mp4"} + */ + public static final AsciiString VIDEO_MP4 = AsciiString.cached("video/mp4"); + /** + * {@code "video/mpeg"} + */ + public static final AsciiString VIDEO_MPEG = AsciiString.cached("video/mpeg"); + /** + * {@code "video/ogg"} + */ + public static final AsciiString VIDEO_OGG = AsciiString.cached("video/ogg"); + /** + * {@code "video/webm"} + */ + public static final AsciiString VIDEO_WEBM = AsciiString.cached("video/webm"); /** * {@code "websocket"} */ diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java index 1efd2c58b77..8aa4e5cea31 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java @@ -158,14 +158,14 @@ protected boolean isContentLengthInvalid(HttpMessage start, int maxContentLength } } - private static Object continueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { + private Object continueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { if (HttpUtil.isUnsupportedExpectation(start)) { // if the request contains an unsupported expectation, we return 417 pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); return EXPECTATION_FAILED.retainedDuplicate(); } else if (HttpUtil.is100ContinueExpected(start)) { // if the request contains 100-continue but the content-length is too large, we return 413 - if (getContentLength(start, -1L) <= maxContentLength) { + if (!isContentLengthInvalid(start, maxContentLength)) { return CONTINUE.retainedDuplicate(); } pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); @@ -247,7 +247,8 @@ protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessa // If the client started to send data already, close because it's impossible to recover. // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection open. - if (oversized instanceof FullHttpMessage || + // If auto read is false the channel must be closed or it will be stuck without a call to read() + if (oversized instanceof FullHttpMessage || !ctx.channel().config().isAutoRead() || !HttpUtil.is100ContinueExpected(oversized) && !HttpUtil.isKeepAlive(oversized)) { ChannelFuture future = ctx.writeAndFlush(TOO_LARGE_CLOSE.retainedDuplicate()); future.addListener(new ChannelFutureListener() { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java index 2f0d6c4fd72..06819d01245 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java @@ -477,6 +477,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou if (line == null) { return; } + checkChunkExtensions(line); int chunkSize = getChunkSize(line.array(), line.arrayOffset() + line.readerIndex(), line.readableBytes()); this.chunkSize = chunkSize; if (chunkSize == 0) { @@ -723,6 +724,16 @@ private HttpMessage invalidMessage(HttpMessage current, ByteBuf in, Exception ca return current; } + private static void checkChunkExtensions(ByteBuf line) { + int extensionsStart = line.bytesBefore((byte) ';'); + if (extensionsStart == -1) { + return; + } + HttpChunkLineValidatingByteProcessor processor = new HttpChunkLineValidatingByteProcessor(); + line.forEachByte(processor); + processor.finish(); + } + private HttpContent invalidChunk(ByteBuf in, Exception cause) { currentState = State.BAD_MESSAGE; message = null; @@ -867,7 +878,6 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { return LastHttpContent.EMPTY_LAST_CONTENT; } - CharSequence lastHeader = null; if (trailer == null) { trailer = this.trailer = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, trailersFactory); } @@ -875,29 +885,19 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { final byte[] lineContent = line.array(); final int startLine = line.arrayOffset() + line.readerIndex(); final byte firstChar = lineContent[startLine]; - if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) { - List current = trailer.trailingHeaders().getAll(lastHeader); - if (!current.isEmpty()) { - int lastPos = current.size() - 1; - //please do not make one line from below code - //as it breaks +XX:OptimizeStringConcat optimization - String lineTrimmed = langAsciiString(lineContent, startLine, line.readableBytes()).trim(); - String currentLastPos = current.get(lastPos); - current.set(lastPos, currentLastPos + lineTrimmed); - } + if (name != null && (firstChar == ' ' || firstChar == '\t')) { + //please do not make one line from below code + //as it breaks +XX:OptimizeStringConcat optimization + String trimmedLine = langAsciiString(lineContent, startLine, lineLength).trim(); + String valueStr = value; + value = valueStr + ' ' + trimmedLine; } else { - splitHeader(lineContent, startLine, lineLength); - AsciiString headerName = name; - if (!HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName) && - !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(headerName) && - !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(headerName)) { - trailer.trailingHeaders().add(headerName, value); + if (name != null && isPermittedTrailingHeader(name)) { + trailer.trailingHeaders().add(name, value); } - lastHeader = name; - // reset name and value fields - name = null; - value = null; + splitHeader(lineContent, startLine, lineLength); } + line = headerParser.parse(buffer, defaultStrictCRLFCheck); if (line == null) { return null; @@ -905,10 +905,28 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { lineLength = line.readableBytes(); } + // Add the last trailer + if (name != null && isPermittedTrailingHeader(name)) { + trailer.trailingHeaders().add(name, value); + } + + // reset name and value fields + name = null; + value = null; + this.trailer = null; return trailer; } + /** + * Checks whether the given trailer field name is permitted per RFC 9110 section 6.5 + */ + private static boolean isPermittedTrailingHeader(final AsciiString name) { + return !HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(name) && + !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(name) && + !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(name); + } + protected abstract boolean isDecodingRequest(); protected abstract HttpMessage createMessage(String[] initialLine) throws Exception; protected abstract HttpMessage createInvalidMessage(); @@ -926,7 +944,7 @@ private static int skipWhiteSpaces(byte[] hex, int start, int length) { } private static int getChunkSize(byte[] hex, int start, int length) { - // trim the leading bytes if white spaces, if any + // trim the leading bytes of white spaces, if any final int skipped = skipWhiteSpaces(hex, start, length); if (skipped == length) { // empty case diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java index 409718628b4..643f79a9757 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java @@ -45,7 +45,7 @@ private HttpUtil() { } /** * Determine if a uri is in origin-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.1. */ public static boolean isOriginForm(URI uri) { return isOriginForm(uri.toString()); @@ -53,7 +53,7 @@ public static boolean isOriginForm(URI uri) { /** * Determine if a string uri is in origin-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.1. */ public static boolean isOriginForm(String uri) { return uri.startsWith("/"); @@ -61,7 +61,7 @@ public static boolean isOriginForm(String uri) { /** * Determine if a uri is in asterisk-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.4. */ public static boolean isAsteriskForm(URI uri) { return isAsteriskForm(uri.toString()); @@ -69,16 +69,59 @@ public static boolean isAsteriskForm(URI uri) { /** * Determine if a string uri is in asterisk-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.4. */ public static boolean isAsteriskForm(String uri) { return "*".equals(uri); } + static void validateRequestLineTokens(HttpMethod method, String uri) { + // The HttpVersion class does its own validation, and it's not possible for subclasses to circumvent it. + // The HttpMethod class does its own validation, but subclasses might circumvent it. + if (method.getClass() != HttpMethod.class) { + if (!isEncodingSafeStartLineToken(method.asciiName())) { + throw new IllegalArgumentException( + "The HTTP method name contain illegal characters: " + method.asciiName()); + } + } + + if (!isEncodingSafeStartLineToken(uri)) { + throw new IllegalArgumentException("The URI contain illegal characters: " + uri); + } + } + /** - * Returns {@code true} if and only if the connection can remain open and - * thus 'kept alive'. This methods respects the value of the. + * Validate that the given request line token is safe for verbatim encoding to the network. + * This does not fully check that the token – HTTP method, version, or URI – is valid and formatted correctly. + * Only that the token does not contain characters that would break or + * desynchronize HTTP message parsing of the start line wherein the token would be included. + *

+ * See RFC 9112, 3. * + * @param token The token to check. + * @return {@code true} if the token is safe to encode verbatim into the HTTP message output stream, + * otherwise {@code false}. + */ + public static boolean isEncodingSafeStartLineToken(CharSequence token) { + int lenBytes = token.length(); + for (int i = 0; i < lenBytes; i++) { + char ch = token.charAt(i); + // this is to help AOT compiled code which cannot profile the switch + if (ch <= ' ') { + switch (ch) { + case '\n': + case '\r': + case ' ': + return false; + } + } + } + return true; + } + + /** + * Returns {@code true} if and only if the connection can remain open and + * thus 'kept alive'. This method respects the value of the * {@code "Connection"} header first and then the return value of * {@link HttpVersion#isKeepAliveDefault()}. */ @@ -676,8 +719,10 @@ private static int validateAsciiStringToken(AsciiString token) { */ private static int validateCharSequenceToken(CharSequence token) { for (int i = 0, len = token.length(); i < len; i++) { - byte value = (byte) token.charAt(i); - if (!isValidTokenChar(value)) { + int value = token.charAt(i); + // 1. Check for truncation (anything above 255) + // 2. Check against the BitSet (isValidTokenChar handles 128-255 via bit < 0) + if (value > 0xFF || !isValidTokenChar((byte) value)) { return i; } } @@ -761,18 +806,17 @@ private static int validateCharSequenceToken(CharSequence token) { // .bits('-', '.', '_', '~') // Unreserved characters. // .bits('!', '#', '$', '%', '&', '\'', '*', '+', '^', '`', '|'); // Token special characters. - //this constants calculated by the above code + // This constants calculated by the above code private static final long TOKEN_CHARS_HIGH = 0x57ffffffc7fffffeL; private static final long TOKEN_CHARS_LOW = 0x3ff6cfa00000000L; - private static boolean isValidTokenChar(byte bit) { - if (bit < 0) { + static boolean isValidTokenChar(byte octet) { + if (octet < 0) { return false; } - if (bit < 64) { - return 0 != (TOKEN_CHARS_LOW & 1L << bit); + if (octet < 64) { + return 0 != (TOKEN_CHARS_LOW & 1L << octet); } - return 0 != (TOKEN_CHARS_HIGH & 1L << bit - 64); + return 0 != (TOKEN_CHARS_HIGH & 1L << octet - 64); } - } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 75e958c5ad7..45d5e0ffb13 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; @@ -29,6 +30,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -58,6 +60,7 @@ public class CorsHandler extends ChannelDuplexHandler { private HttpRequest request; private final List configList; private final boolean isShortCircuit; + private boolean consumeContent; /** * Creates a new instance with a single {@link CorsConfig}. @@ -87,13 +90,28 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) throw config = getForOrigin(origin); if (isPreflightRequest(request)) { handlePreflight(ctx, request); + // Enable consumeContent so that all following HttpContent + // for this request will be released and not propagated downstream. + consumeContent = true; return; } if (isShortCircuit && !(origin == null || config != null)) { forbidden(ctx, request); + consumeContent = true; return; } + + // This request is forwarded, stop discarding + consumeContent = false; + ctx.fireChannelRead(msg); + return; + } + + if (consumeContent && (msg instanceof HttpContent)) { + ReferenceCountUtil.release(msg); + return; } + ctx.fireChannelRead(msg); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java index 01f1c0036c1..a898439fe22 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java @@ -21,7 +21,7 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -53,7 +53,7 @@ public static List extractExtensions(String extensionHea String name = extensionParameters[0].trim(); Map parameters; if (extensionParameters.length > 1) { - parameters = new HashMap(extensionParameters.length - 1); + parameters = new LinkedHashMap(extensionParameters.length - 1); for (int i = 1; i < extensionParameters.length; i++) { String parameter = extensionParameters[i].trim(); Matcher parameterMatcher = PARAMETER.matcher(parameter); @@ -93,7 +93,7 @@ static String computeMergeExtensionsHeaderValue(String userDefinedHeaderValue, extraExtensions.add(userDefined); } else { // merge with higher precedence to user defined parameters - Map mergedParameters = new HashMap(matchingExtra.parameters()); + Map mergedParameters = new LinkedHashMap(matchingExtra.parameters()); mergedParameters.putAll(userDefined.parameters()); extraExtensions.set(i, new WebSocketExtensionData(matchingExtra.name(), mergedParameters)); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java index 944f36e50b4..972b4c9e7ef 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java @@ -232,10 +232,16 @@ public WebSocketClientExtension handshakeExtension(WebSocketExtensionData extens if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { // allowed client_window_size_bits if (allowClientWindowSize) { - clientWindowSize = Integer.parseInt(parameter.getValue()); - if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { - succeed = false; + // RFC 7692: client_max_window_bits may have a value or no value + String value = parameter.getValue(); + if (value != null) { + // Let NumberFormatException bubble up if value is invalid + clientWindowSize = Integer.parseInt(value); + if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { + succeed = false; + } } + // If value is null, keep MAX_WINDOW_SIZE (default) } else { succeed = false; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java index ce19476f403..9aeb219b142 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java @@ -220,8 +220,18 @@ public WebSocketServerExtension handshakeExtension(WebSocketExtensionData extens Entry parameter = parametersIterator.next(); if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { - // use preferred clientWindowSize because client is compatible with customization - clientWindowSize = preferredClientWindowSize; + // RFC 7692: client_max_window_bits may have a value or no value + String value = parameter.getValue(); + if (value != null) { + // Let NumberFormatException bubble up if value is invalid + clientWindowSize = Integer.parseInt(value); + if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { + deflateEnabled = false; + } + } else { + // No value specified, use preferred client window size + clientWindowSize = preferredClientWindowSize; + } } else if (SERVER_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { // use provided windowSize if it is allowed if (allowServerWindowSize) { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java index 9ddb597ae9c..0a5b24aae55 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java @@ -17,12 +17,205 @@ import io.netty.util.AsciiString; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.SplittableRandom; +import java.util.function.LongFunction; +import java.util.stream.Stream; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class DefaultHttpRequestTest { + @ParameterizedTest + @ValueSource(strings = { + "http://localhost/\r\n", + "/r\r\n?q=1", + "http://localhost/\r\n?q=1", + "/r\r\n/?q=1", + "http://localhost/\r\n/?q=1", + "/r\r\n", + "http://localhost/ HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/r HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/ path", + "/path ", + " /path", + "http://localhost/ ", + " http://localhost/", + "http://local host/", + }) + void constructorMustRejectIllegalUrisByDefault(final String uri) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + } + }); + } + + public static Stream validUris() { + final String pdigit = "123456789"; + final String digit = '0' + pdigit; + final String digitcolon = digit + ':'; + final String alpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + final String alphanum = alpha + digit; + final String alphanumdot = alphanum + '.'; + final String unreserved = alphanumdot + "-_~"; + final String subdelims = "$&%=!+,;'()"; + final String userinfochars = unreserved + subdelims + ':'; + final String pathchars = unreserved + '/'; + final String querychars = pathchars + subdelims + '?'; + return new SplittableRandom().longs(1000) + .mapToObj(new LongFunction() { + @Override + public String apply(long seed) { + SplittableRandom rng = new SplittableRandom(seed); + String start; + String path; + String query; + String fragment; + if (rng.nextBoolean()) { + String scheme = rng.nextBoolean() ? "http://" : "HTTP://"; + String userinfo = rng.nextBoolean() ? "" : pick(rng, userinfochars, 1, 8) + '@'; + String host; + String port; + switch (rng.nextInt(3)) { + case 0: + host = pick(rng, alphanum, 1, 1) + pick(rng, alphanumdot, 1, 5); + break; + case 1: + host = pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2); + break; + default: + host = '[' + pick(rng, digitcolon, 1, 8) + ']'; + break; + } + if (rng.nextBoolean()) { + port = ':' + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 4); + } else { + port = ""; + } + start = scheme + userinfo + host + port; + } else { + start = ""; + } + path = '/' + pick(rng, pathchars, 0, 8); + if (rng.nextBoolean()) { + query = '?' + pick(rng, querychars, 0, 8); + } else { + query = ""; + } + if (rng.nextBoolean()) { + fragment = '#' + pick(rng, querychars, 0, 8); + } else { + fragment = ""; + } + return start + path + query + fragment; + } + }); + } + + private static String pick(SplittableRandom rng, String cs, int lowerBound, int upperBound) { + int length = rng.nextInt(lowerBound, upperBound + 1); + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(cs.charAt(rng.nextInt(cs.length()))); + } + return sb.toString(); + } + + @ParameterizedTest + @MethodSource("validUris") + void constructorMustAcceptValidUris(String uri) { + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + } + + @ParameterizedTest + @ValueSource(strings = { + "GET ", + " GET", + "G ET", + " GET ", + "GET\r", + "GET\n", + "GET\r\n", + "GE\rT", + "GE\nT", + "GE\r\nT", + "\rGET", + "\nGET", + "\r\nGET", + " \r\nGET", + "\r \nGET", + "\r\n GET", + "\r\nGET ", + "\nGET ", + "\rGET ", + "\r GET", + " \rGET", + "\nGET ", + "\n GET", + " \nGET", + "GET \n", + "GET \r", + " GET\r", + " GET\r", + "GET \n", + " GET\n", + " GET\n", + "GE\nT ", + "GE\rT ", + " GE\rT", + " GE\rT", + "GE\nT ", + " GE\nT", + " GE\nT", + }) + void constructorMustRejectIllegalHttpMethodByDefault(final String method) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttpRequest(HttpVersion.HTTP_1_0, new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }, "/"); + } + }); + } + + @ParameterizedTest + @ValueSource(strings = { + "GET", + "POST", + "PUT", + "HEAD", + "DELETE", + "OPTIONS", + "CONNECT", + "TRACE", + "PATCH", + "QUERY" + }) + void constructorMustAcceptAllHttpMethods(final String method) { + new DefaultHttpRequest(HttpVersion.HTTP_1_0, new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }, "/"); + + new DefaultHttpRequest(HttpVersion.HTTP_1_0, new HttpMethod(method), "/"); + } @Test public void testHeaderRemoval() { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java index 0bb21521387..4fb28861ad0 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -26,7 +26,6 @@ import io.netty.util.AsciiString; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; import org.mockito.Mockito; @@ -758,4 +757,34 @@ public void execute() { } }); } + + @Test + public void invalidContinueLength() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(1024)); + + channel.writeInbound(Unpooled.copiedBuffer("POST / HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length:\r\n" + + "\r\n\r\n", CharsetUtil.US_ASCII)); + assertTrue(channel.finishAndReleaseAll()); + } + + @Test + public void testOversizedRequestWithAutoReadFalse() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4)); + embedder.config().setAutoRead(false); + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Content-Length: 5\r\n\r\n", CharsetUtil.US_ASCII))); + + assertNull(embedder.readInbound()); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + ReferenceCountUtil.release(response); + + assertFalse(embedder.isOpen()); + assertFalse(embedder.finish()); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java index 0ce7d196fad..3cf5f45f345 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -import static io.netty.handler.codec.http.HttpHeaderNames.*; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -269,6 +269,92 @@ public void testEmptyHeaderValue() { assertEquals("", req.headers().get(of("EmptyHeader"))); } + @Test + public void testSingleTrailingHeader() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + "X-Checksum: abc123\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("abc123", last.trailingHeaders().get(of("X-Checksum"))); + last.release(); + assertFalse(channel.finish()); + } + + @Test + public void testMultiLineTrailingHeader() { + // Regression: folded trailer values previously threw UnsupportedOperationException + // because trailingHeaders().getAll() returns an AbstractList that does not implement set(). + // Note: obs-fold in trailers is permitted as trailers are field-lines per + // https://www.rfc-editor.org/rfc/rfc9112#section-5.2 + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + "X-Long: part1\r\n" + + " part2\r\n" + + "\t\t\t part3\r\n" + + "X-Short: value\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("part1 part2 part3", last.trailingHeaders().get(of("X-Long"))); + assertEquals("value", last.trailingHeaders().get(of("X-Short"))); + last.release(); + assertFalse(channel.finish()); + } + + @Test + public void testForbiddenTrailingHeadersAreDropped() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + HttpHeaderNames.CONTENT_LENGTH + ": 5\r\n" + + HttpHeaderNames.TRANSFER_ENCODING + ": chunked\r\n" + + "X-Custom: keep\r\n" + + HttpHeaderNames.TRAILER + ": X-Checksum\r\n" + // covering post-loop flush path + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertNull(last.trailingHeaders().get(HttpHeaderNames.CONTENT_LENGTH)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRAILER)); + assertEquals("keep", last.trailingHeaders().get(of("X-Custom"))); + last.release(); + assertFalse(channel.finish()); + } + @Test public void test100Continue() { HttpRequestDecoder decoder = new HttpRequestDecoder(); @@ -695,6 +781,187 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { assertFalse(channel.finish()); } + @Test + void mustParsedChunkExtensionsWithQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" ;\t\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithLineBreaksInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\"\r\n" + // chunk extension quote start + "X\r\n" + + "0\r\n\r\n" + + "GET /two HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParseChunkExtensionsWithQuotedStringsAndEscapes() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\";\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedLineBreakInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\n;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedCarriageReturnInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\r;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain carraige return. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void lineLengthRestrictionMustNotApplyToChunkContents() throws Exception { + char[] chars = new char[10000]; + Arrays.fill(chars, 'a'); + String requestContent = new String(chars); + String requestStr = "POST /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + Integer.toHexString(chars.length) + "\r\n" + + requestContent + "\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + int contentLength = 0; + HttpContent content; + do { + content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + if (decoderResult.cause() != null) { + throw new Exception(decoderResult.cause()); + } + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + contentLength += content.content().readableBytes(); + content.release(); + } while (!(content instanceof LastHttpContent)); + assertEquals(chars.length, contentLength); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkSizeWithNonHexadecimalCharacters() throws Exception { + String requestStr = "POST /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "test\r\n\r\n" + // chunk extension quote start + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers + HttpContent content = channel.readInbound(); + assertTrue(content.decoderResult().isFailure()); + assertThat(content.decoderResult().cause()).isInstanceOf(NumberFormatException.class); + assertFalse(channel.finish()); + } + @Test public void testOrderOfHeadersWithContentLength() { String requestStr = "GET /some/path HTTP/1.1\r\n" + diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java index 2c0ffd7d942..cd822063dab 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java @@ -37,8 +37,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -/** - */ public class HttpRequestEncoderTest { @SuppressWarnings("deprecation") diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java index d38e6169d0c..5573606b06d 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java @@ -26,11 +26,12 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; + import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -206,21 +207,6 @@ public void testResponseChunkedWithValidUncommonPatterns() { assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); - // leading whitespace, trailing control char - - assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + "\0\r\n", - CharsetUtil.US_ASCII))); - assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); - content = ch.readInbound(); - assertEquals(data.length, content.content().readableBytes()); - - decodedData = new byte[data.length]; - content.content().readBytes(decodedData); - assertArrayEquals(data, decodedData); - content.release(); - - assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); - // leading whitespace, trailing semicolon assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + ";\r\n", @@ -665,6 +651,64 @@ private static void testLastResponseWithTrailingHeaderFragmented(byte[] content, assertNull(ch.readInbound()); } + @Test + public void testMultiLineTrailingHeader() { + // Regression: folded trailer values previously threw UnsupportedOperationException + // because trailingHeaders().getAll() returns an AbstractList that does not implement set(). + // Note: obs-fold in trailers is permitted as trailers are field-lines per + // https://www.rfc-editor.org/rfc/rfc9112#section-5.2 + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + String response = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "X-Long: part1\r\n" + + " part2\r\n" + + "\t\t\t part3\r\n" + + "X-Short: value\r\n" + + "\r\n"; + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.US_ASCII))); + HttpResponse res = ch.readInbound(); + assertFalse(res.decoderResult().isFailure()); + assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); + assertEquals(HttpResponseStatus.OK, res.status()); + + LastHttpContent last = ch.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("part1 part2 part3", last.trailingHeaders().get(of("X-Long"))); + assertEquals("value", last.trailingHeaders().get(of("X-Short"))); + last.release(); + assertFalse(ch.finish()); + } + + @Test + public void testForbiddenTrailingHeadersAreDropped() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + String response = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + HttpHeaderNames.CONTENT_LENGTH + ": 5\r\n" + + HttpHeaderNames.TRANSFER_ENCODING + ": chunked\r\n" + + "X-Custom: keep\r\n" + + HttpHeaderNames.TRAILER + ": X-Checksum\r\n" + // covering post-loop flush path + "\r\n"; + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.US_ASCII))); + HttpResponse res = ch.readInbound(); + assertFalse(res.decoderResult().isFailure()); + assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); + assertEquals(HttpResponseStatus.OK, res.status()); + + LastHttpContent last = ch.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertNull(last.trailingHeaders().get(HttpHeaderNames.CONTENT_LENGTH)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRAILER)); + assertEquals("keep", last.trailingHeaders().get(of("X-Custom"))); + last.release(); + assertFalse(ch.finish()); + } + @Test public void testResponseWithContentLength() { EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); @@ -1000,7 +1044,7 @@ public void testGarbageChunkAfterWhiteSpaces() { @Test void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { // See full explanation: https://w4ke.info/2025/06/18/funky-chunks.html - String requestStr = "HTTP/1.1 200 OK\r\n" + + String responseStr = "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "2;\n" + // Chunk size followed by illegal single newline (not preceded by carraige return) @@ -1011,7 +1055,7 @@ void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { "Transfer-Encoding: chunked\r\n\r\n" + "0\r\n\r\n"; EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); - assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); HttpResponse response = channel.readInbound(); assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. assertTrue(response.headers().names().contains("Transfer-Encoding")); @@ -1027,7 +1071,7 @@ void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { @Test void mustRejectImproperlyTerminatedChunkBodies() throws Exception { // See full explanation: https://w4ke.info/2025/06/18/funky-chunks.html - String requestStr = "HTTP/1.1 200 OK\r\n" + + String responseStr = "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "5\r\n" + "AAAAXX" + // Chunk body contains extra (XX) bytes, and no CRLF terminator. @@ -1037,7 +1081,7 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { "Transfer-Encoding: chunked\r\n\r\n" + "0\r\n\r\n"; EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); - assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); HttpResponse response = channel.readInbound(); assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. assertTrue(response.headers().names().contains("Transfer-Encoding")); @@ -1054,6 +1098,185 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { assertFalse(channel.finish()); } + @Test + void mustParsedChunkExtensionsWithQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" ;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithLineBreaksInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\"\r\n" + // chunk extension quote start + "X\r\n" + + "0\r\n\r\n" + + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParsedChunkExtensionsWithQuotedStringsAndEscapes() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\";\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedLineBreakInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\n;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedCarraigeReturnInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\r;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain carriage returns. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void lineLengthRestrictionMustNotApplyToChunkContents() throws Exception { + char[] chars = new char[10000]; + Arrays.fill(chars, 'a'); + String requestContent = new String(chars); + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + Integer.toHexString(chars.length) + "\r\n" + + requestContent + "\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + int contentLength = 0; + HttpContent content; + do { + content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + if (decoderResult.cause() != null) { + throw new Exception(decoderResult.cause()); + } + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + contentLength += content.content().readableBytes(); + content.release(); + } while (!(content instanceof LastHttpContent)); + assertEquals(chars.length, contentLength); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkSizeWithNonHexadecimalCharacters() throws Exception { + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "test\r\n\r\n" + // chunk extension quote start + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers + HttpContent content = channel.readInbound(); + assertTrue(content.decoderResult().isFailure()); + assertThat(content.decoderResult().cause()).isInstanceOf(NumberFormatException.class); + assertFalse(channel.finish()); + } + @Test public void testConnectionClosedBeforeHeadersReceived() { EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java index 05dd678564d..f77d4e8c297 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java @@ -56,7 +56,8 @@ public void testRecognizesOriginForm() { assertFalse(HttpUtil.isOriginForm(URI.create("*"))); } - @Test public void testRecognizesAsteriskForm() { + @Test + public void testRecognizesAsteriskForm() { // Asterisk form: https://tools.ietf.org/html/rfc7230#section-5.3.4 assertTrue(HttpUtil.isAsteriskForm(URI.create("*"))); // Origin form: https://tools.ietf.org/html/rfc7230#section-5.3.1 @@ -67,6 +68,26 @@ public void testRecognizesOriginForm() { assertFalse(HttpUtil.isAsteriskForm(URI.create("www.example.com:80"))); } + @ParameterizedTest + @ValueSource(strings = { + "http://localhost/\r\n", + "/r\r\n?q=1", + "http://localhost/\r\n?q=1", + "/r\r\n/?q=1", + "http://localhost/\r\n/?q=1", + "/r\r\n", + "http://localhost/ HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/r HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "GET ", + " GET", + "HTTP/ 1.1", + "HTTP/\r0.9", + "HTTP/\n1.1", + }) + public void requestLineTokenValidationMustRejectInvalidTokens(String token) throws Exception { + assertFalse(HttpUtil.isEncodingSafeStartLineToken(token)); + } + @Test public void testRemoveTransferEncodingIgnoreCase() { HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); @@ -514,4 +535,22 @@ public void testInvalidTokenChars(char invalidChar) { assertEquals(2, validateToken(asciiStringToken)); assertEquals(2, validateToken(token)); } + + @ParameterizedTest + @ValueSource(chars = { + // High-bit Truncation Candidates (verifying > 0xFF check) + // These characters are chosen because their lower 8 bits + // alias to valid US-ASCII 'tchar' values. + '\u0161', // 0x0161 truncates to 0x61 ('a') + '\u0121', // 0x0121 truncates to 0x21 ('!') + '\u0231', // 0x0231 truncates to 0x31 ('1') + '\u0361' // 0x0361 truncates to 0x61 ('a') + }) + public void testInvalidTokenCharsOutsideAsciiRange(char invalidChar) { + // We use a String here because AsciiString would truncate + // the char to a byte during construction. + String token = "GE" + invalidChar + 'T'; + assertEquals(2, validateToken(token), + String.format("Character U+%04X should be invalid", (int) invalidChar)); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java new file mode 100644 index 00000000000..d2971ac4726 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpVersionParsingTest { + + @Test + void testStandardVersions() { + HttpVersion v10 = HttpVersion.valueOf("HTTP/1.0"); + HttpVersion v11 = HttpVersion.valueOf("HTTP/1.1"); + + assertSame(HttpVersion.HTTP_1_0, v10); + assertSame(HttpVersion.HTTP_1_1, v11); + + assertEquals("HTTP", v10.protocolName()); + assertEquals(1, v10.majorVersion()); + assertEquals(0, v10.minorVersion()); + + assertEquals("HTTP", v11.protocolName()); + assertEquals(1, v11.majorVersion()); + assertEquals(1, v11.minorVersion()); + } + + @Test + void testLowerCaseProtocolNameNonStrict() { + HttpVersion version = HttpVersion.valueOf("http/1.1"); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(1, version.minorVersion()); + assertEquals("HTTP/1.1", version.text()); + } + + @Test + void testMixedCaseProtocolNameNonStrict() { + HttpVersion version = HttpVersion.valueOf("hTtP/1.0"); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(0, version.minorVersion()); + assertEquals("HTTP/1.0", version.text()); + } + + @Test + void testCustomLowerCaseProtocolNonStrict() { + HttpVersion version = HttpVersion.valueOf("mqtt/5.0"); + assertEquals("MQTT", version.protocolName()); + assertEquals(5, version.majorVersion()); + assertEquals(0, version.minorVersion()); + assertEquals("MQTT/5.0", version.text()); + } + + @Test + void testCustomVersionNonStrict() { + HttpVersion version = HttpVersion.valueOf("MyProto/2.3"); + assertEquals("MYPROTO", version.protocolName()); // uppercased + assertEquals(2, version.majorVersion()); + assertEquals(3, version.minorVersion()); + assertEquals("MYPROTO/2.3", version.text()); + } + + @Test + void testCustomVersionStrict() { + HttpVersion version = new HttpVersion("HTTP/1.1", true, true); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(1, version.minorVersion()); + } + + @Test + void testCustomVersionStrictFailsOnLongVersion() { + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new HttpVersion("HTTP/10.1", true, true); + } + }); + assertTrue(ex.getMessage().contains("invalid version format")); + } + + @Test + void testInvalidFormatMissingSlash() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + HttpVersion.valueOf("HTTP1.1"); + } + }); + } + + @Test + void testInvalidFormatWhitespaceInProtocol() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + HttpVersion.valueOf("HT TP/1.1"); + } + }); + } + + @ParameterizedTest + @ValueSource(strings = { + "HTTP ", + " HTTP", + "H TTP", + " HTTP ", + "HTTP\r", + "HTTP\n", + "HTTP\r\n", + "HTT\rP", + "HTT\nP", + "HTT\r\nP", + "\rHTTP", + "\nHTTP", + "\r\nHTTP", + " \r\nHTTP", + "\r \nHTTP", + "\r\n HTTP", + "\r\nHTTP ", + "\nHTTP ", + "\rHTTP ", + "\r HTTP", + " \rHTTP", + "\nHTTP ", + "\n HTTP", + " \nHTTP", + "HTTP \n", + "HTTP \r", + " HTTP\r", + " HTTP\r", + "HTTP \n", + " HTTP\n", + " HTTP\n", + "HTT\nTP", + "HTT\rTP", + " HTT\rP", + " HTT\rP", + "HTT\nTP", + " HTT\nP", + " HTT\nP", + }) + void httpVersionMustRejectIllegalTokens(String protocol) { + try { + HttpVersion httpVersion = new HttpVersion(protocol, 1, 0, true); + // If no exception is thrown, then the version must have been sanitized and made safe. + assertTrue(HttpUtil.isEncodingSafeStartLineToken(httpVersion.text())); + } catch (IllegalArgumentException ignore) { + // Throwing is good. + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index 8b3065fbb89..d76f8d3f04f 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -21,12 +21,17 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.DefaultHttpHeadersFactory; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpUtil; import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.Test; @@ -512,6 +517,154 @@ public void simpleRequestDoNotAllowPrivateNetwork() { assertTrue(ReferenceCountUtil.release(response)); } + @Test + public void preflightEmptyLastDiscarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + + Object outbound = ch.readOutbound(); + assertNotNull(outbound); // preflight response + + LastHttpContent lastHttpContent = LastHttpContent.EMPTY_LAST_CONTENT; + assertFalse(ch.writeInbound(lastHttpContent)); + + // Nothing should have been forwarded + assertNull(ch.readInbound()); + + assertFalse(ch.finish()); + } + + @Test + public void preflightSecondEmptyLastForwardedAfterFirstDiscard() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + ReferenceCountUtil.release(ch.readOutbound()); + + LastHttpContent first = LastHttpContent.EMPTY_LAST_CONTENT; + LastHttpContent second = LastHttpContent.EMPTY_LAST_CONTENT; + + assertFalse(ch.writeInbound(first)); + + assertFalse(ch.writeInbound(second)); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightSecondNonEmptyLastDiscarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + ReferenceCountUtil.release(ch.readOutbound()); + + LastHttpContent first = LastHttpContent.EMPTY_LAST_CONTENT; + LastHttpContent second = new DefaultLastHttpContent( + Unpooled.copiedBuffer("test message", CharsetUtil.UTF_8)); + + assertFalse(ch.writeInbound(first)); + assertFalse(ch.writeInbound(second)); + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightNonEmptyLastForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/x"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + Object outbound = ch.releaseOutbound(); + assertNotNull(outbound); + + LastHttpContent nonEmpty = new DefaultLastHttpContent(Unpooled.copiedBuffer("x", CharsetUtil.UTF_8)); + assertFalse(ch.writeInbound(nonEmpty)); + + Object inbound = ch.readInbound(); + assertNull(inbound); + + assertFalse(ch.finish()); + } + + @Test + public void testNormalRequestForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, GET, "/test"); + req.headers().set(ORIGIN, "http://allowed"); + + assertTrue(ch.writeInbound(req)); + + LastHttpContent last = LastHttpContent.EMPTY_LAST_CONTENT; + assertTrue(ch.writeInbound(last)); + + Object firstInbound = ch.readInbound(); + Object secondInbound = ch.readInbound(); + + assertNotNull(firstInbound); + assertNotNull(secondInbound); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightEmptyLastDiscardedThenNewRequestForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + // Preflight request + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/pre"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertFalse(ch.writeInbound(preflight)); + Object preflightResp = ch.readOutbound(); + assertNotNull(preflightResp); + ReferenceCountUtil.release(preflightResp); + + // Empty last content should be discarded + assertFalse(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + assertNull(ch.readInbound()); + + // New request should be forwarded + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, GET, "/next"); + req.headers().set(ORIGIN, "http://allowed"); + assertTrue(ch.writeInbound(req)); + + Object firstInbound = ch.readInbound(); + assertNotNull(firstInbound); + + HttpContent content = new DefaultHttpContent(Unpooled.copiedBuffer("test message", CharsetUtil.UTF_8)); + assertTrue(ch.writeInbound(content)); + Object secondInbound = ch.readInbound(); + assertNotNull(secondInbound); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java index 786891d604d..5e46bfe7e71 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java @@ -23,6 +23,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; @@ -36,6 +37,7 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; public class PerMessageDeflateClientExtensionHandshakerTest { @@ -243,4 +245,44 @@ public void testDecoderNoClientContext() { assertFalse(decoderChannel.finish()); } + + @Test + public void testClientMaxWindowWithNoValue() { + // Test that client handles client_max_window_bits with no value (null) + // RFC 7692: client_max_window_bits may have no value + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, null); // No value specified + + // Should not throw NumberFormatException + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Handshake should succeed, using MAX_WINDOW_SIZE (15) as default + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + } + + @Test + public void testClientMaxWindowWithInvalidValue() { + // Test that client throws NumberFormatException for invalid client_max_window_bits value + final PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false, 0); + + final Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "invalid"); + + // Should throw NumberFormatException + assertThrows(NumberFormatException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + } + }); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java index e661e05a1a0..efaa7f88679 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java @@ -173,4 +173,40 @@ public void testCustomHandshake() { assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); assertTrue(data.parameters().isEmpty()); } + + @Test + public void testClientMaxWindowWithValue() { + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "12"); + + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + + WebSocketExtensionData data = extension.newReponseData(); + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + // Server should use the client's requested value (12) not the preferred (10) + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertEquals("12", data.parameters().get(CLIENT_MAX_WINDOW)); + } + + @Test + public void testClientMaxWindowWithInvalidValue() { + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "7"); // Below MIN_WINDOW_SIZE (8) + + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Handshake should fail when client_max_window_bits is out of range + assertNull(extension); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java index a0c2cd132a4..e496e5ef136 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java @@ -31,7 +31,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; public class SpdyFrameDecoderTest { @@ -841,7 +841,7 @@ public void testDiscardUnknownFrame() throws Exception { buf.writeLong(RANDOM.nextLong()); decoder.decode(buf); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(buf.isReadable()); buf.release(); } @@ -856,7 +856,7 @@ public void testDiscardUnknownEmptyFrame() throws Exception { encodeControlFrameHeader(buf, type, flags, length); decoder.decode(buf); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(buf.isReadable()); buf.release(); } @@ -878,7 +878,7 @@ public void testProgressivelyDiscardUnknownEmptyFrame() throws Exception { decoder.decode(header); decoder.decode(segment1); decoder.decode(segment2); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(header.isReadable()); assertFalse(segment1.isReadable()); assertFalse(segment2.isReadable()); diff --git a/codec-http2/pom.xml b/codec-http2/pom.xml index f363a73fbdb..8d51d999bfd 100644 --- a/codec-http2/pom.xml +++ b/codec-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-http2 diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java index 7747e4fa458..f0fce6e65ba 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java @@ -13,7 +13,6 @@ * License for the specific language governing permissions and limitations * under the License. */ - package io.netty.handler.codec.http2; import io.netty.channel.Channel; @@ -113,6 +112,8 @@ public abstract class AbstractHttp2ConnectionHandlerBuilder= maxSmallContinuationFrames) { + throw connectionError(ENHANCE_YOUR_CALM, + "Number of small consecutive continuations frames %d exceeds maximum: %d", + headersContinuation.numSmallFragments(), maxSmallContinuationFrames); + } } private void verifyUnknownFrame() throws Http2Exception { @@ -399,7 +415,6 @@ private void verifyUnknownFrame() throws Http2Exception { private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) throws Http2Exception { int padding = readPadding(payload); - verifyPadding(padding); // Determine how much data there is to read by removing the trailing // padding. @@ -414,7 +429,6 @@ private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, final int headersStreamId = streamId; final Http2Flags headersFlags = flags; final int padding = readPadding(payload); - verifyPadding(padding); // The callback that is invoked is different depending on whether priority information // is present in the headers frame. @@ -536,7 +550,6 @@ private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf paylo Http2FrameListener listener) throws Http2Exception { final int pushPromiseStreamId = streamId; final int padding = readPadding(payload); - verifyPadding(padding); final int promisedStreamId = readUnsignedInt(payload); // Create a handler that invokes the listener when the header block is complete. @@ -620,21 +633,19 @@ private int readPadding(ByteBuf payload) { return payload.readUnsignedByte() + 1; } - private void verifyPadding(int padding) throws Http2Exception { - int len = lengthWithoutTrailingPadding(payloadLength, padding); - if (len < 0) { - throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding."); - } - } - /** * The padding parameter consists of the 1 byte pad length field and the trailing padding bytes. This method * returns the number of readable bytes without the trailing padding. */ - private static int lengthWithoutTrailingPadding(int readableBytes, int padding) { - return padding == 0 - ? readableBytes - : readableBytes - (padding - 1); + private static int lengthWithoutTrailingPadding(int readableBytes, int padding) throws Http2Exception { + if (padding == 0) { + return readableBytes; + } + int n = readableBytes - (padding - 1); + if (n < 0) { + throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding."); + } + return n; } /** @@ -650,6 +661,15 @@ private abstract class HeadersContinuation { */ abstract int getStreamId(); + /** + * Return the number of fragments that were used so far. + * + * @return the number of fragments + */ + final int numSmallFragments() { + return builder.numSmallFragments(); + } + /** * Processes the next fragment for the current header block. * @@ -678,6 +698,7 @@ final void close() { */ protected class HeadersBlockBuilder { private ByteBuf headerBlock; + private int numSmallFragments; /** * The local header size maximum has been exceeded while accumulating bytes. @@ -688,6 +709,15 @@ private void headerSizeExceeded() throws Http2Exception { headerListSizeExceeded(headersDecoder.configuration().maxHeaderListSizeGoAway()); } + /** + * Return the number of fragments that was used so far. + * + * @return number of fragments. + */ + int numSmallFragments() { + return numSmallFragments; + } + /** * Adds a fragment to the block. * @@ -699,6 +729,11 @@ private void headerSizeExceeded() throws Http2Exception { */ final void addFragment(ByteBuf fragment, int len, ByteBufAllocator alloc, boolean endOfHeaders) throws Http2Exception { + if (maxSmallContinuationFrames > 0 && !endOfHeaders && len < FRAGMENT_THRESHOLD) { + // Only count of the fragment is not the end of header and if its < 8kb. + numSmallFragments++; + } + if (headerBlock == null) { if (len > headersDecoder.configuration().maxHeaderListSizeGoAway()) { headerSizeExceeded(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java index 73e497ccb8c..0587cf49c12 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java @@ -361,6 +361,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception buf.release(); return; } + + // Also take padding into account. + incrementDecompressedBytes(padding); + incrementDecompressedBytes(buf.readableBytes()); // Immediately return the bytes back to the flow controller. ConsumedBytesConverter will convert // from the decompressed amount which the user knows about to the compressed amount which flow diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java index f68ad765d84..9343736fffb 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -116,7 +116,7 @@ public final class Http2CodecUtil { public static final int SMALLEST_MAX_CONCURRENT_STREAMS = 100; static final int DEFAULT_MAX_RESERVED_STREAMS = SMALLEST_MAX_CONCURRENT_STREAMS; static final int DEFAULT_MIN_ALLOCATION_CHUNK = 1024; - + static final int DEFAULT_MAX_SMALL_CONTINUATION_FRAME = 16; /** * Calculate the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. * @param maxHeaderListSize diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 61e9cd1213b..1b554972b66 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -248,6 +248,10 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) thro byteDecoder.decode(ctx, in, out); } } catch (Throwable e) { + if (byteDecoder != null) { + // Skip all bytes before we report the exception as + in.skipBytes(in.readableBytes()); + } onError(ctx, false, e); } } @@ -256,13 +260,6 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) thro public void channelActive(ChannelHandlerContext ctx) throws Exception { // The channel just became active - send the connection preface to the remote endpoint. sendPreface(ctx); - - if (flushPreface) { - // As we don't know if any channelReadComplete() events will be triggered at all we need to ensure we - // also flush. Otherwise the remote peer might never see the preface / settings frame. - // See https://github.com/netty/netty/issues/12089 - ctx.flush(); - } } @Override @@ -346,12 +343,17 @@ private boolean verifyFirstFrameIsSettings(ByteBuf in) throws Http2Exception { } short frameType = in.getUnsignedByte(in.readerIndex() + 3); - short flags = in.getUnsignedByte(in.readerIndex() + 4); - if (frameType != SETTINGS || (flags & Http2Flags.ACK) != 0) { + if (frameType != SETTINGS) { throw connectionError(PROTOCOL_ERROR, "First received frame was not SETTINGS. " + "Hex dump for first 5 bytes: %s", hexDump(in, in.readerIndex(), 5)); } + short flags = in.getUnsignedByte(in.readerIndex() + 4); + if ((flags & Http2Flags.ACK) != 0) { + throw connectionError(PROTOCOL_ERROR, "First received frame was SETTINGS frame but had ACK flag set. " + + "Hex dump for first 5 bytes: %s", + hexDump(in, in.readerIndex(), 5)); + } return true; } @@ -375,11 +377,20 @@ private void sendPreface(ChannelHandlerContext ctx) throws Exception { encoder.writeSettings(ctx, initialSettings, ctx.newPromise()).addListener( ChannelFutureListener.CLOSE_ON_FAILURE); - if (isClient) { - // If this handler is extended by the user and we directly fire the userEvent from this context then - // the user will not see the event. We should fire the event starting with this handler so this class - // (and extending classes) have a chance to process the event. - userEventTriggered(ctx, Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE); + try { + if (isClient) { + // If this handler is extended by the user and we directly fire the userEvent from this context then + // the user will not see the event. We should fire the event starting with this handler so this + // class (and extending classes) have a chance to process the event. + userEventTriggered(ctx, Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE); + } + } finally { + if (flushPreface) { + // As we don't know if any channelReadComplete() events will be triggered at all we need to ensure + // we also flush. Otherwise the remote peer might never see the preface / settings frame. + // See https://github.com/netty/netty/issues/12089 + ctx.flush(); + } } } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java index d4bd2fe5a3a..2a4a1320d0b 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java @@ -203,6 +203,17 @@ public Http2FrameCodecBuilder encoderEnforceMaxRstFramesPerWindow( return super.encoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); } + @Override + public int decoderEnforceMaxSmallContinuationFrames() { + return super.decoderEnforceMaxSmallContinuationFrames(); + } + + @Override + public Http2FrameCodecBuilder decoderEnforceMaxSmallContinuationFrames( + int maxConsecutiveContinuationsFrames) { + return super.decoderEnforceMaxSmallContinuationFrames(maxConsecutiveContinuationsFrames); + } + /** * Build a {@link Http2FrameCodec} object. */ @@ -216,7 +227,8 @@ public Http2FrameCodec build() { Long maxHeaderListSize = initialSettings().maxHeaderListSize(); Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? new DefaultHttp2HeadersDecoder(isValidateHeaders()) : - new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize), + decoderEnforceMaxSmallContinuationFrames()); if (frameLogger() != null) { frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java index 65a1f471555..945c232b7a1 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java @@ -221,6 +221,17 @@ public Http2MultiplexCodecBuilder encoderEnforceMaxRstFramesPerWindow( return super.encoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); } + @Override + public int decoderEnforceMaxSmallContinuationFrames() { + return super.decoderEnforceMaxSmallContinuationFrames(); + } + + @Override + public Http2MultiplexCodecBuilder decoderEnforceMaxSmallContinuationFrames( + int maxConsecutiveContinuationsFrames) { + return super.decoderEnforceMaxSmallContinuationFrames(maxConsecutiveContinuationsFrames); + } + @Override public Http2MultiplexCodec build() { Http2FrameWriter frameWriter = this.frameWriter; @@ -231,7 +242,8 @@ public Http2MultiplexCodec build() { Long maxHeaderListSize = initialSettings().maxHeaderListSize(); Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? new DefaultHttp2HeadersDecoder(isValidateHeaders()) : - new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize), + decoderEnforceMaxSmallContinuationFrames()); if (frameLogger() != null) { frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java index e50038a051b..651cce73723 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java @@ -69,6 +69,10 @@ public boolean equals(Object obj) { return id == otherId.id && parentId.equals(otherId.parentId); } + public int getSequenceId() { + return id; + } + @Override public String toString() { return asShortText(); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java index 9bece58a3ba..2c0cf61908e 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java @@ -39,7 +39,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; @@ -50,6 +51,7 @@ import java.net.InetSocketAddress; import java.util.Random; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; @@ -89,6 +91,7 @@ public class DataCompressionHttp2Test { private Http2Connection clientConnection; private Http2ConnectionHandler clientHandler; private ByteArrayOutputStream serverOut; + private final AtomicReference serverException = new AtomicReference(); @BeforeAll public static void beforeAllTests() throws Throwable { @@ -148,8 +151,9 @@ public void teardown() throws InterruptedException { clientGroup.sync(); } - @Test - public void justHeadersNoData() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void justHeadersNoData(final int padding) throws Exception { bootstrapEnv(0); final Http2Headers headers = new DefaultHttp2Headers().method(GET).path(PATH) .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); @@ -157,17 +161,18 @@ public void justHeadersNoData() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); awaitServer(); verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true)); + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(padding), eq(true)); } - @Test - public void gzipEncodingSingleEmptyMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void gzipEncodingSingleEmptyMessage(final int padding) throws Exception { final String text = ""; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); bootstrapEnv(data.readableBytes()); @@ -178,8 +183,8 @@ public void gzipEncodingSingleEmptyMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -190,8 +195,9 @@ public void run() throws Http2Exception { } } - @Test - public void gzipEncodingSingleMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void gzipEncodingSingleMessage(final int padding) throws Exception { final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); bootstrapEnv(data.readableBytes()); @@ -202,8 +208,8 @@ public void gzipEncodingSingleMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -214,8 +220,9 @@ public void run() throws Http2Exception { } } - @Test - public void gzipEncodingMultipleMessages() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void gzipEncodingMultipleMessages(final int padding) throws Exception { final String text1 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff"; final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes()); @@ -228,9 +235,9 @@ public void gzipEncodingMultipleMessages() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data1.retain(), 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data2.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data1.retain(), padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data2.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -242,8 +249,9 @@ public void run() throws Http2Exception { } } - @Test - public void brotliEncodingSingleEmptyMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void brotliEncodingSingleEmptyMessage(final int padding) throws Exception { final String text = ""; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); bootstrapEnv(data.readableBytes()); @@ -254,8 +262,8 @@ public void brotliEncodingSingleEmptyMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -266,8 +274,9 @@ public void run() throws Http2Exception { } } - @Test - public void brotliEncodingSingleMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void brotliEncodingSingleMessage(final int padding) throws Exception { final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); bootstrapEnv(data.readableBytes()); @@ -278,8 +287,8 @@ public void brotliEncodingSingleMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -290,8 +299,9 @@ public void run() throws Http2Exception { } } - @Test - public void zstdEncodingSingleEmptyMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void zstdEncodingSingleEmptyMessage(final int padding) throws Exception { final String text = ""; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); bootstrapEnv(data.readableBytes()); @@ -302,8 +312,8 @@ public void zstdEncodingSingleEmptyMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -314,8 +324,9 @@ public void run() throws Http2Exception { } } - @Test - public void zstdEncodingSingleMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void zstdEncodingSingleMessage(final int padding) throws Exception { final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); bootstrapEnv(data.readableBytes()); @@ -326,8 +337,8 @@ public void zstdEncodingSingleMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -338,8 +349,9 @@ public void run() throws Http2Exception { } } - @Test - public void snappyEncodingSingleEmptyMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void snappyEncodingSingleEmptyMessage(final int padding) throws Exception { final String text = ""; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); bootstrapEnv(data.readableBytes()); @@ -350,8 +362,8 @@ public void snappyEncodingSingleEmptyMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -362,8 +374,9 @@ public void run() throws Http2Exception { } } - @Test - public void snappyEncodingSingleMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void snappyEncodingSingleMessage(final int padding) throws Exception { final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); bootstrapEnv(data.readableBytes()); @@ -374,8 +387,8 @@ public void snappyEncodingSingleMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -386,8 +399,9 @@ public void run() throws Http2Exception { } } - @Test - public void deflateEncodingWriteLargeMessage() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void deflateEncodingWriteLargeMessage(final int padding) throws Exception { final int BUFFER_SIZE = 1 << 12; final byte[] bytes = new byte[BUFFER_SIZE]; new Random().nextBytes(bytes); @@ -400,8 +414,8 @@ public void deflateEncodingWriteLargeMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); @@ -417,6 +431,7 @@ private void bootstrapEnv(int serverOutSize) throws Exception { final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); serverOut = new ByteArrayOutputStream(serverOutSize); serverLatch = new CountDownLatch(1); + serverException.set(null); sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -466,7 +481,18 @@ protected void initChannel(Channel ch) throws Exception { Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(serverConnection, encoder, new DefaultHttp2FrameReader()); Http2ConnectionHandler connectionHandler = new Http2ConnectionHandlerBuilder() - .frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener, 0)) + .frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener, 0) { + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream) throws Http2Exception { + try { + return super.onDataRead(ctx, streamId, data, padding, endOfStream); + } catch (Http2Exception e) { + serverException.set(e); + throw e; + } + } + }) .codec(decoder, encoder).build(); p.addLast(connectionHandler); serverChannelLatch.countDown(); @@ -521,6 +547,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc private void awaitServer() throws Exception { assertTrue(serverLatch.await(5, SECONDS)); serverOut.flush(); + Throwable cause = serverException.get(); + if (cause != null) { + throw new AssertionError("Server-side decompression error", cause); + } } private ChannelHandlerContext ctxClient() { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java index 35863d6c06e..97117dac95e 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java @@ -109,6 +109,59 @@ public void readHeaderFrameAndContinuationFrame() throws Http2Exception { } } + @Test + public void readHeaderFrameAndContinuationFrameExceedMax() throws Http2Exception { + frameReader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(true), 2); + final int streamId = 1; + + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(false)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo2", "bar2"), + new Http2Flags().endOfHeaders(false)); + + Http2Exception ex = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + assertEquals(Http2Error.ENHANCE_YOUR_CALM, ex.error()); + } finally { + input.release(); + } + } + + @Test + public void readHeaderFrameAndContinuationFrameDontExceedMax() throws Http2Exception { + frameReader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(true), 2); + final int streamId = 1; + + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(false)); + frameReader.readFrame(ctx, input, listener); + } finally { + input.release(); + } + } + @Test public void readUnknownFrame() throws Http2Exception { ByteBuf input = Unpooled.buffer(); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java index 04acf60d55f..23bc51786ce 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java @@ -29,6 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -66,7 +67,7 @@ protected void initChannel(SocketChannel ch) { } }); - ChannelFuture channelFuture = serverBootstrap.bind(0).sync(); + ChannelFuture channelFuture = serverBootstrap.bind(NetUtil.LOCALHOST, 0).sync(); final Bootstrap bootstrap = new Bootstrap() .group(eventLoopGroup) diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java index 0ded0e1ef38..16e4eefa799 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java @@ -53,7 +53,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** @@ -140,7 +140,7 @@ public void windowUpdateShouldChangeConnectionWindow() throws Http2Exception { assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -151,7 +151,7 @@ public void windowUpdateShouldChangeStreamWindow() throws Http2Exception { assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -159,10 +159,10 @@ public void payloadSmallerThanWindowShouldBeWrittenImmediately() throws Http2Exc FakeFlowControlled data = new FakeFlowControlled(5); sendData(STREAM_A, data); data.assertNotWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); controller.writePendingBytes(); data.assertFullyWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -172,7 +172,7 @@ public void emptyPayloadShouldBeWrittenImmediately() throws Http2Exception { data.assertNotWritten(); controller.writePendingBytes(); data.assertFullyWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -238,7 +238,7 @@ public void stalledStreamShouldQueuePayloads() throws Http2Exception { sendData(STREAM_A, moreData); controller.writePendingBytes(); moreData.assertNotWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -260,7 +260,7 @@ public void queuedPayloadsReceiveErrorOnStreamClose() throws Http2Exception { connection.stream(STREAM_A).close(); data.assertError(Http2Error.STREAM_CLOSED); moreData.assertError(Http2Error.STREAM_CLOSED); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -748,7 +748,7 @@ public void execute() throws Throwable { verify(flowControlled, never()).writeComplete(); assertEquals(90, windowBefore - window(STREAM_A)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java index 4c48e2780dc..ca1b395e417 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -53,10 +53,13 @@ import java.util.concurrent.atomic.AtomicBoolean; import static io.netty.buffer.Unpooled.copiedBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeaderInternal; import static io.netty.handler.codec.http2.Http2Error.CANCEL; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise; @@ -78,7 +81,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** @@ -303,9 +306,11 @@ public void clientShouldSendClientPrefaceStringWhenActive() throws Exception { when(connection.isServer()).thenReturn(false); when(channel.isActive()).thenReturn(false); handler = newHandler(); + verify(ctx, never()).flush(); when(channel.isActive()).thenReturn(true); handler.channelActive(ctx); verify(ctx).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); } @Test @@ -313,9 +318,29 @@ public void serverShouldNotSendClientPrefaceStringWhenActive() throws Exception when(connection.isServer()).thenReturn(true); when(channel.isActive()).thenReturn(false); handler = newHandler(); + verify(ctx, never()).flush(); when(channel.isActive()).thenReturn(true); handler.channelActive(ctx); verify(ctx, never()).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); + } + + @Test + public void clientShouldSendClientPrefaceStringWhenAddedAfterActive() throws Exception { + when(connection.isServer()).thenReturn(false); + when(channel.isActive()).thenReturn(true); + handler = newHandler(); + verify(ctx).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); + } + + @Test + public void serverShouldNotSendClientPrefaceStringWhenAddedAfterActive() throws Exception { + when(connection.isServer()).thenReturn(true); + when(channel.isActive()).thenReturn(true); + handler = newHandler(); + verify(ctx, never()).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); } @Test @@ -329,6 +354,20 @@ public void serverReceivingInvalidClientPrefaceStringShouldHandleException() thr assertEquals(0, captor.getValue().refCnt()); } + @Test + public void serverReceivingInvalidClientSettingsAfterPrefaceShouldHandleException() throws Exception { + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, 0, SETTINGS, new Http2Flags().ack(true), 0); + + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + handler.channelRead(ctx, Unpooled.wrappedBuffer(connectionPrefaceBuf(), buf)); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + assertEquals(0, captor.getValue().refCnt()); + } + @Test public void serverReceivingHttp1ClientPrefaceStringShouldIncludePreface() throws Exception { when(connection.isServer()).thenReturn(true); @@ -716,7 +755,8 @@ public void canCloseStreamWithVoidPromise() throws Exception { @Test public void channelReadCompleteTriggersFlush() throws Exception { - handler = newHandler(); + // Create the handler in a way that it will flush the preface by itself + handler = newHandler(false); handler.channelReadComplete(ctx); verify(ctx, times(1)).flush(); } @@ -748,7 +788,7 @@ public void clientChannelClosedDoesNotSendGoAwayBeforePreface() throws Exception handler = newHandler(); when(channel.isActive()).thenReturn(true); handler.close(ctx, promise); - verifyZeroInteractions(frameWriter); + verifyNoInteractions(frameWriter); } @Test diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java index c16eba07673..485499dbb72 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java @@ -1016,4 +1016,23 @@ private void assertInboundStreamFrame(int expectedId, Http2StreamFrame streamFra private ChannelHandlerContext eqFrameCodecCtx() { return eq(frameCodec.ctx); } + + @Test + public void invalidPayloadLength() throws Exception { + frameInboundWriter.writeInboundSettings(new Http2Settings()); + channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{ + 0, 0, 4, // length + 0, // type: DATA + 9, // flags: PADDED, END_STREAM + 1, 0, 0, 0, // stream id + 4, // pad length + 0, 0, 0 // not enough space for padding + })); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java index 9dec606d712..0b0c34385f1 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java @@ -109,7 +109,9 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - queue.add(msg); + synchronized (queue) { + queue.add(msg); + } } @Override @@ -119,7 +121,9 @@ public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - queue.add(new UserEvent(evt)); + synchronized (queue) { + queue.add(new UserEvent(evt)); + } } @Override @@ -142,11 +146,13 @@ public void checkException() throws Exception { @SuppressWarnings("unchecked") public T readInbound() { - for (int i = 0; i < queue.size(); i++) { - Object o = queue.get(i); - if (!(o instanceof UserEvent)) { - queue.remove(i); - return (T) o; + synchronized (queue) { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (!(o instanceof UserEvent)) { + queue.remove(i); + return (T) o; + } } } @@ -163,11 +169,13 @@ public T blockingReadInbound() { @SuppressWarnings("unchecked") public T readUserEvent() { - for (int i = 0; i < queue.size(); i++) { - Object o = queue.get(i); - if (o instanceof UserEvent) { - queue.remove(i); - return (T) ((UserEvent) o).evt; + synchronized (queue) { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (o instanceof UserEvent) { + queue.remove(i); + return (T) ((UserEvent) o).evt; + } } } @@ -179,14 +187,16 @@ public T readUserEvent() { */ @SuppressWarnings("unchecked") public T readInboundMessageOrUserEvent() { - if (queue.isEmpty()) { - return null; - } - Object o = queue.remove(0); - if (o instanceof UserEvent) { - return (T) ((UserEvent) o).evt; + synchronized (queue) { + if (queue.isEmpty()) { + return null; + } + Object o = queue.remove(0); + if (o instanceof UserEvent) { + return (T) ((UserEvent) o).evt; + } + return (T) o; } - return (T) o; } public void writeOutbound(Object... msgs) throws Exception { diff --git a/codec-memcache/pom.xml b/codec-memcache/pom.xml index f5b785a1331..65076403a6a 100644 --- a/codec-memcache/pom.xml +++ b/codec-memcache/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-memcache diff --git a/codec-mqtt/pom.xml b/codec-mqtt/pom.xml index 1a6aaabe7a1..f61acde36b7 100644 --- a/codec-mqtt/pom.xml +++ b/codec-mqtt/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-mqtt diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java index 9a601ead1c2..eec9999c2c7 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java @@ -287,7 +287,7 @@ private static ByteBuf encodeSubscribeMessage( // Payload for (MqttTopicSubscription topic : payload.topicSubscriptions()) { - writeUnsafeUTF8String(buf, topic.topicName()); + writeEagerUTF8String(buf, topic.topicName()); if (mqttVersion == MqttVersion.MQTT_3_1_1 || mqttVersion == MqttVersion.MQTT_3_1) { buf.writeByte(topic.qualityOfService().value()); } else { @@ -347,7 +347,7 @@ private static ByteBuf encodeUnsubscribeMessage( // Payload for (String topicName : payload.topics()) { - writeUnsafeUTF8String(buf, topicName); + writeEagerUTF8String(buf, topicName); } return buf; @@ -720,15 +720,6 @@ private static void writeEagerUTF8String(ByteBuf buf, String s) { buf.setShort(writerIndex, utf8Length); } - private static void writeUnsafeUTF8String(ByteBuf buf, String s) { - final int writerIndex = buf.writerIndex(); - final int startUtf8String = writerIndex + 2; - // no need to reserve any capacity here, already done earlier: that's why is Unsafe - buf.writerIndex(startUtf8String); - final int utf8Length = s != null? reserveAndWriteUtf8(buf, s, 0) : 0; - buf.setShort(writerIndex, utf8Length); - } - private static int getVariableLengthInt(int num) { int count = 0; do { diff --git a/codec-redis/pom.xml b/codec-redis/pom.xml index 38b73c9326b..ced0039b057 100644 --- a/codec-redis/pom.xml +++ b/codec-redis/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-redis diff --git a/codec-smtp/pom.xml b/codec-smtp/pom.xml index b0896da7ebc..0f6447970ff 100644 --- a/codec-smtp/pom.xml +++ b/codec-smtp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-smtp diff --git a/codec-socks/pom.xml b/codec-socks/pom.xml index 2397bb2e3fa..1bd1f657607 100644 --- a/codec-socks/pom.xml +++ b/codec-socks/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-socks diff --git a/codec-stomp/pom.xml b/codec-stomp/pom.xml index 285373a69b3..ad6232207fb 100644 --- a/codec-stomp/pom.xml +++ b/codec-stomp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-stomp diff --git a/codec-xml/pom.xml b/codec-xml/pom.xml index 8b7b69babfb..95b20be435a 100644 --- a/codec-xml/pom.xml +++ b/codec-xml/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec-xml diff --git a/codec/pom.xml b/codec/pom.xml index d121be92666..57dc222411b 100644 --- a/codec/pom.xml +++ b/codec/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-codec @@ -74,7 +74,7 @@ true - org.lz4 + at.yawk.lz4 lz4-java true diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 3c341543438..ba28386b4a3 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -27,15 +27,17 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import java.util.ArrayDeque; import java.util.List; +import java.util.Queue; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.util.internal.ObjectUtil.checkPositive; -import static java.lang.Integer.MAX_VALUE; /** - * {@link ChannelInboundHandlerAdapter} which decodes bytes in a stream-like fashion from one {@link ByteBuf} to an - * other Message type. - * + * {@link ChannelInboundHandlerAdapter} which decodes bytes in a stream-like fashion from one {@link ByteBuf} to + * another Message type. + *

* For example here is an implementation which reads all readable bytes from * the input {@link ByteBuf} and create a new {@link ByteBuf}. * @@ -66,7 +68,7 @@ * is not always the case. Use in.getInt(in.readerIndex()) instead. *

Pitfalls

*

- * Be aware that sub-classes of {@link ByteToMessageDecoder} MUST NOT + * Be aware that subclasses of {@link ByteToMessageDecoder} MUST NOT * annotated with {@link @Sharable}. *

* Some methods such as {@link ByteBuf#readBytes(int)} will cause a memory leak if the returned buffer @@ -162,6 +164,8 @@ public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) private static final byte STATE_CALLING_CHILD_DECODE = 1; private static final byte STATE_HANDLER_REMOVED_PENDING = 2; + // Used to guard the inputs for reentrant channelRead calls + private Queue inputMessages; ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; @@ -279,49 +283,60 @@ public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof ByteBuf) { - selfFiredChannelRead = true; - CodecOutputList out = CodecOutputList.newInstance(); - try { - first = cumulation == null; - cumulation = cumulator.cumulate(ctx.alloc(), - first ? Unpooled.EMPTY_BUFFER : cumulation, (ByteBuf) msg); - callDecode(ctx, cumulation, out); - } catch (DecoderException e) { - throw e; - } catch (Exception e) { - throw new DecoderException(e); - } finally { - try { - if (cumulation != null && !cumulation.isReadable()) { - numReads = 0; + public void channelRead(ChannelHandlerContext ctx, Object input) throws Exception { + if (decodeState == STATE_INIT) { + do { + if (input instanceof ByteBuf) { + selfFiredChannelRead = true; + CodecOutputList out = CodecOutputList.newInstance(); + try { + first = cumulation == null; + cumulation = cumulator.cumulate(ctx.alloc(), + first ? EMPTY_BUFFER : cumulation, (ByteBuf) input); + callDecode(ctx, cumulation, out); + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { try { - cumulation.release(); - } catch (IllegalReferenceCountException e) { - //noinspection ThrowFromFinallyBlock - throw new IllegalReferenceCountException( - getClass().getSimpleName() + "#decode() might have released its input buffer, " + - "or passed it down the pipeline without a retain() call, " + - "which is not allowed.", e); + if (cumulation != null && !cumulation.isReadable()) { + numReads = 0; + try { + cumulation.release(); + } catch (IllegalReferenceCountException e) { + //noinspection ThrowFromFinallyBlock + throw new IllegalReferenceCountException( + getClass().getSimpleName() + + "#decode() might have released its input buffer, " + + "or passed it down the pipeline without a retain() call, " + + "which is not allowed.", e); + } + cumulation = null; + } else if (++numReads >= discardAfterReads) { + // We did enough reads already try to discard some bytes, so we not risk to see a OOME. + // See https://github.com/netty/netty/issues/4275 + numReads = 0; + discardSomeReadBytes(); + } + + int size = out.size(); + firedChannelRead |= out.insertSinceRecycled(); + fireChannelRead(ctx, out, size); + } finally { + out.recycle(); } - cumulation = null; - } else if (++numReads >= discardAfterReads) { - // We did enough reads already try to discard some bytes, so we not risk to see a OOME. - // See https://github.com/netty/netty/issues/4275 - numReads = 0; - discardSomeReadBytes(); } - - int size = out.size(); - firedChannelRead |= out.insertSinceRecycled(); - fireChannelRead(ctx, out, size); - } finally { - out.recycle(); + } else { + ctx.fireChannelRead(input); } - } + } while (inputMessages != null && (input = inputMessages.poll()) != null); } else { - ctx.fireChannelRead(msg); + // Reentrant call. Bail out here and let original call process our message. + if (inputMessages == null) { + inputMessages = new ArrayDeque(2); + } + inputMessages.offer(input); } } @@ -529,12 +544,14 @@ final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in, try { decode(ctx, in, out); } finally { - boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; - decodeState = STATE_INIT; - if (removePending) { - fireChannelRead(ctx, out, out.size()); - out.clear(); - handlerRemoved(ctx); + if (inputMessages == null || inputMessages.isEmpty()) { + boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; + decodeState = STATE_INIT; + if (removePending) { + fireChannelRead(ctx, out, out.size()); + out.clear(); + handlerRemoved(ctx); + } } } } @@ -558,7 +575,7 @@ static ByteBuf expandCumulation(ByteBufAllocator alloc, ByteBuf oldCumulation, B int oldBytes = oldCumulation.readableBytes(); int newBytes = in.readableBytes(); int totalBytes = oldBytes + newBytes; - ByteBuf newCumulation = alloc.buffer(alloc.calculateNewCapacity(totalBytes, MAX_VALUE)); + ByteBuf newCumulation = alloc.buffer(alloc.calculateNewCapacity(totalBytes, Integer.MAX_VALUE)); ByteBuf toRelease = newCumulation; try { // This avoids redundant checks and stack depth compared to calling writeBytes(...) diff --git a/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java index 4a38db51be3..8808a3ebf85 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java @@ -32,6 +32,8 @@ */ public final class BrotliDecoder extends ByteToMessageDecoder { + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private enum State { DONE, NEEDS_MORE_INPUT, ERROR } @@ -48,6 +50,7 @@ private enum State { private DecoderJNI.Wrapper decoder; private boolean destroyed; private boolean needsRead; + private ByteBuf accumBuffer; /** * Creates a new BrotliDecoder with a default 8kB input buffer @@ -67,10 +70,25 @@ public BrotliDecoder(int inputBufferSize) { private void forwardOutput(ChannelHandlerContext ctx) { ByteBuffer nativeBuffer = decoder.pull(); // nativeBuffer actually wraps brotli's internal buffer so we need to copy its content - ByteBuf copy = ctx.alloc().buffer(nativeBuffer.remaining()); - copy.writeBytes(nativeBuffer); + int remaining = nativeBuffer.remaining(); + if (accumBuffer == null) { + accumBuffer = ctx.alloc().buffer(remaining); + } + accumBuffer.writeBytes(nativeBuffer); needsRead = false; - ctx.fireChannelRead(copy); + if (accumBuffer.readableBytes() >= DEFAULT_MAX_FORWARD_BYTES) { + ctx.fireChannelRead(accumBuffer); + accumBuffer = null; + } + } + + private void flushAccumBuffer(ChannelHandlerContext ctx) { + if (accumBuffer != null && accumBuffer.isReadable()) { + ctx.fireChannelRead(accumBuffer); + } else if (accumBuffer != null) { + accumBuffer.release(); + } + accumBuffer = null; } private State decompress(ChannelHandlerContext ctx, ByteBuf input) { @@ -145,6 +163,8 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } catch (Exception e) { destroy(); throw e; + } finally { + flushAccumBuffer(ctx); } } diff --git a/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java b/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java index d2a06f95287..833b2f8f7cc 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java @@ -16,11 +16,15 @@ package io.netty.handler.codec.compression; import io.netty.buffer.ByteBuf; +import io.netty.util.internal.SystemPropertyUtil; import java.nio.ByteBuffer; final class CompressionUtil { + static final int DEFAULT_MAX_FORWARD_BYTES = SystemPropertyUtil.getInt( + "io.netty.compression.defaultMaxForwardBytes", 64 * 1024); + private CompressionUtil() { } static void checkChecksum(ByteBufChecksum checksum, ByteBuf uncompressed, int currentChecksum) { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java index 51bdd670aa8..81f259f0a0d 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java @@ -28,6 +28,8 @@ public class JZlibDecoder extends ZlibDecoder { private final Inflater z = new Inflater(); private byte[] dictionary; + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private final int maxForwardBytes; private boolean needsRead; private volatile boolean finished; @@ -78,6 +80,7 @@ public JZlibDecoder(ZlibWrapper wrapper) { */ public JZlibDecoder(ZlibWrapper wrapper, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; ObjectUtil.checkNotNull(wrapper, "wrapper"); @@ -113,6 +116,7 @@ public JZlibDecoder(byte[] dictionary) { */ public JZlibDecoder(byte[] dictionary, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; this.dictionary = ObjectUtil.checkNotNull(dictionary, "dictionary"); int resultCode; resultCode = z.inflateInit(JZlib.W_ZLIB); @@ -174,7 +178,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t int outputLength = z.next_out_index - oldNextOutIndex; if (outputLength > 0) { decompressed.writerIndex(decompressed.writerIndex() + outputLength); - if (maxAllocation == 0) { + if (maxAllocation == 0 && decompressed.readableBytes() >= maxForwardBytes) { // If we don't limit the maximum allocations we should just // forward the buffer directly. ByteBuf buffer = decompressed; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java index 0ef03a217b7..ac2b75c8077 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java @@ -59,6 +59,9 @@ private enum GzipState { private int xlen = -1; private boolean needsRead; + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private final int maxForwardBytes; + private volatile boolean finished; private boolean decideZlibOrNone; @@ -161,6 +164,7 @@ public JdkZlibDecoder(boolean decompressConcatenated, int maxAllocation) { private JdkZlibDecoder(ZlibWrapper wrapper, byte[] dictionary, boolean decompressConcatenated, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; ObjectUtil.checkNotNull(wrapper, "wrapper"); @@ -265,9 +269,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t if (crc != null) { crc.update(outArray, outIndex, outputLength); } - if (maxAllocation == 0) { - // If we don't limit the maximum allocations we should just - // forward the buffer directly. + if (maxAllocation == 0 && decompressed.readableBytes() >= maxForwardBytes) { + // Forward the buffer once it exceeds the threshold to bound memory + // while avoiding excessive fireChannelRead calls. ByteBuf buffer = decompressed; decompressed = null; needsRead = false; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java index 05a35b14b92..3be2dcf0939 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java @@ -202,7 +202,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t boolean success = false; try { - decoder.decodeChunk(inputArray, inPos, outputArray, outPos, outPos + originalLength); + decoder.decodeChunk( + inputArray, inPos, inPos + chunkLength, + outputArray, outPos, outPos + originalLength); if (uncompressed.hasArray()) { uncompressed.writerIndex(uncompressed.writerIndex() + originalLength); } else { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java b/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java index 38793a97e6f..1397e123080 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java @@ -71,7 +71,7 @@ public static BrotliOptions brotli(int quality, int window, BrotliMode mode) { /** * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, - * {@link ZstdConstants#MAX_BLOCK_SIZE} + * {@link ZstdConstants#DEFAULT_MAX_ENCODE_SIZE} */ public static ZstdOptions zstd() { return ZstdOptions.DEFAULT; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java index 111372c3ede..b9a5aca6514 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java @@ -35,9 +35,9 @@ final class ZstdConstants { static final int MAX_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); /** - * Max block size + * Max encode size */ - static final int MAX_BLOCK_SIZE = 1 << (DEFAULT_COMPRESSION_LEVEL + 7) + 0x0F; // 32 M + static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE; /** * Default block size */ diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java index ef0bf1371d8..e63c04e19ad 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java @@ -41,7 +41,10 @@ public final class ZstdDecoder extends ByteToMessageDecoder { } } + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private final int maximumAllocationSize; + private final int maxForwardBytes; private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream(); private ZstdInputStreamNoFinalizer zstdIs; @@ -62,6 +65,7 @@ public ZstdDecoder() { public ZstdDecoder(int maximumAllocationSize) { this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize"); + this.maxForwardBytes = maximumAllocationSize > 0 ? maximumAllocationSize : DEFAULT_MAX_FORWARD_BYTES; } @Override @@ -101,13 +105,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } do { w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes()); - } while (w != -1 && outBuffer.isWritable()); - if (outBuffer.isReadable()) { + } while (w > 0 && outBuffer.isWritable()); + if (!outBuffer.isWritable() || outBuffer.readableBytes() >= maxForwardBytes) { needsRead = false; ctx.fireChannelRead(outBuffer); outBuffer = null; } - } while (w != -1); + } while (w > 0); + if (outBuffer != null && outBuffer.isReadable()) { + needsRead = false; + ctx.fireChannelRead(outBuffer); + outBuffer = null; + } } finally { if (outBuffer != null) { outBuffer.release(); diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java index 7ece3c2a643..36e8f364f75 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java @@ -28,7 +28,7 @@ import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; /** * Compresses a {@link ByteBuf} using the Zstandard algorithm. @@ -56,7 +56,7 @@ public final class ZstdEncoder extends MessageToByteEncoder { * please use {@link ZstdEncoder(int,int)} constructor */ public ZstdEncoder() { - this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE); } /** @@ -65,7 +65,7 @@ public ZstdEncoder() { * specifies the level of the compression */ public ZstdEncoder(int compressionLevel) { - this(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + this(compressionLevel, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE); } /** @@ -113,7 +113,9 @@ protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean while (remaining > 0) { int curSize = Math.min(blockSize, remaining); remaining -= curSize; - bufferSize += Zstd.compressBound(curSize); + // calculate the max compressed size with Zstd.compressBound since + // it returns the maximum size of the compressed data + bufferSize = Math.max(bufferSize, Zstd.compressBound(curSize)); } if (bufferSize > maxEncodeSize || 0 > bufferSize) { @@ -141,6 +143,11 @@ protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) { flushBufferedData(out); } } + // return the remaining data in the buffer + // when buffer size is smaller than the block size + if (buffer.isReadable()) { + flushBufferedData(out); + } } private void flushBufferedData(ByteBuf out) { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java index 8b6ce3c5550..583151aa040 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java @@ -21,7 +21,7 @@ import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; /** * {@link ZstdOptions} holds compressionLevel for @@ -36,9 +36,10 @@ public class ZstdOptions implements CompressionOptions { /** * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, - * {@link ZstdConstants#MAX_BLOCK_SIZE} + * {@link ZstdConstants#DEFAULT_MAX_ENCODE_SIZE} */ - static final ZstdOptions DEFAULT = new ZstdOptions(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + static final ZstdOptions DEFAULT = new ZstdOptions(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, + DEFAULT_MAX_ENCODE_SIZE); /** * Create a new {@link ZstdOptions} diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index 84f8c755559..e7069a542e7 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -684,4 +684,80 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { assertEquals(0, buffer.refCnt(), "Buffer should be released"); assertFalse(channel.finish()); } + + @Test + void reentrantReadSafety() throws Exception { + final EmbeddedChannel channel = new EmbeddedChannel(); + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + int reentrancy; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + reentrancy++; + if (reentrancy == 1) { + ByteBuf buf2 = channel.alloc().buffer(); + buf2.writeLong(42); // Adding 8 bytes. + assertFalse(channel.writeInbound(buf2)); // Reentrant call back into ByteToMessageDecoder + ctx.read(); + } + int bytes = in.readableBytes(); + out.add(bytes); + in.skipBytes(bytes); + } + }; + channel.pipeline().addLast(decoder); + ByteBuf buf1 = channel.alloc().buffer(); + buf1.writeInt(42); // Adding 4 bytes. + assertTrue(channel.writeInbound(buf1)); + Integer first = channel.readInbound(); + Integer second = channel.readInbound(); + assertEquals(4, first); + assertEquals(8, second); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + void reentrantReadThenRemoveSafety() throws Exception { + final EmbeddedChannel channel = new EmbeddedChannel(); + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + boolean removed; + int reentrancy; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + reentrancy++; + if (reentrancy == 1) { + ByteBuf buf2 = channel.alloc().buffer(); + buf2.writeLong(42); // Adding 8 bytes. + assertFalse(channel.writeInbound(buf2)); // Reentrant call back into ByteToMessageDecoder + ByteBuf buf3 = channel.alloc().buffer(); + buf3.writeShort(42); // Adding 2 bytes. + assertFalse(channel.writeInbound(buf3)); // Reentrant call back into ByteToMessageDecoder + ctx.read(); + } else if (reentrancy == 2) { + ctx.pipeline().remove(this); + } + int bytes = in.readableBytes(); + out.add(bytes); + in.skipBytes(bytes); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + removed = true; + } + }; + channel.pipeline().addLast(decoder); + ByteBuf buf1 = channel.alloc().buffer(); + buf1.writeInt(42); // Adding 4 bytes. + assertTrue(channel.writeInbound(buf1)); + Integer first = channel.readInbound(); + Integer second = channel.readInbound(); + Integer third = channel.readInbound(); + assertEquals(4, first); + assertEquals(8, second); + assertEquals(2, third); + assertFalse(channel.finishAndReleaseAll()); + } } diff --git a/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java b/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java index 296e3dac7ea..250ed34587a 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java @@ -22,7 +22,9 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; @@ -30,9 +32,10 @@ import java.io.InputStream; - import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; public class ZstdEncoderTest extends AbstractEncoderTest { @@ -46,6 +49,14 @@ public void setup() { when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT); } + public static ByteBuf[] hugeData() { + final byte[] bytesHuge = new byte[36 * 1024 * 1024]; + ByteBuf heap = Unpooled.wrappedBuffer(bytesHuge); + ByteBuf direct = Unpooled.directBuffer(bytesHuge.length); + direct.writeBytes(bytesHuge); + return new ByteBuf[] {heap, direct}; + } + @Override public EmbeddedChannel createChannel() { return new EmbeddedChannel(new ZstdEncoder()); @@ -54,6 +65,16 @@ public EmbeddedChannel createChannel() { @ParameterizedTest @MethodSource("largeData") public void testCompressionOfLargeBatchedFlow(final ByteBuf data) throws Exception { + testCompressionOfLargeDataBatchedFlow(data); + } + + @ParameterizedTest + @MethodSource("hugeData") + public void testCompressionOfHugeBatchedFlow(final ByteBuf data) throws Exception { + testCompressionOfLargeDataBatchedFlow(data); + } + + private void testCompressionOfLargeDataBatchedFlow(final ByteBuf data) throws Exception { final int dataLength = data.readableBytes(); int written = 0; @@ -78,6 +99,18 @@ public void testCompressionOfSmallBatchedFlow(final ByteBuf data) throws Excepti testCompressionOfBatchedFlow(data); } + @Test + public void testCompressionOfTinyData() throws Exception { + ByteBuf data = Unpooled.copiedBuffer("Hello, World", CharsetUtil.UTF_8); + assertTrue(channel.writeOutbound(data)); + assertTrue(channel.finish()); + + ByteBuf out = channel.readOutbound(); + assertThat(out.readableBytes()).isPositive(); + out.release(); + assertNull(channel.readOutbound()); + } + @Override protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { InputStream is = new ByteBufInputStream(compressed, true); diff --git a/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java b/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java index 575dcddb993..620875ac9ab 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java @@ -17,7 +17,7 @@ import io.netty.channel.embedded.EmbeddedChannel; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; public class ZstdIntegrationTest extends AbstractIntegrationTest { @@ -25,7 +25,7 @@ public class ZstdIntegrationTest extends AbstractIntegrationTest { @Override protected EmbeddedChannel createEncoder() { - return new EmbeddedChannel(new ZstdEncoder(BLOCK_SIZE, MAX_BLOCK_SIZE)); + return new EmbeddedChannel(new ZstdEncoder(BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE)); } @Override diff --git a/common/pom.xml b/common/pom.xml index 81587e4e9f0..120ed12430d 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-common diff --git a/common/src/main/java/io/netty/util/DefaultAttributeMap.java b/common/src/main/java/io/netty/util/DefaultAttributeMap.java index a39bb5b996f..39a2a28df09 100644 --- a/common/src/main/java/io/netty/util/DefaultAttributeMap.java +++ b/common/src/main/java/io/netty/util/DefaultAttributeMap.java @@ -68,11 +68,12 @@ private static void orderedCopyOnInsert(DefaultAttribute[] sortedSrc, int srcLen int i; for (i = srcLength - 1; i >= 0; i--) { DefaultAttribute attribute = sortedSrc[i]; - assert attribute.key.id() != id; - if (attribute.key.id() < id) { + int attributeKeyId = attribute.key.id(); + assert attributeKeyId != id; + if (attributeKeyId < id) { break; } - copy[i + 1] = sortedSrc[i]; + copy[i + 1] = attribute; } copy[i + 1] = toInsert; final int toCopy = i + 1; @@ -153,7 +154,6 @@ private void removeAttributeIfMatch(AttributeKey key, DefaultAttribute } } - @SuppressWarnings("serial") private static final class DefaultAttribute extends AtomicReference implements Attribute { private static final AtomicReferenceFieldUpdater MAP_UPDATER = diff --git a/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java b/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java new file mode 100644 index 00000000000..b6a36770704 --- /dev/null +++ b/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java @@ -0,0 +1,1550 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Written by Doug Lea with assistance from members of JCP JSR-166 + * Expert Group and released to the public domain, as explained at + * https://creativecommons.org/publicdomain/zero/1.0/ + * + * With substantial modifications by The Netty Project team. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.LongCounter; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ThreadLocalRandom; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A scalable concurrent multimap implementation. + * The map is sorted according to the natural ordering of its {@code int} keys. + * + *

This class implements a concurrent variant of SkipLists + * providing expected average log(n) time cost for the + * {@code containsKey}, {@code get}, {@code put} and + * {@code remove} operations and their variants. Insertion, removal, + * update, and access operations safely execute concurrently by + * multiple threads. + * + *

This class is a multimap, which means the same key can be associated with + * multiple values. Each such instance will be represented by a separate + * {@code IntEntry}. There is no defined ordering for the values mapped to + * the same key. + * + *

As a multimap, certain atomic operations like {@code putIfPresent}, + * {@code compute}, or {@code computeIfPresent}, cannot be supported. + * Likewise, some get-like operations cannot be supported. + * + *

Iterators and spliterators are + * weakly consistent. + * + *

All {@code IntEntry} pairs returned by methods in this class + * represent snapshots of mappings at the time they were + * produced. They do not support the {@code Entry.setValue} + * method. (Note however that it is possible to change mappings in the + * associated map using {@code put}, {@code putIfAbsent}, or + * {@code replace}, depending on exactly which effect you need.) + * + *

Beware that bulk operations {@code putAll}, {@code equals}, + * {@code toArray}, {@code containsValue}, and {@code clear} are + * not guaranteed to be performed atomically. For example, an + * iterator operating concurrently with a {@code putAll} operation + * might view only some of the added elements. + * + *

This class does not permit the use of {@code null} values + * because some null return values cannot be reliably distinguished from + * the absence of elements. + * + * @param the type of mapped values + */ +public class ConcurrentSkipListIntObjMultimap implements Iterable> { + /* + * This class implements a tree-like two-dimensionally linked skip + * list in which the index levels are represented in separate + * nodes from the base nodes holding data. There are two reasons + * for taking this approach instead of the usual array-based + * structure: 1) Array based implementations seem to encounter + * more complexity and overhead 2) We can use cheaper algorithms + * for the heavily-traversed index lists than can be used for the + * base lists. Here's a picture of some of the basics for a + * possible list with 2 levels of index: + * + * Head nodes Index nodes + * +-+ right +-+ +-+ + * |2|---------------->| |--------------------->| |->null + * +-+ +-+ +-+ + * | down | | + * v v v + * +-+ +-+ +-+ +-+ +-+ +-+ + * |1|----------->| |->| |------>| |----------->| |------>| |->null + * +-+ +-+ +-+ +-+ +-+ +-+ + * v | | | | | + * Nodes next v v v v v + * +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ + * | |->|A|->|B|->|C|->|D|->|E|->|F|->|G|->|H|->|I|->|J|->|K|->null + * +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ + * + * The base lists use a variant of the HM linked ordered set + * algorithm. See Tim Harris, "A pragmatic implementation of + * non-blocking linked lists" + * https://www.cl.cam.ac.uk/~tlh20/publications.html and Maged + * Michael "High Performance Dynamic Lock-Free Hash Tables and + * List-Based Sets" + * https://www.research.ibm.com/people/m/michael/pubs.htm. The + * basic idea in these lists is to mark the "next" pointers of + * deleted nodes when deleting to avoid conflicts with concurrent + * insertions, and when traversing to keep track of triples + * (predecessor, node, successor) in order to detect when and how + * to unlink these deleted nodes. + * + * Rather than using mark-bits to mark list deletions (which can + * be slow and space-intensive using AtomicMarkedReference), nodes + * use direct CAS'able next pointers. On deletion, instead of + * marking a pointer, they splice in another node that can be + * thought of as standing for a marked pointer (see method + * unlinkNode). Using plain nodes acts roughly like "boxed" + * implementations of marked pointers, but uses new nodes only + * when nodes are deleted, not for every link. This requires less + * space and supports faster traversal. Even if marked references + * were better supported by JVMs, traversal using this technique + * might still be faster because any search need only read ahead + * one more node than otherwise required (to check for trailing + * marker) rather than unmasking mark bits or whatever on each + * read. + * + * This approach maintains the essential property needed in the HM + * algorithm of changing the next-pointer of a deleted node so + * that any other CAS of it will fail, but implements the idea by + * changing the pointer to point to a different node (with + * otherwise illegal null fields), not by marking it. While it + * would be possible to further squeeze space by defining marker + * nodes not to have key/value fields, it isn't worth the extra + * type-testing overhead. The deletion markers are rarely + * encountered during traversal, are easily detected via null + * checks that are needed anyway, and are normally quickly garbage + * collected. (Note that this technique would not work well in + * systems without garbage collection.) + * + * In addition to using deletion markers, the lists also use + * nullness of value fields to indicate deletion, in a style + * similar to typical lazy-deletion schemes. If a node's value is + * null, then it is considered logically deleted and ignored even + * though it is still reachable. + * + * Here's the sequence of events for a deletion of node n with + * predecessor b and successor f, initially: + * + * +------+ +------+ +------+ + * ... | b |------>| n |----->| f | ... + * +------+ +------+ +------+ + * + * 1. CAS n's value field from non-null to null. + * Traversals encountering a node with null value ignore it. + * However, ongoing insertions and deletions might still modify + * n's next pointer. + * + * 2. CAS n's next pointer to point to a new marker node. + * From this point on, no other nodes can be appended to n. + * which avoids deletion errors in CAS-based linked lists. + * + * +------+ +------+ +------+ +------+ + * ... | b |------>| n |----->|marker|------>| f | ... + * +------+ +------+ +------+ +------+ + * + * 3. CAS b's next pointer over both n and its marker. + * From this point on, no new traversals will encounter n, + * and it can eventually be GCed. + * +------+ +------+ + * ... | b |----------------------------------->| f | ... + * +------+ +------+ + * + * A failure at step 1 leads to simple retry due to a lost race + * with another operation. Steps 2-3 can fail because some other + * thread noticed during a traversal a node with null value and + * helped out by marking and/or unlinking. This helping-out + * ensures that no thread can become stuck waiting for progress of + * the deleting thread. + * + * Skip lists add indexing to this scheme, so that the base-level + * traversals start close to the locations being found, inserted + * or deleted -- usually base level traversals only traverse a few + * nodes. This doesn't change the basic algorithm except for the + * need to make sure base traversals start at predecessors (here, + * b) that are not (structurally) deleted, otherwise retrying + * after processing the deletion. + * + * Index levels are maintained using CAS to link and unlink + * successors ("right" fields). Races are allowed in index-list + * operations that can (rarely) fail to link in a new index node. + * (We can't do this of course for data nodes.) However, even + * when this happens, the index lists correctly guide search. + * This can impact performance, but since skip lists are + * probabilistic anyway, the net result is that under contention, + * the effective "p" value may be lower than its nominal value. + * + * Index insertion and deletion sometimes require a separate + * traversal pass occurring after the base-level action, to add or + * remove index nodes. This adds to single-threaded overhead, but + * improves contended multithreaded performance by narrowing + * interference windows, and allows deletion to ensure that all + * index nodes will be made unreachable upon return from a public + * remove operation, thus avoiding unwanted garbage retention. + * + * Indexing uses skip list parameters that maintain good search + * performance while using sparser-than-usual indices: The + * hardwired parameters k=1, p=0.5 (see method doPut) mean that + * about one-quarter of the nodes have indices. Of those that do, + * half have one level, a quarter have two, and so on (see Pugh's + * Skip List Cookbook, sec 3.4), up to a maximum of 62 levels + * (appropriate for up to 2^63 elements). The expected total + * space requirement for a map is slightly less than for the + * current implementation of java.util.TreeMap. + * + * Changing the level of the index (i.e, the height of the + * tree-like structure) also uses CAS. Creation of an index with + * height greater than the current level adds a level to the head + * index by CAS'ing on a new top-most head. To maintain good + * performance after a lot of removals, deletion methods + * heuristically try to reduce the height if the topmost levels + * appear to be empty. This may encounter races in which it is + * possible (but rare) to reduce and "lose" a level just as it is + * about to contain an index (that will then never be + * encountered). This does no structural harm, and in practice + * appears to be a better option than allowing unrestrained growth + * of levels. + * + * This class provides concurrent-reader-style memory consistency, + * ensuring that read-only methods report status and/or values no + * staler than those holding at method entry. This is done by + * performing all publication and structural updates using + * (volatile) CAS, placing an acquireFence in a few access + * methods, and ensuring that linked objects are transitively + * acquired via dependent reads (normally once) unless performing + * a volatile-mode CAS operation (that also acts as an acquire and + * release). This form of fence-hoisting is similar to RCU and + * related techniques (see McKenney's online book + * https://www.kernel.org/pub/linux/kernel/people/paulmck/perfbook/perfbook.html) + * It minimizes overhead that may otherwise occur when using so + * many volatile-mode reads. Using explicit acquireFences is + * logistically easier than targeting particular fields to be read + * in acquire mode: fences are just hoisted up as far as possible, + * to the entry points or loop headers of a few methods. A + * potential disadvantage is that these few remaining fences are + * not easily optimized away by compilers under exclusively + * single-thread use. It requires some care to avoid volatile + * mode reads of other fields. (Note that the memory semantics of + * a reference dependently read in plain mode exactly once are + * equivalent to those for atomic opaque mode.) Iterators and + * other traversals encounter each node and value exactly once. + * Other operations locate an element (or position to insert an + * element) via a sequence of dereferences. This search is broken + * into two parts. Method findPredecessor (and its specialized + * embeddings) searches index nodes only, returning a base-level + * predecessor of the key. Callers carry out the base-level + * search, restarting if encountering a marker preventing link + * modification. In some cases, it is possible to encounter a + * node multiple times while descending levels. For mutative + * operations, the reported value is validated using CAS (else + * retrying), preserving linearizability with respect to each + * other. Others may return any (non-null) value holding in the + * course of the method call. (Search-based methods also include + * some useless-looking explicit null checks designed to allow + * more fields to be nulled out upon removal, to reduce floating + * garbage, but which is not currently done, pending discovery of + * a way to do this with less impact on other operations.) + * + * To produce random values without interference across threads, + * we use within-JDK thread local random support (via the + * "secondary seed", to avoid interference with user-level + * ThreadLocalRandom.) + * + * For explanation of algorithms sharing at least a couple of + * features with this one, see Mikhail Fomitchev's thesis + * (https://www.cs.yorku.ca/~mikhail/), Keir Fraser's thesis + * (https://www.cl.cam.ac.uk/users/kaf24/), and Hakan Sundell's + * thesis (https://www.cs.chalmers.se/~phs/). + * + * Notation guide for local variables + * Node: b, n, f, p for predecessor, node, successor, aux + * Index: q, r, d for index node, right, down. + * Head: h + * Keys: k, key + * Values: v, value + * Comparisons: c + */ + + /** No-key sentinel value */ + private final int noKey; + /** Lazily initialized topmost index of the skiplist. */ + private volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Index head; + /** Element count */ + private final LongCounter adder; + + /** + * Nodes hold keys and values, and are singly linked in sorted + * order, possibly with some intervening marker nodes. The list is + * headed by a header node accessible as head.node. Headers and + * marker nodes have null keys. The val field (but currently not + * the key field) is nulled out upon deletion. + */ + static final class Node { + final int key; // currently, never detached + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ V val; + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Node next; + Node(int key, V value, Node next) { + this.key = key; + val = value; + this.next = next; + } + } + + /** + * Index nodes represent the levels of the skip list. + */ + static final class Index { + final Node node; // currently, never detached + final Index down; + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Index right; + Index(Node node, Index down, Index right) { + this.node = node; + this.down = down; + this.right = right; + } + } + + /** + * The multimap entry type with primitive {@code int} keys. + */ + public static final class IntEntry implements Comparable> { + private final int key; + private final V value; + + public IntEntry(int key, V value) { + this.key = key; + this.value = value; + } + + /** + * Get the corresponding key. + */ + public int getKey() { + return key; + } + + /** + * Get the corresponding value. + */ + public V getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof IntEntry)) { + return false; + } + + IntEntry intEntry = (IntEntry) o; + return key == intEntry.key && (value == intEntry.value || (value != null && value.equals(intEntry.value))); + } + + @Override + public int hashCode() { + int result = key; + result = 31 * result + (value == null ? 0 : value.hashCode()); + return result; + } + + @Override + public String toString() { + return "IntEntry[" + key + " => " + value + ']'; + } + + @Override + public int compareTo(IntEntry o) { + return cpr(key, o.key); + } + } + + /* ---------------- Utilities -------------- */ + + /** + * Compares using comparator or natural ordering if null. + * Called only by methods that have performed required type checks. + */ + static int cpr(int x, int y) { + return (x < y) ? -1 : x == y ? 0 : 1; + } + + /** + * Returns the header for base node list, or null if uninitialized + */ + final Node baseHead() { + Index h; + acquireFence(); + return (h = head) == null ? null : h.node; + } + + /** + * Tries to unlink deleted node n from predecessor b (if both + * exist), by first splicing in a marker if not already present. + * Upon return, node n is sure to be unlinked from b, possibly + * via the actions of some other thread. + * + * @param b if nonnull, predecessor + * @param n if nonnull, node known to be deleted + */ + static void unlinkNode(Node b, Node n, int noKey) { + if (b != null && n != null) { + Node f, p; + for (;;) { + if ((f = n.next) != null && f.key == noKey) { + p = f.next; // already marked + break; + } else if (NEXT.compareAndSet(n, f, + new Node(noKey, null, f))) { + p = f; // add marker + break; + } + } + NEXT.compareAndSet(b, n, p); + } + } + + /** + * Adds to element count, initializing adder if necessary + * + * @param c count to add + */ + private void addCount(long c) { + adder.add(c); + } + + /** + * Returns element count, initializing adder if necessary. + */ + final long getAdderCount() { + long c; + return (c = adder.value()) <= 0L ? 0L : c; // ignore transient negatives + } + + /* ---------------- Traversal -------------- */ + + /** + * Returns an index node with key strictly less than given key. + * Also unlinks indexes to deleted nodes found along the way. + * Callers rely on this side-effect of clearing indices to deleted + * nodes. + * + * @param key if nonnull the key + * @return a predecessor node of key, or null if uninitialized or null key + */ + private Node findPredecessor(int key) { + Index q; + acquireFence(); + if ((q = head) == null || key == noKey) { + return null; + } else { + for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { // unlink index to deleted node + RIGHT.compareAndSet(q, r, r.right); + } else if (cpr(key, k) > 0) { + q = r; + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + return q.node; + } + } + } + } + + /** + * Returns node holding key or null if no such, clearing out any + * deleted nodes seen along the way. Repeatedly traverses at + * base-level looking for key starting at predecessor returned + * from findPredecessor, processing base-level deletions as + * encountered. Restarts occur, at traversal step encountering + * node n, if n's key field is null, indicating it is a marker, so + * its predecessor is deleted before continuing, which we help do + * by re-finding a valid predecessor. The traversal loops in + * doPut, doRemove, and findNear all include the same checks. + * + * @param key the key + * @return node holding key, or null if no such + */ + private Node findNode(int key) { + if (key == noKey) { + throw new IllegalArgumentException(); // don't postpone errors + } + Node b; + outer: while ((b = findPredecessor(key)) != null) { + for (;;) { + Node n; int k; int c; + if ((n = b.next) == null) { + break outer; // empty + } else if ((k = n.key) == noKey) { + break; // b is deleted + } else if (n.val == null) { + unlinkNode(b, n, noKey); // n is deleted + } else if ((c = cpr(key, k)) > 0) { + b = n; + } else if (c == 0) { + return n; + } else { + break outer; + } + } + } + return null; + } + + /** + * Gets value for key. Same idea as findNode, except skips over + * deletions and markers, and returns first encountered value to + * avoid possibly inconsistent rereads. + * + * @param key the key + * @return the value, or null if absent + */ + private V doGet(int key) { + Index q; + acquireFence(); + if (key == noKey) { + throw new IllegalArgumentException(); + } + V result = null; + if ((q = head) != null) { + outer: for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; int k; V v; int c; + if ((p = r.node) == null || (k = p.key) == noKey || + (v = p.val) == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if ((c = cpr(key, k)) > 0) { + q = r; + } else if (c == 0) { + result = v; + break outer; + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + Node b, n; + if ((b = q.node) != null) { + while ((n = b.next) != null) { + V v; int c; + int k = n.key; + if ((v = n.val) == null || k == noKey || + (c = cpr(key, k)) > 0) { + b = n; + } else { + if (c == 0) { + result = v; + } + break; + } + } + } + break; + } + } + } + return result; + } + + /* ---------------- Insertion -------------- */ + + /** + * Main insertion method. Adds element if not present, or + * replaces value if present and onlyIfAbsent is false. + * + * @param key the key + * @param value the value that must be associated with key + * @param onlyIfAbsent if should not insert if already present + */ + private V doPut(int key, V value, boolean onlyIfAbsent) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + for (;;) { + Index h; Node b; + acquireFence(); + int levels = 0; // number of levels descended + if ((h = head) == null) { // try to initialize + Node base = new Node(noKey, null, null); + h = new Index(base, null, null); + b = HEAD.compareAndSet(this, null, h) ? base : null; + } else { + for (Index q = h, r, d;;) { // count while descending + while ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if (cpr(key, k) > 0) { + q = r; + } else { + break; + } + } + if ((d = q.down) != null) { + ++levels; + q = d; + } else { + b = q.node; + break; + } + } + } + if (b != null) { + Node z = null; // new node, if inserted + for (;;) { // find insertion point + Node n, p; int k; V v; int c; + if ((n = b.next) == null) { + if (b.key == noKey) { // if empty, type check key now TODO: remove? + cpr(key, key); + } + c = -1; + } else if ((k = n.key) == noKey) { + break; // can't append; restart + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + c = 1; + } else if ((c = cpr(key, k)) > 0) { + b = n; // Multimap +// } else if (c == 0 && +// (onlyIfAbsent || VAL.compareAndSet(n, v, value))) { +// return v; + } + + if (c <= 0 && + NEXT.compareAndSet(b, n, + p = new Node(key, value, n))) { + z = p; + break; + } + } + + if (z != null) { + int lr = ThreadLocalRandom.current().nextInt(); + if ((lr & 0x3) == 0) { // add indices with 1/4 prob + int hr = ThreadLocalRandom.current().nextInt(); + long rnd = ((long) hr << 32) | ((long) lr & 0xffffffffL); + int skips = levels; // levels to descend before add + Index x = null; + for (;;) { // create at most 62 indices + x = new Index(z, x, null); + if (rnd >= 0L || --skips < 0) { + break; + } else { + rnd <<= 1; + } + } + if (addIndices(h, skips, x, noKey) && skips < 0 && + head == h) { // try to add new level + Index hx = new Index(z, x, null); + Index nh = new Index(h.node, h, hx); + HEAD.compareAndSet(this, h, nh); + } + if (z.val == null) { // deleted while adding indices + findPredecessor(key); // clean + } + } + addCount(1L); + return null; + } + } + } + } + + /** + * Add indices after an insertion. Descends iteratively to the + * highest level of insertion, then recursively, to chain index + * nodes to lower ones. Returns null on (staleness) failure, + * disabling higher-level insertions. Recursion depths are + * exponentially less probable. + * + * @param q starting index for current level + * @param skips levels to skip before inserting + * @param x index for this insertion + */ + static boolean addIndices(Index q, int skips, Index x, int noKey) { + Node z; int key; + if (x != null && (z = x.node) != null && (key = z.key) != noKey && + q != null) { // hoist checks + boolean retrying = false; + for (;;) { // find splice point + Index r, d; int c; + if ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + c = 0; + } else if ((c = cpr(key, k)) > 0) { + q = r; + } else if (c == 0) { + break; // stale + } + } else { + c = -1; + } + + if (c < 0) { + if ((d = q.down) != null && skips > 0) { + --skips; + q = d; + } else if (d != null && !retrying && + !addIndices(d, 0, x.down, noKey)) { + break; + } else { + x.right = r; + if (RIGHT.compareAndSet(q, r, x)) { + return true; + } else { + retrying = true; // re-find splice point + } + } + } + } + } + return false; + } + + /* ---------------- Deletion -------------- */ + + /** + * Main deletion method. Locates node, nulls value, appends a + * deletion marker, unlinks predecessor, removes associated index + * nodes, and possibly reduces head index level. + * + * @param key the key + * @param value if non-null, the value that must be + * associated with key + * @return the node, or null if not found + */ + final V doRemove(int key, Object value) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + V result = null; + Node b; + outer: while ((b = findPredecessor(key)) != null && + result == null) { + for (;;) { + Node n; int k; V v; int c; + if ((n = b.next) == null) { + break outer; + } else if ((k = n.key) == noKey) { + break; + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else if ((c = cpr(key, k)) > 0) { + b = n; + } else if (c < 0) { + break outer; + } else if (value != null && !value.equals(v)) { +// break outer; + b = n; // Multimap. + } else if (VAL.compareAndSet(n, v, null)) { + result = v; + unlinkNode(b, n, noKey); + break; // loop to clean up + } + } + } + if (result != null) { + tryReduceLevel(); + addCount(-1L); + } + return result; + } + + /** + * Possibly reduce head level if it has no nodes. This method can + * (rarely) make mistakes, in which case levels can disappear even + * though they are about to contain index nodes. This impacts + * performance, not correctness. To minimize mistakes as well as + * to reduce hysteresis, the level is reduced by one only if the + * topmost three levels look empty. Also, if the removed level + * looks non-empty after CAS, we try to change it back quick + * before anyone notices our mistake! (This trick works pretty + * well because this method will practically never make mistakes + * unless current thread stalls immediately before first CAS, in + * which case it is very unlikely to stall again immediately + * afterwards, so will recover.) + *

+ * We put up with all this rather than just let levels grow + * because otherwise, even a small map that has undergone a large + * number of insertions and removals will have a lot of levels, + * slowing down access more than would an occasional unwanted + * reduction. + */ + private void tryReduceLevel() { + Index h, d, e; + if ((h = head) != null && h.right == null && + (d = h.down) != null && d.right == null && + (e = d.down) != null && e.right == null && + HEAD.compareAndSet(this, h, d) && + h.right != null) { // recheck + HEAD.compareAndSet(this, d, h); // try to backout + } + } + + /* ---------------- Finding and removing first element -------------- */ + + /** + * Gets first valid node, unlinking deleted nodes if encountered. + * @return first node or null if empty + */ + final Node findFirst() { + Node b, n; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if (n.val == null) { + unlinkNode(b, n, noKey); + } else { + return n; + } + } + } + return null; + } + + /** + * Entry snapshot version of findFirst + */ + final IntEntry findFirstEntry() { + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else { + return new IntEntry(n.key, v); + } + } + } + return null; + } + + /** + * Removes first entry; returns its snapshot. + * @return null if empty, else snapshot of first entry + */ + private IntEntry doRemoveFirstEntry() { + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) == null || VAL.compareAndSet(n, v, null)) { + int k = n.key; + unlinkNode(b, n, noKey); + if (v != null) { + tryReduceLevel(); + findPredecessor(k); // clean index + addCount(-1L); + return new IntEntry(k, v); + } + } + } + } + return null; + } + + /* ---------------- Finding and removing last element -------------- */ + + /** + * Specialized version of find to get last valid node. + * @return last node or null if empty + */ + final Node findLast() { + outer: for (;;) { + Index q; Node b; + acquireFence(); + if ((q = head) == null) { + break; + } + for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; + if ((p = r.node) == null || p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else { + q = r; + } + } + if ((d = q.down) != null) { + q = d; + } else { + b = q.node; + break; + } + } + if (b != null) { + for (;;) { + Node n; + if ((n = b.next) == null) { + if (b.key == noKey) { // empty + break outer; + } else { + return b; + } + } else if (n.key == noKey) { + break; + } else if (n.val == null) { + unlinkNode(b, n, noKey); + } else { + b = n; + } + } + } + } + return null; + } + + /** + * Entry version of findLast + * @return Entry for last node or null if empty + */ + final IntEntry findLastEntry() { + for (;;) { + Node n; V v; + if ((n = findLast()) == null) { + return null; + } + if ((v = n.val) != null) { + return new IntEntry(n.key, v); + } + } + } + + /** + * Removes last entry; returns its snapshot. + * Specialized variant of doRemove. + * @return null if empty, else snapshot of last entry + */ + private IntEntry doRemoveLastEntry() { + outer: for (;;) { + Index q; Node b; + acquireFence(); + if ((q = head) == null) { + break; + } + for (;;) { + Index d, r; Node p; + while ((r = q.right) != null) { + if ((p = r.node) == null || p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if (p.next != null) { + q = r; // continue only if a successor + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + b = q.node; + break; + } + } + if (b != null) { + for (;;) { + Node n; int k; V v; + if ((n = b.next) == null) { + if (b.key == noKey) { // empty + break outer; + } else { + break; // retry + } + } else if ((k = n.key) == noKey) { + break; + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else if (n.next != null) { + b = n; + } else if (VAL.compareAndSet(n, v, null)) { + unlinkNode(b, n, noKey); + tryReduceLevel(); + findPredecessor(k); // clean index + addCount(-1L); + return new IntEntry(k, v); + } + } + } + } + return null; + } + + /* ---------------- Relational operations -------------- */ + + // Control values OR'ed as arguments to findNear + + private static final int EQ = 1; + private static final int LT = 2; + private static final int GT = 0; // Actually checked as !LT + + /** + * Variant of findNear returning IntEntry + * @param key the key + * @param rel the relation -- OR'ed combination of EQ, LT, GT + * @return Entry fitting relation, or null if no such + */ + final IntEntry findNearEntry(int key, int rel) { + for (;;) { + Node n; V v; + if ((n = findNear(key, rel)) == null) { + return null; + } + if ((v = n.val) != null) { + return new IntEntry(n.key, v); + } + } + } + + /** + * Utility for ceiling, floor, lower, higher methods. + * @param key the key + * @param rel the relation -- OR'ed combination of EQ, LT, GT + * @return nearest node fitting relation, or null if no such + */ + final Node findNear(int key, int rel) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + Node result; + outer: for (Node b;;) { + if ((b = findPredecessor(key)) == null) { + result = null; + break; // empty + } + for (;;) { + Node n; int k; int c; + if ((n = b.next) == null) { + result = (rel & LT) != 0 && b.key != noKey ? b : null; + break outer; + } else if ((k = n.key) == noKey) { + break; + } else if (n.val == null) { + unlinkNode(b, n, noKey); + } else if (((c = cpr(key, k)) == 0 && (rel & EQ) != 0) || + (c < 0 && (rel & LT) == 0)) { + result = n; + break outer; + } else if (c <= 0 && (rel & LT) != 0) { + result = b.key != noKey ? b : null; + break outer; + } else { + b = n; + } + } + } + return result; + } + + /* ---------------- Constructors -------------- */ + + /** + * Constructs a new, empty map, sorted according to the + * {@linkplain Comparable natural ordering} of the keys. + * @param noKey The value to use as a sentinel for signaling the absence of a key. + */ + public ConcurrentSkipListIntObjMultimap(int noKey) { + this.noKey = noKey; + adder = PlatformDependent.newLongCounter(); + } + + /* ------ Map API methods ------ */ + + /** + * Returns {@code true} if this map contains a mapping for the specified + * key. + * + * @param key key whose presence in this map is to be tested + * @return {@code true} if this map contains a mapping for the specified key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public boolean containsKey(int key) { + return doGet(key) != null; + } + + /** + * Returns the value to which the specified key is mapped, + * or {@code null} if this map contains no mapping for the key. + * + *

More formally, if this map contains a mapping from a key + * {@code k} to a value {@code v} such that {@code key} compares + * equal to {@code k} according to the map's ordering, then this + * method returns {@code v}; otherwise it returns {@code null}. + * (There can be at most one such mapping.) + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public V get(int key) { + return doGet(key); + } + + /** + * Returns the value to which the specified key is mapped, + * or the given defaultValue if this map contains no mapping for the key. + * + * @param key the key + * @param defaultValue the value to return if this map contains + * no mapping for the given key + * @return the mapping for the key, if present; else the defaultValue + * @throws NullPointerException if the specified key is null + * @since 1.8 + */ + public V getOrDefault(int key, V defaultValue) { + V v; + return (v = doGet(key)) == null ? defaultValue : v; + } + + /** + * Associates the specified value with the specified key in this map. + * If the map previously contained a mapping for the key, the old + * value is replaced. + * + * @param key key with which the specified value is to be associated + * @param value value to be associated with the specified key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key or value is null + */ + public void put(int key, V value) { + checkNotNull(value, "value"); + doPut(key, value, false); + } + + /** + * Removes the mapping for the specified key from this map if present. + * + * @param key key for which mapping should be removed + * @return the previous value associated with the specified key, or + * {@code null} if there was no mapping for the key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public V remove(int key) { + return doRemove(key, null); + } + + /** + * Returns {@code true} if this map maps one or more keys to the + * specified value. This operation requires time linear in the + * map size. Additionally, it is possible for the map to change + * during execution of this method, in which case the returned + * result may be inaccurate. + * + * @param value value whose presence in this map is to be tested + * @return {@code true} if a mapping to {@code value} exists; + * {@code false} otherwise + * @throws NullPointerException if the specified value is null + */ + public boolean containsValue(Object value) { + checkNotNull(value, "value"); + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) != null && value.equals(v)) { + return true; + } else { + b = n; + } + } + } + return false; + } + + /** + * Get the approximate size of the collection. + */ + public int size() { + long c; + return baseHead() == null ? 0 : + (c = getAdderCount()) >= Integer.MAX_VALUE ? + Integer.MAX_VALUE : (int) c; + } + + /** + * Check if the collection is empty. + */ + public boolean isEmpty() { + return findFirst() == null; + } + + /** + * Removes all of the mappings from this map. + */ + public void clear() { + Index h, r, d; Node b; + acquireFence(); + while ((h = head) != null) { + if ((r = h.right) != null) { // remove indices + RIGHT.compareAndSet(h, r, null); + } else if ((d = h.down) != null) { // remove levels + HEAD.compareAndSet(this, h, d); + } else { + long count = 0L; + if ((b = h.node) != null) { // remove nodes + Node n; V v; + while ((n = b.next) != null) { + if ((v = n.val) != null && + VAL.compareAndSet(n, v, null)) { + --count; + v = null; + } + if (v == null) { + unlinkNode(b, n, noKey); + } + } + } + if (count != 0L) { + addCount(count); + } else { + break; + } + } + } + } + + /* ------ ConcurrentMap API methods ------ */ + + /** + * Remove the specific entry with the given key and value, if it exist. + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public boolean remove(int key, Object value) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + return value != null && doRemove(key, value) != null; + } + + /** + * Replace the specific entry with the given key and value, with the given replacement value, + * if such an entry exist. + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if any of the arguments are null + */ + public boolean replace(int key, V oldValue, V newValue) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + checkNotNull(oldValue, "oldValue"); + checkNotNull(newValue, "newValue"); + for (;;) { + Node n; V v; + if ((n = findNode(key)) == null) { + return false; + } + if ((v = n.val) != null) { + if (!oldValue.equals(v)) { + return false; + } + if (VAL.compareAndSet(n, v, newValue)) { + return true; + } + } + } + } + + /* ------ SortedMap API methods ------ */ + + public int firstKey() { + Node n = findFirst(); + if (n == null) { + return noKey; + } + return n.key; + } + + public int lastKey() { + Node n = findLast(); + if (n == null) { + return noKey; + } + return n.key; + } + + /* ---------------- Relational operations -------------- */ + + /** + * Returns a key-value mapping associated with the greatest key + * strictly less than the given key, or {@code null} if there is + * no such key. The returned entry does not support the + * {@code Entry.setValue} method. + * + * @throws NullPointerException if the specified key is null + */ + public IntEntry lowerEntry(int key) { + return findNearEntry(key, LT); + } + + /** + * @throws NullPointerException if the specified key is null + */ + public int lowerKey(int key) { + Node n = findNear(key, LT); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the greatest key + * less than or equal to the given key, or {@code null} if there + * is no such key. The returned entry does not support + * the {@code Entry.setValue} method. + * + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public IntEntry floorEntry(int key) { + return findNearEntry(key, LT | EQ); + } + + /** + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public int floorKey(int key) { + Node n = findNear(key, LT | EQ); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least key + * greater than or equal to the given key, or {@code null} if + * there is no such entry. The returned entry does not + * support the {@code Entry.setValue} method. + * + * @throws NullPointerException if the specified key is null + */ + public IntEntry ceilingEntry(int key) { + return findNearEntry(key, GT | EQ); + } + + /** + * @throws NullPointerException if the specified key is null + */ + public int ceilingKey(int key) { + Node n = findNear(key, GT | EQ); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least key + * strictly greater than the given key, or {@code null} if there + * is no such key. The returned entry does not support + * the {@code Entry.setValue} method. + * + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public IntEntry higherEntry(int key) { + return findNearEntry(key, GT); + } + + /** + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public int higherKey(int key) { + Node n = findNear(key, GT); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least + * key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry firstEntry() { + return findFirstEntry(); + } + + /** + * Returns a key-value mapping associated with the greatest + * key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry lastEntry() { + return findLastEntry(); + } + + /** + * Removes and returns a key-value mapping associated with + * the least key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry pollFirstEntry() { + return doRemoveFirstEntry(); + } + + /** + * Removes and returns a key-value mapping associated with + * the greatest key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry pollLastEntry() { + return doRemoveLastEntry(); + } + + public IntEntry pollCeilingEntry(int key) { + // TODO optimize this + Node node; + V val; + do { + node = findNear(key, GT | EQ); + if (node == null) { + return null; + } + val = node.val; + } while (val == null || !remove(node.key, val)); + return new IntEntry(node.key, val); + } + + /* ---------------- Iterators -------------- */ + + /** + * Base of iterator classes + */ + abstract class Iter implements Iterator { + /** the last node returned by next() */ + Node lastReturned; + /** the next node to return from next(); */ + Node next; + /** Cache of next value field to maintain weak consistency */ + V nextValue; + + /** Initializes ascending iterator for entire range. */ + Iter() { + advance(baseHead()); + } + + @Override + public final boolean hasNext() { + return next != null; + } + + /** Advances next to higher entry. */ + final void advance(Node b) { + Node n = null; + V v = null; + if ((lastReturned = b) != null) { + while ((n = b.next) != null && (v = n.val) == null) { + b = n; + } + } + nextValue = v; + next = n; + } + + @Override + public final void remove() { + Node n; int k; + if ((n = lastReturned) == null || (k = n.key) == noKey) { + throw new IllegalStateException(); + } + // It would not be worth all of the overhead to directly + // unlink from here. Using remove is fast enough. + ConcurrentSkipListIntObjMultimap.this.remove(k, n.val); // TODO: inline and optimize this + lastReturned = null; + } + } + + final class EntryIterator extends Iter> { + @Override + public IntEntry next() { + Node n; + if ((n = next) == null) { + throw new NoSuchElementException(); + } + int k = n.key; + V v = nextValue; + advance(n); + return new IntEntry(k, v); + } + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + // VarHandle mechanics + private static final AtomicReferenceFieldUpdater, Index> HEAD; + private static final AtomicReferenceFieldUpdater, Node> NEXT; + private static final AtomicReferenceFieldUpdater, Object> VAL; + private static final AtomicReferenceFieldUpdater, Index> RIGHT; + private static volatile int acquireFenceVariable; + static { + Class> mapCls = cls(ConcurrentSkipListIntObjMultimap.class); + Class> indexCls = cls(Index.class); + Class> nodeCls = cls(Node.class); + + HEAD = AtomicReferenceFieldUpdater.newUpdater(mapCls, indexCls, "head"); + NEXT = AtomicReferenceFieldUpdater.newUpdater(nodeCls, nodeCls, "next"); + VAL = AtomicReferenceFieldUpdater.newUpdater(nodeCls, Object.class, "val"); + RIGHT = AtomicReferenceFieldUpdater.newUpdater(indexCls, indexCls, "right"); + } + + @SuppressWarnings("unchecked") + private static Class cls(Class cls) { + return (Class) cls; + } + + /** + * Orders LOADS before the fence, with LOADS and STORES after the fence. + */ + private static void acquireFence() { + // Volatile store prevent prior loads from ordering down. + acquireFenceVariable = 1; + // Volatile load prevent following loads and stores from ordering up. + int ignore = acquireFenceVariable; + // Note: Putting the volatile store before the volatile load ensures + // surrounding loads and stores don't order "into" the fence. + } +} diff --git a/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java b/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java index 4cf804888af..1640d0897bf 100644 --- a/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java +++ b/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java @@ -56,7 +56,7 @@ public MpscAtomicIntegerArrayQueue(int capacity, int emptyValue) { super(MathUtil.safeFindNextPositivePowerOfTwo(capacity)); if (emptyValue != 0) { this.emptyValue = emptyValue; - int end = capacity - 1; + int end = length() - 1; for (int i = 0; i < end; i++) { lazySet(i, emptyValue); } diff --git a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java index afdb4d5e7c6..b316ab621ed 100644 --- a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java +++ b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java @@ -258,6 +258,8 @@ public void run() { executor.execute(this); return; // done } catch (Throwable ignore) { + // Restore executingThread since we're continuing to execute tasks. + executingThread.set(current); // Reset the state back to running as we will keep on executing tasks. state.set(RUNNING); // if an error happened we should just ignore it and let the loop run again as there is not diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent.java b/common/src/main/java/io/netty/util/internal/PlatformDependent.java index 13421fdb240..32405d3f6ea 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent.java @@ -997,7 +997,7 @@ public static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2 * The resulting hash code will be case insensitive. */ public static int hashCodeAscii(byte[] bytes, int startPos, int length) { - return !hasUnsafe() || !unalignedAccess() ? + return !hasUnsafe() || !unalignedAccess() || BIG_ENDIAN_NATIVE_ORDER ? hashCodeAsciiSafe(bytes, startPos, length) : PlatformDependent0.hashCodeAscii(bytes, startPos, length); } diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java index 62a1ee0f539..950b93bc959 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java @@ -393,7 +393,7 @@ public Object run() { Class bitsClass = Class.forName("java.nio.Bits", false, getSystemClassLoader()); int version = javaVersion(); - if (unsafeStaticFieldOffsetSupported() && version >= 9) { + if (version >= 9) { // Java9/10 use all lowercase and later versions all uppercase. String fieldName = version >= 11? "MAX_MEMORY" : "maxMemory"; // On Java9 and later we try to directly access the field as we can do this without @@ -607,10 +607,6 @@ static boolean isVirtualThread(Thread thread) { } } - private static boolean unsafeStaticFieldOffsetSupported() { - return !RUNNING_IN_NATIVE_IMAGE; - } - static boolean isExplicitNoUnsafe() { return EXPLICIT_NO_UNSAFE_CAUSE != null; } diff --git a/common/src/main/java/io/netty/util/internal/ThrowableUtil.java b/common/src/main/java/io/netty/util/internal/ThrowableUtil.java index c33a19e5591..5af0c7ba883 100644 --- a/common/src/main/java/io/netty/util/internal/ThrowableUtil.java +++ b/common/src/main/java/io/netty/util/internal/ThrowableUtil.java @@ -84,4 +84,19 @@ public static Throwable[] getSuppressed(Throwable source) { } return source.getSuppressed(); } + + /** + * Capture the stack trace of the given thread, interrupt it, and attach the stack trace as a suppressed exception + * to the given cause. + * @param thread The thread to interrupt. + * @param cause The cause to attach a stack trace to. + */ + public static void interruptAndAttachAsyncStackTrace(Thread thread, Throwable cause) { + StackTraceElement[] stackTrace = thread.getStackTrace(); + InterruptedException asyncIE = new InterruptedException( + "Asynchronous interruption: " + thread); + thread.interrupt(); + asyncIE.setStackTrace(stackTrace); + addSuppressed(cause, asyncIE); + } } diff --git a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java index 5445b83d581..63ff46d80a5 100644 --- a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java +++ b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java @@ -16,6 +16,7 @@ package io.netty.util; import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.extension.DynamicTestInvocationContext; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.InvocationInterceptor; import org.junit.jupiter.api.extension.ReflectiveInvocationContext; @@ -37,6 +38,26 @@ public void interceptTestMethod( final Invocation invocation, final ReflectiveInvocationContext invocationContext, final ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + @Override + public void interceptTestTemplateMethod( + Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + @Override + public void interceptDynamicTest( + Invocation invocation, + DynamicTestInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + private static void proceed(final Invocation invocation) throws Throwable { final AtomicReference throwable = new AtomicReference(); Thread thread = new FastThreadLocalThread(new Runnable() { @Override diff --git a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java new file mode 100644 index 00000000000..a5f4b4fa9cb --- /dev/null +++ b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(RunInFastThreadLocalThreadExtension.class) +public class RunInFastThreadLocalThreadExtensionTest { + @Test + void normalTest() { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } + + @RepeatedTest(1) + void repeatedTest() { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } + + @ParameterizedTest + @ValueSource(ints = 1) + void parameterizedTest(int ignoreParameter) { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java b/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java new file mode 100644 index 00000000000..e3ffb84f785 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java @@ -0,0 +1,442 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap.IntEntry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.concurrent.ThreadLocalRandom; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ConcurrentSkipListIntObjMultimapTest { + private ConcurrentSkipListIntObjMultimap map; + private int noKey; + + @BeforeEach + void setUp() { + noKey = -1; + map = new ConcurrentSkipListIntObjMultimap(noKey); + } + + @Test + void addIterateAndRemoveEntries() throws Exception { + assertFalse(map.iterator().hasNext()); + map.put(1, "a"); + map.put(2, "b"); + assertFalse(map.isEmpty()); + assertEquals(2, map.size()); + IntEntry entry; + Iterator> itr = map.iterator(); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(1, "a"), entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(2, "b"), entry); + assertFalse(itr.hasNext()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + } + + @Test + void clearMustRemoveAllEntries() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(3, map.size()); + map.clear(); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + assertTrue(map.isEmpty()); + } + + @Test + void pollingFirstEntryOfUniqueKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(3, "c"), map.pollFirstEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollingLastEntryOfUniqueKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(new IntEntry(3, "c"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(1, "a"), map.pollLastEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void addMultipleEntriesForSameKey() throws Exception { + map.put(2, "b1"); + map.put(1, "a"); + map.put(2, "b2"); // second entry for the 2 key + map.put(3, "c"); + assertEquals(4, map.size()); + + IntEntry entry; + Iterator> itr = map.iterator(); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(1, "a"), entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + IntEntry otherB = entry; + itr.remove(); + assertThat(entry).isIn(new IntEntry(2, "b1"), new IntEntry(2, "b2")); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertThat(entry).isIn(new IntEntry(2, "b1"), new IntEntry(2, "b2")); + assertNotEquals(otherB, entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(3, "c"), entry); + assertFalse(itr.hasNext()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void iteratorRemoveSecondOfMultiMappedEntry(boolean withPriorRemoval) throws Exception { + map.put(1, "a"); + map.put(1, "b"); + + Iterator> itr = map.iterator(); + itr.next(); + IntEntry entry = itr.next(); + if (withPriorRemoval) { + map.remove(entry.getKey(), entry.getValue()); + } + itr.remove(); + assertEquals(1, map.size()); + if (entry.equals(new IntEntry(1, "a"))) { + assertEquals(new IntEntry(1, "b"), map.pollFirstEntry()); + } else { + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + } + } + + @Test + void firstKeyOrEntry() throws Exception { + assertEquals(noKey, map.firstKey()); + assertNull(map.firstEntry()); + map.put(2, "b"); + assertEquals(2, map.firstKey()); + assertEquals(new IntEntry(2, "b"), map.firstEntry()); + map.put(3, "c"); + assertEquals(2, map.firstKey()); + assertEquals(new IntEntry(2, "b"), map.firstEntry()); + map.put(2, "b2"); + assertEquals(2, map.firstKey()); + assertThat(map.firstEntry()).isIn(new IntEntry(2, "b"), new IntEntry(2, "b2")); + map.put(1, "a"); + assertEquals(1, map.firstKey()); + assertEquals(new IntEntry(1, "a"), map.firstEntry()); + map.put(2, "b3"); + assertEquals(1, map.firstKey()); + assertEquals(new IntEntry(1, "a"), map.firstEntry()); + map.pollFirstEntry(); + assertEquals(2, map.firstKey()); + assertThat(map.firstEntry()).isIn( + new IntEntry(2, "b"), new IntEntry(2, "b2"), new IntEntry(2, "b3")); + } + + @Test + void lastKeyOrEntry() throws Exception { + assertEquals(noKey, map.lastKey()); + assertNull(map.lastEntry()); + map.put(2, "b"); + assertEquals(2, map.lastKey()); + assertEquals(new IntEntry(2, "b"), map.lastEntry()); + map.put(1, "a"); + assertEquals(2, map.lastKey()); + assertEquals(new IntEntry(2, "b"), map.lastEntry()); + map.put(2, "b2"); + assertEquals(2, map.lastKey()); + assertThat(map.lastEntry()).isIn(new IntEntry(2, "b"), new IntEntry(2, "b2")); + map.put(3, "c"); + assertEquals(3, map.lastKey()); + assertEquals(new IntEntry(3, "c"), map.lastEntry()); + map.put(2, "b3"); + assertEquals(3, map.lastKey()); + assertEquals(new IntEntry(3, "c"), map.lastEntry()); + map.pollLastEntry(); + assertEquals(2, map.lastKey()); + assertThat(map.lastEntry()).isIn( + new IntEntry(2, "b"), new IntEntry(2, "b2"), new IntEntry(2, "b3")); + } + + @RepeatedTest(100) + void firstLastKeyOrEntry() throws Exception { + int[] xs = new int[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + map.put(key, "a"); + xs[i] = key; + } + Arrays.sort(xs); + assertEquals(xs[0], map.firstKey()); + assertEquals(new IntEntry(xs[0], "a"), map.firstEntry()); + assertEquals(xs[xs.length - 1], map.lastKey()); + assertEquals(new IntEntry(xs[xs.length - 1], "a"), map.lastEntry()); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void lowerEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) < 0) { + expected = x; + } else { + break; + } + } + assertEquals(expected, map.lowerEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.lowerKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void lowerEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + } + assertEquals(1, map.lowerKey(3)); + assertEquals(new IntEntry(1, "a"), map.lowerEntry(3)); + assertEquals(3, map.lowerKey(4)); + assertEquals(new IntEntry(3, "b"), map.lowerEntry(4)); + assertEquals(noKey, map.lowerKey(1)); + assertNull(map.lowerEntry(1)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void floorEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) <= 0) { + expected = x; + } else { + break; + } + } + assertEquals(expected, map.floorEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.floorKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void floorEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + } + assertEquals(1, map.floorKey(2)); + assertEquals(new IntEntry(1, "a"), map.floorEntry(2)); + assertEquals(3, map.floorKey(3)); + assertEquals(new IntEntry(3, "b"), map.floorEntry(3)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void ceilEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) >= 0) { + expected = x; + break; + } + } + assertEquals(expected, map.ceilingEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.ceilingKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ceilEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + } + assertEquals(2, map.ceilingKey(2)); + assertEquals(new IntEntry(2, "b"), map.ceilingEntry(2)); + assertEquals(4, map.ceilingKey(3)); + assertEquals(new IntEntry(4, "c"), map.ceilingEntry(3)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void higherEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) > 0) { + expected = x; + break; + } + } + assertEquals(expected, map.higherEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.higherKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void higherEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + } + assertEquals(4, map.higherKey(2)); + assertEquals(new IntEntry(4, "c"), map.higherEntry(2)); + assertEquals(4, map.higherKey(3)); + assertEquals(new IntEntry(4, "c"), map.higherEntry(3)); + assertEquals(noKey, map.higherKey(4)); + assertNull(map.higherEntry(4)); + } + + @Test + void pollingFirstEntryOfMultiMappedKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(2, "b"); + map.put(3, "c"); + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(3, "c"), map.pollFirstEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollingLastEntryOfMultiMappedKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(2, "b"); + map.put(3, "c"); + assertEquals(new IntEntry(3, "c"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(1, "a"), map.pollLastEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollCeilingEntry() throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(2, "b"); + map.put(3, "c"); + map.put(4, "d"); + map.put(4, "d"); + assertEquals(new IntEntry(2, "b"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(2, "b"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(3, "c"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(4, "d"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(4, "d"), map.pollCeilingEntry(2)); + assertNull(map.pollCeilingEntry(2)); + assertFalse(map.isEmpty()); + assertEquals(1, map.size()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java b/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java new file mode 100644 index 00000000000..f11003be929 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.IntSupplier; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class MpscIntQueueTest { + @ParameterizedTest + @ValueSource(ints = {1, 7, 8, 15, 16, 17}) + void mustFillWithSpecifiedEmptyEntry(int size) throws Exception { + MpscIntQueue queue = new MpscAtomicIntegerArrayQueue(size, -1); + int filled = queue.fill(size, new IntSupplier() { + @Override + public int get() throws Exception { + return 42; + } + }); + assertEquals(size, filled); + for (int i = 0; i < size; i++) { + assertEquals(42, queue.poll()); + } + assertEquals(-1, queue.poll()); + assertTrue(queue.isEmpty()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java index aedd77e9d25..7e3e4784920 100644 --- a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java +++ b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java @@ -24,13 +24,17 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -129,6 +133,153 @@ public void run() { } } + @Test + public void testInEventLoopAfterReschedulingFailure() throws Exception { + final UnorderedThreadPoolEventExecutor underlying = new UnorderedThreadPoolEventExecutor(1); + final AtomicInteger executeCount = new AtomicInteger(); + + final EventExecutorGroup wrapper = new AbstractEventExecutorGroup() { + @Override + public void shutdown() { + shutdownGracefully(); + } + + private final EventExecutor executor = new AbstractEventExecutor(this) { + @Override + public boolean inEventLoop(Thread thread) { + return underlying.inEventLoop(thread); + } + + @Override + public void shutdown() { + shutdownGracefully(); + } + + @Override + public void execute(Runnable command) { + // Reject the 2nd execute() call (the reschedule attempt) + // 1st call: initial task submission + // 2nd call: reschedule after maxTaskExecutePerRun + if (executeCount.incrementAndGet() == 2) { + throw new RejectedExecutionException("Simulated queue full"); + } + underlying.execute(command); + } + + @Override + public boolean isShuttingDown() { + return underlying.isShuttingDown(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return underlying.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return underlying.terminationFuture(); + } + + @Override + public boolean isShutdown() { + return underlying.isShutdown(); + } + + @Override + public boolean isTerminated() { + return underlying.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return underlying.awaitTermination(timeout, unit); + } + }; + + @Override + public EventExecutor next() { + return executor; + } + + @Override + public Iterator iterator() { + return Collections.singletonList(executor).iterator(); + } + + @Override + public boolean isShuttingDown() { + return underlying.isShuttingDown(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return underlying.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return underlying.terminationFuture(); + } + + @Override + public boolean isShutdown() { + return underlying.isShutdown(); + } + + @Override + public boolean isTerminated() { + return underlying.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return underlying.awaitTermination(timeout, unit); + } + }; + + // Use maxTaskExecutePerRun=1 so reschedule happens after first task + NonStickyEventExecutorGroup nonStickyGroup = new NonStickyEventExecutorGroup(wrapper, 1); + + try { + final EventExecutor executor = nonStickyGroup.next(); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference inEventLoopResult = new AtomicReference(); + + // Submit 2 tasks: + // Task 1: completes, triggers reschedule which will be rejected + // Task 2: verifies inEventLoop() still works after failed reschedule + executor.execute(new Runnable() { + @Override + public void run() { + // First task - will trigger reschedule attempt that fails + } + }); + + executor.execute(new Runnable() { + @Override + public void run() { + // This runs AFTER the failed rescheduling + // WITHOUT line 262 fix: executingThread is null, inEventLoop() returns false + // WITH line 262 fix: executingThread restored, inEventLoop() returns true + inEventLoopResult.set(executor.inEventLoop()); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS), "Tasks should complete"); + Boolean result = inEventLoopResult.get(); + assertNotNull(result, "inEventLoop() should have been called"); + assertTrue(result, + "inEventLoop() should return true even after failed reschedule attempt. " + + "This indicates executingThread was properly restored in the exception handler."); + } finally { + nonStickyGroup.shutdownGracefully(); + underlying.shutdownGracefully(); + } + } + private static void execute(EventExecutorGroup group, CountDownLatch startLatch) throws Throwable { final EventExecutor executor = group.next(); assertTrue(executor instanceof OrderedEventExecutor); diff --git a/dev-tools/pom.xml b/dev-tools/pom.xml index f587cce722c..de49d37e91b 100644 --- a/dev-tools/pom.xml +++ b/dev-tools/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-dev-tools diff --git a/docker/Dockerfile.al2023 b/docker/Dockerfile.al2023 new file mode 100644 index 00000000000..06c3eedb81f --- /dev/null +++ b/docker/Dockerfile.al2023 @@ -0,0 +1,70 @@ +FROM --platform=linux/amd64 amazonlinux:2023 + +ARG java_version=11.0.30-amzn +ARG aws_lc_version=v1.54.0 +ARG maven_version=3.9.10 +ENV JAVA_VERSION $java_version +ENV AWS_LC_VERSION $aws_lc_version +ENV MAVEN_VERSION $maven_version + +# install dependencies +RUN dnf install -y \ + apr-devel \ + autoconf \ + automake \ + bzip2 \ + cmake \ + gcc \ + gcc-c++ \ + git \ + glibc-devel \ + golang \ + libgcc \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + libtool \ + make \ + ninja-build \ + patch \ + perl \ + perl-parent \ + perl-devel \ + tar \ + unzip \ + wget \ + which \ + zip + +# Downloading and installing SDKMAN! +RUN curl -s "https://get.sdkman.io" | bash + +# Installing Java removing some unnecessary SDKMAN files +RUN bash -c "source $HOME/.sdkman/bin/sdkman-init.sh && \ + yes | sdk install java $JAVA_VERSION && \ + yes | sdk install maven $MAVEN_VERSION && \ + rm -rf $HOME/.sdkman/archives/* && \ + rm -rf $HOME/.sdkman/tmp/*" + +RUN echo 'export JAVA_HOME="/root/.sdkman/candidates/java/current"' >> ~/.bashrc +RUN echo 'export PATH=$JAVA_HOME/bin:$PATH' >> ~/.bashrc + +ENV PATH /root/.sdkman/candidates/java/current/bin:/root/.sdkman/candidates/maven/current/bin:$PATH +ENV JAVA_HOME=/root/.sdkman/candidates/java/current + +# install rust and setup PATH +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y +RUN echo 'PATH=$PATH:$HOME/.cargo/bin' >> ~/.bashrc + +RUN mkdir "$HOME/sources" && \ + git clone https://github.com/aws/aws-lc.git "$HOME/sources/aws-lc" && \ + cd "$HOME/sources/aws-lc" && \ + git checkout $AWS_LC_VERSION && \ + cmake -B build -S . -DCMAKE_INSTALL_PREFIX=/opt/aws-lc -DBUILD_SHARED_LIBS=1 -DBUILD_TESTING=0 && \ + cmake --build build -- -j && \ + cmake --install build + +# Cleanup +RUN dnf clean all && \ + rm -rf /var/cache/dnf && \ + rm -rf "$HOME/sources" diff --git a/docker/Dockerfile.cross_compile_aarch64 b/docker/Dockerfile.cross_compile_aarch64 index 8c1077c3f14..a5e21d20982 100644 --- a/docker/Dockerfile.cross_compile_aarch64 +++ b/docker/Dockerfile.cross_compile_aarch64 @@ -1,10 +1,9 @@ FROM --platform=linux/amd64 centos:7.6.1810 -ARG gcc_version=10.2-2020.11 +ARG gcc_version=10.3-2021.07 ENV GCC_VERSION $gcc_version ENV SOURCE_DIR /root/source - # Update to use the vault RUN sed -i -e 's/^mirrorlist/#mirrorlist/g' -e 's/^#baseurl=http:\/\/mirror.centos.org\/centos\/$releasever\//baseurl=https:\/\/linuxsoft.cern.ch\/centos-vault\/\/7.6.1810\//g' /etc/yum.repos.d/CentOS-Base.repo diff --git a/docker/docker-compose.al2023.yaml b/docker/docker-compose.al2023.yaml new file mode 100644 index 00000000000..ae8e9b4106d --- /dev/null +++ b/docker/docker-compose.al2023.yaml @@ -0,0 +1,65 @@ +services: + + runtime-setup: + image: netty-al2023:x86_64 + build: + context: ../ + dockerfile: docker/Dockerfile.al2023 + + common: &common + image: netty-al2023:x86_64 + depends_on: [runtime-setup] + environment: + LD_LIBRARY_PATH: /opt/aws-lc/lib64 + volumes: + # Use a separate directory for the AL2023 Maven repository + - ~/.m2-al2023:/root/.m2 + - ..:/netty + - ../../netty-tcnative:/netty-tcnative + working_dir: /netty + + common-tcnative: &common-tcnative + <<: *common + environment: + MAVEN_OPTS: + LD_LIBRARY_PATH: /opt/aws-lc/lib64 + LDFLAGS: -L/opt/aws-lc/lib64 -lssl -lcrypto + CFLAGS: -I/opt/aws-lc/include -DHAVE_OPENSSL -lssl -lcrypto + CXXFLAGS: -I/opt/aws-lc/include -DHAVE_OPENSSL -lssl -lcrypto + + install-tcnative: + <<: *common-tcnative + command: '/bin/bash -cl " + ./mvnw -am -pl openssl-dynamic clean install && + env -u LDFLAGS -u CFLAGS -u CXXFLAGS -u LD_LIBRARY_PATH ./mvnw -am -pl boringssl-static clean install + "' + working_dir: /netty-tcnative + + update-tcnative-version: + <<: *common + command: '/bin/bash -cl " + ./mvnw versions:update-property -Dproperty=tcnative.version -DnewVersion=$(cd /netty-tcnative && ./mvnw help:evaluate -Dexpression=project.version -q -DforceStdout) -DallowSnapshots=true -DprocessParent=true -DgenerateBackupPoms=false + "' + + build: + <<: *common + command: '/bin/bash -cl " + ./mvnw -B -ntp clean install -Dio.netty.testsuite.badHost=netty.io -Dtcnative.classifier=linux-x86_64-fedora -Drevapi.skip=true -Dcheckstyle.skip=true -Dforbiddenapis.skip=true + "' + + build-leak: + <<: *common + command: '/bin/bash -cl " + ./mvnw -B -ntp -Pleak clean install -Dio.netty.testsuite.badHost=netty.io -Dtcnative.classifier=linux-x86_64-fedora -Drevapi.skip=true -Dcheckstyle.skip=true -Dforbiddenapis.skip=true + "' + + shell: + <<: *common + volumes: + - ~/.m2-al2023:/root/.m2 + - ~/.gitconfig:/root/.gitconfig + - ~/.gitignore:/root/.gitignore + - ..:/netty + - ../../netty-tcnative:/netty-tcnative + working_dir: /netty + entrypoint: /bin/bash -l diff --git a/docker/docker-compose.centos-6.111.yaml b/docker/docker-compose.centos-6.111.yaml index 5ef7aecc48d..28b01f82ea5 100644 --- a/docker/docker-compose.centos-6.111.yaml +++ b/docker/docker-compose.centos-6.111.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-1.11 build: args: - java_version : "11.0.28-zulu" + java_version : "11.0.30-zulu" build: image: netty:centos-6-1.11 diff --git a/docker/docker-compose.centos-6.18.yaml b/docker/docker-compose.centos-6.18.yaml index ee132f5ca28..ecf9eaae22b 100644 --- a/docker/docker-compose.centos-6.18.yaml +++ b/docker/docker-compose.centos-6.18.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-1.8 build: args: - java_version : "8.0.462-zulu" + java_version : "8.0.482-zulu" build: image: netty:centos-6-1.8 diff --git a/docker/docker-compose.centos-6.21.yaml b/docker/docker-compose.centos-6.21.yaml index 35a8f3b7707..dc0bf62a6a4 100644 --- a/docker/docker-compose.centos-6.21.yaml +++ b/docker/docker-compose.centos-6.21.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-21 build: args: - java_version : "21.0.8-zulu" + java_version : "21.0.10-zulu" build: image: netty:centos-6-21 diff --git a/docker/docker-compose.centos-6.24.yaml b/docker/docker-compose.centos-6.24.yaml index 8646af72da3..de70f61a624 100644 --- a/docker/docker-compose.centos-6.24.yaml +++ b/docker/docker-compose.centos-6.24.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-24 build: args: - java_version : "24.0.1-zulu" + java_version : "24.0.2-zulu" build: image: netty:centos-6-24 diff --git a/docker/docker-compose.centos-6.25.yaml b/docker/docker-compose.centos-6.25.yaml index 07ee2ba8ed1..e7b2cff3cf4 100644 --- a/docker/docker-compose.centos-6.25.yaml +++ b/docker/docker-compose.centos-6.25.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-25 build: args: - java_version : "25-zulu" + java_version : "25.0.2-zulu" build: image: netty:centos-6-25 diff --git a/docker/docker-compose.centos-7.117.yaml b/docker/docker-compose.centos-7.117.yaml index 411ef802512..464e7082fb8 100644 --- a/docker/docker-compose.centos-7.117.yaml +++ b/docker/docker-compose.centos-7.117.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-7-1.17 build: args: - java_version : "17.0.16-zulu" + java_version : "17.0.18-zulu" build: image: netty:centos-7-1.17 diff --git a/docker/docker-compose.centos-7.yaml b/docker/docker-compose.centos-7.yaml index 6c0facb0652..14437428a34 100644 --- a/docker/docker-compose.centos-7.yaml +++ b/docker/docker-compose.centos-7.yaml @@ -8,8 +8,8 @@ services: context: ../ dockerfile: docker/Dockerfile.cross_compile_aarch64 args: - gcc_version: "10.2-2020.11" - java_version: "8.0.462-zulu" + gcc_version: "10.3-2021.07" + java_version: "8.0.482-zulu" cross-compile-aarch64-common: &cross-compile-aarch64-common depends_on: [ cross-compile-aarch64-runtime-setup ] diff --git a/example/pom.xml b/example/pom.xml index 682c3cb3848..38a222dac5d 100644 --- a/example/pom.xml +++ b/example/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-example @@ -32,6 +32,7 @@ true io.netty.example + true diff --git a/handler-proxy/pom.xml b/handler-proxy/pom.xml index ebc56a14add..35184b88daa 100644 --- a/handler-proxy/pom.xml +++ b/handler-proxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-handler-proxy diff --git a/handler-ssl-ocsp/pom.xml b/handler-ssl-ocsp/pom.xml index d7e0056466e..76b5f0bfb58 100644 --- a/handler-ssl-ocsp/pom.xml +++ b/handler-ssl-ocsp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-handler-ssl-ocsp diff --git a/handler/pom.xml b/handler/pom.xml index e0011af86f8..4f1e2229657 100644 --- a/handler/pom.xml +++ b/handler/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-handler diff --git a/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java b/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java index 6265dc27fe1..32f624ccb62 100644 --- a/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java +++ b/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java @@ -28,6 +28,7 @@ import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.SocketChannel; import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -277,7 +278,12 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { // Initialize if needed if (state.get() == State.INIT) { - initializeIfNecessary(ctx); + try { + initializeIfNecessary(ctx); + } catch (Exception ex) { + ReferenceCountUtil.release(msg); + throw ex; + } } // Only write if State is STARTED @@ -297,7 +303,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { // Initialize if needed if (state.get() == State.INIT) { - initializeIfNecessary(ctx); + try { + initializeIfNecessary(ctx); + } catch (Exception ex) { + ReferenceCountUtil.release(msg); + promise.setFailure(ex); + return; + } } // Only write if State is STARTED diff --git a/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java b/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java index c2c3e9032a9..0807daea621 100644 --- a/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java +++ b/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java @@ -18,15 +18,22 @@ import io.netty.util.internal.SuppressJava6Requirement; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.X509ExtendedTrustManager; -import javax.net.ssl.X509TrustManager; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Collection; import java.util.List; - +import javax.naming.ldap.LdapName; +import javax.naming.ldap.Rdn; +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; /** * Wraps an existing {@link X509ExtendedTrustManager} and enhances the {@link CertificateException} that is thrown @@ -34,6 +41,13 @@ */ @SuppressJava6Requirement(reason = "Usage guarded by java version check") final class EnhancingX509ExtendedTrustManager extends X509ExtendedTrustManager { + + // Constants for subject alt names of type DNS and IP. See X509Certificate#getSubjectAlternativeNames() javadocs. + static final int ALTNAME_DNS = 2; + static final int ALTNAME_URI = 6; + static final int ALTNAME_IP = 7; + private static final String SEPARATOR = ", "; + private final X509ExtendedTrustManager wrapped; EnhancingX509ExtendedTrustManager(X509TrustManager wrapped) { @@ -52,7 +66,8 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, Socket try { wrapped.checkServerTrusted(chain, authType, socket); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, + socket instanceof SSLSocket ? ((SSLSocket) socket).getHandshakeSession() : null); } } @@ -68,7 +83,7 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngi try { wrapped.checkServerTrusted(chain, authType, engine); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, engine != null ? engine.getHandshakeSession() : null); } } @@ -84,7 +99,7 @@ public void checkServerTrusted(X509Certificate[] chain, String authType) try { wrapped.checkServerTrusted(chain, authType); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, null); } } @@ -93,32 +108,91 @@ public X509Certificate[] getAcceptedIssuers() { return wrapped.getAcceptedIssuers(); } - private static void throwEnhancedCertificateException(X509Certificate[] chain, CertificateException e) - throws CertificateException { + private static void throwEnhancedCertificateException(CertificateException e, X509Certificate[] chain, + SSLSession session) throws CertificateException { // Matching the message is the best we can do sadly. String message = e.getMessage(); - if (message != null && e.getMessage().startsWith("No subject alternative DNS name matching")) { - StringBuilder names = new StringBuilder(64); + if (message != null && + (message.startsWith("No subject alternative") || message.startsWith("No name matching"))) { + StringBuilder sb = new StringBuilder(128); + sb.append(message); + // Some exception messages from sun.security.util.HostnameChecker may end with a dot that we don't need + if (message.charAt(message.length() - 1) == '.') { + sb.setLength(sb.length() - 1); + } + if (session != null) { + sb.append(" for SNIHostName=").append(getSNIHostName(session)) + .append(" and peerHost=").append(session.getPeerHost()); + } + sb.append(" in the chain of ").append(chain.length).append(" certificate(s):"); for (int i = 0; i < chain.length; i++) { X509Certificate cert = chain[i]; Collection> collection = cert.getSubjectAlternativeNames(); + sb.append(' ').append(i + 1).append(". subjectAlternativeNames=["); if (collection != null) { + boolean hasNames = false; for (List altNames : collection) { - // 2 is dNSName. See X509Certificate javadocs. - if (altNames.size() >= 2 && ((Integer) altNames.get(0)).intValue() == 2) { - names.append((String) altNames.get(1)).append(","); + if (altNames.size() < 2) { + // We expect at least a pair of 'nameType:value' in that list. + continue; + } + final int nameType = ((Integer) altNames.get(0)).intValue(); + if (nameType == ALTNAME_DNS) { + sb.append("DNS"); + } else if (nameType == ALTNAME_IP) { + sb.append("IP"); + } else if (nameType == ALTNAME_URI) { + // URI names are common in some environments with gRPC services that use SPIFFEs. + // Though the hostname matcher won't be looking at them, having them there can help + // debugging cases where hostname verification was enabled when it shouldn't be. + sb.append("URI"); + } else { + continue; } + sb.append(':').append((String) altNames.get(1)).append(SEPARATOR); + hasNames = true; + } + if (hasNames) { + // Strip of the last separator + sb.setLength(sb.length() - SEPARATOR.length()); } } + sb.append("], CN=").append(getCommonName(cert)).append('.'); } - if (names.length() != 0) { - // Strip of , - names.setLength(names.length() - 1); - throw new CertificateException(message + - " Subject alternative DNS names in the certificate chain of " + chain.length + - " certificate(s): " + names, e); - } + throw new CertificateException(sb.toString(), e); } throw e; } + + private static String getSNIHostName(SSLSession session) { + if (!(session instanceof ExtendedSSLSession)) { + return null; + } + List names = ((ExtendedSSLSession) session).getRequestedServerNames(); + for (SNIServerName sni : names) { + if (sni instanceof SNIHostName) { + SNIHostName hostName = (SNIHostName) sni; + return hostName.getAsciiName(); + } + } + return null; + } + + private static String getCommonName(X509Certificate cert) { + try { + // 1. Get the X500Principal (better than getSubjectDN which is implementation dependent and deprecated) + X500Principal principal = cert.getSubjectX500Principal(); + // 2. Parse the DN using LdapName + LdapName ldapName = new LdapName(principal.getName()); + // 3. Iterate over the Relative Distinguished Names (RDNs) to find CN + for (Rdn rdn : ldapName.getRdns()) { + if (rdn.getType().equalsIgnoreCase("CN")) { + return rdn.getValue().toString(); + } + } + } catch (Exception ignore) { + // ignore + } + return "null"; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java index e1eabf71a7e..af4bcf6779a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java @@ -67,6 +67,7 @@ public final class OpenSsl { private static final boolean SUPPORTS_OCSP; private static final boolean TLSV13_SUPPORTED; private static final boolean IS_BORINGSSL; + private static final boolean IS_AWSLC; private static final Set CLIENT_DEFAULT_PROTOCOLS; private static final Set SERVER_DEFAULT_PROTOCOLS; static final Set SUPPORTED_PROTOCOLS_SET; @@ -161,6 +162,7 @@ public final class OpenSsl { } IS_BORINGSSL = "BoringSSL".equals(versionString()); + IS_AWSLC = versionString().startsWith("AWS-LC"); if (IS_BORINGSSL) { EXTRA_SUPPORTED_TLS_1_3_CIPHERS = new String [] { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" , @@ -268,7 +270,7 @@ public final class OpenSsl { try { boolean propertySet = SystemPropertyUtil.contains( "io.netty.handler.ssl.openssl.useKeyManagerFactory"); - if (!IS_BORINGSSL) { + if (!(IS_BORINGSSL || IS_AWSLC)) { useKeyManagerFactory = SystemPropertyUtil.getBoolean( "io.netty.handler.ssl.openssl.useKeyManagerFactory", true); @@ -282,7 +284,7 @@ public final class OpenSsl { if (propertySet) { logger.info("System property " + "'io.netty.handler.ssl.openssl.useKeyManagerFactory'" + - " is deprecated and will be ignored when using BoringSSL"); + " is deprecated and will be ignored when using BoringSSL or AWS-LC"); } } } catch (Throwable ignore) { @@ -453,6 +455,7 @@ public final class OpenSsl { SUPPORTS_OCSP = false; TLSV13_SUPPORTED = false; IS_BORINGSSL = false; + IS_AWSLC = false; EXTRA_SUPPORTED_TLS_1_3_CIPHERS = EmptyArrays.EMPTY_STRINGS; EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING = StringUtil.EMPTY_STRING; NAMED_GROUPS = DEFAULT_NAMED_GROUPS; @@ -738,7 +741,7 @@ static boolean isOptionSupported(SslContextOption option) { return true; } // Check for options that are only supported by BoringSSL atm. - if (isBoringSSL()) { + if (isBoringSSL() || isAWSLC()) { return option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD || option == OpenSslContextOption.PRIVATE_KEY_METHOD || option == OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS || @@ -779,4 +782,8 @@ static String[] defaultProtocols(boolean isClient) { static boolean isBoringSSL() { return IS_BORINGSSL; } + + static boolean isAWSLC() { + return IS_AWSLC; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java index 4de373a6231..9886543f97e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -146,10 +146,13 @@ public ReferenceCounted touch(Object hint) { @Override protected void deallocate() { - destroy(); - if (leak != null) { - boolean closed = leak.close(ReferenceCountedOpenSslContext.this); - assert closed; + try { + destroy(); + } finally { + if (leak != null) { + boolean closed = leak.close(ReferenceCountedOpenSslContext.this); + assert closed; + } } } }; @@ -326,7 +329,8 @@ public ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListene } } else { CipherSuiteConverter.convertToCipherStrings( - unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, OpenSsl.isBoringSSL()); + unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, + OpenSsl.isBoringSSL()); // Set non TLSv1.3 ciphers. SSLContext.setCipherSuite(ctx, cipherBuilder.toString(), false); diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 8ed9324c0f1..faee3f098ab 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -386,9 +386,9 @@ public List getStatusResponses() { } } - if (OpenSsl.isBoringSSL() && clientMode) { - // If in client-mode and BoringSSL let's allow to renegotiate once as the server may use this - // for client auth. + if ((OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) && clientMode) { + // If in client-mode and provider is BoringSSL or AWS-LC let's allow to renegotiate once as the + // server may use this for client auth. // // See https://github.com/netty/netty/issues/11529 SSL.setRenegotiateMode(ssl, SSL.SSL_RENEGOTIATE_ONCE); @@ -1704,7 +1704,8 @@ public final void setEnabledCipherSuites(String[] cipherSuites) { final StringBuilder buf = new StringBuilder(); final StringBuilder bufTLSv13 = new StringBuilder(); - CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, OpenSsl.isBoringSSL()); + CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, + OpenSsl.isBoringSSL()); final String cipherSuiteSpec = buf.toString(); final String cipherSuiteSpecTLSv13 = bufTLSv13.toString(); diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index f80b3004a8a..8b8b88d4da7 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1903,6 +1903,10 @@ private void resumeOnEventExecutor() { void runComplete() { EventExecutor executor = ctx.executor(); + if (executor.isShuttingDown()) { + // The executor is already shutting down, just return. + return; + } // Jump back on the EventExecutor. We do this even if we are already on the EventLoop to guard against // reentrancy issues. Failing to do so could lead to the situation of tryDecode(...) be called and so // channelRead(...) while still in the decode loop. In this case channelRead(...) might release the input diff --git a/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java b/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java index b502f8cc3e2..d34f038949e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java @@ -15,6 +15,7 @@ */ package io.netty.handler.ssl.util; +import io.netty.util.Recycler; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.SuppressJava6Requirement; @@ -41,18 +42,37 @@ import java.util.Set; public final class LazyX509Certificate extends X509Certificate { + private static final Recycler CERT_FACTORIES = new Recycler() { + @Override + protected CertFactoryHandle newObject(Handle handle) { + try { + return new CertFactoryHandle(CertificateFactory.getInstance("X.509"), handle); + } catch (CertificateException e) { + throw new IllegalStateException(e); + } + } + }; + + private static final class CertFactoryHandle { + private final CertificateFactory factory; + private final Recycler.EnhancedHandle handle; + + private CertFactoryHandle(CertificateFactory factory, Recycler.Handle handle) { + this.factory = factory; + this.handle = (Recycler.EnhancedHandle) handle; + } + + public X509Certificate generateCertificate(byte[] bytes) throws CertificateException { + return (X509Certificate) factory.generateCertificate(new ByteArrayInputStream(bytes)); + } - static final CertificateFactory X509_CERT_FACTORY; - static { - try { - X509_CERT_FACTORY = CertificateFactory.getInstance("X.509"); - } catch (CertificateException e) { - throw new ExceptionInInitializerError(e); + public void recycle() { + handle.unguardedRecycle(this); } } private final byte[] bytes; - private X509Certificate wrapped; + private volatile X509Certificate wrapped; /** * Creates a new instance which will lazy parse the given bytes. Be aware that the bytes will not be cloned. @@ -230,11 +250,16 @@ public byte[] getExtensionValue(String oid) { private X509Certificate unwrap() { X509Certificate wrapped = this.wrapped; if (wrapped == null) { + CertFactoryHandle factory = null; try { - wrapped = this.wrapped = (X509Certificate) X509_CERT_FACTORY.generateCertificate( - new ByteArrayInputStream(bytes)); + factory = CERT_FACTORIES.get(); + wrapped = this.wrapped = factory.generateCertificate(bytes); } catch (CertificateException e) { throw new IllegalStateException(e); + } finally { + if (factory != null) { + factory.recycle(); + } } } return wrapped; diff --git a/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java b/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java index 5d1aa8f60db..78e24f46d9d 100644 --- a/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.api.parallel.Isolated; import org.mockito.ArgumentMatcher; import org.slf4j.LoggerFactory; @@ -54,6 +55,7 @@ /** * Verifies the correct functionality of the {@link LoggingHandler}. */ +@Isolated public class LoggingHandlerTest { private static final String LOGGER_NAME = LoggingHandler.class.getName(); diff --git a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java index 9b7398ba5ef..56fa9c4074a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java @@ -223,16 +223,17 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) } } finally { server.close().sync(); + + if (executorService != null) { + executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); + } } } finally { ReferenceCountUtil.release(sslClientContext); } } finally { ReferenceCountUtil.release(sslServerContext); - - if (executorService != null) { - executorService.shutdown(); - } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java b/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java index e3c39cbc7d9..65cbc448aa8 100644 --- a/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java +++ b/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java @@ -42,7 +42,8 @@ public void execute(Runnable command) { PlatformDependent.threadLocalRandom().nextInt(100), TimeUnit.MILLISECONDS); } - void shutdown() { + boolean shutdownAndAwaitTermination(long timeout, TimeUnit unit) throws InterruptedException { service.shutdown(); + return service.awaitTermination(timeout, unit); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java b/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java index 60976127579..9e9a689c878 100644 --- a/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java @@ -17,13 +17,13 @@ package io.netty.handler.ssl; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLSocket; -import javax.net.ssl.X509ExtendedTrustManager; import java.math.BigInteger; import java.net.Socket; import java.security.Principal; @@ -32,23 +32,48 @@ import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Set; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.security.auth.x500.X500Principal; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_DNS; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_IP; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_URI; +import static io.netty.handler.ssl.SniClientJava8TestUtil.mockSSLSessionWithSNIHostNameAndPeerHost; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; public class EnhancedX509ExtendedTrustManagerTest { + private static final String HOSTNAME = "netty.io"; + private static final String SAN_ENTRY_DNS = "some.netty.io"; + private static final String SAN_ENTRY_IP = "127.0.0.1"; + private static final String SAN_ENTRY_URI = "URI:https://uri.netty.io/profile"; + private static final String SAN_ENTRY_RFC822 = "info@netty.io"; + private static final String COMMON_NAME = "leaf.netty.io"; + private static final X509Certificate TEST_CERT = new X509Certificate() { @Override public Collection> getSubjectAlternativeNames() { - return Arrays.asList(Arrays.asList(1, new Object()), Arrays.asList(2, "some.netty.io")); + return Arrays.asList(Arrays.asList(1, new Object()), + Arrays.asList(ALTNAME_DNS, SAN_ENTRY_DNS), Arrays.asList(ALTNAME_IP, SAN_ENTRY_IP), + Arrays.asList(ALTNAME_URI, SAN_ENTRY_URI), Arrays.asList(1 /* rfc822Name */, SAN_ENTRY_RFC822)); + } + + @Override + public X500Principal getSubjectX500Principal() { + return new X500Principal("CN=" + COMMON_NAME + ", O=Netty"); } @Override @@ -192,7 +217,7 @@ public void checkClientTrusted(X509Certificate[] chain, String authType, Socket @Override public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override @@ -203,7 +228,7 @@ public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngi @Override public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override @@ -214,16 +239,23 @@ public void checkClientTrusted(X509Certificate[] chain, String authType) { @Override public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } + + private CertificateException newCertificateExceptionWithMatchingMessage() { + return new CertificateException("No subject alternative DNS name matching " + HOSTNAME + " found."); + } }); static List throwingMatchingExecutables() { + if (PlatformDependent.javaVersion() < 8) { + return Collections.emptyList(); + } return Arrays.asList(new Executable() { @Override public void execute() throws Throwable { @@ -232,12 +264,18 @@ public void execute() throws Throwable { }, new Executable() { @Override public void execute() throws Throwable { - MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLEngine) null); + SSLSession session = mockSSLSessionWithSNIHostNameAndPeerHost(HOSTNAME); + SSLEngine engine = Mockito.mock(SSLEngine.class); + Mockito.when(engine.getHandshakeSession()).thenReturn(session); + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, engine); } }, new Executable() { @Override public void execute() throws Throwable { - MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLSocket) null); + SSLSession session = mockSSLSessionWithSNIHostNameAndPeerHost(HOSTNAME); + SSLSocket socket = Mockito.mock(SSLSocket.class); + Mockito.when(socket.getHandshakeSession()).thenReturn(session); + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, socket); } }); } @@ -307,16 +345,28 @@ public void execute() throws Throwable { @ParameterizedTest @MethodSource("throwingMatchingExecutables") - void testEnhanceException(Executable executable) { + void testEnhanceException(Executable executable, TestInfo testInfo) { + assumeTrue(PlatformDependent.javaVersion() >= 8); CertificateException exception = assertThrows(CertificateException.class, executable); // We should wrap the original cause with our own. assertInstanceOf(CertificateException.class, exception.getCause()); - assertThat(exception.getMessage()).contains("some.netty.io"); + String message = exception.getMessage(); + if (testInfo.getDisplayName().contains("with")) { + // The following data can be extracted only when we run the test with SSLEngine or SSLSocket: + assertThat(message).contains("SNIHostName=" + HOSTNAME); + assertThat(message).contains("peerHost=" + HOSTNAME); + } + assertThat(message).contains("DNS:" + SAN_ENTRY_DNS); + assertThat(message).contains("IP:" + SAN_ENTRY_IP); + assertThat(message).contains("URI:" + SAN_ENTRY_URI); + assertThat(message).contains("CN=" + COMMON_NAME); + assertThat(message).doesNotContain(SAN_ENTRY_RFC822); } @ParameterizedTest @MethodSource("throwingNonMatchingExecutables") void testNotEnhanceException(Executable executable) { + assumeTrue(PlatformDependent.javaVersion() >= 8); CertificateException exception = assertThrows(CertificateException.class, executable); // We should not wrap the original cause with our own. assertNull(exception.getCause()); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java index 9882c722aad..4122245fcbe 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java @@ -71,7 +71,7 @@ public void refreshAlgos() { @Test public void testSimple() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -92,7 +92,7 @@ public void testSimple() throws Throwable { @Test public void testServerPriority() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -116,7 +116,7 @@ public void testServerPriority() throws Throwable { @Test public void testServerPriorityReverse() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -141,7 +141,7 @@ public void testServerPriorityReverse() throws Throwable { @Test public void testFailedNegotiation() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -162,7 +162,7 @@ public void testFailedNegotiation() throws Throwable { @Test public void testAlgoFailure() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); TestCertCompressionAlgo badZlibAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { @Override @@ -191,7 +191,7 @@ public void execute() throws Throwable { @Test public void testAlgoException() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); TestCertCompressionAlgo badZlibAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { @Override @@ -220,7 +220,7 @@ public void execute() throws Throwable { @Test public void testTlsLessThan13() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = SslContextBuilder.forClient() .sslProvider(SslProvider.OPENSSL) .protocols(SslProtocols.TLS_v1_2) @@ -251,7 +251,7 @@ public void testTlsLessThan13() throws Throwable { @Test public void testDuplicateAdd() throws Throwable { // Fails with "Failed trying to add certificate compression algorithm" - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); Assertions.assertThrows(Exception.class, new Executable() { @Override public void execute() throws Throwable { @@ -283,7 +283,7 @@ public void execute() throws Throwable { @Test public void testNotBoringAdd() throws Throwable { // Fails with "TLS Cert Compression only supported by BoringSSL" - assumeTrue(!OpenSsl.isBoringSSL()); + assumeTrue(!OpenSsl.isBoringSSL() && !OpenSsl.isAWSLC()); Assertions.assertThrows(Exception.class, new Executable() { @Override public void execute() throws Throwable { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java index 896a21583d0..defabca75d0 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java @@ -25,7 +25,7 @@ static void expandCombinations(SSLEngineTest.SSLEngineTestParam param, List output) { output.add(new OpenSslEngineTestParam(true, false, param)); output.add(new OpenSslEngineTestParam(false, false, param)); - if (OpenSsl.isBoringSSL()) { + if (OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) { output.add(new OpenSslEngineTestParam(true, true, param)); output.add(new OpenSslEngineTestParam(false, true, param)); } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java index ef268a8a253..3a3f049bd0b 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java @@ -58,8 +58,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -93,7 +91,7 @@ static Collection parameters() { public static void init() throws Exception { checkShouldUseKeyManagerFactory(); - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API // implementations. assumeCipherAvailable(SslProvider.OPENSSL); @@ -110,11 +108,11 @@ public Thread newThread(Runnable r) { } @AfterAll - public static void destroy() { - if (OpenSsl.isBoringSSL()) { + public static void destroy() throws InterruptedException { + if (OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) { GROUP.shutdownGracefully(); + assertTrue(EXECUTOR.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); CERT.delete(); - EXECUTOR.shutdown(); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java index 36ff3de72e3..573f74bb221 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java @@ -28,7 +28,7 @@ import org.mockito.MockitoAnnotations; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class OptionalSslHandlerTest { @@ -115,7 +115,7 @@ public void decodeBuffered() throws Exception { final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3 }); try { handler.decode(context, payload, null); - verifyZeroInteractions(pipeline); + verifyNoInteractions(pipeline); } finally { payload.release(); } diff --git a/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java b/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java index 174243d9a01..1128d8e4cd3 100644 --- a/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java @@ -101,7 +101,7 @@ public void operationComplete(Future future) throws Exception { }); } }); - Channel channel = sb.bind(new LocalAddress("RenegotiateTest")).syncUninterruptibly().channel(); + Channel channel = sb.bind(new LocalAddress(getClass())).syncUninterruptibly().channel(); final SslContext clientContext = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index da3ca84b444..55387e53930 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -50,7 +50,6 @@ import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; -import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.conscrypt.OpenSSLProvider; @@ -547,7 +546,7 @@ public void tearDown() throws InterruptedException { if (clientGroupShutdownFuture != null) { clientGroupShutdownFuture.sync(); } - delegatingExecutor.shutdown(); + assertTrue(delegatingExecutor.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); serverException = null; clientException = null; } @@ -629,11 +628,7 @@ public void testIncompatibleCiphers(final SSLEngineTestParam param) throws Excep serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); // Set the server to only support a single TLSv1.2 cipher - final String serverCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String serverCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; serverEngine.setEnabledCipherSuites(new String[] { serverCipher }); // Set the client to only support a single TLSv1.3 cipher @@ -1383,9 +1378,9 @@ public void testSessionInvalidate(SSLEngineTestParam param) throws Exception { handshake(param.type(), param.delegate(), clientEngine, serverEngine); SSLSession session = serverEngine.getSession(); - assertTrue(session.isValid()); + assertTrue(session.isValid(), "session should be valid: " + session); session.invalidate(); - assertFalse(session.isValid()); + assertFalse(session.isValid(), "session should be invalid: " + session); } finally { cleanupClientSslEngine(clientEngine); cleanupServerSslEngine(serverEngine); @@ -2247,11 +2242,7 @@ public void testHandshakeCompletesWithNonContiguousProtocolsTLSv1_2CipherOnly(SS SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail // due to no shared/supported cipher. - final String sharedCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String sharedCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; clientSslCtx = wrapContext(param, SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .ciphers(Collections.singletonList(sharedCipher)) @@ -2284,11 +2275,7 @@ public void testHandshakeCompletesWithoutFilteringSupportedCipher(SSLEngineTestP SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail // due to no shared/supported cipher. - final String sharedCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String sharedCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; clientSslCtx = wrapContext(param, SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .ciphers(Collections.singletonList(sharedCipher), SupportedCipherSuiteFilter.INSTANCE) @@ -4511,11 +4498,9 @@ public void testMasterKeyLogging(final SSLEngineTestParam param) throws Exceptio * The JDK SSL engine master key retrieval relies on being able to set field access to true. * That is not available in JDK9+ */ - assumeFalse(sslServerProvider() == SslProvider.JDK && PlatformDependent.javaVersion() > 8); - - String originalSystemPropertyValue = SystemPropertyUtil.get(SslMasterKeyHandler.SYSTEM_PROP_KEY); - System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, Boolean.TRUE.toString()); - + if (sslServerProvider() == SslProvider.JDK) { + assumeTrue(SslMasterKeyHandler.isSunSslEngineAvailable()); + } SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) .sslProvider(sslServerProvider()) @@ -4542,6 +4527,12 @@ protected void initChannel(Channel ch) { ch.pipeline().addLast(sslHandler); ch.pipeline().addLast(new SslMasterKeyHandler() { + + @Override + protected boolean masterKeyHandlerEnabled() { + return true; + } + @Override protected void accept(SecretKey masterKey, SSLSession session) { promise.setSuccess(masterKey); @@ -4565,11 +4556,6 @@ protected void accept(SecretKey masterKey, SSLSession session) { assertEquals(48, key.getEncoded().length, "AES secret key must be 48 bytes"); } finally { closeQuietly(socket); - if (originalSystemPropertyValue != null) { - System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, originalSystemPropertyValue); - } else { - System.clearProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY); - } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java index 3554e5ae46b..2e67ac87279 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java @@ -35,6 +35,7 @@ import io.netty.util.concurrent.Promise; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.ThrowableUtil; +import org.mockito.Mockito; import javax.net.ssl.ExtendedSSLSession; import javax.net.ssl.KeyManager; @@ -64,6 +65,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -345,4 +347,12 @@ public String chooseEngineServerAlias(String s, Principal[] principals, }, factory.getProvider(), factory.getAlgorithm()); } } + + static SSLSession mockSSLSessionWithSNIHostNameAndPeerHost(String hostname) { + ExtendedSSLSession session = Mockito.mock(ExtendedSSLSession.class); + SNIServerName sniName = new SNIHostName(hostname); + Mockito.when(session.getRequestedServerNames()).thenReturn(Arrays.asList(sniName)); + Mockito.when(session.getPeerHost()).thenReturn(hostname); + return session; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index bd0223cd1f6..6bc63ffdf08 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -31,7 +31,6 @@ import io.netty.handler.codec.TooLongFrameException; import io.netty.handler.ssl.util.CachedSelfSignedCertificate; import io.netty.util.concurrent.Future; - import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; @@ -70,14 +69,14 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.mockito.Mockito.mock; public class SniHandlerTest { @@ -394,10 +393,19 @@ public void testMajorVersionNot3(SslProvider provider) throws Exception { @ParameterizedTest(name = "{index}: sslProvider={0}") @MethodSource("data") - public void testSniWithApnHandler(SslProvider provider) throws Exception { - SslContext nettyContext = makeSslContext(provider, true); - SslContext sniContext = makeSslContext(provider, true); - final SslContext clientContext = makeSslClientContext(provider, true); + public void testSniWithAlpnHandler(SslProvider provider) throws Exception { + SslContext nettyContext = null; + SslContext sniContext = null; + final SslContext clientContext; + try { + nettyContext = makeSslContext(provider, true); + sniContext = makeSslContext(provider, true); + clientContext = makeSslClientContext(provider, true); + } catch (Exception e) { + ReferenceCountUtil.safeRelease(nettyContext); + ReferenceCountUtil.safeRelease(sniContext); + throw e; + } try { final AtomicBoolean serverApnCtx = new AtomicBoolean(false); final AtomicBoolean clientApnCtx = new AtomicBoolean(false); @@ -455,8 +463,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); - ChannelFuture ccf = cb.connect(serverChannel.localAddress()); - assertTrue(ccf.awaitUninterruptibly().isSuccess()); + ChannelFuture ccf = cb.connect(serverChannel.localAddress()).sync(); clientChannel = ccf.channel(); assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS)); @@ -472,7 +479,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { if (clientChannel != null) { clientChannel.close().sync(); } - group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + group.shutdownGracefully(100, 5000, TimeUnit.MILLISECONDS).sync(); } } finally { releaseAll(clientContext, nettyContext, sniContext); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java index be0785195fa..17ee9686a6a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java @@ -126,18 +126,18 @@ public void testContextFromManagersOpenssl() throws Exception { @Test public void testUnsupportedPrivateKeyFailsFastForServer() { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); testUnsupportedPrivateKeyFailsFast(true); } @Test public void testUnsupportedPrivateKeyFailsFastForClient() { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); testUnsupportedPrivateKeyFailsFast(false); } private static void testUnsupportedPrivateKeyFailsFast(boolean server) { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); String cert = "-----BEGIN CERTIFICATE-----\n" + "MIICODCCAY2gAwIBAgIEXKTrajAKBggqhkjOPQQDBDBUMQswCQYDVQQGEwJVUzEM\n" + "MAoGA1UECAwDTi9hMQwwCgYDVQQHDANOL2ExDDAKBgNVBAoMA04vYTEMMAoGA1UE\n" + diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 3a5d16d3a46..226bcebfaee 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -1000,7 +1000,7 @@ public void testHandshakeWithExecutorJDK() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.JDK, false); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1029,7 +1029,7 @@ public void testHandshakeWithExecutorOpenSsl() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, false); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1054,7 +1054,7 @@ public void testHandshakeMTLSWithExecutorJDK() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.JDK, true); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1083,7 +1083,7 @@ public void testHandshakeMTLSWithExecutorOpenSsl() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, true); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1574,7 +1574,8 @@ public void testHandshakeFailureCipherMissmatchTLSv12OpenSsl() throws Exception public void testHandshakeFailureCipherMissmatchTLSv13OpenSsl() throws Exception { OpenSsl.ensureAvailability(); assumeTrue(SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); - assumeFalse(OpenSsl.isBoringSSL(), "BoringSSL does not support setting ciphers for TLSv1.3 explicit"); + assumeFalse(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC(), + "Provider does not support setting ciphers for TLSv1.3 explicitly"); testHandshakeFailureCipherMissmatch(SslProvider.OPENSSL, true); } diff --git a/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java b/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java index ca678598577..59a28c69aba 100644 --- a/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java @@ -21,9 +21,15 @@ import java.io.ByteArrayInputStream; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; public class LazyX509CertificateTest { @@ -79,7 +85,29 @@ public void testLazyX509Certificate() throws Exception { assertArrayEquals(x509Certificate.getKeyUsage(), lazyX509Certificate.getKeyUsage()); assertEquals(x509Certificate.getExtendedKeyUsage(), lazyX509Certificate.getExtendedKeyUsage()); assertEquals(x509Certificate.getBasicConstraints(), lazyX509Certificate.getBasicConstraints()); - assertEquals(x509Certificate.getSubjectAlternativeNames(), lazyX509Certificate.getSubjectAlternativeNames()); - assertEquals(x509Certificate.getIssuerAlternativeNames(), lazyX509Certificate.getIssuerAlternativeNames()); + assertEqualSans(x509Certificate.getSubjectAlternativeNames(), lazyX509Certificate.getSubjectAlternativeNames()); + assertEqualSans(x509Certificate.getIssuerAlternativeNames(), lazyX509Certificate.getIssuerAlternativeNames()); + } + + private static void assertEqualSans(Collection> expectedSans, Collection> actualSans) { + String errMsgSans = expectedSans + " != " + actualSans; + if (expectedSans == null) { + assertNull(actualSans, errMsgSans); + return; + } + assertEquals(expectedSans.size(), actualSans.size(), errMsgSans); + Iterator> expectItr = expectedSans.iterator(); + Iterator> actualItr = actualSans.iterator(); + while (expectItr.hasNext() && actualItr.hasNext()) { + List expectedSan = expectItr.next(); + List actualSan = actualItr.next(); + String errMsgSan = expectedSan + " != " + actualSan; + assertEquals(2, expectedSan.size(), errMsgSan); + assertEquals(2, actualSan.size(), errMsgSan); + assertEquals(expectedSan.get(0), actualSan.get(0), errMsgSan); + assertEquals(expectedSan.get(1), actualSan.get(1), errMsgSan); + } + assertFalse(expectItr.hasNext(), errMsgSans); + assertFalse(actualItr.hasNext(), errMsgSans); } } diff --git a/microbench/pom.xml b/microbench/pom.xml index 8850eb00316..4a4e6153531 100644 --- a/microbench/pom.xml +++ b/microbench/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-microbench @@ -223,6 +223,13 @@ **/Http2FrameWriterBenchmark.java + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + diff --git a/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java b/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java index 7ae7f56c0e5..6eb10982ed0 100644 --- a/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java +++ b/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java @@ -35,6 +35,7 @@ public class AbstractReferenceCountedByteBufBenchmark extends AbstractMicrobenchmark { @Param({ + "0", "1", "10", "100", @@ -60,10 +61,16 @@ public void tearDown() { @OutputTimeUnit(TimeUnit.NANOSECONDS) public boolean retainReleaseUncontended() { buf.retain(); - Blackhole.consumeCPU(delay); + delay(); return buf.release(); } + private void delay() { + if (delay > 0) { + Blackhole.consumeCPU(delay); + } + } + @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @@ -71,7 +78,7 @@ public boolean retainReleaseUncontended() { public boolean createUseAndRelease(Blackhole useBuffer) { ByteBuf unpooled = Unpooled.buffer(1); useBuffer.consume(unpooled); - Blackhole.consumeCPU(delay); + delay(); return unpooled.release(); } @@ -81,7 +88,7 @@ public boolean createUseAndRelease(Blackhole useBuffer) { @GroupThreads(4) public boolean retainReleaseContended() { buf.retain(); - Blackhole.consumeCPU(delay); + delay(); return buf.release(); } } diff --git a/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java b/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java index 7d7df184c60..18ab53b24c4 100644 --- a/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java +++ b/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java @@ -23,27 +23,98 @@ import io.netty.util.CharsetUtil; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; -import static io.netty.handler.codec.http.HttpConstants.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import static io.netty.handler.codec.http.HttpConstants.CR; +import static io.netty.handler.codec.http.HttpConstants.LF; +import static io.netty.handler.codec.http.HttpConstants.SP; @State(Scope.Benchmark) @Warmup(iterations = 10) @Measurement(iterations = 20) public class HttpRequestEncoderInsertBenchmark extends AbstractMicrobenchmark { - private final String uri = "http://localhost?eventType=CRITICAL&from=0&to=1497437160327&limit=10&offset=0"; + private static final String[] PARAMS = { + "eventType=CRITICAL", + "from=0", + "to=1497437160327", + "limit=10", + "offset=0" + }; + @Param({"1024", "128000"}) + private int samples; + + private String[] uris; + private int index; private final OldHttpRequestEncoder encoderOld = new OldHttpRequestEncoder(); private final HttpRequestEncoder encoderNew = new HttpRequestEncoder(); + @Setup + public void setup() { + List permutations = new ArrayList(); + permute(PARAMS.clone(), 0, permutations); + + String[] allCombinations = new String[permutations.size()]; + String base = "http://localhost?"; + for (int i = 0; i < permutations.size(); i++) { + StringBuilder sb = new StringBuilder(base); + String[] p = permutations.get(i); + for (int j = 0; j < p.length; j++) { + if (j != 0) { + sb.append('&'); + } + sb.append(p[j]); + } + allCombinations[i] = sb.toString(); + } + uris = new String[samples]; + Random rand = new Random(42); + for (int i = 0; i < uris.length; i++) { + uris[i] = allCombinations[rand.nextInt(allCombinations.length)]; + } + index = 0; + } + + private static void permute(String[] arr, int start, List out) { + if (start == arr.length - 1) { + out.add(Arrays.copyOf(arr, arr.length)); + return; + } + for (int i = start; i < arr.length; i++) { + swap(arr, start, i); + permute(arr, start + 1, out); + swap(arr, start, i); + } + } + + private static void swap(String[] a, int i, int j) { + String t = a[i]; + a[i] = a[j]; + a[j] = t; + } + + private String nextUri() { + if (index >= uris.length) { + index = 0; + } + return uris[index++]; + } + @Benchmark public ByteBuf oldEncoder() throws Exception { ByteBuf buffer = Unpooled.buffer(100); try { encoderOld.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, - HttpMethod.GET, uri)); + HttpMethod.GET, nextUri())); return buffer; } finally { buffer.release(); @@ -55,7 +126,7 @@ public ByteBuf newEncoder() throws Exception { ByteBuf buffer = Unpooled.buffer(100); try { encoderNew.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, - HttpMethod.GET, uri)); + HttpMethod.GET, nextUri())); return buffer; } finally { buffer.release(); diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java new file mode 100644 index 00000000000..365decd1d7f --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java @@ -0,0 +1,114 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.microbench.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.ReferenceCountUtil; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import static io.netty.handler.codec.http.HttpConstants.CR; +import static io.netty.handler.codec.http.HttpConstants.LF; + +@State(Scope.Thread) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 10, time = 1) +public class HttpChunkedRequestResponseBenchmark extends AbstractMicrobenchmark { + private static final int CRLF_SHORT = (CR << 8) + LF; + + ByteBuf POST; + int readerIndex; + int writeIndex; + EmbeddedChannel nettyChannel; + + @Setup + public void setup() { + HttpRequestDecoder httpRequestDecoder = new HttpRequestDecoder( + HttpRequestDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH, HttpRequestDecoder.DEFAULT_MAX_HEADER_SIZE, + HttpRequestDecoder.DEFAULT_MAX_CHUNK_SIZE, false); + ChannelInboundHandlerAdapter inboundHandlerAdapter = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object o) { + // this is saving a slow type check on LastHttpContent vs HttpRequest + try { + if (o == LastHttpContent.EMPTY_LAST_CONTENT) { + writeResponse(ctx); + } + } finally { + ReferenceCountUtil.release(o); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + } + + private void writeResponse(ChannelHandlerContext ctx) { + ByteBuf buffer = ctx.alloc().buffer(); + // Build the response object. + ByteBufUtil.writeAscii(buffer, "HTTP/1.1 200 OK\r\n"); + ByteBufUtil.writeAscii(buffer, "Content-Length: 0\r\n\r\n"); + ctx.write(buffer, ctx.voidPromise()); + } + }; + nettyChannel = new EmbeddedChannel(httpRequestDecoder, inboundHandlerAdapter); + + ByteBuf buffer = Unpooled.buffer(); + ByteBufUtil.writeAscii(buffer, "POST / HTTP/1.1\r\n"); + ByteBufUtil.writeAscii(buffer, "Content-Type: text/plain\r\n"); + ByteBufUtil.writeAscii(buffer, "Transfer-Encoding: chunked\r\n\r\n"); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(43) + "\r\n"); + buffer.writeZero(43); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(18) + + ";extension=kjhkasdhfiushdksjfnskdjfbskdjfbskjdfb\r\n"); + buffer.writeZero(18); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(29) + + ";a=12938746238;b=\"lkjkjhskdfhsdkjh\\\"kjshdflkjhdskjhifuwehwi\";c=lkjdshfkjshdiufh\r\n"); + buffer.writeZero(29); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(9) + + ";A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A\r\n"); + buffer.writeZero(9); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, "0\r\n\r\n"); // Last empty chunk + POST = Unpooled.unreleasableBuffer(buffer); + readerIndex = POST.readerIndex(); + writeIndex = POST.writerIndex(); + } + + @Benchmark + public Object netty() { + POST.setIndex(readerIndex, writeIndex); + ByteBuf byteBuf = POST.retainedDuplicate(); + nettyChannel.writeInbound(byteBuf); + return nettyChannel.outboundMessages().poll(); + } +} diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java index 0716a8fe2f7..54dba8e3ddd 100644 --- a/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java +++ b/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java @@ -68,7 +68,7 @@ public class HttpRequestResponseBenchmark extends AbstractMicrobenchmark { static class Alloc implements ByteBufAllocator { - private final ByteBuf buf = Unpooled.buffer(); + private final ByteBuf buf = Unpooled.buffer(512); private final int capacity = buf.capacity(); @Override @@ -82,7 +82,8 @@ public ByteBuf buffer(int initialCapacity) { if (initialCapacity <= capacity) { return buffer(); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException( + "initialCapacity " + initialCapacity + " is greater than capacity " + capacity); } } @@ -91,7 +92,8 @@ public ByteBuf buffer(int initialCapacity, int maxCapacity) { if (initialCapacity <= capacity) { return buffer(); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException( + "initialCapacity " + initialCapacity + " is greater than capacity " + capacity); } } diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java new file mode 100644 index 00000000000..80b2f1ec736 --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.microbench.http; + +import io.netty.handler.codec.http.HttpUtil; +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.TimeUnit; + +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 10, time = 1) +public class HttpUtilBenchmark extends AbstractMicrobenchmark { + private static final String uri = "https://github.com/netty/netty/blob/893508ce62a7f90464f8e4bf2ac28ecc73ce6608/" + + "handler/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java"; + + @Benchmark + public boolean checkIsEncodingSafeUri() { + return HttpUtil.isEncodingSafeStartLineToken(uri); + } +} diff --git a/pom.xml b/pom.xml index d32129b8bd5..dcb97b4e2c8 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ io.netty netty-parent pom - 4.1.128.1.dse + 4.1.132.1.dse Netty https://netty.io/ @@ -53,7 +53,7 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.128.Final + netty-4.1.132.Final @@ -680,7 +680,7 @@ boringssl-snapshot netty-tcnative-boringssl-static - 2.0.75.Final-SNAPSHOT + 2.0.76.Final-SNAPSHOT ${os.detected.classifier} @@ -828,7 +828,7 @@ fedora,suse,arch netty-tcnative - 2.0.74.Final + 2.0.75.Final ${os.detected.classifier} org.conscrypt conscrypt-openjdk-uber @@ -844,7 +844,7 @@ ${os.detected.name}-${os.detected.arch} ${project.basedir}/../common/src/test/resources/logback-test.xml warn - 2.17.2 + 2.25.3 3.0.0 5.12.1 false @@ -1014,7 +1014,7 @@ org.bouncycastle bcpkix-jdk15on - 1.69 + 1.70 compile true @@ -1026,7 +1026,7 @@ org.bouncycastle bcprov-jdk15on - 1.69 + 1.70 compile true @@ -1037,7 +1037,7 @@ org.bouncycastle bctls-jdk15on - 1.69 + 1.70 compile true @@ -1056,12 +1056,12 @@ com.ning compress-lzf - 1.0.3 + 1.2.0 - org.lz4 + at.yawk.lz4 lz4-java - 1.8.0 + 1.10.1 com.github.jponge @@ -1234,13 +1234,13 @@ org.assertj assertj-core - 3.18.0 + 3.27.7 test org.mockito mockito-core - 2.18.3 + 4.11.0 test @@ -1288,7 +1288,7 @@ org.apache.commons commons-compress - 1.26.0 + 1.28.0 test @@ -1296,7 +1296,7 @@ commons-io commons-io - 2.14.0 + 2.20.0 test @@ -1335,7 +1335,7 @@ io.projectreactor.tools blockhound - 1.0.14.RELEASE + 1.0.16.RELEASE diff --git a/resolver-dns-classes-macos/pom.xml b/resolver-dns-classes-macos/pom.xml index 230036b95b1..0df3f2a9e8b 100644 --- a/resolver-dns-classes-macos/pom.xml +++ b/resolver-dns-classes-macos/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-resolver-dns-classes-macos diff --git a/resolver-dns-native-macos/pom.xml b/resolver-dns-native-macos/pom.xml index dfa467c1f1e..05afffbd67b 100644 --- a/resolver-dns-native-macos/pom.xml +++ b/resolver-dns-native-macos/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-resolver-dns-native-macos diff --git a/resolver-dns/pom.xml b/resolver-dns/pom.xml index addbb1dae09..db1ced3644c 100644 --- a/resolver-dns/pom.xml +++ b/resolver-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-resolver-dns diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index cc7ed6a7180..e60d54c1144 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -68,6 +68,7 @@ import org.apache.directory.server.dns.store.RecordStore; import org.apache.mina.core.buffer.IoBuffer; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -3492,8 +3493,8 @@ private static ServerSocket startDnsServerAndCreateServerSocket(TestDnsServer dn serverSocket.close(); if (i == 10) { // We tried 10 times without success - throw new IllegalStateException( - "Unable to bind TestDnsServer and ServerSocket to the same address", e); + Assumptions.abort("Unable to bind TestDnsServer and ServerSocket to the same address: " + + e.getMessage()); } // We could not start the DnsServer which is most likely because the localAddress was already used, // let's retry diff --git a/resolver/pom.xml b/resolver/pom.xml index 4fca568e0f9..340b51bd533 100644 --- a/resolver/pom.xml +++ b/resolver/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-resolver diff --git a/testsuite-autobahn/pom.xml b/testsuite-autobahn/pom.xml index af5acaf68b3..9b342ff98cb 100644 --- a/testsuite-autobahn/pom.xml +++ b/testsuite-autobahn/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-autobahn diff --git a/testsuite-http2/pom.xml b/testsuite-http2/pom.xml index cec0164e73e..6583f72736e 100644 --- a/testsuite-http2/pom.xml +++ b/testsuite-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-http2 diff --git a/testsuite-native-image-client-runtime-init/pom.xml b/testsuite-native-image-client-runtime-init/pom.xml index 8fa230f0e6a..948f61f9b27 100644 --- a/testsuite-native-image-client-runtime-init/pom.xml +++ b/testsuite-native-image-client-runtime-init/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-native-image-client-runtime-init diff --git a/testsuite-native-image-client/pom.xml b/testsuite-native-image-client/pom.xml index 563d24b1d6f..17aad0baa7b 100644 --- a/testsuite-native-image-client/pom.xml +++ b/testsuite-native-image-client/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-native-image-client diff --git a/testsuite-native-image/pom.xml b/testsuite-native-image/pom.xml index 2266f6eac51..bdba1a84c4f 100644 --- a/testsuite-native-image/pom.xml +++ b/testsuite-native-image/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-native-image diff --git a/testsuite-native/pom.xml b/testsuite-native/pom.xml index 3b5d008617a..54ec10e4aca 100644 --- a/testsuite-native/pom.xml +++ b/testsuite-native/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-native diff --git a/testsuite-osgi/pom.xml b/testsuite-osgi/pom.xml index 0ce12fe0e03..a604f282c4a 100644 --- a/testsuite-osgi/pom.xml +++ b/testsuite-osgi/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-osgi diff --git a/testsuite-shading/pom.xml b/testsuite-shading/pom.xml index 43f0355a568..b4ce5268425 100644 --- a/testsuite-shading/pom.xml +++ b/testsuite-shading/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite-shading diff --git a/testsuite/pom.xml b/testsuite/pom.xml index 289cccfee8f..3e20d6ca991 100644 --- a/testsuite/pom.xml +++ b/testsuite/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-testsuite diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java new file mode 100644 index 00000000000..c26e3a0fb7c --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java @@ -0,0 +1,141 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.DatagramPacket; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.net.InetSocketAddress; +import java.net.PortUnreachableException; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DatagramConnectedWriteExceptionTest extends AbstractClientSocketTest { + + @Override + protected List> newFactories() { + return SocketTestPermutation.INSTANCE.datagramSocket(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @DisabledOnOs(OS.WINDOWS) + public void testWriteThrowsPortUnreachableException(TestInfo testInfo) throws Throwable { + run(testInfo, new Runner() { + @Override + public void run(Bootstrap bootstrap) throws Throwable { + testWriteExceptionAfterServerStop(bootstrap); + } + }); + } + + protected void testWriteExceptionAfterServerStop(Bootstrap clientBootstrap) throws Throwable { + final CountDownLatch serverReceivedLatch = new CountDownLatch(1); + Bootstrap serverBootstrap = clientBootstrap.clone() + .option(ChannelOption.SO_BROADCAST, false) + .handler(new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + serverReceivedLatch.countDown(); + } + }); + + Channel serverChannel = serverBootstrap.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).sync().channel(); + InetSocketAddress serverAddress = (InetSocketAddress) serverChannel.localAddress(); + + clientBootstrap.option(ChannelOption.AUTO_READ, false) + .handler(new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + // no-op + } + }); + + Channel clientChannel = clientBootstrap.connect(serverAddress).sync().channel(); + + final CountDownLatch clientFirstSendLatch = new CountDownLatch(1); + try { + ByteBuf firstMessage = Unpooled.wrappedBuffer("First message".getBytes(CharsetUtil.UTF_8)); + clientChannel.writeAndFlush(firstMessage) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + clientFirstSendLatch.countDown(); + } + } + }); + + assertTrue(serverReceivedLatch.await(5, TimeUnit.SECONDS), "Server should receive first message"); + assertTrue(clientFirstSendLatch.await(5, TimeUnit.SECONDS), "Client should send first message"); + + serverChannel.close().sync(); + + final AtomicReference writeException = new AtomicReference(); + final CountDownLatch writesCompleteLatch = new CountDownLatch(10); + + for (int i = 0; i < 10; i++) { + ByteBuf message = Unpooled.wrappedBuffer(("Message " + i).getBytes(CharsetUtil.UTF_8)); + clientChannel.writeAndFlush(message) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + writeException.compareAndSet(null, future.cause()); + } + writesCompleteLatch.countDown(); + } + }); + Thread.sleep(50); + } + + assertTrue(writesCompleteLatch.await(5, TimeUnit.SECONDS), "All writes should complete"); + + assertNotNull(writeException.get(), "Should have captured a write exception"); + + assertInstanceOf(PortUnreachableException.class, writeException.get(), "Expected " + + "PortUnreachableException but got: " + writeException.get().getClass().getName()); + } finally { + if (clientChannel != null) { + clientChannel.close().sync(); + } + } + } +} diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java index ed91927ad91..4a4ae3a1eb9 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java @@ -128,7 +128,7 @@ public static Collection data() throws Exception { public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate, TestInfo testInfo) throws Throwable { // BoringSSL does not support renegotiation intentionally. - assumeFalse("BoringSSL".equals(OpenSsl.versionString())); + assumeFalse("BoringSSL".equals(OpenSsl.versionString()) || OpenSsl.versionString().startsWith("AWS-LC")); assumeTrue(OpenSsl.isAvailable()); run(testInfo, new Runner() { @Override @@ -206,6 +206,7 @@ public void initChannel(Channel sch) throws Exception { } finally { if (executorService != null) { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index a8fab7c4806..63b8a9b0003 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -381,6 +381,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { clientChannel.close().awaitUninterruptibly(); sc.close().awaitUninterruptibly(); delegatedTaskExecutor.shutdown(); + assertTrue(delegatedTaskExecutor.awaitTermination(5, TimeUnit.SECONDS)); if (serverException.get() != null && !(serverException.get() instanceof IOException)) { throw serverException.get(); diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java index 7cd3ece19df..70a20487950 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java @@ -58,6 +58,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class SocketSslGreetingTest extends AbstractSocketTest { @@ -179,6 +180,7 @@ public void initChannel(Channel sch) throws Exception { } finally { if (executorService != null) { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } } diff --git a/transport-blockhound-tests/pom.xml b/transport-blockhound-tests/pom.xml index c109f9ba179..385b0a4496a 100644 --- a/transport-blockhound-tests/pom.xml +++ b/transport-blockhound-tests/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-blockhound-tests diff --git a/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java b/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java index 403b97a4189..3b8da14d4e0 100644 --- a/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java +++ b/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java @@ -250,6 +250,7 @@ public void testHandshakeWithExecutor() throws Exception { testHandshakeWithExecutor(executorService, "TLSv1.2"); } finally { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } @@ -261,6 +262,7 @@ public void testHandshakeWithExecutorTLSv13() throws Exception { testHandshakeWithExecutor(executorService, "TLSv1.3"); } finally { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } diff --git a/transport-classes-epoll/pom.xml b/transport-classes-epoll/pom.xml index 593265cc3ce..3022af831ba 100644 --- a/transport-classes-epoll/pom.xml +++ b/transport-classes-epoll/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-classes-epoll diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java index c4ea86f452d..bf52134a1ac 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java @@ -77,8 +77,7 @@ protected Object filterOutboundMessage(Object msg) throws Exception { final class EpollServerSocketUnsafe extends AbstractEpollUnsafe { // Will hold the remote address after accept(...) was successful. // We need 24 bytes for the address as maximum + 1 byte for storing the length. - // So use 26 bytes as it's a power of two. - private final byte[] acceptedAddress = new byte[26]; + private final byte[] acceptedAddress = new byte[25]; @Override public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) { diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index 69da8a1e3b8..e606eca38be 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -129,8 +129,9 @@ public ChannelMetadata metadata() { *

  • {@link EpollChannelConfig#getEpollMode()} must be {@link EpollMode#LEVEL_TRIGGERED} for this and the * target {@link AbstractEpollStreamChannel}
  • * - * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final int len) { return spliceTo(ch, len, newPromise()); } @@ -147,8 +148,9 @@ public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final i *
  • {@link EpollChannelConfig#getEpollMode()} must be {@link EpollMode#LEVEL_TRIGGERED} for this and the * target {@link AbstractEpollStreamChannel}
  • * - * + * @deprecated will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final int len, final ChannelPromise promise) { if (ch.eventLoop() != eventLoop()) { @@ -182,7 +184,9 @@ public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final i *
  • the {@link FileDescriptor} will not be closed after the {@link ChannelFuture} is notified
  • *
  • this channel must be registered to an event loop or {@link IllegalStateException} will be thrown.
  • * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, final int len) { return spliceTo(ch, offset, len, newPromise()); } @@ -200,7 +204,9 @@ public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, f *
  • the {@link FileDescriptor} will not be closed after the {@link ChannelPromise} is notified
  • *
  • this channel must be registered to an event loop or {@link IllegalStateException} will be thrown.
  • * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, final int len, final ChannelPromise promise) { checkPositiveOrZero(len, "len"); diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index c42ac048467..613e2c2f274 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -415,7 +415,14 @@ private boolean doWriteMessage(Object msg) throws Exception { return true; } - return doWriteOrSendBytes(data, remoteAddress, false) > 0; + try { + return doWriteOrSendBytes(data, remoteAddress, false) > 0; + } catch (NativeIoException e) { + if (remoteAddress == null) { + throw translateForConnected(e); + } + throw e; + } } private static void checkUnresolved(AddressedEnvelope envelope) { diff --git a/transport-classes-kqueue/pom.xml b/transport-classes-kqueue/pom.xml index 4242dd74bc3..113b5ae7a7c 100644 --- a/transport-classes-kqueue/pom.xml +++ b/transport-classes-kqueue/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-classes-kqueue diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java index 8a4c56cd191..93a26b7a85d 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java @@ -77,8 +77,7 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr final class KQueueServerSocketUnsafe extends AbstractKQueueUnsafe { // Will hold the remote address after accept(...) was successful. // We need 24 bytes for the address as maximum + 1 byte for storing the capacity. - // So use 26 bytes as it's a power of two. - private final byte[] acceptedAddress = new byte[26]; + private final byte[] acceptedAddress = new byte[25]; @Override void readReady(KQueueRecvByteAllocatorHandle allocHandle) { diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 70f848a3e90..52aa4d4fb27 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -34,6 +34,7 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; @@ -265,7 +266,11 @@ protected boolean doWriteMessage(Object msg) throws Exception { if (data.hasMemoryAddress()) { long memoryAddress = data.memoryAddress(); if (remoteAddress == null) { - writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + try { + writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + } catch (Errors.NativeIoException e) { + throw translateForConnected(e); + } } else { writtenBytes = socket.sendToAddress(memoryAddress, data.readerIndex(), data.writerIndex(), remoteAddress.getAddress(), remoteAddress.getPort()); @@ -295,6 +300,16 @@ protected boolean doWriteMessage(Object msg) throws Exception { return writtenBytes > 0; } + private static IOException translateForConnected(Errors.NativeIoException e) { + // We need to correctly translate connect errors to match NIO behaviour. + if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { + PortUnreachableException error = new PortUnreachableException(e.getMessage()); + error.initCause(e); + return error; + } + return e; + } + private static void checkUnresolved(AddressedEnvelope envelope) { if (envelope.recipient() instanceof InetSocketAddress && (((InetSocketAddress) envelope.recipient()).isUnresolved())) { diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java index 87081a82c88..99ec7e620fd 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java @@ -97,7 +97,7 @@ private void reallocIfNeeded() { */ void realloc(boolean throwIfFail) { // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. - int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; + int newLength = capacity <= 65536 ? capacity << 1 : capacity + (capacity >> 1); try { ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java index 5c44c57ca45..42ccb20a16b 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java @@ -85,7 +85,7 @@ private long memoryOffset(int index) { private void reallocIfNeeded() { if (size == capacity) { // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. - int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; + int newLength = capacity <= 65536 ? capacity << 1 : capacity + (capacity >> 1); ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); // Copy over the old content of the memory and reset the position as we always act on the buffer as if // the position was never increased. diff --git a/transport-native-epoll/pom.xml b/transport-native-epoll/pom.xml index 7ac617deb2a..1e36a68612b 100644 --- a/transport-native-epoll/pom.xml +++ b/transport-native-epoll/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-native-epoll diff --git a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c index cd1e6abfb14..7528c679749 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c +++ b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c @@ -507,8 +507,13 @@ static void netty_epoll_linuxsocket_setTcpMd5Sig(JNIEnv* env, jclass clazz, jint } if (key != NULL) { - md5sig.tcpm_keylen = (*env)->GetArrayLength(env, key); - (*env)->GetByteArrayRegion(env, key, 0, md5sig.tcpm_keylen, (void *) &md5sig.tcpm_key); + jint keylen = (*env)->GetArrayLength(env, key); + if (keylen > TCP_MD5SIG_MAXKEYLEN) { + netty_unix_errors_throwIOException(env, "key is too long"); + return; + } + md5sig.tcpm_keylen = (u_int16_t) keylen; + (*env)->GetByteArrayRegion(env, key, 0, keylen, (void *) &md5sig.tcpm_key); if ((*env)->ExceptionCheck(env) == JNI_TRUE) { return; } diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index eda94b3992b..22e29f65c19 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -277,7 +277,7 @@ static inline jint netty_epoll_wait(JNIEnv* env, jint efd, struct epoll_event *e netty_unix_errors_throwRuntimeExceptionErrorNo(env, "clock_gettime() failed: ", errno); return -1; } - deadline = ts.tv_sec * 1000 + ts.tv_nsec / 1000 + timeout; + deadline = ts.tv_sec * 1000 + ts.tv_nsec / 1000000 + timeout; while ((rc = epoll_wait(efd, ev, len, timeout)) < 0) { if (errno != EINTR) { @@ -289,7 +289,7 @@ static inline jint netty_epoll_wait(JNIEnv* env, jint efd, struct epoll_event *e return -1; } - now = ts.tv_sec * 1000 + ts.tv_nsec / 1000; + now = ts.tv_sec * 1000 + ts.tv_nsec / 1000000; if (now >= deadline) { return 0; } @@ -495,6 +495,11 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo for (i = 0; i < len; i++) { jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + if (packet == NULL) { + // This should never happen but just handle it and return early. This way if GetObjectArrayElement(...) + // did put an exception on the stack we will see it and not crash. + return -1; + } jbyteArray address = (jbyteArray) (*env)->GetObjectField(env, packet, packetRecipientAddrFieldId); jint addrLen = (*env)->GetIntField(env, packet, packetRecipientAddrLenFieldId); jint packetSegmentSize = (*env)->GetIntField(env, packet, packetSegmentSizeFieldId); @@ -623,7 +628,7 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo #ifdef IP_RECVORIGDSTADDR int readLocalAddr = 0; if (netty_unix_socket_getOption(env, fd, IPPROTO_IP, IP_RECVORIGDSTADDR, - &readLocalAddr, sizeof(readLocalAddr)) < 0) { + &readLocalAddr, sizeof(readLocalAddr)) != -1 && readLocalAddr != 0) { cntrlbuf = malloc(sizeof(char) * storageSize * len); } #endif // IP_RECVORIGDSTADDR @@ -632,11 +637,16 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo for (i = 0; i < len; i++) { jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + if (packet == NULL) { + // This should never happen but just handle it and return early. This way if GetObjectArrayElement(...) + // did put an exception on the stack we will see it and not crash. + return -1; + } msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId); msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId); msg[i].msg_hdr.msg_name = addr + i; - msg[i].msg_hdr.msg_namelen = (socklen_t) addrSize; + msg[i].msg_hdr.msg_namelen = (socklen_t) storageSize; if (cntrlbuf != NULL) { msg[i].msg_hdr.msg_control = cntrlbuf + i * storageSize; diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java index 6bbcb2e660b..17cce3760b9 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java @@ -16,6 +16,7 @@ package io.netty.channel.epoll; import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOption; import io.netty.channel.ConnectTimeoutException; @@ -53,10 +54,10 @@ public static void afterClass() { @BeforeEach public void setup() { - Bootstrap bootstrap = new Bootstrap(); + ServerBootstrap bootstrap = new ServerBootstrap(); server = (EpollServerSocketChannel) bootstrap.group(GROUP) .channel(EpollServerSocketChannel.class) - .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) .bind(new InetSocketAddress(NetUtil.LOCALHOST4, 0)).syncUninterruptibly().channel(); } @@ -74,10 +75,10 @@ public void testServerSocketChannelOption() throws Exception { @Test public void testServerOption() throws Exception { - Bootstrap bootstrap = new Bootstrap(); + ServerBootstrap bootstrap = new ServerBootstrap(); EpollServerSocketChannel ch = (EpollServerSocketChannel) bootstrap.group(GROUP) .channel(EpollServerSocketChannel.class) - .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); ch.config().setOption(EpollChannelOption.TCP_MD5SIG, diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java index 3e09949053e..bfc5fc2fce3 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -33,6 +34,29 @@ public void testIsAvailable() { assertTrue(Epoll.isAvailable()); } + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testEpollWaitTimeoutAccuracy() throws Exception { + final int timeoutMs = 200; + final FileDescriptor epoll = Native.newEpollCreate(); + final EpollEventArray eventArray = new EpollEventArray(8); + try { + long startNs = System.nanoTime(); + // No fds registered, so this will just wait for the timeout. + int ready = Native.epollWait(epoll, eventArray, timeoutMs); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs); + + assertEquals(0, ready); + // Should have waited at least close to the timeout + assertThat(elapsedMs).isGreaterThanOrEqualTo(timeoutMs - 20); + // Should not have waited vastly longer than the timeout + assertThat(elapsedMs).isLessThan(timeoutMs + 200); + } finally { + eventArray.free(); + epoll.close(); + } + } + // Testcase for https://github.com/netty/netty/issues/8444 @Test @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java index 4fe962e8575..2154f0f28d7 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; public class LinuxSocketTest { @@ -94,4 +95,21 @@ public void execute() throws Throwable { socket.close(); } } + + @Test + public void testUnixAbstractDomainSocket() throws IOException { + String address = "\0" + UUID.randomUUID(); + + final DomainSocketAddress domainSocketAddress = new DomainSocketAddress(address); + final Socket socket = Socket.newSocketDomain(); + try { + socket.bind(domainSocketAddress); + DomainSocketAddress local = socket.localDomainSocketAddress(); + assertEquals(domainSocketAddress, local); + assertEquals(address, domainSocketAddress.path()); + assertEquals(address, local.path()); + } finally { + socket.close(); + } + } } diff --git a/transport-native-kqueue/pom.xml b/transport-native-kqueue/pom.xml index f6fcc3b87d0..8936cd398c7 100644 --- a/transport-native-kqueue/pom.xml +++ b/transport-native-kqueue/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-native-kqueue diff --git a/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c b/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c index 8e13979ee0a..d0bc5a73ad2 100644 --- a/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c +++ b/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c @@ -144,12 +144,36 @@ static void netty_kqueue_bsdsocket_setAcceptFilter(JNIEnv* env, jclass clazz, ji const char* tmpString = NULL; af.af_name[0] = af.af_arg[0] ='\0'; + jsize len = (*env)->GetStringUTFLength(env, afName); + if (len > sizeof(af.af_name)) { + // Too large and so can't be stored + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", EOVERFLOW); + return; + } tmpString = (*env)->GetStringUTFChars(env, afName, NULL); - strncat(af.af_name, tmpString, sizeof(af.af_name) / sizeof(af.af_name[0])); + if (tmpString == NULL) { + // if NULL is returned it failed due OOME + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", ENOMEM); + return; + } + + strlcat(af.af_name, tmpString, sizeof(af.af_name)); (*env)->ReleaseStringUTFChars(env, afName, tmpString); + len = (*env)->GetStringUTFLength(env, afArg); + if (len > sizeof(af.af_arg)) { + // Too large and so can't be stored + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", EOVERFLOW); + return; + } + tmpString = (*env)->GetStringUTFChars(env, afArg, NULL); - strncat(af.af_arg, tmpString, sizeof(af.af_arg) / sizeof(af.af_arg[0])); + if (tmpString == NULL) { + // if NULL is returned it failed due OOME + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", ENOMEM); + return; + } + strlcat(af.af_arg, tmpString, sizeof(af.af_arg)); (*env)->ReleaseStringUTFChars(env, afArg, tmpString); netty_unix_socket_setOption(env, fd, SOL_SOCKET, SO_ACCEPTFILTER, &af, sizeof(af)); diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java new file mode 100644 index 00000000000..c964ca8acb3 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.DatagramConnectedWriteExceptionTest; + +import java.util.List; + +public class KQueueDatagramConnectedWriteExceptionTest extends DatagramConnectedWriteExceptionTest { + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.datagramSocket(); + } +} diff --git a/transport-native-unix-common-tests/pom.xml b/transport-native-unix-common-tests/pom.xml index 4be08a13265..14f541fb564 100644 --- a/transport-native-unix-common-tests/pom.xml +++ b/transport-native-unix-common-tests/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-native-unix-common-tests diff --git a/transport-native-unix-common/pom.xml b/transport-native-unix-common/pom.xml index dd6880e6f51..2eaf353952c 100644 --- a/transport-native-unix-common/pom.xml +++ b/transport-native-unix-common/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-native-unix-common diff --git a/transport-native-unix-common/src/main/c/netty_unix_errors.c b/transport-native-unix-common/src/main/c/netty_unix_errors.c index 1dcac708e22..bdd5e7fef03 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_errors.c +++ b/transport-native-unix-common/src/main/c/netty_unix_errors.c @@ -37,8 +37,11 @@ static jmethodID closedChannelExceptionMethodId = NULL; even on platforms where the GNU variant is exposed. Note: `strerrbuf` must be initialized to all zeros prior to calling this function. XSI or GNU functions do not have such a requirement, but our wrappers do. + + Android exposes the XSI variant by default, see + https://cs.android.com/android/platform/superproject/+/android16-release:bionic/libc/include/string.h;l=145?q=string.h */ -#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE +#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__ || __ANDROID__) && ! _GNU_SOURCE static inline int strerror_r_xsi(int errnum, char *strerrbuf, size_t buflen) { return strerror_r(errnum, strerrbuf, buflen); } diff --git a/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c b/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c index dcb4b34a015..b8ac63ec182 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c +++ b/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c @@ -100,6 +100,9 @@ static jint netty_unix_filedescriptor_close(JNIEnv* env, jclass clazz, jint fd) static jint netty_unix_filedescriptor_open(JNIEnv* env, jclass clazz, jstring path, jint flags) { const char* f_path = (*env)->GetStringUTFChars(env, path, 0); + if (f_path == NULL) { + return -ENOMEM; + } int res = open(f_path, flags, 0666); (*env)->ReleaseStringUTFChars(env, path, f_path); diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.c b/transport-native-unix-common/src/main/c/netty_unix_socket.c index 0efea5d65a0..ebacb134561 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.c +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.c @@ -133,11 +133,23 @@ static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_st return obj; } -static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, int len, jobject local) { +static int domainSocketPathLength(const struct sockaddr_un* s, const socklen_t addrlen) { +#ifdef __linux__ + // Linux supports abstract domain sockets so we need to handle it. + // https://man7.org/linux/man-pages/man7/unix.7.html + if (addrlen >= sizeof(sa_family_t) && s->sun_path[0] == '\0') { + // This is an abstract domain socket address + return (addrlen - sizeof(sa_family_t)); + } +#endif + return strlen(s->sun_path); +} + +static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, const socklen_t addrlen, int len, jobject local) { jclass domainDatagramSocketAddressClass = NULL; jobject obj = NULL; struct sockaddr_un* s = (struct sockaddr_un*) addr; - int pathLength = strlen(s->sun_path); + int pathLength = domainSocketPathLength(s, addrlen); jbyteArray pathBytes = (*env)->NewByteArray(env, pathLength); if (pathBytes == NULL) { return NULL; @@ -157,9 +169,9 @@ static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct socka return obj; } -static jbyteArray netty_unix_socket_createDomainSocketAddressArray(JNIEnv* env, const struct sockaddr_storage* addr) { +static jbyteArray netty_unix_socket_createDomainSocketAddressArray(JNIEnv* env, const struct sockaddr_storage* addr, const socklen_t addrlen) { struct sockaddr_un* s = (struct sockaddr_un*) addr; - int pathLength = strlen(s->sun_path); + int pathLength = domainSocketPathLength(s, addrlen); jbyteArray pathBytes = (*env)->NewByteArray(env, pathLength); if (pathBytes == NULL) { return NULL; @@ -446,7 +458,7 @@ static jobject _recvFromDomainSocket(JNIEnv* env, jint fd, void* buffer, jint po int err; do { - bzero(&addr, sizeof(addr)); // Zap addr so we can strlen(addr.sun_path) later. See unix(4). + memset(&addr, 0, sizeof(addr)); // Zap addr so we can strlen(addr.sun_path) later. See unix(4). res = recvfrom(fd, buffer + pos, (size_t) (limit - pos), 0, (struct sockaddr*) &addr, &addrlen); // Keep on reading if it was interrupted } while (res == -1 && ((err = errno) == EINTR)); @@ -464,7 +476,7 @@ static jobject _recvFromDomainSocket(JNIEnv* env, jint fd, void* buffer, jint po return NULL; } - return createDomainDatagramSocketAddress(env, &addr, res, NULL); + return createDomainDatagramSocketAddress(env, &addr, addrlen, res, NULL); } static jint _send(JNIEnv* env, jclass clazz, jint fd, void* buffer, jint pos, jint limit) { @@ -687,8 +699,10 @@ static jint netty_unix_socket_accept(JNIEnv* env, jclass clazz, jint fd, jbyteAr if (accept4) { return socketFd; } + // accept4 was not present so need two more sys-calls ... if (fcntl(socketFd, F_SETFD, FD_CLOEXEC) == -1 || fcntl(socketFd, F_SETFL, O_NONBLOCK) == -1) { - // accept4 was not present so need two more sys-calls ... + // close the fd before report the error so we don't leak it. + close(socketFd); return -errno; } return socketFd; @@ -709,7 +723,7 @@ static jbyteArray netty_unix_socket_remoteDomainSocketAddress(JNIEnv* env, jclas if (getpeername(fd, (struct sockaddr*) &addr, &len) == -1) { return NULL; } - return netty_unix_socket_createDomainSocketAddressArray(env, &addr); + return netty_unix_socket_createDomainSocketAddressArray(env, &addr, len); } static jbyteArray netty_unix_socket_localAddress(JNIEnv* env, jclass clazz, jint fd) { @@ -727,7 +741,7 @@ static jbyteArray netty_unix_socket_localDomainSocketAddress(JNIEnv* env, jclass if (getsockname(fd, (struct sockaddr*) &addr, &len) == -1) { return NULL; } - return netty_unix_socket_createDomainSocketAddressArray(env, &addr); + return netty_unix_socket_createDomainSocketAddressArray(env, &addr, len); } static jint netty_unix_socket_newSocketDgramFd(JNIEnv* env, jclass clazz, jboolean ipv6) { diff --git a/transport-rxtx/pom.xml b/transport-rxtx/pom.xml index df1cc7672c6..daebc62c765 100644 --- a/transport-rxtx/pom.xml +++ b/transport-rxtx/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-rxtx diff --git a/transport-sctp/pom.xml b/transport-sctp/pom.xml index 5ce10d35157..c4ede212c82 100644 --- a/transport-sctp/pom.xml +++ b/transport-sctp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-sctp diff --git a/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java b/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java index 2860ba74b54..bebe380f221 100644 --- a/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java +++ b/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java @@ -54,7 +54,8 @@ public DefaultSctpServerChannelConfig( public Map, Object> getOptions() { return getOptions( super.getOptions(), - ChannelOption.SO_RCVBUF, ChannelOption.SO_SNDBUF, SctpChannelOption.SCTP_INIT_MAXSTREAMS); + ChannelOption.SO_RCVBUF, ChannelOption.SO_SNDBUF, ChannelOption.SO_BACKLOG, + SctpChannelOption.SCTP_INIT_MAXSTREAMS); } @SuppressWarnings("unchecked") @@ -66,6 +67,9 @@ public T getOption(ChannelOption option) { if (option == ChannelOption.SO_SNDBUF) { return (T) Integer.valueOf(getSendBufferSize()); } + if (option == ChannelOption.SO_BACKLOG) { + return (T) Integer.valueOf(getBacklog()); + } if (option == SctpChannelOption.SCTP_INIT_MAXSTREAMS) { return (T) getInitMaxStreams(); } @@ -80,6 +84,8 @@ public boolean setOption(ChannelOption option, T value) { setReceiveBufferSize((Integer) value); } else if (option == ChannelOption.SO_SNDBUF) { setSendBufferSize((Integer) value); + } else if (option == ChannelOption.SO_BACKLOG) { + setSendBufferSize((Integer) value); } else if (option == SctpChannelOption.SCTP_INIT_MAXSTREAMS) { setInitMaxStreams((SctpStandardSocketOptions.InitMaxStreams) value); } else { diff --git a/transport-udt/pom.xml b/transport-udt/pom.xml index f2dcfb7498e..01eec69e4ce 100644 --- a/transport-udt/pom.xml +++ b/transport-udt/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport-udt diff --git a/transport/pom.xml b/transport/pom.xml index f9e2a81e2df..001287db1ba 100644 --- a/transport/pom.xml +++ b/transport/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.132.1.dse netty-transport diff --git a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java index f319944318d..d82024fe976 100644 --- a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java @@ -32,6 +32,7 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.logging.InternalLogger; import java.net.InetAddress; @@ -52,6 +53,9 @@ * transports such as datagram (UDP).

    */ public abstract class AbstractBootstrap, C extends Channel> implements Cloneable { + + private static final boolean CLOSE_ON_SET_OPTION_FAILURE = SystemPropertyUtil.getBoolean( + "io.netty.bootstrap.closeOnSetOptionFailure", true); @SuppressWarnings("unchecked") private static final Map.Entry, Object>[] EMPTY_OPTION_ARRAY = new Map.Entry[0]; @SuppressWarnings("unchecked") @@ -357,7 +361,7 @@ final ChannelFuture initAndRegister() { return regFuture; } - abstract void init(Channel channel) throws Exception; + abstract void init(Channel channel) throws Throwable; Collection getInitializerExtensions() { ClassLoader loader = extensionsClassLoader; @@ -474,7 +478,7 @@ static void setAttributes(Channel channel, Map.Entry, Object>[] } static void setChannelOptions( - Channel channel, Map.Entry, Object>[] options, InternalLogger logger) { + Channel channel, Map.Entry, Object>[] options, InternalLogger logger) throws Throwable { for (Map.Entry, Object> e: options) { setChannelOption(channel, e.getKey(), e.getValue(), logger); } @@ -482,7 +486,7 @@ static void setChannelOptions( @SuppressWarnings("unchecked") private static void setChannelOption( - Channel channel, ChannelOption option, Object value, InternalLogger logger) { + Channel channel, ChannelOption option, Object value, InternalLogger logger) throws Throwable { try { if (!channel.config().setOption((ChannelOption) option, value)) { logger.warn("Unknown channel option '{}' for channel '{}' of type '{}'", @@ -492,6 +496,10 @@ private static void setChannelOption( logger.warn( "Failed to set channel option '{}' with value '{}' for channel '{}' of type '{}'", option, value, channel, channel.getClass(), t); + if (CLOSE_ON_SET_OPTION_FAILURE) { + // Only rethrow if we want to close the channel in case of a failure. + throw t; + } } } diff --git a/transport/src/main/java/io/netty/bootstrap/Bootstrap.java b/transport/src/main/java/io/netty/bootstrap/Bootstrap.java index cfba85fe31c..5c71a02d8a1 100644 --- a/transport/src/main/java/io/netty/bootstrap/Bootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/Bootstrap.java @@ -271,11 +271,12 @@ public void run() { } @Override - void init(Channel channel) { + void init(Channel channel) throws Throwable { ChannelPipeline p = channel.pipeline(); p.addLast(config.handler()); setChannelOptions(channel, newOptionsArray(), logger); + setAttributes(channel, newAttributesArray()); Collection extensions = getInitializerExtensions(); if (!extensions.isEmpty()) { diff --git a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java index c8a17fc06f4..b3e14c4e715 100644 --- a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java @@ -132,7 +132,7 @@ public ServerBootstrap childHandler(ChannelHandler childHandler) { } @Override - void init(Channel channel) { + void init(Channel channel) throws Throwable { setChannelOptions(channel, newOptionsArray(), logger); setAttributes(channel, newAttributesArray()); @@ -227,7 +227,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { child.pipeline().addLast(childHandler); - setChannelOptions(child, childOptions, logger); + try { + setChannelOptions(child, childOptions, logger); + } catch (Throwable cause) { + forceClose(child, cause); + return; + } setAttributes(child, childAttrs); if (!extensions.isEmpty()) { diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 8e5f9c50c96..82b50da57c3 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -78,6 +78,13 @@ public void run() { } }; + private final Runnable finishReadTask = new Runnable() { + @Override + public void run() { + finishPeerRead0(LocalChannel.this); + } + }; + private volatile State state; private volatile LocalChannel peer; private volatile LocalAddress localAddress; @@ -418,21 +425,19 @@ private void finishPeerRead(final LocalChannel peer) { } } - private void runFinishPeerReadTask(final LocalChannel peer) { + private void runFinishTask0() { // If the peer is writing, we must wait until after reads are completed for that peer before we can read. So // we keep track of the task, and coordinate later that our read can't happen until the peer is done. - final Runnable finishPeerReadTask = new Runnable() { - @Override - public void run() { - finishPeerRead0(peer); - } - }; + if (writeInProgress) { + finishReadFuture = eventLoop().submit(finishReadTask); + } else { + eventLoop().execute(finishReadTask); + } + } + + private void runFinishPeerReadTask(final LocalChannel peer) { try { - if (peer.writeInProgress) { - peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); - } else { - peer.eventLoop().execute(finishPeerReadTask); - } + peer.runFinishTask0(); } catch (Throwable cause) { logger.warn("Closing Local channels {}-{} because exception occurred!", this, peer, cause); close(); @@ -482,7 +487,6 @@ public void connect(final SocketAddress remoteAddress, if (state == State.CONNECTED) { Exception cause = new AlreadyConnectedException(); safeSetFailure(promise, cause); - pipeline().fireExceptionCaught(cause); return; } diff --git a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java index bfa2bae8459..4927a337494 100644 --- a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java @@ -30,6 +30,7 @@ import io.netty.channel.socket.ChannelInputShutdownReadComplete; import io.netty.channel.socket.SocketChannelConfig; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.ThrowableUtil; import java.io.IOException; import java.nio.channels.SelectableChannel; @@ -115,7 +116,11 @@ private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf, Thro if (byteBuf != null) { if (byteBuf.isReadable()) { readPending = false; - pipeline.fireChannelRead(byteBuf); + try { + pipeline.fireChannelRead(byteBuf); + } catch (Exception e) { + ThrowableUtil.addSuppressed(cause, e); + } } else { byteBuf.release(); } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java index a361f3dc208..9f103b8bf35 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java @@ -555,7 +555,7 @@ public ChannelFuture block( try { key.block(sourceToBlock); } catch (IOException e) { - promise.setFailure(e); + return promise.setFailure(e); } } } diff --git a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java index f788899493f..177538bf352 100644 --- a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java +++ b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java @@ -85,6 +85,29 @@ public static void destroy() { groupB.terminationFuture().syncUninterruptibly(); } + @Test + public void testSetOptionsThrow() { + final ChannelFuture cf = new Bootstrap() + .group(groupA) + .channelFactory(new ChannelFactory() { + @Override + public Channel newChannel() { + return new TestChannel(); + } + }) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .handler(new ChannelInboundHandlerAdapter()) + .register(); + + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws Throwable { + cf.syncUninterruptibly(); + } + }); + assertFalse(cf.channel().isActive()); + } + @Test public void testOptionsCopied() { final Bootstrap bootstrapA = new Bootstrap(); @@ -578,4 +601,5 @@ public void run() { }; } } + } diff --git a/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java b/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java index 36ed66cbc5a..c2376f7f159 100644 --- a/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java +++ b/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java @@ -16,6 +16,8 @@ package io.netty.bootstrap; import io.netty.channel.Channel; +import io.netty.channel.ChannelFactory; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -24,6 +26,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalEventLoopGroup; @@ -31,6 +34,7 @@ import io.netty.util.AttributeKey; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; import java.util.UUID; import java.util.concurrent.Callable; @@ -40,12 +44,43 @@ import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class ServerBootstrapTest { + @Test + public void testSetOptionsThrow() { + LocalEventLoopGroup group = new LocalEventLoopGroup(1); + try { + final ChannelFuture cf = new ServerBootstrap() + .group(group) + .channelFactory(new ChannelFactory() { + @Override + public ServerChannel newChannel() { + return new TestServerChannel(); + } + }) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) + .register(); + + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws Throwable { + cf.syncUninterruptibly(); + } + }); + assertFalse(cf.channel().isActive()); + } finally { + group.shutdownGracefully(); + } + } + @Test @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) public void testHandlerRegister() throws Exception { @@ -240,4 +275,6 @@ public Object call() throws Exception { clientChannel.close().syncUninterruptibly(); group.shutdownGracefully(); } + + private static final class TestServerChannel extends TestChannel implements ServerChannel { } } diff --git a/transport/src/test/java/io/netty/bootstrap/TestChannel.java b/transport/src/test/java/io/netty/bootstrap/TestChannel.java new file mode 100644 index 00000000000..d654d36fffe --- /dev/null +++ b/transport/src/test/java/io/netty/bootstrap/TestChannel.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.EventLoop; + +import java.net.SocketAddress; + +class TestChannel extends AbstractChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private final ChannelConfig config; + private volatile boolean closed; + + TestChannel() { + this(null); + } + + TestChannel(Channel parent) { + super(parent); + config = new TestConfig(this); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setSuccess(); + } + }; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return true; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + // NOOP + } + + @Override + protected void doDisconnect() { + closed = true; + } + + @Override + protected void doClose() { + closed = true; + } + + @Override + protected void doBeginRead() { + // NOOP + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + // NOOP + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !closed; + } + + @Override + public boolean isActive() { + return !closed; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + private static final class TestConfig extends DefaultChannelConfig { + TestConfig(Channel channel) { + super(channel); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + throw new UnsupportedOperationException("Unsupported channel option: " + option); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java b/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java index 3c0378b849c..9df99697701 100644 --- a/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java +++ b/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java @@ -44,7 +44,7 @@ public void shouldNotDoAnythingOnRemove() { ChannelFutureListener l = Mockito.mock(ChannelFutureListener.class); future.removeListener(l); Mockito.verifyNoMoreInteractions(l); - Mockito.verifyZeroInteractions(channel); + Mockito.verifyNoInteractions(channel); } @Test @@ -60,7 +60,7 @@ public void testConstantProperties() throws InterruptedException { assertSame(future, future.awaitUninterruptibly()); assertTrue(future.awaitUninterruptibly(1)); assertTrue(future.awaitUninterruptibly(1, TimeUnit.NANOSECONDS)); - Mockito.verifyZeroInteractions(channel); + Mockito.verifyNoInteractions(channel); } private static class CompleteChannelFutureImpl extends CompleteChannelFuture { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index fa86dd10a92..6182325dc82 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -47,6 +47,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.SocketAddress; import java.util.ArrayDeque; @@ -449,6 +451,183 @@ public void channelRegistered(ChannelHandlerContext ctx) { assertTrue(latch.await(2, TimeUnit.SECONDS)); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testInboundOperationsViaContext(boolean inEventLoop) throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + final ChannelHandler handler = new ChannelHandlerAdapter() { }; + pipeline.addLast(handler); + group.register(pipeline.channel()).syncUninterruptibly(); + final BlockingQueue events = new LinkedBlockingQueue(); + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + events.add("channelRegistered"); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + events.add("channelUnregistered"); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + events.add("channelActive"); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + events.add("channelInactive"); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + events.add("channelRead"); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + events.add("channelReadComplete"); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + events.add("userEventTriggered"); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + events.add("channelWritabilityChanged"); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + events.add("exceptionCaught"); + } + }); + final ChannelHandlerContext ctx = pipeline.context(handler); + if (inEventLoop) { + pipeline.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + executeInboundOperations(ctx); + } + }); + } else { + executeInboundOperations(ctx); + } + + assertEquals("channelRegistered", events.take()); + assertEquals("channelUnregistered", events.take()); + assertEquals("channelActive", events.take()); + assertEquals("channelInactive", events.take()); + assertEquals("channelRead", events.take()); + assertEquals("channelReadComplete", events.take()); + assertEquals("userEventTriggered", events.take()); + assertEquals("channelWritabilityChanged", events.take()); + assertEquals("exceptionCaught", events.take()); + assertTrue(events.isEmpty()); + pipeline.removeLast(); + pipeline.channel().close().syncUninterruptibly(); + } + + private static void executeInboundOperations(ChannelHandlerContext ctx) { + ctx.fireChannelRegistered(); + ctx.fireChannelUnregistered(); + ctx.fireChannelActive(); + ctx.fireChannelInactive(); + ctx.fireChannelRead(""); + ctx.fireChannelReadComplete(); + ctx.fireUserEventTriggered(""); + ctx.fireChannelWritabilityChanged(); + ctx.fireExceptionCaught(new Exception()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testOutboundOperationsViaContext(boolean inEventLoop) throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + final ChannelHandler handler = new ChannelHandlerAdapter() { }; + pipeline.addLast(handler); + group.register(pipeline.channel()).syncUninterruptibly(); + final BlockingQueue events = new LinkedBlockingQueue(); + pipeline.addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + events.add("bind"); + promise.setSuccess(); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) { + events.add("connect"); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + events.add("close"); + promise.setSuccess(); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + events.add("deregister"); + promise.setSuccess(); + } + + @Override + public void read(ChannelHandlerContext ctx) { + events.add("read"); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + events.add("write"); + promise.setSuccess(); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + events.add("flush"); + ctx.flush(); + } + }); + final ChannelHandlerContext ctx = pipeline.context(handler); + if (inEventLoop) { + pipeline.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + executeOutboundOperations(ctx); + } + }); + } else { + executeOutboundOperations(ctx); + } + + assertEquals("bind", events.take()); + assertEquals("connect", events.take()); + assertEquals("close", events.take()); + assertEquals("deregister", events.take()); + assertEquals("read", events.take()); + assertEquals("write", events.take()); + assertEquals("flush", events.take()); + assertTrue(events.isEmpty()); + pipeline.removeFirst(); + pipeline.channel().close().syncUninterruptibly(); + } + + private static void executeOutboundOperations(ChannelHandlerContext ctx) { + ctx.bind(new SocketAddress() { }); + ctx.connect(new SocketAddress() { }); + ctx.close(); + ctx.deregister(); + ctx.read(); + ctx.write(""); + ctx.flush(); + } + @Test public void testPipelineOperation() { ChannelPipeline pipeline = new LocalChannel().pipeline(); diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index fdb10e5c0c7..f733748809e 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -46,6 +46,7 @@ import org.junit.jupiter.api.function.Executable; import java.net.ConnectException; +import java.nio.channels.AlreadyConnectedException; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -881,6 +882,48 @@ public void execute() { }); } + @Test + public void testConnectedAlready() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final AtomicReference causeRef = new AtomicReference(); + cb.group(group1) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + causeRef.set(cause); + } + }); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).sync().channel(); + + ChannelFuture f = cc.connect(sc.localAddress()).awaitUninterruptibly(); + assertInstanceOf(AlreadyConnectedException.class, f.cause()); + cc.close().syncUninterruptibly(); + assertNull(causeRef.get()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener { private LatchChannelFutureListener(int count) { super(count);