From ae1a32950d5aff6cd18038d0744f90aeec5a4356 Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Mon, 8 Jun 2026 07:50:39 -0600 Subject: [PATCH 1/6] use fixed KatRng labels Signed-off-by: Mike Lodder --- hqc-kem/src/kem_impl.rs | 46 ++++++++++++++++++++++++++++++++++++++--- hqc-kem/tests/kat.rs | 9 +++++--- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/hqc-kem/src/kem_impl.rs b/hqc-kem/src/kem_impl.rs index a24c9d7..b989c37 100644 --- a/hqc-kem/src/kem_impl.rs +++ b/hqc-kem/src/kem_impl.rs @@ -153,16 +153,55 @@ mod tests { use super::*; use kem_traits::common::{Generate, KeyExport, KeyInit}; use kem_traits::{Decapsulate, Encapsulate, Kem}; + use shake::{ExtendableOutput, Shake256, Shake256Reader, Update, XofReader}; + + struct TestRng { + reader: Shake256Reader, + } + + impl TestRng { + fn new(label: &[u8]) -> Self { + let mut hasher = Shake256::default(); + hasher.update(label); + Self { + reader: hasher.finalize_xof(), + } + } + } + + impl rand::TryRng for TestRng { + type Error = core::convert::Infallible; + + fn try_next_u32(&mut self) -> Result { + let mut buf = [0u8; 4]; + self.try_fill_bytes(&mut buf)?; + Ok(u32::from_le_bytes(buf)) + } + + fn try_next_u64(&mut self) -> Result { + let mut buf = [0u8; 8]; + self.try_fill_bytes(&mut buf)?; + Ok(u64::from_le_bytes(buf)) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + self.reader.read(dest); + Ok(()) + } + } + + impl rand::TryCryptoRng for TestRng {} macro_rules! kem_roundtrip_test { ($name:ident, $params:ty) => { #[test] fn $name() { // Generate via kem trait - let mut rng = rand::rng(); + let mut rng = TestRng::new(concat!(stringify!($name), "-keygen").as_bytes()); let (dk, ek) = <$params>::generate_keypair_from_rng(&mut rng); // Encapsulate + let mut rng = TestRng::new(concat!(stringify!($name), "-encaps").as_bytes()); let (ct, ss1) = ek.encapsulate_with_rng(&mut rng); // Decapsulate (use UFCS to call trait method, not inherent) @@ -185,7 +224,8 @@ mod tests { #[test] fn $name() { // Generate DK, export seed, re-import, verify deterministic - let dk = DecapsulationKey::<$params>::generate_from_rng(&mut rand::rng()); + let mut rng = TestRng::new(concat!(stringify!($name), "-keygen").as_bytes()); + let dk = DecapsulationKey::<$params>::generate_from_rng(&mut rng); let seed = dk.to_bytes(); let dk2 = DecapsulationKey::<$params>::new(&seed); @@ -195,7 +235,7 @@ mod tests { assert_eq!(ek1, ek2); // Both should produce the same shared secret - let mut rng = rand::rng(); + let mut rng = TestRng::new(concat!(stringify!($name), "-encaps").as_bytes()); let (ct, ss1) = ek1.encapsulate_with_rng(&mut rng); let ss2 = <_ as Decapsulate>::decapsulate(&dk2, &ct); assert_eq!(ss1, ss2); diff --git a/hqc-kem/tests/kat.rs b/hqc-kem/tests/kat.rs index cdccda0..6ff66fe 100644 --- a/hqc-kem/tests/kat.rs +++ b/hqc-kem/tests/kat.rs @@ -136,8 +136,9 @@ fn test_hqc256_kat() { #[test] fn test_hqc128_roundtrip() { - let mut rng = rand::rng(); + let mut rng = KatRng::new(b"hqc128-roundtrip-keygen"); let (ek, dk) = hqc128::generate_key(&mut rng); + let mut rng = KatRng::new(b"hqc128-roundtrip-encaps"); let (ct, ss1) = ek.encapsulate(&mut rng); let ss2 = dk.decapsulate(&ct); assert_eq!(ss1, ss2, "HQC-128 roundtrip failed"); @@ -145,8 +146,9 @@ fn test_hqc128_roundtrip() { #[test] fn test_hqc192_roundtrip() { - let mut rng = rand::rng(); + let mut rng = KatRng::new(b"hqc192-roundtrip-keygen"); let (ek, dk) = hqc192::generate_key(&mut rng); + let mut rng = KatRng::new(b"hqc192-roundtrip-encaps"); let (ct, ss1) = ek.encapsulate(&mut rng); let ss2 = dk.decapsulate(&ct); assert_eq!(ss1, ss2, "HQC-192 roundtrip failed"); @@ -154,8 +156,9 @@ fn test_hqc192_roundtrip() { #[test] fn test_hqc256_roundtrip() { - let mut rng = rand::rng(); + let mut rng = KatRng::new(b"hqc256-roundtrip-keygen"); let (ek, dk) = hqc256::generate_key(&mut rng); + let mut rng = KatRng::new(b"hqc256-roundtrip-encaps"); let (ct, ss1) = ek.encapsulate(&mut rng); let ss2 = dk.decapsulate(&ct); assert_eq!(ss1, ss2, "HQC-256 roundtrip failed"); From 4832a28c84fd7a49fcfd886af08d02b7533ae04a Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Fri, 12 Jun 2026 14:30:09 -0600 Subject: [PATCH 2/6] add sntrup-kem Signed-off-by: Mike Lodder --- .github/workflows/publish.yml | 1 + .github/workflows/sntrup-kem.yml | 76 +++++ Cargo.lock | 123 +++++++- Cargo.toml | 1 + sntrup-kem/.gitignore | 1 + sntrup-kem/Cargo.toml | 68 +++++ sntrup-kem/LICENSE-APACHE | 201 ++++++++++++++ sntrup-kem/LICENSE-MIT | 19 ++ sntrup-kem/README.md | 237 ++++++++++++++++ sntrup-kem/benches/mod.rs | 61 ++++ sntrup-kem/src/ct.rs | 17 ++ sntrup-kem/src/error.rs | 14 + sntrup-kem/src/kem.rs | 79 ++++++ sntrup-kem/src/lib.rs | 161 +++++++++++ sntrup-kem/src/params.rs | 225 +++++++++++++++ sntrup-kem/src/r3.rs | 242 ++++++++++++++++ sntrup-kem/src/r3/mod3.rs | 32 +++ sntrup-kem/src/r3/vector.rs | 306 ++++++++++++++++++++ sntrup-kem/src/rq.rs | 290 +++++++++++++++++++ sntrup-kem/src/rq/encoding.rs | 325 ++++++++++++++++++++++ sntrup-kem/src/rq/modq.rs | 63 +++++ sntrup-kem/src/rq/vector.rs | 313 +++++++++++++++++++++ sntrup-kem/src/types.rs | 380 +++++++++++++++++++++++++ sntrup-kem/src/utils.rs | 446 ++++++++++++++++++++++++++++++ sntrup-kem/src/zx.rs | 318 +++++++++++++++++++++ sntrup-kem/tests/data/kat0_ct.hex | 1 + sntrup-kem/tests/data/kat0_sk.hex | 1 + sntrup-kem/tests/data/kat0_ss.hex | 1 + sntrup-kem/tests/data/kat1_ct.hex | 1 + sntrup-kem/tests/data/kat1_sk.hex | 1 + sntrup-kem/tests/data/kat1_ss.hex | 1 + sntrup-kem/tests/kat.rs | 56 ++++ sntrup-kem/tests/kem.rs | 288 +++++++++++++++++++ sntrup-kem/tests/roundtrip.rs | 24 ++ sntrup-kem/tests/serde.rs | 135 +++++++++ sntrup-kem/tests/sizes.rs | 27 ++ 36 files changed, 4531 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/sntrup-kem.yml create mode 100644 sntrup-kem/.gitignore create mode 100644 sntrup-kem/Cargo.toml create mode 100644 sntrup-kem/LICENSE-APACHE create mode 100644 sntrup-kem/LICENSE-MIT create mode 100644 sntrup-kem/README.md create mode 100644 sntrup-kem/benches/mod.rs create mode 100644 sntrup-kem/src/ct.rs create mode 100644 sntrup-kem/src/error.rs create mode 100644 sntrup-kem/src/kem.rs create mode 100644 sntrup-kem/src/lib.rs create mode 100644 sntrup-kem/src/params.rs create mode 100644 sntrup-kem/src/r3.rs create mode 100644 sntrup-kem/src/r3/mod3.rs create mode 100644 sntrup-kem/src/r3/vector.rs create mode 100644 sntrup-kem/src/rq.rs create mode 100644 sntrup-kem/src/rq/encoding.rs create mode 100644 sntrup-kem/src/rq/modq.rs create mode 100644 sntrup-kem/src/rq/vector.rs create mode 100644 sntrup-kem/src/types.rs create mode 100644 sntrup-kem/src/utils.rs create mode 100644 sntrup-kem/src/zx.rs create mode 100644 sntrup-kem/tests/data/kat0_ct.hex create mode 100644 sntrup-kem/tests/data/kat0_sk.hex create mode 100644 sntrup-kem/tests/data/kat0_ss.hex create mode 100644 sntrup-kem/tests/data/kat1_ct.hex create mode 100644 sntrup-kem/tests/data/kat1_sk.hex create mode 100644 sntrup-kem/tests/data/kat1_ss.hex create mode 100644 sntrup-kem/tests/kat.rs create mode 100644 sntrup-kem/tests/kem.rs create mode 100644 sntrup-kem/tests/roundtrip.rs create mode 100644 sntrup-kem/tests/serde.rs create mode 100644 sntrup-kem/tests/sizes.rs diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c6269c4..e0aef03 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,7 @@ on: "frodo-kem/v**", "ml-kem/v**", "module-lattice/v**", + "sntrup-kem/v**", "x-wing/v**" ] diff --git a/.github/workflows/sntrup-kem.yml b/.github/workflows/sntrup-kem.yml new file mode 100644 index 0000000..a52d579 --- /dev/null +++ b/.github/workflows/sntrup-kem.yml @@ -0,0 +1,76 @@ +name: sntrup-kem + +on: + pull_request: + paths: + - ".github/workflows/sntrup-kem.yml" + - "sntrup-kem/**" + - "Cargo.*" + push: + branches: master + +defaults: + run: + working-directory: sntrup-kem + +env: + RUSTFLAGS: "-Dwarnings" + CARGO_INCREMENTAL: 0 + +# Cancels CI jobs when new commits are pushed to a PR branch +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + set-msrv: + uses: RustCrypto/actions/.github/workflows/set-msrv.yml@master + with: + msrv: 1.85.0 + + minimal-versions: + if: false + uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master + with: + working-directory: ${{ github.workflow }} + + test: + needs: set-msrv + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - ${{needs.set-msrv.outputs.msrv}} + - stable + steps: + - uses: actions/checkout@v6.0.2 + - uses: RustCrypto/actions/cargo-cache@master + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - run: cargo build --benches + - run: cargo build --benches --all-features + - run: cargo test --no-default-features + - run: cargo test + - run: cargo test --all-features + - run: cargo test --features serde,force-scalar + + cross: + needs: set-msrv + strategy: + matrix: + include: + # Big-endian target exercises the portable scalar path and encoding byte order. + - target: powerpc-unknown-linux-gnu + rust: ${{needs.set-msrv.outputs.msrv}} + - target: powerpc-unknown-linux-gnu + rust: stable + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6.0.2 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - uses: RustCrypto/actions/cross-install@master + - run: cross test --release --target ${{ matrix.target }} --all-features diff --git a/Cargo.lock b/Cargo.lock index 149fcb6..b113a8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,6 +22,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "anes" version = "0.1.6" @@ -237,7 +246,7 @@ dependencies = [ "cast", "ciborium", "clap", - "criterion-plot", + "criterion-plot 0.6.0", "itertools", "num-traits", "oorandom", @@ -250,6 +259,31 @@ dependencies = [ "walkdir", ] +[[package]] +name = "criterion" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" +dependencies = [ + "alloca", + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot 0.8.2", + "itertools", + "num-traits", + "oorandom", + "page_size", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + [[package]] name = "criterion-plot" version = "0.6.0" @@ -260,6 +294,16 @@ dependencies = [ "itertools", ] +[[package]] +name = "criterion-plot" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" +dependencies = [ + "cast", + "itertools", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -479,7 +523,7 @@ version = "0.1.0" dependencies = [ "aes", "chacha20", - "criterion", + "criterion 0.7.0", "getrandom", "hex", "hybrid-array", @@ -550,11 +594,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "rand_core", "wasip2", "wasip3", + "wasm-bindgen", ] [[package]] @@ -673,7 +719,7 @@ name = "hqc-kem" version = "0.1.0" dependencies = [ "const-oid", - "criterion", + "criterion 0.7.0", "hex", "hybrid-array", "kem", @@ -823,7 +869,7 @@ name = "ml-kem" version = "0.3.2" dependencies = [ "const-oid", - "criterion", + "criterion 0.7.0", "getrandom", "hex", "hex-literal", @@ -949,6 +995,16 @@ dependencies = [ "primeorder", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "pem-rfc7468" version = "1.0.0" @@ -1021,6 +1077,15 @@ dependencies = [ "serde", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1098,6 +1163,16 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand_chacha" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e6af7f3e25ded52c41df4e0b1af2d047e45896c2f3281792ed68a1c243daedb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + [[package]] name = "rand_core" version = "0.10.1" @@ -1383,6 +1458,24 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" +[[package]] +name = "sntrup-kem" +version = "0.1.0" +dependencies = [ + "criterion 0.8.2", + "getrandom", + "hex", + "rand", + "rand_chacha", + "serde", + "serde_json", + "serdect", + "sha2", + "subtle", + "thiserror", + "zeroize", +] + [[package]] name = "spin" version = "0.9.8" @@ -1659,6 +1752,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -1668,6 +1777,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-link" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index bc13478..6a65b1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "ml-kem", "hqc-kem", "module-lattice", + "sntrup-kem", "x-wing" ] diff --git a/sntrup-kem/.gitignore b/sntrup-kem/.gitignore new file mode 100644 index 0000000..2f7896d --- /dev/null +++ b/sntrup-kem/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/sntrup-kem/Cargo.toml b/sntrup-kem/Cargo.toml new file mode 100644 index 0000000..143e56b --- /dev/null +++ b/sntrup-kem/Cargo.toml @@ -0,0 +1,68 @@ +[package] +name = "sntrup-kem" +version = "0.1.0" +authors = ["Michael Lodder "] +license = "MIT OR Apache-2.0" +keywords = ["sntrup", "kem", "post-quantum", "cryptography", "NTRU"] +description = "Pure Rust implementation of the Streamlined NTRU Prime KEM for all parameter sizes" +homepage = "https://github.com/RustCrypto/KEMs/tree/master/sntrup-kem" +repository = "https://github.com/RustCrypto/KEMs" +categories = ["algorithms", "cryptography"] +readme = "README.md" +edition = "2024" + +[features] +default = ["kgen", "ecap", "dcap"] +kgen = [] +ecap = [] +dcap = [] +alloc = [] +force-scalar = [] +std = [] +serde = ["dep:serdect", "dep:serde"] +js = ["getrandom/wasm_js"] + +[dependencies] +hex = "0.4" +rand = "0.10.0" +rand_chacha = "0.10.0" +subtle = "2" +getrandom = { version = "0.4", optional = true } +serde = { version = "1", optional = true, default-features = false } +serdect = { version = "0.4", optional = true } +# sha2 0.11 dropped the `asm` feature; hardware SHA acceleration is now selected +# automatically per-target, so a single unconditional dependency suffices. +sha2 = "0.11" +thiserror = "2.0" +zeroize = { version = "1", features = ["derive"] } + +[dev-dependencies] +criterion = { version = "0.8", features = ["html_reports"] } +serde_json = "1" + +[[bench]] +name = "mod" +harness = false + +[lints.rust] +missing_docs = "deny" +missing_debug_implementations = "deny" +trivial_casts = "deny" +trivial_numeric_casts = "deny" +unstable_features = "deny" +unused_import_braces = "deny" +unused_parens = "deny" +unused_lifetimes = "deny" +unused_qualifications = "deny" +unused_extern_crates = "deny" + +[lints.clippy] +unwrap_used = "deny" +cast_precision_loss = "warn" +cast_possible_truncation = "warn" +cast_possible_wrap = "warn" +cast_sign_loss = "warn" +checked_conversions = "warn" +mod_module_files = "warn" +panic = "warn" +panic_in_result_fn = "warn" diff --git a/sntrup-kem/LICENSE-APACHE b/sntrup-kem/LICENSE-APACHE new file mode 100644 index 0000000..78173fa --- /dev/null +++ b/sntrup-kem/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/sntrup-kem/LICENSE-MIT b/sntrup-kem/LICENSE-MIT new file mode 100644 index 0000000..c4c4d9d --- /dev/null +++ b/sntrup-kem/LICENSE-MIT @@ -0,0 +1,19 @@ +Copyright (c) 2024 RustCrypto Developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sntrup-kem/README.md b/sntrup-kem/README.md new file mode 100644 index 0000000..50b3bd4 --- /dev/null +++ b/sntrup-kem/README.md @@ -0,0 +1,237 @@ +# sntrup-kem +[![Crate][crate-image]][crate-link] +[![Docs][docs-image]][docs-link] +![Apache2/MIT licensed][license-image] +[![Downloads][downloads-image]][crate-link] +![build](https://github.com/RustCrypto/KEMs/actions/workflows/sntrup-kem.yml/badge.svg) +![MSRV][msrv-image] + +A pure-Rust implementation of [Streamlined NTRU Prime](https://ntruprime.cr.yp.to/) for all parameter sizes. + +NTRU Prime is a lattice-based cryptosystem aiming to improve the security of lattice schemes at minimal cost. It is thought to be resistant to quantum computing advances, in particular Shor's algorithm. It made it to NIST final round but was not selected for finalization. + +Please read the [warnings](#warnings) before use. + +The algorithm was authored by Daniel J. Bernstein, Chitchanok Chuengsatiansup, Tanja Lange & Christine van Vredendaal. This implementation is aligned with the [PQClean reference](https://github.com/PQClean/PQClean/tree/master/crypto_kem) and verified against the [IETF draft](https://datatracker.ietf.org/doc/draft-josefsson-ntruprime-streamlined/) KAT vectors. + +## Parameter Sets + +| Parameter Set | NIST Level | P | Q | W | Public Key | Secret Key | Ciphertext | Shared Secret | +|---------------|:----------:|-----:|-----:|----:|-----------:|-----------:|-----------:|--------------:| +| sntrup653 | 1 | 653 | 4621 | 288 | 994 | 1518 | 897 | 32 | +| sntrup761 | 2 | 761 | 4591 | 286 | 1158 | 1763 | 1039 | 32 | +| sntrup857 | 3 | 857 | 5167 | 322 | 1322 | 1999 | 1184 | 32 | +| sntrup953 | 4 | 953 | 6343 | 396 | 1505 | 2254 | 1349 | 32 | +| sntrup1013 | 5 | 1013 | 7177 | 448 | 1623 | 2417 | 1455 | 32 | +| sntrup1277 | 5 | 1277 | 7879 | 492 | 2067 | 3059 | 1847 | 32 | + +All key and ciphertext sizes are in bytes. Sizes are fixed per parameter set using a canonical encoding enforced by the code. + +> **Note:** sntrup653 (NIST Level 1) is recommended for research and testing only. Prefer sntrup761 or higher for production use. + +## Features + +- Pure Rust, `no_std`-compatible, dependency-minimal +- All six parameter sizes: sntrup653, sntrup761, sntrup857, sntrup953, sntrup1013, sntrup1277 +- IND-CCA2 secure with implicit rejection +- Constant-time operations throughout (branchless sort, constant-time comparison and selection) +- SIMD acceleration (AVX2 on x86_64, NEON on aarch64) with automatic detection +- Optional `serde` support via the `serde` feature +- Deterministic key generation from a 32-byte seed + +### Feature Flags + +The KEM API is split into three default features so downstream crates can pull in only what they need: + +| Feature | Default | Description | +|---------|:-------:|-------------| +| `kgen` | **yes** | Key generation: `SntrupKem::generate_key`, `SntrupKem::generate_key_deterministic` | +| `ecap` | **yes** | Encapsulation: `EncapsulationKey::encapsulate` | +| `dcap` | **yes** | Decapsulation: `DecapsulationKey::decapsulate` | +| `force-scalar` | no | Disable SIMD (AVX2/NEON) and use pure-Rust scalar code | +| `serde` | no | Enables `Serialize`/`Deserialize` for all key and ciphertext types (via `serdect` for constant-time hex encoding) | +| `js` | no | Enables WebAssembly support for `wasm32-unknown-unknown` by configuring `getrandom` to use JavaScript's `crypto.getRandomValues()` | + +To use only a subset of the KEM API, disable defaults and pick the features you need: + +```toml +[dependencies] +# Decapsulation only (e.g. a receiver that never generates keys or encapsulates) +sntrup-kem = { version = "0.1", default-features = false, features = ["dcap"] } +``` + +## Usage + +### Key generation + +```rust +use sntrup_kem::{Sntrup761, SntrupKem}; + +let mut rng = rand::rng(); +let (encapsulation_key, decapsulation_key) = Sntrup761::generate_key(&mut rng); +``` + +All six parameter sets are available as type aliases: + +```rust +use sntrup_kem::{Sntrup653, Sntrup761, Sntrup857, Sntrup953, Sntrup1013, Sntrup1277, SntrupKem}; + +let mut rng = rand::rng(); +let (ek_653, dk_653) = Sntrup653::generate_key(&mut rng); +let (ek_761, dk_761) = Sntrup761::generate_key(&mut rng); +let (ek_857, dk_857) = Sntrup857::generate_key(&mut rng); +let (ek_953, dk_953) = Sntrup953::generate_key(&mut rng); +let (ek_1013, dk_1013) = Sntrup1013::generate_key(&mut rng); +let (ek_1277, dk_1277) = Sntrup1277::generate_key(&mut rng); +``` + +Or use the convenience modules with parameter-specific types: + +```rust +let mut rng = rand::rng(); +let (ek, dk) = sntrup_kem::sntrup761::generate_key(&mut rng); +``` + +### Encapsulation + +The sender uses the encapsulation (public) key to produce a ciphertext and shared secret: + +```rust +use sntrup_kem::{Sntrup761, SntrupKem}; + +let mut rng = rand::rng(); +let (encapsulation_key, decapsulation_key) = Sntrup761::generate_key(&mut rng); + +// Sender side +let (ciphertext, shared_secret_sender) = encapsulation_key.encapsulate(&mut rng); +``` + +### Decapsulation + +The receiver uses the decapsulation (secret) key and the ciphertext to recover the shared secret: + +```rust +use sntrup_kem::{Sntrup761, SntrupKem}; + +let mut rng = rand::rng(); +let (encapsulation_key, decapsulation_key) = Sntrup761::generate_key(&mut rng); +let (ciphertext, shared_secret_sender) = encapsulation_key.encapsulate(&mut rng); + +// Receiver side — implicit rejection: always returns a key +let shared_secret_receiver = decapsulation_key.decapsulate(&ciphertext); + +assert_eq!(shared_secret_sender, shared_secret_receiver); +``` + +### Deterministic key generation + +Derive the same keypair from a 32-byte seed: + +```rust +use sntrup_kem::{Sntrup761, SntrupKem}; + +let seed = [0x42u8; 32]; // must come from a cryptographically secure source +let (ek1, dk1) = Sntrup761::generate_key_deterministic(&seed); +let (ek2, dk2) = Sntrup761::generate_key_deterministic(&seed); +assert_eq!(ek1, ek2); +assert_eq!(dk1, dk2); +``` + +### Serialization with serde + +Enable the `serde` feature: + +```toml +sntrup-kem = { version = "0.1", features = ["serde"] } +``` + +Keys and ciphertexts serialize to hex in human-readable formats (JSON) and raw bytes in binary formats (postcard, bincode): + +```rust,ignore +use sntrup_kem::{Sntrup761, SntrupKem, EncapsulationKey, Sntrup761Params}; + +let mut rng = rand::rng(); +let (ek, dk) = Sntrup761::generate_key(&mut rng); +let json = serde_json::to_string(&ek).unwrap(); +let ek2: EncapsulationKey = serde_json::from_str(&json).unwrap(); +assert_eq!(ek, ek2); +``` + +### Byte conversions + +All types support `AsRef<[u8]>` and `TryFrom<&[u8]>`: + +```rust +use sntrup_kem::{Sntrup761, SntrupKem, EncapsulationKey, Sntrup761Params}; + +let mut rng = rand::rng(); +let (ek, dk) = Sntrup761::generate_key(&mut rng); + +// Serialize to bytes +let ek_bytes: &[u8] = ek.as_ref(); + +// Deserialize from bytes (validates size) +let ek2 = EncapsulationKey::::try_from(ek_bytes).unwrap(); +assert_eq!(ek, ek2); +``` + +## WebAssembly + +To compile for `wasm32-unknown-unknown`, enable the `js` feature so that `getrandom` uses JavaScript's `crypto.getRandomValues()` for randomness: + +```toml +[dependencies] +sntrup-kem = { version = "0.1", features = ["js"] } +``` + +Install the target and build: + +```bash +rustup target add wasm32-unknown-unknown +cargo build --target wasm32-unknown-unknown --features js +``` + +For `wasm32-wasi` (or `wasm32-wasip1`), the `js` feature is **not** needed since WASI provides its own random source. + +## Security Properties + +- **IND-CCA2 security** via implicit rejection: decapsulation always returns a shared key. On failure, a pseudorandom key is derived from secret randomness (`rho`), making it indistinguishable from a valid key to an attacker. +- **Hash domain separation**: all hashes use prefix bytes (following the NTRU Prime specification). +- **Constant-time operations**: branchless sorting (djbsort), constant-time weight checks, constant-time ciphertext comparison, and constant-time selection in decapsulation. +- **Zeroization**: secret key material is zeroized on drop. + +## Warnings + +#### Implementation + +This implementation has not undergone any security auditing and while care has been taken no guarantees can be made for either correctness or the constant time running of the underlying functions. **Please use at your own risk.** + +#### Algorithm + +Streamlined NTRU Prime was first published in 2016. The algorithm still requires careful security review. Please see [here](https://ntruprime.cr.yp.to/warnings.html) for further warnings from the authors regarding NTRU Prime and lattice-based encryption schemes. + +# License + +Licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +# Contribution + +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, as defined in the Apache-2.0 +license, shall be licensed as above, without any additional terms or +conditions. + +[//]: # (badges) + +[crate-image]: https://img.shields.io/crates/v/sntrup-kem.svg +[crate-link]: https://crates.io/crates/sntrup-kem +[docs-image]: https://docs.rs/sntrup-kem/badge.svg +[docs-link]: https://docs.rs/sntrup-kem/ +[license-image]: https://img.shields.io/badge/license-Apache2.0/MIT-blue.svg +[downloads-image]: https://img.shields.io/crates/d/sntrup-kem.svg +[msrv-image]: https://img.shields.io/badge/rustc-1.85+-blue.svg diff --git a/sntrup-kem/benches/mod.rs b/sntrup-kem/benches/mod.rs new file mode 100644 index 0000000..4b15433 --- /dev/null +++ b/sntrup-kem/benches/mod.rs @@ -0,0 +1,61 @@ +#![allow(missing_docs)] + +use criterion::{Criterion, criterion_group, criterion_main}; +use sntrup_kem::*; + +fn bench_keygen(c: &mut Criterion) { + let mut group = c.benchmark_group("keygen"); + + group.bench_function("sntrup761", |b| { + let mut rng = rand::rng(); + b.iter(|| Sntrup761::generate_key(&mut rng)); + }); + + group.bench_function("sntrup1277", |b| { + let mut rng = rand::rng(); + b.iter(|| Sntrup1277::generate_key(&mut rng)); + }); + + group.finish(); +} + +fn bench_encapsulate(c: &mut Criterion) { + let mut group = c.benchmark_group("encapsulate"); + + group.bench_function("sntrup761", |b| { + let mut rng = rand::rng(); + let (ek, _dk) = Sntrup761::generate_key(&mut rng); + b.iter(|| ek.encapsulate(&mut rng)); + }); + + group.bench_function("sntrup1277", |b| { + let mut rng = rand::rng(); + let (ek, _dk) = Sntrup1277::generate_key(&mut rng); + b.iter(|| ek.encapsulate(&mut rng)); + }); + + group.finish(); +} + +fn bench_decapsulate(c: &mut Criterion) { + let mut group = c.benchmark_group("decapsulate"); + + group.bench_function("sntrup761", |b| { + let mut rng = rand::rng(); + let (ek, dk) = Sntrup761::generate_key(&mut rng); + let (ct, _ss) = ek.encapsulate(&mut rng); + b.iter(|| dk.decapsulate(&ct)); + }); + + group.bench_function("sntrup1277", |b| { + let mut rng = rand::rng(); + let (ek, dk) = Sntrup1277::generate_key(&mut rng); + let (ct, _ss) = ek.encapsulate(&mut rng); + b.iter(|| dk.decapsulate(&ct)); + }); + + group.finish(); +} + +criterion_group!(benches, bench_keygen, bench_encapsulate, bench_decapsulate); +criterion_main!(benches); diff --git a/sntrup-kem/src/ct.rs b/sntrup-kem/src/ct.rs new file mode 100644 index 0000000..ca48134 --- /dev/null +++ b/sntrup-kem/src/ct.rs @@ -0,0 +1,17 @@ +//! Small branchless helpers shared by the R3 and Rq reciprocal loops. +//! +//! Both extended-GCD reciprocals (`r3::reciprocal`, `rq::reciprocal3`) drive the +//! same constant-time control flow, so these primitives live in one place. + +/// Branchless conditional swap: returns `(y, x)` when `mask == -1`, `(x, y)` when `mask == 0`. +#[inline(always)] +pub(crate) fn swap_int(x: isize, y: isize, mask: isize) -> (isize, isize) { + let t = mask & (x ^ y); + (x ^ t, y ^ t) +} + +/// Branchless sign test: returns `-1` (all ones) when `x < y`, else `0`. +#[inline(always)] +pub(crate) fn smaller_mask(x: isize, y: isize) -> isize { + (x - y) >> 31 +} diff --git a/sntrup-kem/src/error.rs b/sntrup-kem/src/error.rs new file mode 100644 index 0000000..f47e447 --- /dev/null +++ b/sntrup-kem/src/error.rs @@ -0,0 +1,14 @@ +//! Error types for sntrup operations. + +/// Errors returned by sntrup operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum Error { + /// Byte slice has the wrong length for conversion. + #[error("invalid size: expected {expected} bytes, got {actual}")] + InvalidSize { + /// Expected size. + expected: usize, + /// Provided size. + actual: usize, + }, +} diff --git a/sntrup-kem/src/kem.rs b/sntrup-kem/src/kem.rs new file mode 100644 index 0000000..da02d07 --- /dev/null +++ b/sntrup-kem/src/kem.rs @@ -0,0 +1,79 @@ +//! Internal KEM operations for Streamlined NTRU Prime. +//! +//! Top-level keygen/encaps/decaps functions that delegate to `utils` for +//! the core cryptographic operations. + +use crate::params::SntrupParameters; +use crate::{r3, utils, zx}; +use rand::CryptoRng; +use zeroize::Zeroize; + +/// Generate a Streamlined NTRU Prime key pair. +/// +/// Returns `(pk_bytes, sk_bytes)` as `Vec`. +#[cfg(feature = "kgen")] +pub(crate) fn keygen(params: &SntrupParameters, rng: &mut impl CryptoRng) -> (Vec, Vec) { + let p = params.p; + + // Generate g and its reciprocal in R3 + let mut g = vec![0i8; p]; + let mut gr = loop { + zx::random::random_small(&mut g, rng); + let (mask, mut gr) = r3::reciprocal(&g, p); + if mask == 0 { + break gr; + } + // Rejected reciprocal is still derived from the secret g — wipe it. + gr.zeroize(); + }; + + // Generate f with Hamming weight w + let mut f = vec![0i8; p]; + zx::random::random_tsmall(&mut f, p, params.w, rng); + + // Generate random rho for implicit rejection (raw random bytes, per PQClean) + let mut rho = vec![0u8; params.small_encode_size]; + rng.fill_bytes(&mut rho); + + let result = utils::derive_key(&f, &g, &gr, &rho, params); + + // Zeroize secret intermediates + f.zeroize(); + g.zeroize(); + gr.zeroize(); + rho.zeroize(); + + result +} + +/// Encapsulate with a public key. +/// +/// Returns `(ciphertext_bytes, shared_secret_bytes)`. +#[cfg(feature = "ecap")] +pub(crate) fn encaps( + pk: &[u8], + params: &SntrupParameters, + rng: &mut impl CryptoRng, +) -> (Vec, Vec) { + let p = params.p; + + // Generate random r with Hamming weight w + let mut r = vec![0i8; p]; + zx::random::random_tsmall(&mut r, p, params.w, rng); + + let (ct, ss) = utils::create_cipher(&r, pk, params); + + // Zeroize secret intermediate + r.zeroize(); + + (ct, ss.to_vec()) +} + +/// Decapsulate with a secret key. +/// +/// Returns shared secret bytes. +#[cfg(feature = "dcap")] +pub(crate) fn decaps(sk: &[u8], ct: &[u8], params: &SntrupParameters) -> Vec { + let ss = utils::decapsulate_inner(ct, sk, params); + ss.to_vec() +} diff --git a/sntrup-kem/src/lib.rs b/sntrup-kem/src/lib.rs new file mode 100644 index 0000000..f250970 --- /dev/null +++ b/sntrup-kem/src/lib.rs @@ -0,0 +1,161 @@ +//! Pure Rust implementation of Streamlined NTRU Prime KEM for all parameter sizes. +//! +//! Streamlined NTRU Prime is a lattice-based, quantum-resistant cryptographic +//! algorithm designed for secure key exchange. This crate supports all six +//! parameter sets: sntrup653, sntrup761, sntrup857, sntrup953, sntrup1013, +//! and sntrup1277. +//! +//! # Usage +//! +//! ```rust +//! use sntrup_kem::{Sntrup761, SntrupKem}; +//! +//! let mut rng = rand::rng(); +//! let (ek, dk) = Sntrup761::generate_key(&mut rng); +//! let (ct, ss1) = ek.encapsulate(&mut rng); +//! let ss2 = dk.decapsulate(&ct); +//! assert_eq!(ss1, ss2); +//! ``` +//! +//! # Security Levels +//! +//! - [`Sntrup653`] / [`sntrup653`]: NIST Level 1 (128-bit security) — research/testing only, prefer [`Sntrup761`] or higher for production +//! - [`Sntrup761`] / [`sntrup761`]: NIST Level 2 (128-bit+ security, used by OpenSSH) +//! - [`Sntrup857`] / [`sntrup857`]: NIST Level 3 (192-bit security) +//! - [`Sntrup953`] / [`sntrup953`]: NIST Level 4 (192-bit+ security) +//! - [`Sntrup1013`] / [`sntrup1013`]: NIST Level 5 (256-bit security) +//! - [`Sntrup1277`] / [`sntrup1277`]: NIST Level 5 (256-bit security, with extra margin) +//! +//! # Sizes (bytes) +//! +//! | Parameter Set | NIST Level | Public Key | Secret Key | Ciphertext | Shared Secret | +//! |---------------|------------|------------|------------|------------|---------------| +//! | sntrup653 | 1 | 994 | 1518 | 897 | 32 | +//! | sntrup761 | 2 | 1158 | 1763 | 1039 | 32 | +//! | sntrup857 | 3 | 1322 | 1999 | 1184 | 32 | +//! | sntrup953 | 4 | 1505 | 2254 | 1349 | 32 | +//! | sntrup1013 | 5 | 1623 | 2417 | 1455 | 32 | +//! | sntrup1277 | 5 | 2067 | 3059 | 1847 | 32 | +//! +//! # Features +//! +//! - `kgen`: Key generation (default) +//! - `ecap`: Encapsulation (default) +//! - `dcap`: Decapsulation (default) +//! - `serde`: Serde serialization support via `serdect` + +mod ct; +mod error; +mod kem; +mod params; +mod r3; +mod rq; +mod types; +mod utils; +mod zx; + +pub use error::Error; +pub use params::{ + Sntrup653Params, Sntrup761Params, Sntrup857Params, Sntrup953Params, Sntrup1013Params, + Sntrup1277Params, SntrupParams, +}; +pub use types::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret, SntrupKem}; + +/// sntrup653 KEM (NIST Level 1, 128-bit security). +/// +/// **Not recommended for production use.** The 653 parameter set provides the +/// lowest security margin. Prefer [`Sntrup761`] or higher for production deployments. +pub type Sntrup653 = SntrupKem; +/// sntrup761 KEM (NIST Level 2, 128-bit+ security, used by OpenSSH). +pub type Sntrup761 = SntrupKem; +/// sntrup857 KEM (NIST Level 3, 192-bit security). +pub type Sntrup857 = SntrupKem; +/// sntrup953 KEM (NIST Level 4, 192-bit+ security). +pub type Sntrup953 = SntrupKem; +/// sntrup1013 KEM (NIST Level 5, 256-bit security). +pub type Sntrup1013 = SntrupKem; +/// sntrup1277 KEM (NIST Level 5, 256-bit security). +pub type Sntrup1277 = SntrupKem; + +/// Define a per-parameter-set convenience module: size constants (sourced from +/// the `SntrupParams` impl so they cannot drift), type aliases, and free +/// `generate_key` / `generate_key_deterministic` functions. +macro_rules! sntrup_module { + ($modname:ident, $params:ident, $kem:ident, $doc:expr) => { + #[doc = $doc] + pub mod $modname { + use crate::params::SntrupParams; + + /// Public key size in bytes. + pub const PUBLIC_KEY_SIZE: usize = crate::$params::PK_BYTES; + /// Secret key size in bytes. + pub const SECRET_KEY_SIZE: usize = crate::$params::SK_BYTES; + /// Ciphertext size in bytes. + pub const CIPHERTEXT_SIZE: usize = crate::$params::CT_BYTES; + /// Shared secret size in bytes. + pub const SHARED_SECRET_SIZE: usize = crate::params::SS_BYTES; + + /// Encapsulation key for this parameter set. + pub type EncapsulationKey = crate::EncapsulationKey; + /// Decapsulation key for this parameter set. + pub type DecapsulationKey = crate::DecapsulationKey; + /// Ciphertext for this parameter set. + pub type Ciphertext = crate::Ciphertext; + /// Shared secret for this parameter set. + pub type SharedSecret = crate::SharedSecret; + + /// Generate a key pair for this parameter set. + #[cfg(feature = "kgen")] + pub fn generate_key( + rng: &mut impl rand::CryptoRng, + ) -> (EncapsulationKey, DecapsulationKey) { + crate::$kem::generate_key(rng) + } + + /// Generate a key pair deterministically from a 32-byte seed. + #[cfg(feature = "kgen")] + pub fn generate_key_deterministic( + seed: &[u8; 32], + ) -> (EncapsulationKey, DecapsulationKey) { + crate::$kem::generate_key_deterministic(seed) + } + } + }; +} + +sntrup_module!( + sntrup653, + Sntrup653Params, + Sntrup653, + "sntrup653: NIST Level 1 (128-bit security), p=653, q=4621, w=288.\n\n**Not recommended for production use.** Prefer [`sntrup761`] or higher." +); +sntrup_module!( + sntrup761, + Sntrup761Params, + Sntrup761, + "sntrup761: NIST Level 2 (128-bit+ security), p=761, q=4591, w=286. Used by OpenSSH." +); +sntrup_module!( + sntrup857, + Sntrup857Params, + Sntrup857, + "sntrup857: NIST Level 3 (192-bit security), p=857, q=5167, w=322." +); +sntrup_module!( + sntrup953, + Sntrup953Params, + Sntrup953, + "sntrup953: NIST Level 4 (192-bit+ security), p=953, q=6343, w=396." +); +sntrup_module!( + sntrup1013, + Sntrup1013Params, + Sntrup1013, + "sntrup1013: NIST Level 5 (256-bit security), p=1013, q=7177, w=448." +); +sntrup_module!( + sntrup1277, + Sntrup1277Params, + Sntrup1277, + "sntrup1277: NIST Level 5 (256-bit security, extra margin), p=1277, q=7879, w=492." +); diff --git a/sntrup-kem/src/params.rs b/sntrup-kem/src/params.rs new file mode 100644 index 0000000..ad385d4 --- /dev/null +++ b/sntrup-kem/src/params.rs @@ -0,0 +1,225 @@ +//! Streamlined NTRU Prime parameter definitions for all security levels. + +/// Shared secret size in bytes. +pub(crate) const SS_BYTES: usize = 32; + +/// Internal runtime parameter set for Streamlined NTRU Prime. +#[doc(hidden)] +#[derive(Debug, Clone, Copy)] +pub struct SntrupParameters { + /// Polynomial degree. + pub p: usize, + /// Modulus (prime). + pub q: i32, + /// Hamming weight for secret key polynomial. + pub w: usize, + /// (Q-1)/2, used for rounding. + pub q12: i32, + /// Size of small-element encoding in bytes: ceil(P/4). + pub small_encode_size: usize, + /// Size of rounded encoding in bytes (variable-radix). + pub rounded_encode_size: usize, + /// Public key size in bytes (Rq encoding). + pub pk_size: usize, + /// Secret key size in bytes: 3*small_encode_size + pk_size + 32. + pub sk_size: usize, + /// Ciphertext size in bytes: rounded_encode_size + 32. + pub ct_size: usize, + /// Barrett reduction constant 1: floor(2^20 / Q). + pub barrett1: i32, + /// Barrett reduction constant 2: floor(2^28 / Q). + pub barrett2: i32, +} + +/// sntrup653 parameters. +pub(crate) const SNTRUP653: SntrupParameters = SntrupParameters { + p: 653, + q: 4621, + w: 288, + q12: 2310, + small_encode_size: 164, + rounded_encode_size: 865, + pk_size: 994, + sk_size: 1518, + ct_size: 897, + barrett1: 226, + barrett2: 58084, +}; + +/// sntrup761 parameters. +pub(crate) const SNTRUP761: SntrupParameters = SntrupParameters { + p: 761, + q: 4591, + w: 286, + q12: 2295, + small_encode_size: 191, + rounded_encode_size: 1007, + pk_size: 1158, + sk_size: 1763, + ct_size: 1039, + barrett1: 228, + barrett2: 58470, +}; + +/// sntrup857 parameters. +pub(crate) const SNTRUP857: SntrupParameters = SntrupParameters { + p: 857, + q: 5167, + w: 322, + q12: 2583, + small_encode_size: 215, + rounded_encode_size: 1152, + pk_size: 1322, + sk_size: 1999, + ct_size: 1184, + barrett1: 202, + barrett2: 51943, +}; + +/// sntrup953 parameters. +pub(crate) const SNTRUP953: SntrupParameters = SntrupParameters { + p: 953, + q: 6343, + w: 396, + q12: 3171, + small_encode_size: 239, + rounded_encode_size: 1317, + pk_size: 1505, + sk_size: 2254, + ct_size: 1349, + barrett1: 165, + barrett2: 42313, +}; + +/// sntrup1013 parameters. +pub(crate) const SNTRUP1013: SntrupParameters = SntrupParameters { + p: 1013, + q: 7177, + w: 448, + q12: 3588, + small_encode_size: 254, + rounded_encode_size: 1423, + pk_size: 1623, + sk_size: 2417, + ct_size: 1455, + barrett1: 146, + barrett2: 37398, +}; + +/// sntrup1277 parameters. +pub(crate) const SNTRUP1277: SntrupParameters = SntrupParameters { + p: 1277, + q: 7879, + w: 492, + q12: 3939, + small_encode_size: 320, + rounded_encode_size: 1815, + pk_size: 2067, + sk_size: 3059, + ct_size: 1847, + barrett1: 133, + barrett2: 34064, +}; + +mod sealed { + /// Sealed trait preventing external implementations of [`SntrupParams`](super::SntrupParams). + pub trait Sealed {} +} + +/// Trait defining a Streamlined NTRU Prime parameter set. +/// +/// Sealed — cannot be implemented outside this crate. Use one of the provided +/// marker types: [`Sntrup653Params`], [`Sntrup761Params`], [`Sntrup857Params`], +/// [`Sntrup953Params`], [`Sntrup1013Params`], [`Sntrup1277Params`]. +pub trait SntrupParams: sealed::Sealed + 'static { + /// Human-readable name (e.g. `"sntrup761"`). + const NAME: &'static str; + /// Public key size in bytes. + const PK_BYTES: usize; + /// Secret key size in bytes. + const SK_BYTES: usize; + /// Ciphertext size in bytes. + const CT_BYTES: usize; + /// Shared secret size in bytes (always 32). + const SS_BYTES: usize = SS_BYTES; + + /// Runtime parameter struct for internal operations. + #[doc(hidden)] + fn params() -> &'static SntrupParameters; +} + +/// Define a zero-sized parameter-set marker type and its sealed [`SntrupParams`] impl. +macro_rules! sntrup_params_marker { + ($marker:ident, $name:literal, $pk:literal, $sk:literal, $ct:literal, $runtime:ident, $doc:expr) => { + #[doc = $doc] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] + pub struct $marker; + + impl sealed::Sealed for $marker {} + + impl SntrupParams for $marker { + const NAME: &'static str = $name; + const PK_BYTES: usize = $pk; + const SK_BYTES: usize = $sk; + const CT_BYTES: usize = $ct; + fn params() -> &'static SntrupParameters { + &$runtime + } + } + }; +} + +sntrup_params_marker!( + Sntrup653Params, + "sntrup653", + 994, + 1518, + 897, + SNTRUP653, + "sntrup653 parameter marker (NIST Level 1, 128-bit security).\n\n**Not recommended for production use.** Prefer [`Sntrup761Params`] or higher." +); +sntrup_params_marker!( + Sntrup761Params, + "sntrup761", + 1158, + 1763, + 1039, + SNTRUP761, + "sntrup761 parameter marker (NIST Level 2, 128-bit+ security)." +); +sntrup_params_marker!( + Sntrup857Params, + "sntrup857", + 1322, + 1999, + 1184, + SNTRUP857, + "sntrup857 parameter marker (NIST Level 3, 192-bit security)." +); +sntrup_params_marker!( + Sntrup953Params, + "sntrup953", + 1505, + 2254, + 1349, + SNTRUP953, + "sntrup953 parameter marker (NIST Level 4, 192-bit+ security)." +); +sntrup_params_marker!( + Sntrup1013Params, + "sntrup1013", + 1623, + 2417, + 1455, + SNTRUP1013, + "sntrup1013 parameter marker (NIST Level 5, 256-bit security)." +); +sntrup_params_marker!( + Sntrup1277Params, + "sntrup1277", + 2067, + 3059, + 1847, + SNTRUP1277, + "sntrup1277 parameter marker (NIST Level 5, 256-bit security)." +); diff --git a/sntrup-kem/src/r3.rs b/sntrup-kem/src/r3.rs new file mode 100644 index 0000000..3b1e206 --- /dev/null +++ b/sntrup-kem/src/r3.rs @@ -0,0 +1,242 @@ +pub mod mod3; +mod vector; + +use crate::ct::{smaller_mask, swap_int}; + +#[allow(clippy::cast_possible_wrap)] +pub fn reciprocal(s: &[i8], p: usize) -> (isize, Vec) { + let loops = 2 * p + 1; + let mut r = vec![0i8; p]; + let mut f = vec![0i8; p + 1]; + f[0] = -1; + f[1] = -1; + f[p] = 1; + + let mut g = vec![0i8; p + 1]; + g[..p].copy_from_slice(&s[..p]); + let mut d = p as isize; + let mut e = p as isize; + let mut u = vec![0i8; loops + 1]; + let mut v = vec![0i8; loops + 1]; + v[0] = 1; + + for _ in 0..loops { + let c = mod3::quotient(g[p], f[p]); + vector::minus_product_shift(&mut g, p + 1, &f, c); + vector::minus_product_shift(&mut v, loops + 1, &u, c); + e -= 1; + let m = smaller_mask(e, d) & mod3::mask_set(g[p]); + let (e_tmp, d_tmp) = swap_int(e, d, m); + e = e_tmp; + d = d_tmp; + vector::swap(&mut f, &mut g, p + 1, m); + vector::swap(&mut u, &mut v, loops + 1, m); + } + + vector::product(&mut r, p, &u[p..], mod3::reciprocal(f[p])); + (smaller_mask(0, d), r) +} + +#[allow(unsafe_code)] +pub fn mult(h: &mut [i8], f: &[i8], g: &[i8], p: usize) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return mult_avx2(h, f, g, p); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return mult_neon(h, f, g, p); + } + #[allow(unreachable_code)] + mult_scalar(h, f, g, p); +} + +fn mult_scalar(h: &mut [i8], f: &[i8], g: &[i8], p: usize) { + let mut fg = vec![0i8; p * 2 - 1]; + for i in 0..p { + let mut r = 0i32; + for j in 0..=i { + r += f[j] as i32 * g[i - j] as i32; + } + fg[i] = mod3::freeze(r); + } + for i in p..(p * 2 - 1) { + let mut r = 0i32; + for j in (i - p + 1)..p { + r += f[j] as i32 * g[i - j] as i32; + } + fg[i] = mod3::freeze(r); + } + for i in (p..(p * 2) - 1).rev() { + fg[i - p] = mod3::freeze(fg[i - p] as i32 + fg[i] as i32); + fg[i - p + 1] = mod3::freeze(fg[i - p + 1] as i32 + fg[i] as i32); + } + h[..p].copy_from_slice(&fg[..p]); +} + +/// Column-major schoolbook multiplication with AVX2 for R3 polynomials. +/// Uses _mm256_sign_epi16 for {-1,0,1} multiplication and i16 accumulators. +/// Processes 16 coefficients per SIMD instruction. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::needless_range_loop +)] +unsafe fn mult_avx2(h: &mut [i8], f: &[i8], g: &[i8], p: usize) { + unsafe { + use core::arch::x86_64::*; + + let g_pad_len = (p + 15) & !15; // multiple of 16 + let fg_pad_len = p + g_pad_len; // >= 2p-1 + let fg_len = p * 2 - 1; + + // Sign-extend g to i16, padded + let mut g_pad = vec![0i16; g_pad_len]; + for i in 0..p { + g_pad[i] = g[i] as i16; + } + + // i16 accumulators (max value: ±p, fits in i16 for p <= 1277) + let mut fg = vec![0i16; fg_pad_len]; + + // Column-major accumulation: fg[j+k] += f[j] * g[k] + for j in 0..p { + let fj = _mm256_set1_epi16(f[j] as i16); + let mut k = 0usize; + while k + 16 <= g_pad_len { + let gk = _mm256_loadu_si256(g_pad.as_ptr().add(k) as *const __m256i); + // sign_epi16: if fj>0 → gk, if fj==0 → 0, if fj<0 → -gk + let prod = _mm256_sign_epi16(gk, fj); + let acc = _mm256_loadu_si256(fg.as_ptr().add(j + k) as *const __m256i); + _mm256_storeu_si256( + fg.as_mut_ptr().add(j + k) as *mut __m256i, + _mm256_add_epi16(acc, prod), + ); + k += 16; + } + } + + // Vectorized mod-3 freeze: mulhrs(a, 10923) gives floor((a*10923+16384)/32768) + // which is the correct quotient for |a| <= 1277. + // Result: a - 3*q is in {-1, 0, 1}. + let k10923 = _mm256_set1_epi16(10923); + let three16 = _mm256_set1_epi16(3); + + let mut fg8 = vec![0i8; fg_len]; + let mut i = 0usize; + while i + 32 <= fg_len { + // Process 32 values: two batches of 16 i16 → 32 i8 + let a0 = _mm256_loadu_si256(fg.as_ptr().add(i) as *const __m256i); + let q0 = _mm256_mulhrs_epi16(a0, k10923); + let r0 = _mm256_sub_epi16(a0, _mm256_mullo_epi16(q0, three16)); + + let a1 = _mm256_loadu_si256(fg.as_ptr().add(i + 16) as *const __m256i); + let q1 = _mm256_mulhrs_epi16(a1, k10923); + let r1 = _mm256_sub_epi16(a1, _mm256_mullo_epi16(q1, three16)); + + // Pack 16+16 i16 → 32 i8, fix AVX2 lane ordering + let packed = _mm256_permute4x64_epi64(_mm256_packs_epi16(r0, r1), 0xD8); + _mm256_storeu_si256(fg8.as_mut_ptr().add(i) as *mut __m256i, packed); + i += 32; + } + while i < fg_len { + fg8[i] = mod3::freeze(fg[i] as i32); + i += 1; + } + + // Reduction: x^p ≡ x + 1 (mod x^p - x - 1) + for i in (p..(p * 2) - 1).rev() { + fg8[i - p] = mod3::freeze(fg8[i - p] as i32 + fg8[i] as i32); + fg8[i - p + 1] = mod3::freeze(fg8[i - p + 1] as i32 + fg8[i] as i32); + } + h[..p].copy_from_slice(&fg8[..p]); + } +} + +/// Column-major schoolbook multiplication with NEON for R3 polynomials. +/// Uses vmulq_s16 for {-1,0,1} multiplication and i16 accumulators. +/// Processes 8 coefficients per SIMD instruction. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::needless_range_loop +)] +unsafe fn mult_neon(h: &mut [i8], f: &[i8], g: &[i8], p: usize) { + unsafe { + use core::arch::aarch64::*; + + let g_pad_len = (p + 7) & !7; // multiple of 8 + let fg_pad_len = p + g_pad_len; // >= 2p-1 + let fg_len = p * 2 - 1; + + // Sign-extend g to i16, padded + let mut g_pad = vec![0i16; g_pad_len]; + for i in 0..p { + g_pad[i] = g[i] as i16; + } + + // i16 accumulators (max value: ±p, fits in i16 for p <= 1277) + let mut fg = vec![0i16; fg_pad_len]; + + // Column-major accumulation: fg[j+k] += f[j] * g[k] + // vmulq_s16(gk, fj): for fj in {-1,0,1} this produces correct signed product + for j in 0..p { + let fj = vdupq_n_s16(f[j] as i16); + let mut k = 0usize; + while k + 8 <= g_pad_len { + let gk = vld1q_s16(g_pad.as_ptr().add(k)); + let prod = vmulq_s16(gk, fj); + let acc = vld1q_s16(fg.as_ptr().add(j + k)); + vst1q_s16(fg.as_mut_ptr().add(j + k), vaddq_s16(acc, prod)); + k += 8; + } + } + + // Vectorized mod-3 freeze: vqrdmulhq_s16(a, 10923) gives correct quotient + // for |a| <= 1277. Result: a - 3*q is in {-1, 0, 1}. + let k10923 = vdupq_n_s16(10923); + let three16 = vdupq_n_s16(3); + + let mut fg8 = vec![0i8; fg_len]; + let mut i = 0usize; + while i + 16 <= fg_len { + // Process 16 values: two batches of 8 i16 → 16 i8 + let a0 = vld1q_s16(fg.as_ptr().add(i)); + let q0 = vqrdmulhq_s16(a0, k10923); + let r0 = vsubq_s16(a0, vmulq_s16(q0, three16)); + + let a1 = vld1q_s16(fg.as_ptr().add(i + 8)); + let q1 = vqrdmulhq_s16(a1, k10923); + let r1 = vsubq_s16(a1, vmulq_s16(q1, three16)); + + // Pack 8+8 i16 → 16 i8 (naturally ordered, no permute needed) + let packed = vcombine_s8(vqmovn_s16(r0), vqmovn_s16(r1)); + vst1q_s8(fg8.as_mut_ptr().add(i), packed); + i += 16; + } + while i < fg_len { + fg8[i] = mod3::freeze(fg[i] as i32); + i += 1; + } + + // Reduction: x^p ≡ x + 1 (mod x^p - x - 1) + for i in (p..(p * 2) - 1).rev() { + fg8[i - p] = mod3::freeze(fg8[i - p] as i32 + fg8[i] as i32); + fg8[i - p + 1] = mod3::freeze(fg8[i - p + 1] as i32 + fg8[i] as i32); + } + h[..p].copy_from_slice(&fg8[..p]); + } +} diff --git a/sntrup-kem/src/r3/mod3.rs b/sntrup-kem/src/r3/mod3.rs new file mode 100644 index 0000000..8f18fd2 --- /dev/null +++ b/sntrup-kem/src/r3/mod3.rs @@ -0,0 +1,32 @@ +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +pub fn freeze(a: i32) -> i8 { + let b = a - (3 * ((10923 * a) >> 15)); + let c = b - (3 * ((89_478_485 * b + 134_217_728) >> 28)); + c as i8 +} + +#[inline(always)] +pub fn product(a: i8, b: i8) -> i8 { + a * b +} + +#[inline(always)] +pub fn reciprocal(a: i8) -> i8 { + a +} + +#[inline(always)] +pub fn quotient(a: i8, b: i8) -> i8 { + product(a, reciprocal(b)) +} + +#[inline(always)] +pub fn minus_product(a: i8, b: i8, c: i8) -> i8 { + freeze(a as i32 - b as i32 * c as i32) +} + +#[inline(always)] +pub fn mask_set(x: i8) -> isize { + (-x * x) as isize +} diff --git a/sntrup-kem/src/r3/vector.rs b/sntrup-kem/src/r3/vector.rs new file mode 100644 index 0000000..deaaf36 --- /dev/null +++ b/sntrup-kem/src/r3/vector.rs @@ -0,0 +1,306 @@ +#![allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] + +use super::mod3; + +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +pub fn swap(x: &mut [i8], y: &mut [i8], n: usize, mask: isize) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return swap_avx2(x, y, n, mask); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return swap_neon(x, y, n, mask); + } + #[allow(unreachable_code)] + swap_scalar(x, y, n, mask); +} + +#[allow(clippy::cast_possible_truncation)] +fn swap_scalar(x: &mut [i8], y: &mut [i8], n: usize, mask: isize) { + let c = mask as i8; + for i in 0..n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + } +} + +/// 32 i8 elements per SIMD iteration. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +unsafe fn swap_avx2(x: &mut [i8], y: &mut [i8], n: usize, mask: isize) { + unsafe { + use core::arch::x86_64::*; + let cv = _mm256_set1_epi8(mask as i8); + let mut i = 0usize; + while i + 32 <= n { + let xv = _mm256_loadu_si256(x.as_ptr().add(i) as *const __m256i); + let yv = _mm256_loadu_si256(y.as_ptr().add(i) as *const __m256i); + let t = _mm256_and_si256(cv, _mm256_xor_si256(xv, yv)); + _mm256_storeu_si256( + x.as_mut_ptr().add(i) as *mut __m256i, + _mm256_xor_si256(xv, t), + ); + _mm256_storeu_si256( + y.as_mut_ptr().add(i) as *mut __m256i, + _mm256_xor_si256(yv, t), + ); + i += 32; + } + let c = mask as i8; + while i < n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + i += 1; + } + } +} + +/// 16 i8 elements per NEON iteration. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +unsafe fn swap_neon(x: &mut [i8], y: &mut [i8], n: usize, mask: isize) { + unsafe { + use core::arch::aarch64::*; + let cv = vdupq_n_s8(mask as i8); + let mut i = 0usize; + while i + 16 <= n { + let xv = vld1q_s8(x.as_ptr().add(i)); + let yv = vld1q_s8(y.as_ptr().add(i)); + let t = vandq_s8(cv, veorq_s8(xv, yv)); + vst1q_s8(x.as_mut_ptr().add(i), veorq_s8(xv, t)); + vst1q_s8(y.as_mut_ptr().add(i), veorq_s8(yv, t)); + i += 16; + } + let c = mask as i8; + while i < n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + i += 1; + } + } +} + +#[inline(always)] +pub fn product(z: &mut [i8], n: usize, x: &[i8], c: i8) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return product_avx2(z, n, x, c); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return product_neon(z, n, x, c); + } + #[allow(unreachable_code)] + product_scalar(z, n, x, c); +} + +fn product_scalar(z: &mut [i8], n: usize, x: &[i8], c: i8) { + for i in 0..n { + z[i] = mod3::product(x[i], c); + } +} + +/// For c in {-1, 0, 1}: _mm256_sign_epi8(x, c) computes x * c. +/// Processes 32 elements per iteration. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +unsafe fn product_avx2(z: &mut [i8], n: usize, x: &[i8], c: i8) { + unsafe { + use core::arch::x86_64::*; + let cv = _mm256_set1_epi8(c); + let mut i = 0usize; + while i + 32 <= n { + let xv = _mm256_loadu_si256(x.as_ptr().add(i) as *const __m256i); + _mm256_storeu_si256( + z.as_mut_ptr().add(i) as *mut __m256i, + _mm256_sign_epi8(xv, cv), + ); + i += 32; + } + while i < n { + z[i] = mod3::product(x[i], c); + i += 1; + } + } +} + +/// NEON sign_epi8 equivalent: branchless x*sign(c). +/// For c in {-1, 0, 1}: returns x if c>0, -x if c<0, 0 if c==0. +/// Constant-time: no branches on c (which may be secret-derived). +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +#[inline(always)] +unsafe fn sign_epi8_neon( + xv: core::arch::aarch64::int8x16_t, + cv: core::arch::aarch64::int8x16_t, +) -> core::arch::aarch64::int8x16_t { + unsafe { + use core::arch::aarch64::*; + let sign_mask = vreinterpretq_u8_s8(vshrq_n_s8(cv, 7)); // 0xFF if c<0 + let nonzero = vtstq_s8(cv, cv); // 0xFF if c!=0 (uint8x16_t) + let neg_x = vnegq_s8(xv); + let selected = vreinterpretq_s8_u8(vbslq_u8( + sign_mask, + vreinterpretq_u8_s8(neg_x), + vreinterpretq_u8_s8(xv), + )); + vandq_s8(selected, vreinterpretq_s8_u8(nonzero)) + } +} + +/// NEON product for i8: 16 elements per iteration. +/// For c in {-1, 0, 1}, uses branchless sign_epi8 equivalent. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +unsafe fn product_neon(z: &mut [i8], n: usize, x: &[i8], c: i8) { + unsafe { + use core::arch::aarch64::*; + let cv = vdupq_n_s8(c); + let mut i = 0usize; + while i + 16 <= n { + let xv = vld1q_s8(x.as_ptr().add(i)); + vst1q_s8(z.as_mut_ptr().add(i), sign_epi8_neon(xv, cv)); + i += 16; + } + while i < n { + z[i] = mod3::product(x[i], c); + i += 1; + } + } +} + +/// Fused minus_product and shift: z[i+1] = freeze(z[i] - y[i]*c), z[0] = 0. +/// Processes backward to avoid overwrite conflicts, eliminating a separate memmove. +#[inline(always)] +pub fn minus_product_shift(z: &mut [i8], n: usize, y: &[i8], c: i8) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return minus_product_shift_avx2(z, n, y, c); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return minus_product_shift_neon(z, n, y, c); + } + #[allow(unreachable_code)] + minus_product_shift_scalar(z, n, y, c); +} + +fn minus_product_shift_scalar(z: &mut [i8], n: usize, y: &[i8], c: i8) { + for i in (0..n - 1).rev() { + z[i + 1] = mod3::minus_product(z[i], y[i], c); + } + z[0] = 0; +} + +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +unsafe fn minus_product_shift_avx2(z: &mut [i8], n: usize, y: &[i8], c: i8) { + unsafe { + use core::arch::x86_64::*; + let cv = _mm256_set1_epi8(c); + let neg2 = _mm256_set1_epi8(-2); + let pos2 = _mm256_set1_epi8(2); + let three = _mm256_set1_epi8(3); + + let mut j = (n - 2) as isize; + + // Process 32 i8 elements at a time, backward + while j >= 31 { + let start = (j - 31) as usize; + let zv = _mm256_loadu_si256(z.as_ptr().add(start) as *const __m256i); + let yv = _mm256_loadu_si256(y.as_ptr().add(start) as *const __m256i); + let yc = _mm256_sign_epi8(yv, cv); + let r = _mm256_sub_epi8(zv, yc); + // Mod-3 fixup: r is in [-2, 2] + let add = _mm256_and_si256(three, _mm256_cmpeq_epi8(r, neg2)); + let sub = _mm256_and_si256(three, _mm256_cmpeq_epi8(r, pos2)); + let r = _mm256_add_epi8(_mm256_sub_epi8(r, sub), add); + // Store at offset +1 (the shift) + _mm256_storeu_si256(z.as_mut_ptr().add(start + 1) as *mut __m256i, r); + j -= 32; + } + + // Scalar remainder + while j >= 0 { + z[(j + 1) as usize] = mod3::minus_product(z[j as usize], y[j as usize], c); + j -= 1; + } + z[0] = 0; + } +} + +/// NEON minus_product_shift for i8: 16 elements at a time, backward. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +unsafe fn minus_product_shift_neon(z: &mut [i8], n: usize, y: &[i8], c: i8) { + unsafe { + use core::arch::aarch64::*; + let cv = vdupq_n_s8(c); + let neg2 = vdupq_n_s8(-2); + let pos2 = vdupq_n_s8(2); + let three = vdupq_n_s8(3); + + let mut j = (n - 2) as isize; + + // Process 16 i8 elements at a time, backward + while j >= 15 { + let start = (j - 15) as usize; + let zv = vld1q_s8(z.as_ptr().add(start)); + let yv = vld1q_s8(y.as_ptr().add(start)); + let yc = sign_epi8_neon(yv, cv); + let r = vsubq_s8(zv, yc); + // Mod-3 fixup: r is in [-2, 2] + let eq_neg2 = vceqq_s8(r, neg2); + let eq_pos2 = vceqq_s8(r, pos2); + let add = vandq_s8(three, vreinterpretq_s8_u8(eq_neg2)); + let sub = vandq_s8(three, vreinterpretq_s8_u8(eq_pos2)); + let r = vaddq_s8(vsubq_s8(r, sub), add); + // Store at offset +1 (the shift) + vst1q_s8(z.as_mut_ptr().add(start + 1), r); + j -= 16; + } + + // Scalar remainder + while j >= 0 { + z[(j + 1) as usize] = mod3::minus_product(z[j as usize], y[j as usize], c); + j -= 1; + } + z[0] = 0; + } +} diff --git a/sntrup-kem/src/rq.rs b/sntrup-kem/src/rq.rs new file mode 100644 index 0000000..90cc218 --- /dev/null +++ b/sntrup-kem/src/rq.rs @@ -0,0 +1,290 @@ +pub mod encoding; +pub mod modq; +mod vector; + +use crate::ct::{smaller_mask, swap_int}; +use crate::params::SntrupParameters; + +#[allow(clippy::cast_possible_wrap)] +pub fn reciprocal3(s: &[i8], params: &SntrupParameters) -> Vec { + let p = params.p; + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + let loops = 2 * p + 1; + + let mut r = vec![0i16; p]; + let mut f = vec![0i16; p + 1]; + f[0] = -1; + f[1] = -1; + f[p] = 1; + let mut g = vec![0i16; p + 1]; + for i in 0..p { + g[i] = (3 * s[i]) as i16; + } + let mut d = p as isize; + let mut e = p as isize; + let mut u = vec![0i16; loops + 1]; + let mut v = vec![0i16; loops + 1]; + v[0] = 1; + + for _ in 0..loops { + let c = modq::quotient(g[p], f[p], q, b1, b2); + vector::minus_product_shift(&mut g, p + 1, &f, c, q, b1, b2); + vector::minus_product_shift(&mut v, loops + 1, &u, c, q, b1, b2); + e -= 1; + let m = smaller_mask(e, d) & modq::mask_set(g[p]); + let (e_tmp, d_tmp) = swap_int(e, d, m); + e = e_tmp; + d = d_tmp; + vector::swap(&mut f, &mut g, p + 1, m); + vector::swap(&mut u, &mut v, loops + 1, m); + } + vector::product( + &mut r, + p, + &u[p..], + modq::reciprocal(f[p], q, b1, b2), + q, + b1, + b2, + ); + // Note: unlike r3::reciprocal, no invertibility check is returned here. + // For these parameter sets q is prime and x^p - x - 1 is irreducible mod q, + // so R/q is a field and the weight-w secret f is always invertible — the + // reciprocal never fails, so there is no failure mask to propagate. + r +} + +#[allow(clippy::cast_possible_truncation)] +pub fn round3(h: &mut [i16], params: &SntrupParameters) { + let q12 = params.q12; + for coeff in h.iter_mut() { + let inner = 21846i32 * (*coeff as i32 + q12); + *coeff = (((inner + 32768) >> 16) * 3 - q12) as i16; + } +} + +#[allow(unsafe_code)] +pub fn mult(h: &mut [i16], f: &[i16], g: &[i8], params: &SntrupParameters) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 availability verified by cfg target_feature + unsafe { + return mult_avx2(h, f, g, params); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return mult_neon(h, f, g, params); + } + #[allow(unreachable_code)] + mult_scalar(h, f, g, params); +} + +fn mult_scalar(h: &mut [i16], f: &[i16], g: &[i8], params: &SntrupParameters) { + let p = params.p; + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + + let mut fg = vec![0i16; p * 2 - 1]; + for i in 0..p { + let mut r = 0i32; + for j in 0..=i { + r += f[j] as i32 * g[i - j] as i32; + } + fg[i] = modq::freeze(r, q, b1, b2); + } + for i in p..(p * 2 - 1) { + let mut r = 0i32; + for j in (i - p + 1)..p { + r += f[j] as i32 * g[i - j] as i32; + } + fg[i] = modq::freeze(r, q, b1, b2); + } + for i in (p..(p * 2) - 1).rev() { + fg[i - p] = modq::freeze(fg[i - p] as i32 + fg[i] as i32, q, b1, b2); + fg[i - p + 1] = modq::freeze(fg[i - p + 1] as i32 + fg[i] as i32, q, b1, b2); + } + h[..p].copy_from_slice(&fg[..p]); +} + +/// Column-major schoolbook multiplication with AVX2. +/// Processes 8 i32 multiply-accumulates per SIMD instruction. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::needless_range_loop +)] +unsafe fn mult_avx2(h: &mut [i16], f: &[i16], g: &[i8], params: &SntrupParameters) { + unsafe { + use core::arch::x86_64::*; + + let p = params.p; + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + + // Pad to multiples of 8 so SIMD loops need no remainder handling + let g_pad_len = (p + 7) & !7; + let fg_pad_len = p + g_pad_len; + let fg_len = p * 2 - 1; + + let mut g_pad = vec![0i8; g_pad_len]; + g_pad[..p].copy_from_slice(&g[..p]); + let mut fg = vec![0i32; fg_pad_len]; + + // Accumulate f[j]*g[k] into fg[j+k] + for j in 0..p { + let fj = _mm256_set1_epi32(f[j] as i32); + let mut k = 0usize; + while k < g_pad_len { + let gb = _mm_loadl_epi64(g_pad.as_ptr().add(k) as *const __m128i); + let gk = _mm256_cvtepi8_epi32(gb); + let prod = _mm256_mullo_epi32(fj, gk); + let acc = _mm256_loadu_si256(fg.as_ptr().add(j + k) as *const __m256i); + _mm256_storeu_si256( + fg.as_mut_ptr().add(j + k) as *mut __m256i, + _mm256_add_epi32(acc, prod), + ); + k += 8; + } + } + + // Vectorized Barrett freeze: i32 -> i16 + let qv = _mm256_set1_epi32(q); + let kb1 = _mm256_set1_epi32(b1); + let kb2 = _mm256_set1_epi32(b2); + let k134m = _mm256_set1_epi32(134_217_728); + + let mut fg16 = vec![0i16; fg_len]; + let mut i = 0usize; + while i + 16 <= fg_len { + let a0 = _mm256_loadu_si256(fg.as_ptr().add(i) as *const __m256i); + let a1 = _mm256_loadu_si256(fg.as_ptr().add(i + 8) as *const __m256i); + + // freeze(a) = a - Q*((b1*a)>>20) then b - Q*((b2*b+134M)>>28) + let t = _mm256_srai_epi32(_mm256_mullo_epi32(a0, kb1), 20); + let b0 = _mm256_sub_epi32(a0, _mm256_mullo_epi32(t, qv)); + let t = _mm256_srai_epi32(_mm256_add_epi32(_mm256_mullo_epi32(b0, kb2), k134m), 28); + let r0 = _mm256_sub_epi32(b0, _mm256_mullo_epi32(t, qv)); + + let t = _mm256_srai_epi32(_mm256_mullo_epi32(a1, kb1), 20); + let b1v = _mm256_sub_epi32(a1, _mm256_mullo_epi32(t, qv)); + let t = _mm256_srai_epi32(_mm256_add_epi32(_mm256_mullo_epi32(b1v, kb2), k134m), 28); + let r1 = _mm256_sub_epi32(b1v, _mm256_mullo_epi32(t, qv)); + + // Pack 8+8 i32 -> 16 i16 and fix AVX2 lane ordering + let packed = _mm256_permute4x64_epi64(_mm256_packs_epi32(r0, r1), 0xD8); + _mm256_storeu_si256(fg16.as_mut_ptr().add(i) as *mut __m256i, packed); + i += 16; + } + while i < fg_len { + fg16[i] = modq::freeze(fg[i], q, b1, b2); + i += 1; + } + + // Reduction (scalar -- sequential dependencies prevent vectorization) + for i in (p..(p * 2) - 1).rev() { + fg16[i - p] = modq::freeze(fg16[i - p] as i32 + fg16[i] as i32, q, b1, b2); + fg16[i - p + 1] = modq::freeze(fg16[i - p + 1] as i32 + fg16[i] as i32, q, b1, b2); + } + h[..p].copy_from_slice(&fg16[..p]); + } +} + +/// Column-major schoolbook multiplication with NEON. +/// Processes 4 i32 multiply-accumulates per SIMD instruction. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::needless_range_loop +)] +unsafe fn mult_neon(h: &mut [i16], f: &[i16], g: &[i8], params: &SntrupParameters) { + unsafe { + use core::arch::aarch64::*; + + let p = params.p; + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + + // Pad to multiples of 4 so SIMD loops need no remainder handling + let g_pad_len = (p + 3) & !3; + let fg_pad_len = p + g_pad_len; + let fg_len = p * 2 - 1; + + let mut g_pad = vec![0i8; g_pad_len]; + g_pad[..p].copy_from_slice(&g[..p]); + let mut fg = vec![0i32; fg_pad_len]; + + // Accumulate f[j]*g[k] into fg[j+k] + for j in 0..p { + let fj = vdupq_n_s32(f[j] as i32); + let mut k = 0usize; + while k + 4 <= g_pad_len { + // Sign-extend 4 i8 -> i16 -> i32 + let gb = vld1_s8(g_pad.as_ptr().add(k)); + let g16 = vmovl_s8(gb); + let gk = vmovl_s16(vget_low_s16(g16)); + let prod = vmulq_s32(fj, gk); + let acc = vld1q_s32(fg.as_ptr().add(j + k)); + vst1q_s32(fg.as_mut_ptr().add(j + k), vaddq_s32(acc, prod)); + k += 4; + } + } + + // Vectorized Barrett freeze: i32 -> i16 + let qv = vdupq_n_s32(q); + let kb1 = vdupq_n_s32(b1); + let kb2 = vdupq_n_s32(b2); + let k134m = vdupq_n_s32(134_217_728); + + let mut fg16 = vec![0i16; fg_len]; + let mut i = 0usize; + while i + 8 <= fg_len { + // Process 8 values: two batches of 4 i32 -> 8 i16 + let a0 = vld1q_s32(fg.as_ptr().add(i)); + let a1 = vld1q_s32(fg.as_ptr().add(i + 4)); + + let t = vshrq_n_s32(vmulq_s32(a0, kb1), 20); + let b0 = vsubq_s32(a0, vmulq_s32(t, qv)); + let t = vshrq_n_s32(vaddq_s32(vmulq_s32(b0, kb2), k134m), 28); + let r0 = vsubq_s32(b0, vmulq_s32(t, qv)); + + let t = vshrq_n_s32(vmulq_s32(a1, kb1), 20); + let b1v = vsubq_s32(a1, vmulq_s32(t, qv)); + let t = vshrq_n_s32(vaddq_s32(vmulq_s32(b1v, kb2), k134m), 28); + let r1 = vsubq_s32(b1v, vmulq_s32(t, qv)); + + // Pack 4+4 i32 -> 8 i16 (naturally ordered, no permute needed) + let packed = vcombine_s16(vmovn_s32(r0), vmovn_s32(r1)); + vst1q_s16(fg16.as_mut_ptr().add(i), packed); + i += 8; + } + while i < fg_len { + fg16[i] = modq::freeze(fg[i], q, b1, b2); + i += 1; + } + + // Reduction (scalar -- sequential dependencies prevent vectorization) + for i in (p..(p * 2) - 1).rev() { + fg16[i - p] = modq::freeze(fg16[i - p] as i32 + fg16[i] as i32, q, b1, b2); + fg16[i - p + 1] = modq::freeze(fg16[i - p + 1] as i32 + fg16[i] as i32, q, b1, b2); + } + h[..p].copy_from_slice(&fg16[..p]); + } +} diff --git a/sntrup-kem/src/rq/encoding.rs b/sntrup-kem/src/rq/encoding.rs new file mode 100644 index 0000000..d1e239a --- /dev/null +++ b/sntrup-kem/src/rq/encoding.rs @@ -0,0 +1,325 @@ +use super::modq; +use crate::params::SntrupParameters; + +/// Maximum number of pairing levels across all parameter sets (P up to 1277). +/// Levels: 1277 -> 639 -> 320 -> 160 -> 80 -> 40 -> 20 -> 10 -> 5 -> 3 -> 2 -> base case (n=1). +const fn compute_levels(p: usize) -> usize { + let mut n = p; + let mut levels = 0; + while n > 1 { + levels += 1; + n = n.div_ceil(2); + } + levels +} + +/// Total moduli storage across all levels (including the base case modulus). +const fn compute_m_storage(p: usize) -> usize { + let mut n = p; + let mut total = 0; + while n > 1 { + total += n; + n = n.div_ceil(2); + } + total + 1 // +1 for base case modulus +} + +const MAX_LEVELS: usize = compute_levels(1277); // 11 +const MAX_M_STORAGE: usize = compute_m_storage(1277); // 2557 + +/// Constant-time divmod: *quotient = x / m, returns x % m. +/// m must be > 0 and < 16384. Matches PQClean's two-step Barrett reduction. +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +fn uint32_divmod_uint14(quotient: &mut u32, x: u32, m: u16) -> u16 { + let m32 = m as u32; + let v = (0x80000000u32 as u64) / (m32 as u64); + // First Barrett step + let mut qpart = ((x as u64 * v) >> 31) as u32; + let mut r = x - qpart * m32; + *quotient = qpart; + // Second Barrett step on remainder + qpart = ((r as u64 * v) >> 31) as u32; + r -= qpart * m32; + *quotient += qpart; + // Final speculative correction + r = r.wrapping_sub(m32); + *quotient += 1; + let mask = (r >> 31).wrapping_neg(); // 0xFFFFFFFF if r underflowed (was < m), else 0 + r = r.wrapping_add(mask & m32); + *quotient = quotient.wrapping_add(mask); // subtract 1 if we over-corrected + r as u16 +} + +#[inline(always)] +fn uint32_mod_uint14(x: u32, m: u16) -> u16 { + let mut q = 0u32; + uint32_divmod_uint14(&mut q, x, m) +} + +/// Iterative variable-radix encoding. Pairs values, emits bottom bytes when the +/// combined modulus reaches 16384, then repeats on the paired values. +/// `r` and `m` are modified in place across levels. +#[allow(clippy::cast_possible_truncation)] +fn encode(out: &mut [u8], r: &mut [u16], m: &mut [u16], n_start: usize) -> usize { + if n_start == 0 { + return 0; + } + if n_start == 1 { + return encode_single(out, r[0] as u32, m[0] as u32); + } + + let mut n = n_start; + let mut pos = 0; + + while n > 1 { + let n2 = n.div_ceil(2); + // In-place pairing: read from [2*i, 2*i+1], write to [i]. + // Safe because i < 2*i for i >= 1, so reads precede writes. + for i in 0..n2 { + if 2 * i + 1 < n { + let mut combined = r[2 * i] as u32 + (r[2 * i + 1] as u32) * (m[2 * i] as u32); + let mut cm = (m[2 * i] as u32) * (m[2 * i + 1] as u32); + while cm >= 16384 { + out[pos] = combined as u8; + pos += 1; + combined >>= 8; + cm = (cm + 255) >> 8; + } + r[i] = combined as u16; + m[i] = cm as u16; + } else { + r[i] = r[2 * i]; + m[i] = m[2 * i]; + } + } + n = n2; + } + + // Base case: single remaining value + pos + encode_single(&mut out[pos..], r[0] as u32, m[0] as u32) +} + +#[allow(clippy::cast_possible_truncation)] +fn encode_single(out: &mut [u8], mut val: u32, mut modulus: u32) -> usize { + let mut pos = 0; + while modulus > 1 { + out[pos] = val as u8; + pos += 1; + val >>= 8; + modulus = (modulus + 255) >> 8; + } + pos +} + +/// Iterative variable-radix decoding. Forward pass computes moduli and byte +/// offsets at each level; backward pass expands decoded values from base case. +#[allow(clippy::cast_possible_truncation)] +fn decode(out: &mut [u16], s: &[u8], m_in: &[u16], n_start: usize) { + if n_start == 0 { + return; + } + if n_start == 1 { + decode_single(out, s, m_in[0]); + return; + } + + // --- Forward pass: compute level sizes, moduli, and bottom-byte totals --- + + let mut ns = [0usize; MAX_LEVELS]; + let mut num_levels = 0; + { + let mut n = n_start; + while n > 1 { + ns[num_levels] = n; + num_levels += 1; + n = n.div_ceil(2); + } + } + + // Flat storage for moduli at every level (including paired output for base case) + let mut all_m = [0u16; MAX_M_STORAGE]; + let mut level_m_offset = [0usize; MAX_LEVELS + 1]; + let mut level_bottom_total = [0usize; MAX_LEVELS]; + + // Level 0 input moduli + all_m[..n_start].copy_from_slice(&m_in[..n_start]); + level_m_offset[0] = 0; + let mut m_pos = n_start; + + for level in 0..num_levels { + let n = ns[level]; + let n2 = n.div_ceil(2); + let m_off = level_m_offset[level]; + level_m_offset[level + 1] = m_pos; + let mut total_bottom = 0usize; + + for i in 0..n2 { + if 2 * i + 1 < n { + let mut cm = (all_m[m_off + 2 * i] as u32) * (all_m[m_off + 2 * i + 1] as u32); + let mut bb = 0usize; + while cm >= 16384 { + bb += 1; + cm = (cm + 255) >> 8; + } + total_bottom += bb; + all_m[m_pos] = cm as u16; + } else { + all_m[m_pos] = all_m[m_off + 2 * i]; + } + m_pos += 1; + } + + level_bottom_total[level] = total_bottom; + } + + // Cumulative bottom-byte start positions + let mut level_bottom_start = [0usize; MAX_LEVELS]; + let mut cum_bottom = 0usize; + for level in 0..num_levels { + level_bottom_start[level] = cum_bottom; + cum_bottom += level_bottom_total[level]; + } + + // --- Decode base case (n = 1) --- + let base_m_off = level_m_offset[num_levels]; + decode_single(out, &s[cum_bottom..], all_m[base_m_off]); + + // --- Backward pass: expand decoded values level by level --- + for level in (0..num_levels).rev() { + let n = ns[level]; + let n2 = n.div_ceil(2); + let m_off = level_m_offset[level]; + + // out[0..n2] holds decoded values; expand in-place to out[0..n]. + // Process backwards: reads from out[i], writes to out[2*i] / out[2*i+1]. + let mut bpos = level_bottom_start[level] + level_bottom_total[level]; + + for i in (0..n2).rev() { + if 2 * i + 1 < n { + // Recompute bottom-byte count for this pair + let mut cm = (all_m[m_off + 2 * i] as u32) * (all_m[m_off + 2 * i + 1] as u32); + let mut bb = 0usize; + while cm >= 16384 { + bb += 1; + cm = (cm + 255) >> 8; + } + + bpos -= bb; + let mut combined = out[i] as u32; + for j in (0..bb).rev() { + combined = (combined << 8) | (s[bpos + j] as u32); + } + + let mut q = 0u32; + let remainder = uint32_divmod_uint14(&mut q, combined, all_m[m_off + 2 * i]); + out[2 * i] = remainder; + out[2 * i + 1] = uint32_mod_uint14(q, all_m[m_off + 2 * i + 1]); + } else { + out[2 * i] = out[i]; + } + } + } +} + +fn decode_single(out: &mut [u16], s: &[u8], m: u16) { + if m == 1 { + out[0] = 0; + } else if m <= 256 { + out[0] = uint32_mod_uint14(s[0] as u32, m); + } else { + out[0] = uint32_mod_uint14(s[0] as u32 + ((s[1] as u32) << 8), m); + } +} + +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn rq_encode(f: &[i16], params: &SntrupParameters) -> Vec { + let p = params.p; + let q12 = params.q12; + let q_u16 = params.q as u16; + + let mut r = vec![0u16; p]; + for (ri, &fi) in r.iter_mut().zip(f.iter()) { + *ri = (fi as i32 + q12) as u16; + } + let mut m = vec![q_u16; p]; + let mut out = vec![0u8; params.pk_size]; + encode(&mut out, &mut r, &mut m, p); + out +} + +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn rq_decode(c: &[u8], params: &SntrupParameters) -> Vec { + let p = params.p; + let q12 = params.q12; + let q_u16 = params.q as u16; + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + + let m = vec![q_u16; p]; + let mut r = vec![0u16; p]; + // Callers pass exactly `pk_size` bytes, so borrow directly on the hot path. + // Only allocate-and-pad if the input is short (defensive; never happens via + // the public API, where `EncapsulationKey::try_from` enforces the size). + let mut padded; + let s: &[u8] = if c.len() >= params.pk_size { + &c[..params.pk_size] + } else { + padded = vec![0u8; params.pk_size]; + padded[..c.len()].copy_from_slice(c); + &padded + }; + decode(&mut r, s, &m, p); + let mut f = vec![0i16; p]; + for (fi, &ri) in f.iter_mut().zip(r.iter()) { + *fi = modq::freeze(ri as i32 - q12, q, b1, b2); + } + f +} + +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn rounded_encode(f: &[i16], params: &SntrupParameters) -> Vec { + let p = params.p; + let q12 = params.q12; + let q_rounded = (params.q as u16).div_ceil(3); + + let mut r = vec![0u16; p]; + for (ri, &fi) in r.iter_mut().zip(f.iter()) { + *ri = (((fi as i32 + q12) * 10923) >> 15) as u16; + } + let mut m = vec![q_rounded; p]; + let mut out = vec![0u8; params.rounded_encode_size]; + encode(&mut out, &mut r, &mut m, p); + out +} + +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn rounded_decode(c: &[u8], params: &SntrupParameters) -> Vec { + let p = params.p; + let q12 = params.q12; + let q_rounded = (params.q as u16).div_ceil(3); + let q = params.q; + let b1 = params.barrett1; + let b2 = params.barrett2; + + let m = vec![q_rounded; p]; + let mut r = vec![0u16; p]; + // Callers pass exactly `rounded_encode_size` bytes, so borrow directly on the + // hot path. Only allocate-and-pad if the input is short (defensive; never + // happens via the public API, where `Ciphertext::try_from` enforces the size). + let mut padded; + let s: &[u8] = if c.len() >= params.rounded_encode_size { + &c[..params.rounded_encode_size] + } else { + padded = vec![0u8; params.rounded_encode_size]; + padded[..c.len()].copy_from_slice(c); + &padded + }; + decode(&mut r, s, &m, p); + let mut f = vec![0i16; p]; + for (fi, &ri) in f.iter_mut().zip(r.iter()) { + *fi = modq::freeze(ri as i32 * 3 - q12, q, b1, b2); + } + f +} diff --git a/sntrup-kem/src/rq/modq.rs b/sntrup-kem/src/rq/modq.rs new file mode 100644 index 0000000..36c2ea2 --- /dev/null +++ b/sntrup-kem/src/rq/modq.rs @@ -0,0 +1,63 @@ +/// Barrett reduction: freezes `a` into the range (-q/2, q/2). +/// +/// `barrett1` = floor(2^20 / q), `barrett2` = floor(2^28 / q). +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +pub fn freeze(a: i32, q: i32, barrett1: i32, barrett2: i32) -> i16 { + let mut b = a; + b -= q * ((barrett1 * b) >> 20); + b -= q * ((barrett2 * b + 134_217_728) >> 28); + b as i16 +} + +#[inline(always)] +pub fn product(a: i16, b: i16, q: i32, b1: i32, b2: i32) -> i16 { + freeze(a as i32 * b as i32, q, b1, b2) +} + +#[inline(always)] +pub fn square(a: i16, q: i32, b1: i32, b2: i32) -> i16 { + let a32 = a as i32; + freeze(a32 * a32, q, b1, b2) +} + +/// Compute `a1^(q-2) mod q` via Fermat's little theorem using binary +/// exponentiation (square-and-multiply). This is constant-time because +/// `q` is a public parameter. +#[inline(always)] +pub fn reciprocal(a1: i16, q: i32, b1: i32, b2: i32) -> i16 { + #[allow(clippy::cast_sign_loss)] + let exp = (q - 2) as u32; + // Find the highest set bit position + let bits = 32 - exp.leading_zeros(); // number of significant bits + + // Square-and-multiply from the second-highest bit down + let mut result = a1; + for i in (0..(bits - 1)).rev() { + result = square(result, q, b1, b2); + if (exp >> i) & 1 == 1 { + result = product(result, a1, q, b1, b2); + } + } + result +} + +#[inline(always)] +pub fn quotient(a: i16, b: i16, q: i32, b1: i32, b2: i32) -> i16 { + product(a, reciprocal(b, q, b1, b2), q, b1, b2) +} + +#[inline(always)] +pub fn minus_product(a: i16, b: i16, c: i16, q: i32, b1: i32, b2: i32) -> i16 { + freeze(a as i32 - (b as i32 * c as i32), q, b1, b2) +} + +/// Constant-time: returns -1 if x != 0, 0 if x == 0. +#[inline(always)] +#[allow(clippy::cast_sign_loss)] +pub fn mask_set(x: i16) -> isize { + let mut r = (x as u16) as i32; + r = -r; + r >>= 31; + r as isize +} diff --git a/sntrup-kem/src/rq/vector.rs b/sntrup-kem/src/rq/vector.rs new file mode 100644 index 0000000..59909cc --- /dev/null +++ b/sntrup-kem/src/rq/vector.rs @@ -0,0 +1,313 @@ +#![allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] + +use crate::rq::modq; + +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +pub fn swap(x: &mut [i16], y: &mut [i16], n: usize, mask: isize) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return swap_avx2(x, y, n, mask); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return swap_neon(x, y, n, mask); + } + #[allow(unreachable_code)] + swap_scalar(x, y, n, mask); +} + +#[allow(clippy::cast_possible_truncation)] +fn swap_scalar(x: &mut [i16], y: &mut [i16], n: usize, mask: isize) { + let c = mask as i16; + for i in 0..n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + } +} + +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +unsafe fn swap_avx2(x: &mut [i16], y: &mut [i16], n: usize, mask: isize) { + unsafe { + use core::arch::x86_64::*; + let cv = _mm256_set1_epi16(mask as i16); + let mut i = 0usize; + while i + 16 <= n { + let xv = _mm256_loadu_si256(x.as_ptr().add(i) as *const __m256i); + let yv = _mm256_loadu_si256(y.as_ptr().add(i) as *const __m256i); + let t = _mm256_and_si256(cv, _mm256_xor_si256(xv, yv)); + _mm256_storeu_si256( + x.as_mut_ptr().add(i) as *mut __m256i, + _mm256_xor_si256(xv, t), + ); + _mm256_storeu_si256( + y.as_mut_ptr().add(i) as *mut __m256i, + _mm256_xor_si256(yv, t), + ); + i += 16; + } + let c = mask as i16; + while i < n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + i += 1; + } + } +} + +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +unsafe fn swap_neon(x: &mut [i16], y: &mut [i16], n: usize, mask: isize) { + unsafe { + use core::arch::aarch64::*; + let cv = vdupq_n_s16(mask as i16); + let mut i = 0usize; + while i + 8 <= n { + let xv = vld1q_s16(x.as_ptr().add(i)); + let yv = vld1q_s16(y.as_ptr().add(i)); + let t = vandq_s16(cv, veorq_s16(xv, yv)); + vst1q_s16(x.as_mut_ptr().add(i), veorq_s16(xv, t)); + vst1q_s16(y.as_mut_ptr().add(i), veorq_s16(yv, t)); + i += 8; + } + let c = mask as i16; + while i < n { + let t = c & (x[i] ^ y[i]); + x[i] ^= t; + y[i] ^= t; + i += 1; + } + } +} + +#[inline(always)] +pub fn product(z: &mut [i16], n: usize, x: &[i16], c: i16, q: i32, b1: i32, b2: i32) { + for i in 0..n { + z[i] = modq::product(x[i], c, q, b1, b2); + } +} + +/// Fused minus_product and shift: z[i+1] = freeze(z[i] - y[i]*c), z[0] = 0. +/// Processes backward to avoid overwrite conflicts, eliminating a separate memmove. +#[inline(always)] +pub fn minus_product_shift(z: &mut [i16], n: usize, y: &[i16], c: i16, q: i32, b1: i32, b2: i32) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return minus_product_shift_avx2(z, n, y, c, q, b1, b2); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return minus_product_shift_neon(z, n, y, c, q, b1, b2); + } + #[allow(unreachable_code)] + minus_product_shift_scalar(z, n, y, c, q, b1, b2); +} + +fn minus_product_shift_scalar( + z: &mut [i16], + n: usize, + y: &[i16], + c: i16, + q: i32, + b1: i32, + b2: i32, +) { + for i in (0..n - 1).rev() { + z[i + 1] = modq::minus_product(z[i], y[i], c, q, b1, b2); + } + z[0] = 0; +} + +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +unsafe fn minus_product_shift_avx2( + z: &mut [i16], + n: usize, + y: &[i16], + c: i16, + q: i32, + b1: i32, + b2: i32, +) { + unsafe { + use core::arch::x86_64::*; + let qv = _mm256_set1_epi32(q); + let kb1 = _mm256_set1_epi32(b1); + let kb2 = _mm256_set1_epi32(b2); + let k134m = _mm256_set1_epi32(134_217_728); + let cv = _mm256_set1_epi32(c as i32); + + let mut j = (n - 2) as isize; + + // Process 16 at a time (two 8-wide batches for ILP), backward + while j >= 15 { + let start = (j - 15) as usize; + + // Batch 0: elements start..start+8 + let zv0 = + _mm256_cvtepi16_epi32(_mm_loadu_si128(z.as_ptr().add(start) as *const __m128i)); + let yv0 = + _mm256_cvtepi16_epi32(_mm_loadu_si128(y.as_ptr().add(start) as *const __m128i)); + let a0 = _mm256_sub_epi32(zv0, _mm256_mullo_epi32(yv0, cv)); + + // Batch 1: elements start+8..start+16 + let zv1 = + _mm256_cvtepi16_epi32(_mm_loadu_si128(z.as_ptr().add(start + 8) as *const __m128i)); + let yv1 = + _mm256_cvtepi16_epi32(_mm_loadu_si128(y.as_ptr().add(start + 8) as *const __m128i)); + let a1 = _mm256_sub_epi32(zv1, _mm256_mullo_epi32(yv1, cv)); + + // Barrett freeze batch 0 + let t0 = _mm256_srai_epi32(_mm256_mullo_epi32(a0, kb1), 20); + let b0 = _mm256_sub_epi32(a0, _mm256_mullo_epi32(t0, qv)); + let t0 = _mm256_srai_epi32(_mm256_add_epi32(_mm256_mullo_epi32(b0, kb2), k134m), 28); + let r0 = _mm256_sub_epi32(b0, _mm256_mullo_epi32(t0, qv)); + + // Barrett freeze batch 1 + let t1 = _mm256_srai_epi32(_mm256_mullo_epi32(a1, kb1), 20); + let b1 = _mm256_sub_epi32(a1, _mm256_mullo_epi32(t1, qv)); + let t1 = _mm256_srai_epi32(_mm256_add_epi32(_mm256_mullo_epi32(b1, kb2), k134m), 28); + let r1 = _mm256_sub_epi32(b1, _mm256_mullo_epi32(t1, qv)); + + // Pack 8+8 i32 -> 16 i16 and store at offset +1 (the shift) + let packed = _mm256_permute4x64_epi64(_mm256_packs_epi32(r0, r1), 0xD8); + _mm256_storeu_si256(z.as_mut_ptr().add(start + 1) as *mut __m256i, packed); + j -= 16; + } + + // Process remaining 8 at a time + while j >= 7 { + let start = (j - 7) as usize; + let zv = + _mm256_cvtepi16_epi32(_mm_loadu_si128(z.as_ptr().add(start) as *const __m128i)); + let yv = + _mm256_cvtepi16_epi32(_mm_loadu_si128(y.as_ptr().add(start) as *const __m128i)); + let a = _mm256_sub_epi32(zv, _mm256_mullo_epi32(yv, cv)); + + let t = _mm256_srai_epi32(_mm256_mullo_epi32(a, kb1), 20); + let b = _mm256_sub_epi32(a, _mm256_mullo_epi32(t, qv)); + let t = _mm256_srai_epi32(_mm256_add_epi32(_mm256_mullo_epi32(b, kb2), k134m), 28); + let r = _mm256_sub_epi32(b, _mm256_mullo_epi32(t, qv)); + + let lo = _mm256_castsi256_si128(r); + let hi = _mm256_extracti128_si256(r, 1); + _mm_storeu_si128( + z.as_mut_ptr().add(start + 1) as *mut __m128i, + _mm_packs_epi32(lo, hi), + ); + j -= 8; + } + + // Scalar remainder + while j >= 0 { + z[(j + 1) as usize] = modq::minus_product(z[j as usize], y[j as usize], c, q, b1, b2); + j -= 1; + } + z[0] = 0; + } +} + +/// NEON Barrett minus_product_shift: 4 i32 elements at a time (128-bit), backward. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +unsafe fn minus_product_shift_neon( + z: &mut [i16], + n: usize, + y: &[i16], + c: i16, + q: i32, + b1: i32, + b2: i32, +) { + unsafe { + use core::arch::aarch64::*; + let qv = vdupq_n_s32(q); + let kb1 = vdupq_n_s32(b1); + let kb2 = vdupq_n_s32(b2); + let k134m = vdupq_n_s32(134_217_728); + let cv = vdupq_n_s32(c as i32); + + let mut j = (n - 2) as isize; + + // Process 8 at a time (two 4-wide batches), backward + while j >= 7 { + let start = (j - 7) as usize; + + // Batch 0: elements start..start+4 + let zv0 = vmovl_s16(vld1_s16(z.as_ptr().add(start))); + let yv0 = vmovl_s16(vld1_s16(y.as_ptr().add(start))); + let a0 = vsubq_s32(zv0, vmulq_s32(yv0, cv)); + + // Batch 1: elements start+4..start+8 + let zv1 = vmovl_s16(vld1_s16(z.as_ptr().add(start + 4))); + let yv1 = vmovl_s16(vld1_s16(y.as_ptr().add(start + 4))); + let a1 = vsubq_s32(zv1, vmulq_s32(yv1, cv)); + + // Barrett freeze batch 0 + let t0 = vshrq_n_s32(vmulq_s32(a0, kb1), 20); + let b0 = vsubq_s32(a0, vmulq_s32(t0, qv)); + let t0 = vshrq_n_s32(vaddq_s32(vmulq_s32(b0, kb2), k134m), 28); + let r0 = vsubq_s32(b0, vmulq_s32(t0, qv)); + + // Barrett freeze batch 1 + let t1 = vshrq_n_s32(vmulq_s32(a1, kb1), 20); + let b1 = vsubq_s32(a1, vmulq_s32(t1, qv)); + let t1 = vshrq_n_s32(vaddq_s32(vmulq_s32(b1, kb2), k134m), 28); + let r1 = vsubq_s32(b1, vmulq_s32(t1, qv)); + + // Pack 4+4 i32 -> 8 i16 (naturally ordered, no permute needed) + let packed = vcombine_s16(vmovn_s32(r0), vmovn_s32(r1)); + vst1q_s16(z.as_mut_ptr().add(start + 1), packed); + j -= 8; + } + + // Process 4 at a time + while j >= 3 { + let start = (j - 3) as usize; + let zv = vmovl_s16(vld1_s16(z.as_ptr().add(start))); + let yv = vmovl_s16(vld1_s16(y.as_ptr().add(start))); + let a = vsubq_s32(zv, vmulq_s32(yv, cv)); + + let t = vshrq_n_s32(vmulq_s32(a, kb1), 20); + let b = vsubq_s32(a, vmulq_s32(t, qv)); + let t = vshrq_n_s32(vaddq_s32(vmulq_s32(b, kb2), k134m), 28); + let r = vsubq_s32(b, vmulq_s32(t, qv)); + + vst1_s16(z.as_mut_ptr().add(start + 1), vmovn_s32(r)); + j -= 4; + } + + // Scalar remainder + while j >= 0 { + z[(j + 1) as usize] = modq::minus_product(z[j as usize], y[j as usize], c, q, b1, b2); + j -= 1; + } + z[0] = 0; + } +} diff --git a/sntrup-kem/src/types.rs b/sntrup-kem/src/types.rs new file mode 100644 index 0000000..62fe0bf --- /dev/null +++ b/sntrup-kem/src/types.rs @@ -0,0 +1,380 @@ +//! Generic Streamlined NTRU Prime types parameterized by parameter set. + +use crate::error::Error; +use crate::params::SntrupParams; +use core::marker::PhantomData; +use subtle::ConstantTimeEq; +use zeroize::Zeroize; + +/// Streamlined NTRU Prime encapsulation key (public key). +#[derive(Clone)] +pub struct EncapsulationKey { + bytes: Vec, + _marker: PhantomData

, +} + +/// Streamlined NTRU Prime decapsulation key (secret key). +#[derive(Clone)] +pub struct DecapsulationKey { + bytes: Vec, + _marker: PhantomData

, +} + +/// Streamlined NTRU Prime ciphertext. +#[derive(Clone)] +pub struct Ciphertext { + bytes: Vec, + _marker: PhantomData

, +} + +/// Streamlined NTRU Prime shared secret. +#[derive(Clone)] +pub struct SharedSecret { + bytes: Vec, + _marker: PhantomData

, +} + +/// Streamlined NTRU Prime Key Encapsulation Mechanism parameterized by parameter set. +/// +/// Zero-sized marker type providing [`generate_key`](SntrupKem::generate_key). +/// Use the type aliases [`Sntrup653`](crate::Sntrup653), +/// [`Sntrup761`](crate::Sntrup761), [`Sntrup857`](crate::Sntrup857), +/// [`Sntrup953`](crate::Sntrup953), [`Sntrup1013`](crate::Sntrup1013), +/// [`Sntrup1277`](crate::Sntrup1277). +#[derive(Debug, Clone, Copy)] +pub struct SntrupKem(PhantomData

); + +// --------------------------------------------------------------------------- +// Internal constructors +// --------------------------------------------------------------------------- + +/// Internal `from_vec` constructor for a byte-wrapper type. +macro_rules! impl_from_vec { + ($ty:ident) => { + impl $ty

{ + pub(crate) fn from_vec(bytes: Vec) -> Self { + Self { + bytes, + _marker: PhantomData, + } + } + } + }; +} + +impl_from_vec!(EncapsulationKey); +impl_from_vec!(DecapsulationKey); +impl_from_vec!(Ciphertext); +impl_from_vec!(SharedSecret); + +// --------------------------------------------------------------------------- +// DecapsulationKey: extract encapsulation key +// --------------------------------------------------------------------------- + +impl DecapsulationKey

{ + /// Get the encapsulation (public) key embedded in this decapsulation key. + /// + /// SK layout: f(small_enc) || ginv(small_enc) || pk(pk_size) || rho(small_enc) || hash4(32) + /// The public key starts at offset `2 * small_encode_size` with length `pk_size`. + pub fn encapsulation_key(&self) -> EncapsulationKey

{ + let params = P::params(); + let pk_start = 2 * params.small_encode_size; + let pk_end = pk_start + params.pk_size; + EncapsulationKey::from_vec(self.bytes[pk_start..pk_end].to_vec()) + } +} + +// --------------------------------------------------------------------------- +// Debug +// --------------------------------------------------------------------------- + +impl core::fmt::Debug for EncapsulationKey

{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let name: String = format!("{}::EncapsulationKey", P::NAME); + f.debug_struct(&name) + .field("len", &P::PK_BYTES) + .field("bytes", &hex::encode(&self.bytes)) + .finish() + } +} + +impl core::fmt::Debug for DecapsulationKey

{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let name: String = format!("{}::DecapsulationKey", P::NAME); + f.debug_struct(&name).finish() + } +} + +impl core::fmt::Debug for Ciphertext

{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let name: String = format!("{}::Ciphertext", P::NAME); + f.debug_struct(&name) + .field("len", &P::CT_BYTES) + .field("bytes", &hex::encode(&self.bytes)) + .finish() + } +} + +impl core::fmt::Debug for SharedSecret

{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let name: String = format!("{}::SharedSecret", P::NAME); + f.debug_struct(&name).finish() + } +} + +// --------------------------------------------------------------------------- +// AsRef<[u8]> +// --------------------------------------------------------------------------- + +/// `AsRef<[u8]>` byte access for a wrapper type. +macro_rules! impl_as_ref { + ($ty:ident) => { + impl AsRef<[u8]> for $ty

{ + fn as_ref(&self) -> &[u8] { + &self.bytes + } + } + }; +} + +impl_as_ref!(EncapsulationKey); +impl_as_ref!(DecapsulationKey); +impl_as_ref!(Ciphertext); +impl_as_ref!(SharedSecret); + +// --------------------------------------------------------------------------- +// TryFrom<&[u8]> +// --------------------------------------------------------------------------- + +/// Generate the `TryFrom` family (`&[u8]`, `Vec`, `&Vec`, `Box<[u8]>`) +/// for a fixed-size wrapper type. The `&[u8]` impl is the single length-checked +/// entry point; the owned variants delegate to it. +macro_rules! impl_try_from { + ($ty:ident, $size:ident) => { + impl TryFrom<&[u8]> for $ty

{ + type Error = Error; + fn try_from(bytes: &[u8]) -> Result { + if bytes.len() != P::$size { + return Err(Error::InvalidSize { + expected: P::$size, + actual: bytes.len(), + }); + } + Ok(Self { + bytes: bytes.to_vec(), + _marker: PhantomData, + }) + } + } + + impl TryFrom> for $ty

{ + type Error = Error; + fn try_from(bytes: Vec) -> Result { + Self::try_from(bytes.as_slice()) + } + } + + impl TryFrom<&Vec> for $ty

{ + type Error = Error; + fn try_from(bytes: &Vec) -> Result { + Self::try_from(bytes.as_slice()) + } + } + + impl TryFrom> for $ty

{ + type Error = Error; + fn try_from(bytes: Box<[u8]>) -> Result { + Self::try_from(bytes.as_ref()) + } + } + }; +} + +impl_try_from!(EncapsulationKey, PK_BYTES); +impl_try_from!(DecapsulationKey, SK_BYTES); +impl_try_from!(Ciphertext, CT_BYTES); + +// --------------------------------------------------------------------------- +// PartialEq / Eq (EncapsulationKey, Ciphertext — non-secret, byte equality) +// --------------------------------------------------------------------------- + +impl PartialEq for EncapsulationKey

{ + fn eq(&self, other: &Self) -> bool { + self.bytes == other.bytes + } +} + +impl Eq for EncapsulationKey

{} + +impl PartialEq for Ciphertext

{ + fn eq(&self, other: &Self) -> bool { + self.bytes == other.bytes + } +} + +impl Eq for Ciphertext

{} + +// --------------------------------------------------------------------------- +// ConstantTimeEq / PartialEq / Eq (DecapsulationKey) +// --------------------------------------------------------------------------- + +impl ConstantTimeEq for DecapsulationKey

{ + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.bytes.as_slice().ct_eq(other.bytes.as_slice()) + } +} + +impl PartialEq for DecapsulationKey

{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for DecapsulationKey

{} + +// --------------------------------------------------------------------------- +// ConstantTimeEq / PartialEq / Eq (SharedSecret) +// --------------------------------------------------------------------------- + +impl ConstantTimeEq for SharedSecret

{ + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.bytes.as_slice().ct_eq(other.bytes.as_slice()) + } +} + +impl PartialEq for SharedSecret

{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for SharedSecret

{} + +// --------------------------------------------------------------------------- +// Zeroize + Drop (secret types) +// --------------------------------------------------------------------------- + +impl Zeroize for DecapsulationKey

{ + fn zeroize(&mut self) { + self.bytes.zeroize(); + } +} + +impl Drop for DecapsulationKey

{ + fn drop(&mut self) { + self.zeroize(); + } +} + +impl Zeroize for SharedSecret

{ + fn zeroize(&mut self) { + self.bytes.zeroize(); + } +} + +impl Drop for SharedSecret

{ + fn drop(&mut self) { + self.zeroize(); + } +} + +// --------------------------------------------------------------------------- +// KEM operations (feature-gated) +// --------------------------------------------------------------------------- + +#[cfg(feature = "kgen")] +impl SntrupKem

{ + /// Generate a Streamlined NTRU Prime key pair. + pub fn generate_key( + rng: &mut impl rand::CryptoRng, + ) -> (EncapsulationKey

, DecapsulationKey

) { + let (pk, sk) = crate::kem::keygen(P::params(), rng); + ( + EncapsulationKey::from_vec(pk), + DecapsulationKey::from_vec(sk), + ) + } + + /// Generate a key pair deterministically from a 32-byte seed. + /// + /// The seed is expanded via ChaCha20Rng to derive the full key pair. + /// Identical seeds always produce identical key pairs. + pub fn generate_key_deterministic( + seed: &[u8; 32], + ) -> (EncapsulationKey

, DecapsulationKey

) { + use rand::SeedableRng; + let mut rng = rand_chacha::ChaCha20Rng::from_seed(*seed); + Self::generate_key(&mut rng) + } +} + +#[cfg(feature = "ecap")] +impl EncapsulationKey

{ + /// Encapsulate: produce a ciphertext and shared secret. + pub fn encapsulate(&self, rng: &mut impl rand::CryptoRng) -> (Ciphertext

, SharedSecret

) { + let (ct, ss) = crate::kem::encaps(&self.bytes, P::params(), rng); + (Ciphertext::from_vec(ct), SharedSecret::from_vec(ss)) + } +} + +#[cfg(feature = "dcap")] +impl DecapsulationKey

{ + /// Decapsulate: recover shared secret from ciphertext. + /// + /// Always returns a shared secret (implicit rejection / IND-CCA2). + /// On failure, returns a pseudorandom key derived from rho, + /// indistinguishable from a valid key to an attacker. + pub fn decapsulate(&self, ct: &Ciphertext

) -> SharedSecret

{ + let ss = crate::kem::decaps(&self.bytes, &ct.bytes, P::params()); + SharedSecret::from_vec(ss) + } +} + +// --------------------------------------------------------------------------- +// Serde (feature-gated) +// --------------------------------------------------------------------------- + +#[cfg(feature = "serde")] +mod serde_impl { + use super::*; + + /// Generate `Serialize`/`Deserialize` for a byte-wrapper type. Deserialization + /// validates that the decoded length matches the parameter set's fixed size, + /// rejecting (rather than silently zero-padding) short or oversized input. + macro_rules! impl_serde { + ($ty:ident, $size:ident) => { + impl serde::Serialize for $ty

{ + fn serialize(&self, s: S) -> Result { + serdect::slice::serialize_hex_lower_or_bin(&self.bytes, s) + } + } + + impl<'de, P: SntrupParams> serde::Deserialize<'de> for $ty

{ + fn deserialize>(d: D) -> Result { + let mut buf = vec![0u8; P::$size]; + let decoded = serdect::slice::deserialize_hex_or_bin(&mut buf, d)?; + if decoded.len() != P::$size { + return Err(serde::de::Error::invalid_length( + decoded.len(), + &concat!( + stringify!($ty), + " expects exactly P::", + stringify!($size), + " bytes" + ), + )); + } + Ok(Self { + bytes: buf, + _marker: PhantomData, + }) + } + } + }; + } + + impl_serde!(EncapsulationKey, PK_BYTES); + impl_serde!(DecapsulationKey, SK_BYTES); + impl_serde!(Ciphertext, CT_BYTES); + impl_serde!(SharedSecret, SS_BYTES); +} diff --git a/sntrup-kem/src/utils.rs b/sntrup-kem/src/utils.rs new file mode 100644 index 0000000..f2be3da --- /dev/null +++ b/sntrup-kem/src/utils.rs @@ -0,0 +1,446 @@ +use sha2::{Digest, Sha512}; +use zeroize::Zeroize; + +use crate::params::SntrupParameters; +use crate::{r3, rq, zx}; + +/// Hash prefix helper: SHA-512(prefix || input), truncated to 32 bytes. +pub(crate) fn hash_prefix(out: &mut [u8; 32], prefix: u8, input: &[u8]) { + let mut hasher = Sha512::new(); + hasher.update([prefix]); + hasher.update(input); + let digest = hasher.finalize(); + out.copy_from_slice(&digest[..32]); +} + +/// hash_confirm: Hash(2 || Hash(3 || r_enc) || cache) +/// where cache = Hash4(pk) stored in the secret key. +pub(crate) fn hash_confirm(out: &mut [u8; 32], r_enc: &[u8], cache: &[u8; 32]) { + let mut inner = [0u8; 32]; + hash_prefix(&mut inner, 3, r_enc); + + let mut hasher = Sha512::new(); + hasher.update([2u8]); + hasher.update(inner); + hasher.update(&cache[..]); + let digest = hasher.finalize(); + out.copy_from_slice(&digest[..32]); +} + +/// hash_session: Hash(b || Hash(3 || y) || z) +pub(crate) fn hash_session(out: &mut [u8; 32], b: u8, y: &[u8], z: &[u8]) { + let mut inner = [0u8; 32]; + hash_prefix(&mut inner, 3, y); + + let mut hasher = Sha512::new(); + hasher.update([b]); + hasher.update(inner); + hasher.update(z); + let digest = hasher.finalize(); + out.copy_from_slice(&digest[..32]); +} + +/// Constant-time: returns 0 if x == 0, -1 (0xFFFFFFFF) otherwise. +#[allow(clippy::cast_sign_loss)] +fn int16_nonzero_mask(x: i16) -> i32 { + let u = x as u16; + let mut r = u.wrapping_neg() | u; + r >>= 15; + -(r as i32) +} + +/// Constant-time check if weight of `r` equals `w`. +/// Returns 0 if weight == w, -1 otherwise. +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap +)] +pub(crate) fn weightw_mask(r: &[i8], p: usize, w: usize) -> i32 { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return weightw_mask_avx2(r, p, w); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return weightw_mask_neon(r, p, w); + } + #[allow(unreachable_code)] + weightw_mask_scalar(r, p, w) +} + +#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] +fn weightw_mask_scalar(r: &[i8], _p: usize, w: usize) -> i32 { + let mut weight: i32 = 0; + for &val in r.iter() { + weight += (val & 1) as i32; + } + int16_nonzero_mask((weight - w as i32) as i16) +} + +/// Count non-zero elements 32 at a time using AVX2. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap +)] +unsafe fn weightw_mask_avx2(r: &[i8], p: usize, w: usize) -> i32 { + unsafe { + use core::arch::x86_64::*; + let ones = _mm256_set1_epi8(1); + let mut acc = _mm256_setzero_si256(); + let mut i = 0usize; + while i + 32 <= p { + let v = _mm256_loadu_si256(r.as_ptr().add(i) as *const __m256i); + let masked = _mm256_and_si256(v, ones); + acc = _mm256_add_epi8(acc, masked); + i += 32; + } + // Horizontal sum: sad against zero gives sum of abs values in each 8-byte lane + let sad = _mm256_sad_epu8(acc, _mm256_setzero_si256()); + // sad has 4 u64 lanes with partial sums + let lo = _mm256_castsi256_si128(sad); + let hi = _mm256_extracti128_si256(sad, 1); + let sum128 = _mm_add_epi64(lo, hi); + let sum_hi = _mm_srli_si128(sum128, 8); + let total = _mm_add_epi64(sum128, sum_hi); + let mut weight = _mm_cvtsi128_si64(total) as i32; + // Handle remainder + while i < p { + weight += (r[i] & 1) as i32; + i += 1; + } + int16_nonzero_mask((weight - w as i32) as i16) + } +} + +/// Count non-zero elements 16 at a time using NEON. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +#[allow( + unsafe_code, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap +)] +unsafe fn weightw_mask_neon(r: &[i8], p: usize, w: usize) -> i32 { + unsafe { + use core::arch::aarch64::*; + let ones = vdupq_n_s8(1); + let mut acc = vdupq_n_u8(0); + let mut i = 0usize; + while i + 16 <= p { + let v = vld1q_s8(r.as_ptr().add(i)); + let masked = vreinterpretq_u8_s8(vandq_s8(v, ones)); + acc = vaddq_u8(acc, masked); + i += 16; + } + // Progressive horizontal sum: u8 -> u16 -> u32 -> u64 + let sum16 = vpaddlq_u8(acc); + let sum32 = vpaddlq_u16(sum16); + let sum64 = vpaddlq_u32(sum32); + let mut weight = (vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1)) as i32; + // Handle remainder + while i < p { + weight += (r[i] & 1) as i32; + i += 1; + } + int16_nonzero_mask((weight - w as i32) as i16) + } +} + +/// Constant-time comparison of two byte slices. +/// Returns 0 if equal, -1 otherwise. +#[allow(unsafe_code, clippy::cast_possible_wrap)] +fn ciphertexts_diff_mask(a: &[u8], b: &[u8]) -> i32 { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return ciphertexts_diff_mask_avx2(a, b); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return ciphertexts_diff_mask_neon(a, b); + } + #[allow(unreachable_code)] + ciphertexts_diff_mask_scalar(a, b) +} + +#[allow(clippy::cast_possible_wrap)] +fn ciphertexts_diff_mask_scalar(a: &[u8], b: &[u8]) -> i32 { + let mut diff: u16 = 0; + let len = a.len().min(b.len()); + for i in 0..len { + diff |= (a[i] ^ b[i]) as u16; + } + int16_nonzero_mask(diff as i16) +} + +/// XOR-accumulate 32 bytes at a time, then horizontal OR. +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") +))] +#[target_feature(enable = "avx2")] +#[allow(unsafe_code, clippy::cast_possible_wrap, clippy::cast_sign_loss)] +unsafe fn ciphertexts_diff_mask_avx2(a: &[u8], b: &[u8]) -> i32 { + unsafe { + use core::arch::x86_64::*; + let len = a.len().min(b.len()); + let mut acc = _mm256_setzero_si256(); + let mut i = 0usize; + while i + 32 <= len { + let av = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i); + let bv = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i); + acc = _mm256_or_si256(acc, _mm256_xor_si256(av, bv)); + i += 32; + } + // Horizontal OR reduction. + // movemask bit i is 1 iff byte i of acc == 0; mask == 0xFFFFFFFF iff equal. + // Collapse to 0/1 branchlessly — a source-level branch here would leak, + // via the branch predictor, whether the ciphertexts matched (the secret + // the implicit-rejection comparison must hide). + let inv = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(acc, _mm256_setzero_si256())) as u32); + let mut diff: u16 = ((inv | inv.wrapping_neg()) >> 31) as u16; + // Handle remainder + while i < len { + diff |= (a[i] ^ b[i]) as u16; + i += 1; + } + int16_nonzero_mask(diff as i16) + } +} + +/// XOR-accumulate 16 bytes at a time, then horizontal OR. +#[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] +#[allow(unsafe_code, clippy::cast_possible_wrap, clippy::cast_sign_loss)] +unsafe fn ciphertexts_diff_mask_neon(a: &[u8], b: &[u8]) -> i32 { + unsafe { + use core::arch::aarch64::*; + let len = a.len().min(b.len()); + let mut acc = vdupq_n_u8(0); + let mut i = 0usize; + while i + 16 <= len { + let av = vld1q_u8(a.as_ptr().add(i)); + let bv = vld1q_u8(b.as_ptr().add(i)); + acc = vorrq_u8(acc, veorq_u8(av, bv)); + i += 16; + } + // Horizontal max: any-nonzero check + let mut diff: u16 = vmaxvq_u8(acc) as u16; + // Handle remainder + while i < len { + diff |= (a[i] ^ b[i]) as u16; + i += 1; + } + int16_nonzero_mask(diff as i16) + } +} + +/// Derive a keypair from secret polynomials. +/// +/// Returns `(pk_bytes, sk_bytes)` as `Vec`. +/// +/// SK layout: `f_enc || ginv_enc || pk || rho || Hash4(pk)` +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_possible_wrap +)] +pub(crate) fn derive_key( + f: &[i8], + g: &[i8], + gr: &[i8], + rho: &[u8], + params: &SntrupParameters, +) -> (Vec, Vec) { + let p = params.p; + + let mut f3r = rq::reciprocal3(f, params); + let mut h = vec![0i16; p]; + rq::mult(&mut h, &f3r, g, params); + let pk = rq::encoding::rq_encode(&h, params); + + // SK layout: f_enc || ginv_enc || pk || rho || Hash4(pk) + let mut sk = vec![0u8; params.sk_size]; + let mut f_enc = zx::encoding::encode(f, p, params.small_encode_size); + let mut ginv_enc = zx::encoding::encode(gr, p, params.small_encode_size); + + let ses = params.small_encode_size; + sk[..ses].copy_from_slice(&f_enc); + sk[ses..(2 * ses)].copy_from_slice(&ginv_enc); + sk[(2 * ses)..(2 * ses + params.pk_size)].copy_from_slice(&pk); + sk[(2 * ses + params.pk_size)..(2 * ses + params.pk_size + ses)].copy_from_slice(rho); + + // Hash4(pk) = Hash(4 || pk) truncated to 32 bytes + let mut cache = [0u8; 32]; + hash_prefix(&mut cache, 4, &pk); + sk[(2 * ses + params.pk_size + ses)..].copy_from_slice(&cache); + + // Zeroize secret intermediates + f3r.zeroize(); + h.zeroize(); + f_enc.zeroize(); + ginv_enc.zeroize(); + cache.zeroize(); + + (pk, sk) +} + +/// Encrypt a small polynomial `r` under a public key. +/// +/// Returns `(ciphertext_bytes, shared_secret)`. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_possible_wrap +)] +pub(crate) fn create_cipher(r: &[i8], pk: &[u8], params: &SntrupParameters) -> (Vec, [u8; 32]) { + let p = params.p; + + let h = rq::encoding::rq_decode(pk, params); + let mut c = vec![0i16; p]; + rq::mult(&mut c, &h, r, params); + rq::round3(&mut c, params); + + let r_enc = zx::encoding::encode(r, p, params.small_encode_size); + + // Compute confirm hash: Hash(2 || Hash(3 || r_enc) || Hash4(pk)) + let mut cache = [0u8; 32]; + hash_prefix(&mut cache, 4, pk); + let mut confirm = [0u8; 32]; + hash_confirm(&mut confirm, &r_enc, &cache); + + // Ciphertext layout: rounded(rounded_encode_size) || confirm_hash(32) + let mut cstr = vec![0u8; params.ct_size]; + cstr[..params.rounded_encode_size].copy_from_slice(&rq::encoding::rounded_encode(&c, params)); + cstr[params.rounded_encode_size..].copy_from_slice(&confirm); + + // Shared key: hash_session(1, r_enc, cstr) + let mut k = [0u8; 32]; + hash_session(&mut k, 1, &r_enc, &cstr); + + // Zeroize secret intermediates + // r_enc, cache, confirm are on the stack / local Vecs and will be dropped, + // but we zeroize explicitly for defense in depth. + let mut r_enc = r_enc; + r_enc.zeroize(); + cache.zeroize(); + confirm.zeroize(); + + (cstr, k) +} + +/// Decapsulate a ciphertext with a secret key. +/// +/// Implements implicit rejection (IND-CCA2): on failure, returns a pseudorandom +/// key derived from `rho`, indistinguishable from a valid key. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_possible_wrap +)] +pub(crate) fn decapsulate_inner(cstr: &[u8], sk: &[u8], params: &SntrupParameters) -> [u8; 32] { + let p = params.p; + let w = params.w; + let ses = params.small_encode_size; + + // Parse SK: f(ses) || ginv(ses) || pk(pk_size) || rho(ses) || cache(32) + let mut f = zx::encoding::decode(&sk[..ses], p); + let mut ginv = zx::encoding::decode(&sk[ses..(2 * ses)], p); + let pk_start = 2 * ses; + let pk_end = pk_start + params.pk_size; + let rho_start = pk_end; + let rho_end = rho_start + ses; + let cache_start = rho_end; + + let mut cache = [0u8; 32]; + cache.copy_from_slice(&sk[cache_start..cache_start + 32]); + + // Decrypt: Rounded_decode, multiply by f, Rq_mult3, R3_fromRq, R3_mult by ginv + let c = rq::encoding::rounded_decode(&cstr[..params.rounded_encode_size], params); + let mut cf = vec![0i16; p]; + rq::mult(&mut cf, &c, &f, params); + let mut t3 = vec![0i8; p]; + for i in 0..p { + t3[i] = r3::mod3::freeze(rq::modq::freeze( + 3 * cf[i] as i32, + params.q, + params.barrett1, + params.barrett2, + ) as i32); + } + let mut r = vec![0i8; p]; + r3::mult(&mut r, &t3, &ginv, p); + + // Weight mask: on failure, set r to default weight-W vector + // (W ones followed by P-W zeros), matching PQClean's Decrypt + let w_mask = weightw_mask(&r, p, w); + let not_mask = (!w_mask) as i8; + for val in r[..w].iter_mut() { + *val = ((*val ^ 1) & not_mask) ^ 1; + } + for val in r[w..p].iter_mut() { + *val &= not_mask; + } + + // Hide: encode r, re-encrypt with pk, compute confirm hash + let mut r_enc = zx::encoding::encode(&r, p, ses); + let h = rq::encoding::rq_decode(&sk[pk_start..pk_end], params); + let mut hr = vec![0i16; p]; + rq::mult(&mut hr, &h, &r, params); + rq::round3(&mut hr, params); + let mut cnew = vec![0u8; params.ct_size]; + cnew[..params.rounded_encode_size].copy_from_slice(&rq::encoding::rounded_encode(&hr, params)); + let mut confirm = [0u8; 32]; + hash_confirm(&mut confirm, &r_enc, &cache); + cnew[params.rounded_encode_size..].copy_from_slice(&confirm); + + // Compare full ciphertexts (rounded + confirm hash) + let mask = ciphertexts_diff_mask(cstr, &cnew); + + // Constant-time select: r_enc on success (mask=0), rho on failure (mask=-1) + let rho = &sk[rho_start..rho_end]; + let mut selected = vec![0u8; ses]; + selected.copy_from_slice(&r_enc); + let mask_byte = mask as u8; + for i in 0..ses { + selected[i] ^= mask_byte & (selected[i] ^ rho[i]); + } + + // Hash session: prefix=1 on success (mask=0), prefix=0 on failure (mask=-1) + let prefix = (1 + mask) as u8; + let mut k = [0u8; 32]; + hash_session(&mut k, prefix, &selected, cstr); + + // Zeroize secret intermediates + f.zeroize(); + ginv.zeroize(); + cache.zeroize(); + cf.zeroize(); + t3.zeroize(); + r.zeroize(); + r_enc.zeroize(); + hr.zeroize(); + cnew.zeroize(); + confirm.zeroize(); + selected.zeroize(); + + k +} diff --git a/sntrup-kem/src/zx.rs b/sntrup-kem/src/zx.rs new file mode 100644 index 0000000..c0c0989 --- /dev/null +++ b/sntrup-kem/src/zx.rs @@ -0,0 +1,318 @@ +/// Small-element (ternary) encoding and decoding. +pub mod encoding { + /// Encode a small polynomial `f` of length `p` into `small_encode_size` bytes. + /// + /// Packs 4 trits per byte (each trit shifted to {0,1,2} by adding 1). + /// The last byte holds `f[p-1] + 1`. + #[allow(clippy::cast_sign_loss)] + pub fn encode(f: &[i8], p: usize, small_encode_size: usize) -> Vec { + let mut c = vec![0u8; small_encode_size]; + for (byte, chunk) in c[..small_encode_size - 1].iter_mut().zip(f.chunks(4)) { + let mut c0 = chunk[0] + 1; + c0 += (chunk[1] + 1) << 2; + c0 += (chunk[2] + 1) << 4; + c0 += (chunk[3] + 1) << 6; + *byte = c0 as u8; + } + c[small_encode_size - 1] = (f[p - 1] + 1) as u8; + c + } + + /// Decode `small_encode_size` bytes into a small polynomial of length `p`. + /// + /// Inverse of [`encode`]: unpacks 4 trits per byte, last element from last byte. + #[allow(clippy::cast_possible_wrap)] + pub fn decode(c: &[u8], p: usize) -> Vec { + let small_encode_size = c.len(); + let mut f = vec![0i8; p]; + for (byte, chunk) in c[..small_encode_size - 1].iter().zip(f.chunks_mut(4)) { + let mut c0 = *byte; + chunk[0] = ((c0 & 3) as i8) - 1; + c0 >>= 2; + chunk[1] = ((c0 & 3) as i8) - 1; + c0 >>= 2; + chunk[2] = ((c0 & 3) as i8) - 1; + c0 >>= 2; + chunk[3] = ((c0 & 3) as i8) - 1; + } + f[p - 1] = ((c[small_encode_size - 1] & 3) as i8) - 1; + f + } +} + +/// Random polynomial generation and constant-time sorting. +pub mod random { + use rand::Rng; + use rand::RngExt; + + /// Branchless constant-time min/max swap (djbsort int32_minmax). + /// Operates on a slice with two indices to avoid borrow issues. + /// + /// Uses wrapping i32 subtraction (matching the original djbsort algorithm) + /// with an XOR fixup for overflow. The `>> 31` extracts the sign bit. + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn int32_minmax(x: &mut [i32], i: usize, j: usize) { + let ab = x[j] ^ x[i]; + let mut c = x[j].wrapping_sub(x[i]); + c ^= ab & (c ^ x[j]); + c >>= 31; + c &= ab; + x[i] ^= c; + x[j] ^= c; + } + + /// Batcher bitonic sort on `n` elements of `x`, dispatching to SIMD when available. + #[allow(unsafe_code)] + pub fn sort(x: &mut [i32], n: usize) { + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + // SAFETY: AVX2 verified by cfg + unsafe { + return sort_avx2(x, n); + } + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + // SAFETY: NEON is baseline on aarch64 + unsafe { + return sort_neon(x, n); + } + #[allow(unreachable_code)] + sort_scalar(x, n); + } + + fn sort_scalar(x: &mut [i32], n: usize) { + if n < 2 { + return; + } + let mut top = 1; + while top < (n - top) { + top += top; + } + let mut p = top; + while p > 0 { + for i in 0..(n - p) { + if i & p == 0 { + int32_minmax(x, i, i + p); + } + } + let mut q = top; + while q > p { + for i in 0..(n - q) { + if i & p == 0 { + int32_minmax(x, i + p, i + q); + } + } + q >>= 1; + } + p >>= 1; + } + } + + /// AVX2-accelerated Batcher bitonic sort. + /// Uses _mm256_min/max_epi32 for 8 parallel comparators when stride >= 8. + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + #[target_feature(enable = "avx2")] + #[allow(unsafe_code)] + unsafe fn sort_avx2(x: &mut [i32], n: usize) { + unsafe { + if n < 2 { + return; + } + let mut top = 1; + while top < (n - top) { + top += top; + } + let mut p = top; + while p > 0 { + // First pass: comparators at stride p + minmax_pass_avx2(x, n, p, 0, p); + + // Sub-passes + let mut q = top; + while q > p { + minmax_pass_avx2(x, n, p, p, q); + q >>= 1; + } + p >>= 1; + } + } + } + + /// Process one pass of comparators: minmax(x[i+off0], x[i+off1]) + /// for all i in 0..(n-off1) where i & p_mask == 0. + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(feature = "force-scalar") + ))] + #[target_feature(enable = "avx2")] + #[allow(unsafe_code)] + unsafe fn minmax_pass_avx2(x: &mut [i32], n: usize, p_mask: usize, off0: usize, off1: usize) { + unsafe { + use core::arch::x86_64::*; + + let end = n.saturating_sub(off1); + if p_mask >= 8 { + // When p_mask >= 8, the condition i & p_mask == 0 selects contiguous + // blocks of p_mask elements. Process 8 at a time with SIMD. + let mut i = 0; + while i < end { + if i & p_mask == 0 { + let block_end = (i + p_mask).min(end); + let mut j = i; + while j + 8 <= block_end { + let a = _mm256_loadu_si256(x.as_ptr().add(j + off0) as *const __m256i); + let b = _mm256_loadu_si256(x.as_ptr().add(j + off1) as *const __m256i); + _mm256_storeu_si256( + x.as_mut_ptr().add(j + off0) as *mut __m256i, + _mm256_min_epi32(a, b), + ); + _mm256_storeu_si256( + x.as_mut_ptr().add(j + off1) as *mut __m256i, + _mm256_max_epi32(a, b), + ); + j += 8; + } + // Scalar remainder for this block + while j < block_end { + int32_minmax(x, j + off0, j + off1); + j += 1; + } + i = block_end + p_mask; // skip the next block (i & p_mask != 0) + } else { + i += 1; + } + } + } else { + // Small strides: scalar + for i in 0..end { + if i & p_mask == 0 { + int32_minmax(x, i + off0, i + off1); + } + } + } + } + } + + /// NEON-accelerated Batcher bitonic sort. + /// Uses vminq_s32/vmaxq_s32 for 4 parallel comparators when stride >= 4. + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + #[allow(unsafe_code)] + unsafe fn sort_neon(x: &mut [i32], n: usize) { + unsafe { + if n < 2 { + return; + } + let mut top = 1; + while top < (n - top) { + top += top; + } + let mut p = top; + while p > 0 { + // First pass: comparators at stride p + minmax_pass_neon(x, n, p, 0, p); + + // Sub-passes + let mut q = top; + while q > p { + minmax_pass_neon(x, n, p, p, q); + q >>= 1; + } + p >>= 1; + } + } + } + + /// Process one pass of comparators with NEON. + #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] + #[allow(unsafe_code)] + unsafe fn minmax_pass_neon(x: &mut [i32], n: usize, p_mask: usize, off0: usize, off1: usize) { + unsafe { + use core::arch::aarch64::*; + + let end = n.saturating_sub(off1); + if p_mask >= 4 { + let mut i = 0; + while i < end { + if i & p_mask == 0 { + let block_end = (i + p_mask).min(end); + let mut j = i; + while j + 4 <= block_end { + let a = vld1q_s32(x.as_ptr().add(j + off0)); + let b = vld1q_s32(x.as_ptr().add(j + off1)); + vst1q_s32(x.as_mut_ptr().add(j + off0), vminq_s32(a, b)); + vst1q_s32(x.as_mut_ptr().add(j + off1), vmaxq_s32(a, b)); + j += 4; + } + // Scalar remainder for this block + while j < block_end { + int32_minmax(x, j + off0, j + off1); + j += 1; + } + i = block_end + p_mask; + } else { + i += 1; + } + } + } else { + // Small strides: scalar + for i in 0..end { + if i & p_mask == 0 { + int32_minmax(x, i + off0, i + off1); + } + } + } + } + } + + /// Fill `g` with random small elements in {-1, 0, 1}. + #[allow(clippy::cast_sign_loss)] + pub fn random_small(g: &mut [i8], rng: &mut impl Rng) { + for val in g.iter_mut() { + let r: i32 = rng.random(); + *val = ((((1_073_741_823 & (r as u32)) * 3) >> 30) as i8) - 1; + } + } + + /// Unsigned sort: XOR with 0x80000000, signed sort, XOR back. + /// Matches PQClean's crypto_sort_uint32. + #[allow(clippy::cast_possible_wrap)] + fn sort_uint32(x: &mut [i32], n: usize) { + for val in x.iter_mut().take(n) { + *val ^= 0x80000000u32 as i32; + } + sort(x, n); + for val in x.iter_mut().take(n) { + *val ^= 0x80000000u32 as i32; + } + } + + /// Generate a random ternary polynomial with exactly `w` non-zero entries out of `p`. + /// + /// The first `w` positions get weight (odd), the remaining `p - w` get zero (even tag), + /// then a constant-time sort shuffles them. + #[allow(clippy::cast_possible_wrap)] + pub fn random_tsmall(f: &mut [i8], p: usize, w: usize, rng: &mut impl Rng) { + let mut r = vec![0i32; p]; + for val in r.iter_mut() { + *val = rng.random(); + } + for val in r[..w].iter_mut() { + *val &= -2; + } + for val in r[w..p].iter_mut() { + *val = (*val & -3) | 1 + } + sort_uint32(&mut r, p); + for (fv, &rv) in f.iter_mut().zip(r.iter()) { + *fv = ((rv & 3) as i8) - 1; + } + } +} diff --git a/sntrup-kem/tests/data/kat0_ct.hex b/sntrup-kem/tests/data/kat0_ct.hex new file mode 100644 index 0000000..1d8aa51 --- /dev/null +++ b/sntrup-kem/tests/data/kat0_ct.hex @@ -0,0 +1 @@ +84F327D38929039EF84366AB796A96E6CC268454D5AD8D05410D3E969B6D228B6CED2CBAD7BAB3B4B7EE453747D3331D43B21CD2ECC3BF557D696E81773EE69F687FE1F61E63742AF913A35418575467118409D4FFDBBAC5CA062DFF3F04D76154D17EC34212DEBFCA967A91E461F73786D949D6A9765140B9DC7849639A487058F45DDA919576090D35E783804BA7B8B33B5C6A8C1EBB7C39F042696CF9F2B624D26EA5D6D101489E242CC35126B573928B6B8BF9945269AD22E7DA70CCB7DC374F75641978D5F6D5F540A1169EB098570A7CFD65C9257C4C196F42C6DA4F4466D8F11A17C6A75A9BDD45B5AC77A8367B005ED22141A81B995A19E4DC9E6743B4B45E0329E7A00D3FF90C7A917CDC9CA75A34FFFDD9A321751B5C8B2AB9E1EA037CF45C04B412A97F52E7C6D66508456D9117961257B4F50BB9F500DE99B6F061C377D44F495D711AB2ED3E88CCF14DF10F8C60BD00E29CD83AD160FAEF2984FAC7E897CBB8CDA144D50E4C1AD20238A2617598DA6EFE786AA82E0931B0B5D1A7FA024B856353AFAEDB0A1150A42DA182353CE5A00403BB6B7A13EFFC14B0FD97E1117302ED8A66FAAED88B43B34CFC73CB33E49A4C34D0BFDEE2FEADC2B4C8BB1EC41215586FC66550D496373305B99CEE254EAC71D3CDAC689A777062ACAE65B031B5A1A973835BE6A4BB061F178531C45DE3B33E7763B483621821E037FE5EF48FD971D8D2BEFDA63250E9B1181D8B8086BED9CDA8CDE6644CDE8EBF7C1F1038AFE798A2A4C492B603B7F4F8805D0FDF19D8274888CEBEC3043CC9ACF024BD30D8B0D28D311985930110386598775CF84E0848B11E090905A65F56142FC17D2573536FC51E7D5B2E0FF7B09471155008561156FBAEC215031ED0D844EB14D6C609173355FFF8AE22C6E5BCCD19B2832E5FBE625C2C0E516E1369D1FCF49A5FDBAF8CF572F04317DB8ED70EE1605B96D077B4F054C72B139F5BCAA5BB627E39BE5BCC97AF922C5E70835BD918281B95E557FC86A43C75980A0D9F0F9162F230014CD867DD07DB23260934546C04499D6B8627B8D5DEC73FA45327F84983529BAF3AA10E885724FE76055099D136CCAACC159B4CE38784404CC9B04723910F990AF77C64D42A29D784FC897955758F1D2A12BB7C6F6A8E5B94EFA740F662CF0DA2D099C61EE1CC6B661A64F0887B805AF53C7B64C4456DAF228582C1B2212A8F718E56DF6B72D7EFA0C5D26582CE96AEFCE13CB2830C3FB32E5115AACC20D3A335F7A4C12E0CD1DE532D8D9D42D69BC2D0BE0AC7DF3B8F416B9A53DB261C47272725A8A88298678351B42103219092076B6ECE744D993388A8F2C475079CB8BC4DC492A4797147C01EA3508225826F7A3D32BF502B0F4552F36FC9F6FB91E82DB1949338B45436EC0FF63B53F900F27F8147D420F93D33C83794980E94143FD361F944DDB56D3114EE974C96A155 \ No newline at end of file diff --git a/sntrup-kem/tests/data/kat0_sk.hex b/sntrup-kem/tests/data/kat0_sk.hex new file mode 100644 index 0000000..4e0440f --- /dev/null +++ b/sntrup-kem/tests/data/kat0_sk.hex @@ -0,0 +1 @@ +55565515155545556569159955A2555455456161450594556556915055895615548852955456550098544995555AA5465400A5555558956510A6515596559665996A44915415414465565555551105624156954944626940555555165665955555A51591555956591565651542404089195454566955859064258656285054569505A6A956949446661415580115565055485955259049A5545525455504526A195455554185525A89962552555658248464451455658965625851545915018508154A51A0448959112A286212608A52911A949119A2895664165900A861541A200A52104029544600425A5A1198101A1A50A094A551425A484959A1506486A618555A89905615029622A2A46696A2A6595A46682622525050489A8040562808168156854185026540A1059A94AAA291811948A259816991110951802409AA5A2A06290092451446655608404096898AA56220562285150891914066110A0905A841044865042A2688AA12441AAA46064669155555A2109659299922820036C969CF1008A6AA9551A784941C65A9BF68C2DC33FA36B5D266B25171B346679F2D22BF3123A79C790D6DEC68E1BC44420A6824F5357C78E3C336FEE0551E620DCB975F563682A312A3353B521C727F57CABED0C3228F09317CAE8B58158EBF5B26BDC6E6365AA601ACAD2ABD37F5830D0BBFE355705C0A62B76A5C910AD04E55E5DAD749C7393D2E2E8AB643E62E4757AD2201CAA33203F53B4A9D4757E7274D72BDB036A31D7DE11E5D1C66CF3059F33B6A2972C1E1D9C9FB2AEBC78B2C055D48D79C3A7C996C08B7DEF0791CDE895053885D8DA1D254EE19C090AF34F4720B0B77139108D8498982FEAB0B54934CEE3DCD24B049F981A84C928028A64A26CDF87052313C3E50B2E1F539394502433C0962996C3599189B15174281B6595B567B8C4CD80902860E613AE1906A55607CBF0E11AD6C0C0DF38C006A5F535FA6E6FB49D10B78B9CB6473BF0630518DB6FDECFD70DED201C7E35FFB3BB78D8AEC181F6DE960C316FE354B7CC69E8048201965AAFEF4EA3F808737FF45255A1779DE57A4B68DB587263D7B7F6CB07D8B01224DD291237EFD9C01676E30154A2A7D60174536580FE64EFEDDA5FB42C13ED8C5768A3CBCAE7A2343B3128CD5C1C663A07C7DC0E1E2642C0D9A02F349C964154A6C4308D8ECF30F47E9A81EED2A32A2BB44DFC36A28A66AE77BB139FA416B6327EEFE33632469CCC21229587573C4F7752CE1CC79CA0C6BDA08CD79E25720D2D9092EA2AB13F31E9A3AF1C69FC6379D8E2AD2B87B514C817A087338B7B0736B8954DE9223225B40079C92601248C14B2104901E74F849D9EEE9F5636C1DAD6031AB477C573197E7EB6CE535D9F0F69183DD9DF5521595E5C9E98D846CE08AA655B70CC0D8041401929690298F645D9112DBB03B189CDB1FCF4D512F6874B409BF55EE1CC284A05B698B8818F043C2591C9F4ABA29CF4259915D6A0BE71B6B93963C618CBB567838E5EADE6500DE9AD250083A01EB3EB44C00EECB2B0F874CDA165FAA6524CA13D435C938DB9469292FE97283C69222107EA2F9EFCA1B6D41DCA5B149B2A8CC2244A1BC54261CC11742ACD27B7F2352C33FB83FE143386479D78D6D5C3E51684B60A56714182449C94327139B0289B9BDE0AE3FEF319FE2C605AB5507C894D2C2761A3BD3EA30F4BA928F9F23303061617E2F042DCF229AFF2C9345A6B3CDCF90D3E3BCC3A1110C85616BD585C31CDE3E69ADC12A18BFA5797DF0543435A7C874A8ACF3786FCCBA4A6CEAA65E5666230AA7206AE78EB1B98CB5236508E6357E18CCCFAEC5693532BE4EE38022C48FE94F62B0293A088BF4D737F48748F23BCB338B58B4D666AD3C64E9ECD07265F971B07AF716D4A5B719C3F5FE35744734CBA5430381661E372B6F6510D61B11E4697A1A961589949DBA53B5BC5BDB76FA09324387D799536506182EFF7078034B34E1DDD2612B0A40E7F1A294FBA869DA2E46C4C36BB08C7E9BEF09A94459E8FC2D3CC579D15284DF1F19EF77E03041511ACC39BF9CA8AC8DE6A0E5EEA0BBD4289B6DAE38E9E82EB50A0397B3EE6C52FBDC7DEEA2C376825E89016C859B09488548BF76CA62D696DF94A61E8F0E13C69EE816CE5AA2768987B0A782D74C673EC0059FA532AE97F8BDF90F90130926654FB8B469D772049D72A8375CD8459D06CC1B90633EB3A899685D21491B5062A9FC73FB6E878D7A73198EFA0B569D9CE665CE126FFDC9862EB00D11F457A7995555F7C92011C24E1CEC45C270FB5F121F08177F97FC3F631C0EF86A92E99D557FB69A0F6FCD8B1C0EF94AA7429B3A8A11614D847202D3D04A98313FC0D63ABF9FD75CA321C879458BF8837B7E74CCF5179C7714FA9800D1821EE3F9639D28136B7910872631F85AE7B6DD289E01034217210859D4E53C65487CE38AFF621DA76BA1C4E2E77ED5380B47B8D983CCB5BB793E \ No newline at end of file diff --git a/sntrup-kem/tests/data/kat0_ss.hex b/sntrup-kem/tests/data/kat0_ss.hex new file mode 100644 index 0000000..4e7c0a4 --- /dev/null +++ b/sntrup-kem/tests/data/kat0_ss.hex @@ -0,0 +1 @@ +344CA5E25F6DA5EA95E4A695B1C5446ECA9859334532E4A9537669F012C743A2 \ No newline at end of file diff --git a/sntrup-kem/tests/data/kat1_ct.hex b/sntrup-kem/tests/data/kat1_ct.hex new file mode 100644 index 0000000..b50d24d --- /dev/null +++ b/sntrup-kem/tests/data/kat1_ct.hex @@ -0,0 +1 @@ +8CE9955F6DC23A0E49D9B263C43026609612696FD84D37DFA1192BA0D9412D2125E25A7C64209A92E5F1C6F170D5C4C891D5050E336F628544620D6B0A9C5058DBBB51C4F540AE3C19BD941E7CD3105A7BCD1774FF05B0B65E310A1D6F88253CFA36C2E7F34130B14613B188D8B9D6C9BE040CE2446F33F75E2A61AF7D7C11FFF54505F1E1EEE49172234915D722599C6BD77D0BCEBEC69435BAC571A194939861CE3ED3960C0FD0FDB8D7CBD272659BBB881C29A70ECB62388062C4ABB2A565A0D9CE73E23E9CA7D589DE524E8289762499F867697003DBB87E1657313D05E5E8E7C7FF407DF7931EB6842C825219369E92D2E65E29CFCF31DCD0F40706AF8B4F70D9E8AF6147AC7D8B1900773D9EE2CE6C4EEC38C4FF44F2E397D95684CE5EE6FD5FE4975216819BAE496D2CE888A6F630FE448A4872BBC5164A70A33A9988A3C4AEA17A8688346191CD4BE2F8D021CC97D5A20BFFFFC628ED6F267145C41BCFB83161C23262943B530DF4BD4B6BDD85D78A2DAF0499B8DE97AFE6191AFD3A35E8268A4D08EC55DBA8BEC15BDE4F823D4B3C62F7990EF7A655D03DD4897F640ABDD60D58A3244F745CEA09ACCDD56DE4D8F08F43025E084E91634E25FC7B39BCAFFC1803E8D4806786548263D6BF8D8ACC84FA9A01D786F438FEDB811DF9FC0EFFCF9C6A787B17228B260D6ACCE43075E7F80CC57C682101CD9CDF12B2D3F931AAAE1A2434AAE772C11CCEE49E2503DAE5CBC98E41F6AD6E885B1F011E9883CE078B33D99FC41471CCA14FC0DC12789464CE92B93A08E8F3D25822827A08273FA2E5114909F0D4A378A482F0EBDF781991E1E28B3F72CAD3886575FF7E59F2B8607CE314D8292E27C1D6187EC5C76CB1347185D760D69F91B7F455DDE52257B03CBB9845B78867EABAA28535B6DEA69D14693666B1E978503F65412C5AFF4AFEC1F537C416EAC390D5EEA0C3B4694DB4FAB95CAC381D12F42F176F27A4BEE0C0A0A458B51E5C71D82724593014C15430BECDBB472659D91B093192590A43A6E56BB6AD3C6C5BAA7AA49D888E1CFA656200680A31ECDB99A4B0109133523653D9769A945ADEC3890734841B94ACC62A8AB0805F0E9E9E3B4471552DA0B5E78935C149E3D4DFCB20F75BDA04A86611B4192BC6AF2B5C3C1C4424E40E8AF9E1A7A3E57B9AACA0B730585931386329C19846D1BF091084C12380D7B2FB3348E7BC25EF9BC39FEBA174B8FC9A191A11C345E0E3CB3170593928429F0729BB96D771BAFCB84633B80377658EBB5866631AE2EBAE05940AFDBA13F34A0341E923EA2C31F11E2A1A421C4522AB50153A77B7950834CBF23A16D01462A0FCD270BCF66B9F6E0A5A39F5989D75643DFC8EC4D5B051AE94C6DB74D7779C54016EAD109D10AC190B07816CDCD2B8EC9BC596090DF00FE23B4E862FE7E0D7EBCBF6B2A1A565435DA662CB34489BB570286BA114F8 \ No newline at end of file diff --git a/sntrup-kem/tests/data/kat1_sk.hex b/sntrup-kem/tests/data/kat1_sk.hex new file mode 100644 index 0000000..e585f4a --- /dev/null +++ b/sntrup-kem/tests/data/kat1_sk.hex @@ -0,0 +1 @@ +645159996555659148569656954105506019451895595549848669516585555885215216015254544569A6521699555985415514151155925466465965555565555561461655950895559696596565645699555A1596555A966566614655A40951559554669590615659A551555125556595551955569915A5614089A15501582951945A15558659555555654110491169A0115656649156245559244155A584145650516445A555555955485655455681166469511A59642155555501990285891A514062416580150298400062840021A908660880A85A28008958212528A0A604218026444A2AA5452A952205120199055041A99129298A989A895206621800648A98A88069864862A4651468A14114268209A55249440A8951906505689586229848816266852198A2A2096459194A550402910486622949A09448210A94429A444126A4A66A0AA18921A4A54514595A81248446A86601905118A2850492910185041822A52422565641A195969068850240588648408A14990A6801D2530F125EE5F208B1976A66BCBC917161F6929E636BA8C73470DE18065F6057528D718744E9248DFFF6BB55C188CEACB9419863C3C456B46A21354834ADA6B2132C67747C9EE70D02560268EE650E56C84BCE6A6700C5E612999110E53A866AE6B4F778230367B8B886C4FEC089C6267F91C7F24D6CE53C754D9CCCDD25756D76CD211D5954F0E8A11343679C1F692CE5C0D0E42A02381144E1E0201B6F49F00628A86E09918488BC3E1E1071561700544C2CE50F4B7BAE4757CA7A0DD0A221260E0574D1F8F81AF072FCBA8061B5B8BB450FDCC6732D35CDBBBC4EF1798EA7E263BCF369059EED3A86E2C61C72CBFCD69521C3A1FD4868AF5638148043162B8B6E39F82B56D6C9C7724807F623C06CC08FF619F7961EC972B1BD2856F87D5E6940DC15B9575264A4AB0CA229C5B02C9AA5BB8DA2BD8EB09A7A5E824C2D666A17C4845CA9530F3A5E45BE5A380B83C81B3965180C706E2CC9ABF4DC8A1559CCF176EFD0FAAD1B8E184AAA08E24CA6F5D070E040D64CD65A1628E99F3BACAB5A7206B5C7D1F7D282651448E259A839E128774530C931C4E4E6F60AB9CDEB645AA24012C3A0983750C04914C59E34AF67B9128CE906AD162D70E582C42E21898EF8023A9D1BC3B5F5BDE6415CD168F1DE726A2ECB27C48A73356FBBEE8F405A6F8C34F8F5A864C109045618B03FBFDCBAFD8BC934658634070D40E5C62A03F9B544C325BCC56D8C70D67DA39879356E9871B016A0CCF3FE1C6AB941B69417FF514B616C9CFE9EA066FBB4B081B25D584E40430B0950A7EBFBFA9B4E4EE003F297BA668F7139F26F218026DC9BFDBBBE6FD1CDE03CAC9814299126E8CBDAFFD54101E07C09B65B88C9743B85C464E44C7256F506F07BC7D3877E2578D589CD2001E5124C9E647FAC2134B34E74DC188478E538BBFB750B47600191C51B71F18460CECD5FBDABAA48D6DD6DBE7E26CEF9498220379F1FF25AFACA60CEFFA408CEB677E4CAFB3A43B3E5FAA5B731EA608A945EDF645136ADC77B07DB34E191D9E786ED7CFD97E2330548FFE021997E2E8B774E4A5F1666A91F05D471874E765DF42C7C3D9EF37CD946E0D69C9EA83CEC9B1AA41A25F1A3A458C9935CBC7D294B64D628DA7F1B0FA4F24E6375DF31AE82815F5DFF4ACD0DD9BD8A8740CB92633D1E000191B837021F143F64B08388A78B9A0F55FE7B824FA9D4E85709EAC46EAB5B24C46BBC2D1D8D39FEC8BA130EB68A7F55769606F07CB7B8CA0AD99739C68A365E415983B964F1F2D261145057B7A76D72300AD49D9FFFFFFE9C41BD48B82F8450274C6C25DCECE61D4C443FEE0D3C359FDD4E4409AD607AE7A707B83B8629134DACCBD54C618EEF0B2F04E848F7B62C494DEC89F2830BACEDB3B876670309B36F7D70BE0219F5D05A3C1E5BB58B1CF053FE92D2E3F934AA2047F963172E04B7457FED3956C08A705C9C441F1FEE07E05344C63817F5DA7F298D5323BF88FB490E1C628015ECD09C5D89978EB42A2DFD32E0820B005BDB60FFFB15EBE1A123C6107530FE48C72C925FB85465749EF1A46C6BA7F0F31D911D9E78608DAA64EC87CB3C82ABEC988FA3656734A038FB5A3F072B7E9E9F383227EA3D5ACE37A6A5536B9E3AA926E900B736895EF729F2159BA82A0D090B0D4A26FFC467511911F39D3EC248467A576EE1F8DCA10DD0C3BE961080D925B823AD9538477F258DD445AF4872DFDEB8E4A819DDE314766683246379B03AA738907ED3671359999FF298C4A44133417821912013E792C90D939815BC3AEBB565E1D6B42BB356CC6A6C79EB6D640001C9D0ED847B4D39B2C38FC2123609EF94608B766EBEB91DD12D228123D29D14C1B4169D8417D26054304B5C900E5CF78159735D0FFA15B691369BC66811A9F11ADB3ED280E3151830BAC71B6E5EE77238F632911067ED8C525887CF983CFE4A2572D3 \ No newline at end of file diff --git a/sntrup-kem/tests/data/kat1_ss.hex b/sntrup-kem/tests/data/kat1_ss.hex new file mode 100644 index 0000000..97f65bc --- /dev/null +++ b/sntrup-kem/tests/data/kat1_ss.hex @@ -0,0 +1 @@ +16C15126F734E51268BA916CE3B39A72E171AE79B8C2B6A68B34AB0DC5621B7E \ No newline at end of file diff --git a/sntrup-kem/tests/kat.rs b/sntrup-kem/tests/kat.rs new file mode 100644 index 0000000..8fb4f6c --- /dev/null +++ b/sntrup-kem/tests/kat.rs @@ -0,0 +1,56 @@ +#![allow(missing_docs)] +#![cfg(feature = "dcap")] + +use sntrup_kem::*; + +/// IETF draft-josefsson-ntruprime-streamlined-00, test vector 0 (sntrup761). +#[test] +fn kat0_decapsulation_761() { + let sk_hex = include_str!("data/kat0_sk.hex"); + let ct_hex = include_str!("data/kat0_ct.hex"); + let ss_hex = include_str!("data/kat0_ss.hex"); + + let sk = sntrup761::DecapsulationKey::try_from( + hex::decode(sk_hex.trim()) + .expect("invalid SK hex") + .as_slice(), + ) + .expect("SK size mismatch"); + + let ct = sntrup761::Ciphertext::try_from( + hex::decode(ct_hex.trim()) + .expect("invalid CT hex") + .as_slice(), + ) + .expect("CT size mismatch"); + + let ss_expected = hex::decode(ss_hex.trim()).expect("invalid SS hex"); + let ss = sk.decapsulate(&ct); + assert_eq!(ss.as_ref(), &ss_expected[..], "KAT0 shared secret mismatch"); +} + +/// IETF draft-josefsson-ntruprime-streamlined-00, test vector 1 (sntrup761). +#[test] +fn kat1_decapsulation_761() { + let sk_hex = include_str!("data/kat1_sk.hex"); + let ct_hex = include_str!("data/kat1_ct.hex"); + let ss_hex = include_str!("data/kat1_ss.hex"); + + let sk = sntrup761::DecapsulationKey::try_from( + hex::decode(sk_hex.trim()) + .expect("invalid SK hex") + .as_slice(), + ) + .expect("SK size mismatch"); + + let ct = sntrup761::Ciphertext::try_from( + hex::decode(ct_hex.trim()) + .expect("invalid CT hex") + .as_slice(), + ) + .expect("CT size mismatch"); + + let ss_expected = hex::decode(ss_hex.trim()).expect("invalid SS hex"); + let ss = sk.decapsulate(&ct); + assert_eq!(ss.as_ref(), &ss_expected[..], "KAT1 shared secret mismatch"); +} diff --git a/sntrup-kem/tests/kem.rs b/sntrup-kem/tests/kem.rs new file mode 100644 index 0000000..aaf494d --- /dev/null +++ b/sntrup-kem/tests/kem.rs @@ -0,0 +1,288 @@ +#![allow(missing_docs)] + +use sntrup_kem::*; + +// --------------------------------------------------------------------------- +// Implicit rejection: corrupted CT still returns a key, but a different one +// --------------------------------------------------------------------------- + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +macro_rules! implicit_rejection_test { + ($name:ident, $kem:ty, $ct_size:expr) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let (ct, ss_encap) = ek.encapsulate(&mut rng); + + // Corrupt the ciphertext + let mut ct_bytes = ct.as_ref().to_vec(); + ct_bytes[0] ^= 0xFF; + ct_bytes[100] ^= 0x42; + let ct_bad = Ciphertext::try_from(ct_bytes.as_slice()).expect("CT size"); + + let ss_decap = dk.decapsulate(&ct_bad); + assert!( + ss_encap != ss_decap, + "corrupted CT must produce different key" + ); + + // Deterministic: same corrupted CT + SK always produces same key + let ss_decap2 = dk.decapsulate(&ct_bad); + assert!( + ss_decap == ss_decap2, + "repeated decap must be deterministic" + ); + } + }; +} + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_653, Sntrup653, 897); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_761, Sntrup761, 1039); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_857, Sntrup857, 1184); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_953, Sntrup953, 1349); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_1013, Sntrup1013, 1455); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +implicit_rejection_test!(implicit_rejection_1277, Sntrup1277, 1847); + +// --------------------------------------------------------------------------- +// Wrong secret key gives different shared secret +// --------------------------------------------------------------------------- + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +macro_rules! wrong_sk_test { + ($name:ident, $kem:ty) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek1, _dk1) = <$kem>::generate_key(&mut rng); + let (_ek2, dk2) = <$kem>::generate_key(&mut rng); + let (ct, ss_encap) = ek1.encapsulate(&mut rng); + let ss_decap = dk2.decapsulate(&ct); + assert!(ss_encap != ss_decap, "wrong SK must produce different key"); + } + }; +} + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_653, Sntrup653); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_761, Sntrup761); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_857, Sntrup857); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_953, Sntrup953); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_1013, Sntrup1013); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +wrong_sk_test!(wrong_sk_1277, Sntrup1277); + +// --------------------------------------------------------------------------- +// Constant-time decapsulate always returns SS_BYTES +// --------------------------------------------------------------------------- + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +macro_rules! constant_time_test { + ($name:ident, $kem:ty, $ct_size:expr) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let (ct, _ss) = ek.encapsulate(&mut rng); + let result = dk.decapsulate(&ct); + assert_eq!(result.as_ref().len(), 32); + + // Even with garbage ciphertext + let garbage = vec![0xABu8; $ct_size]; + let garbage_ct = Ciphertext::try_from(garbage.as_slice()).expect("CT size"); + let result2 = dk.decapsulate(&garbage_ct); + assert_eq!(result2.as_ref().len(), 32); + } + }; +} + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_653, Sntrup653, 897); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_761, Sntrup761, 1039); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_857, Sntrup857, 1184); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_953, Sntrup953, 1349); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_1013, Sntrup1013, 1455); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +constant_time_test!(constant_time_1277, Sntrup1277, 1847); + +// --------------------------------------------------------------------------- +// Deterministic keygen from seed +// --------------------------------------------------------------------------- + +#[cfg(feature = "kgen")] +macro_rules! deterministic_keygen_test { + ($name:ident, $kem:ty) => { + #[test] + fn $name() { + let seed = [0xABu8; 32]; + let (ek1, dk1) = <$kem>::generate_key_deterministic(&seed); + let (ek2, dk2) = <$kem>::generate_key_deterministic(&seed); + assert_eq!(ek1, ek2, "same seed must produce same EK"); + assert!(dk1 == dk2, "same seed must produce same DK"); + + // Different seed produces different key + let (ek3, _dk3) = <$kem>::generate_key_deterministic(&[0xCDu8; 32]); + assert_ne!(ek1, ek3, "different seed must produce different EK"); + } + }; +} + +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_653, Sntrup653); +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_761, Sntrup761); +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_857, Sntrup857); +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_953, Sntrup953); +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_1013, Sntrup1013); +#[cfg(feature = "kgen")] +deterministic_keygen_test!(deterministic_keygen_1277, Sntrup1277); + +// --------------------------------------------------------------------------- +// Extract encapsulation key from decapsulation key +// --------------------------------------------------------------------------- + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +macro_rules! ek_from_dk_test { + ($name:ident, $kem:ty) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let ek_extracted = dk.encapsulation_key(); + assert_eq!(ek, ek_extracted, "extracted EK must match original"); + + // Encapsulating with the extracted key should produce a valid shared secret + let (ct, ss_encap) = ek_extracted.encapsulate(&mut rng); + let ss_decap = dk.decapsulate(&ct); + assert!(ss_encap == ss_decap, "shared secrets must match"); + } + }; +} + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_653, Sntrup653); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_761, Sntrup761); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_857, Sntrup857); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_953, Sntrup953); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_1013, Sntrup1013); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +ek_from_dk_test!(ek_from_dk_1277, Sntrup1277); + +// --------------------------------------------------------------------------- +// TryFrom with wrong sizes +// --------------------------------------------------------------------------- + +macro_rules! try_from_invalid_size_test { + ($name:ident, $ek:ty, $dk:ty, $ct:ty) => { + #[test] + fn $name() { + let short = vec![0u8; 16]; + assert!(<$ek>::try_from(short.as_slice()).is_err()); + assert!(<$dk>::try_from(short.as_slice()).is_err()); + assert!(<$ct>::try_from(short.as_slice()).is_err()); + } + }; +} + +try_from_invalid_size_test!( + try_from_invalid_653, + sntrup653::EncapsulationKey, + sntrup653::DecapsulationKey, + sntrup653::Ciphertext +); +try_from_invalid_size_test!( + try_from_invalid_761, + sntrup761::EncapsulationKey, + sntrup761::DecapsulationKey, + sntrup761::Ciphertext +); +try_from_invalid_size_test!( + try_from_invalid_857, + sntrup857::EncapsulationKey, + sntrup857::DecapsulationKey, + sntrup857::Ciphertext +); +try_from_invalid_size_test!( + try_from_invalid_953, + sntrup953::EncapsulationKey, + sntrup953::DecapsulationKey, + sntrup953::Ciphertext +); +try_from_invalid_size_test!( + try_from_invalid_1013, + sntrup1013::EncapsulationKey, + sntrup1013::DecapsulationKey, + sntrup1013::Ciphertext +); +try_from_invalid_size_test!( + try_from_invalid_1277, + sntrup1277::EncapsulationKey, + sntrup1277::DecapsulationKey, + sntrup1277::Ciphertext +); + +// --------------------------------------------------------------------------- +// TryFrom / AsRef roundtrip +// --------------------------------------------------------------------------- + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +macro_rules! bytes_roundtrip_test { + ($name:ident, $kem:ty) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let (ct, ss) = ek.encapsulate(&mut rng); + + // EK roundtrip + let ek2 = EncapsulationKey::try_from(ek.as_ref()).expect("EK roundtrip"); + assert_eq!(ek, ek2); + + // DK roundtrip + let dk2 = DecapsulationKey::try_from(dk.as_ref()).expect("DK roundtrip"); + assert!(dk == dk2, "DK must match"); + + // CT roundtrip + let ct2 = Ciphertext::try_from(ct.as_ref()).expect("CT roundtrip"); + assert_eq!(ct, ct2); + + // Full KEM roundtrip through bytes + let ss_decap = dk2.decapsulate(&ct2); + assert!(ss == ss_decap, "KEM roundtrip through bytes must work"); + } + }; +} + +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_653, Sntrup653); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_761, Sntrup761); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_857, Sntrup857); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_953, Sntrup953); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_1013, Sntrup1013); +#[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +bytes_roundtrip_test!(bytes_roundtrip_1277, Sntrup1277); diff --git a/sntrup-kem/tests/roundtrip.rs b/sntrup-kem/tests/roundtrip.rs new file mode 100644 index 0000000..842804f --- /dev/null +++ b/sntrup-kem/tests/roundtrip.rs @@ -0,0 +1,24 @@ +#![allow(missing_docs)] +#![cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] + +use sntrup_kem::*; + +macro_rules! roundtrip_test { + ($name:ident, $kem:ty) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let (ct, ss1) = ek.encapsulate(&mut rng); + let ss2 = dk.decapsulate(&ct); + assert_eq!(ss1, ss2); + } + }; +} + +roundtrip_test!(roundtrip_653, Sntrup653); +roundtrip_test!(roundtrip_761, Sntrup761); +roundtrip_test!(roundtrip_857, Sntrup857); +roundtrip_test!(roundtrip_953, Sntrup953); +roundtrip_test!(roundtrip_1013, Sntrup1013); +roundtrip_test!(roundtrip_1277, Sntrup1277); diff --git a/sntrup-kem/tests/serde.rs b/sntrup-kem/tests/serde.rs new file mode 100644 index 0000000..0f070a3 --- /dev/null +++ b/sntrup-kem/tests/serde.rs @@ -0,0 +1,135 @@ +#![allow(missing_docs)] +#![cfg(feature = "serde")] + +use sntrup_kem::*; + +macro_rules! serde_json_test { + ($name:ident, $kem:ty, $params:ty, $pk_size:expr, $ct_size:expr) => { + mod $name { + use super::*; + + #[test] + fn json_roundtrip_encapsulation_key() { + let mut rng = rand::rng(); + let (ek, _dk) = <$kem>::generate_key(&mut rng); + let json = serde_json::to_string(&ek).expect("serialize EK"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse"); + assert!(parsed.is_string(), "EK should serialize as hex string"); + assert_eq!( + parsed.as_str().expect("str").len(), + $pk_size * 2, + "hex length mismatch" + ); + let ek2: EncapsulationKey<$params> = + serde_json::from_str(&json).expect("deserialize EK"); + assert_eq!(ek, ek2); + } + + #[test] + fn json_roundtrip_decapsulation_key() { + let mut rng = rand::rng(); + let (_ek, dk) = <$kem>::generate_key(&mut rng); + let json = serde_json::to_string(&dk).expect("serialize DK"); + let dk2: DecapsulationKey<$params> = + serde_json::from_str(&json).expect("deserialize DK"); + assert!(dk == dk2, "DK must match after JSON roundtrip"); + } + + #[test] + fn json_roundtrip_ciphertext() { + let mut rng = rand::rng(); + let (ek, _dk) = <$kem>::generate_key(&mut rng); + let (ct, _ss) = ek.encapsulate(&mut rng); + let json = serde_json::to_string(&ct).expect("serialize CT"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse"); + assert!(parsed.is_string(), "CT should serialize as hex string"); + assert_eq!( + parsed.as_str().expect("str").len(), + $ct_size * 2, + "hex length mismatch" + ); + let ct2: Ciphertext<$params> = serde_json::from_str(&json).expect("deserialize CT"); + assert_eq!(ct, ct2); + } + + #[test] + fn json_roundtrip_shared_secret() { + let mut rng = rand::rng(); + let (ek, _dk) = <$kem>::generate_key(&mut rng); + let (_ct, ss) = ek.encapsulate(&mut rng); + let json = serde_json::to_string(&ss).expect("serialize SS"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse"); + assert!(parsed.is_string(), "SS should serialize as hex string"); + assert_eq!( + parsed.as_str().expect("str").len(), + 64, + "hex length mismatch" + ); + let ss2: SharedSecret<$params> = + serde_json::from_str(&json).expect("deserialize SS"); + assert!(ss == ss2, "SS must match after JSON roundtrip"); + } + + #[test] + fn json_full_kem_roundtrip() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + let (ct, ss_encap) = ek.encapsulate(&mut rng); + + let dk_json = serde_json::to_string(&dk).expect("serialize DK"); + let ct_json = serde_json::to_string(&ct).expect("serialize CT"); + + let dk2: DecapsulationKey<$params> = + serde_json::from_str(&dk_json).expect("deserialize DK"); + let ct2: Ciphertext<$params> = + serde_json::from_str(&ct_json).expect("deserialize CT"); + + let ss_decap = dk2.decapsulate(&ct2); + assert!(ss_encap == ss_decap, "KEM roundtrip through JSON must work"); + } + } + }; +} + +mod reject_malformed_input { + use super::*; + + /// Inputs shorter than the expected size must be rejected, not zero-padded. + #[test] + fn short_input_rejected() { + let json = "\"deadbeef\""; // 4 bytes — far shorter than any expected size + assert!( + serde_json::from_str::>(json).is_err(), + "short EncapsulationKey input must be rejected, not zero-padded" + ); + assert!( + serde_json::from_str::>(json).is_err(), + "short DecapsulationKey input must be rejected, not zero-padded" + ); + assert!( + serde_json::from_str::>(json).is_err(), + "short Ciphertext input must be rejected, not zero-padded" + ); + assert!( + serde_json::from_str::>(json).is_err(), + "short SharedSecret input must be rejected, not zero-padded" + ); + } + + /// Empty input must be rejected for all four types. + #[test] + fn empty_input_rejected() { + let json = "\"\""; + assert!(serde_json::from_str::>(json).is_err()); + assert!(serde_json::from_str::>(json).is_err()); + assert!(serde_json::from_str::>(json).is_err()); + assert!(serde_json::from_str::>(json).is_err()); + } +} + +serde_json_test!(serde_653, Sntrup653, Sntrup653Params, 994, 897); +serde_json_test!(serde_761, Sntrup761, Sntrup761Params, 1158, 1039); +serde_json_test!(serde_857, Sntrup857, Sntrup857Params, 1322, 1184); +serde_json_test!(serde_953, Sntrup953, Sntrup953Params, 1505, 1349); +serde_json_test!(serde_1013, Sntrup1013, Sntrup1013Params, 1623, 1455); +serde_json_test!(serde_1277, Sntrup1277, Sntrup1277Params, 2067, 1847); diff --git a/sntrup-kem/tests/sizes.rs b/sntrup-kem/tests/sizes.rs new file mode 100644 index 0000000..60d3cee --- /dev/null +++ b/sntrup-kem/tests/sizes.rs @@ -0,0 +1,27 @@ +#![allow(missing_docs)] +#![cfg(all(feature = "kgen", feature = "ecap"))] + +use sntrup_kem::*; + +macro_rules! size_test { + ($name:ident, $kem:ty, $pk:expr, $sk:expr, $ct:expr) => { + #[test] + fn $name() { + let mut rng = rand::rng(); + let (ek, dk) = <$kem>::generate_key(&mut rng); + assert_eq!(ek.as_ref().len(), $pk, "PK size mismatch"); + assert_eq!(dk.as_ref().len(), $sk, "SK size mismatch"); + + let (ct, ss) = ek.encapsulate(&mut rng); + assert_eq!(ct.as_ref().len(), $ct, "CT size mismatch"); + assert_eq!(ss.as_ref().len(), 32, "SS size mismatch"); + } + }; +} + +size_test!(sizes_653, Sntrup653, 994, 1518, 897); +size_test!(sizes_761, Sntrup761, 1158, 1763, 1039); +size_test!(sizes_857, Sntrup857, 1322, 1999, 1184); +size_test!(sizes_953, Sntrup953, 1505, 2254, 1349); +size_test!(sizes_1013, Sntrup1013, 1623, 2417, 1455); +size_test!(sizes_1277, Sntrup1277, 2067, 3059, 1847); From ec4e609e50f5515c39dcd7aff6c1c51e39150c1c Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Fri, 12 Jun 2026 14:36:20 -0600 Subject: [PATCH 3/6] fix versioning for CI Signed-off-by: Mike Lodder --- Cargo.lock | 86 +++---------------------------------------- sntrup-kem/Cargo.toml | 2 +- 2 files changed, 6 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b113a8b..2b8ee27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,15 +22,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "alloca" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" -dependencies = [ - "cc", -] - [[package]] name = "anes" version = "0.1.6" @@ -246,35 +237,10 @@ dependencies = [ "cast", "ciborium", "clap", - "criterion-plot 0.6.0", - "itertools", - "num-traits", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" -dependencies = [ - "alloca", - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot 0.8.2", + "criterion-plot", "itertools", "num-traits", "oorandom", - "page_size", "plotters", "rayon", "regex", @@ -294,16 +260,6 @@ dependencies = [ "itertools", ] -[[package]] -name = "criterion-plot" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" -dependencies = [ - "cast", - "itertools", -] - [[package]] name = "critical-section" version = "1.2.0" @@ -523,7 +479,7 @@ version = "0.1.0" dependencies = [ "aes", "chacha20", - "criterion 0.7.0", + "criterion", "getrandom", "hex", "hybrid-array", @@ -719,7 +675,7 @@ name = "hqc-kem" version = "0.1.0" dependencies = [ "const-oid", - "criterion 0.7.0", + "criterion", "hex", "hybrid-array", "kem", @@ -869,7 +825,7 @@ name = "ml-kem" version = "0.3.2" dependencies = [ "const-oid", - "criterion 0.7.0", + "criterion", "getrandom", "hex", "hex-literal", @@ -995,16 +951,6 @@ dependencies = [ "primeorder", ] -[[package]] -name = "page_size" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "pem-rfc7468" version = "1.0.0" @@ -1462,7 +1408,7 @@ checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" name = "sntrup-kem" version = "0.1.0" dependencies = [ - "criterion 0.8.2", + "criterion", "getrandom", "hex", "rand", @@ -1752,22 +1698,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - [[package]] name = "winapi-util" version = "0.1.11" @@ -1777,12 +1707,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-link" version = "0.2.1" diff --git a/sntrup-kem/Cargo.toml b/sntrup-kem/Cargo.toml index 143e56b..bbd4bc2 100644 --- a/sntrup-kem/Cargo.toml +++ b/sntrup-kem/Cargo.toml @@ -37,7 +37,7 @@ thiserror = "2.0" zeroize = { version = "1", features = ["derive"] } [dev-dependencies] -criterion = { version = "0.8", features = ["html_reports"] } +criterion = "0.7" serde_json = "1" [[bench]] From a2cfc78d615b4cb83019ebfdbad1227e7b10a3ec Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Fri, 12 Jun 2026 14:43:05 -0600 Subject: [PATCH 4/6] more CI fixes Signed-off-by: Mike Lodder --- sntrup-kem/src/r3/vector.rs | 2 +- sntrup-kem/src/rq/vector.rs | 2 +- sntrup-kem/src/zx.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sntrup-kem/src/r3/vector.rs b/sntrup-kem/src/r3/vector.rs index deaaf36..6747cd1 100644 --- a/sntrup-kem/src/r3/vector.rs +++ b/sntrup-kem/src/r3/vector.rs @@ -196,7 +196,7 @@ unsafe fn product_neon(z: &mut [i8], n: usize, x: &[i8], c: i8) { } } -/// Fused minus_product and shift: z[i+1] = freeze(z[i] - y[i]*c), z[0] = 0. +/// Fused minus_product and shift: `z[i+1] = freeze(z[i] - y[i]*c)`, `z[0] = 0`. /// Processes backward to avoid overwrite conflicts, eliminating a separate memmove. #[inline(always)] pub fn minus_product_shift(z: &mut [i8], n: usize, y: &[i8], c: i8) { diff --git a/sntrup-kem/src/rq/vector.rs b/sntrup-kem/src/rq/vector.rs index 59909cc..d63ff7a 100644 --- a/sntrup-kem/src/rq/vector.rs +++ b/sntrup-kem/src/rq/vector.rs @@ -104,7 +104,7 @@ pub fn product(z: &mut [i16], n: usize, x: &[i16], c: i16, q: i32, b1: i32, b2: } } -/// Fused minus_product and shift: z[i+1] = freeze(z[i] - y[i]*c), z[0] = 0. +/// Fused minus_product and shift: `z[i+1] = freeze(z[i] - y[i]*c)`, `z[0] = 0`. /// Processes backward to avoid overwrite conflicts, eliminating a separate memmove. #[inline(always)] pub fn minus_product_shift(z: &mut [i16], n: usize, y: &[i16], c: i16, q: i32, b1: i32, b2: i32) { diff --git a/sntrup-kem/src/zx.rs b/sntrup-kem/src/zx.rs index c0c0989..58d40b6 100644 --- a/sntrup-kem/src/zx.rs +++ b/sntrup-kem/src/zx.rs @@ -145,7 +145,7 @@ pub mod random { } } - /// Process one pass of comparators: minmax(x[i+off0], x[i+off1]) + /// Process one pass of comparators: `minmax(x[i+off0], x[i+off1])` /// for all i in 0..(n-off1) where i & p_mask == 0. #[cfg(all( target_arch = "x86_64", From 1f28978b377ec82b34eb20468cc4fba21bbdcd08 Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Fri, 12 Jun 2026 15:20:09 -0600 Subject: [PATCH 5/6] fix flake Signed-off-by: Mike Lodder --- sntrup-kem/src/lib.rs | 14 ++++++++++++++ sntrup-kem/src/params.rs | 5 ++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sntrup-kem/src/lib.rs b/sntrup-kem/src/lib.rs index f250970..14bb092 100644 --- a/sntrup-kem/src/lib.rs +++ b/sntrup-kem/src/lib.rs @@ -8,6 +8,8 @@ //! # Usage //! //! ```rust +//! # #[cfg(all(feature = "kgen", feature = "ecap", feature = "dcap"))] +//! # { //! use sntrup_kem::{Sntrup761, SntrupKem}; //! //! let mut rng = rand::rng(); @@ -15,6 +17,7 @@ //! let (ct, ss1) = ek.encapsulate(&mut rng); //! let ss2 = dk.decapsulate(&ct); //! assert_eq!(ss1, ss2); +//! # } //! ``` //! //! # Security Levels @@ -44,6 +47,17 @@ //! - `dcap`: Decapsulation (default) //! - `serde`: Serde serialization support via `serdect` +// The `kgen`/`ecap`/`dcap` features select which KEM operations are compiled. +// Building with a subset (or none) of them leaves some shared internal helpers +// (`ct`, `r3`, `rq`, `zx`, `utils`, and their imports) without a caller — that is +// expected, not a defect. Dead-code/unused-import enforcement is therefore scoped +// to the full-feature build (default + `--all-features`); partial builds tolerate +// the uncalled helpers so the crate stays warning-clean under `-D warnings`. +#![cfg_attr( + not(all(feature = "kgen", feature = "ecap", feature = "dcap")), + allow(dead_code, unused_imports) +)] + mod ct; mod error; mod kem; diff --git a/sntrup-kem/src/params.rs b/sntrup-kem/src/params.rs index ad385d4..900630c 100644 --- a/sntrup-kem/src/params.rs +++ b/sntrup-kem/src/params.rs @@ -88,7 +88,10 @@ pub(crate) const SNTRUP953: SntrupParameters = SntrupParameters { sk_size: 2254, ct_size: 1349, barrett1: 165, - barrett2: 42313, + // floor(2^28 / 6343). The previous value (42313) sat one below the valid + // Barrett window [42314, 42339], so ~0.7% of reductions misreduced — causing + // rare sntrup953 decapsulation/roundtrip failures on every platform. + barrett2: 42319, }; /// sntrup1013 parameters. From f92afd0090c8e021b82b663ad1019322a376f355 Mon Sep 17 00:00:00 2001 From: Mike Lodder Date: Fri, 12 Jun 2026 16:53:22 -0600 Subject: [PATCH 6/6] update README.md Signed-off-by: Mike Lodder --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 60b0f83..e76a404 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ commonly used in transport encryption protocols (e.g. [TLS]) and hybrid cryptosy |----------------------|---------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|--------------------| | [`dhkem`](./dhkem) | [![crates.io](https://img.shields.io/crates/v/dhkem.svg?logo=rust)](https://crates.io/crates/dhkem) | [![Documentation](https://docs.rs/dhkem/badge.svg)](https://docs.rs/dhkem) | Diffie-Hellman KEM | | [`frodo‑kem`](./frodo-kem) | [![crates.io](https://img.shields.io/crates/v/frodo-kem.svg?logo=rust)](https://crates.io/crates/frodo-kem) | [![Documentation](https://docs.rs/frodo-kem/badge.svg)](https://docs.rs/frodo-kem) | Frodo KEM | +| [`hqc‑kem`](./hqc-kem) | [![crates.io](https://img.shields.io/crates/v/hqc-kem.svg?logo=rust)](https://crates.io/crates/hqc-kem) | [![Documentation](https://docs.rs/hqc-kem/badge.svg)](https://docs.rs/hqc-kem) | HQC Code-based KEM | | [`ml‑kem`](./ml-kem) | [![crates.io](https://img.shields.io/crates/v/ml-kem.svg?logo=rust)](https://crates.io/crates/ml-kem) | [![Documentation](https://docs.rs/ml-kem/badge.svg)](https://docs.rs/ml-kem) | Module Lattice KEM | +| [`sntrup‑kem`](./sntrup-kem) | [![crates.io](https://img.shields.io/crates/v/sntrup-kem.svg?logo=rust)](https://crates.io/crates/sntrup-kem) | [![Documentation](https://docs.rs/sntrup-kem/badge.svg)](https://docs.rs/sntrup-kem) | Streamlined NTRU Prime KEM | | [`x‑wing`](./x-wing) | [![crates.io](https://img.shields.io/crates/v/x-wing.svg?logo=rust)](https://crates.io/crates/x-wing) | [![Documentation](https://docs.rs/x-wing/badge.svg)](https://docs.rs/x-wing) | Hybrid PQ KEM | ## License