diff --git a/.gitignore b/.gitignore index f03ee8a4..cec6dbbd 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ kite_sql_tpcc copy.csv tests/data/row_20000.csv -tests/data/distinct_rows.csv \ No newline at end of file +tests/data/distinct_rows.csv diff --git a/Cargo.lock b/Cargo.lock index 423bb6b0..ce235dbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -995,36 +995,6 @@ dependencies = [ "slab", ] -[[package]] -name = "genawaiter" -version = "0.99.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c86bd0361bcbde39b13475e6e36cb24c329964aa2611be285289d1e4b751c1a0" -dependencies = [ - "genawaiter-macro", - "genawaiter-proc-macro", - "proc-macro-hack", -] - -[[package]] -name = "genawaiter-macro" -version = "0.99.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b32dfe1fdfc0bbde1f22a5da25355514b5e450c33a6af6770884c8750aedfbc" - -[[package]] -name = "genawaiter-proc-macro" -version = "0.99.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784f84eebc366e15251c4a8c3acee82a6a6f427949776ecb88377362a9621738" -dependencies = [ - "proc-macro-error", - "proc-macro-hack", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -1337,7 +1307,7 @@ dependencies = [ [[package]] name = "kite_sql" -version = "0.1.8" +version = "0.2.0" dependencies = [ "ahash 0.8.12", "async-trait", @@ -1353,13 +1323,14 @@ dependencies = [ "env_logger", "fixedbitset", "futures", - "genawaiter", "getrandom 0.2.16", "getrandom 0.3.3", "indicatif", "itertools 0.12.1", "js-sys", "kite_sql_serde_macros", + "lmdb", + "lmdb-sys", "log", "once_cell", "ordered-float", @@ -1389,7 +1360,7 @@ dependencies = [ [[package]] name = "kite_sql_serde_macros" -version = "0.1.2" +version = "0.2.0" dependencies = [ "darling", "proc-macro2", @@ -1503,6 +1474,28 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" +[[package]] +name = "lmdb" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0908efb5d6496aa977d96f91413da2635a902e5e31dbef0bfb88986c248539" +dependencies = [ + "bitflags 1.3.2", + "libc", + "lmdb-sys", +] + +[[package]] +name = "lmdb-sys" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5b392838cfe8858e86fac37cf97a0e8c55cc60ba0a18365cadc33092f128ce9" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "lock_api" version = "0.4.13" @@ -1942,38 +1935,6 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "syn-mid", - "version_check", -] - -[[package]] -name = "proc-macro-hack" -version = "0.5.20+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" - [[package]] name = "proc-macro2" version = "1.0.95" @@ -2740,17 +2701,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "syn-mid" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea305d57546cc8cd04feb14b62ec84bf17f50e3f7b12560d7bfa9265f39d9ed" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "tap" version = "1.0.1" @@ -2939,6 +2889,7 @@ dependencies = [ "indicatif", "kite_sql", "ordered-float", + "pprof", "rand 0.8.5", "rust_decimal", "sqlite", diff --git a/Cargo.toml b/Cargo.toml index dfd44da9..0677305a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "kite_sql" -version = "0.1.8" +version = "0.2.0" edition = "2021" authors = ["Kould ", "Xwg "] description = "SQL as a Function for Rust" @@ -17,17 +17,19 @@ default-run = "kite_sql" [[bin]] name = "kite_sql" path = "src/bin/server.rs" -required-features = ["net"] +required-features = ["net", "rocksdb"] [lib] doctest = false crate-type = ["cdylib", "rlib"] [features] -default = ["macros"] +default = ["macros", "rocksdb"] macros = [] orm = [] -net = ["dep:pgwire", "dep:async-trait", "dep:clap", "dep:env_logger", "dep:futures", "dep:log", "dep:tokio"] +rocksdb = ["dep:rocksdb"] +lmdb = ["dep:lmdb", "dep:lmdb-sys"] +net = ["rocksdb", "dep:pgwire", "dep:async-trait", "dep:clap", "dep:env_logger", "dep:futures", "dep:log", "dep:tokio"] pprof = ["pprof/criterion", "pprof/flamegraph"] python = ["dep:pyo3"] @@ -55,13 +57,12 @@ recursive = { version = "0.1" } regex = { version = "1" } rust_decimal = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } -kite_sql_serde_macros = { version = "0.1.2", path = "kite_sql_serde_macros" } +kite_sql_serde_macros = { version = "0.2.0", path = "kite_sql_serde_macros" } siphasher = { version = "1", features = ["serde"] } sqlparser = { version = "0.61", features = ["serde"] } thiserror = { version = "1" } typetag = { version = "0.2" } ulid = { version = "1", features = ["serde"] } -genawaiter = { version = "0.99" } # Feature: net async-trait = { version = "0.1", optional = true } @@ -84,7 +85,9 @@ tempfile = { version = "3.10" } sqlite = { version = "0.34" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -rocksdb = { version = "0.23" } +rocksdb = { version = "0.23", optional = true } +lmdb = { version = "0.8.0", optional = true } +lmdb-sys = { version = "0.8.0", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = { version = "0.2.106" } diff --git a/Makefile b/Makefile index 12639b6a..a2e5e467 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,12 @@ CARGO ?= cargo WASM_PACK ?= wasm-pack SQLLOGIC_PATH ?= tests/slt/**/*.slt PYO3_PYTHON ?= /usr/bin/python3.12 +TPCC_MEASURE_TIME ?= 15 +TPCC_NUM_WARE ?= 1 +TPCC_PPROF_OUTPUT ?= /tmp/tpcc_lmdb.svg +TPCC_SQLITE_PROFILE ?= balanced -.PHONY: test test-python test-wasm test-slt test-all wasm-build check tpcc tpcc-dual cargo-check build wasm-examples native-examples fmt clippy +.PHONY: test test-python test-wasm test-slt test-all wasm-build check tpcc tpcc-kitesql-rocksdb tpcc-kitesql-lmdb tpcc-lmdb-flamegraph tpcc-sqlite tpcc-sqlite-practical tpcc-sqlite-balanced tpcc-dual cargo-check build wasm-examples native-examples fmt clippy ## Run default Rust tests in the current environment (non-WASM). test: @@ -48,9 +52,31 @@ clippy: ## Run formatting (check mode) and clippy linting together. check: fmt clippy -## Execute the TPCC workload example as a standalone command. -tpcc: - $(CARGO) run -p tpcc --release +tpcc: tpcc-kitesql-lmdb + +## Execute the TPCC workload on KiteSQL with RocksDB storage. +tpcc-kitesql-rocksdb: + $(CARGO) run -p tpcc --release -- --backend kitesql-rocksdb + +## Execute the TPCC workload on KiteSQL with LMDB storage. +tpcc-kitesql-lmdb: + $(CARGO) run -p tpcc --release -- --backend kitesql-lmdb + +## Execute TPCC on LMDB and emit a pprof flamegraph SVG. +tpcc-lmdb-flamegraph: + CARGO_PROFILE_RELEASE_DEBUG=true $(CARGO) run -p tpcc --release --features pprof -- --backend kitesql-lmdb --measure-time $(TPCC_MEASURE_TIME) --num-ware $(TPCC_NUM_WARE) --pprof-output $(TPCC_PPROF_OUTPUT) + +## Execute the TPCC workload on SQLite with the practical profile. +tpcc-sqlite: + $(CARGO) run -p tpcc --release -- --backend sqlite --sqlite-profile $(TPCC_SQLITE_PROFILE) --path kite_sql_tpcc.sqlite + +## Execute the TPCC workload on SQLite with the practical profile. +tpcc-sqlite-practical: + $(MAKE) tpcc-sqlite TPCC_SQLITE_PROFILE=practical + +## Execute the TPCC workload on SQLite with the balanced profile. +tpcc-sqlite-balanced: + $(MAKE) tpcc-sqlite TPCC_SQLITE_PROFILE=balanced ## Execute TPCC while mirroring every statement to an in-memory SQLite instance for validation. tpcc-dual: diff --git a/README.md b/README.md index b3c0f616..ff6898bf 100755 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ ## Introduction **KiteSQL** is a lightweight embedded relational database for Rust, inspired by **MyRocks** and **SQLite** and fully written in Rust. It is designed to work not only as a SQL engine, but also as a Rust-native data API that can be embedded directly into applications without relying on external services or heavyweight infrastructure. -KiteSQL supports direct SQL execution, typed ORM models, schema migration, and builder-style queries, so you can combine relational power with an API surface that feels natural in Rust. +KiteSQL supports direct SQL execution, typed ORM models, schema migration, and builder-style queries, so you can combine relational power with an API surface that feels natural in Rust. On native targets, KiteSQL ships with both RocksDB-backed and LMDB-backed persistent storage builders, plus an in-memory builder for tests and temporary workloads. ## Key Features - A lightweight embedded SQL database fully rewritten in Rust @@ -79,7 +79,8 @@ struct UserSummary { } fn main() -> Result<(), DatabaseError> { - let database = DataBaseBuilder::path("./data").build()?; + let database = DataBaseBuilder::path("./data").build_rocksdb()?; + // Or: let database = DataBaseBuilder::path("./data").build_lmdb()?; database.migrate::()?; @@ -128,8 +129,20 @@ fn main() -> Result<(), DatabaseError> { } ``` +## Storage Backends +- `build_rocksdb()` opens a persistent RocksDB-backed database. +- `build_lmdb()` opens a persistent LMDB-backed database. +- `build_in_memory()` opens an in-memory database for tests, examples, and temporary workloads. +- `build_optimistic()` is available on native targets when you specifically want optimistic transactions on top of RocksDB. +- Cargo features: + - `rocksdb` is enabled by default + - `lmdb` is optional + - `cargo check --no-default-features --features lmdb` builds an LMDB-only native configuration + +On native targets, `LMDB` shines when reads dominate, while `RocksDB` is usually the stronger choice when writes do. + 👉**more examples** -- [hello_word](examples/hello_world.rs) +- [hello_world](examples/hello_world.rs) - [transaction](examples/transaction.rs) @@ -149,7 +162,7 @@ console.log(rows.map((r) => r.values.map((v) => v.Int32 ?? v))); ## Python (PyO3) - Enable bindings with Cargo feature `python`. -- Constructor is explicit: `Database(path)`; in-memory usage is `Database.in_memory()`. +- Constructor is explicit: `Database(path, backend="rocksdb")`; use `backend="lmdb"` to open LMDB. In-memory usage is `Database.in_memory()`. - Minimal usage: ```python import kite_sql @@ -162,7 +175,7 @@ for row in db.run("select * from demo"): ``` ## TPC-C -Run `make tpcc` (or `cargo run -p tpcc --release`) to execute the benchmark against the default KiteSQL storage. +Run `make tpcc` (or `cargo run -p tpcc --release`) to execute the benchmark against the default KiteSQL storage. Use `--backend rocksdb` or `--backend lmdb` to compare the two persistent backends directly. Run `make tpcc-dual` to mirror every TPCC statement to an in-memory SQLite database alongside KiteSQL and assert the two engines return identical results; this target runs for 60 seconds (`--measure-time 60`). Use `cargo run -p tpcc --release -- --backend dual --measure-time ` for a custom duration. - i9-13900HX @@ -170,17 +183,16 @@ Run `make tpcc-dual` to mirror every TPCC statement to an in-memory SQLite datab - KIOXIA-EXCERIA PLUS G3 SSD - Tips: TPC-C currently only supports single thread -All cases have been fully optimized. -```shell -<90th Percentile RT (MaxRT)> - New-Order : 0.002 (0.005) - Payment : 0.001 (0.013) -Order-Status : 0.002 (0.006) - Delivery : 0.010 (0.023) - Stock-Level : 0.002 (0.017) - -27226 Tpmc -``` +Recent 720-second local comparison on the machine above: + +| Backend | TpmC | New-Order p90 | Payment p90 | Order-Status p90 | Delivery p90 | Stock-Level p90 | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| KiteSQL LMDB | 53510 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | +| KiteSQL RocksDB | 32248 | 0.001s | 0.001s | 0.002s | 0.011s | 0.003s | +| SQLite balanced | 36273 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| SQLite practical | 35516 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | + +The detailed raw outputs for both runs are recorded in [tpcc/README.md](tpcc/README.md). #### 👉[check more](tpcc/README.md) ## Roadmap diff --git a/benchmarks/query_benchmark.rs b/benchmarks/query_benchmark.rs index 3d730b33..a3118747 100644 --- a/benchmarks/query_benchmark.rs +++ b/benchmarks/query_benchmark.rs @@ -14,7 +14,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use indicatif::{ProgressBar, ProgressStyle}; -use kite_sql::db::{DataBaseBuilder, ResultIter}; +use kite_sql::db::DataBaseBuilder; use kite_sql::errors::DatabaseError; #[cfg(unix)] use pprof::criterion::{Output, PProfProfiler}; @@ -38,7 +38,7 @@ fn query_cases() -> Vec<(&'static str, &'static str)> { } fn init_kitesql_query_bench() -> Result<(), DatabaseError> { - let database = DataBaseBuilder::path(QUERY_BENCH_KITE_SQL_PATH).build()?; + let database = DataBaseBuilder::path(QUERY_BENCH_KITE_SQL_PATH).build_rocksdb()?; database .run("create table t1 (c1 int primary key, c2 int)")? .done()?; @@ -104,7 +104,7 @@ fn query_on_execute(c: &mut Criterion) { init_kitesql_query_bench().unwrap(); } let database = DataBaseBuilder::path(QUERY_BENCH_KITE_SQL_PATH) - .build() + .build_rocksdb() .unwrap(); println!("Table initialization completed"); diff --git a/examples/hello_world.rs b/examples/hello_world.rs index 4bb56224..7b7a5bd0 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -14,20 +14,29 @@ #[cfg(all(not(target_arch = "wasm32"), feature = "orm"))] mod app { - use kite_sql::db::{DataBaseBuilder, ResultIter}; + use kite_sql::db::{DataBaseBuilder, Database}; use kite_sql::errors::DatabaseError; + use kite_sql::storage::Storage; use kite_sql::Model; + use std::env; use std::fs; use std::io::ErrorKind; + use std::path::Path; const EXAMPLE_DB_PATH: &str = "./example_data/hello_world"; fn reset_example_dir() -> Result<(), DatabaseError> { - match fs::remove_dir_all(EXAMPLE_DB_PATH) { - Ok(()) => Ok(()), - Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), - Err(err) => Err(err.into()), + if let Err(err) = fs::remove_dir_all(EXAMPLE_DB_PATH) { + if err.kind() != ErrorKind::NotFound { + return Err(err.into()); + } } + + if let Some(parent) = Path::new(EXAMPLE_DB_PATH).parent() { + fs::create_dir_all(parent)?; + } + + Ok(()) } #[derive(Default, Debug, PartialEq, Model)] @@ -38,10 +47,7 @@ mod app { pub c2: String, } - pub fn run() -> Result<(), DatabaseError> { - reset_example_dir()?; - let database = DataBaseBuilder::path(EXAMPLE_DB_PATH).build()?; - + fn run_with_database(database: Database) -> Result<(), DatabaseError> { database.create_table_if_not_exists::()?; database.insert(&MyStruct { c1: 0, @@ -78,6 +84,29 @@ mod app { Ok(()) } + + pub fn run() -> Result<(), DatabaseError> { + reset_example_dir()?; + let backend = env::var("KITESQL_BACKEND").unwrap_or_else(|_| "rocksdb".to_string()); + + match backend.to_ascii_lowercase().as_str() { + #[cfg(feature = "rocksdb")] + "rocksdb" => run_with_database(DataBaseBuilder::path(EXAMPLE_DB_PATH).build_rocksdb()?), + #[cfg(feature = "lmdb")] + "lmdb" => run_with_database(DataBaseBuilder::path(EXAMPLE_DB_PATH).build_lmdb()?), + other => Err(DatabaseError::InvalidValue(format!( + "unsupported example backend '{other}', expected {}", + { + let mut expected = Vec::new(); + #[cfg(feature = "rocksdb")] + expected.push("rocksdb"); + #[cfg(feature = "lmdb")] + expected.push("lmdb"); + expected.join(" or ") + } + ))), + } + } } #[cfg(target_arch = "wasm32")] diff --git a/examples/transaction.rs b/examples/transaction.rs index 89e75c50..6e0db24e 100644 --- a/examples/transaction.rs +++ b/examples/transaction.rs @@ -14,25 +14,33 @@ #[cfg(not(target_arch = "wasm32"))] mod app { - use kite_sql::db::{DataBaseBuilder, ResultIter}; + use kite_sql::db::DataBaseBuilder; use kite_sql::errors::DatabaseError; use kite_sql::types::tuple::Tuple; use kite_sql::types::value::DataValue; use std::fs; use std::io::ErrorKind; + use std::path::Path; const EXAMPLE_DB_PATH: &str = "./example_data/transaction"; fn reset_example_dir() -> Result<(), DatabaseError> { - match fs::remove_dir_all(EXAMPLE_DB_PATH) { - Ok(()) => Ok(()), - Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), - Err(err) => Err(err.into()), + if let Err(err) = fs::remove_dir_all(EXAMPLE_DB_PATH) { + if err.kind() != ErrorKind::NotFound { + return Err(err.into()); + } } + + if let Some(parent) = Path::new(EXAMPLE_DB_PATH).parent() { + fs::create_dir_all(parent)?; + } + + Ok(()) } pub fn run() -> Result<(), DatabaseError> { reset_example_dir()?; + // Optimistic transactions are currently backed by RocksDB. let database = DataBaseBuilder::path(EXAMPLE_DB_PATH).build_optimistic()?; database .run("create table if not exists t1 (c1 int primary key, c2 int)")? diff --git a/kite_sql_serde_macros/Cargo.toml b/kite_sql_serde_macros/Cargo.toml index f0220c51..e590741e 100644 --- a/kite_sql_serde_macros/Cargo.toml +++ b/kite_sql_serde_macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kite_sql_serde_macros" -version = "0.1.2" +version = "0.2.0" edition = "2021" description = "Derive macros for KiteSQL" documentation = "https://docs.rs/kite_sql_serde_macros/latest/kite_sql_serde_macros/" diff --git a/kite_sql_serde_macros/src/orm.rs b/kite_sql_serde_macros/src/orm.rs index 4b885322..40b04056 100644 --- a/kite_sql_serde_macros/src/orm.rs +++ b/kite_sql_serde_macros/src/orm.rs @@ -176,7 +176,6 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { primary_key_value = Some(quote! { &self.#field_name }); - } else { } generics @@ -257,11 +256,11 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { } }); if is_unique { - let unique_index_name_value = format!("uk_{}_index", column_name); + let unique_index_name_value = format!("uk_{column_name}_index"); if !index_names.insert(unique_index_name_value.clone()) { return Err(Error::new_spanned( struct_name, - format!("duplicate ORM index name: {}", unique_index_name_value), + format!("duplicate ORM index name: {unique_index_name_value}"), )); } } @@ -272,7 +271,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { if !index_names.insert(index_name.clone()) { return Err(Error::new_spanned( struct_name, - format!("duplicate ORM index name: {}", index_name), + format!("duplicate ORM index name: {index_name}"), )); } create_index_statements.push(quote! { @@ -325,8 +324,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { Error::new_spanned( struct_name, format!( - "unknown ORM index column `{}`; use a persisted field name or column name", - raw_column + "unknown ORM index column `{raw_column}`; use a persisted field name or column name", ), ) })?; diff --git a/scripts/run_tpcc_matrix.sh b/scripts/run_tpcc_matrix.sh new file mode 100755 index 00000000..444e8c8e --- /dev/null +++ b/scripts/run_tpcc_matrix.sh @@ -0,0 +1,177 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +NUM_WARE="${TPCC_NUM_WARE:-1}" +MAX_DUPLICATE_RETRY="${TPCC_DUPLICATE_RETRY:-1}" +MAIN_MEASURE_TIME="${TPCC_MAIN_MEASURE_TIME:-}" +STAMP="${TPCC_RESULT_STAMP:-$(date +%Y-%m-%d_%H-%M-%S)}" +RESULT_DIR="${TPCC_RESULT_DIR:-$ROOT_DIR/tpcc/results/$STAMP}" +LOG_DIR="$RESULT_DIR/logs" +TMP_DIR="$ROOT_DIR/target/tpcc-run-data" +BINARY="$ROOT_DIR/target/release/tpcc" +SUMMARY_FILE="$RESULT_DIR/summary.md" + +mkdir -p "$LOG_DIR" "$TMP_DIR" + +if [[ ! -x "$BINARY" ]]; then + echo "missing binary: $BINARY" >&2 + echo "build it first with: cargo build -p tpcc --release" >&2 + exit 1 +fi + +extract_tpmc() { + local log_file="$1" + awk '//{getline; print $1; exit}' "$log_file" +} + +extract_p90() { + local log_file="$1" + local label="$2" + awk -v label="$label" ' + /<90th Percentile RT \(MaxRT\)>/ { in_block = 1; next } + in_block && index($0, label) { + gsub(/^[[:space:]]+/, "", $0) + print $3 + exit + } + ' "$log_file" +} + +should_retry_duplicate() { + local log_file="$1" + rg -q "UNIQUE constraint failed|duplicate key|primary key|Duplicate" "$log_file" +} + +run_variant() { + local name="$1" + local measure_label="$2" + local db_path="$3" + shift 3 + local -a cmd=("$@") + local log_file="$LOG_DIR/$name.log" + local status="ok" + local notes="-" + local attempts=0 + local max_attempts=$((MAX_DUPLICATE_RETRY + 1)) + + : > "$log_file" + + while (( attempts < max_attempts )); do + attempts=$((attempts + 1)) + rm -rf "$db_path" + + { + printf '## Attempt %s\n' "$attempts" + printf '$' + printf ' %q' "${cmd[@]}" + printf '\n\n' + } >> "$log_file" + + set +e + "${cmd[@]}" >> "$log_file" 2>&1 + local cmd_status=$? + set -e + + if [[ "$cmd_status" -eq 0 ]]; then + break + fi + + if (( attempts < max_attempts )) && should_retry_duplicate "$log_file"; then + notes="retry after duplicate-key failure" + printf '\n[runner] duplicate-key style failure detected, retrying %s from scratch\n\n' "$name" >> "$log_file" + continue + fi + + status="failed" + notes="$(tail -n 5 "$log_file" | tr '\n' ' ' | sed 's/[[:space:]]\\+/ /g; s/^ //; s/ $//')" + break + done + + local tpmc="-" + local new_order="-" + local payment="-" + local order_status="-" + local delivery="-" + local stock_level="-" + + if [[ "$status" == "ok" ]]; then + tpmc="$(extract_tpmc "$log_file" || echo -)" + new_order="$(extract_p90 "$log_file" "New-Order" || echo -)" + payment="$(extract_p90 "$log_file" "Payment" || echo -)" + order_status="$(extract_p90 "$log_file" "Order-Status" || echo -)" + delivery="$(extract_p90 "$log_file" "Delivery" || echo -)" + stock_level="$(extract_p90 "$log_file" "Stock-Level" || echo -)" + fi + + printf '| %s | %s | %s | %s | %s | %s | %s | %s | %s | %s | %s | [%s](./logs/%s.log) |\n' \ + "$name" \ + "$status" \ + "$attempts" \ + "$measure_label" \ + "$tpmc" \ + "$new_order" \ + "$payment" \ + "$order_status" \ + "$delivery" \ + "$stock_level" \ + "$notes" \ + "$name" \ + "$name" \ + >> "$SUMMARY_FILE" + + rm -rf "$db_path" + + cat "$log_file" +} + +cat > "$SUMMARY_FILE" < + Send) -> Result { - let database = DataBaseBuilder::path(path).build()?; + let database = DataBaseBuilder::path(path).build_rocksdb()?; Ok(KiteSQLBackend { inner: Arc::new(database), @@ -213,29 +214,22 @@ impl SimpleQueryHandler for SessionBackend { _ => { let mut guard = self.tx.lock(); - let mut tuples = Vec::new(); let response = if let Some(transaction) = guard.as_mut() { let mut iter = unsafe { transaction.as_mut().run(query) } .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - for tuple in iter.by_ref() { - tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?); - } - let schema = iter.schema().clone(); + let response = encode_query_result(&mut iter)?; iter.done() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - encode_tuples(&schema, tuples)? + response } else { let mut iter = self .inner .run(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - for tuple in iter.by_ref() { - tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?); - } - let schema = iter.schema().clone(); + let response = encode_query_result(&mut iter)?; iter.done() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - encode_tuples(&schema, tuples)? + response }; Ok(vec![Response::Query(response)]) } @@ -243,13 +237,25 @@ impl SimpleQueryHandler for SessionBackend { } } -fn encode_tuples<'a>(schema: &SchemaRef, tuples: Vec) -> PgWireResult> { - if tuples.is_empty() { - return Ok(QueryResponse::new(Arc::new(vec![]), stream::empty())); +fn encode_query_result<'a, I>(iter: &mut I) -> PgWireResult> +where + I: BorrowResultIter, +{ + let fields = encode_fields(iter.schema())?; + let mut results = Vec::new(); + + while let Some(tuple) = iter + .next_borrowed_tuple() + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + { + results.push(encode_tuple(fields.clone(), tuple)); } - let mut results = Vec::with_capacity(tuples.len()); - let schema = Arc::new( + Ok(QueryResponse::new(fields, stream::iter(results))) +} + +fn encode_fields(schema: &SchemaRef) -> PgWireResult>> { + Ok(Arc::new( schema .iter() .map(|column| { @@ -264,41 +270,37 @@ fn encode_tuples<'a>(schema: &SchemaRef, tuples: Vec) -> PgWireResult>>()?, - ); - - for tuple in tuples { - let mut encoder = DataRowEncoder::new(schema.clone()); - for value in tuple.values { - match value.logical_type() { - LogicalType::SqlNull => encoder.encode_field(&None::), - LogicalType::Boolean => encoder.encode_field(&value.bool()), - LogicalType::Tinyint => encoder.encode_field(&value.i8()), - LogicalType::UTinyint => encoder.encode_field(&value.u8().map(|v| v as i8)), - LogicalType::Smallint => encoder.encode_field(&value.i16()), - LogicalType::USmallint => encoder.encode_field(&value.u16().map(|v| v as i16)), - LogicalType::Integer => encoder.encode_field(&value.i32()), - LogicalType::UInteger => encoder.encode_field(&value.u32()), - LogicalType::Bigint => encoder.encode_field(&value.i64()), - LogicalType::UBigint => encoder.encode_field(&value.u64().map(|v| v as i64)), - LogicalType::Float => encoder.encode_field(&value.float()), - LogicalType::Double => encoder.encode_field(&value.double()), - LogicalType::Char(..) | LogicalType::Varchar(..) => { - encoder.encode_field(&value.utf8()) - } - LogicalType::Date => encoder.encode_field(&value.date()), - LogicalType::DateTime => encoder.encode_field(&value.datetime()), - LogicalType::Time(_) => encoder.encode_field(&value.time()), - LogicalType::Decimal(_, _) => { - encoder.encode_field(&value.decimal().map(|decimal| decimal.to_string())) - } - _ => unreachable!(), - }?; - } + )) +} - results.push(encoder.finish()); +fn encode_tuple(schema: Arc>, tuple: &Tuple) -> PgWireResult { + let mut encoder = DataRowEncoder::new(schema); + for value in &tuple.values { + match value.logical_type() { + LogicalType::SqlNull => encoder.encode_field(&None::), + LogicalType::Boolean => encoder.encode_field(&value.bool()), + LogicalType::Tinyint => encoder.encode_field(&value.i8()), + LogicalType::UTinyint => encoder.encode_field(&value.u8().map(|v| v as i8)), + LogicalType::Smallint => encoder.encode_field(&value.i16()), + LogicalType::USmallint => encoder.encode_field(&value.u16().map(|v| v as i16)), + LogicalType::Integer => encoder.encode_field(&value.i32()), + LogicalType::UInteger => encoder.encode_field(&value.u32()), + LogicalType::Bigint => encoder.encode_field(&value.i64()), + LogicalType::UBigint => encoder.encode_field(&value.u64().map(|v| v as i64)), + LogicalType::Float => encoder.encode_field(&value.float()), + LogicalType::Double => encoder.encode_field(&value.double()), + LogicalType::Char(..) | LogicalType::Varchar(..) => encoder.encode_field(&value.utf8()), + LogicalType::Date => encoder.encode_field(&value.date()), + LogicalType::DateTime => encoder.encode_field(&value.datetime()), + LogicalType::Time(_) => encoder.encode_field(&value.time()), + LogicalType::Decimal(_, _) => { + encoder.encode_field(&value.decimal().map(|decimal| decimal.to_string())) + } + _ => unreachable!(), + }?; } - Ok(QueryResponse::new(schema, stream::iter(results))) + encoder.finish() } fn into_pg_type(data_type: &LogicalType) -> PgWireResult { diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 2765e65a..f372b414 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use super::{Binder, QueryBindStep}; use crate::errors::DatabaseError; use crate::expression::function::scala::ScalarFunction; +use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::value::DataValue; @@ -34,10 +35,14 @@ impl> Binder<'_, '_, T, A> children: LogicalPlan, agg_calls: Vec, groupby_exprs: Vec, - ) -> LogicalPlan { + ) -> Result { self.context.step(QueryBindStep::Agg); - - AggregateOperator::build(children, agg_calls, groupby_exprs, false) + Ok(AggregateOperator::build( + children, + agg_calls, + groupby_exprs, + false, + )) } pub fn extract_select_aggregate( @@ -104,6 +109,18 @@ impl> Binder<'_, '_, T, A> Ok((return_having, return_orderby)) } + pub fn bind_aggregate_output_exprs<'c>( + &self, + exprs: impl IntoIterator, + ) -> Result<(), DatabaseError> { + let mut binder = + AggregateOutputBinder::new(&self.context.agg_calls, &self.context.group_by_exprs); + for expr in exprs { + binder.visit(expr)?; + } + Ok(()) + } + fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression) -> Result<(), DatabaseError> { match expr { ScalarExpression::AggCall { .. } => { @@ -452,3 +469,150 @@ impl> Binder<'_, '_, T, A> } } } + +struct AggregateOutputBinder<'a> { + agg_calls: &'a [ScalarExpression], + group_by_exprs: &'a [ScalarExpression], +} + +impl<'a> AggregateOutputBinder<'a> { + fn new(agg_calls: &'a [ScalarExpression], group_by_exprs: &'a [ScalarExpression]) -> Self { + Self { + agg_calls, + group_by_exprs, + } + } + + fn output_ref(&self, expr: &ScalarExpression) -> Option { + self.agg_calls + .iter() + .chain(self.group_by_exprs.iter()) + .position(|candidate| { + candidate == expr || candidate.unpack_alias_ref() == expr.unpack_alias_ref() + }) + .map(|position| { + let output_expr = self + .agg_calls + .iter() + .chain(self.group_by_exprs.iter()) + .nth(position) + .unwrap(); + ScalarExpression::column_expr(output_expr.output_column(), position) + }) + } +} + +impl<'a> VisitorMut<'a> for AggregateOutputBinder<'_> { + fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { + if let ScalarExpression::Alias { + expr: inner_expr, + alias: crate::expression::AliasType::Name(_), + } = expr + { + return self.visit(inner_expr); + } + + if let Some(output_ref) = self.output_ref(expr) { + *expr = output_ref; + return Ok(()); + } + walk_mut_expr(self, expr) + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::AggregateOutputBinder; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::errors::DatabaseError; + use crate::expression::agg::AggKind; + use crate::expression::visitor_mut::VisitorMut; + use crate::expression::{AliasType, ScalarExpression}; + use crate::types::LogicalType; + + fn test_column(name: &str, ty: LogicalType) -> ColumnRef { + ColumnRef::from(ColumnCatalog::new( + name.to_string(), + true, + ColumnDesc::new(ty, None, false, None).unwrap(), + )) + } + + fn test_count(expr: ScalarExpression) -> ScalarExpression { + ScalarExpression::AggCall { + distinct: false, + kind: AggKind::Count, + args: vec![expr], + ty: LogicalType::Bigint, + } + } + + #[test] + fn test_aggregate_output_binder_rewrites_agg_and_group_slots() -> Result<(), DatabaseError> { + let group_column = test_column("c1", LogicalType::Integer); + let agg_column = test_column("c2", LogicalType::Integer); + + let group_expr = ScalarExpression::column_expr(group_column, 0); + let agg_expr = test_count(ScalarExpression::column_expr(agg_column, 1)); + + let agg_output = ScalarExpression::Alias { + expr: Box::new(agg_expr.clone()), + alias: AliasType::Name("cnt".to_string()), + }; + let group_output = ScalarExpression::Alias { + expr: Box::new(group_expr.clone()), + alias: AliasType::Name("g".to_string()), + }; + + let mut binder = AggregateOutputBinder::new( + std::slice::from_ref(&agg_output), + std::slice::from_ref(&group_output), + ); + + let mut order_by_agg = ScalarExpression::Alias { + expr: Box::new(agg_expr), + alias: AliasType::Name("cnt".to_string()), + }; + binder.visit(&mut order_by_agg)?; + assert_eq!( + order_by_agg, + ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr(agg_output.output_column(), 0)), + alias: AliasType::Name("cnt".to_string()), + } + ); + + let mut order_by_group = group_expr; + binder.visit(&mut order_by_group)?; + assert_eq!( + order_by_group, + ScalarExpression::column_expr(group_output.output_column(), 1) + ); + + Ok(()) + } + + #[test] + fn test_aggregate_output_binder_matches_alias_expr_reference() -> Result<(), DatabaseError> { + let group_column = test_column("c1", LogicalType::Integer); + let group_expr = ScalarExpression::column_expr(group_column, 0); + let group_output = ScalarExpression::Alias { + expr: Box::new(group_expr.clone()), + alias: AliasType::Name("g".to_string()), + }; + + let mut binder = AggregateOutputBinder::new(&[], std::slice::from_ref(&group_output)); + let mut target = ScalarExpression::Alias { + expr: Box::new(ScalarExpression::Constant(1_i32.into())), + alias: AliasType::Expr(Box::new(group_expr)), + }; + + binder.visit(&mut target)?; + assert_eq!( + target, + ScalarExpression::column_expr(group_output.output_column(), 0) + ); + + Ok(()) + } +} diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs index 4d313161..8e823555 100644 --- a/src/binder/alter_table.rs +++ b/src/binder/alter_table.rs @@ -39,7 +39,7 @@ impl> Binder<'_, '_, T, A> ) -> Result { let mut expr = self.bind_expr(expr)?; - if !expr.referenced_columns(true).is_empty() { + if expr.any_referenced_column(true, |_| true) { return Err(DatabaseError::UnsupportedStmt( "column is not allowed to exist in default".to_string(), )); diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 515f58da..be6594ef 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -53,6 +53,11 @@ impl> Binder<'_, '_, T, A> let plan = match source { Source::Table(table) => TableScanOperator::build(table_name.clone(), table, true)?, Source::View(view) => LogicalPlan::clone(&view.plan), + Source::Schema(_) => { + return Err(DatabaseError::UnsupportedStmt( + "derived source cannot be rebound as a base relation".to_string(), + )) + } }; let mut columns = Vec::with_capacity(index_columns.len()); diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 4b22679d..4179c345 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -156,7 +156,7 @@ impl> Binder<'_, '_, T, A> ColumnOption::Default(expr) => { let mut expr = self.bind_expr(expr)?; - if !expr.referenced_columns(true).is_empty() { + if expr.any_referenced_column(true, |_| true) { return Err(DatabaseError::UnsupportedStmt( "column is not allowed to exist in `default`".to_string(), )); diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index 9d7b684f..9dada889 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -52,9 +52,10 @@ impl> Binder<'_, '_, T, A> column.set_ref_table(view_name.clone(), Ulid::new(), true); ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(mapping_column.clone())), + expr: Box::new(ScalarExpression::column_expr(mapping_column.clone(), i)), alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( ColumnRef::from(column), + i, ))), } }) diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index a4766148..b52572c6 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -13,8 +13,11 @@ // limitations under the License. use crate::binder::{Binder, QueryBindStep}; +use crate::errors::DatabaseError; +use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; +use crate::planner::operator::sort::SortField; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::value::DataValue; @@ -24,9 +27,169 @@ impl> Binder<'_, '_, T, A> &mut self, children: LogicalPlan, select_list: Vec, - ) -> LogicalPlan { + ) -> Result { self.context.step(QueryBindStep::Distinct); - AggregateOperator::build(children, vec![], select_list, true) + Ok(AggregateOperator::build( + children, + vec![], + select_list, + true, + )) + } + + pub fn bind_distinct_output_exprs<'c>( + &self, + select_list: &[ScalarExpression], + exprs: impl IntoIterator, + ) -> Result<(), DatabaseError> { + let mut binder = DistinctOutputBinder::new(select_list); + for expr in exprs { + binder.visit(expr)?; + } + Ok(()) + } + + pub fn bind_distinct_orderby_exprs( + &self, + select_list: &[ScalarExpression], + orderby: &mut [SortField], + ) -> Result<(), DatabaseError> { + let binder = DistinctOutputBinder::new(select_list); + + for field in orderby { + field.expr = binder.output_ref(&field.expr).ok_or_else(|| { + DatabaseError::InvalidValue(format!( + "for SELECT DISTINCT, ORDER BY expressions must appear in select list: '{}'", + field.expr + )) + })?; + } + + Ok(()) + } +} + +struct DistinctOutputBinder<'a> { + select_list: &'a [ScalarExpression], +} + +impl<'a> DistinctOutputBinder<'a> { + fn new(select_list: &'a [ScalarExpression]) -> Self { + Self { select_list } + } + + fn output_ref(&self, expr: &ScalarExpression) -> Option { + self.select_list + .iter() + .position(|candidate| { + candidate == expr || candidate.unpack_alias_ref() == expr.unpack_alias_ref() + }) + .map(|position| { + let output_expr = &self.select_list[position]; + ScalarExpression::column_expr(output_expr.output_column(), position) + }) + } +} + +impl<'a> VisitorMut<'a> for DistinctOutputBinder<'_> { + fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { + if let ScalarExpression::Alias { + expr: inner_expr, + alias: crate::expression::AliasType::Name(_), + } = expr + { + return self.visit(inner_expr); + } + + if let Some(output_ref) = self.output_ref(expr) { + *expr = output_ref; + return Ok(()); + } + walk_mut_expr(self, expr) + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::DistinctOutputBinder; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::errors::DatabaseError; + use crate::expression::visitor_mut::VisitorMut; + use crate::expression::{AliasType, ScalarExpression}; + use crate::types::LogicalType; + + fn test_column(name: &str, ty: LogicalType) -> ColumnRef { + ColumnRef::from(ColumnCatalog::new( + name.to_string(), + true, + ColumnDesc::new(ty, None, false, None).unwrap(), + )) + } + + #[test] + fn test_distinct_output_binder_rewrites_output_slot() -> Result<(), DatabaseError> { + let left_column = test_column("c1", LogicalType::Integer); + let right_column = test_column("c2", LogicalType::Integer); + + let left_expr = ScalarExpression::column_expr(left_column, 0); + let right_expr = ScalarExpression::column_expr(right_column, 1); + let second_output = right_expr.clone(); + let select_output = ScalarExpression::Alias { + expr: Box::new(left_expr.clone()), + alias: AliasType::Name("v".to_string()), + }; + let select_list = [select_output.clone(), right_expr.clone()]; + + let mut binder = DistinctOutputBinder::new(&select_list); + + let mut order_by_alias = ScalarExpression::Alias { + expr: Box::new(left_expr), + alias: AliasType::Name("v".to_string()), + }; + binder.visit(&mut order_by_alias)?; + assert_eq!( + order_by_alias, + ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + select_output.output_column(), + 0 + )), + alias: AliasType::Name("v".to_string()), + } + ); + + let mut order_by_second = right_expr; + binder.visit(&mut order_by_second)?; + assert_eq!( + order_by_second, + ScalarExpression::column_expr(second_output.output_column(), 1) + ); + + Ok(()) + } + + #[test] + fn test_distinct_output_binder_matches_alias_expr_reference() -> Result<(), DatabaseError> { + let column = test_column("c1", LogicalType::Integer); + let expr = ScalarExpression::column_expr(column, 0); + let select_output = ScalarExpression::Alias { + expr: Box::new(expr.clone()), + alias: AliasType::Name("v".to_string()), + }; + + let mut binder = DistinctOutputBinder::new(std::slice::from_ref(&select_output)); + let mut target = ScalarExpression::Alias { + expr: Box::new(ScalarExpression::Constant(1_i32.into())), + alias: AliasType::Expr(Box::new(expr)), + }; + + binder.visit(&mut target)?; + assert_eq!( + target, + ScalarExpression::column_expr(select_output.output_column(), 0) + ); + + Ok(()) } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index c8a3a239..5e6ea8c1 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -32,8 +32,10 @@ use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; use crate::expression::function::FunctionSummary; use crate::expression::{AliasType, ScalarExpression}; +use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; use crate::planner::{LogicalPlan, SchemaOutput}; use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, Utf8Type}; use crate::types::{ColumnId, LogicalType}; @@ -80,6 +82,51 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } } + fn find_column_in_schema<'b>( + schema_ref: &'b SchemaRef, + column_name: &str, + ) -> Option<(usize, &'b ColumnRef)> { + schema_ref + .iter() + .enumerate() + .find(|(_, column)| column.name() == column_name) + } + + fn find_column_in_scope( + context: &BinderContext<'a, T>, + table_schema_buf: &mut HashMap>, + column_name: &str, + ) -> Option { + let mut position_offset = 0; + + for bound_source in &context.bind_table { + let schema_buf = table_schema_buf + .entry(bound_source.table_name.clone()) + .or_default(); + let schema_ref = bound_source.source.schema_ref(schema_buf); + + if let Some((position, column)) = Self::find_column_in_schema(&schema_ref, column_name) + { + return Some(ScalarExpression::column_expr( + column.clone(), + position_offset + position, + )); + } + + position_offset += schema_ref.len(); + } + + None + } + + fn column_not_found_with_span(idents: &[Ident], column_name: &str) -> DatabaseError { + let err = DatabaseError::column_not_found(column_name.to_string()); + match idents.last() { + Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), + None => err, + } + } + pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result { match expr { Expr::Identifier(ident) => { @@ -212,6 +259,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } Expr::Subquery(subquery) => { let (sub_query, column, correlated) = self.bind_subquery(None, subquery)?; + let sub_query = ScalarSubqueryOperator::build(sub_query); let (expr, sub_query) = if !self.context.is_step(&QueryBindStep::Where) { self.bind_temp_table(column, sub_query)? } else { @@ -317,17 +365,43 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T expr: ScalarExpression, sub_query: LogicalPlan, ) -> Result<(ScalarExpression, LogicalPlan), DatabaseError> { + let (exprs, is_tuple) = match expr { + ScalarExpression::Tuple(exprs) => (exprs, true), + expr => (vec![expr], false), + }; + let mut alias_exprs = Vec::with_capacity(exprs.len()); + let mut alias_refs = Vec::with_capacity(exprs.len()); + + for (position, expr) in exprs.into_iter().enumerate() { + let (alias_expr, alias_ref) = self.bind_temp_table_alias(expr, position); + if !is_tuple { + let alias_plan = Self::build_project_plan(sub_query, vec![alias_expr.clone()]); + return Ok((alias_expr, alias_plan)); + } + alias_exprs.push(alias_expr); + alias_refs.push(alias_ref); + } + + let alias_plan = Self::build_project_plan(sub_query, alias_exprs); + Ok((ScalarExpression::Tuple(alias_refs), alias_plan)) + } + + fn bind_temp_table_alias( + &mut self, + expr: ScalarExpression, + position: usize, + ) -> (ScalarExpression, ScalarExpression) { let mut alias_column = ColumnCatalog::clone(&expr.output_column()); alias_column.set_ref_table(self.context.temp_table(), ColumnId::new(), true); - let alias_expr = ScalarExpression::Alias { - expr: Box::new(expr), - alias: AliasType::Expr(Box::new(ScalarExpression::column_expr(ColumnRef::from( - alias_column, - )))), - }; - let alias_plan = self.bind_project(sub_query, vec![alias_expr.clone()])?; - Ok((alias_expr, alias_plan)) + let alias_ref = ScalarExpression::column_expr(ColumnRef::from(alias_column), position); + ( + ScalarExpression::Alias { + expr: Box::new(expr), + alias: AliasType::Expr(Box::new(alias_ref.clone())), + }, + alias_ref, + ) } fn bind_subquery( @@ -375,13 +449,14 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let columns = sub_query_schema .iter() - .map(|column| ScalarExpression::column_expr(column.clone())) + .enumerate() + .map(|(position, column)| ScalarExpression::column_expr(column.clone(), position)) .collect::>(); ScalarExpression::Tuple(columns) } else { fn_check(1)?; - ScalarExpression::column_expr(sub_query_schema[0].clone()) + ScalarExpression::column_expr(sub_query_schema[0].clone(), 0) }; Ok((sub_query, expr, correlated)) } @@ -431,17 +506,28 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T }); } }; - try_alias!(self.context, full_name); + if full_name.0.is_none() { + try_alias!(self.context, full_name); + } if self.context.allow_default { try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let source = match self.context.bind_source(&table) { + let (schema_ref, position_offset) = match Self::resolve_source_columns_in_scope( + &self.context, + &mut self.table_schema_buf, + &table, + ) { Ok(source) => source, Err(err) => { if let Some(parent) = self.parent { self.context.mark_outer_ref(); - parent.context.bind_source(&table).map_err(|_| { + Self::resolve_source_columns_in_scope( + &parent.context, + &mut self.table_schema_buf, + &table, + ) + .map_err(|_| { if let [table_ident, _] = idents { attach_span_from_sqlparser_span_if_absent(err, table_ident.span) } else { @@ -457,67 +543,36 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } } }; - let schema_buf = self.table_schema_buf.entry(table.into()).or_default(); + let (position, column) = Self::find_column_in_schema(&schema_ref, full_name.1.as_str()) + .ok_or_else(|| Self::column_not_found_with_span(idents, full_name.1.as_str()))?; Ok(ScalarExpression::column_expr( - source.column(&full_name.1, schema_buf).ok_or_else(|| { - let err = DatabaseError::column_not_found(full_name.1.to_string()); - match idents.last() { - Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), - None => err, - } - })?, + column.clone(), + position_offset + position, )) } else { - let op = - |got_column: &mut Option, - context: &BinderContext<'a, T>, - table_schema_buf: &mut HashMap>| { - for ((table_name, alias, _), source) in context.bind_table.iter() { - if got_column.is_some() { - break; - } - if let Some(alias) = alias { - *got_column = context.expr_aliases.iter().find_map( - |((alias_table, alias_column), expr)| { - matches!( - alias_table - .as_ref() - .map(|table_name| table_name == alias.as_ref() - && alias_column == &full_name.1), - Some(true) - ) - .then(|| expr.clone()) - }, - ); - } else if let Some(column) = { - let schema_buf = - table_schema_buf.entry(table_name.clone()).or_default(); - source.column(&full_name.1, schema_buf) - } { - *got_column = Some(ScalarExpression::column_expr(column)); - } - } - }; // handle col syntax - let mut got_column = None; - - op(&mut got_column, &self.context, &mut self.table_schema_buf); + let mut got_column = Self::find_column_in_scope( + &self.context, + &mut self.table_schema_buf, + full_name.1.as_str(), + ); if got_column.is_none() { if let Some(parent) = self.parent { self.context.mark_outer_ref(); - op(&mut got_column, &parent.context, &mut self.table_schema_buf); + got_column = Self::find_column_in_scope( + &parent.context, + &mut self.table_schema_buf, + full_name.1.as_str(), + ); } } match got_column { Some(column) => Ok(column), - None => { - let err = DatabaseError::column_not_found(full_name.1.clone()); - Err(match idents.last() { - Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), - None => err, - }) - } + None => Err(Self::column_not_found_with_span( + idents, + full_name.1.as_str(), + )), } } } diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 528a38d4..07096c79 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{attach_span_if_absent, lower_case_name, Binder}; +use crate::binder::{ + attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_case_name, lower_ident, + Binder, +}; use crate::errors::DatabaseError; use crate::expression::simplify::ConstantCalculator; use crate::expression::visitor_mut::VisitorMut; @@ -26,6 +29,7 @@ use crate::storage::Transaction; use crate::types::tuple::SchemaRef; use crate::types::value::DataValue; use sqlparser::ast::{Expr, Ident, ObjectName, Query}; +use std::borrow::Cow; use std::slice; use std::sync::Arc; @@ -144,12 +148,14 @@ impl> Binder<'_, '_, T, A> is_overwrite: bool, ) -> Result { let table_name: Arc = lower_case_name(name)?.into(); - let source = self - .context - .source_and_bind(table_name.clone(), None, None, false)? - .ok_or(DatabaseError::TableNotFound)?; - let mut schema_buf = None; - let table_schema = source.schema_ref(&mut schema_buf); + let table_schema = { + let source = self + .context + .source(&table_name)? + .ok_or(DatabaseError::TableNotFound)?; + let mut schema_buf = None; + source.schema_ref(&mut schema_buf) + }; let mut input_plan = self.bind_query(query)?; let input_schema = input_plan.output_schema().clone(); @@ -162,35 +168,45 @@ impl> Binder<'_, '_, T, A> input_len, )); } - table_schema - .iter() - .take(input_len) - .cloned() - .collect::>() + Cow::Borrowed(&table_schema[..input_len]) } else { let mut columns = Vec::with_capacity(idents.len()); + let source = self + .context + .source(&table_name)? + .ok_or(DatabaseError::TableNotFound)?; + let mut schema_buf = None; for ident in idents { - match self.bind_column_ref_from_identifiers( - slice::from_ref(ident), - Some(table_name.to_string()), - )? { - ScalarExpression::ColumnRef { column, .. } => columns.push(column), - _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), - } + let column_name = lower_ident(ident); + let column = source + .column(&column_name, &mut schema_buf) + .ok_or_else(|| { + attach_span_from_sqlparser_span_if_absent( + DatabaseError::column_not_found(column_name), + ident.span, + ) + })?; + columns.push(column); } if input_len != columns.len() { return Err(DatabaseError::ValuesLenMismatch(columns.len(), input_len)); } - columns + Cow::Owned(columns) }; let projection = input_schema .iter() + .enumerate() .zip(target_columns.iter()) - .map(|(input_column, target_column)| ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(input_column.clone())), - alias: AliasType::Name(target_column.name().to_string()), - }) + .map( + |((position, input_column), target_column)| ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + input_column.clone(), + position, + )), + alias: AliasType::Name(target_column.name().to_string()), + }, + ) .collect::>(); input_plan = self.bind_project(input_plan, projection)?; diff --git a/src/binder/mod.rs b/src/binder/mod.rs index fd252c75..7f0dc79d 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -167,6 +167,37 @@ pub enum SubQueryType { pub enum Source<'a> { Table(&'a TableCatalog), View(&'a View), + Schema(SchemaRef), +} + +#[derive(Debug, Clone)] +pub struct BoundSource<'a> { + pub(crate) table_name: TableName, + pub(crate) alias: Option, + pub(crate) join_type: Option, + pub(crate) source: Source<'a>, +} + +impl BoundSource<'_> { + pub(crate) fn matches_name(&self, table_name: &str) -> bool { + self.table_name.as_ref() == table_name + || matches!(self.alias.as_ref(), Some(alias) if alias.as_ref() == table_name) + } + + pub(crate) fn same_binding( + &self, + table_name: &TableName, + alias: Option<&TableName>, + join_type: Option, + ) -> bool { + self.table_name == *table_name + && self.alias.as_ref() == alias + && self.join_type == join_type + } + + pub(crate) fn visible_name(&self) -> &TableName { + self.alias.as_ref().unwrap_or(&self.table_name) + } } #[derive(Clone)] @@ -176,8 +207,9 @@ pub struct BinderContext<'a, T: Transaction> { pub(crate) table_cache: &'a TableCache, pub(crate) view_cache: &'a ViewCache, pub(crate) transaction: &'a T, - // Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain. - pub(crate) bind_table: BTreeMap<(TableName, Option, Option), Source<'a>>, + // Tips: retain binding order so wildcard expansion and position derivation + // follow FROM/JOIN order directly. + pub(crate) bind_table: Vec>, // alias expr_aliases: BTreeMap<(Option, String), ScalarExpression>, table_aliases: HashMap, @@ -207,24 +239,11 @@ impl Source<'_> { .get_or_insert_with(|| view.plan.output_schema_direct()) .columns() .find(|column| column.name() == name), + Source::Schema(schema_ref) => schema_ref.iter().find(|column| column.name() == name), } .cloned() } - pub(crate) fn columns<'a>( - &'a self, - schema_buf: &'a mut Option, - ) -> Box + 'a> { - match self { - Source::Table(table) => Box::new(table.columns()), - Source::View(view) => Box::new( - schema_buf - .get_or_insert_with(|| view.plan.output_schema_direct()) - .columns(), - ), - } - } - pub(crate) fn schema_ref(&self, schema_buf: &mut Option) -> SchemaRef { match self { Source::Table(table) => table.schema_ref().clone(), @@ -234,6 +253,7 @@ impl Source<'_> { SchemaOutput::SchemaRef(schema_ref) => schema_ref.clone(), } } + Source::Schema(schema_ref) => schema_ref.clone(), } } } @@ -352,23 +372,59 @@ impl<'a, T: Transaction> BinderContext<'a, T> { .map(Source::View); } if let Some(source) = &source { - self.bind_table.insert( - (table_name.clone(), alias.cloned(), join_type), + self.add_bound_source( + table_name.clone(), + alias.cloned(), + join_type, source.clone(), ); } Ok(source) } - pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source<'_>, DatabaseError> { - if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { - t.as_ref() == table_name - || matches!(alias.as_ref().map(|a| a.as_ref() == table_name), Some(true)) - }) { - Ok(source.1) + pub fn source(&self, table_name: &TableName) -> Result>, DatabaseError> { + let mut source = if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { + self.transaction.table(self.table_cache, real_name.clone()) } else { - Err(DatabaseError::invalid_table(table_name)) + self.transaction.table(self.table_cache, table_name.clone()) + }? + .map(Source::Table); + + if source.is_none() { + source = if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { + self.transaction + .view(self.table_cache, self.view_cache, real_name.clone()) + } else { + self.transaction + .view(self.table_cache, self.view_cache, table_name.clone()) + }? + .map(Source::View); + } + + Ok(source) + } + + pub fn add_bound_source( + &mut self, + table_name: TableName, + alias: Option, + join_type: Option, + source: Source<'a>, + ) { + if let Some(bound_source) = self + .bind_table + .iter_mut() + .find(|bound_source| bound_source.same_binding(&table_name, alias.as_ref(), join_type)) + { + bound_source.source = source; + return; } + self.bind_table.push(BoundSource { + table_name, + alias, + join_type, + source, + }); } // Tips: The order of this index is based on Aggregate being bound first. @@ -608,8 +664,13 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' } fn extend(&mut self, context: BinderContext<'a, T>) { - for (key, table) in context.bind_table { - self.context.bind_table.insert(key, table); + for bound_source in context.bind_table { + self.context.add_bound_source( + bound_source.table_name, + bound_source.alias, + bound_source.join_type, + bound_source.source, + ); } for (key, expr) in context.expr_aliases { self.context.expr_aliases.insert(key, expr); diff --git a/src/binder/select.rs b/src/binder/select.rs index 4049f987..ef6080cb 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -32,12 +32,14 @@ use super::{ Source, SubQueryType, }; -use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnSummary, TableName}; +use crate::catalog::{ + ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary, TableName, +}; use crate::errors::DatabaseError; use crate::execution::dql::join::joins_nullable; use crate::expression::agg::AggKind; use crate::expression::simplify::ConstantCalculator; -use crate::expression::visitor_mut::VisitorMut; +use crate::expression::visitor_mut::{walk_mut_expr, PositionShift, VisitorMut}; use crate::expression::{AliasType, BinaryOperator}; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::operator::except::ExceptOperator; @@ -59,7 +61,251 @@ use sqlparser::ast::{ TableAliasColumnDef, TableFactor, TableWithJoins, }; +struct RightSidePositionGlobalizer<'a> { + right_schema: &'a Schema, + left_len: usize, +} + +impl<'a> VisitorMut<'a> for RightSidePositionGlobalizer<'_> { + fn visit_column_ref( + &mut self, + column: &'a mut ColumnRef, + position: &'a mut usize, + ) -> Result<(), DatabaseError> { + if self.right_schema.contains(column) { + *position += self.left_len; + } + Ok(()) + } +} + +struct SplitScopePositionRebinder<'a> { + left_schema: &'a Schema, + right_schema: &'a Schema, +} + +impl VisitorMut<'_> for SplitScopePositionRebinder<'_> { + fn visit_column_ref( + &mut self, + column: &mut ColumnRef, + position: &mut usize, + ) -> Result<(), DatabaseError> { + if let Some(left_position) = self + .left_schema + .iter() + .position(|candidate| candidate.same_column(column)) + { + *position = left_position; + } else if let Some(right_position) = self + .right_schema + .iter() + .position(|candidate| candidate.same_column(column)) + { + *position = right_position; + } + Ok(()) + } +} + +struct ProjectionOutputBinder<'a> { + project_exprs: &'a [ScalarExpression], +} + +impl<'a> ProjectionOutputBinder<'a> { + fn new(project_exprs: &'a [ScalarExpression]) -> Self { + Self { project_exprs } + } + + fn output_ref(&self, expr: &ScalarExpression) -> Option { + self.project_exprs + .iter() + .position(|candidate| { + candidate == expr || candidate.unpack_alias_ref() == expr.unpack_alias_ref() + }) + .map(|position| { + let output_expr = &self.project_exprs[position]; + ScalarExpression::column_expr(output_expr.output_column(), position) + }) + } +} + +impl<'a> VisitorMut<'a> for ProjectionOutputBinder<'_> { + fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { + if let Some(output_ref) = self.output_ref(expr) { + *expr = output_ref; + return Ok(()); + } + walk_mut_expr(self, expr) + } +} + impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { + fn is_temp_alias_projection(exprs: &[ScalarExpression]) -> bool { + !exprs.is_empty() + && exprs.iter().all(|expr| { + matches!( + expr, + ScalarExpression::Alias { + alias: AliasType::Expr(alias_expr), + .. + } if matches!( + alias_expr.unpack_alias_ref(), + ScalarExpression::ColumnRef { column, .. } + if matches!( + &column.summary().relation, + crate::catalog::ColumnRelation::Table { is_temp: true, .. } + ) + ) + ) + }) + } + + fn is_joined_values_source(join_type: Option, source: &Source<'a>) -> bool { + join_type.is_some() + && matches!( + source, + Source::Schema(schema_ref) + if !schema_ref.is_empty() + && schema_ref.iter().all(|column| { + matches!( + &column.summary().relation, + ColumnRelation::Table { is_temp: true, .. } + ) && column.id() == Some(ColumnId::default()) + }) + ) + } + + fn bind_project_output_exprs<'c>( + project_exprs: &[ScalarExpression], + exprs: impl IntoIterator, + ) -> Result<(), DatabaseError> { + let mut binder = ProjectionOutputBinder::new(project_exprs); + for expr in exprs { + binder.visit(expr)?; + } + Ok(()) + } + + pub(crate) fn resolve_source_columns_in_scope( + context: &BinderContext<'a, T>, + table_schema_buf: &mut std::collections::HashMap>, + table_name: &str, + ) -> Result<(SchemaRef, usize), DatabaseError> { + let mut position_offset = 0; + + for bound_source in &context.bind_table { + let schema_buf = table_schema_buf + .entry(bound_source.table_name.clone()) + .or_default(); + let schema_ref = bound_source.source.schema_ref(schema_buf); + + if bound_source.matches_name(table_name) { + return Ok((schema_ref, position_offset)); + } + + position_offset += schema_ref.len(); + } + + Err(DatabaseError::invalid_table(table_name)) + } + + fn localize_join_condition_from_join_scope( + join_condition: &mut JoinCondition, + left_len: usize, + ) -> Result<(), DatabaseError> { + let JoinCondition::On { on, .. } = join_condition else { + return Ok(()); + }; + + let mut right_shift = PositionShift { + delta: -(left_len as isize), + }; + for (_, right_expr) in on { + right_shift.visit(right_expr)?; + } + + Ok(()) + } + + fn globalize_join_filter_from_split_scope( + join_condition: &mut JoinCondition, + left_len: usize, + right_schema: &Schema, + ) -> Result<(), DatabaseError> { + let JoinCondition::On { filter, .. } = join_condition else { + return Ok(()); + }; + + if let Some(expr) = filter { + RightSidePositionGlobalizer { + right_schema, + left_len, + } + .visit(expr)?; + } + + Ok(()) + } + + fn rebind_split_scope_positions( + expr: &mut ScalarExpression, + left_schema: &Schema, + right_schema: &Schema, + ) -> Result<(), DatabaseError> { + SplitScopePositionRebinder { + left_schema, + right_schema, + } + .visit(expr) + } + + fn build_join_from_split_scope_predicates( + mut children: LogicalPlan, + mut plan: LogicalPlan, + join_ty: JoinType, + predicates: impl IntoIterator, + rebind_positions: bool, + ) -> Result { + let left_schema = children.output_schema().clone(); + let right_schema = plan.output_schema().clone(); + let mut on_keys = Vec::new(); + let mut filter = Vec::new(); + + for mut predicate in predicates { + if rebind_positions { + Self::rebind_split_scope_positions( + &mut predicate, + left_schema.as_ref(), + right_schema.as_ref(), + )?; + } + Self::extract_join_keys( + predicate, + &mut on_keys, + &mut filter, + left_schema.as_ref(), + right_schema.as_ref(), + )?; + } + + let mut join_condition = JoinCondition::On { + on: on_keys, + filter: Self::combine_conjuncts(filter), + }; + Self::globalize_join_filter_from_split_scope( + &mut join_condition, + left_schema.len(), + right_schema.as_ref(), + )?; + + Ok(LJoinOperator::build( + children, + plan, + join_condition, + join_ty, + )) + } + pub(crate) fn bind_query(&mut self, query: &Query) -> Result { let origin_step = self.context.step_now(); @@ -117,11 +363,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' orderbys: &[OrderByExpr], ) -> Result { let saved_aliases = self.context.expr_aliases.clone(); - for column in plan.output_schema().iter() { + for (position, column) in plan.output_schema().iter().enumerate() { self.context.add_alias( None, column.name().to_string(), - ScalarExpression::column_expr(column.clone()), + ScalarExpression::column_expr(column.clone(), position), ); } @@ -129,7 +375,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' self.context.expr_aliases = saved_aliases; Ok(match sort_fields { - Some(sort_fields) => self.bind_sort(plan, sort_fields), + Some(sort_fields) => self.bind_sort(plan, sort_fields)?, None => plan, }) } @@ -156,7 +402,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } plan }; - let mut select_list = self.normalize_select_item(&select.projection, &mut plan)?; + let select_bind_step = self.context.step_now(); + self.context.step(QueryBindStep::Project); + let mut select_list = self.normalize_select_item(&select.projection)?; + self.context.step(select_bind_step); if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; @@ -194,7 +443,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' plan, self.context.agg_calls.clone(), self.context.group_by_exprs.clone(), - ); + )?; + self.bind_aggregate_output_exprs(select_list.iter_mut())?; + if let Some(orderby) = having_orderby.1.as_mut() { + self.bind_aggregate_output_exprs(orderby.iter_mut().map(|field| &mut field.expr))?; + } } if let Some(having) = having_orderby.0 { @@ -202,11 +455,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } if let Some(Distinct::Distinct) = select.distinct { - plan = self.bind_distinct(plan, select_list.clone()); + plan = self.bind_distinct(plan, select_list.clone())?; + let distinct_outputs = select_list.clone(); + self.bind_distinct_output_exprs(&distinct_outputs, select_list.iter_mut())?; + if let Some(orderby) = having_orderby.1.as_mut() { + self.bind_distinct_orderby_exprs(&distinct_outputs, orderby)?; + } } if let Some(orderby) = having_orderby.1 { - plan = self.bind_sort(plan, orderby); + plan = self.bind_sort(plan, orderby)?; } if !select_list.is_empty() { @@ -294,24 +552,32 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let left_schema = left_plan.output_schema(); let right_schema = right_plan.output_schema(); - for (left_schema, right_schema) in left_schema.iter().zip(right_schema.iter()) { + for (position, (left_schema, right_schema)) in + left_schema.iter().zip(right_schema.iter()).enumerate() + { let cast_type = LogicalType::max_logical_type(left_schema.datatype(), right_schema.datatype())?; if &cast_type != left_schema.datatype() { left_cast.push(ScalarExpression::TypeCast { - expr: Box::new(ScalarExpression::column_expr(left_schema.clone())), + expr: Box::new(ScalarExpression::column_expr(left_schema.clone(), position)), ty: cast_type.clone(), }); } else { - left_cast.push(ScalarExpression::column_expr(left_schema.clone())); + left_cast.push(ScalarExpression::column_expr(left_schema.clone(), position)); } if &cast_type != right_schema.datatype() { right_cast.push(ScalarExpression::TypeCast { - expr: Box::new(ScalarExpression::column_expr(right_schema.clone())), + expr: Box::new(ScalarExpression::column_expr( + right_schema.clone(), + position, + )), ty: cast_type.clone(), }); } else { - right_cast.push(ScalarExpression::column_expr(right_schema.clone())); + right_cast.push(ScalarExpression::column_expr( + right_schema.clone(), + position, + )); } } @@ -420,7 +686,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let distinct_exprs = left_schema .iter() .cloned() - .map(ScalarExpression::column_expr) + .enumerate() + .map(|(position, column)| ScalarExpression::column_expr(column, position)) .collect_vec(); let union_op = Operator::Union(UnionOperator { @@ -437,7 +704,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }, ), distinct_exprs, - )) + )?) } } SetOperator::Except => { @@ -452,16 +719,18 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let left_distinct_exprs = left_schema .iter() .cloned() - .map(ScalarExpression::column_expr) + .enumerate() + .map(|(position, column)| ScalarExpression::column_expr(column, position)) .collect_vec(); let right_distinct_exprs = right_schema .iter() .cloned() - .map(ScalarExpression::column_expr) + .enumerate() + .map(|(position, column)| ScalarExpression::column_expr(column, position)) .collect_vec(); - left_plan = self.bind_distinct(left_plan, left_distinct_exprs); - right_plan = self.bind_distinct(right_plan, right_distinct_exprs); + left_plan = self.bind_distinct(left_plan, left_distinct_exprs)?; + right_plan = self.bind_distinct(right_plan, right_distinct_exprs)?; left_schema = left_plan.output_schema(); right_schema = right_plan.output_schema(); @@ -508,8 +777,28 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' TableFactor::Derived { subquery, alias, .. } => { - let mut plan = self.bind_query(subquery)?; - let mut tables = plan.referenced_table(); + let BinderContext { + table_cache, + view_cache, + transaction, + scala_functions, + table_functions, + temp_table_id, + .. + } = &self.context; + let mut binder = Binder::new( + BinderContext::new( + table_cache, + view_cache, + *transaction, + scala_functions, + table_functions, + temp_table_id.clone(), + ), + self.args, + Some(self), + ); + let mut plan = binder.bind_query(subquery)?; if let Some(TableAlias { name, @@ -517,15 +806,47 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' .. }) = alias { - if tables.len() > 1 { - return Err(DatabaseError::UnsupportedStmt( - "Implement virtual tables for multiple table aliases".to_string(), - )); - } + let source_name = self.context.temp_table(); let table_alias: Arc = name.value.to_lowercase().into(); - let table = tables.pop().unwrap_or_else(|| self.context.temp_table()); - plan = self.bind_alias(plan, alias_column, table_alias, table)?; + plan = self.bind_alias( + plan, + alias_column, + table_alias.clone(), + source_name.clone(), + )?; + self.context.add_bound_source( + table_alias.clone(), + Some(table_alias), + joint_type, + Source::Schema(plan.output_schema().clone()), + ); + } else { + let passthrough_source = { + let output_schema = plan.output_schema().clone(); + let mut names = output_schema + .iter() + .filter_map(|column| column.table_name().cloned()); + let first = names.next(); + if first.is_some() && names.all(|name| Some(name) == first) { + first + } else { + None + } + }; + let needs_virtual_source = passthrough_source.is_none(); + let source_name = + passthrough_source.unwrap_or_else(|| self.context.temp_table()); + + if needs_virtual_source { + plan = self.bind_schema_source(plan, source_name.clone()); + } + self.context.add_bound_source( + source_name.clone(), + None, + joint_type, + Source::Schema(plan.output_schema().clone()), + ); } plan } @@ -552,9 +873,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )?; } + let source = if table_alias.is_some() { + Source::Schema(plan.output_schema().clone()) + } else { + Source::Table(table) + }; self.context - .bind_table - .insert((table_name, table_alias, joint_type), Source::Table(table)); + .add_bound_source(table_name, table_alias, joint_type, source); plan } else { unreachable!() @@ -595,17 +920,22 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' for (alias, column) in aliases_with_columns { let mut alias_column = ColumnCatalog::clone(&column); alias_column.set_name(alias.clone()); + let is_temp = matches!( + &column.summary().relation, + ColumnRelation::Table { is_temp: true, .. } + ); alias_column.set_ref_table( table_alias.clone(), column.id().unwrap_or(ColumnId::new()), - false, + is_temp, ); let alias_column_expr = ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(column)), - alias: AliasType::Expr(Box::new(ScalarExpression::column_expr(ColumnRef::from( - alias_column, - )))), + expr: Box::new(ScalarExpression::column_expr(column, alias_exprs.len())), + alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( + ColumnRef::from(alias_column), + alias_exprs.len(), + ))), }; self.context.add_alias( Some(table_alias.to_string()), @@ -618,6 +948,30 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' self.bind_project(plan, alias_exprs) } + fn bind_schema_source(&mut self, mut plan: LogicalPlan, source_name: TableName) -> LogicalPlan { + let input_schema = plan.output_schema(); + let mut source_exprs = Vec::with_capacity(input_schema.len()); + + for (position, column) in input_schema.iter().cloned().enumerate() { + let mut source_column = ColumnCatalog::clone(&column); + source_column.set_ref_table( + source_name.clone(), + column.id().unwrap_or(ColumnId::new()), + true, + ); + + source_exprs.push(ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr(column, position)), + alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( + ColumnRef::from(source_column), + position, + ))), + }); + } + + Self::build_project_plan(plan, source_exprs) + } + pub(crate) fn _bind_single_table_ref( &mut self, join_type: Option, @@ -641,10 +995,21 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let mut plan = match source { Source::Table(table) => TableScanOperator::build(table_name.clone(), table, with_pk)?, Source::View(view) => LogicalPlan::clone(&view.plan), + Source::Schema(_) => { + return Err(DatabaseError::UnsupportedStmt( + "derived source cannot be rebound as a base relation".to_string(), + )) + } }; - if let Some(idents) = alias_idents { - plan = self.bind_alias(plan, idents, table_alias.unwrap(), table_name.clone())?; + if let (Some(idents), Some(alias_name)) = (alias_idents, table_alias) { + plan = self.bind_alias(plan, idents, alias_name.clone(), table_name.clone())?; + self.context.add_bound_source( + table_name, + Some(alias_name), + join_type, + Source::Schema(plan.output_schema().clone()), + ); } Ok(plan) } @@ -658,7 +1023,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' fn normalize_select_item( &mut self, items: &[SelectItem], - plan: &mut LogicalPlan, ) -> Result, DatabaseError> { let mut select_items = vec![]; @@ -678,20 +1042,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }); } SelectItem::Wildcard(_) => { - if let Operator::Project(op) = &plan.operator { - for expr in op.exprs.iter() { - select_items.push(expr.clone()); - } - continue; - } - for (table_name, alias, _) in self.context.bind_table.keys() { - let schema_buf = - self.table_schema_buf.entry(table_name.clone()).or_default(); + for visible_name in self + .context + .bind_table + .iter() + .filter(|bound_source| { + !Self::is_joined_values_source( + bound_source.join_type, + &bound_source.source, + ) + }) + .map(|bound_source| bound_source.visible_name()) + .unique() + .cloned() + { Self::bind_table_column_refs( &self.context, - schema_buf, + &mut self.table_schema_buf, &mut select_items, - alias.as_ref().unwrap_or(table_name).clone(), + visible_name, false, )?; } @@ -707,11 +1076,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' ))) } }; - let schema_buf = self.table_schema_buf.entry(table_name.clone()).or_default(); - Self::bind_table_column_refs( &self.context, - schema_buf, + &mut self.table_schema_buf, &mut select_items, table_name, true, @@ -726,7 +1093,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' #[allow(unused_assignments)] fn bind_table_column_refs( context: &BinderContext<'a, T>, - schema_buf: &mut Option, + table_schema_buf: &mut std::collections::HashMap>, exprs: &mut Vec, table_name: TableName, is_qualified_wildcard: bool, @@ -739,38 +1106,47 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' || Some(&table_name) == column.table_name() && !context.using.contains(column) }; - let bound_alias = context + let (schema_ref, position_offset) = + Self::resolve_source_columns_in_scope(context, table_schema_buf, table_name.as_ref())?; + let mut pushed_alias_columns = false; + + for alias_column in context .expr_aliases - .iter() - .filter(|(_, expr)| { - if let ScalarExpression::ColumnRef { column, .. } = expr.unpack_alias_ref() { - if fn_not_on_using(column) { - exprs.push(ScalarExpression::clone(expr)); - return true; - } - } - false + .keys() + .filter_map(|(alias_table, alias_column)| { + matches!(alias_table.as_deref(), Some(alias) if alias == table_name.as_ref()) + .then_some(alias_column.as_str()) }) - .count() - > 0; + { + let Some((position, column)) = schema_ref + .iter() + .enumerate() + .find(|(_, column)| column.name() == alias_column) + else { + continue; + }; + if !fn_not_on_using(column) { + continue; + } + exprs.push(ScalarExpression::column_expr( + column.clone(), + position_offset + position, + )); + pushed_alias_columns = true; + } - if bound_alias { + if pushed_alias_columns { return Ok(()); } - let mut source = None; - source = context.table(table_name.clone())?.map(Source::Table); - if source.is_none() { - source = context.view(table_name.clone())?.map(Source::View); - } - for column in source - .ok_or(DatabaseError::SourceNotFound)? - .columns(schema_buf) - { + for (position, column) in schema_ref.iter().enumerate() { if !fn_not_on_using(column) { continue; } - exprs.push(ScalarExpression::column_expr(column.clone())); + exprs.push(ScalarExpression::column_expr( + column.clone(), + position_offset + position, + )); } Ok(()) } @@ -836,7 +1212,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let mut right = binder.bind_single_table_ref(relation, Some(join_type))?; self.extend(binder.context); - let on = match joint_condition { + let mut on = match joint_condition { Some(constraint) => self.bind_join_constraint( join_type, left.output_schema(), @@ -845,6 +1221,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )?, None => JoinCondition::None, }; + Self::localize_join_condition_from_join_scope(&mut on, left.output_schema().len())?; Ok(LJoinOperator::build(left, right, on, join_type)) } @@ -860,18 +1237,23 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if let Some(sub_queries) = self.context.sub_queries_at_now() { for sub_query in sub_queries { - let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; - let mut filter = vec![]; - - let (mut plan, join_ty) = match sub_query { - SubQueryType::SubQuery { plan, .. } => (plan, JoinType::Inner), + let (plan, join_ty) = match sub_query { + SubQueryType::SubQuery { plan, correlated } => { + if correlated { + return Err(DatabaseError::UnsupportedStmt( + "correlated scalar subqueries in WHERE are not supported" + .to_string(), + )); + } + (plan, JoinType::Inner) + } SubQueryType::ExistsSubQuery { negated, plan, correlated, } => { children = if correlated { - Self::bind_correlated_exists(children, plan, negated)? + self.bind_correlated_exists(children, plan, negated)? } else { Self::bind_uncorrelated_exists(children, plan, negated) }; @@ -883,7 +1265,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' correlated, } => { if correlated { - children = Self::bind_correlated_in_subquery( + children = self.bind_correlated_in_subquery( children, plan, negated, @@ -900,26 +1282,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } }; - Self::extract_join_keys( - predicate.clone(), - &mut on_keys, - &mut filter, - children.output_schema(), - plan.output_schema(), - )?; - - // combine multiple filter exprs into one BinaryExpr - let join_filter = Self::combine_conjuncts(filter); - - children = LJoinOperator::build( + children = Self::build_join_from_split_scope_predicates( children, plan, - JoinCondition::On { - on: on_keys, - filter: join_filter, - }, join_ty, - ); + std::iter::once(predicate.clone()), + true, + )?; } return Ok(children); } @@ -927,40 +1296,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } fn bind_correlated_exists( + &self, mut children: LogicalPlan, plan: LogicalPlan, negated: bool, ) -> Result { - let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; - let mut filter = vec![]; let join_ty = if negated { JoinType::LeftAnti } else { JoinType::LeftSemi }; - let (mut plan, correlated_filters) = - Self::prepare_correlated_exists_plan(plan, children.output_schema())?; - - for expr in correlated_filters { - Self::extract_join_keys( - expr, - &mut on_keys, - &mut filter, - children.output_schema(), - plan.output_schema(), - )?; - } - - children = LJoinOperator::build( + let (plan, correlated_filters) = + Self::prepare_correlated_subquery_plan(plan, children.output_schema(), false)?; + Self::build_join_from_split_scope_predicates( children, plan, - JoinCondition::On { - on: on_keys, - filter: Self::combine_conjuncts(filter), - }, join_ty, - ); - Ok(children) + correlated_filters, + false, + ) } fn bind_uncorrelated_exists( @@ -993,6 +1347,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }, left_expr: Box::new(ScalarExpression::column_expr( agg.output_schema()[0].clone(), + 0, )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, @@ -1013,50 +1368,27 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } fn bind_correlated_in_subquery( + &self, mut children: LogicalPlan, plan: LogicalPlan, negated: bool, predicate: ScalarExpression, ) -> Result { - let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; - let mut filter = vec![]; let join_ty = if negated { JoinType::LeftAnti } else { JoinType::LeftSemi }; - let (mut plan, correlated_filters) = - Self::prepare_correlated_exists_plan(plan, children.output_schema())?; + let (plan, correlated_filters) = + Self::prepare_correlated_subquery_plan(plan, children.output_schema(), true)?; let predicate = Self::rewrite_correlated_in_predicate(predicate); - - Self::extract_join_keys( - predicate, - &mut on_keys, - &mut filter, - children.output_schema(), - plan.output_schema(), - )?; - - for expr in correlated_filters { - Self::extract_join_keys( - expr, - &mut on_keys, - &mut filter, - children.output_schema(), - plan.output_schema(), - )?; - } - - children = LJoinOperator::build( + Self::build_join_from_split_scope_predicates( children, plan, - JoinCondition::On { - on: on_keys, - filter: Self::combine_conjuncts(filter), - }, join_ty, - ); - Ok(children) + std::iter::once(predicate).chain(correlated_filters), + false, + ) } fn rewrite_correlated_in_predicate(predicate: ScalarExpression) -> ScalarExpression { @@ -1087,13 +1419,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } fn plan_has_correlated_refs(plan: &LogicalPlan, left_schema: &Schema) -> bool { - let contains = |column: &ColumnRef| { - left_schema - .iter() - .any(|left| left.summary() == column.summary()) - }; + let contains = |column: &ColumnRef| left_schema.contains(column); - if plan.operator.referenced_columns(true).iter().any(contains) { + if plan.operator.any_referenced_column(true, contains) { return true; } @@ -1108,11 +1436,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } fn expr_has_correlated_refs(expr: &ScalarExpression, left_schema: &Schema) -> bool { - expr.referenced_columns(true).iter().any(|column| { - left_schema - .iter() - .any(|left| left.summary() == column.summary()) - }) + expr.any_referenced_column(true, |column| left_schema.contains(column)) } fn split_conjuncts(expr: ScalarExpression, exprs: &mut Vec) { @@ -1142,9 +1466,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }) } - fn prepare_correlated_exists_plan( + fn prepare_correlated_subquery_plan( plan: LogicalPlan, left_schema: &Schema, + preserve_projection: bool, ) -> Result<(LogicalPlan, Vec), DatabaseError> { match plan.childrens.as_ref() { Childrens::Only(_) => {} @@ -1166,8 +1491,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' .. } => { let child = childrens.pop_only(); - let (child, mut correlated_filters) = - Self::prepare_correlated_exists_plan(child, left_schema)?; + let (child, mut correlated_filters) = Self::prepare_correlated_subquery_plan( + child, + left_schema, + preserve_projection, + )?; let mut local_filters = Vec::new(); let mut predicates = Vec::new(); Self::split_conjuncts(op.predicate, &mut predicates); @@ -1186,11 +1514,28 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok((plan, correlated_filters)) } LogicalPlan { - operator: Operator::Project(_), + operator: Operator::Project(op), childrens, .. + } => { + let child = childrens.pop_only(); + let (child, mut correlated_filters) = Self::prepare_correlated_subquery_plan( + child, + left_schema, + preserve_projection, + )?; + + if !preserve_projection || Self::is_temp_alias_projection(&op.exprs) { + Ok((child, correlated_filters)) + } else { + Self::bind_project_output_exprs(&op.exprs, correlated_filters.iter_mut())?; + Ok(( + LogicalPlan::new(Operator::Project(op), Childrens::Only(Box::new(child))), + correlated_filters, + )) + } } - | LogicalPlan { + LogicalPlan { operator: Operator::Sort(_), childrens, .. @@ -1204,7 +1549,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' operator: Operator::TopK(_), childrens, .. - } => Self::prepare_correlated_exists_plan(childrens.pop_only(), left_schema), + } => Self::prepare_correlated_subquery_plan( + childrens.pop_only(), + left_schema, + preserve_projection, + ), plan => { if Self::plan_has_correlated_refs(&plan, left_schema) { Err(DatabaseError::UnsupportedStmt( @@ -1221,37 +1570,82 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' fn bind_having( &mut self, children: LogicalPlan, - having: ScalarExpression, + mut having: ScalarExpression, ) -> Result { self.context.step(QueryBindStep::Having); self.validate_having_orderby(&having)?; + self.bind_aggregate_output_exprs(std::iter::once(&mut having))?; Ok(FilterOperator::build(having, children, true)) } - pub(crate) fn bind_project( - &mut self, + pub(crate) fn build_project_plan( children: LogicalPlan, select_list: Vec, + ) -> LogicalPlan { + LogicalPlan::new( + Operator::Project(ProjectOperator { exprs: select_list }), + Childrens::Only(Box::new(children)), + ) + } + + pub(crate) fn bind_project( + &mut self, + mut children: LogicalPlan, + mut select_list: Vec, ) -> Result { self.context.step(QueryBindStep::Project); - Ok(LogicalPlan::new( - Operator::Project(ProjectOperator { exprs: select_list }), - Childrens::Only(Box::new(children)), - )) + if let Some(sub_queries) = self.context.sub_queries_at_now() { + for sub_query in sub_queries { + let SubQueryType::SubQuery { + mut plan, + correlated, + } = sub_query + else { + return Err(DatabaseError::UnsupportedStmt( + "only scalar subqueries are supported in SELECT list".to_string(), + )); + }; + + if correlated { + return Err(DatabaseError::UnsupportedStmt( + "correlated scalar subqueries in SELECT list are not supported".to_string(), + )); + } + + let left_len = children.output_schema().len(); + let right_schema = plan.output_schema().clone(); + for expr in select_list.iter_mut() { + RightSidePositionGlobalizer { + right_schema: right_schema.as_ref(), + left_len, + } + .visit(expr)?; + } + + children = + LJoinOperator::build(children, plan, JoinCondition::None, JoinType::Cross); + } + } + + Ok(Self::build_project_plan(children, select_list)) } - fn bind_sort(&mut self, children: LogicalPlan, sort_fields: Vec) -> LogicalPlan { + fn bind_sort( + &mut self, + children: LogicalPlan, + sort_fields: Vec, + ) -> Result { self.context.step(QueryBindStep::Sort); - LogicalPlan::new( + Ok(LogicalPlan::new( Operator::Sort(SortOperator { sort_fields, limit: None, }), Childrens::Only(Box::new(children)), - ) + )) } fn bind_non_negative_limit_value(&mut self, expr: &Expr) -> Result { @@ -1311,22 +1705,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } pub fn extract_select_join(&mut self, select_items: &mut [ScalarExpression]) { - let bind_tables = &self.context.bind_table; - if bind_tables.len() < 2 { + if self.context.bind_table.len() < 2 { return; } - let mut table_force_nullable = Vec::with_capacity(bind_tables.len()); + let mut table_force_nullable = Vec::with_capacity(self.context.bind_table.len()); let mut left_table_force_nullable = false; let mut left_table = None; - for ((table_name, _, join_option), table) in bind_tables { - if let Some(join_type) = join_option { - let (left_force_nullable, right_force_nullable) = joins_nullable(join_type); - table_force_nullable.push((table_name, table, right_force_nullable)); + for bound_source in &self.context.bind_table { + if let Some(join_type) = bound_source.join_type { + let (left_force_nullable, right_force_nullable) = joins_nullable(&join_type); + table_force_nullable.push(( + &bound_source.table_name, + &bound_source.source, + right_force_nullable, + )); left_table_force_nullable = left_force_nullable; } else { - left_table = Some((table_name, table)); + left_table = Some((&bound_source.table_name, &bound_source.source)); } } @@ -1394,15 +1791,21 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }) } JoinConstraint::Using(idents) => { - fn find_column<'a>(schema: &'a Schema, name: &'a str) -> Option<&'a ColumnRef> { - schema.iter().find(|column| column.name() == name) + fn find_column<'a>( + schema: &'a Schema, + name: &'a str, + ) -> Option<(usize, &'a ColumnRef)> { + schema + .iter() + .enumerate() + .find(|(_, column)| column.name() == name) } let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new(); for ident in idents { let name = lower_case_name(ident)?; - let (Some(left_column), Some(right_column)) = ( + let (Some((left_position, left_column)), Some((right_position, right_column))) = ( find_column(left_schema, &name), find_column(right_schema, &name), ) else { @@ -1413,8 +1816,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; self.context.add_using(join_type, left_column, right_column); on_keys.push(( - ScalarExpression::column_expr(left_column.clone()), - ScalarExpression::column_expr(right_column.clone()), + ScalarExpression::column_expr(left_column.clone(), left_position), + ScalarExpression::column_expr( + right_column.clone(), + left_schema.len() + right_position, + ), )); } Ok(JoinCondition::On { @@ -1430,12 +1836,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new(); for name in fn_names(left_schema).intersection(&fn_names(right_schema)) { - if let (Some(left_column), Some(right_column)) = ( - left_schema.iter().find(|column| column.name() == *name), - right_schema.iter().find(|column| column.name() == *name), + if let ( + Some((left_position, left_column)), + Some((right_position, right_column)), + ) = ( + left_schema + .iter() + .enumerate() + .find(|(_, column)| column.name() == *name), + right_schema + .iter() + .enumerate() + .find(|(_, column)| column.name() == *name), ) { - let left_expr = ScalarExpression::column_expr(left_column.clone()); - let right_expr = ScalarExpression::column_expr(right_column.clone()); + let left_expr = + ScalarExpression::column_expr(left_column.clone(), left_position); + let right_expr = ScalarExpression::column_expr( + right_column.clone(), + left_schema.len() + right_position, + ); self.context.add_using(join_type, left_column, right_column); on_keys.push((left_expr, right_expr)); @@ -1526,9 +1945,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } _other => { // example: baz > 1 - if left_expr.referenced_columns(true).iter().all(|column| { + if left_expr.all_referenced_columns(true, |column| { fn_or_contains(left_schema, right_schema, column.summary()) - }) && right_expr.referenced_columns(true).iter().all(|column| { + }) && right_expr.all_referenced_columns(true, |column| { fn_or_contains(left_schema, right_schema, column.summary()) }) { accum_filter.push(ScalarExpression::Binary { @@ -1569,9 +1988,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }); } _ => { - if left_expr.referenced_columns(true).iter().all(|column| { + if left_expr.all_referenced_columns(true, |column| { fn_or_contains(left_schema, right_schema, column.summary()) - }) && right_expr.referenced_columns(true).iter().all(|column| { + }) && right_expr.all_referenced_columns(true, |column| { fn_or_contains(left_schema, right_schema, column.summary()) }) { accum_filter.push(ScalarExpression::Binary { @@ -1586,11 +2005,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } expr => { - if expr - .referenced_columns(true) - .iter() - .all(|column| fn_or_contains(left_schema, right_schema, column.summary())) - { + if expr.all_referenced_columns(true, |column| { + fn_or_contains(left_schema, right_schema, column.summary()) + }) { // example: baz > 1 accum_filter.push(expr); } @@ -1603,8 +2020,27 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { + use super::{ProjectionOutputBinder, RightSidePositionGlobalizer}; use crate::binder::test::build_t1_table; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::errors::DatabaseError; + use crate::expression::visitor_mut::VisitorMut; + use crate::expression::{AliasType, ScalarExpression}; + use crate::planner::operator::join::{JoinCondition, JoinType}; + use crate::planner::operator::Operator; + use crate::planner::{Childrens, LogicalPlan}; + use crate::types::LogicalType; + + fn test_column(name: &str, position: usize) -> ScalarExpression { + ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new( + name.to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), + )), + position, + ) + } #[test] fn test_select_bind() -> Result<(), DatabaseError> { @@ -1635,4 +2071,161 @@ mod tests { Ok(()) } + + #[test] + fn test_right_side_position_globalizer_only_shifts_right_columns() -> Result<(), DatabaseError> + { + let left_column = ColumnRef::from(ColumnCatalog::new( + "left".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), + )); + let right_column = ColumnRef::from(ColumnCatalog::new( + "right".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), + )); + let right_schema = vec![right_column.clone()]; + let mut expr = ScalarExpression::Binary { + op: crate::expression::BinaryOperator::Eq, + left_expr: Box::new(ScalarExpression::column_expr(left_column, 0)), + right_expr: Box::new(ScalarExpression::column_expr(right_column.clone(), 0)), + evaluator: None, + ty: LogicalType::Boolean, + }; + + RightSidePositionGlobalizer { + right_schema: &right_schema, + left_len: 2, + } + .visit(&mut expr)?; + + let ScalarExpression::Binary { + left_expr, + right_expr, + .. + } = expr + else { + unreachable!() + }; + let ScalarExpression::ColumnRef { + position: left_position, + .. + } = left_expr.as_ref() + else { + unreachable!() + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = right_expr.as_ref() + else { + unreachable!() + }; + assert_eq!((*left_position, *right_position), (0, 2)); + + Ok(()) + } + + #[test] + fn test_projection_output_binder_rewrites_to_project_slot() -> Result<(), DatabaseError> { + let project_output = ScalarExpression::Alias { + expr: Box::new(test_column("c1", 0)), + alias: AliasType::Name("v".to_string()), + }; + let mut expr = ScalarExpression::Alias { + expr: Box::new(test_column("c1", 0)), + alias: AliasType::Name("v".to_string()), + }; + + ProjectionOutputBinder::new(std::slice::from_ref(&project_output)).visit(&mut expr)?; + + assert_eq!( + expr, + ScalarExpression::column_expr(project_output.output_column(), 0) + ); + Ok(()) + } + + fn find_join(plan: &LogicalPlan) -> Option<(&JoinType, &JoinCondition)> { + if let Operator::Join(op) = &plan.operator { + return Some((&op.join_type, &op.on)); + } + + match plan.childrens.as_ref() { + Childrens::Only(child) => find_join(child), + Childrens::Twins { left, right } => find_join(left).or_else(|| find_join(right)), + Childrens::None => None, + } + } + + #[test] + fn test_scalar_subquery_in_where_binds_as_inner_join() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states.plan("select * from t1 where c1 = (select max(c3) from t2)")?; + let Some((join_type, join_condition)) = find_join(&plan) else { + panic!("expected scalar subquery to introduce a join") + }; + + assert_eq!(*join_type, JoinType::Inner); + assert!(matches!(join_condition, JoinCondition::On { .. })); + + Ok(()) + } + + fn find_top_join(plan: &LogicalPlan) -> Option<&LogicalPlan> { + if matches!(plan.operator, Operator::Join(_)) { + return Some(plan); + } + + match plan.childrens.as_ref() { + Childrens::Only(child) => find_top_join(child), + Childrens::Twins { .. } | Childrens::None => None, + } + } + + fn collect_column_positions(expr: &ScalarExpression, positions: &mut Vec) { + match expr.unpack_alias_ref() { + ScalarExpression::ColumnRef { position, .. } => positions.push(*position), + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => { + collect_column_positions(left_expr, positions); + collect_column_positions(right_expr, positions); + } + _ => {} + } + } + + #[test] + fn test_multiple_scalar_subqueries_in_where_rebind_positions() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = + table_states.plan("select * from t1 where c1 <= (select 4) and c1 > (select 1)")?; + let outer_join = + find_top_join(&plan).expect("expected scalar subqueries to introduce a join"); + let Operator::Join(op) = &outer_join.operator else { + panic!("expected join plan") + }; + let Childrens::Twins { left, .. } = outer_join.childrens.as_ref() else { + panic!("expected binary join") + }; + let JoinCondition::On { + filter: Some(filter), + .. + } = &op.on + else { + panic!("expected join filter") + }; + let left_len = left.output_schema_direct().columns().count(); + + let mut positions = Vec::new(); + collect_column_positions(filter, &mut positions); + + assert_eq!(positions, vec![0, left_len - 1, 0, left_len]); + + Ok(()) + } } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 13a6db70..510f915c 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -66,6 +66,10 @@ pub struct ColumnSummary { } impl ColumnRef { + pub(crate) fn same_column(&self, other: &ColumnRef) -> bool { + self.summary() == other.summary() + } + pub(crate) fn nullable_for_join(&self, nullable: bool) -> Option { if self.nullable == nullable { return None; @@ -201,6 +205,39 @@ impl ColumnCatalog { } } +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::errors::DatabaseError; + use crate::types::LogicalType; + + #[test] + fn test_same_column_ignores_nullable_and_desc() -> Result<(), DatabaseError> { + let mut left = ColumnCatalog::new( + "c1".to_string(), + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + ); + let mut right = ColumnCatalog::new( + "c1".to_string(), + true, + ColumnDesc::new(LogicalType::Bigint, None, false, None)?, + ); + let left_ref = ColumnRef::from(left.clone()); + let right_ref = ColumnRef::from(right.clone()); + + assert_ne!(left_ref, right_ref); + assert!(left_ref.same_column(&right_ref)); + + left.set_name("c2".to_string()); + right.set_name("c3".to_string()); + let left_ref = ColumnRef::from(left); + let right_ref = ColumnRef::from(right); + assert!(!left_ref.same_column(&right_ref)); + Ok(()) + } +} + /// The descriptor of a column. #[derive(Debug, Clone, PartialEq, Eq, Hash, ReferenceSerialization)] pub struct ColumnDesc { diff --git a/src/db.rs b/src/db.rs index 8d2ad291..ccd26380 100644 --- a/src/db.rs +++ b/src/db.rs @@ -14,7 +14,7 @@ use crate::binder::{command_type, Binder, BinderContext, CommandType}; use crate::errors::DatabaseError; -use crate::execution::{build_write, Executor}; +use crate::execution::{build_write, ExecArena, Executor}; use crate::expression::function::scala::ScalarFunctionImpl; use crate::expression::function::table::TableFunctionImpl; use crate::expression::function::FunctionSummary; @@ -32,8 +32,10 @@ use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::parser::parse_sql; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; +#[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] +use crate::storage::lmdb::{LmdbConfig, LmdbStorage}; use crate::storage::memory::MemoryStorage; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] use crate::storage::rocksdb::{OptimisticRocksStorage, RocksStorage, StorageConfig}; use crate::storage::{StatisticsMetaCache, Storage, TableCache, Transaction, ViewCache}; use crate::types::tuple::{SchemaRef, Tuple}; @@ -115,8 +117,10 @@ pub struct DataBaseBuilder { scala_functions: ScalaFunctions, table_functions: TableFunctions, histogram_buckets: Option, - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] storage_config: StorageConfig, + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + lmdb_config: LmdbConfig, } impl DataBaseBuilder { @@ -130,8 +134,10 @@ impl DataBaseBuilder { scala_functions: Default::default(), table_functions: Default::default(), histogram_buckets: None, - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] storage_config: Default::default(), + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + lmdb_config: Default::default(), }; builder = builder.register_scala_function(CharLength::new("char_length".to_lowercase())); builder = @@ -168,9 +174,56 @@ impl DataBaseBuilder { } /// Enables or disables RocksDB statistics collection. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all( + not(target_arch = "wasm32"), + any(feature = "rocksdb", feature = "lmdb") + ))] pub fn storage_statistics(mut self, enable: bool) -> Self { - self.storage_config.enable_statistics = enable; + #[cfg(feature = "rocksdb")] + { + self.storage_config.enable_statistics = enable; + } + #[cfg(feature = "lmdb")] + { + self.lmdb_config.enable_statistics = enable; + } + self + } + + /// Sets the LMDB map size in bytes. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn lmdb_map_size(mut self, map_size: usize) -> Self { + self.lmdb_config.map_size = map_size; + self + } + + /// Sets the LMDB environment flags. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn lmdb_flags(mut self, flags: lmdb::EnvironmentFlags) -> Self { + self.lmdb_config.flags = flags; + self + } + + /// Enables or disables LMDB `NO_SYNC`. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn lmdb_no_sync(mut self, enable: bool) -> Self { + self.lmdb_config + .flags + .set(lmdb::EnvironmentFlags::NO_SYNC, enable); + self + } + + /// Sets the maximum number of LMDB readers. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn lmdb_max_readers(mut self, max_readers: u32) -> Self { + self.lmdb_config.max_readers = Some(max_readers); + self + } + + /// Sets the maximum number of LMDB named databases. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn lmdb_max_dbs(mut self, max_dbs: u32) -> Self { + self.lmdb_config.max_dbs = Some(max_dbs); self } @@ -198,8 +251,8 @@ impl DataBaseBuilder { } /// Builds a RocksDB-backed database. - #[cfg(not(target_arch = "wasm32"))] - pub fn build(self) -> Result, DatabaseError> { + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] + pub fn build_rocksdb(self) -> Result, DatabaseError> { let storage = RocksStorage::with_config(self.path, self.storage_config)?; Self::_build::( @@ -224,9 +277,22 @@ impl DataBaseBuilder { ) } - #[cfg(not(target_arch = "wasm32"))] + /// Builds a LMDB-backed database. + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + pub fn build_lmdb(self) -> Result, DatabaseError> { + let storage = LmdbStorage::with_config(self.path, self.lmdb_config)?; + + Self::_build::( + storage, + self.scala_functions, + self.table_functions, + self.histogram_buckets, + ) + } + + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] /// Builds a RocksDB-backed database that uses optimistic transactions. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] pub fn build_optimistic(self) -> Result, DatabaseError> { let storage = OptimisticRocksStorage::with_config(self.path, self.storage_config)?; @@ -280,10 +346,12 @@ fn default_optimizer_pipeline() -> HepOptimizerPipeline { .before_batch( "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![ - NormalizationRuleImpl::SimplifyFilter, - NormalizationRuleImpl::ConstantCalculation, - ], + vec![NormalizationRuleImpl::SimplifyFilter], + ) + .before_batch( + "Constant Calculation".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::ConstantCalculation], ) .before_batch( "Predicate Pushdown".to_string(), @@ -320,21 +388,10 @@ fn default_optimizer_pipeline() -> HepOptimizerPipeline { NormalizationRuleImpl::CombineFilter, ], ) - .after_batch( - "Eliminate Aggregate".to_string(), - HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::EliminateRedundantSort, - NormalizationRuleImpl::UseStreamDistinct, - ], - ) .after_batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::BindExpressionPosition, - NormalizationRuleImpl::EvaluatorBind, - ], + vec![NormalizationRuleImpl::EvaluatorBind], ) .implementations(vec![ // DQL @@ -345,6 +402,7 @@ fn default_optimizer_pipeline() -> HepOptimizerPipeline { ImplementationRuleImpl::HashJoin, ImplementationRuleImpl::Limit, ImplementationRuleImpl::Projection, + ImplementationRuleImpl::ScalarSubquery, ImplementationRuleImpl::SeqScan, ImplementationRuleImpl::IndexScan, ImplementationRuleImpl::FunctionScan, @@ -443,12 +501,15 @@ impl State { Ok(best_plan) } - fn execute<'a, A: AsRef<[(&'static str, DataValue)]>>( + fn execute<'a, 'txn, A: AsRef<[(&'static str, DataValue)]>>( &'a self, - transaction: &'a mut S::TransactionType<'_>, + transaction: &'a mut S::TransactionType<'txn>, stmt: &Statement, params: A, - ) -> Result<(SchemaRef, Executor<'a>), DatabaseError> { + ) -> Result<(SchemaRef, Executor<'a, S::TransactionType<'txn>>), DatabaseError> + where + S: 'txn, + { let mut plan = self.build_plan( stmt, params, @@ -460,11 +521,14 @@ impl State { self.table_functions(), )?; let schema = plan.output_schema().clone(); - let executor = build_write( + let mut arena = ExecArena::default(); + let root = build_write( + &mut arena, plan, (&self.table_cache, &self.view_cache, &self.meta_cache), transaction, ); + let executor = Executor::new(arena, root); Ok((schema, executor)) } @@ -611,14 +675,38 @@ impl Database { } } -/// Common interface for result iterators returned by database execution APIs. -/// -/// A result iterator streams [`Tuple`] values and exposes the output schema of -/// the current statement. -pub trait ResultIter: Iterator> { +/// Borrowing interface for result iterators returned by database execution APIs. +pub trait BorrowResultIter { /// Returns the output schema for the current result set. fn schema(&self) -> &SchemaRef; + /// Returns the next row as a borrowed tuple. + fn next_borrowed_tuple(&mut self) -> Result, DatabaseError>; + + /// Creates a mapped iterator that transforms borrowed tuples into owned output values. + fn map_result(self, mapper: F) -> MappedResultIter + where + Self: Sized, + F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, + { + let schema = self.schema().clone(); + MappedResultIter { + inner: self, + mapper, + schema, + _marker: PhantomData, + } + } + + /// Finishes consuming the iterator and flushes any remaining work. + fn done(self) -> Result<(), DatabaseError>; +} + +/// Common interface for owned-tuple result iterators. +/// +/// This remains for compatibility with existing callers that expect +/// `Iterator>`. +pub trait ResultIter: BorrowResultIter + Iterator> { #[cfg(feature = "orm")] /// Converts this iterator into a typed ORM iterator. /// @@ -632,9 +720,46 @@ pub trait ResultIter: Iterator> { { OrmIter::new(self) } +} - /// Finishes consuming the iterator and flushes any remaining work. - fn done(self) -> Result<(), DatabaseError>; +impl ResultIter for I where I: BorrowResultIter + Iterator> {} + +/// Typed adapter over a borrowing result iterator. +pub struct MappedResultIter { + inner: I, + mapper: F, + schema: SchemaRef, + _marker: PhantomData, +} + +impl MappedResultIter +where + I: BorrowResultIter, + F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, +{ + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + pub fn done(self) -> Result<(), DatabaseError> { + self.inner.done() + } +} + +impl Iterator for MappedResultIter +where + I: BorrowResultIter, + F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match self.inner.next_borrowed_tuple() { + Ok(Some(tuple)) => Some((self.mapper)(&self.schema, tuple)), + Ok(None) => None, + Err(err) => Some(Err(err)), + } + } } #[cfg(feature = "orm")] @@ -690,7 +815,7 @@ where /// Raw result iterator returned by [`Database::run`] and [`Database::execute`]. pub struct DatabaseIter<'a, S: Storage + 'a> { transaction: *mut S::TransactionType<'a>, - inner: *mut TransactionIter<'a>, + inner: *mut TransactionIter<'a, S::TransactionType<'a>>, _guard: Option, } @@ -705,6 +830,33 @@ impl Drop for DatabaseIter<'_, S> { } } +impl DatabaseIter<'_, S> { + #[inline] + pub fn schema(&self) -> &SchemaRef { + unsafe { (*self.inner).schema() } + } + + #[inline] + pub fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { + let result = unsafe { (*self.inner).next_borrowed_tuple() }; + if result.as_ref().is_ok_and(Option::is_none) { + self._guard = None; + } + result + } + + #[inline] + pub fn done(mut self) -> Result<(), DatabaseError> { + unsafe { + Box::from_raw(mem::replace(&mut self.inner, std::ptr::null_mut())).done()?; + } + unsafe { + Box::from_raw(mem::replace(&mut self.transaction, std::ptr::null_mut())).commit()?; + } + Ok(()) + } +} + impl Iterator for DatabaseIter<'_, S> { type Item = Result; @@ -717,19 +869,17 @@ impl Iterator for DatabaseIter<'_, S> { } } -impl ResultIter for DatabaseIter<'_, S> { +impl BorrowResultIter for DatabaseIter<'_, S> { fn schema(&self) -> &SchemaRef { - unsafe { (*self.inner).schema() } + DatabaseIter::schema(self) } - fn done(mut self) -> Result<(), DatabaseError> { - unsafe { - Box::from_raw(mem::replace(&mut self.inner, std::ptr::null_mut())).done()?; - } - unsafe { - Box::from_raw(mem::replace(&mut self.transaction, std::ptr::null_mut())).commit()?; - } - Ok(()) + fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { + DatabaseIter::next_borrowed_tuple(self) + } + + fn done(self) -> Result<(), DatabaseError> { + DatabaseIter::done(self) } } @@ -740,9 +890,12 @@ pub struct DBTransaction<'a, S: Storage + 'a> { state: Arc>, } -impl DBTransaction<'_, S> { +impl<'txn, S: Storage> DBTransaction<'txn, S> { /// Runs SQL inside the current transaction and returns the final result iterator. - pub fn run>(&mut self, sql: T) -> Result, DatabaseError> { + pub fn run<'a, T: AsRef>( + &'a mut self, + sql: T, + ) -> Result>, DatabaseError> { let sql = sql.as_ref(); let mut statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; let last_statement = statements @@ -761,11 +914,11 @@ impl DBTransaction<'_, S> { } /// Executes a prepared [`Statement`] inside the current transaction. - pub fn execute>( - &mut self, + pub fn execute<'a, A: AsRef<[(&'static str, DataValue)]>>( + &'a mut self, statement: &Statement, params: A, - ) -> Result, DatabaseError> { + ) -> Result>, DatabaseError> { if matches!(command_type(statement)?, CommandType::DDL) { return Err(DatabaseError::UnsupportedStmt( "`DDL` is not allowed to execute within a transaction".to_string(), @@ -784,47 +937,73 @@ impl DBTransaction<'_, S> { } /// Raw result iterator returned by [`DBTransaction::run`] and [`DBTransaction::execute`]. -pub struct TransactionIter<'a> { - executor: Executor<'a>, +pub struct TransactionIter<'a, T: Transaction + 'a> { + executor: Executor<'a, T>, schema: SchemaRef, } -impl<'a> TransactionIter<'a> { - fn new(schema: SchemaRef, executor: Executor<'a>) -> Self { +impl<'a, T: Transaction + 'a> TransactionIter<'a, T> { + fn new(schema: SchemaRef, executor: Executor<'a, T>) -> Self { Self { executor, schema } } + + #[inline] + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + #[inline] + pub fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { + self.executor.next_tuple() + } + + #[inline] + pub fn done(mut self) -> Result<(), DatabaseError> { + while self.next_borrowed_tuple()?.is_some() {} + Ok(()) + } } -impl Iterator for TransactionIter<'_> { +impl Iterator for TransactionIter<'_, T> { type Item = Result; fn next(&mut self) -> Option { - self.executor.next() + match self.executor.next_tuple() { + Ok(Some(tuple)) => Some(Ok(tuple.clone())), + Ok(None) => None, + Err(err) => Some(Err(err)), + } } } -impl ResultIter for TransactionIter<'_> { +impl BorrowResultIter for TransactionIter<'_, T> { fn schema(&self) -> &SchemaRef { - &self.schema + TransactionIter::schema(self) } - fn done(mut self) -> Result<(), DatabaseError> { - for result in self.by_ref() { - let _ = result?; - } - Ok(()) + fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { + TransactionIter::next_borrowed_tuple(self) + } + + fn done(self) -> Result<(), DatabaseError> { + TransactionIter::done(self) } } #[cfg(all(test, not(target_arch = "wasm32")))] pub(crate) mod test { + use crate::binder::{Binder, BinderContext}; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; - use crate::db::{DataBaseBuilder, DatabaseError, ResultIter}; + use crate::db::{DataBaseBuilder, DatabaseError}; + use crate::expression::ScalarExpression; + use crate::planner::operator::join::JoinCondition; + use crate::planner::operator::Operator; use crate::storage::{Storage, TableCache, Transaction}; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; use chrono::{Datelike, Local}; + use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; @@ -857,7 +1036,7 @@ pub(crate) mod test { #[test] fn test_run_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let database = DataBaseBuilder::path(temp_dir.path()).build()?; + let database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; let mut transaction = database.storage.transaction()?; build_table(database.state.table_cache(), &mut transaction)?; @@ -873,7 +1052,7 @@ pub(crate) mod test { #[test] fn test_udf() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; let mut iter = kite_sql.run("select current_date()")?; assert_eq!( @@ -900,7 +1079,7 @@ pub(crate) mod test { #[test] fn test_udtf() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; let mut iter = kite_sql.run( "SELECT * FROM (select * from table(numbers(10)) a ORDER BY number LIMIT 5) OFFSET 3", )?; @@ -925,10 +1104,306 @@ pub(crate) mod test { Ok(()) } + #[test] + fn test_join_on_alias_right_key_is_localized() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + + kite_sql + .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? + .done()?; + kite_sql + .run("CREATE TABLE empty (e_id INT PRIMARY KEY, x INT)")? + .done()?; + + let stmt = crate::db::prepare( + "SELECT * FROM onecolumn AS a(aid, x) JOIN empty AS b(bid, y) ON a.x = b.y", + )?; + let transaction = kite_sql.storage.transaction()?; + let mut binder = Binder::new( + BinderContext::new( + kite_sql.state.table_cache(), + kite_sql.state.view_cache(), + &transaction, + kite_sql.state.scala_functions(), + kite_sql.state.table_functions(), + Arc::new(AtomicUsize::new(0)), + ), + &[], + None, + ); + let source_plan = binder.bind(&stmt)?; + let best_plan = kite_sql.state.build_plan( + &stmt, + [], + kite_sql.state.table_cache(), + kite_sql.state.view_cache(), + kite_sql.state.meta_cache(), + &transaction, + kite_sql.state.scala_functions(), + kite_sql.state.table_functions(), + )?; + + let join_plan = match source_plan.operator { + Operator::Project(_) => source_plan.childrens.pop_only(), + Operator::Join(_) => source_plan, + _ => unreachable!("expected a join plan"), + }; + let Operator::Join(join_op) = join_plan.operator else { + unreachable!("expected join operator"); + }; + let JoinCondition::On { on, filter } = join_op.on else { + unreachable!("expected join condition"); + }; + assert!(filter.is_none()); + assert_eq!(on.len(), 1); + let ScalarExpression::ColumnRef { + position: left_position, + .. + } = on[0].0.unpack_alias_ref() + else { + unreachable!("expected left join key column ref"); + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = on[0].1.unpack_alias_ref() + else { + unreachable!("expected right join key column ref"); + }; + assert_eq!(*left_position, 1); + assert_eq!(*right_position, 1); + + let join_plan = match best_plan.operator { + Operator::Project(_) => best_plan.childrens.pop_only(), + Operator::Join(_) => best_plan, + _ => unreachable!("expected a join plan"), + }; + let Operator::Join(join_op) = join_plan.operator else { + unreachable!("expected join operator"); + }; + let JoinCondition::On { on, .. } = join_op.on else { + unreachable!("expected join condition"); + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = on[0].1.unpack_alias_ref() + else { + unreachable!("expected right join key column ref"); + }; + assert_eq!(*right_position, 1); + + Ok(()) + } + + #[test] + fn test_join_on_with_right_filter_keeps_localized_key() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + + kite_sql + .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? + .done()?; + kite_sql + .run("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")? + .done()?; + + let stmt = crate::db::prepare( + "SELECT o.x, t.y FROM onecolumn o INNER JOIN twocolumn t ON (o.x=t.x AND t.y=53)", + )?; + let transaction = kite_sql.storage.transaction()?; + let mut binder = Binder::new( + BinderContext::new( + kite_sql.state.table_cache(), + kite_sql.state.view_cache(), + &transaction, + kite_sql.state.scala_functions(), + kite_sql.state.table_functions(), + Arc::new(AtomicUsize::new(0)), + ), + &[], + None, + ); + let source_plan = binder.bind(&stmt)?; + let best_plan = kite_sql.state.build_plan( + &stmt, + [], + kite_sql.state.table_cache(), + kite_sql.state.view_cache(), + kite_sql.state.meta_cache(), + &transaction, + kite_sql.state.scala_functions(), + kite_sql.state.table_functions(), + )?; + + let join_plan = match source_plan.operator { + Operator::Project(_) => source_plan.childrens.pop_only(), + Operator::Join(_) => source_plan, + _ => unreachable!("expected a join plan"), + }; + let Operator::Join(join_op) = join_plan.operator else { + unreachable!("expected join operator"); + }; + let JoinCondition::On { on, filter } = join_op.on else { + unreachable!("expected join condition"); + }; + assert_eq!(on.len(), 1); + let ScalarExpression::ColumnRef { + position: left_position, + .. + } = on[0].0.unpack_alias_ref() + else { + unreachable!("expected left join key column ref"); + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = on[0].1.unpack_alias_ref() + else { + unreachable!("expected right join key column ref"); + }; + assert_eq!(*left_position, 1); + assert_eq!(*right_position, 1); + let Some(filter) = filter else { + unreachable!("expected join filter"); + }; + let mut referenced_columns = Vec::new(); + filter.visit_referenced_columns(true, &mut |column| { + referenced_columns.push(column.clone()); + true + }); + assert_eq!(referenced_columns.len(), 1); + assert_eq!(referenced_columns[0].name(), "y"); + + let join_plan = match best_plan.operator { + Operator::Project(_) => best_plan.childrens.pop_only(), + Operator::Join(_) => best_plan, + _ => unreachable!("expected a join plan"), + }; + let Operator::Join(join_op) = join_plan.operator else { + unreachable!("expected join operator"); + }; + let JoinCondition::On { on, filter } = join_op.on else { + unreachable!("expected join condition"); + }; + assert_eq!(on.len(), 1); + assert!(filter.is_none()); + let ScalarExpression::ColumnRef { + position: left_position, + .. + } = on[0].0.unpack_alias_ref() + else { + unreachable!("expected left join key column ref"); + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = on[0].1.unpack_alias_ref() + else { + unreachable!("expected right join key column ref"); + }; + assert_eq!(*left_position, 0); + assert_eq!(*right_position, 0); + + Ok(()) + } + + #[test] + fn test_join_on_with_right_filter_keeps_localized_key_with_data() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + + kite_sql + .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? + .done()?; + kite_sql + .run("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")? + .done()?; + kite_sql + .run("INSERT INTO onecolumn(id, x) VALUES (0, 44), (1, NULL), (2, 42)")? + .done()?; + kite_sql + .run( + "INSERT INTO twocolumn(t_id, x, y) VALUES (0,44,51), (1,NULL,52), (2,42,53), (3,45,45)", + )? + .done()?; + + let stmt = crate::db::prepare( + "SELECT o.x, t.y FROM onecolumn o INNER JOIN twocolumn t ON (o.x=t.x AND t.y=53)", + )?; + let transaction = kite_sql.storage.transaction()?; + let best_plan = kite_sql.state.build_plan( + &stmt, + [], + kite_sql.state.table_cache(), + kite_sql.state.view_cache(), + kite_sql.state.meta_cache(), + &transaction, + kite_sql.state.scala_functions(), + kite_sql.state.table_functions(), + )?; + let join_plan = match best_plan.operator { + Operator::Project(_) => best_plan.childrens.pop_only(), + Operator::Join(_) => best_plan, + _ => unreachable!("expected a join plan"), + }; + let Operator::Join(join_op) = join_plan.operator else { + unreachable!("expected join operator"); + }; + let JoinCondition::On { on, filter } = join_op.on else { + unreachable!("expected join condition"); + }; + assert_eq!(on.len(), 1); + assert!(filter.is_none()); + let ScalarExpression::ColumnRef { + position: left_position, + .. + } = on[0].0.unpack_alias_ref() + else { + unreachable!("expected left join key column ref"); + }; + let ScalarExpression::ColumnRef { + position: right_position, + .. + } = on[0].1.unpack_alias_ref() + else { + unreachable!("expected right join key column ref"); + }; + assert_eq!(*left_position, 0); + assert_eq!(*right_position, 0); + let (_, right_child) = join_plan.childrens.pop_twins(); + let Operator::Filter(filter_op) = right_child.operator else { + unreachable!("expected pushed-down filter on right child"); + }; + let ScalarExpression::Binary { + left_expr, + right_expr, + .. + } = filter_op.predicate + else { + unreachable!("expected binary filter predicate"); + }; + let ScalarExpression::ColumnRef { + position: filter_position, + .. + } = left_expr.unpack_alias_ref() + else { + unreachable!("expected filter column ref"); + }; + assert_eq!(*filter_position, 1); + assert!(matches!( + *right_expr, + ScalarExpression::Constant(DataValue::Int32(53)) + )); + + Ok(()) + } + #[test] fn test_prepare_statment() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -1004,7 +1479,7 @@ pub(crate) mod test { #[test] fn test_run_multi_statement() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t_multi (a int primary key, b int)")? @@ -1030,7 +1505,7 @@ pub(crate) mod test { #[test] fn test_run_multi_statement_disallow_ddl() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; let err = match kite_sql.run("create table t_multi_ddl (a int primary key); select 1") { Ok(_) => panic!("multi-statement execution with DDL should be rejected"), @@ -1049,7 +1524,7 @@ pub(crate) mod test { #[test] fn test_bind_error_with_span() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t_bind_span(id int primary key)")? @@ -1059,7 +1534,7 @@ pub(crate) mod test { Ok(_) => panic!("expected bind error"), Err(err) => err, }; - println!("{}", err); + println!("{err}"); match err { DatabaseError::ColumnNotFound { span, .. } @@ -1082,7 +1557,7 @@ pub(crate) mod test { #[test] fn test_bind_function_error_with_span() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t_bind_fn_span(id int primary key)")? @@ -1092,7 +1567,7 @@ pub(crate) mod test { Ok(_) => panic!("expected function bind error"), Err(err) => err, }; - println!("{}", err); + println!("{err}"); match err { DatabaseError::FunctionNotFound { span, .. } => { @@ -1114,7 +1589,7 @@ pub(crate) mod test { #[test] fn test_transaction_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -1161,7 +1636,7 @@ pub(crate) mod test { #[test] fn test_transaction_run_multi_statement() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t_multi_tx (a int primary key, b int)")? @@ -1298,7 +1773,7 @@ pub(crate) mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let result = DataBaseBuilder::path(temp_dir.path()) .histogram_buckets(0) - .build(); + .build_rocksdb(); assert!(matches!( result, diff --git a/src/errors.rs b/src/errors.rs index 2d6f2cda..83ea114d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -211,7 +211,7 @@ pub enum DatabaseError { PrimaryKeyNotFound, #[error("primaryKey only allows single or multiple values")] PrimaryKeyTooManyLayers, - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] #[error("rocksdb: {0}")] RocksDB( #[source] @@ -430,7 +430,7 @@ fn build_sql_highlight(sql: &str, span: &SqlErrorSpan) -> Option { for (i, line) in lines.iter().enumerate() { let line_no = i + 1; - out.push_str(&format!("{line_no:>width$} | {line}\n", width = width)); + out.push_str(&format!("{line_no:>width$} | {line}\n")); if line_no == span.line { let char_len = line.chars().count(); diff --git a/src/execution/ddl/add_column.rs b/src/execution/ddl/add_column.rs index 625cf063..0d40bc7f 100644 --- a/src/execution/ddl/add_column.rs +++ b/src/execution/ddl/add_column.rs @@ -14,129 +14,124 @@ use super::rewrite_table_in_batches; use crate::errors::DatabaseError; -use crate::execution::{spawn_executor, Executor, WriteExecutor}; -use crate::storage::{StatisticsMetaCache, TableCache, ViewCache}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::planner::operator::alter_table::add_column::AddColumnOperator; +use crate::storage::Transaction; use crate::types::index::{Index, IndexType}; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; -use crate::{ - planner::operator::alter_table::add_column::AddColumnOperator, storage::Transaction, throw, -}; use itertools::Itertools; pub struct AddColumn { - op: AddColumnOperator, + op: Option, } impl From for AddColumn { fn from(op: AddColumnOperator) -> Self { - Self { op } + Self { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for AddColumn { - fn execute_mut( + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let AddColumnOperator { - table_name, - column, - if_not_exists, - } = &self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::AddColumn(self)) + } +} - let table_catalog = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) - .cloned() - .ok_or(DatabaseError::TableNotFound) - ); - if table_catalog.get_column_by_name(column.name()).is_some() { - if *if_not_exists { - co.yield_(Ok(TupleBuilder::build_result("1".to_string()))) - .await; - return; - } - co.yield_(Err(DatabaseError::DuplicateColumn( - column.name().to_string(), - ))) - .await; - return; +impl AddColumn { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let table_cache = arena.table_cache(); + let Some(AddColumnOperator { + table_name, + column, + if_not_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; + + let table_catalog = arena + .transaction_mut() + .table(table_cache, table_name.clone())? + .cloned() + .ok_or(DatabaseError::TableNotFound)?; + if table_catalog.get_column_by_name(column.name()).is_some() { + if if_not_exists { + TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); + arena.resume(); + return Ok(()); } + return Err(DatabaseError::DuplicateColumn(column.name().to_string())); + } - let schema = table_catalog.schema_ref().clone(); - let old_deserializers = schema - .iter() - .map(|column_ref| column_ref.datatype().serializable()) - .collect_vec(); - let serializers = schema - .iter() - .map(|column_ref| column_ref.datatype().serializable()) - .chain(::std::iter::once(column.datatype().serializable())) - .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); - let default_value = throw!(co, column.default_value()); + let schema = table_catalog.schema_ref().clone(); + let old_deserializers = schema + .iter() + .map(|column_ref| column_ref.datatype().serializable()) + .collect_vec(); + let serializers = schema + .iter() + .map(|column_ref| column_ref.datatype().serializable()) + .chain(::std::iter::once(column.datatype().serializable())) + .collect_vec(); + let pk_ty = table_catalog.primary_keys_type().clone(); + let default_value = column.default_value()?; - let col_id = throw!( - co, - unsafe { &mut (*transaction) }.add_column( - cache.0, - table_name, - column, - *if_not_exists - ) - ); - let unique_meta = if column.desc().is_unique() { - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) + let col_id = + arena + .transaction_mut() + .add_column(table_cache, &table_name, &column, if_not_exists)?; + let unique_meta = if column.desc().is_unique() { + arena + .transaction_mut() + .table(table_cache, table_name.clone())? .and_then(|table| table.get_unique_index(&col_id)) .cloned() - } else { - None - }; - let default_for_index = default_value.clone(); + } else { + None + }; + let default_for_index = default_value.clone(); - throw!( - co, - rewrite_table_in_batches( - unsafe { &mut (*transaction) }, - table_name, - &pk_ty, - &old_deserializers, - schema.len(), - schema.len(), - &serializers, - |mut tuple| { - if let Some(value) = &default_value { - tuple.values.push(value.clone()); - } else { - tuple.values.push(DataValue::Null); - } - Ok(tuple) - }, - |transaction, tuple| { - if let (Some(unique_meta), Some(value), Some(tuple_id)) = ( - unique_meta.as_ref(), - default_for_index.as_ref(), - tuple.pk.as_ref(), - ) { - let index = Index::new(unique_meta.id, value, IndexType::Unique); - transaction.add_index(table_name, index, tuple_id)?; - } - Ok(()) - }, - ) - ); + rewrite_table_in_batches( + arena.transaction_mut(), + &table_name, + &pk_ty, + &old_deserializers, + schema.len(), + schema.len(), + &serializers, + |tuple| { + if let Some(value) = &default_value { + tuple.values.push(value.clone()); + } else { + tuple.values.push(DataValue::Null); + } + Ok(()) + }, + |transaction, tuple| { + if let (Some(unique_meta), Some(value), Some(tuple_id)) = ( + unique_meta.as_ref(), + default_for_index.as_ref(), + tuple.pk.as_ref(), + ) { + let index = Index::new(unique_meta.id, value, IndexType::Unique); + transaction.add_index(&table_name, index, tuple_id)?; + } + Ok(()) + }, + )?; - co.yield_(Ok(TupleBuilder::build_result("1".to_string()))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/change_column.rs b/src/execution/ddl/change_column.rs index 339553af..bc5301eb 100644 --- a/src/execution/ddl/change_column.rs +++ b/src/execution/ddl/change_column.rs @@ -14,163 +14,150 @@ use super::{rewrite_table_in_batches, visit_table_in_batches}; use crate::errors::DatabaseError; -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::alter_table::change_column::{ChangeColumnOperator, NotNullChange}; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; use itertools::Itertools; pub struct ChangeColumn { - op: ChangeColumnOperator, + op: Option, } impl From for ChangeColumn { fn from(op: ChangeColumnOperator) -> Self { - Self { op } + Self { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for ChangeColumn { - fn execute_mut( + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let ChangeColumnOperator { - table_name, - old_column_name, - new_column_name, - data_type, - default_change, - not_null_change, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::ChangeColumn(self)) + } +} + +impl ChangeColumn { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let table_cache = arena.table_cache(); + let Some(ChangeColumnOperator { + table_name, + old_column_name, + new_column_name, + data_type, + default_change, + not_null_change, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - let table_catalog = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) - .cloned() - .ok_or(DatabaseError::TableNotFound) - ); - let schema = table_catalog.schema_ref().clone(); - let (column_index, old_column) = throw!( - co, - schema - .iter() - .enumerate() - .find(|(_, column)| column.name() == old_column_name) - .map(|(index, column)| (index, column.clone())) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.clone())) - ); - let needs_data_rewrite = old_column.datatype() != &data_type; - let needs_not_null_validation = matches!(not_null_change, NotNullChange::Set); + let table_catalog = arena + .transaction_mut() + .table(table_cache, table_name.clone())? + .cloned() + .ok_or(DatabaseError::TableNotFound)?; + let schema = table_catalog.schema_ref().clone(); + let (column_index, old_column) = schema + .iter() + .enumerate() + .find(|(_, column)| column.name() == old_column_name) + .map(|(index, column)| (index, column.clone())) + .ok_or_else(|| DatabaseError::column_not_found(old_column_name.clone()))?; + let needs_data_rewrite = old_column.datatype() != &data_type; + let needs_not_null_validation = matches!(not_null_change, NotNullChange::Set); - if needs_data_rewrite { - let Some(column_id) = old_column.id() else { - co.yield_(Err(DatabaseError::column_not_found( - old_column_name.clone(), - ))) - .await; - return; - }; - let affected_index = table_catalog - .indexes() - .find(|index_meta| index_meta.column_ids.contains(&column_id)); - if let Some(index_meta) = affected_index { - co.yield_(Err(DatabaseError::UnsupportedStmt(format!( - "cannot alter type of indexed column `{}`; drop index `{}` first", - old_column_name, index_meta.name - )))) - .await; - return; - } + if needs_data_rewrite { + let Some(column_id) = old_column.id() else { + return Err(DatabaseError::column_not_found(old_column_name.clone())); + }; + let affected_index = table_catalog + .indexes() + .find(|index_meta| index_meta.column_ids.contains(&column_id)); + if let Some(index_meta) = affected_index { + return Err(DatabaseError::UnsupportedStmt(format!( + "cannot alter type of indexed column `{}`; drop index `{}` first", + old_column_name, index_meta.name + ))); } + } - let old_deserializers = schema + let old_deserializers = schema + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let pk_ty = table_catalog.primary_keys_type().clone(); + + if needs_data_rewrite { + let serializers = schema .iter() - .map(|column| column.datatype().serializable()) + .enumerate() + .map(|(index, column)| { + if index == column_index { + data_type.serializable() + } else { + column.datatype().serializable() + } + }) .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); - - if needs_data_rewrite { - let serializers = schema - .iter() - .enumerate() - .map(|(index, column)| { - if index == column_index { - data_type.serializable() - } else { - column.datatype().serializable() - } - }) - .collect_vec(); - let target_column_name = new_column_name.clone(); - let target_data_type = data_type.clone(); - throw!( - co, - rewrite_table_in_batches( - unsafe { &mut (*transaction) }, - &table_name, - &pk_ty, - &old_deserializers, - schema.len(), - schema.len(), - &serializers, - |mut tuple| { - tuple.values[column_index] = - tuple.values[column_index].clone().cast(&target_data_type)?; - if needs_not_null_validation && tuple.values[column_index].is_null() { - return Err(DatabaseError::not_null_column( - target_column_name.clone(), - )); - } - Ok(tuple) - }, - |_, _| Ok(()), - ) - ); - } else if needs_not_null_validation { - let target_column_name = new_column_name.clone(); - throw!( - co, - visit_table_in_batches( - unsafe { &*transaction }, - &table_name, - &pk_ty, - &old_deserializers, - schema.len(), - schema.len(), - |tuple| { - if tuple.values[column_index].is_null() { - return Err(DatabaseError::not_null_column( - target_column_name.clone(), - )); - } - Ok(()) - }, - ) - ); - } + let target_column_name = new_column_name.clone(); + let target_data_type = data_type.clone(); + rewrite_table_in_batches( + arena.transaction_mut(), + &table_name, + &pk_ty, + &old_deserializers, + schema.len(), + schema.len(), + &serializers, + |tuple| { + tuple.values[column_index] = + tuple.values[column_index].clone().cast(&target_data_type)?; + if needs_not_null_validation && tuple.values[column_index].is_null() { + return Err(DatabaseError::not_null_column(target_column_name.clone())); + } + Ok(()) + }, + |_, _| Ok(()), + )?; + } else if needs_not_null_validation { + let target_column_name = new_column_name.clone(); + visit_table_in_batches( + arena.transaction(), + &table_name, + &pk_ty, + &old_deserializers, + schema.len(), + schema.len(), + |tuple| { + if tuple.values[column_index].is_null() { + return Err(DatabaseError::not_null_column(target_column_name.clone())); + } + Ok(()) + }, + )?; + } - throw!( - co, - unsafe { &mut (*transaction) }.change_column( - cache.0, - &table_name, - &old_column_name, - &new_column_name, - &data_type, - &default_change, - ¬_null_change, - ) - ); + arena.transaction_mut().change_column( + table_cache, + &table_name, + &old_column_name, + &new_column_name, + &data_type, + &default_change, + ¬_null_change, + )?; - co.yield_(Ok(TupleBuilder::build_result(format!("{table_name}")))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/create_index.rs b/src/execution/ddl/create_index.rs index ad21246c..8d8cc2e2 100644 --- a/src/execution/ddl/create_index.rs +++ b/src/execution/ddl/create_index.rs @@ -12,105 +12,128 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::DatabaseError; -use crate::execution::{build_read, spawn_executor, Executor, WriteExecutor}; -use crate::expression::{BindPosition, ScalarExpression}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::expression::ScalarExpression; use crate::planner::operator::create_index::CreateIndexOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::index::Index; -use crate::types::tuple::Tuple; +use crate::types::tuple::SchemaRef; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use crate::types::ColumnId; -use std::borrow::Cow; pub struct CreateIndex { - op: CreateIndexOperator, - input: LogicalPlan, + op: Option, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, } impl From<(CreateIndexOperator, LogicalPlan)> for CreateIndex { - fn from((op, input): (CreateIndexOperator, LogicalPlan)) -> Self { - Self { op, input } + fn from((op, mut input): (CreateIndexOperator, LogicalPlan)) -> Self { + Self { + op: Some(op), + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: 0, + } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateIndex { - fn execute_mut( + fn into_executor( mut self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let CreateIndexOperator { - table_name, - index_name, - columns, - if_not_exists, - ty, - } = self.op; + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("create index input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::CreateIndex(self)) + } +} + +impl CreateIndex { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let table_cache = arena.table_cache(); - let (column_ids, mut column_exprs): (Vec, Vec) = columns - .into_iter() - .filter_map(|column| { - column - .id() - .map(|id| (id, ScalarExpression::column_expr(column))) + let Some(CreateIndexOperator { + table_name, + index_name, + columns, + if_not_exists, + ty, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; + + let (column_ids, column_exprs): (Vec, Vec) = columns + .into_iter() + .filter_map(|column| { + column.id().and_then(|id| { + self.input_schema + .iter() + .position(|schema_column| schema_column == &column) + .map(|position| (id, ScalarExpression::column_expr(column, position))) }) - .unzip(); - let schema = self.input.output_schema().clone(); - throw!( - co, - BindPosition::bind_exprs( - column_exprs.iter_mut(), - || schema.iter().map(Cow::Borrowed), - |a, b| a == b - ) - ); - let index_id = match unsafe { &mut (*transaction) }.add_index_meta( - cache.0, - &table_name, - index_name, - column_ids, - ty, - ) { - Ok(index_id) => index_id, - Err(DatabaseError::DuplicateIndex(index_name)) => { - if if_not_exists { - return; - } else { - throw!(co, Err(DatabaseError::DuplicateIndex(index_name))) - } + }) + .unzip(); + let index_id = match arena.transaction_mut().add_index_meta( + table_cache, + &table_name, + index_name, + column_ids, + ty, + ) { + Ok(index_id) => index_id, + Err(DatabaseError::DuplicateIndex(index_name)) => { + if if_not_exists { + arena.finish(); + return Ok(()); + } else { + return Err(DatabaseError::DuplicateIndex(index_name)); } - err => throw!(co, err), - }; - let mut coroutine = build_read(self.input, cache, transaction); - - for tuple in coroutine.by_ref() { - let tuple: Tuple = throw!(co, tuple); + } + Err(err) => return Err(err), + }; - let Some(value) = DataValue::values_to_tuple(throw!( - co, - Projection::projection(&tuple, &column_exprs, &schema) - )) else { + while arena.next_tuple(self.input)? { + let (value, tuple_pk) = { + let tuple = arena.result_tuple(); + let Some(value) = DataValue::values_to_tuple(Projection::projection( + tuple, + &column_exprs, + &self.input_schema, + )?) else { continue; }; - let tuple_id = if let Some(tuple_id) = tuple.pk.as_ref() { - tuple_id - } else { + let Some(tuple_pk) = tuple.pk.clone() else { continue; }; - let index = Index::new(index_id, &value, ty); - throw!( - co, - unsafe { &mut (*transaction) }.add_index(table_name.as_ref(), index, tuple_id) - ); - } - co.yield_(Ok(TupleBuilder::build_result("1".to_string()))) - .await; - }) + (value, tuple_pk) + }; + let index = Index::new(index_id, &value, ty); + arena + .transaction_mut() + .add_index(table_name.as_ref(), index, &tuple_pk)?; + } + + TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/create_table.rs b/src/execution/ddl/create_table.rs index 9b8695ac..8ee40b20 100644 --- a/src/execution/ddl/create_table.rs +++ b/src/execution/ddl/create_table.rs @@ -12,47 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::create_table::CreateTableOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct CreateTable { - op: CreateTableOperator, + op: Option, } impl From for CreateTable { fn from(op: CreateTableOperator) -> Self { - CreateTable { op } + CreateTable { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateTable { - fn execute_mut( + fn into_executor( self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let CreateTableOperator { - table_name, - columns, - if_not_exists, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::CreateTable(self)) + } +} + +impl CreateTable { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(CreateTableOperator { + table_name, + columns, + if_not_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - let _ = throw!( - co, - unsafe { &mut (*transaction) }.create_table( - table_cache, - table_name.clone(), - columns, - if_not_exists - ) - ); + arena.transaction_mut().create_table( + arena.table_cache(), + table_name.clone(), + columns, + if_not_exists, + )?; - co.yield_(Ok(TupleBuilder::build_result(format!("{table_name}")))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/create_view.rs b/src/execution/ddl/create_view.rs index ec1cf0e4..a199667b 100644 --- a/src/execution/ddl/create_view.rs +++ b/src/execution/ddl/create_view.rs @@ -12,38 +12,49 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::create_view::CreateViewOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct CreateView { - op: CreateViewOperator, + op: Option, } impl From for CreateView { fn from(op: CreateViewOperator) -> Self { - CreateView { op } + CreateView { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateView { - fn execute_mut( + fn into_executor( self, - (_, view_cache, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let CreateViewOperator { view, or_replace } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::CreateView(self)) + } +} - let result_tuple = TupleBuilder::build_result(format!("{}", view.name)); - throw!( - co, - unsafe { &mut (*transaction) }.create_view(view_cache, view, or_replace) - ); +impl CreateView { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(CreateViewOperator { view, or_replace }) = self.op.take() else { + arena.finish(); + return Ok(()); + }; + let view_name = view.name.to_string(); + arena + .transaction_mut() + .create_view(arena.view_cache(), view, or_replace)?; - co.yield_(Ok(result_tuple)).await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), view_name); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/drop_column.rs b/src/execution/ddl/drop_column.rs index f0eabdfd..42a45762 100644 --- a/src/execution/ddl/drop_column.rs +++ b/src/execution/ddl/drop_column.rs @@ -14,104 +14,107 @@ use super::rewrite_table_in_batches; use crate::errors::DatabaseError; -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::alter_table::drop_column::DropColumnOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; use itertools::Itertools; pub struct DropColumn { - op: DropColumnOperator, + op: Option, } impl From for DropColumn { fn from(op: DropColumnOperator) -> Self { - Self { op } + Self { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { - fn execute_mut( + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let DropColumnOperator { - table_name, - column_name, - if_exists, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::DropColumn(self)) + } +} + +impl DropColumn { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let table_cache = arena.table_cache(); + let meta_cache = arena.meta_cache(); + let Some(DropColumnOperator { + table_name, + column_name, + if_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - let table_catalog = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) - .cloned() - .ok_or(DatabaseError::TableNotFound) - ); - let tuple_columns = table_catalog.schema_ref().clone(); - if let Some((column_index, is_primary)) = tuple_columns + let table_catalog = arena + .transaction_mut() + .table(table_cache, table_name.clone())? + .cloned() + .ok_or(DatabaseError::TableNotFound)?; + let tuple_columns = table_catalog.schema_ref().clone(); + if let Some((column_index, is_primary)) = tuple_columns + .iter() + .enumerate() + .find(|(_, column)| column.name() == column_name) + .map(|(i, column)| (i, column.desc().is_primary())) + { + if is_primary { + return Err(DatabaseError::invalid_column( + "drop of primary key column is not allowed.".to_owned(), + )); + } + let old_deserializers = tuple_columns + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let serializers = tuple_columns .iter() .enumerate() - .find(|(_, column)| column.name() == column_name) - .map(|(i, column)| (i, column.desc().is_primary())) - { - if is_primary { - throw!( - co, - Err(DatabaseError::invalid_column( - "drop of primary key column is not allowed.".to_owned(), - )) - ); - } - let old_deserializers = tuple_columns - .iter() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let serializers = tuple_columns - .iter() - .enumerate() - .filter(|(i, _)| *i != column_index) - .map(|(_, column)| column.datatype().serializable()) - .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); - throw!( - co, - rewrite_table_in_batches( - unsafe { &mut (*transaction) }, - &table_name, - &pk_ty, - &old_deserializers, - tuple_columns.len(), - tuple_columns.len(), - &serializers, - |mut tuple| { - let _ = tuple.values.remove(column_index); - Ok(tuple) - }, - |_, _| Ok(()), - ) - ); - throw!( - co, - unsafe { &mut (*transaction) }.drop_column( - cache.0, - cache.2, - &table_name, - &column_name - ) - ); + .filter(|(i, _)| *i != column_index) + .map(|(_, column)| column.datatype().serializable()) + .collect_vec(); + let pk_ty = table_catalog.primary_keys_type().clone(); + rewrite_table_in_batches( + arena.transaction_mut(), + &table_name, + &pk_ty, + &old_deserializers, + tuple_columns.len(), + tuple_columns.len(), + &serializers, + |tuple| { + let _ = tuple.values.remove(column_index); + Ok(()) + }, + |_, _| Ok(()), + )?; + arena.transaction_mut().drop_column( + table_cache, + meta_cache, + &table_name, + &column_name, + )?; - co.yield_(Ok(TupleBuilder::build_result("1".to_string()))) - .await; - } else if !if_exists { - co.yield_(Err(DatabaseError::column_not_found(column_name))) - .await; - } - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); + arena.resume(); + Ok(()) + } else if !if_exists { + Err(DatabaseError::column_not_found(column_name)) + } else { + arena.finish(); + Ok(()) + } } } diff --git a/src/execution/ddl/drop_index.rs b/src/execution/ddl/drop_index.rs index 1253d70f..04b31f8c 100644 --- a/src/execution/ddl/drop_index.rs +++ b/src/execution/ddl/drop_index.rs @@ -12,48 +12,58 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::drop_index::DropIndexOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct DropIndex { - op: DropIndexOperator, + op: Option, } impl From for DropIndex { fn from(op: DropIndexOperator) -> Self { - Self { op } + Self { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropIndex { - fn execute_mut( + fn into_executor( self, - (table_cache, _, meta_cache): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let DropIndexOperator { - table_name, - index_name, - if_exists, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::DropIndex(self)) + } +} + +impl DropIndex { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(DropIndexOperator { + table_name, + index_name, + if_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - throw!( - co, - unsafe { &mut (*transaction) }.drop_index( - table_cache, - meta_cache, - table_name, - &index_name, - if_exists - ) - ); + arena.transaction_mut().drop_index( + arena.table_cache(), + arena.meta_cache(), + table_name, + &index_name, + if_exists, + )?; - co.yield_(Ok(TupleBuilder::build_result(index_name.to_string()))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), index_name.to_string()); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/drop_table.rs b/src/execution/ddl/drop_table.rs index 36330258..4597d94a 100644 --- a/src/execution/ddl/drop_table.rs +++ b/src/execution/ddl/drop_table.rs @@ -12,45 +12,53 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::drop_table::DropTableOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct DropTable { - op: DropTableOperator, + op: Option, } impl From for DropTable { fn from(op: DropTableOperator) -> Self { - DropTable { op } + DropTable { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropTable { - fn execute_mut( + fn into_executor( self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let DropTableOperator { - table_name, - if_exists, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::DropTable(self)) + } +} + +impl DropTable { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(DropTableOperator { + table_name, + if_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - throw!( - co, - unsafe { &mut (*transaction) }.drop_table( - table_cache, - table_name.clone(), - if_exists - ) - ); + arena + .transaction_mut() + .drop_table(arena.table_cache(), table_name.clone(), if_exists)?; - co.yield_(Ok(TupleBuilder::build_result(format!("{table_name}")))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/drop_view.rs b/src/execution/ddl/drop_view.rs index 280715d3..f2c771c1 100644 --- a/src/execution/ddl/drop_view.rs +++ b/src/execution/ddl/drop_view.rs @@ -12,46 +12,55 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::drop_view::DropViewOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct DropView { - op: DropViewOperator, + op: Option, } impl From for DropView { fn from(op: DropViewOperator) -> Self { - DropView { op } + DropView { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropView { - fn execute_mut( + fn into_executor( self, - (table_cache, view_cache, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let DropViewOperator { - view_name, - if_exists, - } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::DropView(self)) + } +} + +impl DropView { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(DropViewOperator { + view_name, + if_exists, + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; - throw!( - co, - unsafe { &mut (*transaction) }.drop_view( - view_cache, - table_cache, - view_name.clone(), - if_exists - ) - ); + let table_cache = arena.table_cache(); + let view_cache = arena.view_cache(); + arena + .transaction_mut() + .drop_view(view_cache, table_cache, view_name.clone(), if_exists)?; - co.yield_(Ok(TupleBuilder::build_result(format!("{view_name}")))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{view_name}")); + arena.resume(); + Ok(()) } } diff --git a/src/execution/ddl/mod.rs b/src/execution/ddl/mod.rs index 8609a647..9f209b01 100644 --- a/src/execution/ddl/mod.rs +++ b/src/execution/ddl/mod.rs @@ -34,6 +34,7 @@ use std::collections::Bound; const REWRITE_BATCH_SIZE: usize = 1024; +#[allow(clippy::too_many_arguments)] fn read_tuple_batch( transaction: &T, table_name: &TableName, @@ -44,33 +45,53 @@ fn read_tuple_batch( start_after: Option<&TupleId>, batch: &mut Vec, batch_size: usize, -) -> Result<(), DatabaseError> { +) -> Result { let table_codec = unsafe { &*transaction.table_codec() }; let lower = if let Some(last_pk) = start_after { - Bound::Excluded(table_codec.encode_tuple_key(table_name.as_ref(), last_pk)?) + table_codec.with_tuple_key_unchecked(table_name.as_ref(), last_pk, |key| { + Ok::<_, DatabaseError>(Bound::Excluded(key.to_vec())) + })? } else { - let (min, _) = table_codec.tuple_bound(table_name.as_ref()); - Bound::Included(min) + Bound::Unbounded }; - let (_, max) = table_codec.tuple_bound(table_name.as_ref()); - let mut iter = transaction.range(lower, Bound::Included(max))?; - batch.clear(); - while batch.len() < batch_size { - let Some((key, value)) = iter.try_next()? else { - break; + table_codec.with_tuple_bound(table_name.as_ref(), |min, max| { + let lower = match &lower { + Bound::Included(bytes) => Bound::Included(bytes.as_slice()), + Bound::Excluded(bytes) => Bound::Excluded(bytes.as_slice()), + Bound::Unbounded => Bound::Included(min), }; - let tuple_id = TableCodec::decode_tuple_key(&key, pk_ty)?; - batch.push(TableCodec::decode_tuple( - old_deserializers, - Some(tuple_id), - &value, - old_values_len, - old_total_len, - )?); - } + let mut iter = transaction.range(lower, Bound::Included(max))?; + let mut len = 0; + + while len < batch_size { + let Some((key, value)) = iter.try_next()? else { + break; + }; + let tuple_id = TableCodec::decode_tuple_key(key, pk_ty)?; + let tuple = if len < batch.len() { + &mut batch[len] + } else { + batch.push(Tuple { + pk: None, + values: Vec::with_capacity(old_values_len), + }); + batch + .last_mut() + .expect("batch contains the tuple that was just pushed") + }; + TableCodec::decode_tuple_into( + tuple, + old_deserializers, + Some(tuple_id), + value, + old_total_len, + )?; + len += 1; + } - Ok(()) + Ok(len) + }) } pub(crate) fn visit_table_in_batches( @@ -84,13 +105,13 @@ pub(crate) fn visit_table_in_batches( ) -> Result<(), DatabaseError> where T: Transaction, - F: FnMut(Tuple) -> Result<(), DatabaseError>, + F: FnMut(&Tuple) -> Result<(), DatabaseError>, { let mut last_pk = None; let mut batch = Vec::with_capacity(REWRITE_BATCH_SIZE); loop { - read_tuple_batch( + let batch_len = read_tuple_batch( transaction, table_name, pk_ty, @@ -101,13 +122,12 @@ where &mut batch, REWRITE_BATCH_SIZE, )?; - let batch_len = batch.len(); if batch_len == 0 { break; } - last_pk = batch.last().and_then(|tuple| tuple.pk.clone()); + last_pk = batch.get(batch_len - 1).and_then(|tuple| tuple.pk.clone()); - for tuple in batch.drain(..) { + for tuple in batch.iter().take(batch_len) { visit(tuple)?; } @@ -119,6 +139,7 @@ where Ok(()) } +#[allow(clippy::too_many_arguments)] pub(crate) fn rewrite_table_in_batches( transaction: &mut T, table_name: &TableName, @@ -132,14 +153,14 @@ pub(crate) fn rewrite_table_in_batches( ) -> Result<(), DatabaseError> where T: Transaction, - F: FnMut(Tuple) -> Result, + F: FnMut(&mut Tuple) -> Result<(), DatabaseError>, G: FnMut(&mut T, &Tuple) -> Result<(), DatabaseError>, { let mut last_pk = None; let mut batch = Vec::with_capacity(REWRITE_BATCH_SIZE); loop { - read_tuple_batch( + let batch_len = read_tuple_batch( transaction, table_name, pk_ty, @@ -150,16 +171,15 @@ where &mut batch, REWRITE_BATCH_SIZE, )?; - let batch_len = batch.len(); if batch_len == 0 { break; } - last_pk = batch.last().and_then(|tuple| tuple.pk.clone()); + last_pk = batch.get(batch_len - 1).and_then(|tuple| tuple.pk.clone()); - for tuple in batch.drain(..) { - let tuple = rewrite(tuple)?; + for tuple in batch.iter_mut().take(batch_len) { + rewrite(tuple)?; transaction.append_tuple(table_name.as_ref(), tuple.clone(), new_serializers, true)?; - after_write(transaction, &tuple)?; + after_write(transaction, tuple)?; } if batch_len < REWRITE_BATCH_SIZE { diff --git a/src/execution/ddl/truncate.rs b/src/execution/ddl/truncate.rs index 29b7e61f..ac5ac4d3 100644 --- a/src/execution/ddl/truncate.rs +++ b/src/execution/ddl/truncate.rs @@ -12,35 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::truncate::TruncateOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct Truncate { - op: TruncateOperator, + op: Option, } impl From for Truncate { fn from(op: TruncateOperator) -> Self { - Truncate { op } + Truncate { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Truncate { - fn execute_mut( + fn into_executor( self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let TruncateOperator { table_name } = self.op; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::Truncate(self)) + } +} - throw!(co, unsafe { &mut (*transaction) }.drop_data(&table_name)); +impl Truncate { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(TruncateOperator { table_name }) = self.op.take() else { + arena.finish(); + return Ok(()); + }; + arena.transaction_mut().drop_data(&table_name)?; - co.yield_(Ok(TupleBuilder::build_result(format!("{table_name}")))) - .await; - }) + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); + arena.resume(); + Ok(()) } } diff --git a/src/execution/dml/analyze.rs b/src/execution/dml/analyze.rs index b79b8fa6..2f55d67b 100644 --- a/src/execution/dml/analyze.rs +++ b/src/execution/dml/analyze.rs @@ -15,20 +15,18 @@ use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, spawn_executor, Executor, WriteExecutor}; -use crate::expression::{BindPosition, ScalarExpression}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::expression::ScalarExpression; use crate::optimizer::core::histogram::HistogramBuilder; use crate::optimizer::core::statistics_meta::StatisticsMeta; use crate::planner::operator::analyze::AnalyzeOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::{StatisticsMetaCache, Transaction}; use crate::types::index::IndexId; -use crate::types::tuple::Tuple; +use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, Utf8Type}; use itertools::Itertools; use sqlparser::ast::CharLengthUnits; -use std::borrow::Cow; use std::fmt::{self, Formatter}; use std::sync::Arc; @@ -36,7 +34,9 @@ const DEFAULT_NUM_OF_BUCKETS: usize = 100; pub struct Analyze { table_name: TableName, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: Option, histogram_buckets: Option, } @@ -48,101 +48,93 @@ impl From<(AnalyzeOperator, LogicalPlan)> for Analyze { index_metas, histogram_buckets, }, - input, + mut input, ): (AnalyzeOperator, LogicalPlan), ) -> Self { let _ = index_metas; Analyze { table_name, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: None, histogram_buckets, } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Analyze { - fn execute_mut( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Analyze { - table_name, - mut input, - histogram_buckets, - } = self; - - let schema = input.output_schema().clone(); - let mut builders = Vec::new(); - let table = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) - .cloned() - .ok_or(DatabaseError::TableNotFound) - ); - - for index in table.indexes() { - builders.push(State { - is_bound_position: false, - index_id: index.id, - exprs: throw!(co, index.column_exprs(&table)), - builder: HistogramBuilder::new(index, None), - histogram_buckets, - }); - } + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("analyze input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::Analyze(self)) + } +} + +impl Analyze { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; + + let mut builders = Vec::new(); + let table = arena + .transaction_mut() + .table(arena.table_cache(), self.table_name.clone())? + .cloned() + .ok_or(DatabaseError::TableNotFound)?; + + for index in table.indexes() { + builders.push(State { + index_id: index.id, + exprs: index.column_exprs(&table)?, + builder: HistogramBuilder::new(index, None), + histogram_buckets: self.histogram_buckets, + }); + } + + while arena.next_tuple(input)? { + let tuple = arena.result_tuple(); + for State { exprs, builder, .. } in builders.iter_mut() { + let values = Projection::projection(tuple, exprs, &self.input_schema)?; - let mut coroutine = build_read(input, cache, transaction); - - for tuple in coroutine.by_ref() { - let tuple = throw!(co, tuple); - - for State { - is_bound_position, - exprs, - builder, - .. - } in builders.iter_mut() - { - if !*is_bound_position { - throw!( - co, - BindPosition::bind_exprs( - exprs.iter_mut(), - || schema.iter().map(Cow::Borrowed), - |a, b| a == b - ) - ); - *is_bound_position = true; - } - let values = throw!(co, Projection::projection(&tuple, exprs, &schema)); - - if values.len() == 1 { - throw!(co, builder.append(&values[0])); - } else { - throw!( - co, - builder.append(&Arc::new(DataValue::Tuple(values, false))) - ); - } + if values.len() == 1 { + builder.append(&values[0])?; + } else { + builder.append(&Arc::new(DataValue::Tuple(values, false)))?; } } - drop(coroutine); - let values = throw!( - co, - Self::persist_statistics_meta(&table_name, builders, cache.2, transaction) - ); - - co.yield_(Ok(Tuple::new(None, values))).await; - }) + } + let values = Self::persist_statistics_meta( + &self.table_name, + builders, + arena.meta_cache(), + arena.transaction_mut(), + )?; + + let output = arena.result_tuple_mut(); + output.pk = None; + output.values = values; + arena.resume(); + Ok(()) } } struct State { - is_bound_position: bool, index_id: IndexId, exprs: Vec, builder: HistogramBuilder, @@ -150,11 +142,11 @@ struct State { } impl Analyze { - fn persist_statistics_meta( + fn persist_statistics_meta( table_name: &TableName, builders: Vec, cache: &StatisticsMetaCache, - transaction: *mut T, + transaction: &mut U, ) -> Result, DatabaseError> { let mut values = Vec::with_capacity(builders.len()); @@ -167,9 +159,9 @@ impl Analyze { { let (histogram, sketch) = builder.build(histogram_buckets.unwrap_or(DEFAULT_NUM_OF_BUCKETS))?; - let meta = StatisticsMeta::new(histogram); + let meta = StatisticsMeta::new(histogram, sketch); - unsafe { &mut (*transaction) }.save_statistics_meta(cache, table_name, meta, sketch)?; + transaction.save_statistics_meta(cache, table_name, meta)?; values.push(DataValue::Utf8 { value: format!("{table_name}/{index_id}"), ty: Utf8Type::Variable(None), @@ -193,7 +185,7 @@ impl fmt::Display for AnalyzeOperator { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::errors::DatabaseError; use crate::execution::dml::analyze::DEFAULT_NUM_OF_BUCKETS; use crate::expression::range_detacher::Range; @@ -218,7 +210,7 @@ mod test { let buckets = 10; let kite_sql = DataBaseBuilder::path(temp_dir.path()) .histogram_buckets(buckets) - .build()?; + .build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -257,7 +249,7 @@ mod test { fn test_meta_loader_uses_cache() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -282,13 +274,15 @@ mod test { drop(transaction); let mut transaction = kite_sql.storage.transaction()?; - let (min, max) = unsafe { &*transaction.table_codec() }.statistics_index_bound("t1", 1); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; - let mut keys: Vec> = Vec::new(); - while let Some((key, _)) = iter.try_next()? { - keys.push(key); - } - drop(iter); + let keys: Vec> = unsafe { &*transaction.table_codec() } + .with_statistics_index_bound("t1", 1, |min, max| { + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut keys = Vec::new(); + while let Some((key, _)) = iter.try_next()? { + keys.push(key.to_vec()); + } + Ok(keys) + })?; for key in keys { transaction.remove(&key)?; } @@ -306,7 +300,7 @@ mod test { fn test_meta_loader_negative_cache() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -330,7 +324,7 @@ mod test { fn test_clean_expired_index() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int)")? @@ -346,34 +340,38 @@ mod test { kite_sql.run("analyze table t1")?.done()?; let transaction = kite_sql.storage.transaction()?; - let (min, max) = unsafe { &*transaction.table_codec() }.statistics_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; - let mut count = 0; - while iter.try_next()?.is_some() { - count += 1; - } + let count = + unsafe { &*transaction.table_codec() }.with_statistics_bound("t1", |min, max| { + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut count = 0; + while iter.try_next()?.is_some() { + count += 1; + } + Ok(count) + })?; assert!(count > 3); - drop(iter); kite_sql.run("alter table t1 drop column b")?.done()?; kite_sql.run("analyze table t1")?.done()?; let transaction = kite_sql.storage.transaction()?; - let (min, max) = unsafe { &*transaction.table_codec() }.statistics_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; - let mut keys = 0; - while iter.try_next()?.is_some() { - keys += 1; - } + let keys = + unsafe { &*transaction.table_codec() }.with_statistics_bound("t1", |min, max| { + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut keys = 0; + while iter.try_next()?.is_some() { + keys += 1; + } + Ok(keys) + })?; let table_name = "t1".to_string().into(); let loader = transaction.meta_loader(kite_sql.state.meta_cache()); let statistics_meta = loader.load(&table_name, 0)?.unwrap(); - let statistics_sketch = transaction - .statistics_sketch(table_name.as_ref(), 0)? - .unwrap(); let expected_keys = 1 + 1 - + statistics_sketch.storage_page_count(COUNT_MIN_SKETCH_STORAGE_PAGE_LEN) + + statistics_meta + .sketch() + .storage_page_count(COUNT_MIN_SKETCH_STORAGE_PAGE_LEN) + statistics_meta.histogram().buckets_len(); assert_eq!(keys, expected_keys); diff --git a/src/execution/dml/copy_from_file.rs b/src/execution/dml/copy_from_file.rs index e213826e..d81f0f53 100644 --- a/src/execution/dml/copy_from_file.rs +++ b/src/execution/dml/copy_from_file.rs @@ -15,91 +15,109 @@ use crate::binder::copy::FileFormat; use crate::catalog::PrimaryKeyIndices; use crate::errors::DatabaseError; -use crate::execution::{spawn_executor, Executor, WriteExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::copy_from_file::CopyFromFileOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use itertools::Itertools; use std::fs::File; use std::io::BufReader; -use std::sync::mpsc; -use std::sync::mpsc::Sender; -use std::thread; pub struct CopyFromFile { - op: CopyFromFileOperator, - size: usize, + op: Option, } impl From for CopyFromFile { fn from(op: CopyFromFileOperator) -> Self { - CopyFromFile { op, size: 0 } + CopyFromFile { op: Some(op) } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CopyFromFile { - fn execute_mut( + fn into_executor( self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let serializers = self - .op - .schema_ref - .iter() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let (tx, rx) = mpsc::channel(); - let (tx1, rx1) = mpsc::channel(); - let table = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(table_cache, self.op.table.clone()) - ) - .ok_or(DatabaseError::TableNotFound) - ); - let primary_keys_indices = table.primary_keys_indices().clone(); - let handle = thread::spawn(|| self.read_file_blocking(tx, primary_keys_indices)); - let mut size = 0_usize; - while let Ok(chunk) = rx.recv() { - throw!( - co, - unsafe { &mut (*transaction) }.append_tuple( - table.name(), - chunk, - &serializers, - false - ) - ); - size += 1; - } - throw!(co, handle.join().unwrap()); - - let handle = thread::spawn(move || return_result(size, tx1)); - while let Ok(chunk) = rx1.recv() { - co.yield_(Ok(chunk)).await; - } - throw!(co, handle.join().unwrap()) - }) + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::CopyFromFile(self)) } } impl CopyFromFile { - /// Read records from file using blocking IO. - /// - /// The read data chunks will be sent through `tx`. + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(op) = self.op.take() else { + arena.finish(); + return Ok(()); + }; + let serializers = op + .schema_ref + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let table = arena + .transaction_mut() + .table(arena.table_cache(), op.table.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let primary_keys_indices = table.primary_keys_indices().clone(); + + let file = File::open(op.source.path)?; + let mut buf_reader = BufReader::new(file); + let mut reader = match op.source.format { + FileFormat::Csv { + delimiter, + quote, + escape, + header, + } => csv::ReaderBuilder::new() + .delimiter(delimiter as u8) + .quote(quote as u8) + .escape(escape.map(|c| c as u8)) + .has_headers(header) + .from_reader(&mut buf_reader), + }; + + let column_count = op.schema_ref.len(); + let tuple_builder = TupleBuilder::new(&op.schema_ref, Some(&primary_keys_indices)); + let mut size = 0_usize; + + for record in reader.records() { + let record = record?; + + if !(record.len() == column_count + || record.len() == column_count + 1 && record.get(column_count) == Some("")) + { + return Err(DatabaseError::MisMatch("columns", "values")); + } + + let chunk = tuple_builder.build_with_row(record.iter())?; + arena + .transaction_mut() + .append_tuple(table.name(), chunk, &serializers, false)?; + size += 1; + } + + TupleBuilder::build_result_into(arena.result_tuple_mut(), size.to_string()); + arena.resume(); + Ok(()) + } + + #[allow(dead_code)] fn read_file_blocking( mut self, - tx: Sender, + tx: std::sync::mpsc::Sender, pk_indices: PrimaryKeyIndices, ) -> Result<(), DatabaseError> { - let file = File::open(self.op.source.path)?; + let Some(op) = self.op.take() else { + return Ok(()); + }; + let file = File::open(op.source.path)?; let mut buf_reader = BufReader::new(file); - let mut reader = match self.op.source.format { + let mut reader = match op.source.format { FileFormat::Csv { delimiter, quote, @@ -113,11 +131,10 @@ impl CopyFromFile { .from_reader(&mut buf_reader), }; - let column_count = self.op.schema_ref.len(); - let tuple_builder = TupleBuilder::new(&self.op.schema_ref, Some(&pk_indices)); + let column_count = op.schema_ref.len(); + let tuple_builder = TupleBuilder::new(&op.schema_ref, Some(&pk_indices)); for record in reader.records() { - // read records and push raw str rows into data chunk builder let record = record?; if !(record.len() == column_count @@ -126,7 +143,6 @@ impl CopyFromFile { return Err(DatabaseError::MisMatch("columns", "values")); } - self.size += 1; tx.send(tuple_builder.build_with_row(record.iter())?) .map_err(|_| DatabaseError::ChannelClose)?; } @@ -134,19 +150,12 @@ impl CopyFromFile { } } -fn return_result(size: usize, tx: Sender) -> Result<(), DatabaseError> { - let tuple = TupleBuilder::build_result(size.to_string()); - - tx.send(tuple).map_err(|_| DatabaseError::ChannelClose)?; - Ok(()) -} - #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::*; use crate::binder::copy::ExtSource; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::errors::DatabaseError; use crate::storage::Storage; use crate::types::LogicalType; @@ -223,19 +232,16 @@ mod tests { }, schema_ref: Arc::new(columns), }; - let executor = CopyFromFile { - op: op.clone(), - size: 0, - }; - let temp_dir = TempDir::new().unwrap(); - let db = DataBaseBuilder::path(temp_dir.path()).build()?; + let tmp_dir = TempDir::new()?; + let db = DataBaseBuilder::path(tmp_dir.path()).build_rocksdb()?; db.run("create table test_copy (a int primary key, b float, c varchar(10))")? .done()?; + let storage = db.storage; let mut transaction = storage.transaction()?; - - let mut executor_iter = executor.execute_mut( + let mut executor = crate::execution::execute_mut( + CopyFromFile::from(op), ( db.state.table_cache(), db.state.view_cache(), @@ -243,11 +249,9 @@ mod tests { ), &mut transaction, ); - let tuple = executor_iter - .next() - .expect("executor should yield once") - .unwrap(); - assert_eq!(tuple, TupleBuilder::build_result(2.to_string())); + + let result = executor.next().expect("copy from file should yield once")?; + assert_eq!(result.values[0].to_string(), "2"); Ok(()) } diff --git a/src/execution/dml/copy_to_file.rs b/src/execution/dml/copy_to_file.rs index 87e55be3..b1dfdbe4 100644 --- a/src/execution/dml/copy_to_file.rs +++ b/src/execution/dml/copy_to_file.rs @@ -14,63 +14,75 @@ use crate::binder::copy::FileFormat; use crate::errors::DatabaseError; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::copy_to_file::CopyToFileOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; pub struct CopyToFile { op: CopyToFileOperator, - input: LogicalPlan, + input_plan: Option, + input: Option, } impl From<(CopyToFileOperator, LogicalPlan)> for CopyToFile { fn from((op, input): (CopyToFileOperator, LogicalPlan)) -> Self { - CopyToFile { op, input } + CopyToFile { + op, + input_plan: Some(input), + input: None, + } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for CopyToFile { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let this = self; - let mut writer = throw!(co, this.create_writer()); - let CopyToFile { input, op } = this; - - let coroutine = build_read(input, cache, transaction); - - for tuple in coroutine { - let tuple = throw!(co, tuple); - - throw!( - co, - writer - .write_record( - tuple - .values - .iter() - .map(|v| v.to_string()) - .collect::>() - ) - .map_err(DatabaseError::from) - ); - } - - throw!(co, writer.flush().map_err(DatabaseError::from)); - - co.yield_(Ok(TupleBuilder::build_result(format!("{op}")))) - .await; - }) + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("copy to file input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::CopyToFile(self)) } } impl CopyToFile { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; + + let mut writer = self.create_writer()?; + while arena.next_tuple(input)? { + let tuple = arena.result_tuple(); + writer.write_record( + tuple + .values + .iter() + .map(|v| v.to_string()) + .collect::>(), + )?; + } + writer.flush().map_err(DatabaseError::from)?; + + TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{}", self.op)); + arena.resume(); + Ok(()) + } + fn create_writer(&self) -> Result, DatabaseError> { let mut writer = match self.op.target.format { FileFormat::Csv { @@ -104,7 +116,7 @@ mod tests { use super::*; use crate::binder::copy::ExtSource; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::errors::DatabaseError; use crate::planner::operator::table_scan::TableScanOperator; use crate::storage::Storage; @@ -180,7 +192,7 @@ mod tests { }; let temp_dir = TempDir::new().unwrap(); - let db = DataBaseBuilder::path(temp_dir.path()).build()?; + let db = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; db.run("create table t1 (a int primary key, b float, c varchar(10))")? .done()?; db.run("insert into t1 values (1, 1.1, 'foo')")?.done()?; @@ -195,9 +207,15 @@ mod tests { let executor = CopyToFile { op: op.clone(), - input: TableScanOperator::build("t1".to_string().into(), table, true)?, + input_plan: Some(TableScanOperator::build( + "t1".to_string().into(), + table, + true, + )?), + input: None, }; - let mut executor = executor.execute( + let mut executor = crate::execution::execute( + executor, ( db.state.table_cache(), db.state.view_cache(), @@ -222,10 +240,7 @@ mod tests { let record3 = records.next().unwrap()?; assert_eq!(record3, vec!["3", "2.1", "Kite"]); - assert!(records.next().is_none()); - - assert_eq!(tuple, TupleBuilder::build_result(format!("{op}"))); - + assert_eq!(tuple.values[0].to_string(), format!("{op}")); Ok(()) } } diff --git a/src/execution/dml/delete.rs b/src/execution/dml/delete.rs index 35fd7e4a..f8da2325 100644 --- a/src/execution/dml/delete.rs +++ b/src/execution/dml/delete.rs @@ -15,128 +15,133 @@ use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, spawn_executor, Executor, WriteExecutor}; -use crate::expression::{BindPosition, ScalarExpression}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::expression::ScalarExpression; use crate::planner::operator::delete::DeleteOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::index::{Index, IndexId, IndexType}; -use crate::types::tuple::Tuple; +use crate::types::tuple::SchemaRef; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; -use std::borrow::Cow; use std::collections::HashMap; pub struct Delete { table_name: TableName, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: Option, } impl From<(DeleteOperator, LogicalPlan)> for Delete { - fn from((DeleteOperator { table_name, .. }, input): (DeleteOperator, LogicalPlan)) -> Self { - Delete { table_name, input } + fn from((DeleteOperator { table_name, .. }, mut input): (DeleteOperator, LogicalPlan)) -> Self { + Delete { + table_name, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: None, + } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Delete { - fn execute_mut( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Delete { - table_name, - mut input, - } = self; + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("delete input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::Delete(self)) + } +} - let schema = input.output_schema().clone(); - let table = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) - .ok_or(DatabaseError::TableNotFound) - ); - let mut indexes: HashMap = HashMap::new(); +impl Delete { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; - let mut deleted_count = 0; - let mut coroutine = build_read(input, cache, transaction); + let table = arena + .transaction_mut() + .table(arena.table_cache(), self.table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let mut indexes: HashMap = HashMap::new(); - for tuple in coroutine.by_ref() { - let tuple: Tuple = throw!(co, tuple); + let mut deleted_count = 0; - for index_meta in table.indexes() { - if let Some(Value { exprs, values, .. }) = indexes.get_mut(&index_meta.id) { - let Some(data_value) = DataValue::values_to_tuple(throw!( - co, - Projection::projection(&tuple, exprs, &schema) - )) else { - continue; - }; - values.push(data_value); - } else { - let mut values = Vec::with_capacity(table.indexes().len()); - let mut exprs = throw!(co, index_meta.column_exprs(table)); - throw!( - co, - BindPosition::bind_exprs( - exprs.iter_mut(), - || schema.iter().map(Cow::Borrowed), - |a, b| a == b - ) - ); - let Some(data_value) = DataValue::values_to_tuple(throw!( - co, - Projection::projection(&tuple, &exprs, &schema) - )) else { - continue; - }; - values.push(data_value); + while arena.next_tuple(input)? { + let tuple = arena.result_tuple().clone(); + for index_meta in table.indexes() { + if let Some(Value { exprs, values, .. }) = indexes.get_mut(&index_meta.id) { + let Some(data_value) = DataValue::values_to_tuple(Projection::projection( + &tuple, + exprs, + &self.input_schema, + )?) else { + continue; + }; + values.push(data_value); + } else { + let mut values = Vec::with_capacity(table.indexes().len()); + let exprs = index_meta.column_exprs(table)?; + let Some(data_value) = DataValue::values_to_tuple(Projection::projection( + &tuple, + &exprs, + &self.input_schema, + )?) else { + continue; + }; + values.push(data_value); - indexes.insert( - index_meta.id, - Value { - exprs, - values, - index_ty: index_meta.ty, - }, - ); - } - } - if let Some(tuple_id) = &tuple.pk { - for ( - index_id, + indexes.insert( + index_meta.id, Value { - values, index_ty, .. + exprs, + values, + index_ty: index_meta.ty, }, - ) in indexes.iter_mut() - { - for value in values { - throw!( - co, - unsafe { &mut (*transaction) }.del_index( - &table_name, - &Index::new(*index_id, value, *index_ty), - tuple_id, - ) - ); - } - } - - throw!( - co, - unsafe { &mut (*transaction) }.remove_tuple(&table_name, tuple_id) ); - deleted_count += 1; } } - drop(coroutine); - co.yield_(Ok(TupleBuilder::build_result(deleted_count.to_string()))) - .await; - }) + if let Some(tuple_id) = &tuple.pk { + for ( + index_id, + Value { + values, index_ty, .. + }, + ) in indexes.iter_mut() + { + for value in values { + arena.transaction_mut().del_index( + &self.table_name, + &Index::new(*index_id, value, *index_ty), + tuple_id, + )?; + } + } + + arena + .transaction_mut() + .remove_tuple(&self.table_name, tuple_id)?; + deleted_count += 1; + } + } + + TupleBuilder::build_result_into(arena.result_tuple_mut(), deleted_count.to_string()); + arena.resume(); + Ok(()) } } diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index ad7cc1aa..3b7fb94b 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -15,24 +15,24 @@ use crate::catalog::{ColumnCatalog, TableName}; use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, spawn_executor, Executor, WriteExecutor}; -use crate::expression::BindPosition; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; use crate::planner::operator::insert::InsertOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::index::Index; +use crate::types::tuple::SchemaRef; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use crate::types::ColumnId; use itertools::Itertools; -use std::borrow::Cow; use std::collections::HashMap; pub struct Insert { table_name: TableName, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: Option, is_overwrite: bool, is_mapping_by_name: bool, } @@ -45,12 +45,14 @@ impl From<(InsertOperator, LogicalPlan)> for Insert { is_overwrite, is_mapping_by_name, }, - input, + mut input, ): (InsertOperator, LogicalPlan), ) -> Self { Insert { table_name, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: None, is_overwrite, is_mapping_by_name, } @@ -74,125 +76,114 @@ impl ColumnCatalog { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { - fn execute_mut( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Insert { - table_name, - mut input, - is_overwrite, - is_mapping_by_name, - } = self; - - let schema = input.output_schema().clone(); + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("insert input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::Insert(self)) + } +} - if let Some(table_catalog) = throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) +impl Insert { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; + + if let Some(table_catalog) = arena + .transaction_mut() + .table(arena.table_cache(), self.table_name.clone())? .cloned() - { - if table_catalog.primary_keys().is_empty() { - throw!(co, Err(DatabaseError::not_null())) - } + { + if table_catalog.primary_keys().is_empty() { + return Err(DatabaseError::not_null()); + } - // Index values must be projected from the full table schema, because - // omitted input columns may be filled by defaults before index maintenance. - let table_schema = table_catalog.schema_ref(); - let mut index_metas = Vec::new(); - for index_meta in table_catalog.indexes() { - let mut exprs = throw!(co, index_meta.column_exprs(&table_catalog)); - throw!( - co, - BindPosition::bind_exprs( - exprs.iter_mut(), - || table_schema.iter().map(Cow::Borrowed), - |a, b| a == b - ) - ); - index_metas.push((index_meta, exprs)); - } + let table_schema = table_catalog.schema_ref(); + let mut index_metas = Vec::new(); + for index_meta in table_catalog.indexes() { + let exprs = index_meta.column_exprs(&table_catalog)?; + index_metas.push((index_meta, exprs)); + } - let serializers = table_catalog - .columns() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let pk_indices = table_catalog.primary_keys_indices(); - let mut coroutine = build_read(input, cache, transaction); - let mut inserted_count = 0; + let serializers = table_catalog + .columns() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let pk_indices = table_catalog.primary_keys_indices(); + let mut inserted_count = 0; - for tuple in coroutine.by_ref() { - let Tuple { values, .. } = throw!(co, tuple); + while arena.next_tuple(input)? { + let values = arena.result_tuple().values.clone(); - let mut tuple_map = HashMap::new(); - for (i, value) in values.into_iter().enumerate() { - tuple_map.insert(schema[i].key(is_mapping_by_name), value); - } - let mut values = Vec::with_capacity(table_catalog.columns_len()); - - for col in table_catalog.columns() { - let mut value = { - let mut value = tuple_map.remove(&col.key(is_mapping_by_name)); - - if value.is_none() { - value = throw!(co, col.default_value()); - } - value.unwrap_or(DataValue::Null) - }; - if !value.is_null() && &value.logical_type() != col.datatype() { - value = throw!(co, value.cast(col.datatype())); - } - throw!(co, value.check_len(col.datatype())); - if value.is_null() && !col.nullable() { - co.yield_(Err(DatabaseError::not_null_column(col.name().to_string()))) - .await; - return; + let mut tuple_map = HashMap::new(); + for (i, value) in values.into_iter().enumerate() { + tuple_map.insert(self.input_schema[i].key(self.is_mapping_by_name), value); + } + let mut values = Vec::with_capacity(table_catalog.columns_len()); + + for col in table_catalog.columns() { + let mut value = { + let mut value = tuple_map.remove(&col.key(self.is_mapping_by_name)); + + if value.is_none() { + value = col.default_value()?; } - values.push(value) + value.unwrap_or(DataValue::Null) + }; + if !value.is_null() && &value.logical_type() != col.datatype() { + value = value.cast(col.datatype())?; } - let pk = Tuple::primary_projection(pk_indices, &values); - let tuple = Tuple::new(Some(pk), values); - - for (index_meta, exprs) in index_metas.iter() { - let values = throw!( - co, - Projection::projection(&tuple, exprs, table_schema.as_slice()) - ); - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; - let tuple_id = throw!( - co, - tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound) - ); - let index = Index::new(index_meta.id, &value, index_meta.ty); - throw!( - co, - unsafe { &mut (*transaction) }.add_index(&table_name, index, tuple_id) - ); + value.check_len(col.datatype())?; + if value.is_null() && !col.nullable() { + return Err(DatabaseError::not_null_column(col.name().to_string())); } - throw!( - co, - unsafe { &mut (*transaction) }.append_tuple( - &table_name, - tuple, - &serializers, - is_overwrite - ) - ); - inserted_count += 1; + values.push(value) } - drop(coroutine); - - co.yield_(Ok(TupleBuilder::build_result(inserted_count.to_string()))) - .await; - } else { - co.yield_(Ok(TupleBuilder::build_result("0".to_string()))) - .await; + let pk = Tuple::primary_projection(pk_indices, &values); + let tuple = Tuple::new(Some(pk), values); + + for (index_meta, exprs) in index_metas.iter() { + let values = Projection::projection(&tuple, exprs, table_schema.as_slice())?; + let Some(value) = DataValue::values_to_tuple(values) else { + continue; + }; + let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; + let index = Index::new(index_meta.id, &value, index_meta.ty); + arena + .transaction_mut() + .add_index(&self.table_name, index, tuple_id)?; + } + arena.transaction_mut().append_tuple( + &self.table_name, + tuple, + &serializers, + self.is_overwrite, + )?; + inserted_count += 1; } - }) + + TupleBuilder::build_result_into(arena.result_tuple_mut(), inserted_count.to_string()); + arena.resume(); + Ok(()) + } else { + TupleBuilder::build_result_into(arena.result_tuple_mut(), "0".to_string()); + arena.resume(); + Ok(()) + } } } diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index c9114e7b..626baf57 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -15,24 +15,25 @@ use crate::catalog::{ColumnRef, TableName}; use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, spawn_executor, Executor, WriteExecutor}; -use crate::expression::{BindPosition, ScalarExpression}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::expression::ScalarExpression; use crate::planner::operator::update::UpdateOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::index::Index; +use crate::types::tuple::SchemaRef; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use itertools::Itertools; -use std::borrow::Cow; use std::collections::HashMap; pub struct Update { table_name: TableName, value_exprs: Vec<(ColumnRef, ScalarExpression)>, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: Option, } impl From<(UpdateOperator, LogicalPlan)> for Update { @@ -42,139 +43,132 @@ impl From<(UpdateOperator, LogicalPlan)> for Update { table_name, value_exprs, }, - input, + mut input, ): (UpdateOperator, LogicalPlan), ) -> Self { Update { table_name, value_exprs, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: None, } } } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { - fn execute_mut( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Update { - table_name, - value_exprs, - mut input, - } = self; + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("update input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::Update(self)) + } +} - let mut exprs_map = HashMap::with_capacity(value_exprs.len()); - for (column, expr) in value_exprs { - exprs_map.insert(column.id(), expr); - } +impl Update { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; - let input_schema = input.output_schema().clone(); + let mut exprs_map = HashMap::with_capacity(self.value_exprs.len()); + for (column, expr) in self.value_exprs.drain(..) { + exprs_map.insert(column.id(), expr); + } - if let Some(table_catalog) = throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, table_name.clone()) - ) + if let Some(table_catalog) = arena + .transaction_mut() + .table(arena.table_cache(), self.table_name.clone())? .cloned() - { - let serializers = input_schema - .iter() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let mut index_metas = Vec::new(); - for index_meta in table_catalog.indexes() { - let mut exprs = throw!(co, index_meta.column_exprs(&table_catalog)); - throw!( - co, - BindPosition::bind_exprs( - exprs.iter_mut(), - || input_schema.iter().map(Cow::Borrowed), - |a, b| a == b - ) - ); - index_metas.push((index_meta, exprs)); - } - - let mut coroutine = build_read(input, cache, transaction); - let mut updated_count = 0; + { + let serializers = self + .input_schema + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let mut index_metas = Vec::new(); + for index_meta in table_catalog.indexes() { + let exprs = index_meta.column_exprs(&table_catalog)?; + index_metas.push((index_meta, exprs)); + } - for tuple in coroutine.by_ref() { - let mut tuple: Tuple = throw!(co, tuple); + let mut updated_count = 0; - let mut is_overwrite = true; + while arena.next_tuple(input)? { + let mut tuple = arena.result_tuple().clone(); + let mut is_overwrite = true; - let old_pk = throw!( - co, - tuple.pk.clone().ok_or(DatabaseError::PrimaryKeyNotFound) - ); - for (index_meta, exprs) in index_metas.iter() { - let values = - throw!(co, Projection::projection(&tuple, exprs, &input_schema)); - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; - let index = Index::new(index_meta.id, &value, index_meta.ty); - throw!( - co, - unsafe { &mut (*transaction) }.del_index(&table_name, &index, &old_pk) - ); - } - for (i, column) in input_schema.iter().enumerate() { - if let Some(expr) = exprs_map.get(&column.id()) { - tuple.values[i] = throw!(co, expr.eval(Some((&tuple, &input_schema)))); - } + let old_pk = tuple.pk.clone().ok_or(DatabaseError::PrimaryKeyNotFound)?; + for (index_meta, exprs) in index_metas.iter() { + let values = Projection::projection(&tuple, exprs, &self.input_schema)?; + let Some(value) = DataValue::values_to_tuple(values) else { + continue; + }; + let index = Index::new(index_meta.id, &value, index_meta.ty); + arena + .transaction_mut() + .del_index(&self.table_name, &index, &old_pk)?; + } + for (i, column) in self.input_schema.iter().enumerate() { + if let Some(expr) = exprs_map.get(&column.id()) { + let value = expr.eval(Some((&tuple, &self.input_schema)))?; + tuple.values[i] = value; } + } - tuple.pk = Some(Tuple::primary_projection( - table_catalog.primary_keys_indices(), - &tuple.values, - )); - let new_pk = throw!( - co, - tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound) - ); + tuple.pk = Some(Tuple::primary_projection( + table_catalog.primary_keys_indices(), + &tuple.values, + )); + let new_pk = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; - if new_pk != &old_pk { - throw!( - co, - unsafe { &mut (*transaction) }.remove_tuple(&table_name, &old_pk) - ); - is_overwrite = false; - } - for (index_meta, exprs) in index_metas.iter() { - let values = - throw!(co, Projection::projection(&tuple, exprs, &input_schema)); - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; - let index = Index::new(index_meta.id, &value, index_meta.ty); - throw!( - co, - unsafe { &mut (*transaction) }.add_index(&table_name, index, new_pk) - ); - } - - throw!( - co, - unsafe { &mut (*transaction) }.append_tuple( - &table_name, - tuple, - &serializers, - is_overwrite - ) - ); - updated_count += 1; + if new_pk != &old_pk { + arena + .transaction_mut() + .remove_tuple(&self.table_name, &old_pk)?; + is_overwrite = false; + } + for (index_meta, exprs) in index_metas.iter() { + let values = Projection::projection(&tuple, exprs, &self.input_schema)?; + let Some(value) = DataValue::values_to_tuple(values) else { + continue; + }; + let index = Index::new(index_meta.id, &value, index_meta.ty); + arena + .transaction_mut() + .add_index(&self.table_name, index, new_pk)?; } - drop(coroutine); - co.yield_(Ok(TupleBuilder::build_result(updated_count.to_string()))) - .await; - } else { - co.yield_(Ok(TupleBuilder::build_result("0".to_string()))) - .await; + arena.transaction_mut().append_tuple( + &self.table_name, + tuple, + &serializers, + is_overwrite, + )?; + updated_count += 1; } - }) + + TupleBuilder::build_result_into(arena.result_tuple_mut(), updated_count.to_string()); + arena.resume(); + Ok(()) + } else { + TupleBuilder::build_result_into(arena.result_tuple_mut(), "0".to_string()); + arena.resume(); + Ok(()) + } } } diff --git a/src/execution/dql/aggregate/hash_agg.rs b/src/execution/dql/aggregate/hash_agg.rs index a38c6d58..b7d67248 100644 --- a/src/execution/dql/aggregate/hash_agg.rs +++ b/src/execution/dql/aggregate/hash_agg.rs @@ -14,22 +14,26 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::{create_accumulators, Accumulator}; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; use crate::types::value::DataValue; use ahash::{HashMap, HashMapExt}; use itertools::Itertools; -use std::collections::hash_map::Entry; +use std::collections::hash_map::{Entry, IntoIter as HashMapIntoIter}; + +type HashAggOutput = HashMapIntoIter, Vec>>; pub struct HashAggExecutor { agg_calls: Vec, groupby_exprs: Vec, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, + output: Option, } impl From<(AggregateOperator, LogicalPlan)> for HashAggExecutor { @@ -40,84 +44,96 @@ impl From<(AggregateOperator, LogicalPlan)> for HashAggExecutor { groupby_exprs, .. }, - input, + mut input, ): (AggregateOperator, LogicalPlan), ) -> Self { HashAggExecutor { agg_calls, groupby_exprs, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: 0, + output: None, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashAggExecutor { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let HashAggExecutor { - agg_calls, - groupby_exprs, - mut input, - } = self; + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("hash aggregate input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::HashAgg(self)) + } +} - let schema_ref = input.output_schema().clone(); +impl HashAggExecutor { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.output.is_none() { let mut group_hash_accs: HashMap, Vec>> = HashMap::new(); - let mut executor = build_read(input, cache, transaction); - - for result in executor.by_ref() { - let tuple = throw!(co, result); - let mut values = Vec::with_capacity(agg_calls.len()); - - for expr in agg_calls.iter() { - if let ScalarExpression::AggCall { args, .. } = expr { - if args.len() > 1 { - throw!(co, Err(DatabaseError::UnsupportedStmt( - "currently aggregate functions only support a single Column as a parameter" - .to_string() - ))) - } - values.push(throw!(co, args[0].eval(Some((&tuple, &schema_ref))))); - } else { - unreachable!() - } - } - let group_keys: Vec = throw!( - co, - groupby_exprs - .iter() - .map(|expr| expr.eval(Some((&tuple, &schema_ref)))) - .try_collect() - ); + while arena.next_tuple(self.input)? { + let tuple = arena.result_tuple(); + let group_keys = self + .groupby_exprs + .iter() + .map(|expr| expr.eval(Some((tuple, &self.input_schema)))) + .try_collect()?; let entry = match group_hash_accs.entry(group_keys) { Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => { - entry.insert(throw!(co, create_accumulators(&agg_calls))) - } + Entry::Vacant(entry) => entry.insert(create_accumulators(&self.agg_calls)?), }; - for (acc, value) in entry.iter_mut().zip_eq(values.iter()) { - throw!(co, acc.update_value(value)); + + for (acc, expr) in entry.iter_mut().zip_eq(self.agg_calls.iter()) { + let ScalarExpression::AggCall { args, .. } = expr else { + unreachable!() + }; + if args.len() > 1 { + return Err(DatabaseError::UnsupportedStmt( + "currently aggregate functions only support a single Column as a parameter" + .to_string(), + )); + } + let value = args[0].eval(Some((tuple, &self.input_schema)))?; + acc.update_value(&value)?; } } - for (group_keys, accs) in group_hash_accs { - // Tips: Accumulator First - let values: Vec = throw!( - co, - accs.iter() - .map(|acc| acc.evaluate()) - .chain(group_keys.into_iter().map(Ok)) - .try_collect() - ); - co.yield_(Ok(Tuple::new(None, values))).await; - } - }) + self.output = Some(group_hash_accs.into_iter()); + } + + let Some((group_keys, accs)) = self.output.as_mut().and_then(Iterator::next) else { + arena.finish(); + return Ok(()); + }; + + let output = arena.result_tuple_mut(); + + output.pk = None; + output.values.clear(); + output.values.reserve(accs.len() + group_keys.len()); + + for acc in accs.iter() { + output.values.push(acc.evaluate()?); + } + output.values.extend(group_keys); + arena.resume(); + Ok(()) } } @@ -127,7 +143,7 @@ mod test { use crate::errors::DatabaseError; use crate::execution::dql::aggregate::hash_agg::HashAggExecutor; use crate::execution::dql::test::build_integers; - use crate::execution::{try_collect, ReadExecutor}; + use crate::execution::try_collect; use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; use crate::optimizer::heuristic::batch::HepBatchStrategy; @@ -196,11 +212,11 @@ mod test { }; let plan = LogicalPlan::new( Operator::Aggregate(AggregateOperator { - groupby_exprs: vec![ScalarExpression::column_expr(t1_schema[0].clone())], + groupby_exprs: vec![ScalarExpression::column_expr(t1_schema[0].clone(), 0)], agg_calls: vec![ScalarExpression::AggCall { distinct: false, kind: AggKind::Sum, - args: vec![ScalarExpression::column_expr(t1_schema[1].clone())], + args: vec![ScalarExpression::column_expr(t1_schema[1].clone(), 1)], ty: LogicalType::Integer, }], is_distinct: false, @@ -212,11 +228,7 @@ mod test { .before_batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::BindExpressionPosition, - // TIPS: This rule is necessary - NormalizationRuleImpl::EvaluatorBind, - ], + vec![NormalizationRuleImpl::EvaluatorBind], ) .build(); let plan = pipeline @@ -226,10 +238,11 @@ mod test { let Operator::Aggregate(op) = plan.operator else { unreachable!() }; - let tuples = try_collect( - HashAggExecutor::from((op, plan.childrens.pop_only())) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let tuples = try_collect(crate::execution::execute( + HashAggExecutor::from((op, plan.childrens.pop_only())), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; assert_eq!(tuples.len(), 2); diff --git a/src/execution/dql/aggregate/mod.rs b/src/execution/dql/aggregate/mod.rs index 4d2b2f2b..4bbba9e6 100644 --- a/src/execution/dql/aggregate/mod.rs +++ b/src/execution/dql/aggregate/mod.rs @@ -63,6 +63,7 @@ fn create_accumulator(expr: &ScalarExpression) -> Result, D } } +#[inline] pub(crate) fn create_accumulators( exprs: &[ScalarExpression], ) -> Result>, DatabaseError> { diff --git a/src/execution/dql/aggregate/simple_agg.rs b/src/execution/dql/aggregate/simple_agg.rs index c48aa0af..5f880b67 100644 --- a/src/execution/dql/aggregate/simple_agg.rs +++ b/src/execution/dql/aggregate/simple_agg.rs @@ -12,70 +12,91 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::DatabaseError; use crate::execution::dql::aggregate::create_accumulators; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; -use crate::types::value::DataValue; -use itertools::Itertools; +use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; pub struct SimpleAggExecutor { agg_calls: Vec, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: Option, } impl From<(AggregateOperator, LogicalPlan)> for SimpleAggExecutor { fn from( - (AggregateOperator { agg_calls, .. }, input): (AggregateOperator, LogicalPlan), + (AggregateOperator { agg_calls, .. }, mut input): (AggregateOperator, LogicalPlan), ) -> Self { - SimpleAggExecutor { agg_calls, input } + SimpleAggExecutor { + agg_calls, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: None, + } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SimpleAggExecutor { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let SimpleAggExecutor { - agg_calls, - mut input, - } = self; - - let mut accs = throw!(co, create_accumulators(&agg_calls)); - let schema = input.output_schema().clone(); - - let mut executor = build_read(input, cache, transaction); + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("simple aggregate input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::SimpleAgg(self)) + } +} - for tuple in executor.by_ref() { - let tuple = throw!(co, tuple); +impl SimpleAggExecutor { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; - let values: Vec = throw!( - co, - agg_calls - .iter() - .map(|expr| match expr { - ScalarExpression::AggCall { args, .. } => { - args[0].eval(Some((&tuple, &schema))) - } - _ => unreachable!(), - }) - .try_collect() - ); + let mut accs = create_accumulators(&self.agg_calls)?; - for (acc, value) in accs.iter_mut().zip_eq(values.iter()) { - throw!(co, acc.update_value(value)); + while arena.next_tuple(input)? { + let tuple = arena.result_tuple(); + for (acc, expr) in accs.iter_mut().zip(self.agg_calls.iter()) { + let ScalarExpression::AggCall { args, .. } = expr else { + unreachable!() + }; + if args.len() > 1 { + return Err(DatabaseError::UnsupportedStmt( + "currently aggregate functions only support a single Column as a parameter" + .to_string(), + )); } + + let value = args[0].eval(Some((tuple, &self.input_schema)))?; + acc.update_value(&value)?; } - let values: Vec = - throw!(co, accs.into_iter().map(|acc| acc.evaluate()).try_collect()); + } - co.yield_(Ok(Tuple::new(None, values))).await; - }) + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + output.values.reserve(accs.len()); + for acc in accs { + output.values.push(acc.evaluate()?); + } + arena.resume(); + Ok(()) } } diff --git a/src/execution/dql/aggregate/stream_distinct.rs b/src/execution/dql/aggregate/stream_distinct.rs index d1e73200..2d3758a4 100644 --- a/src/execution/dql/aggregate/stream_distinct.rs +++ b/src/execution/dql/aggregate/stream_distinct.rs @@ -12,62 +12,84 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; +use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::DataValue; use itertools::Itertools; pub struct StreamDistinctExecutor { groupby_exprs: Vec, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, + last_keys: Option>, + scratch: Tuple, } impl From<(AggregateOperator, LogicalPlan)> for StreamDistinctExecutor { - fn from((op, input): (AggregateOperator, LogicalPlan)) -> Self { + fn from((op, mut input): (AggregateOperator, LogicalPlan)) -> Self { StreamDistinctExecutor { groupby_exprs: op.groupby_exprs, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: 0, + last_keys: None, + scratch: Tuple::default(), } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for StreamDistinctExecutor { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let StreamDistinctExecutor { - groupby_exprs, - mut input, - } = self; - - let schema_ref = input.output_schema().clone(); - let mut executor = build_read(input, cache, transaction); - let mut last_keys: Option> = None; - - for result in executor.by_ref() { - let tuple = throw!(co, result); - let group_keys: Vec = throw!( - co, - groupby_exprs - .iter() - .map(|expr| expr.eval(Some((&tuple, &schema_ref)))) - .try_collect() - ); - - if last_keys.as_ref() != Some(&group_keys) { - last_keys = Some(group_keys.clone()); - co.yield_(Ok(Tuple::new(tuple.pk, group_keys))).await; - } + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("stream distinct input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::StreamDistinct(self)) + } +} + +impl StreamDistinctExecutor { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + loop { + if !arena.next_tuple(self.input)? { + arena.finish(); + return Ok(()); + } + std::mem::swap(&mut self.scratch, arena.result_tuple_mut()); + let tuple = &self.scratch; + let group_keys = self + .groupby_exprs + .iter() + .map(|expr| expr.eval(Some((tuple, &self.input_schema)))) + .try_collect()?; + + if self.last_keys.as_ref() != Some(&group_keys) { + self.last_keys = Some(group_keys.clone()); + let output = arena.result_tuple_mut(); + output.pk.clone_from(&tuple.pk); + output.values = group_keys; + arena.resume(); + return Ok(()); } - }) + } } } @@ -76,7 +98,7 @@ mod tests { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::errors::DatabaseError; use crate::execution::dql::aggregate::stream_distinct::StreamDistinctExecutor; - use crate::execution::{try_collect, ReadExecutor}; + use crate::execution::try_collect; use crate::expression::ScalarExpression; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; @@ -121,10 +143,7 @@ mod tests { .before_batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::BindExpressionPosition, - NormalizationRuleImpl::EvaluatorBind, - ], + vec![NormalizationRuleImpl::EvaluatorBind], ) .build() .instantiate(plan) @@ -154,7 +173,7 @@ mod tests { Childrens::None, ); let agg = AggregateOperator { - groupby_exprs: vec![ScalarExpression::column_expr(schema_ref[0].clone())], + groupby_exprs: vec![ScalarExpression::column_expr(schema_ref[0].clone(), 0)], agg_calls: vec![], is_distinct: true, }; @@ -166,10 +185,11 @@ mod tests { let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; let mut transaction = storage.transaction()?; - let tuples = try_collect( - StreamDistinctExecutor::from((agg, plan.childrens.pop_only())) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let tuples = try_collect(crate::execution::execute( + StreamDistinctExecutor::from((agg, plan.childrens.pop_only())), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; let actual = tuples .into_iter() @@ -204,8 +224,8 @@ mod tests { ); let agg = AggregateOperator { groupby_exprs: vec![ - ScalarExpression::column_expr(schema_ref[0].clone()), - ScalarExpression::column_expr(schema_ref[1].clone()), + ScalarExpression::column_expr(schema_ref[0].clone(), 0), + ScalarExpression::column_expr(schema_ref[1].clone(), 1), ], agg_calls: vec![], is_distinct: true, @@ -218,22 +238,21 @@ mod tests { let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; let mut transaction = storage.transaction()?; - let tuples = try_collect( - StreamDistinctExecutor::from((agg, plan.childrens.pop_only())) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let tuples = try_collect(crate::execution::execute( + StreamDistinctExecutor::from((agg, plan.childrens.pop_only())), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; - let actual = tuples - .into_iter() - .map(|tuple| { - tuple - .values - .into_iter() - .flat_map(|value| value.i32()) - .collect_vec() - }) - .collect_vec(); - assert_eq!(actual, vec![vec![1, 1], vec![1, 2], vec![2, 1]]); + let actual = tuples.into_iter().map(|tuple| tuple.values).collect_vec(); + assert_eq!( + actual, + vec![ + vec![DataValue::Int32(1), DataValue::Int32(1)], + vec![DataValue::Int32(1), DataValue::Int32(2)], + vec![DataValue::Int32(2), DataValue::Int32(1)], + ] + ); Ok(()) } diff --git a/src/execution/dql/describe.rs b/src/execution/dql/describe.rs index 353833fc..e1011bcd 100644 --- a/src/execution/dql/describe.rs +++ b/src/execution/dql/describe.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::{ColumnCatalog, TableName}; -use crate::execution::DatabaseError; -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::describe::DescribeOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; use sqlparser::ast::CharLengthUnits; use std::sync::LazyLock; @@ -43,82 +41,112 @@ static EMPTY_KEY_TYPE: LazyLock = LazyLock::new(|| DataValue::Utf8 { pub struct Describe { table_name: TableName, + columns: Option>, + cursor: usize, } impl From for Describe { fn from(op: DescribeOperator) -> Self { Describe { table_name: op.table_name, + columns: None, + cursor: 0, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Describe { - fn execute( + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let table = throw!( - co, - throw!( - co, - unsafe { &mut (*transaction) }.table(cache.0, self.table_name.clone()) - ) - .ok_or(DatabaseError::TableNotFound) - ); - let key_fn = |column: &ColumnCatalog| { - if column.desc().is_primary() { - PRIMARY_KEY_TYPE.clone() - } else if column.desc().is_unique() { - UNIQUE_KEY_TYPE.clone() - } else { - EMPTY_KEY_TYPE.clone() - } - }; + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::Describe(self)) + } +} + +impl Describe { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.columns.is_none() { + let table = arena + .transaction_mut() + .table(arena.table_cache(), self.table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + self.columns = Some(table.columns().cloned().collect()); + } + + let Some(column) = self + .columns + .as_ref() + .and_then(|columns| columns.get(self.cursor)) + .cloned() + else { + arena.finish(); + return Ok(()); + }; + + self.cursor += 1; + + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + fill_describe_row(&mut output.values, &column); + + arena.resume(); + Ok(()) + } +} + +fn fill_describe_row(values: &mut Vec, column: &ColumnCatalog) { + let datatype = column.datatype(); + let default = column + .desc() + .default + .as_ref() + .map(|expr| format!("{expr}")) + .unwrap_or_else(|| "null".to_string()); + + values.push(DataValue::Utf8 { + value: column.name().to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); + values.push(DataValue::Utf8 { + value: datatype.to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); + values.push(DataValue::Utf8 { + value: datatype + .raw_len() + .map(|len| len.to_string()) + .unwrap_or_else(|| "variable".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); + values.push(DataValue::Utf8 { + value: column.nullable().to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); + values.push(key_value(column)); + values.push(DataValue::Utf8 { + value: default, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); +} - for column in table.columns() { - let datatype = column.datatype(); - let default = column - .desc() - .default - .as_ref() - .map(|expr| format!("{expr}")) - .unwrap_or_else(|| "null".to_string()); - let values = vec![ - DataValue::Utf8 { - value: column.name().to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - DataValue::Utf8 { - value: datatype.to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - DataValue::Utf8 { - value: datatype - .raw_len() - .map(|len| len.to_string()) - .unwrap_or_else(|| "variable".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - DataValue::Utf8 { - value: column.nullable().to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - key_fn(column), - DataValue::Utf8 { - value: default, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - ]; - co.yield_(Ok(Tuple::new(None, values))).await; - } - }) +fn key_value(column: &ColumnCatalog) -> DataValue { + if column.desc().is_primary() { + PRIMARY_KEY_TYPE.clone() + } else if column.desc().is_unique() { + UNIQUE_KEY_TYPE.clone() + } else { + EMPTY_KEY_TYPE.clone() } } diff --git a/src/execution/dql/dummy.rs b/src/execution/dql/dummy.rs index 750d0097..6d7c26cd 100644 --- a/src/execution/dql/dummy.rs +++ b/src/execution/dql/dummy.rs @@ -12,20 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; +use crate::storage::Transaction; use crate::types::tuple::Tuple; -pub struct Dummy {} +pub struct Dummy { + row: Option, +} + +impl Default for Dummy { + fn default() -> Self { + Self { + row: Some(Tuple::new(None, Vec::new())), + } + } +} impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Dummy { - fn execute( + fn into_executor( self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, _: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - co.yield_(Ok(Tuple::new(None, Vec::new()))).await; - }) + ) -> ExecId { + arena.push(ExecNode::Dummy(self)) + } +} + +impl Dummy { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(row) = self.row.take() else { + arena.finish(); + return Ok(()); + }; + arena.produce_tuple(row); + Ok(()) } } diff --git a/src/execution/dql/except.rs b/src/execution/dql/except.rs index 70700396..912eca2c 100644 --- a/src/execution/dql/except.rs +++ b/src/execution/dql/except.rs @@ -12,58 +12,92 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; +use crate::types::tuple::Tuple; use ahash::{HashMap, HashMapExt}; pub struct Except { - left_input: LogicalPlan, - right_input: LogicalPlan, + left_plan: Option, + right_plan: Option, + left_input: ExecId, + right_input: ExecId, + except_col: HashMap, + built: bool, } impl From<(LogicalPlan, LogicalPlan)> for Except { fn from((left_input, right_input): (LogicalPlan, LogicalPlan)) -> Self { Except { - left_input, - right_input, + left_plan: Some(left_input), + right_plan: Some(right_input), + left_input: 0, + right_input: 0, + except_col: HashMap::new(), + built: false, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Except { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Except { - left_input, - right_input, - } = self; - - let mut coroutine = build_read(right_input, cache, transaction); - - let mut except_col = HashMap::new(); + ) -> ExecId { + self.left_input = build_read( + arena, + self.left_plan + .take() + .expect("except left input plan initialized"), + cache, + transaction, + ); + self.right_input = build_read( + arena, + self.right_plan + .take() + .expect("except right input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Except(self)) + } +} - for tuple in coroutine.by_ref() { - let tuple = throw!(co, tuple); - *except_col.entry(tuple).or_insert(0usize) += 1; +impl Except { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if !self.built { + while arena.next_tuple(self.right_input)? { + *self + .except_col + .entry(arena.result_tuple().clone()) + .or_insert(0) += 1; } + self.built = true; + } - let coroutine = build_read(left_input, cache, transaction); + loop { + if !arena.next_tuple(self.left_input)? { + arena.finish(); + return Ok(()); + } + let tuple = arena.result_tuple(); - for tuple in coroutine { - let tuple = throw!(co, tuple); - if let Some(count) = except_col.get_mut(&tuple) { - if *count > 0 { - *count -= 1; - continue; - } + if let Some(count) = self.except_col.get_mut(tuple) { + if *count > 0 { + *count -= 1; + continue; } - co.yield_(Ok(tuple)).await; } - }) + + arena.resume(); + return Ok(()); + } } } diff --git a/src/execution/dql/explain.rs b/src/execution/dql/explain.rs index aa1ad3a2..d1e5f4df 100644 --- a/src/execution/dql/explain.rs +++ b/src/execution/dql/explain.rs @@ -12,37 +12,54 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; use sqlparser::ast::CharLengthUnits; pub struct Explain { - plan: LogicalPlan, + plan: Option, } impl From for Explain { fn from(plan: LogicalPlan) -> Self { - Explain { plan } + Explain { plan: Some(plan) } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Explain { - fn execute( + fn into_executor( self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, _: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let values = vec![DataValue::Utf8 { - value: self.plan.explain(0), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }]; + ) -> ExecId { + arena.push(ExecNode::Explain(self)) + } +} + +impl Explain { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(plan) = self.plan.take() else { + arena.finish(); + return Ok(()); + }; + + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + output.values.push(DataValue::Utf8 { + value: plan.explain(0), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); - co.yield_(Ok(Tuple::new(None, values))).await; - }) + arena.resume(); + Ok(()) } } diff --git a/src/execution/dql/filter.rs b/src/execution/dql/filter.rs index d212807e..d316c28a 100644 --- a/src/execution/dql/filter.rs +++ b/src/execution/dql/filter.rs @@ -12,49 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::filter::FilterOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; pub struct Filter { predicate: ScalarExpression, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, } impl From<(FilterOperator, LogicalPlan)> for Filter { - fn from((FilterOperator { predicate, .. }, input): (FilterOperator, LogicalPlan)) -> Self { - Filter { predicate, input } + fn from((FilterOperator { predicate, .. }, mut input): (FilterOperator, LogicalPlan)) -> Self { + let input_schema = input.output_schema().clone(); + Filter { + predicate, + input_schema, + input_plan: Some(input), + input: 0, + } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Filter { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Filter { - predicate, - mut input, - } = self; - - let schema = input.output_schema().clone(); - - let executor = build_read(input, cache, transaction); + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("filter input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Filter(self)) + } +} - for tuple in executor { - let tuple = throw!(co, tuple); +impl Filter { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + loop { + if !arena.next_tuple(self.input)? { + arena.finish(); + return Ok(()); + }; + let tuple = arena.result_tuple(); - if throw!( - co, - throw!(co, predicate.eval(Some((&tuple, &schema)))).is_true() - ) { - co.yield_(Ok(tuple)).await; - } + if self + .predicate + .eval(Some((tuple, &self.input_schema)))? + .is_true()? + { + arena.resume(); + return Ok(()); } - }) + } } } diff --git a/src/execution/dql/function_scan.rs b/src/execution/dql/function_scan.rs index d3008b95..e8c1f689 100644 --- a/src/execution/dql/function_scan.rs +++ b/src/execution/dql/function_scan.rs @@ -12,35 +12,54 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::function::table::TableFunction; use crate::planner::operator::function_scan::FunctionScanOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; +use crate::types::tuple::Tuple; pub struct FunctionScan { table_function: TableFunction, + iter: Option>>>, } impl From for FunctionScan { fn from(op: FunctionScanOperator) -> Self { FunctionScan { table_function: op.table_function, + iter: None, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for FunctionScan { - fn execute( + fn into_executor( self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, _: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let TableFunction { args, inner } = self.table_function; - for tuple in throw!(co, inner.eval(&args)) { - co.yield_(tuple).await; - } - }) + ) -> ExecId { + arena.push(ExecNode::FunctionScan(self)) + } +} + +impl FunctionScan { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.iter.is_none() { + let TableFunction { args, inner } = &self.table_function; + self.iter = Some(inner.eval(args)?); + } + + let tuple = self.iter.as_mut().and_then(Iterator::next).transpose()?; + let Some(tuple) = tuple else { + arena.finish(); + return Ok(()); + }; + arena.produce_tuple(tuple); + Ok(()) } } diff --git a/src/execution/dql/index_scan.rs b/src/execution/dql/index_scan.rs index 1fb046ba..b5181599 100644 --- a/src/execution/dql/index_scan.rs +++ b/src/execution/dql/index_scan.rs @@ -12,30 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::range_detacher::Range; use crate::planner::operator::table_scan::TableScanOperator; -use crate::storage::{Iter, StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::{IndexIter, Iter, Transaction}; use crate::types::index::IndexMetaRef; use crate::types::serialize::TupleValueSerializableImpl; -pub(crate) struct IndexScan { - op: TableScanOperator, +pub(crate) struct IndexScan<'a, T: Transaction + 'a> { + op: Option, index_by: IndexMetaRef, ranges: Vec, covered_deserializers: Option>, cover_mapping: Option>, + iter: Option>, } -impl +impl<'a, T: Transaction + 'a> From<( TableScanOperator, IndexMetaRef, Range, Option>, Option>, - )> for IndexScan + )> for IndexScan<'a, T> { fn from( (op, index_by, range, covered_deserializers, cover_mapping): ( @@ -52,48 +53,64 @@ impl }; IndexScan { - op, + op: Some(op), index_by, ranges, covered_deserializers, cover_mapping, + iter: None, } } } -impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for IndexScan { - fn execute( +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for IndexScan<'a, T> { + fn into_executor( self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let TableScanOperator { + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::IndexScan(self)) + } +} + +impl<'a, T: Transaction + 'a> IndexScan<'a, T> { + pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { + if self.iter.is_none() { + let Some(TableScanOperator { table_name, columns, limit, with_pk, .. - } = self.op; - - let mut iter = throw!( - co, - unsafe { &(*transaction) }.read_by_index( - table_cache, - table_name, - limit, - columns, - self.index_by, - self.ranges, - with_pk, - self.covered_deserializers, - self.cover_mapping, - ) - ); + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; + self.iter = Some(arena.transaction().read_by_index( + arena.table_cache(), + table_name, + limit, + columns, + self.index_by.clone(), + std::mem::take(&mut self.ranges), + with_pk, + self.covered_deserializers.take(), + self.cover_mapping.take(), + )?); + } - while let Some(tuple) = throw!(co, iter.next_tuple()) { - co.yield_(Ok(tuple)).await; - } - }) + if self + .iter + .as_mut() + .expect("index scan iterator initialized") + .next_tuple_into(arena.result_tuple_mut())? + { + arena.resume(); + } else { + arena.finish(); + } + Ok(()) } } diff --git a/src/execution/dql/join/hash/full_join.rs b/src/execution/dql/join/hash/full_join.rs index 35cab767..fdaa97c4 100644 --- a/src/execution/dql/join/hash/full_join.rs +++ b/src/execution/dql/join/hash/full_join.rs @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; +use crate::errors::DatabaseError; +use crate::execution::dql::join::hash::{ + filter, FilterArgs, JoinProbeState, LeftDropState, LeftDropTuples, ProbeState, +}; use crate::execution::dql::join::hash_join::BuildState; -use crate::execution::dql::sort::BumpVec; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; use crate::types::tuple::Tuple; use crate::types::value::DataValue; -use ahash::HashMap; use fixedbitset::FixedBitSet; pub(crate) struct FullJoinState { @@ -28,79 +27,104 @@ pub(crate) struct FullJoinState { pub(crate) bits: FixedBitSet, } -impl<'a> JoinProbeState<'a> for FullJoinState { - fn probe( +impl JoinProbeState for FullJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - let left_schema_len = self.left_schema_len; - let bits_ptr: *mut FixedBitSet = &mut self.bits; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + if probe_state.is_keys_has_null { + if probe_state.emitted_unmatched { + probe_state.finished = true; + return Ok(None); + } + probe_state.emitted_unmatched = true; + probe_state.finished = true; + return Ok(Some(Self::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); + } - spawn_executor(move |co| async move { - let ProbeArgs { probe_tuple, .. } = probe_args; + let Some(build_state) = build_state else { + if probe_state.emitted_unmatched { + probe_state.finished = true; + return Ok(None); + } + probe_state.emitted_unmatched = true; + probe_state.finished = true; + return Ok(Some(Self::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); + }; - if let ProbeArgs { - is_keys_has_null: false, - build_state: Some(build_state), - .. - } = probe_args - { - let mut has_filtered = false; - for (i, Tuple { values, pk }) in build_state.tuples.iter() { - let full_values = - Vec::from_iter(values.iter().chain(probe_tuple.values.iter()).cloned()); + if probe_state.index < build_state.tuples.len() { + let (i, Tuple { values, pk }) = &build_state.tuples[probe_state.index]; + probe_state.index += 1; + let full_values = Vec::from_iter( + values + .iter() + .chain(probe_state.probe_tuple.values.iter()) + .cloned(), + ); - match &filter_args { - None => (), - Some(filter_args) => { - if !throw!(co, filter(&full_values, filter_args)) { - has_filtered = true; - unsafe { - (*bits_ptr).set(*i, true); - } - co.yield_(Ok(Self::full_right_row(left_schema_len, &probe_tuple))) - .await; - continue; - } - } - } - co.yield_(Ok(Tuple::new(pk.clone(), full_values))).await; + if let Some(filter_args) = filter_args { + if !filter(&full_values, filter_args)? { + probe_state.has_filtered = true; + self.bits.set(*i, true); + return Ok(Some(Self::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); } - build_state.is_used = !has_filtered; - build_state.has_filted = has_filtered; - return; } + build_state.is_used = true; + build_state.has_filted = probe_state.has_filtered; + return Ok(Some(Tuple::new(pk.clone(), full_values))); + } - co.yield_(Ok(Self::full_right_row(left_schema_len, &probe_tuple))) - .await; - }) + build_state.is_used = !probe_state.has_filtered; + build_state.has_filted = probe_state.has_filtered; + probe_state.finished = true; + Ok(None) } - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - _filter_args: Option<&'a FilterArgs>, - ) -> Option> { + left_drop_state: &mut LeftDropState, + _filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { let full_schema_len = self.right_schema_len + self.left_schema_len; - let bits_ptr: *mut FixedBitSet = &mut self.bits; - Some(spawn_executor(move |co| async move { - for (_, state) in _build_map { - if state.is_used { - continue; - } - for (i, mut left_tuple) in state.tuples { - unsafe { - if !(*bits_ptr).contains(i) && state.has_filted { - continue; - } + loop { + if let Some(LeftDropTuples { + tuples, has_filted, .. + }) = left_drop_state.current.as_mut() + { + for (i, mut left_tuple) in tuples.by_ref() { + if !self.bits.contains(i) && *has_filted { + continue; } left_tuple.values.resize(full_schema_len, DataValue::Null); - co.yield_(Ok(left_tuple)).await; + return Ok(Some(left_tuple)); } + left_drop_state.current = None; + } + + let Some((_, state)) = left_drop_state.states.next() else { + return Ok(None); + }; + + if state.is_used { + continue; } - })) + left_drop_state.current = Some(LeftDropTuples { + tuples: state.tuples.into_iter(), + has_filted: state.has_filted, + }); + } } } diff --git a/src/execution/dql/join/hash/inner_join.rs b/src/execution/dql/join/hash/inner_join.rs index 183c90f6..d9413fea 100644 --- a/src/execution/dql/join/hash/inner_join.rs +++ b/src/execution/dql/join/hash/inner_join.rs @@ -12,45 +12,50 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; +use crate::errors::DatabaseError; +use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeState}; +use crate::execution::dql::join::hash_join::BuildState; use crate::types::tuple::Tuple; pub(crate) struct InnerJoinState; -impl<'a> JoinProbeState<'a> for InnerJoinState { - fn probe( +impl JoinProbeState for InnerJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let ProbeArgs { - is_keys_has_null: false, - probe_tuple, - build_state: Some(build_state), - .. - } = probe_args - else { - return; - }; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + if probe_state.is_keys_has_null { + probe_state.finished = true; + return Ok(None); + } - build_state.is_used = true; - for (_, Tuple { values, pk }) in build_state.tuples.iter() { - let full_values = - Vec::from_iter(values.iter().chain(probe_tuple.values.iter()).cloned()); + let Some(build_state) = build_state else { + probe_state.finished = true; + return Ok(None); + }; - match &filter_args { - None => (), - Some(filter_args) => { - if !throw!(co, filter(&full_values, filter_args)) { - continue; - } - } + build_state.is_used = true; + while probe_state.index < build_state.tuples.len() { + let (_, Tuple { values, pk }) = &build_state.tuples[probe_state.index]; + probe_state.index += 1; + let full_values = Vec::from_iter( + values + .iter() + .chain(probe_state.probe_tuple.values.iter()) + .cloned(), + ); + + if let Some(filter_args) = filter_args { + if !filter(&full_values, filter_args)? { + continue; } - co.yield_(Ok(Tuple::new(pk.clone(), full_values))).await; } - }) + return Ok(Some(Tuple::new(pk.clone(), full_values))); + } + + probe_state.finished = true; + Ok(None) } } diff --git a/src/execution/dql/join/hash/left_anti_join.rs b/src/execution/dql/join/hash/left_anti_join.rs index bf9ff524..f2291c30 100644 --- a/src/execution/dql/join/hash/left_anti_join.rs +++ b/src/execution/dql/join/hash/left_anti_join.rs @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::DatabaseError; use crate::execution::dql::join::hash::left_semi_join::LeftSemiJoinState; -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; -use crate::execution::dql::join::hash_join::BuildState; -use crate::execution::dql::sort::BumpVec; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; +use crate::execution::dql::join::hash::{ + filter, FilterArgs, JoinProbeState, LeftDropState, LeftDropTuples, ProbeState, +}; +use crate::types::tuple::Tuple; use crate::types::value::DataValue; -use ahash::HashMap; use fixedbitset::FixedBitSet; pub(crate) struct LeftAntiJoinState { @@ -27,40 +26,32 @@ pub(crate) struct LeftAntiJoinState { pub(crate) inner: LeftSemiJoinState, } -impl<'a> JoinProbeState<'a> for LeftAntiJoinState { - fn probe( +impl JoinProbeState for LeftAntiJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - self.inner.probe(probe_args, filter_args) + probe_state: &mut ProbeState, + build_state: Option<&mut crate::execution::dql::join::hash_join::BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + self.inner.probe_next(probe_state, build_state, filter_args) } - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - filter_args: Option<&'a FilterArgs>, - ) -> Option> { - let bits_ptr: *mut FixedBitSet = &mut self.inner.bits; + left_drop_state: &mut LeftDropState, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + let bits: &FixedBitSet = &self.inner.bits; let right_schema_len = self.right_schema_len; - Some(spawn_executor(move |co| async move { - for ( - _, - BuildState { - tuples, - is_used, - has_filted, - }, - ) in _build_map + + loop { + if let Some(LeftDropTuples { + tuples, has_filted, .. + }) = left_drop_state.current.as_mut() { - if is_used { - continue; - } - for (i, tuple) in tuples { - unsafe { - if (*bits_ptr).contains(i) && has_filted { - continue; - } + for (i, tuple) in tuples.by_ref() { + if bits.contains(i) && *has_filted { + continue; } if let Some(filter_args) = filter_args { let full_values = Vec::from_iter( @@ -70,13 +61,26 @@ impl<'a> JoinProbeState<'a> for LeftAntiJoinState { .cloned() .chain((0..right_schema_len).map(|_| DataValue::Null)), ); - if !throw!(co, filter(&full_values, filter_args)) { + if !filter(&full_values, filter_args)? { continue; } } - co.yield_(Ok(tuple)).await; + return Ok(Some(tuple)); } + left_drop_state.current = None; + } + + let Some((_, state)) = left_drop_state.states.next() else { + return Ok(None); + }; + + if state.is_used { + continue; } - })) + left_drop_state.current = Some(LeftDropTuples { + tuples: state.tuples.into_iter(), + has_filted: state.has_filted, + }); + } } } diff --git a/src/execution/dql/join/hash/left_join.rs b/src/execution/dql/join/hash/left_join.rs index bb4fc291..5fb657f6 100644 --- a/src/execution/dql/join/hash/left_join.rs +++ b/src/execution/dql/join/hash/left_join.rs @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; +use crate::errors::DatabaseError; +use crate::execution::dql::join::hash::{ + filter, FilterArgs, JoinProbeState, LeftDropState, LeftDropTuples, ProbeState, +}; use crate::execution::dql::join::hash_join::BuildState; -use crate::execution::dql::sort::BumpVec; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; use crate::types::tuple::Tuple; use crate::types::value::DataValue; -use ahash::HashMap; use fixedbitset::FixedBitSet; pub(crate) struct LeftJoinState { @@ -28,71 +27,83 @@ pub(crate) struct LeftJoinState { pub(crate) bits: FixedBitSet, } -impl<'a> JoinProbeState<'a> for LeftJoinState { - fn probe( +impl JoinProbeState for LeftJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - let bits_ptr: *mut FixedBitSet = &mut self.bits; - spawn_executor(move |co| async move { - let ProbeArgs { - is_keys_has_null: false, - probe_tuple, - build_state: Some(build_state), - .. - } = probe_args - else { - return; - }; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + if probe_state.is_keys_has_null { + probe_state.finished = true; + return Ok(None); + } - let mut has_filted = false; - for (i, Tuple { values, pk }) in build_state.tuples.iter() { - let full_values = - Vec::from_iter(values.iter().chain(probe_tuple.values.iter()).cloned()); + let Some(build_state) = build_state else { + probe_state.finished = true; + return Ok(None); + }; - match &filter_args { - None => (), - Some(filter_args) => { - if !throw!(co, filter(&full_values, filter_args)) { - has_filted = true; - unsafe { - (*bits_ptr).set(*i, true); - } - continue; - } - } + while probe_state.index < build_state.tuples.len() { + let (i, Tuple { values, pk }) = &build_state.tuples[probe_state.index]; + probe_state.index += 1; + let full_values = Vec::from_iter( + values + .iter() + .chain(probe_state.probe_tuple.values.iter()) + .cloned(), + ); + + if let Some(filter_args) = filter_args { + if !filter(&full_values, filter_args)? { + probe_state.has_filtered = true; + self.bits.set(*i, true); + continue; } - co.yield_(Ok(Tuple::new(pk.clone(), full_values))).await; } - build_state.is_used = !has_filted; - build_state.has_filted = has_filted; - }) + build_state.is_used = true; + return Ok(Some(Tuple::new(pk.clone(), full_values))); + } + + build_state.is_used = !probe_state.has_filtered; + build_state.has_filted = probe_state.has_filtered; + probe_state.finished = true; + Ok(None) } - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - _filter_args: Option<&'a FilterArgs>, - ) -> Option> { + left_drop_state: &mut LeftDropState, + _filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { let full_schema_len = self.right_schema_len + self.left_schema_len; - let bits_ptr: *mut FixedBitSet = &mut self.bits; - Some(spawn_executor(move |co| async move { - for (_, state) in _build_map { - if state.is_used { - continue; - } - for (i, mut left_tuple) in state.tuples { - unsafe { - if !(*bits_ptr).contains(i) && state.has_filted { - continue; - } + loop { + if let Some(LeftDropTuples { + tuples, has_filted, .. + }) = left_drop_state.current.as_mut() + { + for (i, mut left_tuple) in tuples.by_ref() { + if !self.bits.contains(i) && *has_filted { + continue; } left_tuple.values.resize(full_schema_len, DataValue::Null); - co.yield_(Ok(left_tuple)).await; + return Ok(Some(left_tuple)); } + left_drop_state.current = None; + } + + let Some((_, state)) = left_drop_state.states.next() else { + return Ok(None); + }; + + if state.is_used { + continue; } - })) + left_drop_state.current = Some(LeftDropTuples { + tuples: state.tuples.into_iter(), + has_filted: state.has_filted, + }); + } } } diff --git a/src/execution/dql/join/hash/left_semi_join.rs b/src/execution/dql/join/hash/left_semi_join.rs index a60f4eca..43e1e066 100644 --- a/src/execution/dql/join/hash/left_semi_join.rs +++ b/src/execution/dql/join/hash/left_semi_join.rs @@ -12,89 +12,89 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; +use crate::errors::DatabaseError; +use crate::execution::dql::join::hash::{ + filter, FilterArgs, JoinProbeState, LeftDropState, LeftDropTuples, ProbeState, +}; use crate::execution::dql::join::hash_join::BuildState; -use crate::execution::dql::sort::BumpVec; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; use crate::types::tuple::Tuple; -use crate::types::value::DataValue; -use ahash::HashMap; use fixedbitset::FixedBitSet; pub(crate) struct LeftSemiJoinState { pub(crate) bits: FixedBitSet, } -impl<'a> JoinProbeState<'a> for LeftSemiJoinState { - fn probe( +impl JoinProbeState for LeftSemiJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - let bits_ptr: *mut FixedBitSet = &mut self.bits; - spawn_executor(move |co| async move { - let ProbeArgs { - is_keys_has_null: false, - probe_tuple, - build_state: Some(build_state), - .. - } = probe_args - else { - return; - }; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + if probe_state.is_keys_has_null { + probe_state.finished = true; + return Ok(None); + } - let mut has_filted = false; - for (i, Tuple { values, .. }) in build_state.tuples.iter() { - let full_values = - Vec::from_iter(values.iter().chain(probe_tuple.values.iter()).cloned()); + let Some(build_state) = build_state else { + probe_state.finished = true; + return Ok(None); + }; - match &filter_args { - None => (), - Some(filter_args) => { - if !throw!(co, filter(&full_values, filter_args)) { - has_filted = true; - unsafe { - (*bits_ptr).set(*i, true); - } - continue; - } - } + while probe_state.index < build_state.tuples.len() { + let (i, Tuple { values, .. }) = &build_state.tuples[probe_state.index]; + probe_state.index += 1; + let full_values = Vec::from_iter( + values + .iter() + .chain(probe_state.probe_tuple.values.iter()) + .cloned(), + ); + + if let Some(filter_args) = filter_args { + if !filter(&full_values, filter_args)? { + probe_state.has_filtered = true; + self.bits.set(*i, true); } } - build_state.is_used = true; - build_state.has_filted = has_filted; - }) + } + build_state.is_used = true; + build_state.has_filted = probe_state.has_filtered; + probe_state.finished = true; + + Ok(None) } - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - _filter_args: Option<&'a FilterArgs>, - ) -> Option> { - let bits_ptr: *mut FixedBitSet = &mut self.bits; - Some(spawn_executor(move |co| async move { - for ( - _, - BuildState { - tuples, - is_used, - has_filted, - }, - ) in _build_map + left_drop_state: &mut LeftDropState, + _filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + loop { + if let Some(LeftDropTuples { + tuples, has_filted, .. + }) = left_drop_state.current.as_mut() { - if !is_used { - continue; - } - for (i, tuple) in tuples { - unsafe { - if (*bits_ptr).contains(i) && has_filted { - continue; - } + for (i, tuple) in tuples.by_ref() { + if self.bits.contains(i) && *has_filted { + continue; } - co.yield_(Ok(tuple)).await; + return Ok(Some(tuple)); } + left_drop_state.current = None; + } + + let Some((_, state)) = left_drop_state.states.next() else { + return Ok(None); + }; + + if !state.is_used { + continue; } - })) + left_drop_state.current = Some(LeftDropTuples { + tuples: state.tuples.into_iter(), + has_filted: state.has_filted, + }); + } } } diff --git a/src/execution/dql/join/hash/mod.rs b/src/execution/dql/join/hash/mod.rs index 85f2346b..fc23dd62 100644 --- a/src/execution/dql/join/hash/mod.rs +++ b/src/execution/dql/join/hash/mod.rs @@ -28,38 +28,50 @@ use crate::execution::dql::join::hash::left_semi_join::LeftSemiJoinState; use crate::execution::dql::join::hash::right_join::RightJoinState; use crate::execution::dql::join::hash_join::BuildState; use crate::execution::dql::sort::BumpVec; -use crate::execution::Executor; use crate::expression::ScalarExpression; use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::DataValue; -use ahash::HashMap; +use std::collections::hash_map::IntoIter as HashMapIntoIter; -#[derive(Debug)] -pub(crate) struct ProbeArgs<'a> { +pub(crate) struct FilterArgs { + pub(crate) full_schema: SchemaRef, + pub(crate) filter_expr: ScalarExpression, +} + +pub(crate) struct ProbeState { pub(crate) is_keys_has_null: bool, pub(crate) probe_tuple: Tuple, - pub(crate) build_state: Option<&'a mut BuildState>, + pub(crate) index: usize, + pub(crate) has_filtered: bool, + pub(crate) produced: bool, + pub(crate) finished: bool, + pub(crate) emitted_unmatched: bool, } -pub(crate) struct FilterArgs { - pub(crate) full_schema: SchemaRef, - pub(crate) filter_expr: ScalarExpression, +pub(crate) struct LeftDropState { + pub(crate) states: HashMapIntoIter, BuildState>, + pub(crate) current: Option, +} + +pub(crate) struct LeftDropTuples { + pub(crate) tuples: std::vec::IntoIter<(usize, Tuple)>, + pub(crate) has_filted: bool, } -pub(crate) trait JoinProbeState<'a> { - fn probe( +pub(crate) trait JoinProbeState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a>; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError>; - #[allow(clippy::mutable_key_type)] - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - _filter_args: Option<&'a FilterArgs>, - ) -> Option> { - None + _left_drop_state: &mut LeftDropState, + _filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + Ok(None) } } @@ -72,34 +84,51 @@ pub(crate) enum JoinProbeStateImpl { LeftAnti(LeftAntiJoinState), } -impl<'a> JoinProbeState<'a> for JoinProbeStateImpl { - fn probe( +impl JoinProbeState for JoinProbeStateImpl { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { match self { - JoinProbeStateImpl::Inner(state) => state.probe(probe_args, filter_args), - JoinProbeStateImpl::Left(state) => state.probe(probe_args, filter_args), - JoinProbeStateImpl::Right(state) => state.probe(probe_args, filter_args), - JoinProbeStateImpl::Full(state) => state.probe(probe_args, filter_args), - JoinProbeStateImpl::LeftSemi(state) => state.probe(probe_args, filter_args), - JoinProbeStateImpl::LeftAnti(state) => state.probe(probe_args, filter_args), + JoinProbeStateImpl::Inner(state) => { + state.probe_next(probe_state, build_state, filter_args) + } + JoinProbeStateImpl::Left(state) => { + state.probe_next(probe_state, build_state, filter_args) + } + JoinProbeStateImpl::Right(state) => { + state.probe_next(probe_state, build_state, filter_args) + } + JoinProbeStateImpl::Full(state) => { + state.probe_next(probe_state, build_state, filter_args) + } + JoinProbeStateImpl::LeftSemi(state) => { + state.probe_next(probe_state, build_state, filter_args) + } + JoinProbeStateImpl::LeftAnti(state) => { + state.probe_next(probe_state, build_state, filter_args) + } } } - fn left_drop( + fn left_drop_next( &mut self, - _build_map: HashMap, BuildState>, - filter_args: Option<&'a FilterArgs>, - ) -> Option> { + left_drop_state: &mut LeftDropState, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { match self { - JoinProbeStateImpl::Inner(state) => state.left_drop(_build_map, filter_args), - JoinProbeStateImpl::Left(state) => state.left_drop(_build_map, filter_args), - JoinProbeStateImpl::Right(state) => state.left_drop(_build_map, filter_args), - JoinProbeStateImpl::Full(state) => state.left_drop(_build_map, filter_args), - JoinProbeStateImpl::LeftSemi(state) => state.left_drop(_build_map, filter_args), - JoinProbeStateImpl::LeftAnti(state) => state.left_drop(_build_map, filter_args), + JoinProbeStateImpl::Inner(state) => state.left_drop_next(left_drop_state, filter_args), + JoinProbeStateImpl::Left(state) => state.left_drop_next(left_drop_state, filter_args), + JoinProbeStateImpl::Right(state) => state.left_drop_next(left_drop_state, filter_args), + JoinProbeStateImpl::Full(state) => state.left_drop_next(left_drop_state, filter_args), + JoinProbeStateImpl::LeftSemi(state) => { + state.left_drop_next(left_drop_state, filter_args) + } + JoinProbeStateImpl::LeftAnti(state) => { + state.left_drop_next(left_drop_state, filter_args) + } } } } diff --git a/src/execution/dql/join/hash/right_join.rs b/src/execution/dql/join/hash/right_join.rs index 325ba836..da07be48 100644 --- a/src/execution/dql/join/hash/right_join.rs +++ b/src/execution/dql/join/hash/right_join.rs @@ -12,68 +12,84 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::DatabaseError; use crate::execution::dql::join::hash::full_join::FullJoinState; -use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeArgs}; -use crate::execution::{spawn_executor, Executor}; -use crate::throw; +use crate::execution::dql::join::hash::{filter, FilterArgs, JoinProbeState, ProbeState}; +use crate::execution::dql::join::hash_join::BuildState; use crate::types::tuple::Tuple; pub(crate) struct RightJoinState { pub(crate) left_schema_len: usize, } -impl<'a> JoinProbeState<'a> for RightJoinState { - fn probe( +impl JoinProbeState for RightJoinState { + fn probe_next( &mut self, - probe_args: ProbeArgs<'a>, - filter_args: Option<&'a FilterArgs>, - ) -> Executor<'a> { - let left_schema_len = self.left_schema_len; + probe_state: &mut ProbeState, + build_state: Option<&mut BuildState>, + filter_args: Option<&FilterArgs>, + ) -> Result, DatabaseError> { + if probe_state.is_keys_has_null { + if probe_state.emitted_unmatched { + probe_state.finished = true; + return Ok(None); + } + probe_state.emitted_unmatched = true; + probe_state.finished = true; + return Ok(Some(FullJoinState::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); + } - spawn_executor(move |co| async move { - let ProbeArgs { probe_tuple, .. } = probe_args; + let Some(build_state) = build_state else { + if probe_state.emitted_unmatched { + probe_state.finished = true; + return Ok(None); + } + probe_state.emitted_unmatched = true; + probe_state.finished = true; + return Ok(Some(FullJoinState::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); + }; - if let ProbeArgs { - is_keys_has_null: false, - build_state: Some(build_state), - .. - } = probe_args - { - let mut has_filtered = false; - let mut produced = false; - for (_, Tuple { values, pk }) in build_state.tuples.iter() { - let full_values = - Vec::from_iter(values.iter().chain(probe_tuple.values.iter()).cloned()); + while probe_state.index < build_state.tuples.len() { + let (_, Tuple { values, pk }) = &build_state.tuples[probe_state.index]; + probe_state.index += 1; + let full_values = Vec::from_iter( + values + .iter() + .chain(probe_state.probe_tuple.values.iter()) + .cloned(), + ); - match &filter_args { - None => (), - Some(filter_args) => { - if !throw!(co, filter(&full_values, filter_args)) { - has_filtered = true; - continue; - } - } - } - produced = true; - co.yield_(Ok(Tuple::new(pk.clone(), full_values))).await; - } - if !produced { - co.yield_(Ok(FullJoinState::full_right_row( - left_schema_len, - &probe_tuple, - ))) - .await; + if let Some(filter_args) = filter_args { + if !filter(&full_values, filter_args)? { + probe_state.has_filtered = true; + continue; } - build_state.is_used = produced; - build_state.has_filted = has_filtered; - return; } + probe_state.produced = true; + build_state.is_used = true; + build_state.has_filted = probe_state.has_filtered; + return Ok(Some(Tuple::new(pk.clone(), full_values))); + } + + build_state.is_used = probe_state.produced; + build_state.has_filted = probe_state.has_filtered; + + if !probe_state.produced && !probe_state.emitted_unmatched { + probe_state.emitted_unmatched = true; + probe_state.finished = true; + return Ok(Some(FullJoinState::full_right_row( + self.left_schema_len, + &probe_state.probe_tuple, + ))); + } - co.yield_(Ok(FullJoinState::full_right_row( - left_schema_len, - &probe_tuple, - ))) - .await; - }) + probe_state.finished = true; + Ok(None) } } diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index 3670ed05..86d56555 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -21,55 +21,146 @@ use crate::execution::dql::join::hash::left_join::LeftJoinState; use crate::execution::dql::join::hash::left_semi_join::LeftSemiJoinState; use crate::execution::dql::join::hash::right_join::RightJoinState; use crate::execution::dql::join::hash::{ - FilterArgs, JoinProbeState, JoinProbeStateImpl, ProbeArgs, + FilterArgs, JoinProbeState, JoinProbeStateImpl, LeftDropState, ProbeState, }; use crate::execution::dql::join::joins_nullable; use crate::execution::dql::sort::BumpVec; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; +use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::DataValue; use ahash::{HashMap, HashMapExt}; use bumpalo::Bump; use fixedbitset::FixedBitSet; +use std::mem::transmute; use std::sync::Arc; pub struct HashJoin { - on: JoinCondition, + state: HashJoinState, ty: JoinType, - left_input: LogicalPlan, - right_input: LogicalPlan, - bump: Bump, + on_left_keys: Vec, + on_right_keys: Vec, + full_schema: SchemaRef, + filter: Option, + left_schema_len: usize, + right_schema_len: usize, + left_input_plan: Option, + right_input_plan: Option, + left_input: ExecId, + right_input: ExecId, + bump: Box, + init_error: Option, +} + +enum HashJoinState { + Build, + Probe { + build_map: HashMap, BuildState>, + join_impl: JoinProbeStateImpl, + probe_buf: BumpVec<'static, DataValue>, + probe_state: Option, + }, + LeftDrop { + join_impl: JoinProbeStateImpl, + left_drop: LeftDropState, + }, + End, } impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { fn from( - (JoinOperator { on, join_type, .. }, left_input, right_input): ( + (JoinOperator { on, join_type, .. }, mut left_input, mut right_input): ( JoinOperator, LogicalPlan, LogicalPlan, ), ) -> Self { + let ((on_left_keys, on_right_keys), filter_expr) = match on { + JoinCondition::On { on, filter } => (on.into_iter().unzip(), filter), + JoinCondition::None => ((vec![], vec![]), None), + }; + + let init_error = if join_type == JoinType::Cross { + Some(DatabaseError::UnsupportedStmt( + "Cross join should not be executed by HashJoin".to_string(), + )) + } else if on_left_keys.is_empty() || on_right_keys.is_empty() { + Some(DatabaseError::UnsupportedStmt( + "`NestLoopJoin` should be used when there is no equivalent condition".to_string(), + )) + } else { + None + }; + + let (left_force_nullable, right_force_nullable) = joins_nullable(&join_type); + + let mut full_schema_ref = Vec::clone(left_input.output_schema()); + let left_schema_len = full_schema_ref.len(); + + force_nullable(&mut full_schema_ref, left_force_nullable); + full_schema_ref.extend_from_slice(right_input.output_schema()); + force_nullable( + &mut full_schema_ref[left_schema_len..], + right_force_nullable, + ); + let right_schema_len = full_schema_ref.len() - left_schema_len; + HashJoin { - on, + state: HashJoinState::Build, ty: join_type, - left_input, - right_input, - bump: Default::default(), + on_left_keys, + on_right_keys, + full_schema: Arc::new(full_schema_ref.clone()), + filter: filter_expr.map(|filter_expr| FilterArgs { + full_schema: Arc::new(full_schema_ref), + filter_expr, + }), + left_schema_len, + right_schema_len, + left_input_plan: Some(left_input), + right_input_plan: Some(right_input), + left_input: 0, + right_input: 0, + bump: Box::::default(), + init_error, + } + } +} + +fn force_nullable(schema: &mut [ColumnRef], force_nullable: bool) { + for column in schema.iter_mut() { + if let Some(new_column) = column.nullable_for_join(force_nullable) { + *column = new_column; } } } impl HashJoin { + #[allow(clippy::mutable_key_type)] + fn own_bump_vec(buf: BumpVec<'_, DataValue>) -> BumpVec<'static, DataValue> { + unsafe { transmute::, BumpVec<'static, DataValue>>(buf) } + } + + #[allow(clippy::mutable_key_type)] + fn own_build_map( + build_map: HashMap, BuildState>, + ) -> HashMap, BuildState> { + unsafe { + transmute::< + HashMap, BuildState>, + HashMap, BuildState>, + >(build_map) + } + } + fn eval_keys( on_keys: &[ScalarExpression], tuple: &Tuple, schema: &[ColumnRef], - build_buf: &mut BumpVec, + build_buf: &mut BumpVec<'_, DataValue>, ) -> Result<(), DatabaseError> { build_buf.clear(); for expr in on_keys { @@ -77,162 +168,64 @@ impl HashJoin { } Ok(()) } -} -#[derive(Default, Debug)] -pub(crate) struct BuildState { - pub(crate) tuples: Vec<(usize, Tuple)>, - pub(crate) is_used: bool, - pub(crate) has_filted: bool, -} - -impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { - #[allow(clippy::mutable_key_type)] - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let HashJoin { - on, - ty, - mut left_input, - mut right_input, - mut bump, - } = self; - - if ty == JoinType::Cross { - unreachable!("Cross join should not be in HashJoinExecutor"); - } - let ((on_left_keys, on_right_keys), filter): ( - (Vec, Vec), - _, - ) = match on { - JoinCondition::On { on, filter } => (on.into_iter().unzip(), filter), - JoinCondition::None => unreachable!("HashJoin must has on condition"), - }; - if on_left_keys.is_empty() || on_right_keys.is_empty() { - throw!( - co, - Err(DatabaseError::UnsupportedStmt( - "`NestLoopJoin` should be used when there is no equivalent condition" - .to_string() - )) - ) - } - debug_assert!(!on_left_keys.is_empty()); - debug_assert!(!on_right_keys.is_empty()); - - let fn_process = |schema: &mut [ColumnRef], force_nullable| { - for column in schema.iter_mut() { - if let Some(new_column) = column.nullable_for_join(force_nullable) { - *column = new_column; - } - } - }; - - let (left_force_nullable, right_force_nullable) = joins_nullable(&ty); + fn initialize_build<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if !matches!(self.state, HashJoinState::Build) { + return Ok(()); + } - let mut full_schema_ref = Vec::clone(left_input.output_schema()); - let left_schema_len = full_schema_ref.len(); + // build phase: + // 1.construct hashtable, one hash key may contains multiple rows indices. + // 2.merged all left tuples. + #[allow(clippy::mutable_key_type)] + let mut build_map = HashMap::new(); + let mut build_buf = BumpVec::with_capacity_in(self.on_left_keys.len(), &self.bump); + let mut build_count = 0usize; + + while arena.next_tuple(self.left_input)? { + let tuple = arena.result_tuple().clone(); + Self::eval_keys( + &self.on_left_keys, + &tuple, + &self.full_schema[0..self.left_schema_len], + &mut build_buf, + )?; - fn_process(&mut full_schema_ref, left_force_nullable); - full_schema_ref.extend_from_slice(right_input.output_schema()); - fn_process( - &mut full_schema_ref[left_schema_len..], - right_force_nullable, - ); - let right_schema_len = full_schema_ref.len() - left_schema_len; - let full_schema_ref = Arc::new(full_schema_ref); - - // build phase: - // 1.construct hashtable, one hash key may contains multiple rows indices. - // 2.merged all left tuples. - let mut coroutine = build_read(left_input, cache, transaction); - let mut build_map = HashMap::new(); - let bump_ptr: *mut Bump = &mut bump; - let build_map_ptr: *mut HashMap, BuildState> = &mut build_map; - - let mut buf_row = - BumpVec::with_capacity_in(on_left_keys.len(), unsafe { &mut (*bump_ptr) }); - - let mut build_count = 0; - for tuple in coroutine.by_ref() { - let tuple: Tuple = throw!(co, tuple); - throw!( - co, - Self::eval_keys( - &on_left_keys, - &tuple, - &full_schema_ref[0..left_schema_len], - &mut buf_row, - ) - ); - - let build_map_ref = unsafe { &mut (*build_map_ptr) }; - match build_map_ref.get_mut(&buf_row) { - None => { - build_map_ref.insert( - buf_row.clone(), - BuildState { - tuples: vec![(build_count, tuple)], - ..Default::default() - }, - ); - } - Some(BuildState { tuples, .. }) => tuples.push((build_count, tuple)), - } - build_count += 1; - } - let mut join_impl = - Self::create_join_impl(self.ty, left_schema_len, right_schema_len, build_count); - let mut filter_arg = filter.map(|expr| FilterArgs { - full_schema: full_schema_ref.clone(), - filter_expr: expr, - }); - let filter_arg_ptr: *mut Option = &mut filter_arg; - - // probe phase - let mut coroutine = build_read(right_input, cache, transaction); - - for tuple in coroutine.by_ref() { - let tuple: Tuple = throw!(co, tuple); - - throw!( - co, - Self::eval_keys( - &on_right_keys, - &tuple, - &full_schema_ref[left_schema_len..], - &mut buf_row - ) - ); - let build_value = unsafe { (*build_map_ptr).get_mut(&buf_row) }; - - let probe_args = ProbeArgs { - is_keys_has_null: buf_row.iter().any(|value| value.is_null()), - probe_tuple: tuple, - build_state: build_value, - }; - let executor = - join_impl.probe(probe_args, unsafe { &mut (*filter_arg_ptr) }.as_ref()); - for tuple in executor { - co.yield_(tuple).await; - } - } - if let Some(executor) = - join_impl.left_drop(build_map, unsafe { &mut (*filter_arg_ptr) }.as_ref()) - { - for tuple in executor { - co.yield_(tuple).await; + match build_map.get_mut(&build_buf) { + None => { + build_map.insert( + Self::own_bump_vec(build_buf.clone()), + BuildState { + tuples: vec![(build_count, tuple)], + ..Default::default() + }, + ); } + Some(BuildState { tuples, .. }) => tuples.push((build_count, tuple)), } - }) + build_count += 1; + } + + self.state = HashJoinState::Probe { + join_impl: Self::create_join_impl( + self.ty, + self.left_schema_len, + self.right_schema_len, + build_count, + ), + probe_buf: Self::own_bump_vec(BumpVec::with_capacity_in( + self.on_right_keys.len(), + &self.bump, + )), + build_map: Self::own_build_map(build_map), + probe_state: None, + }; + Ok(()) } -} -impl HashJoin { fn create_join_impl( ty: JoinType, left_schema_len: usize, @@ -266,13 +259,156 @@ impl HashJoin { } } +#[derive(Default, Debug)] +pub(crate) struct BuildState { + pub(crate) tuples: Vec<(usize, Tuple)>, + pub(crate) is_used: bool, + pub(crate) has_filted: bool, +} + +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, + transaction: *mut T, + ) -> ExecId { + self.left_input = build_read( + arena, + self.left_input_plan + .take() + .expect("hash join left input plan initialized"), + cache, + transaction, + ); + self.right_input = build_read( + arena, + self.right_input_plan + .take() + .expect("hash join right input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::HashJoin(self)) + } +} + +impl HashJoin { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if let Some(err) = self.init_error.take() { + return Err(err); + } + + self.initialize_build(arena)?; + let mut state = std::mem::replace(&mut self.state, HashJoinState::End); + + loop { + match state { + HashJoinState::Build => unreachable!("hash join must be initialized before probe"), + HashJoinState::Probe { + mut build_map, + mut join_impl, + mut probe_buf, + mut probe_state, + } => { + let probe_finished = loop { + if probe_state.is_none() { + if !arena.next_tuple(self.right_input)? { + break true; + } + let tuple = arena.result_tuple().clone(); + Self::eval_keys( + &self.on_right_keys, + &tuple, + &self.full_schema[self.left_schema_len..], + &mut probe_buf, + )?; + probe_state = Some(ProbeState { + is_keys_has_null: probe_buf.iter().any(DataValue::is_null), + probe_tuple: tuple, + index: 0, + has_filtered: false, + produced: false, + finished: false, + emitted_unmatched: false, + }); + } + + let Some(probe) = probe_state.as_mut() else { + continue; + }; + let build_state = if probe.is_keys_has_null { + None + } else { + build_map.get_mut(&probe_buf) + }; + + if let Some(tuple) = + join_impl.probe_next(probe, build_state, self.filter.as_ref())? + { + if probe.finished { + probe_state = None; + } + self.state = HashJoinState::Probe { + build_map, + join_impl, + probe_buf, + probe_state, + }; + arena.produce_tuple(tuple); + return Ok(()); + } + + if probe.finished { + probe_state = None; + } + }; + + debug_assert!(probe_finished); + state = HashJoinState::LeftDrop { + join_impl, + left_drop: LeftDropState { + states: build_map.into_iter(), + current: None, + }, + }; + } + HashJoinState::LeftDrop { + mut join_impl, + mut left_drop, + } => { + if let Some(tuple) = + join_impl.left_drop_next(&mut left_drop, self.filter.as_ref())? + { + self.state = HashJoinState::LeftDrop { + join_impl, + left_drop, + }; + arena.produce_tuple(tuple); + return Ok(()); + } + state = HashJoinState::End; + } + HashJoinState::End => { + self.state = HashJoinState::End; + arena.finish(); + return Ok(()); + } + } + } + } +} + #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::errors::DatabaseError; use crate::execution::dql::join::hash_join::HashJoin; use crate::execution::dql::test::build_integers; - use crate::execution::{try_collect, ReadExecutor}; + use crate::execution::try_collect; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; @@ -297,10 +433,7 @@ mod test { .before_batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::BindExpressionPosition, - NormalizationRuleImpl::EvaluatorBind, - ], + vec![NormalizationRuleImpl::EvaluatorBind], ) .build() .instantiate(plan) @@ -327,8 +460,8 @@ mod test { ]; let on_keys = vec![( - ScalarExpression::column_expr(t1_columns[0].clone()), - ScalarExpression::column_expr(t2_columns[0].clone()), + ScalarExpression::column_expr(t1_columns[0].clone(), 0), + ScalarExpression::column_expr(t2_columns[0].clone(), 0), )]; let values_t1 = LogicalPlan { @@ -420,8 +553,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = HashJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + HashJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!(tuples.len(), 3); @@ -471,12 +607,13 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - //Outer { let executor = HashJoin::from((op.clone(), left.clone(), right.clone())); - let tuples = try_collect( - executor.execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let tuples = try_collect(crate::execution::execute( + executor, + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; assert_eq!(tuples.len(), 4); @@ -497,13 +634,14 @@ mod test { build_integers(vec![Some(3), Some(5), Some(7), None, None, None]) ); } - // Semi { let mut executor = HashJoin::from((op.clone(), left.clone(), right.clone())); executor.ty = JoinType::LeftSemi; - let mut tuples = try_collect( - executor.execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let mut tuples = try_collect(crate::execution::execute( + executor, + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; let arena = Bump::new(); assert_eq!(tuples.len(), 2); @@ -522,13 +660,14 @@ mod test { build_integers(vec![Some(1), Some(3), Some(5)]) ); } - // Anti { let mut executor = HashJoin::from((op, left, right)); executor.ty = JoinType::LeftAnti; - let tuples = try_collect( - executor.execute((&table_cache, &view_cache, &meta_cache), &mut transaction), - )?; + let tuples = try_collect(crate::execution::execute( + executor, + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ))?; assert_eq!(tuples.len(), 1); assert_eq!( @@ -569,8 +708,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = HashJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + HashJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!(tuples.len(), 4); @@ -616,12 +758,12 @@ mod test { ))]; let on_keys = vec![( - ScalarExpression::column_expr(left_columns[0].clone()), - ScalarExpression::column_expr(right_columns[0].clone()), + ScalarExpression::column_expr(left_columns[0].clone(), 0), + ScalarExpression::column_expr(right_columns[0].clone(), 0), )]; let filter_expr = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(left_columns[1].clone())), + left_expr: Box::new(ScalarExpression::column_expr(left_columns[1].clone(), 1)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Boolean, @@ -669,8 +811,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = HashJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + HashJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!(tuples.len(), 1); @@ -715,8 +860,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = HashJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + HashJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!(tuples.len(), 5); diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 9eb45425..0438ad8a 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -19,12 +19,11 @@ use super::joins_nullable; use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; use crate::types::tuple::{Schema, SchemaRef, Tuple}; use crate::types::value::DataValue; use fixedbitset::FixedBitSet; @@ -86,12 +85,38 @@ impl EqualCondition { /// |--------------------------------|----------------|----------------| /// | Full | left | right | pub struct NestedLoopJoin { - left_input: LogicalPlan, - right_input: LogicalPlan, + left_input_plan: Option, + right_input_plan: LogicalPlan, output_schema_ref: SchemaRef, ty: JoinType, filter: Option, eq_cond: EqualCondition, + left_input: ExecId, + state: NestedLoopJoinState, +} + +enum NestedLoopJoinState { + PullLeft { + right_bitmap: Option, + }, + ScanRight { + active_left: ActiveLeftState, + right_bitmap: Option, + }, + EmitRightUnmatched { + right_input: ExecId, + right_bitmap: FixedBitSet, + right_emit_index: usize, + }, + End, +} + +struct ActiveLeftState { + left_tuple: Tuple, + right_input: ExecId, + right_index: usize, + has_matched: bool, + first_matches: Vec, } impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { @@ -126,152 +151,260 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { ); NestedLoopJoin { - ty: join_type, - left_input, - right_input, + left_input_plan: Some(left_input), + right_input_plan: right_input, output_schema_ref, + ty: join_type, filter, eq_cond, + left_input: 0, + state: NestedLoopJoinState::PullLeft { right_bitmap: None }, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let NestedLoopJoin { - ty, - left_input, - right_input, - output_schema_ref, - filter, - eq_cond, - .. - } = self; - - let right_schema_len = eq_cond.right_schema.len(); - let mut left_coroutine = build_read(left_input, cache, transaction); - let mut bitmap: Option = None; - let mut first_matches = Vec::new(); - - for left_tuple in left_coroutine.by_ref() { - let left_tuple: Tuple = throw!(co, left_tuple); - let mut has_matched = false; - - let mut right_coroutine = build_read(right_input.clone(), cache, transaction); - let mut right_idx = 0; - - for right_tuple in right_coroutine.by_ref() { - let right_tuple: Tuple = throw!(co, right_tuple); - - let tuple = match ( - filter.as_ref(), - throw!(co, eq_cond.equals(&left_tuple, &right_tuple)), - ) { - (None, true) if matches!(ty, JoinType::RightOuter) => { - has_matched = true; - Self::emit_tuple(&right_tuple, &left_tuple, ty, true) - } - (None, true) => { - has_matched = true; - Self::emit_tuple(&left_tuple, &right_tuple, ty, true) + ) -> ExecId { + self.left_input = build_read( + arena, + self.left_input_plan + .take() + .expect("nested loop join left input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::NestedLoopJoin(self)) + } +} + +impl NestedLoopJoin { + fn build_right_input<'a, T: Transaction + 'a>(&self, arena: &mut ExecArena<'a, T>) -> ExecId { + let cache = (arena.table_cache(), arena.view_cache(), arena.meta_cache()); + let transaction = arena.transaction_mut() as *mut T; + build_read(arena, self.right_input_plan.clone(), cache, transaction) + } + + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let mut state = std::mem::replace(&mut self.state, NestedLoopJoinState::End); + + loop { + match state { + NestedLoopJoinState::PullLeft { right_bitmap } => { + if !arena.next_tuple(self.left_input)? { + if matches!(self.ty, JoinType::Full) { + state = NestedLoopJoinState::EmitRightUnmatched { + right_input: self.build_right_input(arena), + right_bitmap: right_bitmap.unwrap_or_default(), + right_emit_index: 0, + }; + continue; } - (Some(filter), true) => { - let new_tuple = Self::merge_tuple(&left_tuple, &right_tuple, &ty); - let value = - throw!(co, filter.eval(Some((&new_tuple, &output_schema_ref)))); - match &value { - DataValue::Boolean(true) => { - let tuple = match ty { - JoinType::LeftAnti => None, - JoinType::LeftSemi if has_matched => None, - JoinType::RightOuter => { - Self::emit_tuple(&right_tuple, &left_tuple, ty, true) - } - _ => Self::emit_tuple(&left_tuple, &right_tuple, ty, true), - }; - has_matched = true; - tuple + self.state = NestedLoopJoinState::End; + arena.finish(); + return Ok(()); + } + let left_tuple = arena.result_tuple().clone(); + + state = NestedLoopJoinState::ScanRight { + active_left: ActiveLeftState { + left_tuple, + right_input: self.build_right_input(arena), + right_index: 0, + has_matched: false, + first_matches: Vec::new(), + }, + right_bitmap, + }; + } + NestedLoopJoinState::ScanRight { + mut active_left, + mut right_bitmap, + } => { + while arena.next_tuple(active_left.right_input)? { + let right_tuple = arena.result_tuple().clone(); + let idx = active_left.right_index; + active_left.right_index += 1; + + let tuple = match ( + self.filter.as_ref(), + self.eq_cond.equals(&active_left.left_tuple, &right_tuple)?, + ) { + (None, true) if matches!(self.ty, JoinType::RightOuter) => { + active_left.has_matched = true; + Self::emit_tuple( + &right_tuple, + &active_left.left_tuple, + self.ty, + true, + ) + } + (None, true) => { + active_left.has_matched = true; + Self::emit_tuple( + &active_left.left_tuple, + &right_tuple, + self.ty, + true, + ) + } + (Some(filter), true) => { + let new_tuple = Self::merge_tuple( + &active_left.left_tuple, + &right_tuple, + &self.ty, + ); + let value = + filter.eval(Some((&new_tuple, &self.output_schema_ref)))?; + match &value { + DataValue::Boolean(true) => { + let tuple = match self.ty { + JoinType::LeftAnti => None, + JoinType::LeftSemi if active_left.has_matched => None, + JoinType::RightOuter => Self::emit_tuple( + &right_tuple, + &active_left.left_tuple, + self.ty, + true, + ), + _ => Self::emit_tuple( + &active_left.left_tuple, + &right_tuple, + self.ty, + true, + ), + }; + active_left.has_matched = true; + tuple + } + DataValue::Boolean(false) | DataValue::Null => None, + _ => return Err(DatabaseError::InvalidType), } - DataValue::Boolean(false) | DataValue::Null => None, - _ => { - co.yield_(Err(DatabaseError::InvalidType)).await; - return; + } + _ => None, + }; + + if let Some(tuple) = tuple { + if matches!(self.ty, JoinType::Full) { + if let Some(bits) = right_bitmap.as_mut() { + bits.insert(idx); + } else { + active_left.first_matches.push(idx); } } + + self.state = if matches!(self.ty, JoinType::LeftSemi) { + NestedLoopJoinState::PullLeft { right_bitmap } + } else { + NestedLoopJoinState::ScanRight { + active_left, + right_bitmap, + } + }; + arena.produce_tuple(tuple); + return Ok(()); } - _ => None, - }; - if let Some(tuple) = tuple { - co.yield_(Ok(tuple)).await; - if matches!(ty, JoinType::LeftSemi) { + if matches!(self.ty, JoinType::LeftAnti) && active_left.has_matched { break; } - if let Some(bits) = bitmap.as_mut() { - bits.insert(right_idx); - } else if matches!(ty, JoinType::Full) { - first_matches.push(right_idx); - } } - if matches!(ty, JoinType::LeftAnti) && has_matched { - break; - } - right_idx += 1; - } - if matches!(ty, JoinType::Full) && bitmap.is_none() { - bitmap = Some(FixedBitSet::with_capacity(right_idx)); - } - - // handle no matched tuple case - let tuple = match ty { - JoinType::LeftAnti if !has_matched => Some(left_tuple.clone()), - JoinType::LeftOuter - | JoinType::LeftSemi - | JoinType::RightOuter - | JoinType::Full - if !has_matched => - { - let right_tuple = Tuple::new(None, vec![DataValue::Null; right_schema_len]); - if matches!(ty, JoinType::RightOuter) { - Self::emit_tuple(&right_tuple, &left_tuple, ty, false) + if matches!(self.ty, JoinType::Full) { + if let Some(bits) = right_bitmap.as_mut() { + for idx in active_left.first_matches { + bits.insert(idx); + } } else { - Self::emit_tuple(&left_tuple, &right_tuple, ty, false) + let mut bits = FixedBitSet::with_capacity(active_left.right_index); + for idx in active_left.first_matches { + bits.insert(idx); + } + right_bitmap = Some(bits); } } - _ => None, - }; - if let Some(tuple) = tuple { - co.yield_(Ok(tuple)).await; - } - } + let right_schema_len = self.eq_cond.right_schema.len(); + let tuple = match self.ty { + JoinType::LeftAnti if !active_left.has_matched => { + Some(active_left.left_tuple) + } + JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::RightOuter + | JoinType::Full + if !active_left.has_matched => + { + let right_tuple = + Tuple::new(None, vec![DataValue::Null; right_schema_len]); + if matches!(self.ty, JoinType::RightOuter) { + Self::emit_tuple( + &right_tuple, + &active_left.left_tuple, + self.ty, + false, + ) + } else { + Self::emit_tuple( + &active_left.left_tuple, + &right_tuple, + self.ty, + false, + ) + } + } + _ => None, + }; - if matches!(ty, JoinType::Full) { - for idx in first_matches.into_iter() { - bitmap.as_mut().unwrap().insert(idx); + self.state = NestedLoopJoinState::PullLeft { right_bitmap }; + if let Some(tuple) = tuple { + arena.produce_tuple(tuple); + return Ok(()); + } + state = std::mem::replace(&mut self.state, NestedLoopJoinState::End); } - - let mut right_coroutine = build_read(right_input.clone(), cache, transaction); - for (idx, right_tuple) in right_coroutine.by_ref().enumerate() { - if !bitmap.as_ref().unwrap().contains(idx) { - let mut right_tuple: Tuple = throw!(co, right_tuple); - let mut values = vec![DataValue::Null; right_schema_len]; - values.append(&mut right_tuple.values); - - co.yield_(Ok(Tuple::new(right_tuple.pk, values))).await; + NestedLoopJoinState::EmitRightUnmatched { + right_input, + right_bitmap, + mut right_emit_index, + } => { + while arena.next_tuple(right_input)? { + let mut right_tuple = arena.result_tuple().clone(); + let idx = right_emit_index; + right_emit_index += 1; + + if !right_bitmap.contains(idx) { + let mut values = vec![DataValue::Null; self.eq_cond.left_schema.len()]; + values.append(&mut right_tuple.values); + self.state = NestedLoopJoinState::EmitRightUnmatched { + right_input, + right_bitmap, + right_emit_index, + }; + arena.produce_tuple(Tuple::new(right_tuple.pk, values)); + return Ok(()); + } } + + self.state = NestedLoopJoinState::End; + arena.finish(); + return Ok(()); + } + NestedLoopJoinState::End => { + self.state = NestedLoopJoinState::End; + arena.finish(); + return Ok(()); } } - }) + } } -} -impl NestedLoopJoin { /// Emit a tuple according to the join type. /// /// `left_tuple`: left tuple to be included. @@ -380,13 +513,11 @@ impl NestedLoopJoin { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::execution::dql::test::build_integers; - use crate::execution::{try_collect, ReadExecutor}; - use crate::expression::ScalarExpression; + use crate::execution::try_collect; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::normalization::NormalizationRuleImpl; @@ -397,12 +528,10 @@ mod test { use crate::storage::Storage; use crate::types::evaluator::int32::Int32GtBinaryEvaluator; use crate::types::evaluator::BinaryEvaluatorBox; - use crate::types::value::DataValue; use crate::types::LogicalType; use crate::utils::lru::SharedLruCache; use std::collections::HashSet; use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; fn optimize_exprs(plan: LogicalPlan) -> Result { @@ -410,10 +539,7 @@ mod test { .before_batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![ - NormalizationRuleImpl::BindExpressionPosition, - NormalizationRuleImpl::EvaluatorBind, - ], + vec![NormalizationRuleImpl::EvaluatorBind], ) .build() .instantiate(plan) @@ -456,8 +582,8 @@ mod test { let on_keys = if eq { vec![( - ScalarExpression::column_expr(t1_columns[1].clone()), - ScalarExpression::column_expr(t2_columns[1].clone()), + ScalarExpression::column_expr(t1_columns[1].clone(), 1), + ScalarExpression::column_expr(t2_columns[1].clone(), 1), )] } else { vec![] @@ -527,12 +653,14 @@ mod test { let filter = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(ColumnRef::from( - ColumnCatalog::new("c1".to_owned(), true, desc.clone()), - ))), - right_expr: Box::new(ScalarExpression::column_expr(ColumnRef::from( - ColumnCatalog::new("c4".to_owned(), true, desc.clone()), - ))), + left_expr: Box::new(ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), + 0, + )), + right_expr: Box::new(ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), + 3, + )), evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), ty: LogicalType::Boolean, }; @@ -588,13 +716,22 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(1); - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); valid_result(&mut expected_set, &tuples); @@ -628,8 +765,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!( @@ -638,15 +778,38 @@ mod test { ); let mut expected_set = HashSet::with_capacity(4); - let tuple = build_integers(vec![Some(0), Some(2), Some(4), None, None, None]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); - - let tuple = build_integers(vec![Some(1), Some(3), Some(5), None, None, None]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(3), Some(5), Some(7), None, None, None]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(0), + Some(2), + Some(4), + None, + None, + None, + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(3), + Some(5), + None, + None, + None, + ])); + expected_set.insert(build_integers(vec![ + Some(3), + Some(5), + Some(7), + None, + None, + None, + ])); valid_result(&mut expected_set, &tuples); @@ -680,14 +843,22 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(1); - - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); valid_result(&mut expected_set, &tuples); @@ -721,18 +892,38 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(3); - - let tuple = build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(0), + Some(2), + Some(4), + Some(0), + Some(2), + Some(4), + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(3), + Some(5), + Some(1), + Some(3), + Some(5), + ])); valid_result(&mut expected_set, &tuples); Ok(()) @@ -765,8 +956,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!(tuples.len(), 16); @@ -801,8 +995,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(1); @@ -840,8 +1037,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(3); @@ -881,19 +1081,46 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; let mut expected_set = HashSet::with_capacity(4); - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(1), Some(3), Some(5)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(1), Some(1), Some(1)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(4), Some(6), Some(8)]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(1), + Some(3), + Some(5), + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(1), + Some(1), + Some(1), + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(4), + Some(6), + Some(8), + ])); valid_result(&mut expected_set, &tuples); @@ -927,8 +1154,11 @@ mod test { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); - let executor = NestedLoopJoin::from((op, left, right)) - .execute((&table_cache, &view_cache, &meta_cache), &mut transaction); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); let tuples = try_collect(executor)?; assert_eq!( @@ -937,21 +1167,62 @@ mod test { ); let mut expected_set = HashSet::with_capacity(7); - let tuple = build_integers(vec![Some(0), Some(2), Some(4), None, None, None]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(1), Some(2), Some(5), Some(0), Some(2), Some(4)]); - expected_set.insert(tuple); - - let tuple = build_integers(vec![Some(1), Some(3), Some(5), None, None, None]); - expected_set.insert(tuple); - let tuple = build_integers(vec![Some(3), Some(5), Some(7), None, None, None]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(1), Some(3), Some(5)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(4), Some(6), Some(8)]); - expected_set.insert(tuple); - let tuple = build_integers(vec![None, None, None, Some(1), Some(1), Some(1)]); - expected_set.insert(tuple); + expected_set.insert(build_integers(vec![ + Some(0), + Some(2), + Some(4), + None, + None, + None, + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(2), + Some(5), + Some(0), + Some(2), + Some(4), + ])); + expected_set.insert(build_integers(vec![ + Some(1), + Some(3), + Some(5), + None, + None, + None, + ])); + expected_set.insert(build_integers(vec![ + Some(3), + Some(5), + Some(7), + None, + None, + None, + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(1), + Some(3), + Some(5), + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(4), + Some(6), + Some(8), + ])); + expected_set.insert(build_integers(vec![ + None, + None, + None, + Some(1), + Some(1), + Some(1), + ])); valid_result(&mut expected_set, &tuples); @@ -959,7 +1230,100 @@ mod test { } #[test] - fn test_right_join_using_preserves_right_side_values() -> Result<(), DatabaseError> { + fn test_nested_right_join_filter_only_left_columns() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let storage = RocksStorage::new(temp_dir.path())?; + let mut transaction = storage.transaction()?; + let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + + let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; + let left_columns = vec![ + ColumnRef::from(ColumnCatalog::new("k".to_string(), true, desc.clone())), + ColumnRef::from(ColumnCatalog::new("v".to_string(), true, desc.clone())), + ]; + let right_columns = vec![ColumnRef::from(ColumnCatalog::new( + "rk".to_string(), + true, + desc.clone(), + ))]; + + let on_keys = vec![( + ScalarExpression::column_expr(left_columns[0].clone(), 0), + ScalarExpression::column_expr(right_columns[0].clone(), 0), + )]; + let filter_expr = ScalarExpression::Binary { + op: crate::expression::BinaryOperator::Gt, + left_expr: Box::new(ScalarExpression::column_expr(left_columns[1].clone(), 1)), + right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), + evaluator: None, + ty: LogicalType::Boolean, + }; + + let left = LogicalPlan { + operator: Operator::Values(ValuesOperator { + rows: vec![ + vec![DataValue::Int32(2), DataValue::Int32(0)], + vec![DataValue::Int32(2), DataValue::Int32(5)], + ], + schema_ref: Arc::new(left_columns), + }), + childrens: Box::new(Childrens::None), + physical_option: None, + _output_schema_ref: None, + }; + let right = LogicalPlan { + operator: Operator::Values(ValuesOperator { + rows: vec![vec![DataValue::Int32(2)]], + schema_ref: Arc::new(right_columns), + }), + childrens: Box::new(Childrens::None), + physical_option: None, + _output_schema_ref: None, + }; + + let plan = LogicalPlan::new( + Operator::Join(JoinOperator { + on: JoinCondition::On { + on: on_keys, + filter: Some(filter_expr), + }, + join_type: JoinType::RightOuter, + }), + Childrens::Twins { + left: Box::new(left), + right: Box::new(right), + }, + ); + let plan = optimize_exprs(plan)?; + + let Operator::Join(op) = plan.operator else { + unreachable!() + }; + let (left, right) = plan.childrens.pop_twins(); + let executor = crate::execution::execute( + NestedLoopJoin::from((op, left, right)), + (&table_cache, &view_cache, &meta_cache), + &mut transaction, + ); + let tuples = try_collect(executor)?; + + assert_eq!(tuples.len(), 1); + assert_eq!( + tuples[0].values, + vec![ + DataValue::Int32(2), + DataValue::Int32(5), + DataValue::Int32(2) + ] + ); + + Ok(()) + } + + #[test] + fn test_right_join_using_keeps_left_visible_column_binding() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let db = DataBaseBuilder::path(temp_dir.path()).build_in_memory()?; diff --git a/src/execution/dql/limit.rs b/src/execution/dql/limit.rs index aec7f045..9cd3e814 100644 --- a/src/execution/dql/limit.rs +++ b/src/execution/dql/limit.rs @@ -12,14 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::limit::LimitOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; +use crate::storage::Transaction; pub struct Limit { offset: Option, limit: Option, - input: LogicalPlan, + input_plan: Option, + input: ExecId, + skipped: usize, + emitted: usize, } impl From<(LimitOperator, LogicalPlan)> for Limit { @@ -27,44 +31,60 @@ impl From<(LimitOperator, LogicalPlan)> for Limit { Limit { offset, limit, - input, + input_plan: Some(input), + input: 0, + skipped: 0, + emitted: 0, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Limit { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Limit { - offset, - limit, - input, - } = self; - - if limit.is_some() && limit.unwrap() == 0 { - return; - } + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("limit input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Limit(self)) + } +} - let offset_val = offset.unwrap_or(0); - let offset_limit = offset_val.saturating_add(limit.unwrap_or(usize::MAX)) - 1; +impl Limit { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let offset = self.offset.unwrap_or(0); + let limit = self.limit.unwrap_or(usize::MAX); - let mut i = 0; - let executor = build_read(input, cache, transaction); + if limit == 0 || self.emitted >= limit { + arena.finish(); + return Ok(()); + } - for tuple in executor { - i += 1; - if i - 1 < offset_val { - continue; - } else if i - 1 > offset_limit { - break; - } + loop { + if !arena.next_tuple(self.input)? { + arena.finish(); + return Ok(()); + } - co.yield_(tuple).await; + if self.skipped < offset { + self.skipped += 1; + continue; } - }) + + self.emitted += 1; + arena.resume(); + return Ok(()); + } } } diff --git a/src/execution/dql/mod.rs b/src/execution/dql/mod.rs index 67959f7f..86595e43 100644 --- a/src/execution/dql/mod.rs +++ b/src/execution/dql/mod.rs @@ -23,6 +23,7 @@ pub(crate) mod index_scan; pub(crate) mod join; pub(crate) mod limit; pub(crate) mod projection; +pub(crate) mod scalar_subquery; pub(crate) mod seq_scan; pub(crate) mod show_table; pub(crate) mod show_view; diff --git a/src/execution/dql/projection.rs b/src/execution/dql/projection.rs index c5b9f49b..eeca5b65 100644 --- a/src/execution/dql/projection.rs +++ b/src/execution/dql/projection.rs @@ -14,46 +14,78 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::expression::ScalarExpression; use crate::planner::operator::project::ProjectOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; use crate::types::tuple::Tuple; use crate::types::value::DataValue; pub struct Projection { exprs: Vec, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, + scratch: Tuple, } impl From<(ProjectOperator, LogicalPlan)> for Projection { - fn from((ProjectOperator { exprs }, input): (ProjectOperator, LogicalPlan)) -> Self { - Projection { exprs, input } + fn from((ProjectOperator { exprs }, mut input): (ProjectOperator, LogicalPlan)) -> Self { + Projection { + exprs, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: 0, + scratch: Tuple::default(), + } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Projection { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Projection { exprs, mut input } = self; - let schema = input.output_schema().clone(); - let executor = build_read(input, cache, transaction); - - for tuple in executor { - let tuple = throw!(co, tuple); - let values = throw!(co, Self::projection(&tuple, &exprs, &schema)); - co.yield_(Ok(Tuple::new(tuple.pk, values))).await; - } - }) + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("projection input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Projection(self)) } } impl Projection { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if !arena.next_tuple(self.input)? { + arena.finish(); + return Ok(()); + } + + std::mem::swap(&mut self.scratch, arena.result_tuple_mut()); + let tuple = &self.scratch; + let output = arena.result_tuple_mut(); + output.pk.clone_from(&tuple.pk); + output.values.clear(); + output.values.reserve(self.exprs.len()); + for expr in self.exprs.iter() { + output + .values + .push(expr.eval(Some((tuple, &self.input_schema)))?); + } + arena.resume(); + Ok(()) + } + pub fn projection( tuple: &Tuple, exprs: &[ScalarExpression], diff --git a/src/execution/dql/scalar_subquery.rs b/src/execution/dql/scalar_subquery.rs new file mode 100644 index 00000000..152a8436 --- /dev/null +++ b/src/execution/dql/scalar_subquery.rs @@ -0,0 +1,88 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; +use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; +use crate::planner::LogicalPlan; +use crate::storage::Transaction; +use crate::types::value::DataValue; + +pub struct ScalarSubquery { + input_plan: Option, + input: Option, + value_count: usize, +} + +impl From<(ScalarSubqueryOperator, LogicalPlan)> for ScalarSubquery { + fn from((_, mut input): (ScalarSubqueryOperator, LogicalPlan)) -> Self { + Self { + value_count: input.output_schema().len(), + input_plan: Some(input), + input: None, + } + } +} + +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ScalarSubquery { + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, + transaction: *mut T, + ) -> ExecId { + self.input = Some(build_read( + arena, + self.input_plan + .take() + .expect("scalar subquery input plan initialized"), + cache, + transaction, + )); + arena.push(ExecNode::ScalarSubquery(self)) + } +} + +impl ScalarSubquery { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(input) = self.input.take() else { + arena.finish(); + return Ok(()); + }; + + let has_first = arena.next_tuple(input)?; + if !has_first { + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + output + .values + .extend((0..self.value_count).map(|_| DataValue::Null)); + arena.resume(); + return Ok(()); + } + + if arena.next_tuple(input)? { + return Err(DatabaseError::InvalidValue( + "scalar subquery returned more than one row".to_string(), + )); + } + + arena.resume(); + Ok(()) + } +} diff --git a/src/execution/dql/seq_scan.rs b/src/execution/dql/seq_scan.rs index ca4aa5b2..fe7b1631 100644 --- a/src/execution/dql/seq_scan.rs +++ b/src/execution/dql/seq_scan.rs @@ -12,50 +12,69 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::table_scan::TableScanOperator; -use crate::storage::{Iter, StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; +use crate::storage::{Iter, Transaction, TupleIter}; -pub(crate) struct SeqScan { - op: TableScanOperator, +pub(crate) struct SeqScan<'a, T: Transaction + 'a> { + op: Option, + iter: Option>, } -impl From for SeqScan { +impl<'a, T: Transaction + 'a> From for SeqScan<'a, T> { fn from(op: TableScanOperator) -> Self { - SeqScan { op } + SeqScan { + op: Some(op), + iter: None, + } } } -impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SeqScan { - fn execute( +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SeqScan<'a, T> { + fn into_executor( self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let TableScanOperator { + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, + _: *mut T, + ) -> ExecId { + arena.push(ExecNode::SeqScan(self)) + } +} + +impl<'a, T: Transaction + 'a> SeqScan<'a, T> { + pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { + if self.iter.is_none() { + let Some(TableScanOperator { table_name, columns, limit, with_pk, .. - } = self.op; - - let mut iter = throw!( - co, - unsafe { &mut (*transaction) }.read( - table_cache, - table_name, - limit, - columns, - with_pk - ) - ); + }) = self.op.take() + else { + arena.finish(); + return Ok(()); + }; + self.iter = Some(arena.transaction_mut().read( + arena.table_cache(), + table_name, + limit, + columns, + with_pk, + )?); + } - while let Some(tuple) = throw!(co, iter.next_tuple()) { - co.yield_(Ok(tuple)).await; - } - }) + if self + .iter + .as_mut() + .expect("seq scan iterator initialized") + .next_tuple_into(arena.result_tuple_mut())? + { + arena.resume(); + } else { + arena.finish(); + } + Ok(()) } } diff --git a/src/execution/dql/show_table.rs b/src/execution/dql/show_table.rs index be562a9c..14196691 100644 --- a/src/execution/dql/show_table.rs +++ b/src/execution/dql/show_table.rs @@ -13,33 +13,41 @@ // limitations under the License. use crate::catalog::TableMeta; -use crate::execution::{spawn_executor, Executor, ReadExecutor}; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::errors::DatabaseError; +use crate::execution::ExecArena; +use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; use sqlparser::ast::CharLengthUnits; -pub struct ShowTables; +pub struct ShowTables { + pub(crate) metas: Option>, +} + +impl ShowTables { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.metas.is_none() { + self.metas = Some(arena.transaction_mut().table_metas()?.into_iter()); + } -impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ShowTables { - fn execute( - self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let metas = throw!(co, unsafe { &mut (*transaction) }.table_metas()); + let Some(TableMeta { table_name }) = self.metas.as_mut().and_then(|metas| metas.next()) + else { + arena.finish(); + return Ok(()); + }; - for TableMeta { table_name } in metas { - let values = vec![DataValue::Utf8 { - value: table_name.to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }]; + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + output.values.push(DataValue::Utf8 { + value: table_name.to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); - co.yield_(Ok(Tuple::new(None, values))).await; - } - }) + arena.resume(); + Ok(()) } } diff --git a/src/execution/dql/show_view.rs b/src/execution/dql/show_view.rs index df88e3c6..6d80fa2f 100644 --- a/src/execution/dql/show_view.rs +++ b/src/execution/dql/show_view.rs @@ -13,33 +13,45 @@ // limitations under the License. use crate::catalog::view::View; -use crate::execution::{spawn_executor, Executor, ReadExecutor}; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::errors::DatabaseError; +use crate::execution::ExecArena; +use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; use sqlparser::ast::CharLengthUnits; -pub struct ShowViews; +pub struct ShowViews { + pub(crate) metas: Option>, +} + +impl ShowViews { + pub(crate) fn next_tuple<'a, T: Transaction>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.metas.is_none() { + self.metas = Some( + arena + .transaction_mut() + .views(arena.table_cache())? + .into_iter(), + ); + } -impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ShowViews { - fn execute( - self, - (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), - transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let metas = throw!(co, unsafe { &mut (*transaction) }.views(table_cache)); + let Some(View { name, .. }) = self.metas.as_mut().and_then(|metas| metas.next()) else { + arena.finish(); + return Ok(()); + }; - for View { name, .. } in metas { - let values = vec![DataValue::Utf8 { - value: name.to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }]; + let output = arena.result_tuple_mut(); + output.pk = None; + output.values.clear(); + output.values.push(DataValue::Utf8 { + value: name.to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); - co.yield_(Ok(Tuple::new(None, values))).await; - } - }) + arena.resume(); + Ok(()) } } diff --git a/src/execution/dql/sort.rs b/src/execution/dql/sort.rs index 83e1e6e1..6de82cfe 100644 --- a/src/execution/dql/sort.rs +++ b/src/execution/dql/sort.rs @@ -13,16 +13,15 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::LogicalPlan; use crate::storage::table_codec::BumpBytes; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::{Schema, Tuple}; +use crate::storage::Transaction; +use crate::types::tuple::{Schema, SchemaRef, Tuple}; use bumpalo::Bump; use std::cmp::Ordering; -use std::mem::MaybeUninit; +use std::mem::{transmute, MaybeUninit}; pub(crate) type BumpVec<'bump, T> = bumpalo::collections::Vec<'bump, T>; @@ -276,45 +275,57 @@ impl SortBy { } pub struct Sort { - arena: Bump, + output: Option>>, + arena: Box, sort_fields: Vec, limit: Option, - input: LogicalPlan, + input_schema: SchemaRef, + input: ExecId, + input_plan: Option, } impl From<(SortOperator, LogicalPlan)> for Sort { - fn from((SortOperator { sort_fields, limit }, input): (SortOperator, LogicalPlan)) -> Self { + fn from((SortOperator { sort_fields, limit }, mut input): (SortOperator, LogicalPlan)) -> Self { Sort { - arena: Default::default(), + output: None, + arena: Box::::default(), sort_fields, limit, - input, + input_schema: input.output_schema().clone(), + input: 0, + input_plan: Some(input), } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Sort { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Sort { - arena, - sort_fields, - limit, - mut input, - } = self; - - let arena: *const Bump = &arena; - let schema = input.output_schema().clone(); - let mut tuples = NullableVec::new(unsafe { &*arena }); - - let mut coroutine = build_read(input, cache, transaction); - - for (offset, tuple) in coroutine.by_ref().enumerate() { - tuples.put((offset, throw!(co, tuple))); + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan.take().expect("sort input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Sort(self)) + } +} + +impl Sort { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.output.is_none() { + let mut tuples = NullableVec::new(&self.arena); + + while arena.next_tuple(self.input)? { + let offset = tuples.len(); + tuples.put((offset, arena.result_tuple().clone())); } let sort_by = if tuples.len() > 256 { @@ -322,18 +333,29 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Sort { } else { SortBy::Fast }; - let mut limit = limit.unwrap_or(tuples.len()); - - for tuple in throw!( - co, - sort_by.sorted_tuples(unsafe { &*arena }, &schema, &sort_fields, tuples) - ) { - if limit != 0 { - co.yield_(Ok(tuple)).await; - limit -= 1; - } - } - }) + let limit = self.limit.unwrap_or(tuples.len()); + let rows = sort_by.sorted_tuples( + &self.arena, + &self.input_schema, + &self.sort_fields, + tuples, + )?; + let rows: Box + '_> = Box::new(rows.take(limit)); + // The arena lives at a stable boxed address, so we can keep the old iterator shape + // and resume it across executor polls. + self.output = Some(unsafe { + transmute:: + '_>, Box>>( + rows, + ) + }); + } + + if let Some(tuple) = self.output.as_mut().and_then(std::iter::Iterator::next) { + arena.produce_tuple(tuple); + } else { + arena.finish(); + } + Ok(()) } } @@ -381,7 +403,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(0), + position: 0, }, asc, nulls_first, @@ -540,7 +562,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(0), + position: 0, }, asc: asc_1, nulls_first: nulls_first_1, @@ -552,7 +574,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(1), + position: 1, }, asc: asc_2, nulls_first: nulls_first_2, diff --git a/src/execution/dql/top_k.rs b/src/execution/dql/top_k.rs index 6f34b02a..07cffa0e 100644 --- a/src/execution/dql/top_k.rs +++ b/src/execution/dql/top_k.rs @@ -14,17 +14,17 @@ use crate::errors::DatabaseError; use crate::execution::dql::sort::BumpVec; -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::sort::SortField; use crate::planner::operator::top_k::TopKOperator; use crate::planner::LogicalPlan; use crate::storage::table_codec::BumpBytes; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::{Schema, Tuple}; +use crate::storage::Transaction; +use crate::types::tuple::{Schema, SchemaRef, Tuple}; use bumpalo::Bump; use std::cmp::Ordering; -use std::collections::BTreeSet; +use std::collections::{btree_set::IntoIter as BTreeSetIntoIter, BTreeSet}; +use std::mem::transmute; #[derive(Eq, PartialEq, Debug)] struct CmpItem<'a> { @@ -89,11 +89,14 @@ fn top_sort<'a>( } pub struct TopK { - arena: Bump, + output: Option>>>, + arena: Box, sort_fields: Vec, limit: usize, offset: Option, - input: LogicalPlan, + input_schema: SchemaRef, + input_plan: Option, + input: ExecId, } impl From<(TopKOperator, LogicalPlan)> for TopK { @@ -104,66 +107,80 @@ impl From<(TopKOperator, LogicalPlan)> for TopK { limit, offset, }, - input, + mut input, ): (TopKOperator, LogicalPlan), ) -> Self { TopK { - arena: Default::default(), + output: None, + arena: Box::::default(), sort_fields, limit, offset, - input, + input_schema: input.output_schema().clone(), + input_plan: Some(input), + input: 0, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for TopK { - #[allow(clippy::mutable_key_type)] - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let TopK { - arena, - sort_fields, - limit, - offset, - mut input, - } = self; - - let arena: *const Bump = &arena; + ) -> ExecId { + self.input = build_read( + arena, + self.input_plan + .take() + .expect("top-k input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::TopK(self)) + } +} - let schema = input.output_schema().clone(); - let keep_count = offset.unwrap_or(0) + limit; +impl TopK { + #[allow(clippy::mutable_key_type)] + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.output.is_none() { + let keep_count = self.offset.unwrap_or(0) + self.limit; let mut set = BTreeSet::new(); - let coroutine = build_read(input, cache, transaction); - - for tuple in coroutine { - throw!( - co, - top_sort( - unsafe { &*arena }, - &schema, - &sort_fields, - &mut set, - throw!(co, tuple), - keep_count, - ) - ); + + while arena.next_tuple(self.input)? { + top_sort( + &self.arena, + &self.input_schema, + &self.sort_fields, + &mut set, + arena.result_tuple().clone(), + keep_count, + )?; } - let mut i: usize = 0; + let offset = self.offset.unwrap_or(0); + let rows = set.into_iter().skip(offset); + // The arena lives at a stable boxed address, so we can keep the old set/key shape + // and resume iteration across executor polls. + self.output = Some(unsafe { + transmute::< + std::iter::Skip>>, + std::iter::Skip>>, + >(rows) + }); + } - while let Some(item) = set.pop_first() { - i += 1; - if i - 1 < offset.unwrap_or(0) { - continue; - } - co.yield_(Ok(item.tuple)).await; - } - }) + if let Some(item) = self.output.as_mut().and_then(std::iter::Iterator::next) { + arena.produce_tuple(item.tuple); + } else { + arena.finish(); + } + Ok(()) } } @@ -192,7 +209,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(0), + position: 0, }, asc, nulls_first, @@ -384,7 +401,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(0), + position: 0, }, asc: asc_1, nulls_first: nulls_first_1, @@ -396,7 +413,7 @@ mod test { false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), ))), - position: Some(1), + position: 1, }, asc: asc_2, nulls_first: nulls_first_2, diff --git a/src/execution/dql/union.rs b/src/execution/dql/union.rs index 1b63522a..78c72e38 100644 --- a/src/execution/dql/union.rs +++ b/src/execution/dql/union.rs @@ -12,44 +12,74 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{build_read, spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; +use crate::storage::Transaction; pub struct Union { - left_input: LogicalPlan, - right_input: LogicalPlan, + left_plan: Option, + right_plan: Option, + left_input: ExecId, + right_input: ExecId, + reading_left: bool, } impl From<(LogicalPlan, LogicalPlan)> for Union { fn from((left_input, right_input): (LogicalPlan, LogicalPlan)) -> Self { Union { - left_input, - right_input, + left_plan: Some(left_input), + right_plan: Some(right_input), + left_input: 0, + right_input: 0, + reading_left: true, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Union { - fn execute( - self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + fn into_executor( + mut self, + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let Union { - left_input, - right_input, - } = self; - let mut left = build_read(left_input, cache, transaction); - - for tuple in left.by_ref() { - co.yield_(tuple).await; - } - let right = build_read(right_input, cache, transaction); + ) -> ExecId { + self.left_input = build_read( + arena, + self.left_plan + .take() + .expect("union left input plan initialized"), + cache, + transaction, + ); + self.right_input = build_read( + arena, + self.right_plan + .take() + .expect("union right input plan initialized"), + cache, + transaction, + ); + arena.push(ExecNode::Union(self)) + } +} - for tuple in right { - co.yield_(tuple).await; +impl Union { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + if self.reading_left { + if arena.next_tuple(self.left_input)? { + arena.resume(); + return Ok(()); } - }) + self.reading_left = false; + } + if arena.next_tuple(self.right_input)? { + arena.resume(); + } else { + arena.finish(); + } + Ok(()) } } diff --git a/src/execution/dql/values.rs b/src/execution/dql/values.rs index 10f97cf2..9e462ff4 100644 --- a/src/execution/dql/values.rs +++ b/src/execution/dql/values.rs @@ -12,44 +12,61 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::execution::{spawn_executor, Executor, ReadExecutor}; +use crate::errors::DatabaseError; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; use crate::planner::operator::values::ValuesOperator; -use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; -use crate::throw; -use crate::types::tuple::Tuple; +use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; use crate::types::value::DataValue; use std::mem; pub struct Values { - op: ValuesOperator, + rows: std::vec::IntoIter>, + schema_ref: SchemaRef, } impl From for Values { - fn from(op: ValuesOperator) -> Self { - Values { op } + fn from(ValuesOperator { rows, schema_ref }: ValuesOperator) -> Self { + Values { + rows: rows.into_iter(), + schema_ref, + } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Values { - fn execute( + fn into_executor( self, - _: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + _: ExecutionCaches<'a>, _: *mut T, - ) -> Executor<'a> { - spawn_executor(move |co| async move { - let ValuesOperator { rows, schema_ref } = self.op; + ) -> ExecId { + arena.push(ExecNode::Values(self)) + } +} - for mut values in rows { - for (i, value) in values.iter_mut().enumerate() { - let ty = schema_ref[i].datatype().clone(); +impl Values { + pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + let Some(mut values) = self.rows.next() else { + arena.finish(); + return Ok(()); + }; - if value.logical_type() != ty { - *value = throw!(co, mem::replace(value, DataValue::Null).cast(&ty)); - } - } + for (i, value) in values.iter_mut().enumerate() { + let ty = self.schema_ref[i].datatype().clone(); - co.yield_(Ok(Tuple::new(None, values))).await; + if value.logical_type() != ty { + *value = mem::replace(value, DataValue::Null).cast(&ty)?; } - }) + } + + let output = arena.result_tuple_mut(); + output.pk = None; + output.values = values; + arena.resume(); + Ok(()) } } diff --git a/src/execution/execute_macro.rs b/src/execution/execute_macro.rs deleted file mode 100644 index 2c6752eb..00000000 --- a/src/execution/execute_macro.rs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#[macro_export] -macro_rules! throw { - ($co:expr, $code:expr) => { - match $code { - Ok(item) => item, - Err(err) => { - $co.yield_(Err(err)).await; - return; - } - } - }; -} diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 25e830c6..6fb8f7ea 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -15,7 +15,6 @@ pub(crate) mod ddl; pub(crate) mod dml; pub(crate) mod dql; -pub(crate) mod execute_macro; use self::ddl::add_column::AddColumn; use self::ddl::change_column::ChangeColumn; @@ -48,6 +47,7 @@ use crate::execution::dql::index_scan::IndexScan; use crate::execution::dql::join::hash_join::HashJoin; use crate::execution::dql::limit::Limit; use crate::execution::dql::projection::Projection; +use crate::execution::dql::scalar_subquery::ScalarSubquery; use crate::execution::dql::seq_scan::SeqScan; use crate::execution::dql::show_table::ShowTables; use crate::execution::dql::show_view::ShowViews; @@ -61,40 +61,324 @@ use crate::planner::LogicalPlan; use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::types::index::IndexInfo; use crate::types::tuple::Tuple; -use genawaiter::sync::{Co, Gen}; -use std::future::Future; -pub type Executor<'a> = Box> + 'a>; +pub(crate) type ExecutionCaches<'a> = (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache); +pub(crate) type ExecId = usize; -pub(crate) fn spawn_executor<'a, F, Fut>(producer: F) -> Executor<'a> -where - F: FnOnce(Co>) -> Fut + 'a, - Fut: Future + 'a, -{ - Box::new(Gen::new(producer).into_iter().fuse()) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ExecStatus { + Continue, + End, +} + +#[derive(Debug, Default)] +pub(crate) struct ExecResult { + pub(crate) tuple: Tuple, + pub(crate) status: Option, +} + +pub struct Executor<'a, T: Transaction + 'a> { + arena: ExecArena<'a, T>, + root: ExecId, +} + +impl<'a, T: Transaction + 'a> Executor<'a, T> { + pub(crate) fn new(arena: ExecArena<'a, T>, root: ExecId) -> Self { + Self { arena, root } + } + + pub(crate) fn next_tuple(&mut self) -> Result, DatabaseError> { + if !self.arena.next_tuple(self.root)? { + return Ok(None); + } + Ok(Some(self.arena.result_tuple())) + } +} + +impl Iterator for Executor<'_, T> { + type Item = Result; + + fn next(&mut self) -> Option { + match self.next_tuple() { + Ok(Some(tuple)) => Some(Ok(tuple.clone())), + Ok(None) => None, + Err(err) => Some(Err(err)), + } + } +} + +#[allow(clippy::large_enum_variant)] +pub(crate) enum ExecNode<'a, T: Transaction + 'a> { + AddColumn(AddColumn), + Analyze(Analyze), + ChangeColumn(ChangeColumn), + CopyFromFile(CopyFromFile), + CopyToFile(CopyToFile), + CreateIndex(CreateIndex), + CreateTable(CreateTable), + CreateView(CreateView), + Delete(Delete), + Describe(Describe), + DropColumn(DropColumn), + DropIndex(DropIndex), + DropTable(DropTable), + DropView(DropView), + Dummy(Dummy), + Except(Except), + Explain(Explain), + Filter(Filter), + FunctionScan(FunctionScan), + HashAgg(HashAggExecutor), + HashJoin(HashJoin), + IndexScan(IndexScan<'a, T>), + Insert(Insert), + Limit(Limit), + NestedLoopJoin(NestedLoopJoin), + Projection(Projection), + ScalarSubquery(ScalarSubquery), + SeqScan(SeqScan<'a, T>), + ShowTables(ShowTables), + ShowViews(ShowViews), + SimpleAgg(SimpleAggExecutor), + Sort(Sort), + StreamDistinct(StreamDistinctExecutor), + TopK(TopK), + Truncate(Truncate), + Union(Union), + Update(Update), + Values(Values), + Empty, +} + +pub(crate) trait ExecNodeRunner<'a, T: Transaction + 'a> { + fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError>; +} + +macro_rules! impl_exec_node_runner { + ($($ty:ty),* $(,)?) => { + $( + impl<'a, T: Transaction + 'a> ExecNodeRunner<'a, T> for $ty { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + ) -> Result<(), DatabaseError> { + <$ty>::next_tuple(self, arena) + } + } + )* + }; +} + +impl_exec_node_runner!( + AddColumn, + Analyze, + ChangeColumn, + CopyFromFile, + CopyToFile, + CreateIndex, + CreateTable, + CreateView, + Delete, + Describe, + DropColumn, + DropIndex, + DropTable, + DropView, + Dummy, + Except, + Explain, + Filter, + FunctionScan, + HashAggExecutor, + HashJoin, + IndexScan<'a, T>, + Insert, + Limit, + NestedLoopJoin, + Projection, + ScalarSubquery, + SeqScan<'a, T>, + ShowTables, + ShowViews, + SimpleAggExecutor, + Sort, + StreamDistinctExecutor, + TopK, + Truncate, + Union, + Update, + Values, +); + +impl<'a, T: Transaction + 'a> ExecNodeRunner<'a, T> for ExecNode<'a, T> { + fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { + match self { + ExecNode::AddColumn(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Analyze(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::ChangeColumn(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::CopyFromFile(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::CopyToFile(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::CreateIndex(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::CreateTable(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::CreateView(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Delete(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Describe(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::DropColumn(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::DropIndex(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::DropTable(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::DropView(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Dummy(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Except(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Explain(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Filter(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::FunctionScan(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::HashAgg(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::HashJoin(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::IndexScan(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Insert(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Limit(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::NestedLoopJoin(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Projection(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::ScalarSubquery(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::SeqScan(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::ShowTables(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::ShowViews(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::SimpleAgg(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Sort(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::StreamDistinct(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::TopK(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Truncate(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Union(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Update(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Values(exec) => ExecNodeRunner::next_tuple(exec, arena), + ExecNode::Empty => unreachable!("executor node re-entered while active"), + } + } +} + +pub(crate) struct ExecArena<'a, T: Transaction + 'a> { + nodes: Vec>, + result: ExecResult, + cache: Option>, + transaction: *mut T, } -pub trait ReadExecutor<'a, T: Transaction + 'a> { - fn execute( +impl<'a, T: Transaction + 'a> Default for ExecArena<'a, T> { + fn default() -> Self { + Self { + nodes: Vec::new(), + result: ExecResult::default(), + cache: None, + transaction: std::ptr::null_mut(), + } + } +} + +impl<'a, T: Transaction + 'a> ExecArena<'a, T> { + pub(crate) fn init_context(&mut self, cache: ExecutionCaches<'a>, transaction: *mut T) { + if let Some(current) = self.cache { + debug_assert!(std::ptr::eq(current.0, cache.0)); + debug_assert!(std::ptr::eq(current.1, cache.1)); + debug_assert!(std::ptr::eq(current.2, cache.2)); + debug_assert_eq!(self.transaction, transaction); + } else { + self.cache = Some(cache); + self.transaction = transaction; + } + } + + pub(crate) fn push(&mut self, node: ExecNode<'a, T>) -> ExecId { + let id = self.nodes.len(); + self.nodes.push(node); + id + } + + pub(crate) fn table_cache(&self) -> &'a TableCache { + self.cache.expect("execution arena context initialized").0 + } + + pub(crate) fn view_cache(&self) -> &'a ViewCache { + self.cache.expect("execution arena context initialized").1 + } + + pub(crate) fn meta_cache(&self) -> &'a StatisticsMetaCache { + self.cache.expect("execution arena context initialized").2 + } + + pub(crate) fn transaction(&self) -> &'a T { + unsafe { &*self.transaction } + } + + pub(crate) fn transaction_mut(&mut self) -> &'a mut T { + unsafe { &mut *self.transaction } + } + + #[inline] + pub(crate) fn result_tuple(&self) -> &Tuple { + &self.result.tuple + } + + #[inline] + pub(crate) fn result_tuple_mut(&mut self) -> &mut Tuple { + &mut self.result.tuple + } + + #[inline] + pub(crate) fn resume(&mut self) { + self.result.status = Some(ExecStatus::Continue); + } + + #[inline] + pub(crate) fn finish(&mut self) { + self.result.status = Some(ExecStatus::End); + } + + #[inline] + pub(crate) fn produce_tuple(&mut self, tuple: Tuple) { + self.result.tuple = tuple; + self.resume(); + } + + pub(crate) fn next_tuple(&mut self, id: ExecId) -> Result { + self.result.status = None; + let mut node = std::mem::replace(&mut self.nodes[id], ExecNode::Empty); + let result = ExecNodeRunner::next_tuple(&mut node, self); + self.nodes[id] = node; + result?; + + match self.result.status.unwrap_or(ExecStatus::End) { + ExecStatus::Continue => Ok(true), + ExecStatus::End => Ok(false), + } + } +} + +pub(crate) trait ReadExecutor<'a, T: Transaction + 'a>: Sized { + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a>; + ) -> ExecId; } -pub trait WriteExecutor<'a, T: Transaction + 'a> { - fn execute_mut( +pub(crate) trait WriteExecutor<'a, T: Transaction + 'a>: Sized { + fn into_executor( self, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + arena: &mut ExecArena<'a, T>, + cache: ExecutionCaches<'a>, transaction: *mut T, - ) -> Executor<'a>; + ) -> ExecId; } -pub fn build_read<'a, T: Transaction + 'a>( +pub(crate) fn build_read<'a, T: Transaction + 'a>( + arena: &mut ExecArena<'a, T>, plan: LogicalPlan, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + cache: ExecutionCaches<'a>, transaction: *mut T, -) -> Executor<'a> { +) -> ExecId { + arena.init_context(cache, transaction); + let LogicalPlan { operator, childrens, @@ -103,12 +387,12 @@ pub fn build_read<'a, T: Transaction + 'a>( } = plan; match operator { - Operator::Dummy => Dummy {}.execute(cache, transaction), + Operator::Dummy => Dummy::default().into_executor(arena, cache, transaction), Operator::Aggregate(op) => { let input = childrens.pop_only(); if op.groupby_exprs.is_empty() { - SimpleAggExecutor::from((op, input)).execute(cache, transaction) + SimpleAggExecutor::from((op, input)).into_executor(arena, cache, transaction) } else if op.is_distinct && op.agg_calls.is_empty() && matches!( @@ -119,15 +403,15 @@ pub fn build_read<'a, T: Transaction + 'a>( }) ) { - StreamDistinctExecutor::from((op, input)).execute(cache, transaction) + StreamDistinctExecutor::from((op, input)).into_executor(arena, cache, transaction) } else { - HashAggExecutor::from((op, input)).execute(cache, transaction) + HashAggExecutor::from((op, input)).into_executor(arena, cache, transaction) } } Operator::Filter(op) => { let input = childrens.pop_only(); - Filter::from((op, input)).execute(cache, transaction) + Filter::from((op, input)).into_executor(arena, cache, transaction) } Operator::Join(op) => { let (left_input, right_input) = childrens.pop_twins(); @@ -143,17 +427,28 @@ pub fn build_read<'a, T: Transaction + 'a>( }) ) => { - HashJoin::from((op, left_input, right_input)).execute(cache, transaction) - } - _ => { - NestedLoopJoin::from((op, left_input, right_input)).execute(cache, transaction) + HashJoin::from((op, left_input, right_input)).into_executor( + arena, + cache, + transaction, + ) } + _ => NestedLoopJoin::from((op, left_input, right_input)).into_executor( + arena, + cache, + transaction, + ), } } Operator::Project(op) => { let input = childrens.pop_only(); - Projection::from((op, input)).execute(cache, transaction) + Projection::from((op, input)).into_executor(arena, cache, transaction) + } + Operator::ScalarSubquery(op) => { + let input = childrens.pop_only(); + + ScalarSubquery::from((op, input)).into_executor(arena, cache, transaction) } Operator::TableScan(op) => { if let Some(PhysicalOption { @@ -176,56 +471,61 @@ pub fn build_read<'a, T: Transaction + 'a>( covered_deserializers, cover_mapping, )) - .execute(cache, transaction); + .into_executor(arena, cache, transaction); } } - SeqScan::from(op).execute(cache, transaction) + SeqScan::from(op).into_executor(arena, cache, transaction) + } + Operator::FunctionScan(op) => { + FunctionScan::from(op).into_executor(arena, cache, transaction) } - Operator::FunctionScan(op) => FunctionScan::from(op).execute(cache, transaction), Operator::Sort(op) => { let input = childrens.pop_only(); - Sort::from((op, input)).execute(cache, transaction) + Sort::from((op, input)).into_executor(arena, cache, transaction) } Operator::Limit(op) => { let input = childrens.pop_only(); - Limit::from((op, input)).execute(cache, transaction) + Limit::from((op, input)).into_executor(arena, cache, transaction) } Operator::TopK(op) => { let input = childrens.pop_only(); - TopK::from((op, input)).execute(cache, transaction) + TopK::from((op, input)).into_executor(arena, cache, transaction) } - Operator::Values(op) => Values::from(op).execute(cache, transaction), - Operator::ShowTable => ShowTables.execute(cache, transaction), - Operator::ShowView => ShowViews.execute(cache, transaction), + Operator::Values(op) => Values::from(op).into_executor(arena, cache, transaction), + Operator::ShowTable => arena.push(ExecNode::ShowTables(ShowTables { metas: None })), + Operator::ShowView => arena.push(ExecNode::ShowViews(ShowViews { metas: None })), Operator::Explain => { let input = childrens.pop_only(); - Explain::from(input).execute(cache, transaction) + Explain::from(input).into_executor(arena, cache, transaction) } - Operator::Describe(op) => Describe::from(op).execute(cache, transaction), + Operator::Describe(op) => Describe::from(op).into_executor(arena, cache, transaction), Operator::Union(_) => { let (left_input, right_input) = childrens.pop_twins(); - Union::from((left_input, right_input)).execute(cache, transaction) + Union::from((left_input, right_input)).into_executor(arena, cache, transaction) } Operator::Except(_) => { let (left_input, right_input) = childrens.pop_twins(); - Except::from((left_input, right_input)).execute(cache, transaction) + Except::from((left_input, right_input)).into_executor(arena, cache, transaction) } _ => unreachable!(), } } -pub fn build_write<'a, T: Transaction + 'a>( +pub(crate) fn build_write<'a, T: Transaction + 'a>( + arena: &mut ExecArena<'a, T>, plan: LogicalPlan, - cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + cache: ExecutionCaches<'a>, transaction: *mut T, -) -> Executor<'a> { +) -> ExecId { + arena.init_context(cache, transaction); + let LogicalPlan { operator, childrens, @@ -237,45 +537,49 @@ pub fn build_write<'a, T: Transaction + 'a>( Operator::Insert(op) => { let input = childrens.pop_only(); - Insert::from((op, input)).execute_mut(cache, transaction) + Insert::from((op, input)).into_executor(arena, cache, transaction) } Operator::Update(op) => { let input = childrens.pop_only(); - Update::from((op, input)).execute_mut(cache, transaction) + Update::from((op, input)).into_executor(arena, cache, transaction) } Operator::Delete(op) => { let input = childrens.pop_only(); - Delete::from((op, input)).execute_mut(cache, transaction) + Delete::from((op, input)).into_executor(arena, cache, transaction) + } + Operator::AddColumn(op) => AddColumn::from(op).into_executor(arena, cache, transaction), + Operator::ChangeColumn(op) => { + ChangeColumn::from(op).into_executor(arena, cache, transaction) } - Operator::AddColumn(op) => AddColumn::from(op).execute_mut(cache, transaction), - Operator::ChangeColumn(op) => ChangeColumn::from(op).execute_mut(cache, transaction), - Operator::DropColumn(op) => DropColumn::from(op).execute_mut(cache, transaction), - Operator::CreateTable(op) => CreateTable::from(op).execute_mut(cache, transaction), + Operator::DropColumn(op) => DropColumn::from(op).into_executor(arena, cache, transaction), + Operator::CreateTable(op) => CreateTable::from(op).into_executor(arena, cache, transaction), Operator::CreateIndex(op) => { let input = childrens.pop_only(); - CreateIndex::from((op, input)).execute_mut(cache, transaction) + CreateIndex::from((op, input)).into_executor(arena, cache, transaction) + } + Operator::CreateView(op) => CreateView::from(op).into_executor(arena, cache, transaction), + Operator::DropTable(op) => DropTable::from(op).into_executor(arena, cache, transaction), + Operator::DropView(op) => DropView::from(op).into_executor(arena, cache, transaction), + Operator::DropIndex(op) => DropIndex::from(op).into_executor(arena, cache, transaction), + Operator::Truncate(op) => Truncate::from(op).into_executor(arena, cache, transaction), + Operator::CopyFromFile(op) => { + CopyFromFile::from(op).into_executor(arena, cache, transaction) } - Operator::CreateView(op) => CreateView::from(op).execute_mut(cache, transaction), - Operator::DropTable(op) => DropTable::from(op).execute_mut(cache, transaction), - Operator::DropView(op) => DropView::from(op).execute_mut(cache, transaction), - Operator::DropIndex(op) => DropIndex::from(op).execute_mut(cache, transaction), - Operator::Truncate(op) => Truncate::from(op).execute_mut(cache, transaction), - Operator::CopyFromFile(op) => CopyFromFile::from(op).execute_mut(cache, transaction), Operator::CopyToFile(op) => { let input = childrens.pop_only(); - CopyToFile::from((op, input)).execute(cache, transaction) + CopyToFile::from((op, input)).into_executor(arena, cache, transaction) } - Operator::Analyze(op) => { let input = childrens.pop_only(); - Analyze::from((op, input)).execute_mut(cache, transaction) + Analyze::from((op, input)).into_executor(arena, cache, transaction) } operator => build_read( + arena, LogicalPlan { operator, childrens, @@ -289,6 +593,44 @@ pub fn build_write<'a, T: Transaction + 'a>( } #[cfg(all(test, not(target_arch = "wasm32")))] -pub fn try_collect(executor: Executor) -> Result, DatabaseError> { - executor.collect() +pub(crate) fn execute<'a, T, E>( + executor: E, + cache: ExecutionCaches<'a>, + transaction: *mut T, +) -> Executor<'a, T> +where + T: Transaction + 'a, + E: ReadExecutor<'a, T>, +{ + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = executor.into_executor(&mut arena, cache, transaction); + Executor::new(arena, root) +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +pub(crate) fn execute_mut<'a, T, E>( + executor: E, + cache: ExecutionCaches<'a>, + transaction: *mut T, +) -> Executor<'a, T> +where + T: Transaction + 'a, + E: WriteExecutor<'a, T>, +{ + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = executor.into_executor(&mut arena, cache, transaction); + Executor::new(arena, root) +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +pub fn try_collect(executor: Executor<'_, T>) -> Result, DatabaseError> { + let mut executor = executor; + let mut tuples = Vec::new(); + + while let Some(tuple) = executor.next_tuple()? { + tuples.push(tuple.clone()); + } + Ok(tuples) } diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 6a99912a..cdf30624 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -52,9 +52,7 @@ impl ScalarExpression { let Some((tuple, _)) = tuple else { return Ok(DataValue::Null); }; - let position = position - .ok_or_else(|| DatabaseError::UnbindExpressionPosition(self.clone()))?; - Ok(tuple.into()[position].clone()) + Ok(tuple.into()[*position].clone()) } ScalarExpression::Alias { expr, alias } => { let Some((tuple, schema)) = tuple else { diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 4c4942d8..a8a5a27e 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -13,12 +13,12 @@ // limitations under the License. use self::agg::AggKind; -use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnSummary}; +use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::errors::DatabaseError; use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::visitor::{walk_expr, Visitor}; -use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; +use crate::expression::visitor_mut::VisitorMut; use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory, UnaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -28,10 +28,8 @@ use sqlparser::ast::TrimWhereField; use sqlparser::ast::{ BinaryOperator as SqlBinaryOperator, CharLengthUnits, UnaryOperator as SqlUnaryOperator, }; -use std::borrow::Cow; use std::fmt::{Debug, Formatter}; use std::hash::Hash; -use std::slice::IterMut; use std::{fmt, mem}; pub mod agg; @@ -57,7 +55,7 @@ pub enum ScalarExpression { Constant(DataValue), ColumnRef { column: ColumnRef, - position: Option, + position: usize, }, Alias { expr: Box, @@ -148,79 +146,6 @@ pub enum ScalarExpression { }, } -#[derive(Clone)] -pub struct BindPosition< - T: Clone, - F: Clone + Fn() -> T, - E: Fn(&ColumnSummary, &ColumnSummary) -> bool, -> { - fn_output_columns: F, - fn_eq: E, -} - -impl< - 'a, - 'b, - T: Iterator> + Clone, - F: Clone + Fn() -> T, - E: Clone + Fn(&ColumnSummary, &ColumnSummary) -> bool, - > VisitorMut<'a> for BindPosition -{ - fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { - walk_mut_expr(&mut self.clone(), expr)?; - - let column = expr.output_column(); - - if let Some((pos, _)) = (self.fn_output_columns)() - .find_position(|c| (self.fn_eq)(c.summary(), column.summary())) - { - *expr = ScalarExpression::ColumnRef { - column, - position: Some(pos), - }; - } - Ok(()) - } - - fn visit_alias( - &mut self, - expr: &'a mut ScalarExpression, - ty: &'a mut AliasType, - ) -> Result<(), DatabaseError> { - if let AliasType::Expr(inner_expr) = ty { - self.visit(inner_expr)?; - } - self.visit(expr)?; - Ok(()) - } -} - -impl<'b, T, F, E> BindPosition -where - T: Iterator> + Clone, - F: Clone + Fn() -> T, - E: Clone + Fn(&ColumnSummary, &ColumnSummary) -> bool, -{ - pub fn new(output_columns: F, fn_eq: E) -> BindPosition { - BindPosition { - fn_output_columns: output_columns, - fn_eq, - } - } - - pub fn bind_exprs( - exprs: IterMut, - fn_schema: F, - fn_eq: E, - ) -> Result<(), DatabaseError> { - let mut bind_schema_position = BindPosition::new(fn_schema, fn_eq); - for expr in exprs { - bind_schema_position.visit(expr)?; - } - Ok(()) - } -} - pub struct BindEvaluator; impl VisitorMut<'_> for BindEvaluator { @@ -311,10 +236,21 @@ impl Visitor<'_> for HasCountStar { } impl ScalarExpression { - pub fn column_expr(column: ColumnRef) -> ScalarExpression { - ScalarExpression::ColumnRef { - column, - position: None, + pub fn column_expr(column: ColumnRef, position: usize) -> ScalarExpression { + ScalarExpression::ColumnRef { column, position } + } + + pub(crate) fn eq_ignore_colref_pos(&self, other: &ScalarExpression) -> bool { + match (self.unpack_alias_ref(), other.unpack_alias_ref()) { + ( + ScalarExpression::ColumnRef { + column: lhs_column, .. + }, + ScalarExpression::ColumnRef { + column: rhs_column, .. + }, + ) => lhs_column.same_column(rhs_column), + (lhs, rhs) => lhs == rhs, } } @@ -400,43 +336,90 @@ impl ScalarExpression { } } - pub fn referenced_columns(&self, only_column_ref: bool) -> Vec { - struct ColumnRefCollector(Vec); - impl Visitor<'_> for ColumnRefCollector { - fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { - self.0.push(col.clone()); + pub fn visit_referenced_columns( + &self, + only_column_ref: bool, + f: &mut impl FnMut(&ColumnRef) -> bool, + ) -> bool { + struct ColumnRefVisitor<'a, F> { + f: &'a mut F, + keep_going: bool, + } + impl bool> Visitor<'_> for ColumnRefVisitor<'_, F> { + fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { + if self.keep_going { + walk_expr(self, expr)?; + } Ok(()) } - fn visit_alias( - &mut self, - expr: &ScalarExpression, - ty: &AliasType, - ) -> Result<(), DatabaseError> { - if let AliasType::Expr(alias_expr) = ty { - self.0.push(alias_expr.output_column()); - } - self.visit(expr) + fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { + self.keep_going = (self.f)(col); + Ok(()) } } - struct OutputColumnCollector(Vec); - impl Visitor<'_> for OutputColumnCollector { + struct OutputColumnVisitor<'a, F> { + f: &'a mut F, + keep_going: bool, + } + impl bool> Visitor<'_> for OutputColumnVisitor<'_, F> { fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { - self.0.push(expr.output_column()); - walk_expr(self, expr) + if !self.keep_going { + return Ok(()); + } + + let output = expr.output_column(); + self.keep_going = (self.f)(&output); + if self.keep_going { + walk_expr(self, expr)?; + } + Ok(()) } } + if only_column_ref { - let mut collector = ColumnRefCollector(Vec::new()); - collector.visit(self).unwrap(); - collector.0 + let mut visitor = ColumnRefVisitor { + f, + keep_going: true, + }; + visitor.visit(self).unwrap(); + visitor.keep_going } else { - let mut collector = OutputColumnCollector(Vec::new()); - collector.visit(self).unwrap(); - collector.0 + let mut visitor = OutputColumnVisitor { + f, + keep_going: true, + }; + visitor.visit(self).unwrap(); + visitor.keep_going } } + pub fn any_referenced_column( + &self, + only_column_ref: bool, + mut predicate: impl FnMut(&ColumnRef) -> bool, + ) -> bool { + let mut found = false; + self.visit_referenced_columns(only_column_ref, &mut |column| { + found = predicate(column); + !found + }); + found + } + + pub fn all_referenced_columns( + &self, + only_column_ref: bool, + mut predicate: impl FnMut(&ColumnRef) -> bool, + ) -> bool { + let mut all = true; + self.visit_referenced_columns(only_column_ref, &mut |column| { + all = predicate(column); + all + }); + all + } + pub fn has_table_ref_column(&self) -> bool { struct TableRefChecker { found: bool, @@ -840,6 +823,38 @@ mod test { use std::sync::Arc; use tempfile::TempDir; + #[test] + fn test_eq_ignore_colref_pos() -> Result<(), DatabaseError> { + let left = ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new( + "c1".to_string(), + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + )), + 0, + ); + let right = ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new( + "c1".to_string(), + true, + ColumnDesc::new(LogicalType::Bigint, None, false, None)?, + )), + 2, + ); + let different = ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new( + "c2".to_string(), + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + )), + 0, + ); + + assert!(left.eq_ignore_colref_pos(&right)); + assert!(!left.eq_ignore_colref_pos(&different)); + Ok(()) + } + #[test] fn test_serialization() -> Result<(), DatabaseError> { fn fn_assert( @@ -900,33 +915,39 @@ mod test { )?; fn_assert( &mut cursor, - ScalarExpression::column_expr(ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::Table { - column_id: c3_column_id, - table_name: "t1".to_string().into(), - is_temp: false, + ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c3".to_string(), + relation: ColumnRelation::Table { + column_id: c3_column_id, + table_name: "t1".to_string().into(), + is_temp: false, + }, }, - }, - false, - ColumnDesc::new(LogicalType::Integer, None, false, None)?, - false, - ))), + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + false, + )), + 0, + ), Some((&transaction, &table_cache)), &mut reference_tables, )?; fn_assert( &mut cursor, - ScalarExpression::column_expr(ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c4".to_string(), - relation: ColumnRelation::None, - }, - false, - ColumnDesc::new(LogicalType::Boolean, None, false, None)?, - false, - ))), + ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c4".to_string(), + relation: ColumnRelation::None, + }, + false, + ColumnDesc::new(LogicalType::Boolean, None, false, None)?, + false, + )), + 1, + ), Some((&transaction, &table_cache)), &mut reference_tables, )?; diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index 60066e6f..0bb3e725 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -207,13 +207,15 @@ impl<'a> RangeDetacher<'a> { Self::merge_binary(*op, left_binary, right_binary) } (None, None) => { - if let (Some(col), Some(val)) = - (left_expr.unpack_col(false), right_expr.unpack_val()) - { + if let (Some(col), Some(val)) = ( + left_expr.unpack_bound_col(false).map(|(column, _)| column), + right_expr.unpack_val(), + ) { return self.new_range(*op, col, val, false); - } else if let (Some(val), Some(col)) = - (left_expr.unpack_val(), right_expr.unpack_col(false)) - { + } else if let (Some(val), Some(col)) = ( + left_expr.unpack_val(), + right_expr.unpack_bound_col(false).map(|(column, _)| column), + ) { return self.new_range(*op, col, val, true); } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index f63cc604..55907286 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -152,19 +152,22 @@ impl VisitorMut<'_> for Simplify { self.fix_expr(right_expr, left_expr, op)?; if Self::is_arithmetic(op) { - match (left_expr.unpack_col(false), right_expr.unpack_col(false)) { - (Some(col), None) => { + match ( + left_expr.unpack_bound_col(false), + right_expr.unpack_bound_col(false), + ) { + (Some((col, position)), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::column_expr(col), + column_expr: ScalarExpression::column_expr(col, position), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), is_column_left: true, })); } - (None, Some(col)) => { + (None, Some((col, position))) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::column_expr(col), + column_expr: ScalarExpression::column_expr(col, position), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -176,19 +179,22 @@ impl VisitorMut<'_> for Simplify { return Ok(()); } - match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { - (Some(col), None) => { + match ( + left_expr.unpack_bound_col(true), + right_expr.unpack_bound_col(true), + ) { + (Some((col, position)), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::column_expr(col), + column_expr: ScalarExpression::column_expr(col, position), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), is_column_left: true, })); } - (None, Some(col)) => { + (None, Some((col, position))) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::column_expr(col), + column_expr: ScalarExpression::column_expr(col, position), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -503,11 +509,11 @@ impl ScalarExpression { } } - pub(crate) fn unpack_col(&self, is_deep: bool) -> Option { + pub(crate) fn unpack_bound_col(&self, is_deep: bool) -> Option<(ColumnRef, usize)> { match self { - ScalarExpression::ColumnRef { column, .. } => Some(column.clone()), - ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_deep), - ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_deep), + ScalarExpression::ColumnRef { column, position } => Some((column.clone(), *position)), + ScalarExpression::Alias { expr, .. } => expr.unpack_bound_col(is_deep), + ScalarExpression::Unary { expr, .. } => expr.unpack_bound_col(is_deep), ScalarExpression::Binary { left_expr, right_expr, @@ -518,8 +524,8 @@ impl ScalarExpression { } left_expr - .unpack_col(true) - .or_else(|| right_expr.unpack_col(true)) + .unpack_bound_col(true) + .or_else(|| right_expr.unpack_bound_col(true)) } _ => None, } diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index b7a949f8..ab56977c 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -39,8 +39,11 @@ pub trait Visitor<'a>: Sized { fn visit_alias( &mut self, expr: &'a ScalarExpression, - _ty: &'a AliasType, + ty: &'a AliasType, ) -> Result<(), DatabaseError> { + if let AliasType::Expr(alias_expr) = ty { + self.visit(alias_expr)?; + } self.visit(expr) } diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index 87325600..36222659 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -23,6 +23,25 @@ use crate::types::value::DataValue; use crate::types::LogicalType; use sqlparser::ast::TrimWhereField; +pub(crate) struct PositionShift { + pub(crate) delta: isize, +} + +impl VisitorMut<'_> for PositionShift { + fn visit_column_ref( + &mut self, + _column: &mut ColumnRef, + position: &mut usize, + ) -> Result<(), DatabaseError> { + if self.delta.is_negative() { + *position = position.saturating_sub(self.delta.unsigned_abs()); + } else { + *position += self.delta as usize; + } + Ok(()) + } +} + pub trait VisitorMut<'a>: Sized { fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { walk_mut_expr(self, expr) @@ -32,15 +51,22 @@ pub trait VisitorMut<'a>: Sized { Ok(()) } - fn visit_column_ref(&mut self, _column: &'a mut ColumnRef) -> Result<(), DatabaseError> { + fn visit_column_ref( + &mut self, + _column: &'a mut ColumnRef, + _position: &'a mut usize, + ) -> Result<(), DatabaseError> { Ok(()) } fn visit_alias( &mut self, expr: &'a mut ScalarExpression, - _ty: &'a mut AliasType, + ty: &'a mut AliasType, ) -> Result<(), DatabaseError> { + if let AliasType::Expr(alias_expr) = ty { + self.visit(alias_expr)?; + } self.visit(expr) } @@ -268,7 +294,9 @@ pub fn walk_mut_expr<'a, V: VisitorMut<'a>>( ) -> Result<(), DatabaseError> { match expr { ScalarExpression::Constant(value) => visitor.visit_constant(value), - ScalarExpression::ColumnRef { column, .. } => visitor.visit_column_ref(column), + ScalarExpression::ColumnRef { column, position } => { + visitor.visit_column_ref(column, position) + } ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), diff --git a/src/lib.rs b/src/lib.rs index 89b7e902..04ff5651 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,14 @@ // limitations under the License. //! KiteSQL is a high-performance SQL database -//! that can be embedded in Rust code (based on RocksDB by default), +//! that can be embedded in Rust code with RocksDB, LMDB, or in-memory storage, //! making it possible to call SQL just like calling a function. //! It supports most of the syntax of SQL 2016. //! +//! Native storage backends are feature-gated: +//! - `rocksdb` is enabled by default +//! - `lmdb` is optional +//! //! KiteSQL provides thread-safe API: [`DataBase::run`](db::Database::run) for running SQL //! //! KiteSQL uses [`DataBaseBuilder`](db::DataBaseBuilder) for instance construction, @@ -65,7 +69,8 @@ //! //! #[cfg(feature = "orm")] //! fn main() -> Result<(), DatabaseError> { -//! let database = DataBaseBuilder::path("./hello_world").build()?; +//! let database = DataBaseBuilder::path("./hello_world").build_rocksdb()?; +//! // Or: let database = DataBaseBuilder::path("./hello_world").build_lmdb()?; //! //! database.create_table_if_not_exists::()?; //! database.insert(&MyStruct { diff --git a/src/optimizer/core/cm_sketch.rs b/src/optimizer/core/cm_sketch.rs index 5ad5e95e..fcc3c03f 100644 --- a/src/optimizer/core/cm_sketch.rs +++ b/src/optimizer/core/cm_sketch.rs @@ -107,8 +107,8 @@ impl CountMinSketch { width, k_num, page_len, - hasher_0: hashers[0].clone(), - hasher_1: hashers[1].clone(), + hasher_0: hashers[0], + hasher_1: hashers[1], }; let pages = counters .into_iter() diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index 082e6f04..c61e83f3 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -18,7 +18,7 @@ use crate::expression::range_detacher::Range; use crate::expression::BinaryOperator; use crate::optimizer::core::cm_sketch::CountMinSketch; use crate::storage::table_codec::BumpBytes; -use crate::types::evaluator::EvaluatorFactory; +use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; use crate::types::index::{IndexId, IndexMeta}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -41,6 +41,13 @@ pub struct HistogramBuilder { value_index: usize, } +struct BoundComparator { + lt: BinaryEvaluatorBox, + lte: BinaryEvaluatorBox, + gt: BinaryEvaluatorBox, + gte: BinaryEvaluatorBox, +} + #[derive(Debug, Clone, PartialEq, ReferenceSerialization)] pub struct HistogramMeta { index_id: IndexId, @@ -222,22 +229,57 @@ impl HistogramBuilder { } } +impl BoundComparator { + fn new(ty: LogicalType) -> Result { + Ok(Self { + lt: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Lt)?, + lte: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::LtEq)?, + gt: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Gt)?, + gte: EvaluatorFactory::binary_create(ty, BinaryOperator::GtEq)?, + }) + } + + fn lt(&self, value: &DataValue, target: &DataValue) -> Result { + Ok(matches!( + self.lt.binary_eval(value, target)?, + DataValue::Boolean(true) + )) + } + + fn lte(&self, value: &DataValue, target: &DataValue) -> Result { + Ok(matches!( + self.lte.binary_eval(value, target)?, + DataValue::Boolean(true) + )) + } + + fn gt(&self, value: &DataValue, target: &DataValue) -> Result { + Ok(matches!( + self.gt.binary_eval(value, target)?, + DataValue::Boolean(true) + )) + } + + fn gte(&self, value: &DataValue, target: &DataValue) -> Result { + Ok(matches!( + self.gte.binary_eval(value, target)?, + DataValue::Boolean(true) + )) + } +} + fn is_under( + comparator: &BoundComparator, value: &DataValue, target: &Bound, is_min: bool, ) -> Result { let _is_under = |value: &DataValue, target: &DataValue, is_min: bool| { - let evaluator = EvaluatorFactory::binary_create( - value.logical_type(), - if is_min { - BinaryOperator::Lt - } else { - BinaryOperator::LtEq - }, - )?; - let value = evaluator.0.binary_eval(value, target)?; - Ok::(matches!(value, DataValue::Boolean(true))) + if is_min { + comparator.lt(value, target) + } else { + comparator.lte(value, target) + } }; Ok(match target { @@ -248,21 +290,17 @@ fn is_under( } fn is_above( + comparator: &BoundComparator, value: &DataValue, target: &Bound, is_min: bool, ) -> Result { let _is_above = |value: &DataValue, target: &DataValue, is_min: bool| { - let evaluator = EvaluatorFactory::binary_create( - value.logical_type(), - if is_min { - BinaryOperator::GtEq - } else { - BinaryOperator::Gt - }, - )?; - let value = evaluator.0.binary_eval(value, target)?; - Ok::(matches!(value, DataValue::Boolean(true))) + if is_min { + comparator.gte(value, target) + } else { + comparator.gt(value, target) + } }; Ok(match target { Bound::Included(target) => _is_above(value, target, is_min)?, @@ -316,17 +354,15 @@ impl Histogram { self.meta.buckets_len } - pub fn collect_count( + pub fn collect_count( &self, ranges: &[Range], - estimate: &mut F, - ) -> Result - where - F: FnMut(&DataValue) -> Result, - { + sketch: &CountMinSketch, + ) -> Result { if self.buckets.is_empty() || ranges.is_empty() { return Ok(0); } + let comparator = BoundComparator::new(self.buckets[0].upper.logical_type())?; let mut count = 0; let mut binary_i = 0; @@ -340,7 +376,8 @@ impl Histogram { &mut bucket_i, &mut bucket_idxs, &mut count, - estimate, + sketch, + &comparator, )?; if is_dummy { return Ok(0); @@ -354,18 +391,17 @@ impl Histogram { + count) } - fn _collect_count( + #[allow(clippy::too_many_arguments)] + fn _collect_count( &self, ranges: &[Range], binary_i: &mut usize, bucket_i: &mut usize, bucket_idxs: &mut Vec, count: &mut usize, - estimate: &mut F, - ) -> Result - where - F: FnMut(&DataValue) -> Result, - { + sketch: &CountMinSketch, + comparator: &BoundComparator, + ) -> Result { let float_value = |value: &DataValue, prefix_len: usize| { let value = match value.logical_type() { LogicalType::Varchar(..) | LogicalType::Char(..) => match value { @@ -472,22 +508,23 @@ impl Histogram { _ => false, }; - if (is_above(&bucket.lower, min, true)? || is_eq(&bucket.lower, min)) - && (is_under(&bucket.upper, max, false)? || is_eq(&bucket.upper, max)) + if (is_above(comparator, &bucket.lower, min, true)? || is_eq(&bucket.lower, min)) + && (is_under(comparator, &bucket.upper, max, false)? + || is_eq(&bucket.upper, max)) { bucket_idxs.push(mem::replace(bucket_i, *bucket_i + 1)); - } else if is_above(&bucket.lower, max, false)? { + } else if is_above(comparator, &bucket.lower, max, false)? { *binary_i += 1; - } else if is_under(&bucket.upper, min, true)? { + } else if is_under(comparator, &bucket.upper, min, true)? { *bucket_i += 1; - } else if is_above(&bucket.lower, min, true)? { + } else if is_above(comparator, &bucket.lower, min, true)? { let (temp_ratio, option) = match max { Bound::Included(val) => { (calc_fraction(&bucket.lower, &bucket.upper, val)?, None) } Bound::Excluded(val) => ( calc_fraction(&bucket.lower, &bucket.upper, val)?, - Some(estimate(val)?), + Some(sketch.estimate(val)), ), Bound::Unbounded => unreachable!(), }; @@ -497,14 +534,14 @@ impl Histogram { temp_count = temp_count.saturating_sub(count); } *bucket_i += 1; - } else if is_under(&bucket.upper, max, false)? { + } else if is_under(comparator, &bucket.upper, max, false)? { let (temp_ratio, option) = match min { Bound::Included(val) => { (calc_fraction(&bucket.lower, &bucket.upper, val)?, None) } Bound::Excluded(val) => ( calc_fraction(&bucket.lower, &bucket.upper, val)?, - Some(estimate(val)?), + Some(sketch.estimate(val)), ), Bound::Unbounded => unreachable!(), }; @@ -521,7 +558,7 @@ impl Histogram { } Bound::Excluded(val) => ( calc_fraction(&bucket.lower, &bucket.upper, val)?, - Some(estimate(val)?), + Some(sketch.estimate(val)), ), Bound::Unbounded => unreachable!(), }; @@ -531,7 +568,7 @@ impl Histogram { } Bound::Excluded(val) => ( calc_fraction(&bucket.lower, &bucket.upper, val)?, - Some(estimate(val)?), + Some(sketch.estimate(val)), ), Bound::Unbounded => unreachable!(), }; @@ -549,7 +586,7 @@ impl Histogram { *count += cmp::max(temp_count, 0); } Range::Eq(value) => { - *count += estimate(value)?; + *count += sketch.estimate(value); *binary_i += 1 } Range::Dummy => return Ok(true), @@ -823,7 +860,6 @@ mod tests { builder.append(&DataValue::Null)?; let (histogram, sketch) = builder.build(4)?; - let mut estimate = |value: &DataValue| Ok(sketch.estimate(value)); let count_1 = histogram.collect_count( &[ @@ -833,7 +869,7 @@ mod tests { max: Bound::Excluded(DataValue::Int32(12)), }, ], - &mut estimate, + &sketch, )?; assert_eq!(count_1, 9); @@ -843,7 +879,7 @@ mod tests { min: Bound::Included(DataValue::Int32(4)), max: Bound::Unbounded, }], - &mut estimate, + &sketch, )?; assert_eq!(count_2, 11); @@ -853,7 +889,7 @@ mod tests { min: Bound::Excluded(DataValue::Int32(7)), max: Bound::Unbounded, }], - &mut estimate, + &sketch, )?; assert_eq!(count_3, 7); @@ -863,7 +899,7 @@ mod tests { min: Bound::Unbounded, max: Bound::Included(DataValue::Int32(11)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_4, 12); @@ -873,7 +909,7 @@ mod tests { min: Bound::Unbounded, max: Bound::Excluded(DataValue::Int32(8)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_5, 8); @@ -883,7 +919,7 @@ mod tests { min: Bound::Included(DataValue::Int32(2)), max: Bound::Unbounded, }], - &mut estimate, + &sketch, )?; assert_eq!(count_6, 13); @@ -893,7 +929,7 @@ mod tests { min: Bound::Excluded(DataValue::Int32(1)), max: Bound::Unbounded, }], - &mut estimate, + &sketch, )?; assert_eq!(count_7, 13); @@ -903,7 +939,7 @@ mod tests { min: Bound::Unbounded, max: Bound::Included(DataValue::Int32(12)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_8, 13); @@ -913,7 +949,7 @@ mod tests { min: Bound::Unbounded, max: Bound::Excluded(DataValue::Int32(13)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_9, 13); @@ -923,7 +959,7 @@ mod tests { min: Bound::Excluded(DataValue::Int32(0)), max: Bound::Excluded(DataValue::Int32(3)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_10, 2); @@ -933,7 +969,7 @@ mod tests { min: Bound::Included(DataValue::Int32(1)), max: Bound::Included(DataValue::Int32(2)), }], - &mut estimate, + &sketch, )?; assert_eq!(count_11, 2); diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs deleted file mode 100644 index 914475b5..00000000 --- a/src/optimizer/core/memo.rs +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::errors::DatabaseError; -use crate::optimizer::core::pattern::PatternMatcher; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; -use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::optimizer::heuristic::matcher::PlanMatcher; -use crate::optimizer::rule::implementation::ImplementationRuleImpl; -use crate::planner::operator::PhysicalOption; -use crate::planner::{Childrens, LogicalPlan}; -use crate::storage::Transaction; -use std::cmp::Ordering; -use std::collections::HashMap; - -#[derive(Debug, Clone)] -pub struct Expression { - pub(crate) op: PhysicalOption, - pub(crate) cost: Option, - // TODO: output rows -} - -#[derive(Debug, Clone)] -pub struct GroupExpression { - exprs: Vec, -} - -impl GroupExpression { - pub(crate) fn append_expr(&mut self, expr: Expression) { - self.exprs.push(expr); - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) struct NodePath(Vec); - -impl NodePath { - fn root() -> Self { - Self(Vec::new()) - } - - fn child(&self, idx: usize) -> Self { - let mut path = self.0.clone(); - path.push(idx); - Self(path) - } -} - -#[derive(Debug)] -pub struct Memo { - groups: HashMap, -} - -impl Memo { - pub(crate) fn new( - plan: &LogicalPlan, - loader: &StatisticMetaLoader<'_, T>, - implementations: &[ImplementationRuleImpl], - ) -> Result { - let mut groups = HashMap::new(); - Self::collect(plan, NodePath::root(), loader, implementations, &mut groups)?; - Ok(Memo { groups }) - } - - fn collect( - plan: &LogicalPlan, - path: NodePath, - loader: &StatisticMetaLoader<'_, T>, - implementations: &[ImplementationRuleImpl], - groups: &mut HashMap, - ) -> Result<(), DatabaseError> { - for rule in implementations { - if PlanMatcher::new(rule.pattern(), plan).match_opt_expr() { - let group_expr = groups - .entry(path.clone()) - .or_insert_with(|| GroupExpression { exprs: vec![] }); - rule.to_expression(&plan.operator, loader, group_expr)?; - } - } - - match plan.childrens.as_ref() { - Childrens::Only(child) => { - Self::collect(child, path.child(0), loader, implementations, groups)?; - } - Childrens::Twins { left, right } => { - Self::collect(left, path.child(0), loader, implementations, groups)?; - Self::collect(right, path.child(1), loader, implementations, groups)?; - } - Childrens::None => {} - } - - Ok(()) - } - - pub(crate) fn annotate_plan(&self, plan: &mut LogicalPlan) { - Self::annotate(plan, &NodePath::root(), self); - } - - fn annotate(plan: &mut LogicalPlan, path: &NodePath, memo: &Memo) { - if let Some(option) = memo.cheapest_physical_option(path) { - plan.physical_option = Some(option); - } - - match plan.childrens.as_mut() { - Childrens::Only(child) => Self::annotate(child, &path.child(0), memo), - Childrens::Twins { left, right } => { - Self::annotate(left, &path.child(0), memo); - Self::annotate(right, &path.child(1), memo); - } - Childrens::None => {} - } - } - - pub(crate) fn cheapest_physical_option(&self, path: &NodePath) -> Option { - self.groups.get(path).and_then(|exprs| { - exprs - .exprs - .iter() - .min_by(|expr_1, expr_2| match (expr_1.cost, expr_2.cost) { - (Some(cost_1), Some(cost_2)) => cost_1.cmp(&cost_2), - (None, Some(_)) => Ordering::Greater, - (Some(_), None) => Ordering::Less, - (None, None) => Ordering::Equal, - }) - .map(|expr| expr.op.clone()) - }) - } -} - -#[cfg(all(test, not(target_arch = "wasm32")))] -mod tests { - use super::NodePath; - use crate::binder::{Binder, BinderContext}; - use crate::db::{DataBaseBuilder, ResultIter}; - use crate::errors::DatabaseError; - use crate::expression::range_detacher::Range; - use crate::expression::ScalarExpression; - use crate::optimizer::core::memo::Memo; - use crate::optimizer::heuristic::batch::HepBatchStrategy; - use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; - use crate::optimizer::rule::implementation::ImplementationRuleImpl; - use crate::optimizer::rule::normalization::NormalizationRuleImpl; - use crate::planner::operator::sort::SortField; - use crate::planner::operator::{PhysicalOption, PlanImpl, SortOption}; - use crate::storage::rocksdb::RocksTransaction; - use crate::storage::{Storage, Transaction}; - use crate::types::index::{IndexInfo, IndexMeta, IndexType}; - use crate::types::value::DataValue; - use crate::types::LogicalType; - use std::ops::Bound; - use std::sync::atomic::AtomicUsize; - use std::sync::Arc; - use tempfile::TempDir; - - // Tips: This test may occasionally encounter errors; you can repeat the test multiple times. - #[test] - fn test_build_memo() -> Result<(), DatabaseError> { - let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let database = DataBaseBuilder::path(temp_dir.path()).build()?; - database - .run("create table t1 (c1 int primary key, c2 int)")? - .done()?; - database - .run("create table t2 (c3 int primary key, c4 int)")? - .done()?; - - for i in 0..1000 { - database - .run(format!("insert into t1 values({}, {})", i, i + 1).as_str())? - .done()?; - } - database.run("analyze table t1")?.done()?; - - let transaction = database.storage.transaction()?; - let c1_column = transaction - .table(database.state.table_cache(), "t1".to_string().into())? - .unwrap() - .get_column_by_name("c1") - .unwrap(); - let sort_fields = vec![SortField::new( - ScalarExpression::column_expr(c1_column.clone()), - true, - false, - )]; - let scala_functions = Default::default(); - let table_functions = Default::default(); - let mut binder = Binder::new( - BinderContext::new( - database.state.table_cache(), - database.state.view_cache(), - &transaction, - &scala_functions, - &table_functions, - Arc::new(AtomicUsize::new(0)), - ), - &[], - None, - ); - // where: c1 => 2, (40, +inf) - let stmt = crate::parser::parse_sql( - "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", - )?; - let plan = binder.bind(&stmt[0])?; - let pipeline = HepOptimizerPipeline::builder() - .before_batch( - "Simplify Filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .before_batch( - "Predicate Pushdown".to_string(), - HepBatchStrategy::fix_point_topdown(10), - vec![ - NormalizationRuleImpl::PushPredicateThroughJoin, - NormalizationRuleImpl::PushJoinPredicateIntoScan, - NormalizationRuleImpl::PushPredicateIntoScan, - ], - ) - .build(); - let mut best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; - let rules = vec![ - ImplementationRuleImpl::Projection, - ImplementationRuleImpl::Filter, - ImplementationRuleImpl::HashJoin, - ImplementationRuleImpl::SeqScan, - ImplementationRuleImpl::IndexScan, - ]; - - let memo = Memo::new( - &best_plan, - &transaction.meta_loader(database.state.meta_cache()), - &rules, - )?; - Memo::annotate_plan(&memo, &mut best_plan); - let exprs = memo - .groups - .get(&NodePath(vec![0, 0, 0])) - .expect("missing group"); - - assert_eq!(exprs.exprs.len(), 2); - assert_eq!(exprs.exprs[0].cost, Some(1000)); - assert_eq!( - exprs.exprs[0].op, - PhysicalOption::new(PlanImpl::SeqScan, SortOption::None) - ); - assert!(exprs.exprs[1].cost.unwrap() >= 960); - assert!(matches!( - exprs.exprs[1].op, - PhysicalOption { - plan: PlanImpl::IndexScan(..), - .. - } - )); - assert_eq!( - best_plan - .childrens - .pop_only() - .childrens - .pop_twins() - .0 - .childrens - .pop_only() - .physical_option, - Some(PhysicalOption::new( - PlanImpl::IndexScan(Box::new(IndexInfo { - meta: Arc::new(IndexMeta { - id: 0, - column_ids: vec![c1_column.id().unwrap()], - table_name: "t1".to_string().into(), - pk_ty: LogicalType::Integer, - value_ty: LogicalType::Integer, - name: "pk_index".to_string(), - ty: IndexType::PrimaryKey { is_multiple: false }, - }), - sort_option: SortOption::OrderBy { - fields: sort_fields.clone(), - ignore_prefix_len: 0, - }, - range: Some(Range::SortedRanges(vec![ - Range::Eq(DataValue::Int32(2)), - Range::Scope { - min: Bound::Excluded(DataValue::Int32(40)), - max: Bound::Unbounded, - } - ])), - covered_deserializers: None, - cover_mapping: None, - sort_elimination_hint: None, - stream_distinct_hint: None, - })), - SortOption::OrderBy { - fields: sort_fields, - ignore_prefix_len: 0, - } - )) - ); - - Ok(()) - } -} diff --git a/src/optimizer/core/mod.rs b/src/optimizer/core/mod.rs index 4be7d875..87ee196a 100644 --- a/src/optimizer/core/mod.rs +++ b/src/optimizer/core/mod.rs @@ -14,7 +14,6 @@ pub(crate) mod cm_sketch; pub(crate) mod histogram; -pub(crate) mod memo; pub(crate) mod pattern; pub(crate) mod rule; pub(crate) mod statistics_meta; diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index 5e2d0d60..303cb578 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -13,28 +13,54 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::GroupExpression; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::planner::operator::Operator; +use crate::planner::operator::{Operator, PhysicalOption}; use crate::planner::LogicalPlan; use crate::storage::Transaction; +use std::cmp::Ordering; + +pub type BestPhysicalOption = Option<(PhysicalOption, Option)>; // TODO: Use indexing and other methods for matching optimization to avoid traversal pub trait MatchPattern { fn pattern(&self) -> &Pattern; } -pub trait NormalizationRule: MatchPattern { +pub trait NormalizationRule { /// Returns true when the plan tree is modified. fn apply(&self, plan: &mut LogicalPlan) -> Result; } +fn compare_costs(candidate_cost: Option, best_cost: Option) -> Ordering { + match (candidate_cost, best_cost) { + (Some(candidate_cost), Some(best_cost)) => candidate_cost.cmp(&best_cost), + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (None, None) => Ordering::Equal, + } +} + +pub fn keep_best_physical_option( + best_physical_option: &mut BestPhysicalOption, + option: PhysicalOption, + cost: Option, +) { + let should_replace = match best_physical_option.as_ref() { + Some((_, best_cost)) => compare_costs(cost, *best_cost).is_lt(), + None => true, + }; + + if should_replace { + *best_physical_option = Some((option, cost)); + } +} + pub trait ImplementationRule: MatchPattern { - fn to_expression( + fn update_best_option( &self, op: &Operator, loader: &StatisticMetaLoader, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError>; } diff --git a/src/optimizer/core/statistics_meta.rs b/src/optimizer/core/statistics_meta.rs index f16ad8c3..7697ddb8 100644 --- a/src/optimizer/core/statistics_meta.rs +++ b/src/optimizer/core/statistics_meta.rs @@ -40,24 +40,18 @@ impl<'a, T: Transaction> StatisticMetaLoader<'a, T> { ) -> Result, DatabaseError> { let key = (table_name.clone(), index_id); match self.cache.get(&key) { - Some(Some(entry)) => return Ok(Some(entry.meta())), + Some(Some(entry)) => return Ok(Some(entry)), Some(None) => return Ok(None), - None => {} + _ => {} } let Some(statistics_meta) = self.tx.statistics_meta(table_name.as_ref(), index_id)? else { self.cache.put(key, None); return Ok(None); }; - self.cache.put( - key.clone(), - Some(StatisticsMetaCacheValue::new(statistics_meta)), - ); + self.cache.put(key.clone(), Some(statistics_meta)); - Ok(self - .cache - .get(&key) - .and_then(|entry| entry.as_ref().map(|entry| entry.meta()))) + Ok(self.cache.get(&key).and_then(|entry| entry.as_ref())) } pub fn collect_count( @@ -66,7 +60,7 @@ impl<'a, T: Transaction> StatisticMetaLoader<'a, T> { index_id: IndexId, range: &Range, ) -> Result, DatabaseError> { - let Some(statistics_meta) = self.load(table_name, index_id)? else { + let Some(entry) = self.load(table_name, index_id)? else { return Ok(None); }; let ranges = if let Range::SortedRanges(ranges) = range { @@ -74,66 +68,12 @@ impl<'a, T: Transaction> StatisticMetaLoader<'a, T> { } else { slice::from_ref(range) }; - let mut sketch = None; - let mut estimate = |value: &DataValue| -> Result { - if sketch.is_none() { - sketch = self.load_sketch(table_name, index_id)?; - } - sketch - .as_ref() - .ok_or_else(|| { - DatabaseError::InvalidValue("statistics sketch is incomplete".to_string()) - }) - .map(|sketch| sketch.estimate(value)) - }; - statistics_meta + entry .histogram() - .collect_count(ranges, &mut estimate) + .collect_count(ranges, entry.sketch()) .map(Some) } - - fn load_sketch( - &self, - table_name: &TableName, - index_id: IndexId, - ) -> Result>, DatabaseError> { - let key = (table_name.clone(), index_id); - match self.cache.get(&key) { - Some(Some(entry)) => { - if let Some(sketch) = entry.sketch() { - return Ok(Some(sketch)); - } - } - Some(None) => return Ok(None), - None => {} - } - - let Some(sketch) = self.tx.statistics_sketch(table_name.as_ref(), index_id)? else { - return Ok(None); - }; - let meta = match self.cache.get(&key) { - Some(Some(entry)) => entry.meta().clone(), - Some(None) | None => { - let Some(meta) = self.tx.statistics_meta(table_name.as_ref(), index_id)? else { - self.cache.put(key, None); - return Ok(None); - }; - meta - } - }; - self.cache.put( - key.clone(), - Some(StatisticsMetaCacheValue::new(meta).with_sketch(sketch)), - ); - - Ok(match self.cache.get(&key) { - Some(Some(entry)) => entry.sketch(), - Some(None) | None => { - return Ok(None); - } - }) - } } #[derive(Debug, Clone, ReferenceSerialization)] @@ -167,28 +107,35 @@ impl StatisticsMetaRoot { pub struct StatisticsMeta { index_id: IndexId, histogram: Histogram, + sketch: CountMinSketch, } impl StatisticsMeta { - pub fn new(histogram: Histogram) -> Self { + pub fn new(histogram: Histogram, sketch: CountMinSketch) -> Self { StatisticsMeta { index_id: histogram.index_id(), histogram, + sketch, } } pub fn from_parts( root: StatisticsMetaRoot, buckets: Vec, + sketch: CountMinSketch, ) -> Result { let histogram = Histogram::from_parts(root.into_histogram_meta(), buckets)?; - Ok(Self::new(histogram)) + Ok(Self::new(histogram, sketch)) } - pub fn into_parts(self) -> (StatisticsMetaRoot, Vec) { + pub fn into_parts(self) -> (StatisticsMetaRoot, Vec, CountMinSketch) { let (histogram_meta, buckets) = self.histogram.into_parts(); - (StatisticsMetaRoot::new(histogram_meta), buckets) + ( + StatisticsMetaRoot::new(histogram_meta), + buckets, + self.sketch, + ) } pub fn index_id(&self) -> IndexId { @@ -198,30 +145,9 @@ impl StatisticsMeta { pub fn histogram(&self) -> &Histogram { &self.histogram } -} - -#[derive(Debug, Clone)] -pub struct StatisticsMetaCacheValue { - meta: StatisticsMeta, - sketch: Option>, -} - -impl StatisticsMetaCacheValue { - pub fn new(meta: StatisticsMeta) -> Self { - Self { meta, sketch: None } - } - - pub fn with_sketch(mut self, sketch: CountMinSketch) -> Self { - self.sketch = Some(sketch); - self - } - pub fn meta(&self) -> &StatisticsMeta { - &self.meta - } - - pub fn sketch(&self) -> Option<&CountMinSketch> { - self.sketch.as_ref() + pub fn sketch(&self) -> &CountMinSketch { + &self.sketch } } @@ -268,12 +194,17 @@ mod tests { builder.append(&Arc::new(DataValue::Null))?; builder.append(&Arc::new(DataValue::Null))?; - let (histogram, _) = builder.build(4)?; - let meta = StatisticsMeta::new(histogram.clone()); - let (root, buckets) = meta.into_parts(); - let statistics_meta = StatisticsMeta::from_parts(root, buckets)?; + let (histogram, sketch) = builder.build(4)?; + let expected_estimate = sketch.estimate(&DataValue::Int32(7)); + let meta = StatisticsMeta::new(histogram.clone(), sketch); + let (root, buckets, sketch) = meta.into_parts(); + let statistics_meta = StatisticsMeta::from_parts(root, buckets, sketch)?; assert_eq!(histogram, statistics_meta.histogram); + assert_eq!( + expected_estimate, + statistics_meta.sketch().estimate(&DataValue::Int32(7)) + ); Ok(()) } diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs index 46e0e556..25efa273 100644 --- a/src/optimizer/heuristic/batch.rs +++ b/src/optimizer/heuristic/batch.rs @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::optimizer::rule::normalization::NormalizationRuleImpl; +use crate::optimizer::rule::normalization::{ + NormalizationPassKind, NormalizationRuleImpl, NormalizationRuleRootTag, WholeTreePassKind, +}; +use crate::planner::operator::Operator; +use std::array; /// A batch of rules. #[derive(Clone)] @@ -20,19 +24,102 @@ pub struct HepBatch { #[allow(dead_code)] pub name: String, pub strategy: HepBatchStrategy, + pub steps: Vec, +} + +#[derive(Clone)] +pub enum HepBatchStep { + WholeTree(HepWholeTreePass), + LocalRewrite(HepLocalRewriteBatch), +} + +#[derive(Clone)] +pub struct HepWholeTreePass { + pub kind: WholeTreePassKind, pub rules: Vec, } +#[derive(Clone)] +pub struct HepLocalRewriteBatch { + pub rules: Vec, + groups: [Vec; NormalizationRuleRootTag::COUNT], +} + +impl HepLocalRewriteBatch { + fn new(rules: Vec) -> Self { + let mut groups = array::from_fn(|_| Vec::new()); + for (idx, rule) in rules.iter().enumerate() { + groups[rule.root_tag() as usize].push(idx); + } + Self { rules, groups } + } + + fn push(&mut self, rule: NormalizationRuleImpl) { + let idx = self.rules.len(); + self.groups[rule.root_tag() as usize].push(idx); + self.rules.push(rule); + } + + pub fn len(&self) -> usize { + self.rules.len() + } + + pub fn next_matching_rule_index_from( + &self, + operator: &Operator, + start_idx: usize, + ) -> Option { + let any_group = &self.groups[NormalizationRuleRootTag::Any as usize]; + let any_idx = any_group.iter().copied().find(|idx| *idx >= start_idx); + + let specific_idx = NormalizationRuleRootTag::from_operator(operator).and_then(|tag| { + self.groups[tag as usize] + .iter() + .copied() + .find(|idx| *idx >= start_idx) + }); + + match (any_idx, specific_idx) { + (Some(any_idx), Some(specific_idx)) => Some(any_idx.min(specific_idx)), + (Some(any_idx), None) => Some(any_idx), + (None, Some(specific_idx)) => Some(specific_idx), + (None, None) => None, + } + } +} + impl HepBatch { pub fn new( name: String, strategy: HepBatchStrategy, rules: Vec, ) -> Self { + let mut steps = Vec::new(); + + for rule in rules { + match rule.pass_kind() { + NormalizationPassKind::WholeTreePass(kind) => match steps.last_mut() { + Some(HepBatchStep::WholeTree(pass)) if pass.kind == kind => { + pass.rules.push(rule); + } + _ => steps.push(HepBatchStep::WholeTree(HepWholeTreePass { + kind, + rules: vec![rule], + })), + }, + NormalizationPassKind::LocalRewrite => match steps.last_mut() { + Some(HepBatchStep::LocalRewrite(local_rules)) => local_rules.push(rule), + _ => steps.push(HepBatchStep::LocalRewrite(HepLocalRewriteBatch::new(vec![ + rule, + ]))), + }, + } + } + Self { name, strategy, - rules, + steps, } } } diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 3a2bb222..754abc9c 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -13,39 +13,46 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::Memo; -use crate::optimizer::core::pattern::PatternMatcher; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::{ + BestPhysicalOption, ImplementationRule, MatchPattern, NormalizationRule, +}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::optimizer::heuristic::batch::{HepBatch, HepBatchStrategy}; -use crate::optimizer::heuristic::matcher::PlanMatcher; -use crate::optimizer::rule::implementation::ImplementationRuleImpl; -use crate::optimizer::rule::normalization::NormalizationRuleImpl; +use crate::optimizer::heuristic::batch::{ + HepBatch, HepBatchStep, HepBatchStrategy, HepLocalRewriteBatch, HepWholeTreePass, +}; +use crate::optimizer::rule::implementation::{ImplementationRuleImpl, ImplementationRuleRootTag}; use crate::optimizer::rule::normalization::{ - annotate_sort_preserving_indexes, annotate_stream_distinct_indexes, + apply_annotated_post_rules, apply_scan_order_hint, constant_calculation_current, + evaluator_bind_current, NormalizationRuleImpl, OrderHintKind, ScanOrderHint, WholeTreePassKind, }; +use crate::planner::operator::join::JoinCondition; +use crate::planner::operator::table_scan::TableScanOperator; +use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; +use std::array; use std::ops::Not; +type ScanHintApplier<'a> = dyn Fn(&mut TableScanOperator) + 'a; + pub struct HepOptimizer<'a> { before_batches: &'a [HepBatch], after_batches: &'a [HepBatch], - implementations: &'a [ImplementationRuleImpl], + implementation_index: &'a ImplementationRuleIndex, plan: LogicalPlan, } impl<'a> HepOptimizer<'a> { - pub fn new( + fn new( plan: LogicalPlan, before_batches: &'a [HepBatch], after_batches: &'a [HepBatch], - implementations: &'a [ImplementationRuleImpl], + implementation_index: &'a ImplementationRuleIndex, ) -> Self { Self { before_batches, after_batches, - implementations, + implementation_index, plan, } } @@ -55,13 +62,18 @@ impl<'a> HepOptimizer<'a> { loader: Option<&StatisticMetaLoader<'_, T>>, ) -> Result { Self::apply_batches(&mut self.plan, self.before_batches)?; - annotate_sort_preserving_indexes(&mut self.plan); - annotate_stream_distinct_indexes(&mut self.plan); if let Some(loader) = loader { - if self.implementations.is_empty().not() { - let memo = Memo::new(&self.plan, loader, self.implementations)?; - Memo::annotate_plan(&memo, &mut self.plan); + if self.implementation_index.is_empty().not() { + let apply_no_sort_hints = |_scan_op: &mut TableScanOperator| {}; + let apply_no_stream_distinct_hints = |_scan_op: &mut TableScanOperator| {}; + Self::annotate_hints_and_physical_options( + &mut self.plan, + loader, + self.implementation_index, + &apply_no_sort_hints, + &apply_no_stream_distinct_hints, + )?; } } Self::apply_batches(&mut self.plan, self.after_batches)?; @@ -89,44 +101,596 @@ impl<'a> HepOptimizer<'a> { #[inline] fn apply_batch(plan: &mut LogicalPlan, batch: &HepBatch) -> Result { let mut applied = false; - for rule in &batch.rules { - if Self::apply_rule(plan, rule)? { - applied = true; + for step in &batch.steps { + match step { + HepBatchStep::WholeTree(pass) => { + if Self::apply_whole_tree_pass(plan, pass)? { + plan.reset_output_schema_cache_recursive(); + applied = true; + } + } + HepBatchStep::LocalRewrite(rules) => { + if Self::apply_local_rules(plan, rules)? { + applied = true; + } + } } } Ok(applied) } - fn apply_rule( + fn apply_whole_tree_pass( plan: &mut LogicalPlan, - rule: &NormalizationRuleImpl, + pass: &HepWholeTreePass, ) -> Result { - if PlanMatcher::new(rule.pattern(), plan).match_opt_expr() && rule.apply(plan)? { - plan.reset_output_schema_cache_recursive(); - return Ok(true); + match pass.kind { + WholeTreePassKind::ColumnPruning => { + let mut applied = false; + for rule in &pass.rules { + applied |= rule.apply(plan)?; + } + Ok(applied) + } + WholeTreePassKind::ExpressionRewrite => { + let has_constant_calculation = pass + .rules + .iter() + .any(|rule| matches!(rule, NormalizationRuleImpl::ConstantCalculation)); + let has_evaluator_bind = pass + .rules + .iter() + .any(|rule| matches!(rule, NormalizationRuleImpl::EvaluatorBind)); + + Self::apply_expression_rewrite_pass( + plan, + has_constant_calculation, + has_evaluator_bind, + )?; + Ok(true) + } } + } + fn apply_expression_rewrite_pass( + plan: &mut LogicalPlan, + has_constant_calculation: bool, + has_evaluator_bind: bool, + ) -> Result<(), DatabaseError> { + Self::apply_expression_rewrite_pass_inner( + plan, + has_constant_calculation, + has_evaluator_bind, + ) + } + + fn apply_expression_rewrite_pass_inner( + plan: &mut LogicalPlan, + has_constant_calculation: bool, + has_evaluator_bind: bool, + ) -> Result<(), DatabaseError> { match plan.childrens.as_mut() { Childrens::Only(child) => { - if Self::apply_rule(child, rule)? { - plan.reset_output_schema_cache(); - return Ok(true); - } - Ok(false) + Self::apply_expression_rewrite_pass_inner( + child, + has_constant_calculation, + has_evaluator_bind, + )?; } Childrens::Twins { left, right } => { - if Self::apply_rule(left, rule)? { - plan.reset_output_schema_cache(); - return Ok(true); + Self::apply_expression_rewrite_pass_inner( + left, + has_constant_calculation, + has_evaluator_bind, + )?; + Self::apply_expression_rewrite_pass_inner( + right, + has_constant_calculation, + has_evaluator_bind, + )?; + } + Childrens::None => {} + } + + if has_constant_calculation { + constant_calculation_current(plan)?; + } + if has_evaluator_bind { + evaluator_bind_current(plan)?; + } + + Ok(()) + } + + fn annotate_hints_and_physical_options<'plan, T: Transaction>( + plan: &'plan mut LogicalPlan, + loader: &StatisticMetaLoader<'_, T>, + implementation_index: &ImplementationRuleIndex, + inherited_sort_hints: &'plan ScanHintApplier<'plan>, + inherited_stream_distinct_hints: &'plan ScanHintApplier<'plan>, + ) -> Result<(), DatabaseError> { + if let Operator::TableScan(scan_op) = &mut plan.operator { + inherited_sort_hints(scan_op); + inherited_stream_distinct_hints(scan_op); + } + + { + let LogicalPlan { + operator, + physical_option, + .. + } = plan; + if let Some(option) = implementation_index.direct_physical_option(operator) { + *physical_option = Some(option); + } else { + let mut best_physical_option: BestPhysicalOption = None; + for rule in implementation_index.for_matching_operator(operator) { + rule.update_best_option(operator, loader, &mut best_physical_option)?; } - if Self::apply_rule(right, rule)? { - plan.reset_output_schema_cache(); - return Ok(true); + if let Some((option, _)) = best_physical_option { + *physical_option = Some(option); } - Ok(false) } - Childrens::None => Ok(false), } + + { + let LogicalPlan { + operator, + childrens, + .. + } = plan; + Self::with_child_sort_hints(operator, inherited_sort_hints, |child_sort_hints| { + Self::with_child_stream_distinct_hints( + operator, + inherited_stream_distinct_hints, + |child_stream_distinct_hints| match &mut **childrens { + Childrens::Only(child) => Self::annotate_hints_and_physical_options( + child, + loader, + implementation_index, + child_sort_hints, + child_stream_distinct_hints, + ), + Childrens::Twins { left, right } => { + Self::annotate_hints_and_physical_options( + left, + loader, + implementation_index, + child_sort_hints, + child_stream_distinct_hints, + )?; + Self::annotate_hints_and_physical_options( + right, + loader, + implementation_index, + child_sort_hints, + child_stream_distinct_hints, + ) + } + Childrens::None => Ok(()), + }, + ) + })?; + } + + apply_annotated_post_rules(plan)?; + + Ok(()) + } + + fn with_child_sort_hints<'plan, R>( + operator: &'plan Operator, + inherited_sort_hints: &'plan ScanHintApplier<'plan>, + f: impl for<'b> FnOnce(&'b ScanHintApplier<'plan>) -> R, + ) -> R { + let propagate_hints = matches!( + operator, + Operator::Filter(_) + | Operator::Project(_) + | Operator::Limit(_) + | Operator::TopK(_) + | Operator::Sort(_) + ); + + match operator { + Operator::Sort(op) => { + let child_sort_hints = |scan_op: &mut TableScanOperator| { + inherited_sort_hints(scan_op); + apply_scan_order_hint( + scan_op, + ScanOrderHint::sort_fields(&op.sort_fields), + OrderHintKind::SortElimination, + ); + }; + f(&child_sort_hints) + } + _ if propagate_hints => f(inherited_sort_hints), + _ => { + let no_sort_hints = |_scan_op: &mut TableScanOperator| {}; + f(&no_sort_hints) + } + } + } + + fn with_child_stream_distinct_hints<'plan, R>( + operator: &'plan Operator, + inherited_stream_distinct_hints: &'plan ScanHintApplier<'plan>, + f: impl for<'b> FnOnce(&'b ScanHintApplier<'plan>) -> R, + ) -> R { + let propagate_hints = matches!( + operator, + Operator::Filter(_) + | Operator::Project(_) + | Operator::Limit(_) + | Operator::TopK(_) + | Operator::Sort(_) + ); + + match operator { + Operator::Aggregate(op) + if op.is_distinct && op.agg_calls.is_empty() && !op.groupby_exprs.is_empty() => + { + let child_stream_distinct_hints = |scan_op: &mut TableScanOperator| { + apply_scan_order_hint( + scan_op, + ScanOrderHint::distinct_groupby(&op.groupby_exprs), + OrderHintKind::StreamDistinct, + ); + }; + f(&child_stream_distinct_hints) + } + _ if propagate_hints => f(inherited_stream_distinct_hints), + _ => { + let no_stream_distinct_hints = |_scan_op: &mut TableScanOperator| {}; + f(&no_stream_distinct_hints) + } + } + } + + fn apply_local_rules( + plan: &mut LogicalPlan, + rules: &HepLocalRewriteBatch, + ) -> Result { + let mut applied_rules = vec![false; rules.len()]; + Self::apply_local_rules_inner(plan, rules, &mut applied_rules) + } + + fn apply_local_rules_inner( + plan: &mut LogicalPlan, + rules: &HepLocalRewriteBatch, + applied_rules: &mut [bool], + ) -> Result { + let mut applied = false; + let mut next_rule_idx = 0; + + while let Some(idx) = rules.next_matching_rule_index_from(&plan.operator, next_rule_idx) { + let rule = rules.rules[idx]; + next_rule_idx = idx + 1; + if applied_rules[idx] { + continue; + } + if rule.apply(plan)? { + plan.reset_output_schema_cache_recursive(); + applied_rules[idx] = true; + applied = true; + } + } + + match plan.childrens.as_mut() { + Childrens::Only(child) => { + let child_applied = Self::apply_local_rules_inner(child, rules, applied_rules)?; + applied |= child_applied; + } + Childrens::Twins { left, right } => { + let left_applied = Self::apply_local_rules_inner(left, rules, applied_rules)?; + let right_applied = Self::apply_local_rules_inner(right, rules, applied_rules)?; + applied |= left_applied || right_applied; + } + Childrens::None => {} + } + + if applied { + plan.reset_output_schema_cache(); + } + + Ok(applied) + } +} + +#[derive(Clone, Default)] +struct ImplementationRuleIndex { + groups: [Vec; ImplementationRuleRootTag::COUNT], +} + +impl ImplementationRuleIndex { + fn new(implementations: Vec) -> Self { + let mut groups = array::from_fn(|_| Vec::new()); + for implementation in implementations { + groups[implementation.root_tag() as usize].push(implementation); + } + Self { groups } + } + + fn is_empty(&self) -> bool { + self.groups.iter().all(Vec::is_empty) + } + + fn contains(&self, implementation: ImplementationRuleImpl) -> bool { + self.groups[implementation.root_tag() as usize].contains(&implementation) + } + + fn direct_physical_option(&self, operator: &Operator) -> Option { + match operator { + Operator::Aggregate(op) + if !op.groupby_exprs.is_empty() + && self.contains(ImplementationRuleImpl::GroupByAggregate) => + { + Some(PhysicalOption::new( + PlanImpl::HashAggregate, + SortOption::None, + )) + } + Operator::Aggregate(op) + if op.groupby_exprs.is_empty() + && self.contains(ImplementationRuleImpl::SimpleAggregate) => + { + Some(PhysicalOption::new( + PlanImpl::SimpleAggregate, + SortOption::None, + )) + } + Operator::Dummy if self.contains(ImplementationRuleImpl::Dummy) => { + Some(PhysicalOption::new(PlanImpl::Dummy, SortOption::None)) + } + Operator::Filter(_) if self.contains(ImplementationRuleImpl::Filter) => { + Some(PhysicalOption::new(PlanImpl::Filter, SortOption::Follow)) + } + Operator::Join(join_op) if self.contains(ImplementationRuleImpl::HashJoin) => { + let plan = match &join_op.on { + JoinCondition::On { on, .. } if !on.is_empty() => PlanImpl::HashJoin, + _ => PlanImpl::NestLoopJoin, + }; + Some(PhysicalOption::new(plan, SortOption::None)) + } + Operator::Limit(_) if self.contains(ImplementationRuleImpl::Limit) => { + Some(PhysicalOption::new(PlanImpl::Limit, SortOption::Follow)) + } + Operator::Project(_) if self.contains(ImplementationRuleImpl::Projection) => { + Some(PhysicalOption::new(PlanImpl::Project, SortOption::Follow)) + } + Operator::ScalarSubquery(_) + if self.contains(ImplementationRuleImpl::ScalarSubquery) => + { + Some(PhysicalOption::new( + PlanImpl::ScalarSubquery, + SortOption::Follow, + )) + } + Operator::FunctionScan(_) if self.contains(ImplementationRuleImpl::FunctionScan) => { + Some(PhysicalOption::new( + PlanImpl::FunctionScan, + SortOption::None, + )) + } + Operator::Sort(op) if self.contains(ImplementationRuleImpl::Sort) => { + Some(PhysicalOption::new( + PlanImpl::Sort, + SortOption::OrderBy { + fields: op.sort_fields.clone(), + ignore_prefix_len: 0, + }, + )) + } + Operator::TopK(op) if self.contains(ImplementationRuleImpl::TopK) => { + Some(PhysicalOption::new( + PlanImpl::TopK, + SortOption::OrderBy { + fields: op.sort_fields.clone(), + ignore_prefix_len: 0, + }, + )) + } + Operator::Values(_) if self.contains(ImplementationRuleImpl::Values) => { + Some(PhysicalOption::new(PlanImpl::Values, SortOption::None)) + } + Operator::Analyze(_) if self.contains(ImplementationRuleImpl::Analyze) => { + Some(PhysicalOption::new(PlanImpl::Analyze, SortOption::None)) + } + Operator::CopyFromFile(_) if self.contains(ImplementationRuleImpl::CopyFromFile) => { + Some(PhysicalOption::new( + PlanImpl::CopyFromFile, + SortOption::None, + )) + } + Operator::CopyToFile(_) if self.contains(ImplementationRuleImpl::CopyToFile) => { + Some(PhysicalOption::new(PlanImpl::CopyToFile, SortOption::None)) + } + Operator::Delete(_) if self.contains(ImplementationRuleImpl::Delete) => { + Some(PhysicalOption::new(PlanImpl::Delete, SortOption::None)) + } + Operator::Insert(_) if self.contains(ImplementationRuleImpl::Insert) => { + Some(PhysicalOption::new(PlanImpl::Insert, SortOption::None)) + } + Operator::Update(_) if self.contains(ImplementationRuleImpl::Update) => { + Some(PhysicalOption::new(PlanImpl::Update, SortOption::None)) + } + Operator::AddColumn(_) if self.contains(ImplementationRuleImpl::AddColumn) => { + Some(PhysicalOption::new(PlanImpl::AddColumn, SortOption::None)) + } + Operator::ChangeColumn(_) if self.contains(ImplementationRuleImpl::ChangeColumn) => { + Some(PhysicalOption::new( + PlanImpl::ChangeColumn, + SortOption::None, + )) + } + Operator::CreateTable(_) if self.contains(ImplementationRuleImpl::CreateTable) => { + Some(PhysicalOption::new(PlanImpl::CreateTable, SortOption::None)) + } + Operator::DropColumn(_) if self.contains(ImplementationRuleImpl::DropColumn) => { + Some(PhysicalOption::new(PlanImpl::DropColumn, SortOption::None)) + } + Operator::DropTable(_) if self.contains(ImplementationRuleImpl::DropTable) => { + Some(PhysicalOption::new(PlanImpl::DropTable, SortOption::None)) + } + Operator::Truncate(_) if self.contains(ImplementationRuleImpl::Truncate) => { + Some(PhysicalOption::new(PlanImpl::Truncate, SortOption::None)) + } + _ => None, + } + } + + fn for_matching_operator<'b>( + &'b self, + operator: &'b Operator, + ) -> impl Iterator + 'b { + ImplementationRuleRootTag::from_operator(operator) + .into_iter() + .flat_map(move |tag| self.groups[tag as usize].iter()) + .filter(move |rule| (rule.pattern().predicate)(operator)) + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use crate::binder::{Binder, BinderContext}; + use crate::db::DataBaseBuilder; + use crate::errors::DatabaseError; + use crate::expression::range_detacher::Range; + use crate::expression::ScalarExpression; + use crate::optimizer::heuristic::batch::HepBatchStrategy; + use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; + use crate::optimizer::rule::implementation::ImplementationRuleImpl; + use crate::optimizer::rule::normalization::NormalizationRuleImpl; + use crate::planner::operator::sort::SortField; + use crate::planner::operator::{PhysicalOption, PlanImpl, SortOption}; + use crate::storage::{Storage, Transaction}; + use crate::types::index::{IndexInfo, IndexMeta, IndexType}; + use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::ops::Bound; + use std::sync::atomic::AtomicUsize; + use std::sync::Arc; + use tempfile::TempDir; + + #[test] + fn test_find_best_selects_cheapest_scan() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + database + .run("create table t1 (c1 int primary key, c2 int)")? + .done()?; + database + .run("create table t2 (c3 int primary key, c4 int)")? + .done()?; + + for i in 0..1000 { + database + .run(format!("insert into t1 values({}, {})", i, i + 1).as_str())? + .done()?; + } + database.run("analyze table t1")?.done()?; + + let transaction = database.storage.transaction()?; + let c1_column = transaction + .table(database.state.table_cache(), "t1".to_string().into())? + .unwrap() + .get_column_by_name("c1") + .unwrap(); + let sort_fields = vec![SortField::new( + ScalarExpression::column_expr(c1_column.clone(), 0), + true, + false, + )]; + let scala_functions = Default::default(); + let table_functions = Default::default(); + let mut binder = Binder::new( + BinderContext::new( + database.state.table_cache(), + database.state.view_cache(), + &transaction, + &scala_functions, + &table_functions, + Arc::new(AtomicUsize::new(0)), + ), + &[], + None, + ); + let stmt = crate::parser::parse_sql( + "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", + )?; + let plan = binder.bind(&stmt[0])?; + let pipeline = HepOptimizerPipeline::builder() + .before_batch( + "Simplify Filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::SimplifyFilter], + ) + .before_batch( + "Predicate Pushdown".to_string(), + HepBatchStrategy::fix_point_topdown(10), + vec![ + NormalizationRuleImpl::PushPredicateThroughJoin, + NormalizationRuleImpl::PushJoinPredicateIntoScan, + NormalizationRuleImpl::PushPredicateIntoScan, + ], + ) + .implementations(vec![ + ImplementationRuleImpl::Projection, + ImplementationRuleImpl::Filter, + ImplementationRuleImpl::HashJoin, + ImplementationRuleImpl::SeqScan, + ImplementationRuleImpl::IndexScan, + ]) + .build(); + + let best_plan = pipeline + .instantiate(plan) + .find_best(Some(&transaction.meta_loader(database.state.meta_cache())))?; + + assert_eq!( + best_plan + .childrens + .pop_only() + .childrens + .pop_twins() + .0 + .childrens + .pop_only() + .physical_option, + Some(PhysicalOption::new( + PlanImpl::IndexScan(Box::new(IndexInfo { + meta: Arc::new(IndexMeta { + id: 0, + column_ids: vec![c1_column.id().unwrap()], + table_name: "t1".to_string().into(), + pk_ty: LogicalType::Integer, + value_ty: LogicalType::Integer, + name: "pk_index".to_string(), + ty: IndexType::PrimaryKey { is_multiple: false }, + }), + sort_option: SortOption::OrderBy { + fields: sort_fields.clone(), + ignore_prefix_len: 0, + }, + range: Some(Range::SortedRanges(vec![ + Range::Eq(DataValue::Int32(2)), + Range::Scope { + min: Bound::Excluded(DataValue::Int32(40)), + max: Bound::Unbounded, + } + ])), + covered_deserializers: None, + cover_mapping: None, + sort_elimination_hint: None, + stream_distinct_hint: None, + })), + SortOption::OrderBy { + fields: sort_fields, + ignore_prefix_len: 0, + } + )) + ); + + Ok(()) } } @@ -134,7 +698,7 @@ impl<'a> HepOptimizer<'a> { pub struct HepOptimizerPipeline { before_batches: Vec, after_batches: Vec, - implementations: Vec, + implementation_index: ImplementationRuleIndex, } impl HepOptimizerPipeline { @@ -154,7 +718,7 @@ impl HepOptimizerPipeline { Self { before_batches, after_batches, - implementations, + implementation_index: ImplementationRuleIndex::new(implementations), } } @@ -163,7 +727,7 @@ impl HepOptimizerPipeline { plan, &self.before_batches, &self.after_batches, - &self.implementations, + &self.implementation_index, ) } } diff --git a/src/optimizer/rule/implementation/ddl/add_column.rs b/src/optimizer/rule/implementation/ddl/add_column.rs index 51614e77..bd17bb3b 100644 --- a/src/optimizer/rule/implementation/ddl/add_column.rs +++ b/src/optimizer/rule/implementation/ddl/add_column.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/ddl/change_column.rs b/src/optimizer/rule/implementation/ddl/change_column.rs index d004fae8..443b53ed 100644 --- a/src/optimizer/rule/implementation/ddl/change_column.rs +++ b/src/optimizer/rule/implementation/ddl/change_column.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/ddl/create_table.rs b/src/optimizer/rule/implementation/ddl/create_table.rs index 4b255076..267e0c19 100644 --- a/src/optimizer/rule/implementation/ddl/create_table.rs +++ b/src/optimizer/rule/implementation/ddl/create_table.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/ddl/drop_column.rs b/src/optimizer/rule/implementation/ddl/drop_column.rs index 7440a10a..f79a04d4 100644 --- a/src/optimizer/rule/implementation/ddl/drop_column.rs +++ b/src/optimizer/rule/implementation/ddl/drop_column.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/ddl/drop_table.rs b/src/optimizer/rule/implementation/ddl/drop_table.rs index ec995230..3914cad6 100644 --- a/src/optimizer/rule/implementation/ddl/drop_table.rs +++ b/src/optimizer/rule/implementation/ddl/drop_table.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/ddl/truncate.rs b/src/optimizer/rule/implementation/ddl/truncate.rs index 281a1996..e4b31596 100644 --- a/src/optimizer/rule/implementation/ddl/truncate.rs +++ b/src/optimizer/rule/implementation/ddl/truncate.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/analyze.rs b/src/optimizer/rule/implementation/dml/analyze.rs index b2e0881d..6682f285 100644 --- a/src/optimizer/rule/implementation/dml/analyze.rs +++ b/src/optimizer/rule/implementation/dml/analyze.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/copy_from_file.rs b/src/optimizer/rule/implementation/dml/copy_from_file.rs index 3fd5c500..18ab769f 100644 --- a/src/optimizer/rule/implementation/dml/copy_from_file.rs +++ b/src/optimizer/rule/implementation/dml/copy_from_file.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/copy_to_file.rs b/src/optimizer/rule/implementation/dml/copy_to_file.rs index 10ac05c4..06785524 100644 --- a/src/optimizer/rule/implementation/dml/copy_to_file.rs +++ b/src/optimizer/rule/implementation/dml/copy_to_file.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/delete.rs b/src/optimizer/rule/implementation/dml/delete.rs index 0e1a4e17..845f3e92 100644 --- a/src/optimizer/rule/implementation/dml/delete.rs +++ b/src/optimizer/rule/implementation/dml/delete.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/insert.rs b/src/optimizer/rule/implementation/dml/insert.rs index 27994273..b697450c 100644 --- a/src/optimizer/rule/implementation/dml/insert.rs +++ b/src/optimizer/rule/implementation/dml/insert.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dml/update.rs b/src/optimizer/rule/implementation/dml/update.rs index 128f3a1b..55d21dd8 100644 --- a/src/optimizer/rule/implementation/dml/update.rs +++ b/src/optimizer/rule/implementation/dml/update.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/aggregate.rs b/src/optimizer/rule/implementation/dql/aggregate.rs index bdcd75c8..10a7ec68 100644 --- a/src/optimizer/rule/implementation/dql/aggregate.rs +++ b/src/optimizer/rule/implementation/dql/aggregate.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/dummy.rs b/src/optimizer/rule/implementation/dql/dummy.rs index d3a85057..afcacaae 100644 --- a/src/optimizer/rule/implementation/dql/dummy.rs +++ b/src/optimizer/rule/implementation/dql/dummy.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/filter.rs b/src/optimizer/rule/implementation/dql/filter.rs index 24cdb124..b720c1a1 100644 --- a/src/optimizer/rule/implementation/dql/filter.rs +++ b/src/optimizer/rule/implementation/dql/filter.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/function_scan.rs b/src/optimizer/rule/implementation/dql/function_scan.rs index d4b0fdbf..ea982c82 100644 --- a/src/optimizer/rule/implementation/dql/function_scan.rs +++ b/src/optimizer/rule/implementation/dql/function_scan.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/join.rs b/src/optimizer/rule/implementation/dql/join.rs index e8e10bd6..de6cdbdf 100644 --- a/src/optimizer/rule/implementation/dql/join.rs +++ b/src/optimizer/rule/implementation/dql/join.rs @@ -13,9 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::join::{JoinCondition, JoinOperator}; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; @@ -37,11 +36,11 @@ impl MatchPattern for JoinImplementation { } impl ImplementationRule for JoinImplementation { - fn to_expression( + fn update_best_option( &self, op: &Operator, _: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { let mut physical_option = PhysicalOption::new(PlanImpl::NestLoopJoin, SortOption::None); @@ -54,10 +53,11 @@ impl ImplementationRule for JoinImplementation { physical_option.plan = PlanImpl::HashJoin; } } - group_expr.append_expr(Expression { - op: physical_option, - cost: None, - }); + crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + physical_option, + None, + ); Ok(()) } } diff --git a/src/optimizer/rule/implementation/dql/limit.rs b/src/optimizer/rule/implementation/dql/limit.rs index 3d6f28ee..47308921 100644 --- a/src/optimizer/rule/implementation/dql/limit.rs +++ b/src/optimizer/rule/implementation/dql/limit.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/mod.rs b/src/optimizer/rule/implementation/dql/mod.rs index 967594aa..eb18286e 100644 --- a/src/optimizer/rule/implementation/dql/mod.rs +++ b/src/optimizer/rule/implementation/dql/mod.rs @@ -19,6 +19,7 @@ pub(crate) mod function_scan; pub(crate) mod join; pub(crate) mod limit; pub(crate) mod projection; +pub(crate) mod scalar_subquery; pub(crate) mod sort; pub(crate) mod table_scan; pub(crate) mod top_k; diff --git a/src/optimizer/rule/implementation/dql/projection.rs b/src/optimizer/rule/implementation/dql/projection.rs index 4744650d..672d81a4 100644 --- a/src/optimizer/rule/implementation/dql/projection.rs +++ b/src/optimizer/rule/implementation/dql/projection.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/dql/scalar_subquery.rs b/src/optimizer/rule/implementation/dql/scalar_subquery.rs new file mode 100644 index 00000000..75a04732 --- /dev/null +++ b/src/optimizer/rule/implementation/dql/scalar_subquery.rs @@ -0,0 +1,37 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; +use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::statistics_meta::StatisticMetaLoader; +use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; +use crate::single_mapping; +use crate::storage::Transaction; +use std::sync::LazyLock; + +static SCALAR_SUBQUERY_PATTERN: LazyLock = LazyLock::new(|| Pattern { + predicate: |op| matches!(op, Operator::ScalarSubquery(_)), + children: PatternChildrenPredicate::None, +}); + +#[derive(Clone)] +pub struct ScalarSubqueryImplementation; + +single_mapping!( + ScalarSubqueryImplementation, + SCALAR_SUBQUERY_PATTERN, + PhysicalOption::new(PlanImpl::ScalarSubquery, SortOption::Follow) +); diff --git a/src/optimizer/rule/implementation/dql/sort.rs b/src/optimizer/rule/implementation/dql/sort.rs index 625c2ada..9b8028ed 100644 --- a/src/optimizer/rule/implementation/dql/sort.rs +++ b/src/optimizer/rule/implementation/dql/sort.rs @@ -13,9 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::storage::Transaction; @@ -36,23 +35,24 @@ impl MatchPattern for SortImplementation { } impl ImplementationRule for SortImplementation { - fn to_expression( + fn update_best_option( &self, op: &Operator, _: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::Sort(op) = op { - group_expr.append_expr(Expression { - op: PhysicalOption::new( + crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + PhysicalOption::new( PlanImpl::Sort, SortOption::OrderBy { fields: op.sort_fields.clone(), ignore_prefix_len: 0, }, ), - cost: None, - }); + None, + ); } Ok(()) diff --git a/src/optimizer/rule/implementation/dql/table_scan.rs b/src/optimizer/rule/implementation/dql/table_scan.rs index 737ea73f..e15e488a 100644 --- a/src/optimizer/rule/implementation/dql/table_scan.rs +++ b/src/optimizer/rule/implementation/dql/table_scan.rs @@ -13,9 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::storage::Transaction; @@ -37,11 +36,11 @@ impl MatchPattern for SeqScanImplementation { } impl ImplementationRule for SeqScanImplementation { - fn to_expression( + fn update_best_option( &self, op: &Operator, loader: &StatisticMetaLoader, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TableScan(scan_op) = op { let cost = scan_op @@ -53,10 +52,11 @@ impl ImplementationRule for SeqScanImplementation { .flatten() .map(|statistics_meta| statistics_meta.histogram().values_len()); - group_expr.append_expr(Expression { - op: PhysicalOption::new(PlanImpl::SeqScan, SortOption::None), + crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + PhysicalOption::new(PlanImpl::SeqScan, SortOption::None), cost, - }); + ); Ok(()) } else { unreachable!("invalid operator!") @@ -73,11 +73,11 @@ impl MatchPattern for IndexScanImplementation { } impl ImplementationRule for IndexScanImplementation { - fn to_expression( + fn update_best_option( &self, op: &Operator, loader: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TableScan(scan_op) = op { for index_info in scan_op.index_infos.iter() { @@ -126,13 +126,14 @@ impl ImplementationRule for IndexScanImplementation { } } - group_expr.append_expr(Expression { - op: PhysicalOption::new( + crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + PhysicalOption::new( PlanImpl::IndexScan(Box::new(index_info.clone())), index_info.sort_option.clone(), ), cost, - }); + ); } Ok(()) diff --git a/src/optimizer/rule/implementation/dql/top_k.rs b/src/optimizer/rule/implementation/dql/top_k.rs index d79cd2d3..39c4b85b 100644 --- a/src/optimizer/rule/implementation/dql/top_k.rs +++ b/src/optimizer/rule/implementation/dql/top_k.rs @@ -13,9 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::storage::Transaction; @@ -36,23 +35,24 @@ impl MatchPattern for TopKImplementation { } impl ImplementationRule for TopKImplementation { - fn to_expression( + fn update_best_option( &self, op: &Operator, _: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TopK(op) = op { - group_expr.append_expr(Expression { - op: PhysicalOption::new( + crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + PhysicalOption::new( PlanImpl::TopK, SortOption::OrderBy { fields: op.sort_fields.clone(), ignore_prefix_len: 0, }, ), - cost: None, - }); + None, + ); } Ok(()) diff --git a/src/optimizer/rule/implementation/dql/values.rs b/src/optimizer/rule/implementation/dql/values.rs index d7713ab9..69faffa8 100644 --- a/src/optimizer/rule/implementation/dql/values.rs +++ b/src/optimizer/rule/implementation/dql/values.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::memo::{Expression, GroupExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::BestPhysicalOption; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; diff --git a/src/optimizer/rule/implementation/macros.rs b/src/optimizer/rule/implementation/macros.rs index dd29d4d1..14f86e86 100644 --- a/src/optimizer/rule/implementation/macros.rs +++ b/src/optimizer/rule/implementation/macros.rs @@ -22,17 +22,18 @@ macro_rules! single_mapping { } impl ImplementationRule for $ty { - fn to_expression( + fn update_best_option( &self, _: &Operator, _: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { //TODO: CostModel - group_expr.append_expr(Expression { - op: $option, - cost: None, - }); + $crate::optimizer::core::rule::keep_best_physical_option( + best_physical_option, + $option, + None, + ); Ok(()) } diff --git a/src/optimizer/rule/implementation/mod.rs b/src/optimizer/rule/implementation/mod.rs index 3f9081c6..cebc4d7c 100644 --- a/src/optimizer/rule/implementation/mod.rs +++ b/src/optimizer/rule/implementation/mod.rs @@ -18,9 +18,8 @@ pub(crate) mod dql; pub(crate) mod macros; use crate::errors::DatabaseError; -use crate::optimizer::core::memo::GroupExpression; use crate::optimizer::core::pattern::Pattern; -use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; +use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::optimizer::rule::implementation::ddl::add_column::AddColumnImplementation; use crate::optimizer::rule::implementation::ddl::change_column::ChangeColumnImplementation; @@ -43,6 +42,7 @@ use crate::optimizer::rule::implementation::dql::function_scan::FunctionScanImpl use crate::optimizer::rule::implementation::dql::join::JoinImplementation; use crate::optimizer::rule::implementation::dql::limit::LimitImplementation; use crate::optimizer::rule::implementation::dql::projection::ProjectionImplementation; +use crate::optimizer::rule::implementation::dql::scalar_subquery::ScalarSubqueryImplementation; use crate::optimizer::rule::implementation::dql::sort::SortImplementation; use crate::optimizer::rule::implementation::dql::table_scan::{ IndexScanImplementation, SeqScanImplementation, @@ -52,7 +52,79 @@ use crate::optimizer::rule::implementation::dql::values::ValuesImplementation; use crate::planner::operator::Operator; use crate::storage::Transaction; -#[derive(Debug, Copy, Clone)] +#[repr(usize)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ImplementationRuleRootTag { + Aggregate = 0, + Dummy, + Filter, + Join, + Limit, + Project, + ScalarSubquery, + TableScan, + FunctionScan, + Sort, + TopK, + Values, + Analyze, + CopyFromFile, + CopyToFile, + Delete, + Insert, + Update, + AddColumn, + ChangeColumn, + CreateTable, + DropColumn, + DropTable, + Truncate, +} + +impl ImplementationRuleRootTag { + pub const COUNT: usize = Self::Truncate as usize + 1; + + pub fn from_operator(operator: &Operator) -> Option { + match operator { + Operator::Aggregate(_) => Some(Self::Aggregate), + Operator::Dummy => Some(Self::Dummy), + Operator::Filter(_) => Some(Self::Filter), + Operator::Join(_) => Some(Self::Join), + Operator::Limit(_) => Some(Self::Limit), + Operator::Project(_) => Some(Self::Project), + Operator::ScalarSubquery(_) => Some(Self::ScalarSubquery), + Operator::TableScan(_) => Some(Self::TableScan), + Operator::FunctionScan(_) => Some(Self::FunctionScan), + Operator::Sort(_) => Some(Self::Sort), + Operator::TopK(_) => Some(Self::TopK), + Operator::Values(_) => Some(Self::Values), + Operator::Analyze(_) => Some(Self::Analyze), + Operator::CopyFromFile(_) => Some(Self::CopyFromFile), + Operator::CopyToFile(_) => Some(Self::CopyToFile), + Operator::Delete(_) => Some(Self::Delete), + Operator::Insert(_) => Some(Self::Insert), + Operator::Update(_) => Some(Self::Update), + Operator::AddColumn(_) => Some(Self::AddColumn), + Operator::ChangeColumn(_) => Some(Self::ChangeColumn), + Operator::CreateTable(_) => Some(Self::CreateTable), + Operator::DropColumn(_) => Some(Self::DropColumn), + Operator::DropTable(_) => Some(Self::DropTable), + Operator::Truncate(_) => Some(Self::Truncate), + Operator::ShowTable + | Operator::ShowView + | Operator::Explain + | Operator::Describe(_) + | Operator::Except(_) + | Operator::Union(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropView(_) + | Operator::DropIndex(_) => None, + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ImplementationRuleImpl { // DQL GroupByAggregate, @@ -62,6 +134,7 @@ pub enum ImplementationRuleImpl { HashJoin, Limit, Projection, + ScalarSubquery, SeqScan, FunctionScan, IndexScan, @@ -94,6 +167,7 @@ impl MatchPattern for ImplementationRuleImpl { ImplementationRuleImpl::HashJoin => JoinImplementation.pattern(), ImplementationRuleImpl::Limit => LimitImplementation.pattern(), ImplementationRuleImpl::Projection => ProjectionImplementation.pattern(), + ImplementationRuleImpl::ScalarSubquery => ScalarSubqueryImplementation.pattern(), ImplementationRuleImpl::SeqScan => SeqScanImplementation.pattern(), ImplementationRuleImpl::IndexScan => IndexScanImplementation.pattern(), ImplementationRuleImpl::FunctionScan => FunctionScanImplementation.pattern(), @@ -116,88 +190,143 @@ impl MatchPattern for ImplementationRuleImpl { } } +impl ImplementationRuleImpl { + pub fn root_tag(&self) -> ImplementationRuleRootTag { + match self { + ImplementationRuleImpl::GroupByAggregate | ImplementationRuleImpl::SimpleAggregate => { + ImplementationRuleRootTag::Aggregate + } + ImplementationRuleImpl::Dummy => ImplementationRuleRootTag::Dummy, + ImplementationRuleImpl::Filter => ImplementationRuleRootTag::Filter, + ImplementationRuleImpl::HashJoin => ImplementationRuleRootTag::Join, + ImplementationRuleImpl::Limit => ImplementationRuleRootTag::Limit, + ImplementationRuleImpl::Projection => ImplementationRuleRootTag::Project, + ImplementationRuleImpl::ScalarSubquery => ImplementationRuleRootTag::ScalarSubquery, + ImplementationRuleImpl::SeqScan | ImplementationRuleImpl::IndexScan => { + ImplementationRuleRootTag::TableScan + } + ImplementationRuleImpl::FunctionScan => ImplementationRuleRootTag::FunctionScan, + ImplementationRuleImpl::Sort => ImplementationRuleRootTag::Sort, + ImplementationRuleImpl::TopK => ImplementationRuleRootTag::TopK, + ImplementationRuleImpl::Values => ImplementationRuleRootTag::Values, + ImplementationRuleImpl::Analyze => ImplementationRuleRootTag::Analyze, + ImplementationRuleImpl::CopyFromFile => ImplementationRuleRootTag::CopyFromFile, + ImplementationRuleImpl::CopyToFile => ImplementationRuleRootTag::CopyToFile, + ImplementationRuleImpl::Delete => ImplementationRuleRootTag::Delete, + ImplementationRuleImpl::Insert => ImplementationRuleRootTag::Insert, + ImplementationRuleImpl::Update => ImplementationRuleRootTag::Update, + ImplementationRuleImpl::AddColumn => ImplementationRuleRootTag::AddColumn, + ImplementationRuleImpl::ChangeColumn => ImplementationRuleRootTag::ChangeColumn, + ImplementationRuleImpl::CreateTable => ImplementationRuleRootTag::CreateTable, + ImplementationRuleImpl::DropColumn => ImplementationRuleRootTag::DropColumn, + ImplementationRuleImpl::DropTable => ImplementationRuleRootTag::DropTable, + ImplementationRuleImpl::Truncate => ImplementationRuleRootTag::Truncate, + } + } +} + impl ImplementationRule for ImplementationRuleImpl { - fn to_expression( + fn update_best_option( &self, operator: &Operator, loader: &StatisticMetaLoader<'_, T>, - group_expr: &mut GroupExpression, + best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { match self { - ImplementationRuleImpl::GroupByAggregate => { - GroupByAggregateImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::SimpleAggregate => { - SimpleAggregateImplementation.to_expression(operator, loader, group_expr)? - } + ImplementationRuleImpl::GroupByAggregate => GroupByAggregateImplementation + .update_best_option(operator, loader, best_physical_option)?, + ImplementationRuleImpl::SimpleAggregate => SimpleAggregateImplementation + .update_best_option(operator, loader, best_physical_option)?, ImplementationRuleImpl::Dummy => { - DummyImplementation.to_expression(operator, loader, group_expr)? + DummyImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Filter => { - FilterImplementation.to_expression(operator, loader, group_expr)? + FilterImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::HashJoin => { - JoinImplementation.to_expression(operator, loader, group_expr)? + JoinImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Limit => { - LimitImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::Projection => { - ProjectionImplementation.to_expression(operator, loader, group_expr)? + LimitImplementation.update_best_option(operator, loader, best_physical_option)? } + ImplementationRuleImpl::Projection => ProjectionImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::ScalarSubquery => ScalarSubqueryImplementation + .update_best_option(operator, loader, best_physical_option)?, ImplementationRuleImpl::SeqScan => { - SeqScanImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::IndexScan => { - IndexScanImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::FunctionScan => { - FunctionScanImplementation.to_expression(operator, loader, group_expr)? + SeqScanImplementation.update_best_option(operator, loader, best_physical_option)? } + ImplementationRuleImpl::IndexScan => IndexScanImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::FunctionScan => FunctionScanImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, ImplementationRuleImpl::Sort => { - SortImplementation.to_expression(operator, loader, group_expr)? + SortImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::TopK => { - TopKImplementation.to_expression(operator, loader, group_expr)? + TopKImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Values => { - ValuesImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::CopyFromFile => { - CopyFromFileImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::CopyToFile => { - CopyToFileImplementation.to_expression(operator, loader, group_expr)? + ValuesImplementation.update_best_option(operator, loader, best_physical_option)? } + ImplementationRuleImpl::CopyFromFile => CopyFromFileImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::CopyToFile => CopyToFileImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, ImplementationRuleImpl::Delete => { - DeleteImplementation.to_expression(operator, loader, group_expr)? + DeleteImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Insert => { - InsertImplementation.to_expression(operator, loader, group_expr)? + InsertImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Update => { - UpdateImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::AddColumn => { - AddColumnImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::ChangeColumn => { - ChangeColumnImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::CreateTable => { - CreateTableImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::DropColumn => { - DropColumnImplementation.to_expression(operator, loader, group_expr)? - } - ImplementationRuleImpl::DropTable => { - DropTableImplementation.to_expression(operator, loader, group_expr)? + UpdateImplementation.update_best_option(operator, loader, best_physical_option)? } + ImplementationRuleImpl::AddColumn => AddColumnImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::ChangeColumn => ChangeColumnImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::CreateTable => CreateTableImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::DropColumn => DropColumnImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::DropTable => DropTableImplementation.update_best_option( + operator, + loader, + best_physical_option, + )?, ImplementationRuleImpl::Truncate => { - TruncateImplementation.to_expression(operator, loader, group_expr)? + TruncateImplementation.update_best_option(operator, loader, best_physical_option)? } ImplementationRuleImpl::Analyze => { - AnalyzeImplementation.to_expression(operator, loader, group_expr)? + AnalyzeImplementation.update_best_option(operator, loader, best_physical_option)? } } diff --git a/src/optimizer/rule/normalization/agg_elimination.rs b/src/optimizer/rule/normalization/agg_elimination.rs index fffc69f2..b6354420 100644 --- a/src/optimizer/rule/normalization/agg_elimination.rs +++ b/src/optimizer/rule/normalization/agg_elimination.rs @@ -12,31 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::sort::SortField; +use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::planner::{Childrens, LogicalPlan}; -use std::sync::LazyLock; - -static REDUNDANT_SORT_PATTERN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Sort(_) | Operator::TopK(_)), - children: PatternChildrenPredicate::None, -}); pub struct EliminateRedundantSort; -impl MatchPattern for EliminateRedundantSort { - fn pattern(&self) -> &Pattern { - &REDUNDANT_SORT_PATTERN - } -} - impl NormalizationRule for EliminateRedundantSort { fn apply(&self, plan: &mut LogicalPlan) -> Result { let (sort_fields, topk_limit) = match &plan.operator { @@ -72,57 +59,32 @@ impl NormalizationRule for EliminateRedundantSort { } } -pub fn annotate_sort_preserving_indexes(plan: &mut LogicalPlan) { - fn visit(plan: &mut LogicalPlan) { - if let Operator::Sort(sort_op) = &plan.operator { - let sort_fields = sort_op.sort_fields.clone(); - mark_sort_preserving_indexes(plan, &sort_fields); - } - match plan.childrens.as_mut() { - Childrens::Only(child) => visit(child), - Childrens::Twins { left, right } => { - visit(left); - visit(right); - } - Childrens::None => {} - } - } - visit(plan); -} - fn mark_sort_preserving_indexes(plan: &mut LogicalPlan, required: &[SortField]) { mark_order_hint(plan, required, OrderHintKind::SortElimination); } -pub fn annotate_stream_distinct_indexes(plan: &mut LogicalPlan) { - fn visit(plan: &mut LogicalPlan) { - if let Operator::Aggregate(op) = &plan.operator { - if op.is_distinct && op.agg_calls.is_empty() && !op.groupby_exprs.is_empty() { - if let Childrens::Only(child) = plan.childrens.as_mut() { - let required = distinct_sort_fields(&op.groupby_exprs); - mark_order_hint(child, &required, OrderHintKind::StreamDistinct); - } - } - } - - match plan.childrens.as_mut() { - Childrens::Only(child) => visit(child), - Childrens::Twins { left, right } => { - visit(left); - visit(right); - } - Childrens::None => {} - } - } - visit(plan); -} - #[derive(Copy, Clone)] -enum OrderHintKind { +pub(crate) enum OrderHintKind { SortElimination, StreamDistinct, } +#[derive(Copy, Clone)] +pub(crate) enum ScanOrderHint<'a> { + SortFields(&'a [SortField]), + DistinctGroupBy(&'a [ScalarExpression]), +} + +impl<'a> ScanOrderHint<'a> { + pub(crate) fn sort_fields(fields: &'a [SortField]) -> Self { + Self::SortFields(fields) + } + + pub(crate) fn distinct_groupby(groupby_exprs: &'a [ScalarExpression]) -> Self { + Self::DistinctGroupBy(groupby_exprs) + } +} + fn mark_order_hint(plan: &mut LogicalPlan, required: &[SortField], hint: OrderHintKind) { if required.is_empty() { return; @@ -139,43 +101,80 @@ fn mark_order_hint(plan: &mut LogicalPlan, required: &[SortField], hint: OrderHi } } Operator::TableScan(scan_op) => { - let table_columns: Vec = scan_op.columns.values().cloned().collect(); - let required_from_table = required.iter().all(|field| { - let referenced = field.expr.referenced_columns(true); - referenced - .iter() - .all(|column| table_columns.contains(column)) - }); - if !required_from_table { - return; - } - for index_info in scan_op.index_infos.iter_mut() { - if covers(required, &index_info.sort_option) { - let covered = required.len(); - match hint { - OrderHintKind::SortElimination => { - index_info.sort_elimination_hint = Some( - index_info - .sort_elimination_hint - .map_or(covered, |old| old.max(covered)), - ); - } - OrderHintKind::StreamDistinct => { - index_info.stream_distinct_hint = Some( - index_info - .stream_distinct_hint - .map_or(covered, |old| old.max(covered)), - ); - } - } + apply_scan_order_hint(scan_op, ScanOrderHint::sort_fields(required), hint); + } + _ => {} + } +} + +pub(crate) fn apply_scan_order_hint( + scan_op: &mut TableScanOperator, + required: ScanOrderHint<'_>, + hint: OrderHintKind, +) { + let required_from_table = match required { + ScanOrderHint::SortFields(fields) => fields.iter().all(|field| { + field.expr.all_referenced_columns(true, |column| { + scan_op + .columns + .values() + .any(|table_column| table_column == column) + }) + }), + ScanOrderHint::DistinctGroupBy(groupby_exprs) => groupby_exprs.iter().all(|expr| { + expr.all_referenced_columns(true, |column| { + scan_op + .columns + .values() + .any(|table_column| table_column == column) + }) + }), + }; + if !required_from_table { + return; + } + for index_info in scan_op.index_infos.iter_mut() { + if hint_covers(required, &index_info.sort_option) { + let covered = hint_len(required); + match hint { + OrderHintKind::SortElimination => { + index_info.sort_elimination_hint = Some( + index_info + .sort_elimination_hint + .map_or(covered, |old| old.max(covered)), + ); + } + OrderHintKind::StreamDistinct => { + index_info.stream_distinct_hint = Some( + index_info + .stream_distinct_hint + .map_or(covered, |old| old.max(covered)), + ); } } } - _ => {} } } -fn distinct_sort_fields(groupby_exprs: &[ScalarExpression]) -> Vec { +fn hint_len(required: ScanOrderHint<'_>) -> usize { + match required { + ScanOrderHint::SortFields(fields) => fields.len(), + ScanOrderHint::DistinctGroupBy(groupby_exprs) => groupby_exprs.len(), + } +} + +fn hint_covers(required: ScanOrderHint<'_>, provided: &SortOption) -> bool { + match required { + ScanOrderHint::SortFields(fields) => covers(fields, provided, sort_field_matches), + ScanOrderHint::DistinctGroupBy(groupby_exprs) => { + covers(groupby_exprs, provided, |expr, field| { + field.asc && !field.nulls_first && expr.eq_ignore_colref_pos(&field.expr) + }) + } + } +} + +pub(crate) fn distinct_sort_fields(groupby_exprs: &[ScalarExpression]) -> Vec { groupby_exprs .iter() .cloned() @@ -183,24 +182,8 @@ fn distinct_sort_fields(groupby_exprs: &[ScalarExpression]) -> Vec { .collect() } -static STREAM_DISTINCT_PATTERN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| match op { - Operator::Aggregate(op) => { - op.is_distinct && op.agg_calls.is_empty() && !op.groupby_exprs.is_empty() - } - _ => false, - }, - children: PatternChildrenPredicate::None, -}); - pub struct UseStreamDistinct; -impl MatchPattern for UseStreamDistinct { - fn pattern(&self) -> &Pattern { - &STREAM_DISTINCT_PATTERN - } -} - impl NormalizationRule for UseStreamDistinct { fn apply(&self, plan: &mut LogicalPlan) -> Result { let Operator::Aggregate(op) = &plan.operator else { @@ -236,20 +219,36 @@ impl NormalizationRule for UseStreamDistinct { } } +pub(crate) fn apply_annotated_post_rules(plan: &mut LogicalPlan) -> Result { + let mut changed = false; + + if EliminateRedundantSort.apply(plan)? { + plan.reset_output_schema_cache_recursive(); + changed = true; + } + if UseStreamDistinct.apply(plan)? { + changed = true; + } + + Ok(changed) +} + fn ensure_stream_distinct_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { if let Some(PhysicalOption { plan: PlanImpl::IndexScan(index_info), .. }) = plan.physical_option.as_ref() { - if covers(required, &index_info.sort_option) { + if covers(required, &index_info.sort_option, sort_field_matches) { return true; } } if let Some(physical_option) = plan.physical_option.as_ref() { match physical_option.sort_option() { - SortOption::OrderBy { .. } if covers(required, physical_option.sort_option()) => { + SortOption::OrderBy { .. } + if covers(required, physical_option.sort_option(), sort_field_matches) => + { return true } SortOption::OrderBy { .. } => {} @@ -273,7 +272,7 @@ fn ensure_index_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { .. }) = plan.physical_option.as_ref() { - if covers(required, &index_info.sort_option) { + if covers(required, &index_info.sort_option, sort_field_matches) { return true; } } @@ -291,7 +290,17 @@ fn ensure_index_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { false } -fn covers(required: &[SortField], provided: &SortOption) -> bool { +fn sort_field_matches(required: &SortField, provided: &SortField) -> bool { + required.asc == provided.asc + && required.nulls_first == provided.nulls_first + && required.expr.eq_ignore_colref_pos(&provided.expr) +} + +pub(crate) fn covers( + required: &[T], + provided: &SortOption, + mut matches: impl FnMut(&T, &SortField) -> bool, +) -> bool { if required.is_empty() { return true; } @@ -313,7 +322,7 @@ fn covers(required: &[SortField], provided: &SortOption) -> bool { if required .iter() .zip(fields.iter().skip(skip)) - .all(|(lhs, rhs)| lhs == rhs) + .all(|(lhs, rhs)| matches(lhs, rhs)) { return true; } @@ -348,8 +357,12 @@ mod tests { use ulid::Ulid; fn make_sort_field(name: &str) -> SortField { + make_sort_field_with_position(name, 0) + } + + fn make_sort_field_with_position(name: &str, position: usize) -> SortField { let column = ColumnRef::from(ColumnCatalog::new_dummy(name.to_string())); - SortField::new(ScalarExpression::column_expr(column), true, false) + SortField::new(ScalarExpression::column_expr(column, position), true, false) } fn build_plan( @@ -425,7 +438,7 @@ mod tests { columns.insert(0, c1.clone()); let sort_fields = vec![SortField::new( - ScalarExpression::column_expr(c1.clone()), + ScalarExpression::column_expr(c1.clone(), 0), true, false, )]; @@ -465,7 +478,7 @@ mod tests { let plan = LogicalPlan::new( Operator::Aggregate(AggregateOperator { - groupby_exprs: vec![ScalarExpression::column_expr(c1)], + groupby_exprs: vec![ScalarExpression::column_expr(c1, 0)], agg_calls: vec![], is_distinct: true, }), @@ -512,18 +525,47 @@ mod tests { fn remove_sort_when_prefix_can_be_ignored() -> Result<(), DatabaseError> { let c1 = make_sort_field("c1"); let c2 = make_sort_field("c2"); - let mut plan = build_plan(vec![c2.clone()], vec![c1, c2], 1); - super::annotate_sort_preserving_indexes(&mut plan); + let mut plan = build_plan(vec![c2.clone()], vec![c1, c2.clone()], 1); + super::mark_sort_preserving_indexes(&mut plan, &[c2]); let rule = EliminateRedundantSort; assert!(rule.apply(&mut plan)?); Ok(()) } + #[test] + fn remove_topk_when_index_matches_same_column_with_different_positions( + ) -> Result<(), DatabaseError> { + let required = make_sort_field_with_position("no_o_id", 0); + let provided_prefix_1 = make_sort_field_with_position("no_w_id", 0); + let provided_prefix_2 = make_sort_field_with_position("no_d_id", 1); + let provided_target = make_sort_field_with_position("no_o_id", 2); + + let mut plan = build_plan( + vec![required.clone()], + vec![provided_prefix_1, provided_prefix_2, provided_target], + 2, + ); + plan.operator = Operator::TopK(TopKOperator { + sort_fields: vec![required], + limit: 1, + offset: None, + }); + + let rule = EliminateRedundantSort; + assert!(rule.apply(&mut plan)?); + assert!(matches!(plan.operator, Operator::Limit(_))); + Ok(()) + } + #[test] fn annotate_sets_sort_hint_on_table_scan() -> Result<(), DatabaseError> { let column = ColumnRef::from(ColumnCatalog::new_dummy("c1".to_string())); - let sort_field = SortField::new(ScalarExpression::column_expr(column.clone()), true, false); + let sort_field = SortField::new( + ScalarExpression::column_expr(column.clone(), 0), + true, + false, + ); let (index_info, _) = build_index_info(vec![sort_field.clone()], 0); let mut columns = BTreeMap::new(); @@ -549,7 +591,11 @@ mod tests { Childrens::Only(Box::new(table_scan)), ); - super::annotate_sort_preserving_indexes(&mut plan); + let sort_fields = match &plan.operator { + Operator::Sort(sort_op) => sort_op.sort_fields.clone(), + _ => unreachable!("expected sort operator"), + }; + super::mark_sort_preserving_indexes(&mut plan, &sort_fields); let table_plan = plan.childrens.pop_only(); match table_plan.operator { @@ -568,8 +614,14 @@ mod tests { #[test] fn annotate_sets_stream_distinct_hint_on_table_scan() -> Result<(), DatabaseError> { let (mut plan, _) = build_distinct_scan_plan(); + let required = match &plan.operator { + Operator::Aggregate(op) => super::distinct_sort_fields(&op.groupby_exprs), + _ => unreachable!("expected aggregate operator"), + }; + if let Childrens::Only(child) = plan.childrens.as_mut() { + super::mark_order_hint(child, &required, super::OrderHintKind::StreamDistinct); + } - super::annotate_stream_distinct_indexes(&mut plan); let child = plan.childrens.pop_only(); let Operator::TableScan(scan_op) = child.operator else { unreachable!() @@ -613,8 +665,8 @@ mod tests { fn keep_sort_when_order_not_covered() -> Result<(), DatabaseError> { let c1 = make_sort_field("c1"); let c2 = make_sort_field("c2"); - let mut plan = build_plan(vec![c2.clone()], vec![c1.clone(), c2], 0); - super::annotate_sort_preserving_indexes(&mut plan); + let mut plan = build_plan(vec![c2.clone()], vec![c1.clone(), c2.clone()], 0); + super::mark_sort_preserving_indexes(&mut plan, &[c2]); let rule = EliminateRedundantSort; assert!(!rule.apply(&mut plan)?); @@ -625,7 +677,11 @@ mod tests { #[test] fn promote_index_to_remove_sort() -> Result<(), DatabaseError> { let column = ColumnRef::from(ColumnCatalog::new_dummy("c_first".to_string())); - let sort_field = SortField::new(ScalarExpression::column_expr(column.clone()), true, false); + let sort_field = SortField::new( + ScalarExpression::column_expr(column.clone(), 0), + true, + false, + ); let (mut index_info, _) = build_index_info(vec![sort_field.clone()], 0); index_info.range = Some(Range::Scope { min: Bound::Unbounded, @@ -672,7 +728,11 @@ mod tests { Childrens::Only(Box::new(filter)), ); - super::annotate_sort_preserving_indexes(&mut plan); + let sort_fields = match &plan.operator { + Operator::Sort(sort_op) => sort_op.sort_fields.clone(), + _ => unreachable!("expected sort operator"), + }; + super::mark_sort_preserving_indexes(&mut plan, &sort_fields); let rule = EliminateRedundantSort; assert!(rule.apply(&mut plan)?); assert!(matches!(plan.operator, Operator::Filter(_))); diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index d018c345..2420f100 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -17,60 +17,386 @@ use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::visitor::Visitor; use crate::expression::{HasCountStar, ScalarExpression}; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; +use crate::optimizer::rule::normalization::{remap_expr_positions, remap_exprs_positions}; +use crate::planner::operator::join::JoinCondition; +use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::LogicalType; +use bumpalo::Bump; use sqlparser::ast::CharLengthUnits; use std::collections::HashSet; -use std::sync::LazyLock; - -static COLUMN_PRUNING_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |_| true, - children: PatternChildrenPredicate::None, -}); #[derive(Clone)] pub struct ColumnPruning; -macro_rules! trans_references { - ($columns:expr) => {{ - let mut column_references = HashSet::with_capacity($columns.len()); - for column in $columns { - column_references.insert(column.summary()); +type BumpUsizeVec<'bump> = bumpalo::collections::Vec<'bump, usize>; + +struct ApplyOutcome<'bump> { + changed: bool, + removed_positions: BumpUsizeVec<'bump>, +} + +impl<'bump> ApplyOutcome<'bump> { + fn new(arena: &'bump Bump) -> Self { + Self { + changed: false, + removed_positions: BumpUsizeVec::new_in(arena), } - column_references - }}; + } +} + +struct JoinChildrenOutcome<'bump> { + changed: bool, + left_removed_positions: BumpUsizeVec<'bump>, + right_removed_positions: BumpUsizeVec<'bump>, +} + +impl<'bump> JoinChildrenOutcome<'bump> { + fn new(arena: &'bump Bump) -> Self { + Self { + changed: false, + left_removed_positions: BumpUsizeVec::new_in(arena), + right_removed_positions: BumpUsizeVec::new_in(arena), + } + } } impl ColumnPruning { - fn clear_exprs(column_references: &HashSet<&ColumnSummary>, exprs: &mut Vec) { + fn copy_removed_positions<'bump>( + removed_positions: &[usize], + arena: &'bump Bump, + ) -> BumpUsizeVec<'bump> { + let mut copied = BumpUsizeVec::with_capacity_in(removed_positions.len(), arena); + copied.extend_from_slice(removed_positions); + copied + } + + fn extend_operator_referenced_columns<'a>( + operator: &'a Operator, + referenced_columns: &mut HashSet<&'a ColumnSummary>, + ) { + match operator { + Operator::Aggregate(op) => { + Self::extend_expr_referenced_columns( + op.agg_calls.iter().chain(op.groupby_exprs.iter()), + referenced_columns, + ); + } + Operator::Filter(op) => { + Self::extend_expr_referenced_columns([&op.predicate], referenced_columns); + } + Operator::Join(op) => { + if let JoinCondition::On { on, filter } = &op.on { + for (left_expr, right_expr) in on { + Self::extend_expr_referenced_columns( + [left_expr, right_expr], + referenced_columns, + ); + } + if let Some(filter_expr) = filter { + Self::extend_expr_referenced_columns([filter_expr], referenced_columns); + } + } + } + Operator::Project(op) => { + Self::extend_expr_referenced_columns(op.exprs.iter(), referenced_columns); + } + Operator::TableScan(op) => { + referenced_columns.extend(op.columns.values().map(|column| column.summary())); + } + Operator::FunctionScan(op) => { + Self::extend_expr_referenced_columns( + op.table_function.args.iter(), + referenced_columns, + ); + } + Operator::Sort(op) => { + Self::extend_expr_referenced_columns( + op.sort_fields.iter().map(|field| &field.expr), + referenced_columns, + ); + } + Operator::TopK(op) => { + Self::extend_expr_referenced_columns( + op.sort_fields.iter().map(|field| &field.expr), + referenced_columns, + ); + } + Operator::Values(op) => { + referenced_columns.extend(op.schema_ref.iter().map(|column| column.summary())); + } + Operator::Union(op) => { + referenced_columns.extend( + op.left_schema_ref + .iter() + .chain(op._right_schema_ref.iter()) + .map(|column| column.summary()), + ); + } + Operator::Except(op) => { + referenced_columns.extend( + op.left_schema_ref + .iter() + .chain(op._right_schema_ref.iter()) + .map(|column| column.summary()), + ); + } + Operator::Delete(op) => { + referenced_columns.extend(op.primary_keys.iter().map(|column| column.summary())); + } + Operator::Dummy + | Operator::Limit(_) + | Operator::ScalarSubquery(_) + | Operator::Analyze(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::AddColumn(_) + | Operator::ChangeColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => {} + } + } + + fn extend_expr_referenced_columns<'a>( + exprs: impl IntoIterator, + referenced_columns: &mut HashSet<&'a ColumnSummary>, + ) { + struct ColumnSummaryCollector<'a, 'b> { + referenced_columns: &'b mut HashSet<&'a ColumnSummary>, + } + + impl<'a> Visitor<'a> for ColumnSummaryCollector<'a, '_> { + fn visit_column_ref( + &mut self, + column: &'a crate::catalog::ColumnRef, + ) -> Result<(), DatabaseError> { + self.referenced_columns.insert(column.summary()); + Ok(()) + } + } + + let mut collector = ColumnSummaryCollector { referenced_columns }; + for expr in exprs { + collector.visit(expr).unwrap(); + } + } + + fn output_column_is_required( + expr: &ScalarExpression, + column_references: &HashSet<&ColumnSummary>, + ) -> bool { + column_references.contains(expr.output_column().summary()) + } + + fn clear_exprs( + column_references: &HashSet<&ColumnSummary>, + exprs: &mut Vec, + removed_positions: &mut BumpUsizeVec<'_>, + ) { + removed_positions.clear(); + let mut position = 0; exprs.retain(|expr| { - if column_references.contains(expr.output_column().summary()) { - return true; + let keep = Self::output_column_is_required(expr, column_references); + if !keep { + removed_positions.push(position); } - expr.referenced_columns(false) - .iter() - .any(|column| column_references.contains(column.summary())) + position += 1; + keep + }); + } + + fn remap_operator_after_child_change( + operator: &mut Operator, + removed_positions: &[usize], + ) -> Result<(), DatabaseError> { + match operator { + Operator::Aggregate(op) => { + Self::remap_exprs_after_child_change( + op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()), + removed_positions, + )?; + } + Operator::Filter(op) => { + remap_expr_positions(&mut op.predicate, removed_positions)?; + } + Operator::Project(op) => { + remap_exprs_positions(op.exprs.iter_mut(), removed_positions)?; + } + Operator::ScalarSubquery(_) => {} + Operator::Sort(op) => { + Self::remap_exprs_after_child_change( + op.sort_fields.iter_mut().map(|field| &mut field.expr), + removed_positions, + )?; + } + Operator::TopK(op) => { + Self::remap_exprs_after_child_change( + op.sort_fields.iter_mut().map(|field| &mut field.expr), + removed_positions, + )?; + } + Operator::Update(op) => { + Self::remap_exprs_after_child_change( + op.value_exprs.iter_mut().map(|(_, expr)| expr), + removed_positions, + )?; + } + Operator::Limit(_) + | Operator::Explain + | Operator::Insert(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::Dummy + | Operator::TableScan(_) + | Operator::Join(_) + | Operator::Values(_) + | Operator::FunctionScan(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::Describe(_) + | Operator::Union(_) + | Operator::Except(_) + | Operator::AddColumn(_) + | Operator::ChangeColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => {} + } + + Ok(()) + } + + fn remap_exprs_after_child_change<'a>( + exprs: impl IntoIterator, + removed_positions: &[usize], + ) -> Result<(), DatabaseError> { + if removed_positions.is_empty() { + return Ok(()); + } + remap_exprs_positions(exprs, removed_positions) + } + + fn apply_only_child<'bump>( + referenced_columns: HashSet<&ColumnSummary>, + all_referenced: bool, + childrens: &mut Childrens, + arena: &'bump Bump, + ) -> Result, DatabaseError> { + let Childrens::Only(child) = childrens else { + return Ok(ApplyOutcome::new(arena)); + }; + Self::_apply(referenced_columns, all_referenced, child.as_mut(), arena) + } + + fn apply_join_children<'bump>( + referenced_columns: HashSet<&ColumnSummary>, + all_referenced: bool, + childrens: &mut Childrens, + arena: &'bump Bump, + ) -> Result, DatabaseError> { + let Childrens::Twins { left, right } = childrens else { + return Ok(JoinChildrenOutcome::new(arena)); + }; + let left_outcome = Self::_apply( + referenced_columns.clone(), + all_referenced, + left.as_mut(), + arena, + )?; + let right_outcome = + Self::_apply(referenced_columns, all_referenced, right.as_mut(), arena)?; + + Ok(JoinChildrenOutcome { + changed: left_outcome.changed || right_outcome.changed, + left_removed_positions: left_outcome.removed_positions, + right_removed_positions: right_outcome.removed_positions, }) } - fn _apply( - column_references: HashSet<&ColumnSummary>, + #[allow(clippy::needless_lifetimes)] + fn apply_twins<'bump>( + referenced_columns: HashSet<&ColumnSummary>, all_referenced: bool, - plan: &mut LogicalPlan, + childrens: &mut Childrens, + arena: &'bump Bump, ) -> Result { + let Childrens::Twins { left, right } = childrens else { + return Ok(false); + }; + + let left_changed = Self::_apply( + referenced_columns.clone(), + all_referenced, + left.as_mut(), + arena, + )? + .changed; + let right_changed = + Self::_apply(referenced_columns, all_referenced, right.as_mut(), arena)?.changed; + + Ok(left_changed || right_changed) + } + + fn merge_removed_positions<'bump>( + left_removed_positions: &[usize], + right_removed_positions: &[usize], + right_offset: usize, + arena: &'bump Bump, + ) -> BumpUsizeVec<'bump> { + let mut removed_positions = BumpUsizeVec::with_capacity_in( + left_removed_positions.len() + right_removed_positions.len(), + arena, + ); + removed_positions.extend_from_slice(left_removed_positions); + removed_positions.extend( + right_removed_positions + .iter() + .map(|position| position + right_offset), + ); + removed_positions + } + + fn _apply<'bump>( + required_columns: HashSet<&ColumnSummary>, + all_referenced: bool, + plan: &mut LogicalPlan, + arena: &'bump Bump, + ) -> Result, DatabaseError> { let mut changed = false; - let operator = &mut plan.operator; + let mut output_removed_positions = BumpUsizeVec::new_in(arena); + let (operator, childrens) = (&mut plan.operator, plan.childrens.as_mut()); match operator { Operator::Aggregate(op) => { if !all_referenced { - let before = op.agg_calls.len(); - Self::clear_exprs(&column_references, &mut op.agg_calls); - if op.agg_calls.len() != before { + Self::clear_exprs( + &required_columns, + &mut op.agg_calls, + &mut output_removed_positions, + ); + if !output_removed_positions.is_empty() { changed = true; } @@ -91,17 +417,26 @@ impl ColumnPruning { changed = true; } } - let is_distinct = op.is_distinct; - let referenced_columns = operator.referenced_columns(false); - let mut new_column_references = trans_references!(&referenced_columns); - // on distinct - if is_distinct { - for summary in column_references { - new_column_references.insert(summary); - } - } + let child_outcome = { + let mut child_required = if op.is_distinct { + required_columns + } else { + HashSet::new() + }; + Self::extend_expr_referenced_columns( + op.agg_calls.iter().chain(op.groupby_exprs.iter()), + &mut child_required, + ); - changed |= Self::recollect_apply(new_column_references, false, plan)?; + Self::apply_only_child(child_required, false, childrens, arena)? + }; + if child_outcome.changed { + Self::remap_operator_after_child_change( + operator, + &child_outcome.removed_positions, + )?; + changed = true; + } } Operator::Project(op) => { let mut has_count_star = HasCountStar::default(); @@ -110,57 +445,161 @@ impl ColumnPruning { } if !has_count_star.value { if !all_referenced { - let before = op.exprs.len(); - Self::clear_exprs(&column_references, &mut op.exprs); - if op.exprs.len() != before { + Self::clear_exprs( + &required_columns, + &mut op.exprs, + &mut output_removed_positions, + ); + if !output_removed_positions.is_empty() { changed = true; } } - let referenced_columns = operator.referenced_columns(false); - let new_column_references = trans_references!(&referenced_columns); + let child_outcome = { + let mut child_required = HashSet::new(); + Self::extend_expr_referenced_columns(op.exprs.iter(), &mut child_required); - changed |= Self::recollect_apply(new_column_references, false, plan)?; + Self::apply_only_child(child_required, false, childrens, arena)? + }; + if child_outcome.changed { + Self::remap_operator_after_child_change( + operator, + &child_outcome.removed_positions, + )?; + changed = true; + } } } Operator::TableScan(op) => { if !all_referenced { - let before = op.columns.len(); - op.columns - .retain(|_, column| column_references.contains(column.summary())); - if op.columns.len() != before { + output_removed_positions.clear(); + op.columns.retain(|position, column| { + let keep = required_columns.contains(column.summary()); + if !keep { + output_removed_positions.push(*position); + } + keep + }); + if !output_removed_positions.is_empty() { changed = true; } } } Operator::Sort(_) | Operator::Limit(_) + | Operator::ScalarSubquery(_) | Operator::Join(_) | Operator::Filter(_) | Operator::Union(_) | Operator::Except(_) | Operator::TopK(_) => { - let temp_columns = operator.referenced_columns(false); - // this is magic!!! do not delete!!! - let mut column_references = column_references; - for column in temp_columns.iter() { - column_references.insert(column.summary()); + if matches!(operator, Operator::Join(_)) { + let (child_outcome, old_left_outputs_len) = { + let mut child_required = required_columns.clone(); + Self::extend_operator_referenced_columns(operator, &mut child_required); + let old_left_outputs_len = match childrens { + Childrens::Twins { left, .. } => left.output_schema().len(), + _ => 0, + }; + let child_outcome = Self::apply_join_children( + child_required, + all_referenced, + childrens, + arena, + )?; + (child_outcome, old_left_outputs_len) + }; + if child_outcome.changed { + let JoinChildrenOutcome { + changed: _, + left_removed_positions, + right_removed_positions, + } = child_outcome; + if let Operator::Join(op) = operator { + match &mut op.on { + JoinCondition::On { on, filter } => { + for (left_expr, right_expr) in on { + remap_expr_positions(left_expr, &left_removed_positions)?; + remap_expr_positions(right_expr, &right_removed_positions)?; + } + if let Some(filter) = filter { + if !left_removed_positions.is_empty() + || !right_removed_positions.is_empty() + { + let removed_positions = Self::merge_removed_positions( + &left_removed_positions, + &right_removed_positions, + old_left_outputs_len, + arena, + ); + remap_expr_positions(filter, &removed_positions)?; + } + } + } + JoinCondition::None => {} + } + if !matches!(op.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + output_removed_positions = Self::merge_removed_positions( + &left_removed_positions, + &right_removed_positions, + old_left_outputs_len, + arena, + ); + } else { + output_removed_positions = + Self::copy_removed_positions(&left_removed_positions, arena); + } + } + changed = true; + } + } else if matches!(operator, Operator::Union(_) | Operator::Except(_)) { + let mut child_required = required_columns; + Self::extend_operator_referenced_columns(operator, &mut child_required); + changed |= Self::apply_twins(child_required, all_referenced, childrens, arena)?; + } else { + let child_outcome = { + let mut child_required = required_columns; + Self::extend_operator_referenced_columns(operator, &mut child_required); + Self::apply_only_child(child_required, all_referenced, childrens, arena)? + }; + if child_outcome.changed { + let removed_positions = child_outcome.removed_positions; + Self::remap_operator_after_child_change(operator, &removed_positions)?; + output_removed_positions = removed_positions; + changed = true; + } } - changed |= Self::recollect_apply(column_references.clone(), all_referenced, plan)?; } // Last Operator Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => (), Operator::Explain => { - changed |= Self::recollect_apply(column_references, true, plan)?; + let child_outcome = + Self::apply_only_child(required_columns, true, childrens, arena)?; + if child_outcome.changed { + Self::remap_operator_after_child_change( + operator, + &child_outcome.removed_positions, + )?; + changed = true; + } } // DDL Based on Other Plan Operator::Insert(_) | Operator::Update(_) | Operator::Delete(_) | Operator::Analyze(_) => { - let referenced_columns = operator.referenced_columns(false); - let new_column_references = trans_references!(&referenced_columns); + let child_outcome = { + let mut child_required = HashSet::new(); + Self::extend_operator_referenced_columns(operator, &mut child_required); - changed |= Self::recollect_apply(new_column_references, true, plan)?; + Self::apply_only_child(child_required, true, childrens, arena)? + }; + if child_outcome.changed { + Self::remap_operator_after_child_change( + operator, + &child_outcome.removed_positions, + )?; + changed = true; + } } // DDL Single Plan Operator::CreateTable(_) @@ -180,47 +619,18 @@ impl ColumnPruning { | Operator::Describe(_) => (), } - Ok(changed) - } - - fn recollect_apply( - referenced_columns: HashSet<&ColumnSummary>, - all_referenced: bool, - plan: &mut LogicalPlan, - ) -> Result { - Self::for_each_child(plan, |child| { - Self::_apply(referenced_columns.clone(), all_referenced, child) + Ok(ApplyOutcome { + changed, + removed_positions: output_removed_positions, }) } - - fn for_each_child( - plan: &mut LogicalPlan, - mut f: impl FnMut(&mut LogicalPlan) -> Result, - ) -> Result { - let mut changed = false; - match plan.childrens.as_mut() { - Childrens::Only(child) => { - changed |= f(child.as_mut())?; - } - Childrens::Twins { left, right } => { - changed |= f(left.as_mut())?; - changed |= f(right.as_mut())?; - } - Childrens::None => (), - } - Ok(changed) - } -} - -impl MatchPattern for ColumnPruning { - fn pattern(&self) -> &Pattern { - &COLUMN_PRUNING_RULE - } } impl NormalizationRule for ColumnPruning { fn apply(&self, plan: &mut LogicalPlan) -> Result { - Self::_apply(HashSet::new(), true, plan) + let arena = Bump::new(); + let outcome = Self::_apply(HashSet::<&ColumnSummary>::new(), true, plan, &arena)?; + Ok(outcome.changed) } } @@ -233,9 +643,129 @@ mod tests { use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; - use crate::planner::Childrens; + use crate::planner::{Childrens, LogicalPlan}; use crate::storage::rocksdb::RocksTransaction; + fn optimize_column_pruning(sql: &str) -> Result { + let table_state = build_t1_table()?; + let plan = table_state.plan(sql)?; + + HepOptimizerPipeline::builder() + .before_batch( + format!("column_pruning::{sql}"), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::ColumnPruning], + ) + .build() + .instantiate(plan) + .find_best::(None) + } + + fn contains_operator(plan: &LogicalPlan, predicate: impl Fn(&Operator) -> bool + Copy) -> bool { + predicate(&plan.operator) + || plan + .childrens + .iter() + .any(|child| contains_operator(child, predicate)) + } + + fn collect_scan_columns(plan: &LogicalPlan, table_name: &str, scans: &mut Vec>) { + if let Operator::TableScan(op) = &plan.operator { + if op.table_name.to_string() == table_name { + scans.push( + op.columns + .values() + .map(|column| column.name().to_string()) + .collect(), + ); + } + } + + for child in plan.childrens.iter() { + collect_scan_columns(child, table_name, scans); + } + } + + fn assert_single_scan_columns(plan: &LogicalPlan, table_name: &str, expected: &[&str]) { + let mut scans = Vec::new(); + collect_scan_columns(plan, table_name, &mut scans); + assert_eq!( + scans.len(), + 1, + "expected exactly one scan for table {table_name}" + ); + let expected = expected + .iter() + .map(|name| name.to_string()) + .collect::>(); + assert_eq!(scans.pop().unwrap(), expected); + } + + #[test] + fn test_column_pruning_project_single_side() -> Result<(), DatabaseError> { + let best_plan = optimize_column_pruning("select c1 from t1")?; + + assert!(contains_operator(&best_plan, |op| matches!( + op, + Operator::Project(_) + ))); + assert_single_scan_columns(&best_plan, "t1", &["c1"]); + + Ok(()) + } + + #[test] + fn test_column_pruning_filter_single_side() -> Result<(), DatabaseError> { + let best_plan = optimize_column_pruning("select c1 from t1 where c2 > 1")?; + + assert!(contains_operator(&best_plan, |op| matches!( + op, + Operator::Filter(_) + ))); + assert_single_scan_columns(&best_plan, "t1", &["c1", "c2"]); + + Ok(()) + } + + #[test] + fn test_column_pruning_aggregate_single_side() -> Result<(), DatabaseError> { + let best_plan = optimize_column_pruning("select sum(c1) from t1")?; + + assert!(contains_operator(&best_plan, |op| matches!( + op, + Operator::Aggregate(_) + ))); + assert_single_scan_columns(&best_plan, "t1", &["c1"]); + + Ok(()) + } + + #[test] + fn test_column_pruning_sort_single_side() -> Result<(), DatabaseError> { + let best_plan = optimize_column_pruning("select c1 from t1 order by c2")?; + + assert!(contains_operator(&best_plan, |op| matches!( + op, + Operator::Sort(_) + ))); + assert_single_scan_columns(&best_plan, "t1", &["c1", "c2"]); + + Ok(()) + } + + #[test] + fn test_column_pruning_limit_single_side() -> Result<(), DatabaseError> { + let best_plan = optimize_column_pruning("select c1 from t1 limit 1")?; + + assert!(contains_operator(&best_plan, |op| matches!( + op, + Operator::Limit(_) + ))); + assert_single_scan_columns(&best_plan, "t1", &["c1"]); + + Ok(()) + } + #[test] fn test_column_pruning() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index fcd9034e..a548145e 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -13,47 +13,17 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::expression::{BinaryOperator, ScalarExpression}; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; use crate::optimizer::rule::normalization::{is_subset_exprs, strip_alias}; +use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; -use crate::planner::LogicalPlan; +use crate::planner::{Childrens, LogicalPlan}; +use crate::types::value::DataValue; use crate::types::LogicalType; -use std::collections::HashSet; -use std::sync::LazyLock; - -static COLLAPSE_PROJECT_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::None, - }]), -}); - -static COMBINE_FILTERS_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Filter(_)) || is_passthrough_project_operator(op), - children: PatternChildrenPredicate::None, - }]), -}); - -static COLLAPSE_GROUP_BY_AGG: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| match op { - Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(), - _ => false, - }, - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| match op { - Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(), - _ => false, - }, - children: PatternChildrenPredicate::None, - }]), -}); +use std::mem; fn is_passthrough_project(op: &ProjectOperator) -> bool { op.exprs @@ -61,34 +31,87 @@ fn is_passthrough_project(op: &ProjectOperator) -> bool { .all(|expr| matches!(strip_alias(expr), ScalarExpression::ColumnRef { .. })) } -fn is_passthrough_project_operator(op: &Operator) -> bool { - matches!(op, Operator::Project(project_op) if is_passthrough_project(project_op)) +fn passthrough_source_position(expr: &ScalarExpression) -> Option { + match strip_alias(expr) { + ScalarExpression::ColumnRef { position, .. } => Some(*position), + _ => None, + } } -/// Combine two adjacent project operators into one. -pub struct CollapseProject; +fn rewrite_column_position(expr: &mut ScalarExpression, new_position: usize) { + match expr { + ScalarExpression::ColumnRef { position, .. } => { + *position = new_position; + } + ScalarExpression::Alias { expr, alias } => { + rewrite_column_position(expr, new_position); + if let AliasType::Expr(alias_expr) = alias { + rewrite_column_position(alias_expr, new_position); + } + } + _ => {} + } +} -impl MatchPattern for CollapseProject { - fn pattern(&self) -> &Pattern { - &COLLAPSE_PROJECT_RULE +fn remap_passthrough_project_exprs( + parent_exprs: &mut [ScalarExpression], + child_exprs: &[ScalarExpression], +) -> bool { + let mut remapped_positions = Vec::with_capacity(parent_exprs.len()); + + for parent_expr in parent_exprs.iter() { + let Some(position) = child_exprs + .iter() + .find(|child_expr| parent_expr.eq_ignore_colref_pos(child_expr)) + .and_then(passthrough_source_position) + else { + return false; + }; + remapped_positions.push(position); } + + for (parent_expr, position) in parent_exprs.iter_mut().zip(remapped_positions) { + rewrite_column_position(parent_expr, position); + } + + true +} + +fn groupby_exprs_match( + parent_exprs: &[ScalarExpression], + child_exprs: &[ScalarExpression], +) -> bool { + parent_exprs.len() == child_exprs.len() + && parent_exprs + .iter() + .zip(child_exprs.iter()) + .all(|(parent_expr, child_expr)| parent_expr.eq_ignore_colref_pos(child_expr)) } +/// Combine two adjacent project operators into one. +pub struct CollapseProject; + impl NormalizationRule for CollapseProject { fn apply(&self, plan: &mut LogicalPlan) -> Result { - let parent_exprs = match &plan.operator { - Operator::Project(op) => op.exprs.clone(), - _ => return Ok(false), + let Operator::Project(parent_op) = &mut plan.operator else { + return Ok(false); }; let mut removed = false; - while let Some(child) = only_child_mut(plan) { + loop { + let Childrens::Only(child) = plan.childrens.as_mut() else { + break; + }; match &child.operator { Operator::Project(child_op) if is_passthrough_project(child_op) - && is_subset_exprs(&parent_exprs, &child_op.exprs) => + && is_subset_exprs(&parent_op.exprs, &child_op.exprs) + && remap_passthrough_project_exprs( + &mut parent_op.exprs, + &child_op.exprs, + ) => { - removed |= replace_with_only_child(child); + removed |= replace_with_only_child(child.as_mut()); } _ => break, } @@ -101,35 +124,45 @@ impl NormalizationRule for CollapseProject { /// Combine two adjacent filter operators into one. pub struct CombineFilter; -impl MatchPattern for CombineFilter { - fn pattern(&self) -> &Pattern { - &COMBINE_FILTERS_RULE - } -} - impl NormalizationRule for CombineFilter { fn apply(&self, plan: &mut LogicalPlan) -> Result { - let (parent_predicate, parent_having) = match &plan.operator { - Operator::Filter(op) => (op.predicate.clone(), op.having), - _ => return Ok(false), + let parent_filter = match mem::replace(&mut plan.operator, Operator::Dummy) { + Operator::Filter(op) => op, + operator => { + plan.operator = operator; + return Ok(false); + } }; + let parent_filter = parent_filter; let cursor = match only_child_mut(plan) { Some(child) => child, - None => return Ok(false), + None => { + plan.operator = Operator::Filter(parent_filter); + return Ok(false); + } }; loop { match &mut cursor.operator { Operator::Filter(child_op) => { + let FilterOperator { + predicate, + having, + is_optimized: _, + } = parent_filter; + let child_predicate = mem::replace( + &mut child_op.predicate, + ScalarExpression::Constant(DataValue::Boolean(true)), + ); child_op.predicate = ScalarExpression::Binary { op: BinaryOperator::And, - left_expr: Box::new(parent_predicate), - right_expr: Box::new(child_op.predicate.clone()), + left_expr: Box::new(predicate), + right_expr: Box::new(child_predicate), evaluator: None, ty: LogicalType::Boolean, }; - child_op.having = parent_having || child_op.having; + child_op.having = having || child_op.having; return Ok(replace_with_only_child(plan)); } @@ -137,9 +170,13 @@ impl NormalizationRule for CombineFilter { if replace_with_only_child(cursor) { continue; } + plan.operator = Operator::Filter(parent_filter); + return Ok(false); + } + _ => { + plan.operator = Operator::Filter(parent_filter); return Ok(false); } - _ => return Ok(false), } } } @@ -147,37 +184,32 @@ impl NormalizationRule for CombineFilter { pub struct CollapseGroupByAgg; -impl MatchPattern for CollapseGroupByAgg { - fn pattern(&self) -> &Pattern { - &COLLAPSE_GROUP_BY_AGG - } -} - impl NormalizationRule for CollapseGroupByAgg { fn apply(&self, plan: &mut LogicalPlan) -> Result { - if let Operator::Aggregate(op) = plan.operator.clone() { + let can_collapse = { + let LogicalPlan { + operator, + childrens, + .. + } = plan; + let Operator::Aggregate(op) = operator else { + return Ok(false); + }; if !op.agg_calls.is_empty() { return Ok(false); } - if let Some(child) = only_child_mut(plan) { - if let Operator::Aggregate(child_op) = child.operator.clone() { - if op.groupby_exprs.len() != child_op.groupby_exprs.len() { - return Ok(false); - } - let mut expr_set = HashSet::new(); + let Childrens::Only(child) = childrens.as_ref() else { + return Ok(false); + }; + let Operator::Aggregate(child_op) = &child.operator else { + return Ok(false); + }; + groupby_exprs_match(&op.groupby_exprs, &child_op.groupby_exprs) + }; - for expr in op.groupby_exprs.iter() { - expr_set.insert(expr); - } - for expr in child_op.groupby_exprs.iter() { - expr_set.remove(expr); - } - if expr_set.is_empty() { - return Ok(replace_with_only_child(plan)); - } - } - } + if can_collapse { + return Ok(replace_with_only_child(plan)); } Ok(false) @@ -187,15 +219,29 @@ impl NormalizationRule for CollapseGroupByAgg { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use crate::binder::test::build_t1_table; + use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::errors::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; + use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; + use crate::optimizer::rule::normalization::combine_operators::{ + CollapseGroupByAgg, CollapseProject, + }; use crate::optimizer::rule::normalization::NormalizationRuleImpl; + use crate::planner::operator::aggregate::AggregateOperator; + use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; - use crate::planner::Childrens; + use crate::planner::{Childrens, LogicalPlan}; use crate::storage::rocksdb::RocksTransaction; + fn column_expr(name: &str, position: usize) -> ScalarExpression { + ScalarExpression::column_expr( + ColumnRef::from(ColumnCatalog::new_dummy(name.to_string())), + position, + ) + } + #[test] fn test_collapse_project() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; @@ -218,12 +264,14 @@ mod tests { unreachable!("Should be a project operator") } - let scan_op = best_plan.childrens.pop_only(); - if let Operator::TableScan(_) = &scan_op.operator { - assert!(matches!(scan_op.childrens.as_ref(), Childrens::None)); - } else { - unreachable!("Should be a scan operator") - } + let alias_project = best_plan.childrens.pop_only(); + assert!( + matches!(alias_project.operator, Operator::Project(_)), + "Derived-table alias projection should be preserved" + ); + let scan_op = alias_project.childrens.pop_only(); + assert!(matches!(scan_op.operator, Operator::TableScan(_))); + assert!(matches!(scan_op.childrens.as_ref(), Childrens::None)); Ok(()) } @@ -265,6 +313,37 @@ mod tests { Ok(()) } + #[test] + fn test_collapse_project_remaps_reordered_passthrough_positions() -> Result<(), DatabaseError> { + let child = LogicalPlan::new( + Operator::Project(ProjectOperator { + exprs: vec![column_expr("c2", 1), column_expr("c1", 0)], + }), + Childrens::Only(Box::new(LogicalPlan::new(Operator::Dummy, Childrens::None))), + ); + let mut plan = LogicalPlan::new( + Operator::Project(ProjectOperator { + exprs: vec![column_expr("c2", 0)], + }), + Childrens::Only(Box::new(child)), + ); + + assert!(CollapseProject.apply(&mut plan)?); + + let Operator::Project(op) = &plan.operator else { + unreachable!("expected project"); + }; + let ScalarExpression::ColumnRef { position, .. } = &op.exprs[0] else { + unreachable!("expected column ref"); + }; + assert_eq!(*position, 1); + assert!(matches!( + plan.childrens.pop_only().operator, + Operator::Dummy + )); + Ok(()) + } + #[test] fn test_combine_filter() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; @@ -324,4 +403,25 @@ mod tests { } unreachable!("Should be a agg operator") } + + #[test] + fn test_collapse_group_by_agg_ignores_columnref_position() -> Result<(), DatabaseError> { + let child = AggregateOperator::build( + LogicalPlan::new(Operator::Dummy, Childrens::None), + vec![], + vec![column_expr("c2", 1)], + false, + ); + let mut plan = AggregateOperator::build(child, vec![], vec![column_expr("c2", 0)], true); + + assert!(CollapseGroupByAgg.apply(&mut plan)?); + let Operator::Aggregate(op) = &plan.operator else { + unreachable!("expected aggregate"); + }; + let ScalarExpression::ColumnRef { position, .. } = &op.groupby_exprs[0] else { + unreachable!("expected column ref"); + }; + assert_eq!(*position, 1); + Ok(()) + } } diff --git a/src/optimizer/rule/normalization/compilation_in_advance.rs b/src/optimizer/rule/normalization/compilation_in_advance.rs index 35610428..8a2b9361 100644 --- a/src/optimizer/rule/normalization/compilation_in_advance.rs +++ b/src/optimizer/rule/normalization/compilation_in_advance.rs @@ -14,177 +14,99 @@ use crate::errors::DatabaseError; use crate::expression::visitor_mut::VisitorMut; -use crate::expression::{BindEvaluator, BindPosition, ScalarExpression}; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::expression::BindEvaluator; +use crate::optimizer::core::rule::NormalizationRule; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; -use std::borrow::Cow; -use std::sync::LazyLock; - -static BIND_EXPRESSION_POSITION: LazyLock = LazyLock::new(|| Pattern { - predicate: |_| true, - children: PatternChildrenPredicate::None, -}); - -static EVALUATOR_BIND_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |_| true, - children: PatternChildrenPredicate::None, -}); #[derive(Clone)] -pub struct BindExpressionPosition; +pub struct EvaluatorBind; -impl BindExpressionPosition { - fn _apply( - output_exprs: &mut Vec, - plan: &mut LogicalPlan, - ) -> Result<(), DatabaseError> { - let mut left_len = 0; - match plan.childrens.as_mut() { - Childrens::Only(child) => { - Self::_apply(output_exprs, child)?; - } - Childrens::Twins { left, right } => { - Self::_apply(output_exprs, left)?; - if matches!( - plan.operator, - Operator::Join(_) | Operator::Union(_) | Operator::Except(_) - ) { - let mut second_output_exprs = Vec::new(); - Self::_apply(&mut second_output_exprs, right)?; - left_len = output_exprs.len(); - output_exprs.append(&mut second_output_exprs); - } - } - Childrens::None => {} - } - let mut bind_position = BindPosition::new( - || { - output_exprs - .iter() - .map(|expr| Cow::Owned(expr.output_column())) - }, - |a, b| a == b, - ); - let operator = &mut plan.operator; - match operator { - Operator::Join(op) => { - match &mut op.on { - JoinCondition::On { on, filter } => { - let mut left_bind_position = BindPosition::new( - || { - output_exprs[0..left_len] - .iter() - .map(|expr| Cow::Owned(expr.output_column())) - }, - |a, b| a == b, - ); - let mut right_bind_position = BindPosition::new( - || { - output_exprs[left_len..] - .iter() - .map(|expr| Cow::Owned(expr.output_column())) - }, - |a, b| a == b, - ); - for (left_expr, right_expr) in on { - left_bind_position.visit(left_expr)?; - right_bind_position.visit(right_expr)?; - } - if let Some(expr) = filter { - bind_position.visit(expr)?; - } - } - JoinCondition::None => {} - } +pub(crate) fn evaluator_bind_current(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + let operator = &mut plan.operator; - return Ok(()); - } - Operator::Aggregate(op) => { - for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { - bind_position.visit(expr)?; + match operator { + Operator::Join(op) => { + match &mut op.on { + JoinCondition::On { on, filter } => { + for (left_expr, right_expr) in on { + BindEvaluator.visit(left_expr)?; + BindEvaluator.visit(right_expr)?; + } + if let Some(expr) = filter { + BindEvaluator.visit(expr)?; + } } + JoinCondition::None => {} } - Operator::Filter(op) => { - bind_position.visit(&mut op.predicate)?; - } - Operator::Project(op) => { - for expr in op.exprs.iter_mut() { - bind_position.visit(expr)?; - } + + return Ok(()); + } + Operator::Aggregate(op) => { + for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { + BindEvaluator.visit(expr)?; } - Operator::Sort(op) => { - for sort_field in op.sort_fields.iter_mut() { - bind_position.visit(&mut sort_field.expr)?; - } + } + Operator::Filter(op) => { + BindEvaluator.visit(&mut op.predicate)?; + } + Operator::Project(op) => { + for expr in op.exprs.iter_mut() { + BindEvaluator.visit(expr)?; } - Operator::TopK(op) => { - for sort_field in op.sort_fields.iter_mut() { - bind_position.visit(&mut sort_field.expr)?; - } + } + Operator::Sort(op) => { + for sort_field in op.sort_fields.iter_mut() { + BindEvaluator.visit(&mut sort_field.expr)?; } - Operator::FunctionScan(op) => { - for expr in op.table_function.args.iter_mut() { - bind_position.visit(expr)?; - } + } + Operator::TopK(op) => { + for sort_field in op.sort_fields.iter_mut() { + BindEvaluator.visit(&mut sort_field.expr)?; } - Operator::Update(op) => { - for (_, expr) in op.value_exprs.iter_mut() { - bind_position.visit(expr)?; - } + } + Operator::FunctionScan(op) => { + for expr in op.table_function.args.iter_mut() { + BindEvaluator.visit(expr)?; } - Operator::Dummy - | Operator::TableScan(_) - | Operator::Limit(_) - | Operator::Values(_) - | Operator::ShowTable - | Operator::ShowView - | Operator::Explain - | Operator::Describe(_) - | Operator::Insert(_) - | Operator::Delete(_) - | Operator::Analyze(_) - | Operator::AddColumn(_) - | Operator::ChangeColumn(_) - | Operator::DropColumn(_) - | Operator::CreateTable(_) - | Operator::CreateIndex(_) - | Operator::CreateView(_) - | Operator::DropTable(_) - | Operator::DropView(_) - | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) - | Operator::Union(_) - | Operator::Except(_) => (), } - if let Some(exprs) = operator.output_exprs() { - *output_exprs = exprs; + Operator::Update(op) => { + for (_, expr) in op.value_exprs.iter_mut() { + BindEvaluator.visit(expr)?; + } } - - Ok(()) + Operator::Dummy + | Operator::TableScan(_) + | Operator::Limit(_) + | Operator::ScalarSubquery(_) + | Operator::Values(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::ChangeColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::Union(_) + | Operator::Except(_) => (), } -} -impl MatchPattern for BindExpressionPosition { - fn pattern(&self) -> &Pattern { - &BIND_EXPRESSION_POSITION - } + Ok(()) } -impl NormalizationRule for BindExpressionPosition { - fn apply(&self, plan: &mut LogicalPlan) -> Result { - Self::_apply(&mut Vec::new(), plan)?; - Ok(true) - } -} - -#[derive(Clone)] -pub struct EvaluatorBind; - impl EvaluatorBind { fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { match plan.childrens.as_mut() { @@ -201,92 +123,7 @@ impl EvaluatorBind { Childrens::None => {} } - let operator = &mut plan.operator; - - match operator { - Operator::Join(op) => { - match &mut op.on { - JoinCondition::On { on, filter } => { - for (left_expr, right_expr) in on { - BindEvaluator.visit(left_expr)?; - BindEvaluator.visit(right_expr)?; - } - if let Some(expr) = filter { - BindEvaluator.visit(expr)?; - } - } - JoinCondition::None => {} - } - - return Ok(()); - } - Operator::Aggregate(op) => { - for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { - BindEvaluator.visit(expr)?; - } - } - Operator::Filter(op) => { - BindEvaluator.visit(&mut op.predicate)?; - } - Operator::Project(op) => { - for expr in op.exprs.iter_mut() { - BindEvaluator.visit(expr)?; - } - } - Operator::Sort(op) => { - for sort_field in op.sort_fields.iter_mut() { - BindEvaluator.visit(&mut sort_field.expr)?; - } - } - Operator::TopK(op) => { - for sort_field in op.sort_fields.iter_mut() { - BindEvaluator.visit(&mut sort_field.expr)?; - } - } - Operator::FunctionScan(op) => { - for expr in op.table_function.args.iter_mut() { - BindEvaluator.visit(expr)?; - } - } - Operator::Update(op) => { - for (_, expr) in op.value_exprs.iter_mut() { - BindEvaluator.visit(expr)?; - } - } - Operator::Dummy - | Operator::TableScan(_) - | Operator::Limit(_) - | Operator::Values(_) - | Operator::ShowTable - | Operator::ShowView - | Operator::Explain - | Operator::Describe(_) - | Operator::Insert(_) - | Operator::Delete(_) - | Operator::Analyze(_) - | Operator::AddColumn(_) - | Operator::ChangeColumn(_) - | Operator::DropColumn(_) - | Operator::CreateTable(_) - | Operator::CreateIndex(_) - | Operator::CreateView(_) - | Operator::DropTable(_) - | Operator::DropView(_) - | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) - | Operator::Union(_) - | Operator::Except(_) => (), - } - - Ok(()) - } -} - -impl MatchPattern for EvaluatorBind { - fn pattern(&self) -> &Pattern { - &EVALUATOR_BIND_RULE + evaluator_bind_current(plan) } } diff --git a/src/optimizer/rule/normalization/min_max_top_k.rs b/src/optimizer/rule/normalization/min_max_top_k.rs index e678c43d..048a39cd 100644 --- a/src/optimizer/rule/normalization/min_max_top_k.rs +++ b/src/optimizer/rule/normalization/min_max_top_k.rs @@ -15,28 +15,15 @@ use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child, wrap_child_with}; use crate::planner::operator::sort::SortField; use crate::planner::operator::top_k::TopKOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; -use std::sync::LazyLock; - -static MIN_MAX_TOPK_PATTERN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Aggregate(_)), - children: PatternChildrenPredicate::None, -}); pub struct MinMaxToTopK; -impl MatchPattern for MinMaxToTopK { - fn pattern(&self) -> &Pattern { - &MIN_MAX_TOPK_PATTERN - } -} - impl NormalizationRule for MinMaxToTopK { fn apply(&self, plan: &mut LogicalPlan) -> Result { let Operator::Aggregate(op) = &plan.operator else { @@ -98,9 +85,7 @@ mod tests { use crate::planner::operator::Operator; use crate::planner::Childrens; - fn find_aggregate<'a>( - plan: &'a crate::planner::LogicalPlan, - ) -> &'a crate::planner::LogicalPlan { + fn find_aggregate(plan: &crate::planner::LogicalPlan) -> &crate::planner::LogicalPlan { if matches!(plan.operator, Operator::Aggregate(_)) { return plan; } @@ -110,9 +95,9 @@ mod tests { } } - fn find_aggregate_mut<'a>( - plan: &'a mut crate::planner::LogicalPlan, - ) -> &'a mut crate::planner::LogicalPlan { + fn find_aggregate_mut( + plan: &mut crate::planner::LogicalPlan, + ) -> &mut crate::planner::LogicalPlan { if matches!(plan.operator, Operator::Aggregate(_)) { return plan; } diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index f5276cbb..692c23d6 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -13,20 +13,16 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; use crate::expression::{AliasType, ScalarExpression}; -use crate::optimizer::core::pattern::Pattern; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::rule::normalization::column_pruning::ColumnPruning; use crate::optimizer::rule::normalization::combine_operators::{ CollapseGroupByAgg, CollapseProject, CombineFilter, }; -use crate::optimizer::rule::normalization::compilation_in_advance::{ - BindExpressionPosition, EvaluatorBind, -}; +use crate::optimizer::rule::normalization::compilation_in_advance::EvaluatorBind; +use crate::planner::operator::Operator; -use crate::optimizer::rule::normalization::agg_elimination::{ - EliminateRedundantSort, UseStreamDistinct, -}; use crate::optimizer::rule::normalization::min_max_top_k::MinMaxToTopK; use crate::optimizer::rule::normalization::pushdown_limit::{ LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, @@ -47,7 +43,11 @@ mod pushdown_limit; mod pushdown_predicates; mod simplification; mod top_k; -pub use agg_elimination::{annotate_sort_preserving_indexes, annotate_stream_distinct_indexes}; +pub(crate) use agg_elimination::{ + apply_annotated_post_rules, apply_scan_order_hint, OrderHintKind, ScanOrderHint, +}; +pub(crate) use compilation_in_advance::evaluator_bind_current; +pub(crate) use simplification::constant_calculation_current; #[derive(Debug, Copy, Clone)] pub enum NormalizationRuleImpl { @@ -69,35 +69,107 @@ pub enum NormalizationRuleImpl { SimplifyFilter, ConstantCalculation, // CompilationInAdvance - BindExpressionPosition, EvaluatorBind, MinMaxToTopK, TopK, - EliminateRedundantSort, - UseStreamDistinct, } -impl MatchPattern for NormalizationRuleImpl { - fn pattern(&self) -> &Pattern { +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum WholeTreePassKind { + ColumnPruning, + ExpressionRewrite, +} + +#[repr(usize)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum NormalizationRuleRootTag { + Any = 0, + Aggregate, + Filter, + Join, + Limit, + Project, + SortLike, +} + +impl NormalizationRuleRootTag { + pub const COUNT: usize = Self::SortLike as usize + 1; + + pub fn from_operator(operator: &Operator) -> Option { + match operator { + Operator::Aggregate(_) => Some(Self::Aggregate), + Operator::Filter(_) => Some(Self::Filter), + Operator::Join(_) => Some(Self::Join), + Operator::Limit(_) => Some(Self::Limit), + Operator::Project(_) => Some(Self::Project), + Operator::Sort(_) | Operator::TopK(_) => Some(Self::SortLike), + Operator::Dummy + | Operator::TableScan(_) + | Operator::ScalarSubquery(_) + | Operator::Values(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::ChangeColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::FunctionScan(_) + | Operator::Update(_) + | Operator::Union(_) + | Operator::Except(_) => None, + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum NormalizationPassKind { + WholeTreePass(WholeTreePassKind), + LocalRewrite, +} + +impl NormalizationRuleImpl { + pub fn pass_kind(&self) -> NormalizationPassKind { + match self { + NormalizationRuleImpl::ColumnPruning => { + NormalizationPassKind::WholeTreePass(WholeTreePassKind::ColumnPruning) + } + NormalizationRuleImpl::ConstantCalculation | NormalizationRuleImpl::EvaluatorBind => { + NormalizationPassKind::WholeTreePass(WholeTreePassKind::ExpressionRewrite) + } + _ => NormalizationPassKind::LocalRewrite, + } + } + + pub fn root_tag(&self) -> NormalizationRuleRootTag { match self { - NormalizationRuleImpl::ColumnPruning => ColumnPruning.pattern(), - NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(), - NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(), - NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(), - NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), - NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(), - NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(), - NormalizationRuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.pattern(), - NormalizationRuleImpl::PushJoinPredicateIntoScan => PushJoinPredicateIntoScan.pattern(), - NormalizationRuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(), - NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.pattern(), - NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.pattern(), - NormalizationRuleImpl::BindExpressionPosition => BindExpressionPosition.pattern(), - NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.pattern(), - NormalizationRuleImpl::MinMaxToTopK => MinMaxToTopK.pattern(), - NormalizationRuleImpl::TopK => TopK.pattern(), - NormalizationRuleImpl::EliminateRedundantSort => EliminateRedundantSort.pattern(), - NormalizationRuleImpl::UseStreamDistinct => UseStreamDistinct.pattern(), + NormalizationRuleImpl::ColumnPruning => NormalizationRuleRootTag::Any, + NormalizationRuleImpl::CollapseProject => NormalizationRuleRootTag::Project, + NormalizationRuleImpl::CollapseGroupByAgg => NormalizationRuleRootTag::Aggregate, + NormalizationRuleImpl::CombineFilter => NormalizationRuleRootTag::Filter, + NormalizationRuleImpl::LimitProjectTranspose + | NormalizationRuleImpl::PushLimitThroughJoin + | NormalizationRuleImpl::PushLimitIntoTableScan + | NormalizationRuleImpl::TopK => NormalizationRuleRootTag::Limit, + NormalizationRuleImpl::PushPredicateThroughJoin + | NormalizationRuleImpl::PushPredicateIntoScan + | NormalizationRuleImpl::SimplifyFilter => NormalizationRuleRootTag::Filter, + NormalizationRuleImpl::PushJoinPredicateIntoScan => NormalizationRuleRootTag::Join, + NormalizationRuleImpl::ConstantCalculation => NormalizationRuleRootTag::Any, + NormalizationRuleImpl::EvaluatorBind => NormalizationRuleRootTag::Any, + NormalizationRuleImpl::MinMaxToTopK => NormalizationRuleRootTag::Aggregate, } } } @@ -119,12 +191,9 @@ impl NormalizationRule for NormalizationRuleImpl { NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.apply(plan), NormalizationRuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(plan), NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(plan), - NormalizationRuleImpl::BindExpressionPosition => BindExpressionPosition.apply(plan), NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(plan), NormalizationRuleImpl::MinMaxToTopK => MinMaxToTopK.apply(plan), NormalizationRuleImpl::TopK => TopK.apply(plan), - NormalizationRuleImpl::EliminateRedundantSort => EliminateRedundantSort.apply(plan), - NormalizationRuleImpl::UseStreamDistinct => UseStreamDistinct.apply(plan), } } } @@ -156,13 +225,65 @@ pub fn is_subset_exprs(left: &[ScalarExpression], right: &[ScalarExpression]) -> let lhs_stripped = strip_alias(lhs); right.iter().any(|rhs| { let rhs_stripped = strip_alias(rhs); - if lhs_stripped == rhs_stripped { + if lhs_stripped.eq_ignore_colref_pos(rhs_stripped) { return true; } - if matches!(lhs, ScalarExpression::ColumnRef { .. }) { - return lhs_stripped == strip_all_alias(rhs); + if matches!(lhs_stripped, ScalarExpression::ColumnRef { .. }) { + return lhs_stripped.eq_ignore_colref_pos(strip_all_alias(rhs)); } false }) }) } + +pub(crate) fn remap_position(position: &mut usize, removed_positions: &[usize]) { + match removed_positions.binary_search(position) { + Ok(_) => { + debug_assert!( + false, + "encountered a reference to pruned output slot {position}" + ); + } + Err(shift) => { + *position -= shift; + } + } +} + +struct PositionRemapper<'a> { + removed_positions: &'a [usize], +} + +impl<'a> VisitorMut<'a> for PositionRemapper<'_> { + fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { + match expr { + ScalarExpression::ColumnRef { position, .. } => { + remap_position(position, self.removed_positions); + Ok(()) + } + ScalarExpression::Alias { expr, alias } => match alias { + AliasType::Expr(alias_expr) => self.visit(alias_expr), + AliasType::Name(_) => self.visit(expr), + }, + _ => walk_mut_expr(self, expr), + } + } +} + +pub(crate) fn remap_expr_positions( + expr: &mut ScalarExpression, + removed_positions: &[usize], +) -> Result<(), DatabaseError> { + PositionRemapper { removed_positions }.visit(expr) +} + +pub(crate) fn remap_exprs_positions<'a>( + exprs: impl IntoIterator, + removed_positions: &[usize], +) -> Result<(), DatabaseError> { + let mut remapper = PositionRemapper { removed_positions }; + for expr in exprs { + remapper.visit(expr)?; + } + Ok(()) +} diff --git a/src/optimizer/rule/normalization/pushdown_limit.rs b/src/optimizer/rule/normalization/pushdown_limit.rs index 145283ac..7b5507b1 100644 --- a/src/optimizer/rule/normalization/pushdown_limit.rs +++ b/src/optimizer/rule/normalization/pushdown_limit.rs @@ -13,47 +13,14 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::pattern::Pattern; -use crate::optimizer::core::pattern::PatternChildrenPredicate; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child, wrap_child_with}; use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; -use std::sync::LazyLock; - -static LIMIT_PROJECT_TRANSPOSE_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Limit(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::None, - }]), -}); - -static PUSH_LIMIT_THROUGH_JOIN_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Limit(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Join(_)), - children: PatternChildrenPredicate::None, - }]), -}); - -static PUSH_LIMIT_INTO_TABLE_SCAN_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Limit(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::TableScan(_)), - children: PatternChildrenPredicate::None, - }]), -}); pub struct LimitProjectTranspose; -impl MatchPattern for LimitProjectTranspose { - fn pattern(&self) -> &Pattern { - &LIMIT_PROJECT_TRANSPOSE_RULE - } -} - impl NormalizationRule for LimitProjectTranspose { fn apply(&self, plan: &mut LogicalPlan) -> Result { let operator = std::mem::replace(&mut plan.operator, Operator::Dummy); @@ -95,12 +62,6 @@ impl NormalizationRule for LimitProjectTranspose { /// TODO: if join condition is empty. pub struct PushLimitThroughJoin; -impl MatchPattern for PushLimitThroughJoin { - fn pattern(&self) -> &Pattern { - &PUSH_LIMIT_THROUGH_JOIN_RULE - } -} - impl NormalizationRule for PushLimitThroughJoin { fn apply(&self, plan: &mut LogicalPlan) -> Result { let limit_op = match &plan.operator { @@ -131,12 +92,6 @@ impl NormalizationRule for PushLimitThroughJoin { /// Push down `Limit` past a `Scan`. pub struct PushLimitIntoScan; -impl MatchPattern for PushLimitIntoScan { - fn pattern(&self) -> &Pattern { - &PUSH_LIMIT_INTO_TABLE_SCAN_RULE - } -} - impl NormalizationRule for PushLimitIntoScan { fn apply(&self, plan: &mut LogicalPlan) -> Result { let (offset, limit) = match &plan.operator { diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index 94a12d34..aa9e7b09 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -15,55 +15,23 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::range_detacher::{Range, RangeDetacher}; +use crate::expression::visitor_mut::{PositionShift, VisitorMut}; use crate::expression::{BinaryOperator, ScalarExpression}; -use crate::optimizer::core::pattern::Pattern; -use crate::optimizer::core::pattern::PatternChildrenPredicate; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{ - left_child, only_child_mut, replace_with_only_child, right_child, wrap_child_with, + left_child, replace_with_only_child, right_child, wrap_child_with, }; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::{JoinCondition, JoinType}; use crate::planner::operator::{Operator, SortOption}; -use crate::planner::{LogicalPlan, SchemaOutput}; +use crate::planner::{Childrens, LogicalPlan, SchemaOutput}; use crate::types::index::{IndexInfo, IndexMetaRef, IndexType}; use crate::types::value::DataValue; use crate::types::LogicalType; use itertools::Itertools; use std::ops::Bound; -use std::sync::LazyLock; use std::{mem, slice}; -static PUSH_PREDICATE_THROUGH_JOIN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Join(_)), - children: PatternChildrenPredicate::None, - }]), -}); - -static PUSH_PREDICATE_INTO_SCAN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::TableScan(_)), - children: PatternChildrenPredicate::None, - }]), -}); - -static JOIN_WITH_FILTER_PATTERN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Join(_)), - children: PatternChildrenPredicate::None, -}); - -#[allow(dead_code)] -static PUSH_PREDICATE_THROUGH_NON_JOIN: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::None, - }]), -}); - fn split_conjunctive_predicates(expr: &ScalarExpression) -> Vec { match expr { ScalarExpression::Binary { @@ -98,13 +66,6 @@ fn reduce_filters(filters: Vec, having: bool) -> Option bool { - left.iter().all(|l| right.contains(l)) -} - fn plan_output_columns(plan: &LogicalPlan) -> Vec { match plan.output_schema_direct() { SchemaOutput::Schema(schema) => schema, @@ -122,26 +83,25 @@ fn plan_output_columns(plan: &LogicalPlan) -> Vec { /// attributes of the left or right side of sub query when applicable. pub struct PushPredicateThroughJoin; -impl MatchPattern for PushPredicateThroughJoin { - fn pattern(&self) -> &Pattern { - &PUSH_PREDICATE_THROUGH_JOIN - } -} - impl NormalizationRule for PushPredicateThroughJoin { fn apply(&self, plan: &mut LogicalPlan) -> Result { - let filter_op = match &plan.operator { - Operator::Filter(op) => op.clone(), - _ => return Ok(false), - }; - let mut applied = false; let parent_replacement = { - let join_plan = match only_child_mut(plan) { - Some(child) => child, - None => return Ok(false), + let LogicalPlan { + operator, + childrens, + .. + } = plan; + let filter_op = match operator { + Operator::Filter(op) => op, + _ => return Ok(false), + }; + + let Childrens::Only(join_plan) = childrens.as_mut() else { + return Ok(false); }; + let join_plan = join_plan.as_mut(); let join_op = match &join_plan.operator { Operator::Join(op) => op, @@ -167,12 +127,13 @@ impl NormalizationRule for PushPredicateThroughJoin { .unwrap_or_default(); let filter_exprs = split_conjunctive_predicates(&filter_op.predicate); - let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs - .into_iter() - .partition(|f| is_subset_cols(&f.referenced_columns(true), &left_columns)); - let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest - .into_iter() - .partition(|f| is_subset_cols(&f.referenced_columns(true), &right_columns)); + let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs.into_iter().partition(|f| { + f.all_referenced_columns(true, |column| left_columns.contains(column)) + }); + let (right_filters, common_filters): (Vec<_>, Vec<_>) = + rest.into_iter().partition(|f| { + f.all_referenced_columns(true, |column| right_columns.contains(column)) + }); let mut new_ops = (None, None, None); let replace_filters = match join_op.join_type { @@ -239,92 +200,97 @@ impl NormalizationRule for PushPredicateThroughJoin { pub struct PushPredicateIntoScan; -impl MatchPattern for PushPredicateIntoScan { - fn pattern(&self) -> &Pattern { - &PUSH_PREDICATE_INTO_SCAN - } -} - impl NormalizationRule for PushPredicateIntoScan { fn apply(&self, plan: &mut LogicalPlan) -> Result { - if let Operator::Filter(op) = plan.operator.clone() { - if let Some(child) = only_child_mut(plan) { - if let Operator::TableScan(scan_op) = &mut child.operator { - let mut changed = false; - for IndexInfo { - meta, - range, - covered_deserializers, - cover_mapping, - sort_option, - sort_elimination_hint: _, - stream_distinct_hint: _, - } in &mut scan_op.index_infos - { - if range.is_some() { - continue; - } - let SortOption::OrderBy { - ignore_prefix_len, .. - } = sort_option - else { - return Err(DatabaseError::InvalidIndex); - }; - *range = match meta.ty { - IndexType::PrimaryKey { is_multiple: false } - | IndexType::Unique - | IndexType::Normal => { - RangeDetacher::new(meta.table_name.as_ref(), &meta.column_ids[0]) - .detach(&op.predicate)? - } - IndexType::PrimaryKey { is_multiple: true } | IndexType::Composite => { - Self::composite_range(&op, meta, ignore_prefix_len)? - } - }; - if range.is_none() { - continue; - } - changed = true; - - *covered_deserializers = None; - *cover_mapping = None; - - // try index covered - let mut mapping_slots = vec![usize::MAX; scan_op.columns.len()]; - let mut needs_mapping = false; - let index_column_types = match &meta.value_ty { - LogicalType::Tuple(tys) => tys, - ty => slice::from_ref(ty), - }; - let mut deserializers = Vec::with_capacity(meta.column_ids.len()); - - for (idx, column_id) in meta.column_ids.iter().enumerate() { - if let Some((scan_idx, column)) = - scan_op.columns.values().enumerate().find(|(_, column)| { - column.id().map(|id| id == *column_id).unwrap_or(false) - }) - { - mapping_slots[scan_idx] = idx; - needs_mapping |= scan_idx != idx; - deserializers.push(column.datatype().serializable()); - } else { - deserializers.push(index_column_types[idx].skip_serializable()); - } - } + let LogicalPlan { + operator, + childrens, + .. + } = plan; + let filter_op = match operator { + Operator::Filter(op) => op, + _ => return Ok(false), + }; + let Childrens::Only(child) = childrens.as_mut() else { + return Ok(false); + }; + let child = child.as_mut(); + let Operator::TableScan(scan_op) = &mut child.operator else { + return Ok(false); + }; - if mapping_slots.iter().all(|slot| *slot != usize::MAX) { - *covered_deserializers = Some(deserializers); - if needs_mapping { - *cover_mapping = Some(mapping_slots); - } - } - } - return Ok(changed); + let mut changed = false; + for IndexInfo { + meta, + range, + covered_deserializers, + cover_mapping, + sort_option, + sort_elimination_hint: _, + stream_distinct_hint: _, + } in &mut scan_op.index_infos + { + if range.is_some() { + continue; + } + let SortOption::OrderBy { + ignore_prefix_len, .. + } = sort_option + else { + return Err(DatabaseError::InvalidIndex); + }; + *range = match meta.ty { + IndexType::PrimaryKey { is_multiple: false } + | IndexType::Unique + | IndexType::Normal => { + RangeDetacher::new(meta.table_name.as_ref(), &meta.column_ids[0]) + .detach(&filter_op.predicate)? + } + IndexType::PrimaryKey { is_multiple: true } | IndexType::Composite => { + Self::composite_range(filter_op, meta, ignore_prefix_len)? + } + }; + if range.is_none() { + continue; + } + changed = true; + + *covered_deserializers = None; + *cover_mapping = None; + + // try index covered + let mut mapping_slots = vec![usize::MAX; scan_op.columns.len()]; + let mut needs_mapping = false; + let index_column_types = match &meta.value_ty { + LogicalType::Tuple(tys) => tys, + ty => slice::from_ref(ty), + }; + let mut deserializers = Vec::with_capacity(meta.column_ids.len()); + + for (idx, column_id) in meta.column_ids.iter().enumerate() { + if let Some((scan_idx, column)) = scan_op + .columns + .values() + .enumerate() + .find(|(_, column)| column.id().map(|id| id == *column_id).unwrap_or(false)) + { + mapping_slots[scan_idx] = idx; + needs_mapping |= scan_idx != idx; + deserializers.push(column.datatype().serializable()); + } else { + deserializers.push(index_column_types[idx].skip_serializable()); + } + } + + if mapping_slots.iter().all(|slot| *slot != usize::MAX) { + *covered_deserializers = Some(deserializers); + if needs_mapping { + *cover_mapping = Some(mapping_slots); } } } - Ok(false) + Ok(changed) } } @@ -388,12 +354,6 @@ impl PushPredicateIntoScan { pub struct PushJoinPredicateIntoScan; -impl MatchPattern for PushJoinPredicateIntoScan { - fn pattern(&self) -> &Pattern { - &JOIN_WITH_FILTER_PATTERN - } -} - impl NormalizationRule for PushJoinPredicateIntoScan { fn apply(&self, plan: &mut LogicalPlan) -> Result { let (join_type, filter_expr) = { @@ -427,12 +387,13 @@ impl NormalizationRule for PushJoinPredicateIntoScan { .unwrap_or_default(); let filter_exprs = split_conjunctive_predicates(&filter_expr); - let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs - .into_iter() - .partition(|expr| is_subset_cols(&expr.referenced_columns(true), &left_columns)); - let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest - .into_iter() - .partition(|expr| is_subset_cols(&expr.referenced_columns(true), &right_columns)); + let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs.into_iter().partition(|expr| { + expr.all_referenced_columns(true, |column| left_columns.contains(column)) + }); + let (right_filters, common_filters): (Vec<_>, Vec<_>) = + rest.into_iter().partition(|expr| { + expr.all_referenced_columns(true, |column| right_columns.contains(column)) + }); let (push_left, push_right) = match join_type { JoinType::Inner => (true, true), @@ -457,11 +418,19 @@ impl NormalizationRule for PushJoinPredicateIntoScan { remaining_filters.extend(left_remain); } - let (right_push, right_remain) = if push_right { + let (mut right_push, right_remain) = if push_right { (right_filters, Vec::new()) } else { (Vec::new(), right_filters) }; + if !right_push.is_empty() { + let mut localizer = PositionShift { + delta: -(left_columns.len() as isize), + }; + for expr in &mut right_push { + localizer.visit(expr)?; + } + } if let Some(filter_op) = reduce_filters(right_push, false) { new_ops.1 = Some(Operator::Filter(filter_op)); } else { @@ -687,14 +656,14 @@ mod tests { let c1_gt = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(c1_ref.clone())), + left_expr: Box::new(ScalarExpression::column_expr(c1_ref.clone(), 0)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(0))), evaluator: None, ty: LogicalType::Boolean, }; let c2_gt = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(c2_ref.clone())), + left_expr: Box::new(ScalarExpression::column_expr(c2_ref.clone(), 1)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(0))), evaluator: None, ty: LogicalType::Boolean, diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index ccd0560b..fe8ad75e 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -15,65 +15,56 @@ use crate::errors::DatabaseError; use crate::expression::simplify::{ConstantCalculator, Simplify}; use crate::expression::visitor_mut::VisitorMut; -use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; -use std::sync::LazyLock; - -static CONSTANT_CALCULATION_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |_| true, - children: PatternChildrenPredicate::None, -}); - -static SIMPLIFY_FILTER_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| !matches!(op, Operator::Aggregate(_)), - children: PatternChildrenPredicate::Recursive, - }]), -}); #[derive(Copy, Clone)] pub struct ConstantCalculation; -impl ConstantCalculation { - fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { - let operator = &mut plan.operator; +pub(crate) fn constant_calculation_current(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + let operator = &mut plan.operator; - match operator { - Operator::Aggregate(op) => { - for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { - ConstantCalculator.visit(expr)?; - } - } - Operator::Filter(op) => { - ConstantCalculator.visit(&mut op.predicate)?; + match operator { + Operator::Aggregate(op) => { + for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { + ConstantCalculator.visit(expr)?; } - Operator::Join(op) => { - if let JoinCondition::On { on, filter } = &mut op.on { - for (left_expr, right_expr) in on { - ConstantCalculator.visit(left_expr)?; - ConstantCalculator.visit(right_expr)?; - } - if let Some(expr) = filter { - ConstantCalculator.visit(expr)?; - } + } + Operator::Filter(op) => { + ConstantCalculator.visit(&mut op.predicate)?; + } + Operator::Join(op) => { + if let JoinCondition::On { on, filter } = &mut op.on { + for (left_expr, right_expr) in on { + ConstantCalculator.visit(left_expr)?; + ConstantCalculator.visit(right_expr)?; } - } - Operator::Project(op) => { - for expr in &mut op.exprs { + if let Some(expr) = filter { ConstantCalculator.visit(expr)?; } } - Operator::Sort(op) => { - for field in &mut op.sort_fields { - ConstantCalculator.visit(&mut field.expr)?; - } + } + Operator::Project(op) => { + for expr in &mut op.exprs { + ConstantCalculator.visit(expr)?; + } + } + Operator::Sort(op) => { + for field in &mut op.sort_fields { + ConstantCalculator.visit(&mut field.expr)?; } - _ => (), } + _ => (), + } + + Ok(()) +} + +impl ConstantCalculation { + fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + constant_calculation_current(plan)?; match plan.childrens.as_mut() { Childrens::Only(child) => Self::_apply(child.as_mut())?, Childrens::Twins { left, right } => { @@ -87,12 +78,6 @@ impl ConstantCalculation { } } -impl MatchPattern for ConstantCalculation { - fn pattern(&self) -> &Pattern { - &CONSTANT_CALCULATION_RULE - } -} - impl NormalizationRule for ConstantCalculation { fn apply(&self, plan: &mut LogicalPlan) -> Result { Self::_apply(plan)?; @@ -103,9 +88,17 @@ impl NormalizationRule for ConstantCalculation { #[derive(Copy, Clone)] pub struct SimplifyFilter; -impl MatchPattern for SimplifyFilter { - fn pattern(&self) -> &Pattern { - &SIMPLIFY_FILTER_RULE +fn has_aggregate_descendant(plan: &LogicalPlan) -> bool { + if matches!(plan.operator, Operator::Aggregate(_)) { + return true; + } + + match plan.childrens.as_ref() { + Childrens::Only(child) => has_aggregate_descendant(child), + Childrens::Twins { left, right } => { + has_aggregate_descendant(left) || has_aggregate_descendant(right) + } + Childrens::None => false, } } @@ -115,6 +108,11 @@ impl NormalizationRule for SimplifyFilter { if filter_op.is_optimized { return Ok(false); } + if let Some(child) = plan.childrens.iter().next() { + if has_aggregate_descendant(child) { + return Ok(false); + } + } ConstantCalculator.visit(&mut filter_op.predicate)?; Simplify::default().visit(&mut filter_op.predicate)?; filter_op.is_optimized = true; @@ -358,9 +356,10 @@ mod test { op: UnaryOperator::Minus, expr: Box::new(ScalarExpression::Binary { op: BinaryOperator::Plus, - left_expr: Box::new(ScalarExpression::column_expr(ColumnRef::from( - c1_col - ))), + left_expr: Box::new(ScalarExpression::column_expr( + ColumnRef::from(c1_col), + 0 + )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -368,7 +367,9 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new(ScalarExpression::column_expr(ColumnRef::from(c2_col))), + right_expr: Box::new( + ScalarExpression::column_expr(ColumnRef::from(c2_col), 1,) + ), evaluator: None, ty: LogicalType::Boolean, } diff --git a/src/optimizer/rule/normalization/top_k.rs b/src/optimizer/rule/normalization/top_k.rs index 4775024a..193ec959 100644 --- a/src/optimizer/rule/normalization/top_k.rs +++ b/src/optimizer/rule/normalization/top_k.rs @@ -13,31 +13,14 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::optimizer::core::pattern::Pattern; -use crate::optimizer::core::pattern::PatternChildrenPredicate; -use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; use crate::planner::operator::top_k::TopKOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; -use std::sync::LazyLock; - -static TOP_K_RULE: LazyLock = LazyLock::new(|| Pattern { - predicate: |op| matches!(op, Operator::Limit(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Sort(_)), - children: PatternChildrenPredicate::None, - }]), -}); pub struct TopK; -impl MatchPattern for TopK { - fn pattern(&self) -> &Pattern { - &TOP_K_RULE - } -} - impl NormalizationRule for TopK { fn apply(&self, plan: &mut LogicalPlan) -> Result { let (offset, limit) = match &plan.operator { diff --git a/src/orm/dql.rs b/src/orm/dql.rs index e96d8235..4edf5b87 100644 --- a/src/orm/dql.rs +++ b/src/orm/dql.rs @@ -152,7 +152,9 @@ impl<'a, S: Storage> DBTransaction<'a, S> { } /// Fetches all rows for a model inside the current transaction. - pub fn fetch(&mut self) -> Result, M>, DatabaseError> { + pub fn fetch( + &mut self, + ) -> Result>, M>, DatabaseError> { orm_list::<_, M>(self) } @@ -164,7 +166,8 @@ impl<'a, S: Storage> DBTransaction<'a, S> { /// Lists all table names inside the current transaction. pub fn show_tables( &mut self, - ) -> Result, String>, DatabaseError> { + ) -> Result>, String>, DatabaseError> + { Ok(ProjectValueIter::new( self.execute(&orm_show_tables_statement(), &[])?, )) @@ -173,7 +176,8 @@ impl<'a, S: Storage> DBTransaction<'a, S> { /// Lists all view names inside the current transaction. pub fn show_views( &mut self, - ) -> Result, String>, DatabaseError> { + ) -> Result>, String>, DatabaseError> + { Ok(ProjectValueIter::new( self.execute(&orm_show_views_statement(), &[])?, )) @@ -182,7 +186,8 @@ impl<'a, S: Storage> DBTransaction<'a, S> { /// Describes the schema of the model table inside the current transaction. pub fn describe( &mut self, - ) -> Result, DescribeColumn>, DatabaseError> { + ) -> Result>, DescribeColumn>, DatabaseError> + { Ok(self .execute(&orm_describe_statement(M::table_name()), &[])? .orm::()) diff --git a/src/orm/mod.rs b/src/orm/mod.rs index 904507af..14ee4095 100644 --- a/src/orm/mod.rs +++ b/src/orm/mod.rs @@ -2,7 +2,8 @@ use crate::catalog::{ColumnRef, TableCatalog}; use crate::db::{ - DBTransaction, Database, DatabaseIter, OrmIter, ResultIter, Statement, TransactionIter, + BorrowResultIter, DBTransaction, Database, DatabaseIter, OrmIter, ResultIter, Statement, + TransactionIter, }; use crate::errors::DatabaseError; use crate::storage::{Storage, Transaction}; @@ -303,10 +304,12 @@ trait ValueExpressionOps: Sized { }) } + #[allow(clippy::wrong_self_convention)] fn is_null_expr(self) -> QueryExpr { QueryExpr::from_expr(Expr::IsNull(Box::new(self.into_query_value().into_expr()))) } + #[allow(clippy::wrong_self_convention)] fn is_not_null_expr(self) -> QueryExpr { QueryExpr::from_expr(Expr::IsNotNull(Box::new( self.into_query_value().into_expr(), @@ -466,21 +469,25 @@ impl Field { } /// Builds `field + value`. + #[allow(clippy::should_implement_trait)] pub fn add>(self, value: V) -> QueryValue { ValueExpressionOps::add_expr(self, value) } /// Builds `field - value`. + #[allow(clippy::should_implement_trait)] pub fn sub>(self, value: V) -> QueryValue { ValueExpressionOps::sub_expr(self, value) } /// Builds `field * value`. + #[allow(clippy::should_implement_trait)] pub fn mul>(self, value: V) -> QueryValue { ValueExpressionOps::mul_expr(self, value) } /// Builds `field / value`. + #[allow(clippy::should_implement_trait)] pub fn div>(self, value: V) -> QueryValue { ValueExpressionOps::div_expr(self, value) } @@ -491,6 +498,7 @@ impl Field { } /// Builds unary `-field`. + #[allow(clippy::should_implement_trait)] pub fn neg(self) -> QueryValue { ValueExpressionOps::neg_expr(self) } @@ -817,6 +825,7 @@ impl QueryExpr { /// let users = database.from::().filter(expr).fetch()?; /// # Ok::<(), kite_sql::errors::DatabaseError>(()) /// ``` + #[allow(clippy::should_implement_trait)] pub fn not(self) -> QueryExpr { QueryExpr::from_expr(Expr::UnaryOp { op: sqlparser::ast::UnaryOperator::Not, @@ -1040,21 +1049,25 @@ impl QueryValue { } /// Builds `expr + value`. + #[allow(clippy::should_implement_trait)] pub fn add>(self, value: V) -> QueryValue { ValueExpressionOps::add_expr(self, value) } /// Builds `expr - value`. + #[allow(clippy::should_implement_trait)] pub fn sub>(self, value: V) -> QueryValue { ValueExpressionOps::sub_expr(self, value) } /// Builds `expr * value`. + #[allow(clippy::should_implement_trait)] pub fn mul>(self, value: V) -> QueryValue { ValueExpressionOps::mul_expr(self, value) } /// Builds `expr / value`. + #[allow(clippy::should_implement_trait)] pub fn div>(self, value: V) -> QueryValue { ValueExpressionOps::div_expr(self, value) } @@ -1065,6 +1078,7 @@ impl QueryValue { } /// Builds unary `-expr`. + #[allow(clippy::should_implement_trait)] pub fn neg(self) -> QueryValue { ValueExpressionOps::neg_expr(self) } @@ -1365,7 +1379,7 @@ impl<'a, S: Storage> StatementSource for &'a Database { } impl<'a, 'tx, S: Storage> StatementSource for &'a mut DBTransaction<'tx, S> { - type Iter = TransactionIter<'a>; + type Iter = TransactionIter<'a, S::TransactionType<'tx>>; fn execute_statement>( self, @@ -1459,6 +1473,7 @@ struct UpdateAssignment { value: UpdateAssignmentValue, } +#[allow(clippy::large_enum_variant)] enum UpdateAssignmentValue { Param(DataValue), Expr(QueryValue), @@ -3506,10 +3521,12 @@ impl> QueryBuilder { self.push_filter(left.into().lte(right), FilterMode::Replace) } + #[allow(clippy::wrong_self_convention)] fn is_null>(self, value: V) -> Self { self.push_filter(value.into().is_null(), FilterMode::Replace) } + #[allow(clippy::wrong_self_convention)] fn is_not_null>(self, value: V) -> Self { self.push_filter(value.into().is_not_null(), FilterMode::Replace) } @@ -3647,8 +3664,7 @@ impl> QueryBuilder { Some(DataValue::UInt64(value)) => *value as usize, other => { return Err(DatabaseError::InvalidValue(format!( - "unexpected count result: {:?}", - other + "unexpected count result: {other:?}" ))) } }, @@ -3942,6 +3958,7 @@ fn model_insert_columns() -> Vec { .collect() } +#[allow(clippy::too_many_arguments)] fn select_query( source: &QuerySource, joins: Vec, diff --git a/src/planner/mod.rs b/src/planner/mod.rs index c474187f..0c0ec3b0 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -161,9 +161,11 @@ impl LogicalPlan { mut childrens_iter: ChildrensIter, ) -> SchemaOutput { match operator { - Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) => { - childrens_iter.next().unwrap().output_schema_direct() - } + Operator::Filter(_) + | Operator::Sort(_) + | Operator::Limit(_) + | Operator::TopK(_) + | Operator::ScalarSubquery(_) => childrens_iter.next().unwrap().output_schema_direct(), Operator::Aggregate(op) => SchemaOutput::Schema( op.agg_calls .iter() diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 86afe50a..081b2232 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -32,6 +32,7 @@ pub mod insert; pub mod join; pub mod limit; pub mod project; +pub mod scalar_subquery; pub mod sort; pub mod table_scan; pub mod top_k; @@ -43,8 +44,8 @@ pub mod values; use self::{ aggregate::AggregateOperator, alter_table::add_column::AddColumnOperator, alter_table::change_column::ChangeColumnOperator, filter::FilterOperator, join::JoinOperator, - limit::LimitOperator, project::ProjectOperator, sort::SortOperator, - table_scan::TableScanOperator, + limit::LimitOperator, project::ProjectOperator, scalar_subquery::ScalarSubqueryOperator, + sort::SortOperator, table_scan::TableScanOperator, }; use crate::catalog::ColumnRef; use crate::expression::ScalarExpression; @@ -71,7 +72,6 @@ use crate::planner::operator::union::UnionOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::values::ValuesOperator; use crate::types::index::IndexInfo; -use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; use std::fmt::Formatter; @@ -84,6 +84,7 @@ pub enum Operator { Filter(FilterOperator), Join(JoinOperator), Project(ProjectOperator), + ScalarSubquery(ScalarSubqueryOperator), TableScan(TableScanOperator), FunctionScan(FunctionScanOperator), Sort(SortOperator), @@ -156,6 +157,7 @@ pub enum PlanImpl { HashJoin, NestLoopJoin, Project, + ScalarSubquery, SeqScan, FunctionScan, IndexScan(Box), @@ -179,25 +181,28 @@ pub enum PlanImpl { } impl Operator { - pub fn output_exprs(&self) -> Option> { + pub fn output_exprs(&self, output_exprs: &mut Vec) -> bool { match self { - Operator::Dummy => None, - Operator::Aggregate(op) => Some( - op.agg_calls - .iter() - .chain(op.groupby_exprs.iter()) - .cloned() - .collect_vec(), - ), - Operator::Filter(_) | Operator::Join(_) => None, - Operator::Project(op) => Some(op.exprs.clone()), - Operator::TableScan(op) => Some( - op.columns - .values() - .map(|column| ScalarExpression::column_expr(column.clone())) - .collect_vec(), - ), - Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) => None, + Operator::Dummy => false, + Operator::Aggregate(op) => { + output_exprs.clear(); + output_exprs.extend(op.agg_calls.iter().chain(op.groupby_exprs.iter()).cloned()); + true + } + Operator::Filter(_) | Operator::Join(_) | Operator::ScalarSubquery(_) => false, + Operator::Project(op) => { + output_exprs.clear(); + output_exprs.extend(op.exprs.iter().cloned()); + true + } + Operator::TableScan(op) => { + output_exprs.clear(); + output_exprs.extend(op.columns.values().enumerate().map(|(position, column)| { + ScalarExpression::column_expr(column.clone(), position) + })); + true + } + Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) => false, Operator::Values(ValuesOperator { schema_ref, .. }) | Operator::Union(UnionOperator { left_schema_ref: schema_ref, @@ -206,21 +211,31 @@ impl Operator { | Operator::Except(ExceptOperator { left_schema_ref: schema_ref, .. - }) => Some( - schema_ref - .iter() - .cloned() - .map(ScalarExpression::column_expr) - .collect_vec(), - ), - Operator::FunctionScan(op) => Some( - op.table_function - .inner - .output_schema() - .iter() - .map(|column| ScalarExpression::column_expr(column.clone())) - .collect_vec(), - ), + }) => { + output_exprs.clear(); + output_exprs.extend( + schema_ref + .iter() + .cloned() + .enumerate() + .map(|(position, column)| ScalarExpression::column_expr(column, position)), + ); + true + } + Operator::FunctionScan(op) => { + output_exprs.clear(); + output_exprs.extend( + op.table_function + .inner + .output_schema() + .iter() + .enumerate() + .map(|(position, column)| { + ScalarExpression::column_expr(column.clone(), position) + }), + ); + true + } Operator::ShowTable | Operator::ShowView | Operator::Explain @@ -240,59 +255,60 @@ impl Operator { | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => None, + | Operator::CopyToFile(_) => false, } } - pub fn referenced_columns(&self, only_column_ref: bool) -> Vec { + pub fn visit_referenced_columns( + &self, + only_column_ref: bool, + f: &mut impl FnMut(&ColumnRef) -> bool, + ) -> bool { match self { Operator::Aggregate(op) => op .agg_calls .iter() .chain(op.groupby_exprs.iter()) - .flat_map(|expr| expr.referenced_columns(only_column_ref)) - .collect_vec(), - Operator::Filter(op) => op.predicate.referenced_columns(only_column_ref), + .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + Operator::Filter(op) => op.predicate.visit_referenced_columns(only_column_ref, f), Operator::Join(op) => { - let mut exprs = Vec::new(); - if let JoinCondition::On { on, filter } = &op.on { for (left_expr, right_expr) in on { - exprs.append(&mut left_expr.referenced_columns(only_column_ref)); - exprs.append(&mut right_expr.referenced_columns(only_column_ref)); + if !left_expr.visit_referenced_columns(only_column_ref, f) + || !right_expr.visit_referenced_columns(only_column_ref, f) + { + return false; + } } if let Some(filter_expr) = filter { - exprs.append(&mut filter_expr.referenced_columns(only_column_ref)); + return filter_expr.visit_referenced_columns(only_column_ref, f); } } - exprs + true } Operator::Project(op) => op .exprs .iter() - .flat_map(|expr| expr.referenced_columns(only_column_ref)) - .collect_vec(), - Operator::TableScan(op) => op.columns.values().cloned().collect_vec(), + .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + Operator::ScalarSubquery(_) => true, + Operator::TableScan(op) => op.columns.values().all(f), Operator::FunctionScan(op) => op .table_function .args .iter() - .flat_map(|expr| expr.referenced_columns(only_column_ref)) - .collect_vec(), + .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), Operator::Sort(op) => op .sort_fields .iter() .map(|field| &field.expr) - .flat_map(|expr| expr.referenced_columns(only_column_ref)) - .collect_vec(), + .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), Operator::TopK(op) => op .sort_fields .iter() .map(|field| &field.expr) - .flat_map(|expr| expr.referenced_columns(only_column_ref)) - .collect_vec(), - Operator::Values(ValuesOperator { schema_ref, .. }) => Vec::clone(schema_ref), + .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + Operator::Values(ValuesOperator { schema_ref, .. }) => schema_ref.iter().all(f), Operator::Union(UnionOperator { left_schema_ref, _right_schema_ref, @@ -303,10 +319,9 @@ impl Operator { }) => left_schema_ref .iter() .chain(_right_schema_ref.iter()) - .cloned() - .collect_vec(), - Operator::Analyze(_) => vec![], - Operator::Delete(op) => op.primary_keys.clone(), + .all(f), + Operator::Analyze(_) => true, + Operator::Delete(op) => op.primary_keys.iter().all(f), Operator::Dummy | Operator::Limit(_) | Operator::ShowTable @@ -326,9 +341,35 @@ impl Operator { | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => vec![], + | Operator::CopyToFile(_) => true, } } + + pub fn any_referenced_column( + &self, + only_column_ref: bool, + mut predicate: impl FnMut(&ColumnRef) -> bool, + ) -> bool { + let mut found = false; + self.visit_referenced_columns(only_column_ref, &mut |column| { + found = predicate(column); + !found + }); + found + } + + pub fn all_referenced_columns( + &self, + only_column_ref: bool, + mut predicate: impl FnMut(&ColumnRef) -> bool, + ) -> bool { + let mut all = true; + self.visit_referenced_columns(only_column_ref, &mut |column| { + all = predicate(column); + all + }); + all + } } impl fmt::Display for Operator { @@ -339,6 +380,7 @@ impl fmt::Display for Operator { Operator::Filter(op) => write!(f, "{op}"), Operator::Join(op) => write!(f, "{op}"), Operator::Project(op) => write!(f, "{op}"), + Operator::ScalarSubquery(op) => write!(f, "{op}"), Operator::TableScan(op) => write!(f, "{op}"), Operator::FunctionScan(op) => write!(f, "{op}"), Operator::Sort(op) => write!(f, "{op}"), @@ -411,6 +453,7 @@ impl fmt::Display for PlanImpl { PlanImpl::HashJoin => write!(f, "HashJoin"), PlanImpl::NestLoopJoin => write!(f, "NestLoopJoin"), PlanImpl::Project => write!(f, "Project"), + PlanImpl::ScalarSubquery => write!(f, "ScalarSubquery"), PlanImpl::SeqScan => write!(f, "SeqScan"), PlanImpl::FunctionScan => write!(f, "FunctionScan"), PlanImpl::IndexScan(index) => write!(f, "IndexScan By {index}"), diff --git a/src/planner/operator/scalar_subquery.rs b/src/planner/operator/scalar_subquery.rs new file mode 100644 index 00000000..b3ffe3f7 --- /dev/null +++ b/src/planner/operator/scalar_subquery.rs @@ -0,0 +1,37 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::Operator; +use crate::planner::{Childrens, LogicalPlan}; +use kite_sql_serde_macros::ReferenceSerialization; +use std::fmt; +use std::fmt::Formatter; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] +pub struct ScalarSubqueryOperator; + +impl ScalarSubqueryOperator { + pub fn build(children: LogicalPlan) -> LogicalPlan { + LogicalPlan::new( + Operator::ScalarSubquery(ScalarSubqueryOperator), + Childrens::Only(Box::new(children)), + ) + } +} + +impl fmt::Display for ScalarSubqueryOperator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "ScalarSubquery") + } +} diff --git a/src/planner/operator/table_scan.rs b/src/planner/operator/table_scan.rs index aefd78f2..2536421f 100644 --- a/src/planner/operator/table_scan.rs +++ b/src/planner/operator/table_scan.rs @@ -68,7 +68,7 @@ impl TableScanOperator { DatabaseError::column_not_found(format!("index column id: {col_id} not found")) })?; sort_fields.push(SortField { - expr: ScalarExpression::column_expr(column.clone()), + expr: ScalarExpression::column_expr(column.clone(), sort_fields.len()), asc: true, nulls_first: false, }) diff --git a/src/python.rs b/src/python.rs index a9f83040..4c959bf6 100644 --- a/src/python.rs +++ b/src/python.rs @@ -14,9 +14,12 @@ #![cfg(all(not(target_arch = "wasm32"), feature = "python"))] -use crate::db::{DataBaseBuilder, Database, DatabaseIter, ResultIter}; +use crate::db::{DataBaseBuilder, Database, DatabaseIter}; use crate::errors::DatabaseError; +#[cfg(feature = "lmdb")] +use crate::storage::lmdb::LmdbStorage; use crate::storage::memory::MemoryStorage; +#[cfg(feature = "rocksdb")] use crate::storage::rocksdb::RocksStorage; use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::DataValue; @@ -61,7 +64,7 @@ fn data_value_to_py(py: Python<'_>, value: &DataValue) -> PyResult { Ok(object) } -fn tuple_to_python_row(py: Python<'_>, tuple: Tuple) -> PyResult { +fn tuple_to_python_row(py: Python<'_>, tuple: &Tuple) -> PyResult { let row = PyDict::new(py); match tuple.pk.as_ref() { @@ -93,13 +96,24 @@ fn schema_to_python(py: Python<'_>, schema: &SchemaRef) -> PyResult), Memory(Database), + #[cfg(feature = "rocksdb")] Rocks(Database), } impl PythonDatabaseInner { fn run(&self, sql: &str) -> Result { match self { + #[cfg(feature = "lmdb")] + PythonDatabaseInner::Lmdb(db) => { + let iter = db.run(sql)?; + // DatabaseIter owns state internally; only the type carries the lifetime. + let iter_static: DatabaseIter<'static, LmdbStorage> = + unsafe { std::mem::transmute(iter) }; + Ok(PythonResultIterInner::Lmdb(iter_static)) + } PythonDatabaseInner::Memory(db) => { let iter = db.run(sql)?; // DatabaseIter owns state internally; only the type carries the lifetime. @@ -107,6 +121,7 @@ impl PythonDatabaseInner { unsafe { std::mem::transmute(iter) }; Ok(PythonResultIterInner::Memory(iter_static)) } + #[cfg(feature = "rocksdb")] PythonDatabaseInner::Rocks(db) => { let iter = db.run(sql)?; // DatabaseIter owns state internally; only the type carries the lifetime. @@ -119,28 +134,40 @@ impl PythonDatabaseInner { } enum PythonResultIterInner { + #[cfg(feature = "lmdb")] + Lmdb(DatabaseIter<'static, LmdbStorage>), Memory(DatabaseIter<'static, MemoryStorage>), + #[cfg(feature = "rocksdb")] Rocks(DatabaseIter<'static, RocksStorage>), } impl PythonResultIterInner { - fn next_tuple(&mut self) -> Option> { + fn next_tuple(&mut self) -> Result, DatabaseError> { match self { - PythonResultIterInner::Memory(iter) => iter.next(), - PythonResultIterInner::Rocks(iter) => iter.next(), + #[cfg(feature = "lmdb")] + PythonResultIterInner::Lmdb(iter) => iter.next_borrowed_tuple(), + PythonResultIterInner::Memory(iter) => iter.next_borrowed_tuple(), + #[cfg(feature = "rocksdb")] + PythonResultIterInner::Rocks(iter) => iter.next_borrowed_tuple(), } } fn schema(&self) -> &SchemaRef { match self { + #[cfg(feature = "lmdb")] + PythonResultIterInner::Lmdb(iter) => iter.schema(), PythonResultIterInner::Memory(iter) => iter.schema(), + #[cfg(feature = "rocksdb")] PythonResultIterInner::Rocks(iter) => iter.schema(), } } fn done(self) -> Result<(), DatabaseError> { match self { + #[cfg(feature = "lmdb")] + PythonResultIterInner::Lmdb(iter) => iter.done(), PythonResultIterInner::Memory(iter) => iter.done(), + #[cfg(feature = "rocksdb")] PythonResultIterInner::Rocks(iter) => iter.done(), } } @@ -154,9 +181,34 @@ pub struct PythonDatabase { #[pymethods] impl PythonDatabase { #[new] - pub fn new(path: String) -> PyResult { - let inner = - PythonDatabaseInner::Rocks(DataBaseBuilder::path(path).build().map_err(to_py_err)?); + #[pyo3(signature = (path, backend=None))] + pub fn new(path: String, backend: Option<&str>) -> PyResult { + let backend = backend.unwrap_or("rocksdb").to_ascii_lowercase(); + let inner = match backend.as_str() { + #[cfg(feature = "rocksdb")] + "rocksdb" => PythonDatabaseInner::Rocks( + DataBaseBuilder::path(path) + .build_rocksdb() + .map_err(to_py_err)?, + ), + #[cfg(feature = "lmdb")] + "lmdb" => PythonDatabaseInner::Lmdb( + DataBaseBuilder::path(path) + .build_lmdb() + .map_err(to_py_err)?, + ), + other => { + let mut expected = Vec::new(); + #[cfg(feature = "rocksdb")] + expected.push("rocksdb"); + #[cfg(feature = "lmdb")] + expected.push("lmdb"); + return Err(PyValueError::new_err(format!( + "unsupported backend '{other}', expected {}", + expected.join(" or ") + ))); + } + }; Ok(PythonDatabase { inner }) } @@ -178,9 +230,7 @@ impl PythonDatabase { pub fn execute(&self, sql: &str) -> PyResult<()> { let mut iter = self.inner.run(sql).map_err(to_py_err)?; - while let Some(tuple) = iter.next_tuple() { - tuple.map_err(to_py_err)?; - } + while iter.next_tuple().map_err(to_py_err)?.is_some() {} iter.done().map_err(to_py_err)?; Ok(()) @@ -211,9 +261,8 @@ impl PythonResultIter { pub fn next(&mut self, py: Python<'_>) -> PyResult> { let iter = self.inner_mut()?; - match iter.next_tuple() { - Some(Ok(tuple)) => tuple_to_python_row(py, tuple).map(Some), - Some(Err(err)) => Err(to_py_err(err)), + match iter.next_tuple().map_err(to_py_err)? { + Some(tuple) => tuple_to_python_row(py, tuple).map(Some), None => Ok(None), } } @@ -230,8 +279,8 @@ impl PythonResultIter { .ok_or_else(|| PyValueError::new_err("iterator already consumed"))?; let mut rows = Vec::new(); - while let Some(tuple) = iter.next_tuple() { - rows.push(tuple_to_python_row(py, tuple.map_err(to_py_err)?)?); + while let Some(tuple) = iter.next_tuple().map_err(to_py_err)? { + rows.push(tuple_to_python_row(py, tuple)?); } iter.done().map_err(to_py_err)?; @@ -281,12 +330,12 @@ mod tests { py: Python<'_>, module: &Bound<'_, PyModule>, script: &'static CStr, - use_memory: bool, + backend: &str, db_path: &str, ) -> PyResult<()> { let locals = PyDict::new(py); locals.set_item("kite_sql", module)?; - locals.set_item("use_memory", use_memory)?; + locals.set_item("backend", backend)?; locals.set_item("db_path", db_path)?; py.run(script, None, Some(&locals)) } @@ -296,12 +345,24 @@ mod tests { module: &Bound<'_, PyModule>, script: &'static CStr, ) -> PyResult<()> { - run_script(py, module, script, true, "")?; + run_script(py, module, script, "memory", "")?; + + #[cfg(feature = "rocksdb")] + { + let temp_dir = TempDir::new() + .map_err(|e| PyRuntimeError::new_err(format!("create tempdir: {e}")))?; + let path = temp_dir.path().to_string_lossy().to_string(); - let temp_dir = - TempDir::new().map_err(|e| PyRuntimeError::new_err(format!("create tempdir: {e}")))?; - let path = temp_dir.path().to_string_lossy().to_string(); - run_script(py, module, script, false, &path)?; + run_script(py, module, script, "rocksdb", &path)?; + } + + #[cfg(feature = "lmdb")] + { + let temp_dir = TempDir::new() + .map_err(|e| PyRuntimeError::new_err(format!("create tempdir: {e}")))?; + let path = temp_dir.path().to_string_lossy().to_string(); + run_script(py, module, script, "lmdb", &path)?; + } Ok(()) } @@ -315,7 +376,7 @@ mod tests { &module, c_str!( r#" -db = kite_sql.Database.in_memory() if use_memory else kite_sql.Database(db_path) +db = kite_sql.Database.in_memory() if backend == "memory" else kite_sql.Database(db_path, backend) db.execute("drop table if exists my_struct") db.execute("create table my_struct (c1 int primary key, c2 int)") db.execute("insert into my_struct values(0, 0), (1, 1)") @@ -361,7 +422,7 @@ db.execute("drop table my_struct") &module, c_str!( r#" -db = kite_sql.Database.in_memory() if use_memory else kite_sql.Database(db_path) +db = kite_sql.Database.in_memory() if backend == "memory" else kite_sql.Database(db_path, backend) db.execute("drop table if exists t1") db.execute("create table t1(id int primary key, c1 int, c2 int)") @@ -419,4 +480,26 @@ db.execute("drop table t1") Ok(()) }) } + + #[test] + fn test_python_rejects_unknown_backend() -> PyResult<()> { + Python::with_gil(|py| { + let module = register_module(py)?; + let locals = PyDict::new(py); + locals.set_item("kite_sql", module)?; + py.run( + c_str!( + r#" +try: + kite_sql.Database("/tmp/kitesql-python-invalid", "unknown") + raise AssertionError("expected constructor to reject unknown backend") +except ValueError as exc: + assert "unsupported backend" in str(exc) +"# + ), + None, + Some(&locals), + ) + }) + } } diff --git a/src/storage/lmdb.rs b/src/storage/lmdb.rs new file mode 100644 index 00000000..5b916537 --- /dev/null +++ b/src/storage/lmdb.rs @@ -0,0 +1,459 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::storage::table_codec::{Bytes, TableCodec}; +use crate::storage::{reuse_bound_as_excluded, InnerIter, KeyValueRef, Storage, Transaction}; +use lmdb::{ + Cursor, Database, DatabaseFlags, Environment, EnvironmentFlags, RoCursor, RwTransaction, + Transaction as _, WriteFlags, +}; +use std::cmp::Ordering; +use std::collections::Bound; +use std::fmt::{self, Display, Formatter}; +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; + +const DEFAULT_MAP_SIZE: usize = 16 * 1024 * 1024 * 1024; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LmdbConfig { + pub enable_statistics: bool, + pub map_size: usize, + pub flags: EnvironmentFlags, + pub max_readers: Option, + pub max_dbs: Option, +} + +impl Default for LmdbConfig { + fn default() -> Self { + Self { + enable_statistics: false, + map_size: DEFAULT_MAP_SIZE, + flags: EnvironmentFlags::empty(), + max_readers: None, + max_dbs: None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LmdbMetrics { + pub map_size: usize, + pub page_size: u32, + pub depth: u32, + pub branch_pages: usize, + pub leaf_pages: usize, + pub overflow_pages: usize, + pub entries: usize, +} + +impl Display for LmdbMetrics { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + writeln!(f, "")?; + writeln!( + f, + "map_size={} page_size={} depth={}", + self.map_size, self.page_size, self.depth + )?; + write!( + f, + "branch_pages={} leaf_pages={} overflow_pages={} entries={}", + self.branch_pages, self.leaf_pages, self.overflow_pages, self.entries + ) + } +} + +#[derive(Clone)] +pub struct LmdbStorage { + env: Arc, + db: Database, + config: LmdbConfig, +} + +impl LmdbStorage { + pub fn new(path: impl Into + Send) -> Result { + Self::with_config(path, LmdbConfig::default()) + } + + pub fn with_config( + path: impl Into + Send, + config: LmdbConfig, + ) -> Result { + let path = path.into(); + fs::create_dir_all(&path)?; + + let mut builder = Environment::new(); + builder.set_map_size(config.map_size); + builder.set_flags(config.flags); + if let Some(max_readers) = config.max_readers { + builder.set_max_readers(max_readers); + } + if let Some(max_dbs) = config.max_dbs { + builder.set_max_dbs(max_dbs); + } + let env = builder.open(&path).map_err(map_lmdb_err)?; + let db = env + .create_db(None, DatabaseFlags::empty()) + .map_err(map_lmdb_err)?; + + Ok(Self { + env: Arc::new(env), + db, + config, + }) + } +} + +impl Storage for LmdbStorage { + type Metrics = LmdbMetrics; + + type TransactionType<'a> + = LmdbTransaction<'a> + where + Self: 'a; + + fn transaction(&self) -> Result, DatabaseError> { + let tx = self.env.begin_rw_txn().map_err(map_lmdb_err)?; + + Ok(LmdbTransaction { + tx, + db: self.db, + table_codec: Default::default(), + }) + } + + fn metrics(&self) -> Option { + if !self.config.enable_statistics { + return None; + } + let stat = self.env.stat().ok()?; + + Some(LmdbMetrics { + map_size: self.config.map_size, + page_size: stat.page_size(), + depth: stat.depth(), + branch_pages: stat.branch_pages(), + leaf_pages: stat.leaf_pages(), + overflow_pages: stat.overflow_pages(), + entries: stat.entries(), + }) + } +} + +pub struct LmdbTransaction<'env> { + tx: RwTransaction<'env>, + db: Database, + table_codec: TableCodec, +} + +pub struct LmdbIter<'txn> { + _cursor: RoCursor<'txn>, + iter: lmdb::Iter<'txn>, + pending: Option<(&'txn [u8], &'txn [u8])>, + max: Bound, + done: bool, +} + +impl LmdbIter<'_> { + fn next_visible(&mut self) -> Option<(&[u8], &[u8])> { + if let Some(entry) = self.pending.take() { + if within_upper_bound(entry.0, &self.max) { + return Some(entry); + } + self.done = true; + return None; + } + + if let Some((key, value)) = self.iter.next() { + if !within_upper_bound(key, &self.max) { + self.done = true; + return None; + } + return Some((key, value)); + } + + self.done = true; + None + } +} + +impl InnerIter for LmdbIter<'_> { + fn try_next(&mut self) -> Result>, DatabaseError> { + if self.done { + return Ok(None); + } + Ok(self.next_visible()) + } +} + +impl Transaction for LmdbTransaction<'_> { + type BorrowedBytes<'a> + = &'a [u8] + where + Self: 'a; + + type IterType<'a> + = LmdbIter<'a> + where + Self: 'a; + + fn table_codec(&self) -> *const TableCodec { + &self.table_codec + } + + fn get_borrowed<'a>( + &'a self, + key: &[u8], + ) -> Result>, DatabaseError> { + match self.tx.get(self.db, &key) { + Ok(value) => Ok(Some(value)), + Err(lmdb::Error::NotFound) => Ok(None), + Err(err) => Err(map_lmdb_err(err)), + } + } + + fn set(&mut self, key: &[u8], value: &[u8]) -> Result<(), DatabaseError> { + self.tx + .put(self.db, &key, &value, lmdb::WriteFlags::empty()) + .map_err(map_lmdb_err)?; + Ok(()) + } + + fn remove(&mut self, key: &[u8]) -> Result<(), DatabaseError> { + match self.tx.del(self.db, &key, None) { + Ok(()) | Err(lmdb::Error::NotFound) => Ok(()), + Err(err) => Err(map_lmdb_err(err)), + } + } + + fn range<'txn, 'key>( + &'txn self, + min: Bound<&'key [u8]>, + max: Bound<&'key [u8]>, + ) -> Result, DatabaseError> { + let mut cursor = self.tx.open_ro_cursor(self.db).map_err(map_lmdb_err)?; + let (pending, done) = initial_entry(&mut cursor, &min).map_err(map_lmdb_err)?; + let iter = cursor.iter(); + + Ok(LmdbIter { + _cursor: cursor, + iter, + pending, + max: owned_bound(max), + done, + }) + } + + fn remove_range(&mut self, min: Bound<&[u8]>, max: Bound<&[u8]>) -> Result<(), DatabaseError> { + let mut cursor = self.tx.open_rw_cursor(self.db).map_err(map_lmdb_err)?; + let upper = owned_bound(max); + let mut lower = owned_bound(min); + let mut seek_key = Bytes::new(); + + loop { + let entry = cursor_seek(&mut cursor, &lower, &mut seek_key).map_err(map_lmdb_err)?; + let Some((key, _)) = entry else { + return Ok(()); + }; + if !within_upper_bound(key, &upper) { + return Ok(()); + } + + reuse_bound_as_excluded(&mut lower, key); + cursor.del(WriteFlags::empty()).map_err(map_lmdb_err)?; + } + } + + fn commit(self) -> Result<(), DatabaseError> { + self.tx.commit().map_err(map_lmdb_err)?; + Ok(()) + } +} + +fn initial_entry<'txn>( + cursor: &mut RoCursor<'txn>, + min: &Bound<&[u8]>, +) -> Result<(Option>, bool), lmdb::Error> { + match min { + Bound::Unbounded => Ok((None, false)), + Bound::Included(min) => match cursor.get(Some(*min), None, lmdb_sys::MDB_SET_RANGE) { + Ok((key, value)) => Ok((Some((key.unwrap_or_default(), value)), false)), + Err(lmdb::Error::NotFound) => Ok((None, true)), + Err(err) => Err(err), + }, + Bound::Excluded(min) => match cursor.get(Some(*min), None, lmdb_sys::MDB_SET_RANGE) { + Ok((key, value)) => { + let key = key.unwrap_or_default(); + if key == *min { + Ok((None, false)) + } else { + Ok((Some((key, value)), false)) + } + } + Err(lmdb::Error::NotFound) => Ok((None, true)), + Err(err) => Err(err), + }, + } +} + +fn cursor_seek<'txn>( + cursor: &mut lmdb::RwCursor<'txn>, + lower: &Bound, + seek_key: &mut Bytes, +) -> Result>, lmdb::Error> { + match lower { + Bound::Unbounded => match cursor.get(None, None, lmdb_sys::MDB_FIRST) { + Ok((key, value)) => Ok(Some((key.unwrap_or_default(), value))), + Err(lmdb::Error::NotFound) => Ok(None), + Err(err) => Err(err), + }, + Bound::Included(min) => { + match cursor.get(Some(min.as_slice()), None, lmdb_sys::MDB_SET_RANGE) { + Ok((key, value)) => Ok(Some((key.unwrap_or_default(), value))), + Err(lmdb::Error::NotFound) => Ok(None), + Err(err) => Err(err), + } + } + Bound::Excluded(min) => { + seek_key.clear(); + seek_key.extend_from_slice(min.as_slice()); + seek_key.push(0); + match cursor.get(Some(seek_key.as_slice()), None, lmdb_sys::MDB_SET_RANGE) { + Ok((key, value)) => Ok(Some((key.unwrap_or_default(), value))), + Err(lmdb::Error::NotFound) => Ok(None), + Err(err) => Err(err), + } + } + } +} + +fn owned_bound(bound: Bound<&[u8]>) -> Bound { + match bound { + Bound::Included(bytes) => Bound::Included(bytes.to_vec()), + Bound::Excluded(bytes) => Bound::Excluded(bytes.to_vec()), + Bound::Unbounded => Bound::Unbounded, + } +} + +fn within_upper_bound(key: &[u8], max: &Bound) -> bool { + match max { + Bound::Included(max) => key.cmp(max.as_slice()) != Ordering::Greater, + Bound::Excluded(max) => key.cmp(max.as_slice()) == Ordering::Less, + Bound::Unbounded => true, + } +} + +fn map_lmdb_err(err: impl std::fmt::Display) -> DatabaseError { + DatabaseError::InvalidValue(format!("lmdb: {err}")) +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::{LmdbConfig, LmdbStorage}; + use crate::db::DataBaseBuilder; + use lmdb::EnvironmentFlags; + use tempfile::TempDir; + + #[test] + fn lmdb_backend_smoke() { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let db_path = temp_dir.path().join("kite_sql.lmdb"); + let kite_sql = DataBaseBuilder::path(db_path).build_lmdb().unwrap(); + + kite_sql + .run("create table t1 (a int primary key, b int)") + .unwrap() + .done() + .unwrap(); + kite_sql + .run("insert into t1 values (1, 10), (2, 20), (3, 30)") + .unwrap() + .done() + .unwrap(); + + let mut iter = kite_sql.run("select b from t1 where a = 2").unwrap(); + let tuple = iter.next().unwrap().unwrap(); + assert_eq!(tuple.values[0].to_string(), "20"); + iter.done().unwrap(); + } + + #[test] + fn build_with_lmdb_storage() { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let db_path = temp_dir.path().join("kite_sql.lmdb"); + let storage = LmdbStorage::new(db_path).unwrap(); + let kite_sql = DataBaseBuilder::path(temp_dir.path()) + .build_with_storage(storage) + .unwrap(); + + kite_sql + .run("create table t1 (a int primary key)") + .unwrap() + .done() + .unwrap(); + } + + #[test] + fn collect_lmdb_metrics_snapshot() { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let db_path = temp_dir.path().join("kite_sql.lmdb"); + let kite_sql = DataBaseBuilder::path(db_path) + .storage_statistics(true) + .lmdb_flags(EnvironmentFlags::NO_SYNC) + .lmdb_map_size(64 * 1024 * 1024) + .build_lmdb() + .unwrap(); + + kite_sql + .run("create table t_metrics (a int primary key, b int)") + .unwrap() + .done() + .unwrap(); + kite_sql + .run("insert into t_metrics values (1, 10), (2, 20), (3, 30)") + .unwrap() + .done() + .unwrap(); + + let metrics = kite_sql.storage_metrics().unwrap(); + assert_eq!(metrics.map_size, 64 * 1024 * 1024); + assert!(metrics.entries > 0); + } + + #[test] + fn build_lmdb_with_config() { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let db_path = temp_dir.path().join("kite_sql.lmdb"); + let storage = LmdbStorage::with_config( + db_path, + LmdbConfig { + map_size: 32 * 1024 * 1024, + flags: EnvironmentFlags::NO_SYNC, + ..LmdbConfig::default() + }, + ) + .unwrap(); + let kite_sql = DataBaseBuilder::path(temp_dir.path()) + .build_with_storage(storage) + .unwrap(); + + kite_sql + .run("create table t1 (a int primary key)") + .unwrap() + .done() + .unwrap(); + } +} diff --git a/src/storage/memory.rs b/src/storage/memory.rs index ee50c19e..ba8f44b9 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -13,9 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::storage::table_codec::{BumpBytes, Bytes, TableCodec}; +use crate::storage::table_codec::{Bytes, TableCodec}; use crate::storage::{EmptyStorageMetrics, InnerIter, Storage, Transaction}; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::collections::{BTreeMap, Bound, VecDeque}; use std::rc::Rc; @@ -53,15 +53,36 @@ pub struct MemoryTransaction { pub struct MemoryIter { entries: VecDeque<(Bytes, Bytes)>, + current: Option<(Bytes, Bytes)>, +} + +pub struct MemoryValue<'a> { + value: Ref<'a, Vec>, +} + +impl AsRef<[u8]> for MemoryValue<'_> { + fn as_ref(&self) -> &[u8] { + self.value.as_slice() + } } impl InnerIter for MemoryIter { - fn try_next(&mut self) -> Result, DatabaseError> { - Ok(self.entries.pop_front()) + fn try_next(&mut self) -> Result>, DatabaseError> { + self.current = self.entries.pop_front(); + + Ok(self + .current + .as_ref() + .map(|(key, value)| (key.as_slice(), value.as_slice()))) } } impl Transaction for MemoryTransaction { + type BorrowedBytes<'a> + = MemoryValue<'a> + where + Self: 'a; + type IterType<'a> = MemoryIter where @@ -71,11 +92,19 @@ impl Transaction for MemoryTransaction { &self.table_codec } - fn get(&self, key: &[u8]) -> Result, DatabaseError> { - Ok(self.inner.borrow().get(key).cloned()) + fn get_borrowed<'a>( + &'a self, + key: &[u8], + ) -> Result>, DatabaseError> { + let map = self.inner.borrow(); + let Ok(value) = Ref::filter_map(map, |map| map.get(key)) else { + return Ok(None); + }; + + Ok(Some(MemoryValue { value })) } - fn set(&mut self, key: BumpBytes, value: BumpBytes) -> Result<(), DatabaseError> { + fn set(&mut self, key: &[u8], value: &[u8]) -> Result<(), DatabaseError> { self.inner.borrow_mut().insert(key.to_vec(), value.to_vec()); Ok(()) } @@ -85,20 +114,20 @@ impl Transaction for MemoryTransaction { Ok(()) } - fn range<'a>( - &'a self, - min: Bound>, - max: Bound>, - ) -> Result, DatabaseError> { + fn range<'txn, 'key>( + &'txn self, + min: Bound<&'key [u8]>, + max: Bound<&'key [u8]>, + ) -> Result, DatabaseError> { let map = self.inner.borrow(); let start = match &min { - Bound::Included(b) => Bound::Included(b.as_ref()), - Bound::Excluded(b) => Bound::Excluded(b.as_ref()), + Bound::Included(b) => Bound::Included(*b), + Bound::Excluded(b) => Bound::Excluded(*b), Bound::Unbounded => Bound::Unbounded, }; let end = match &max { - Bound::Included(b) => Bound::Included(b.as_ref()), - Bound::Excluded(b) => Bound::Excluded(b.as_ref()), + Bound::Included(b) => Bound::Included(*b), + Bound::Excluded(b) => Bound::Excluded(*b), Bound::Unbounded => Bound::Unbounded, }; @@ -107,7 +136,10 @@ impl Transaction for MemoryTransaction { .map(|(k, v)| (k.clone(), v.clone())) .collect(); - Ok(MemoryIter { entries }) + Ok(MemoryIter { + entries, + current: None, + }) } fn commit(self) -> Result<(), DatabaseError> { @@ -119,9 +151,8 @@ impl Transaction for MemoryTransaction { mod wasm_tests { use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::expression::range_detacher::Range; - use crate::storage::Iter; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -197,10 +228,10 @@ mod wasm_tests { true, )?; - let option_1 = iter.next_tuple()?; + let option_1 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_1.unwrap().pk, Some(DataValue::Int32(2))); - let option_2 = iter.next_tuple()?; + let option_2 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_2, None); Ok(()) @@ -239,7 +270,7 @@ mod wasm_tests { )?; let mut result = Vec::new(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { result.push(tuple.pk.unwrap()); } @@ -253,9 +284,8 @@ mod wasm_tests { mod native_tests { use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::expression::range_detacher::Range; - use crate::storage::Iter; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -330,10 +360,10 @@ mod native_tests { true, )?; - let option_1 = iter.next_tuple()?; + let option_1 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_1.unwrap().pk, Some(DataValue::Int32(2))); - let option_2 = iter.next_tuple()?; + let option_2 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_2, None); Ok(()) @@ -372,7 +402,7 @@ mod native_tests { )?; let mut result = Vec::new(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { result.push(tuple.pk.unwrap()); } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index ec24f0f5..c910bbfe 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] +pub mod lmdb; pub mod memory; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] pub mod rocksdb; pub(crate) mod table_codec; @@ -25,12 +27,12 @@ use crate::expression::ScalarExpression; use crate::optimizer::core::cm_sketch::{ CountMinSketch, CountMinSketchPage, COUNT_MIN_SKETCH_STORAGE_PAGE_LEN, }; -use crate::optimizer::core::statistics_meta::{ - StatisticMetaLoader, StatisticsMeta, StatisticsMetaCacheValue, -}; +use crate::optimizer::core::statistics_meta::{StatisticMetaLoader, StatisticsMeta}; use crate::planner::operator::alter_table::change_column::{DefaultChange, NotNullChange}; use crate::serdes::ReferenceTables; -use crate::storage::table_codec::{BumpBytes, Bytes, StatisticsCodecType, TableCodec}; +use crate::storage::table_codec::{ + BumpBytes, Bytes, StatisticsCodecType, TableCodec, BOUND_MAX_TAG, +}; use crate::types::index::{Index, IndexId, IndexMeta, IndexMetaRef, IndexType}; use crate::types::serialize::TupleValueSerializableImpl; use crate::types::tuple::{Tuple, TupleId}; @@ -44,11 +46,12 @@ use std::io::Cursor; use std::mem; use std::ops::SubAssign; use std::sync::Arc; + +pub type KeyValueRef<'a> = (&'a [u8], &'a [u8]); use std::vec::IntoIter; use ulid::Generator; -pub(crate) type StatisticsMetaCache = - SharedLruCache<(TableName, IndexId), Option>; +pub(crate) type StatisticsMetaCache = SharedLruCache<(TableName, IndexId), Option>; pub(crate) type TableCache = SharedLruCache; pub(crate) type ViewCache = SharedLruCache; @@ -100,6 +103,10 @@ pub trait Storage: Clone { pub(crate) type Bounds = (Option, Option); pub trait Transaction: Sized { + type BorrowedBytes<'a>: AsRef<[u8]> + where + Self: 'a; + type IterType<'a>: InnerIter where Self: 'a; @@ -130,17 +137,17 @@ pub trait Transaction: Sized { let deserializers = Self::create_deserializers(&columns, table); let pk_ty = with_pk.then(|| table.primary_keys_type().clone()); - let (min, max) = unsafe { &*self.table_codec() }.tuple_bound(&table_name); - let iter = self.range(Bound::Included(min), Bound::Included(max))?; - - Ok(TupleIter { - offset: bounds.0.unwrap_or(0), - limit: bounds.1, - pk_ty, - deserializers, - values_len: columns.len(), - total_len: table.columns_len(), - iter, + unsafe { &*self.table_codec() }.with_tuple_bound(&table_name, |min, max| { + let iter = self.range(Bound::Included(min), Bound::Included(max))?; + + Ok(TupleIter { + offset: bounds.0.unwrap_or(0), + limit: bounds.1, + pk_ty, + deserializers, + total_len: table.columns_len(), + iter, + }) }) } @@ -158,8 +165,6 @@ pub trait Transaction: Sized { cover_mapping_indices: Option>, ) -> Result, DatabaseError> { debug_assert!(columns.keys().all_unique()); - let values_len = columns.len(); - let table = self .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; @@ -203,7 +208,6 @@ pub trait Transaction: Sized { index_meta, table_name, deserializers, - values_len, total_len: table.columns_len(), tx: self, cover_mapping, @@ -212,6 +216,8 @@ pub trait Transaction: Sized { inner, ranges: ranges.into_iter(), state: IndexIterState::Init, + encode_min_buffer: Bytes::new(), + encode_max_buffer: Bytes::new(), }) } @@ -248,9 +254,12 @@ pub trait Transaction: Sized { ) -> Result { if let Some(mut table) = self.table(table_cache, table_name.clone())?.cloned() { let index_meta = table.add_index_meta(index_name, column_ids, ty)?; - let (key, value) = - unsafe { &*self.table_codec() }.encode_index_meta(table_name, index_meta)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() }.encode_index_meta_value(index_meta)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + table_name, + index_meta.id, + |key| self.set(key, value.as_slice()), + )?; table_cache.remove(table_name); Ok(index_meta.id) @@ -268,21 +277,21 @@ pub trait Transaction: Sized { if matches!(index.ty, IndexType::PrimaryKey { .. }) { return Ok(()); } - let (key, value) = - unsafe { &*self.table_codec() }.encode_index(table_name, &index, tuple_id)?; + let mut value = BumpBytes::new_in(unsafe { &*self.table_codec() }.arena()); + bincode::serialize_into(&mut value, tuple_id)?; - if matches!(index.ty, IndexType::Unique) { - if let Some(bytes) = self.get(&key)? { - return if bytes != value.as_slice() { - Err(DatabaseError::DuplicateUniqueValue) - } else { - Ok(()) - }; + unsafe { &*self.table_codec() }.with_index_key(table_name, &index, Some(tuple_id), |key| { + if matches!(index.ty, IndexType::Unique) { + if let Some(bytes) = self.get_borrowed(key)? { + return if bytes.as_ref() != value.as_slice() { + Err(DatabaseError::DuplicateUniqueValue) + } else { + Ok(()) + }; + } } - } - self.set(key, value)?; - - Ok(()) + self.set(key, value.as_slice()) + }) } fn del_index( @@ -294,11 +303,12 @@ pub trait Transaction: Sized { if matches!(index.ty, IndexType::PrimaryKey { .. }) { return Ok(()); } - self.remove(&unsafe { &*self.table_codec() }.encode_index_key( + unsafe { &*self.table_codec() }.with_index_key( table_name, index, Some(tuple_id), - )?)?; + |key| self.remove(key), + )?; Ok(()) } @@ -306,26 +316,23 @@ pub trait Transaction: Sized { fn append_tuple( &mut self, table_name: &str, - mut tuple: Tuple, + tuple: Tuple, serializers: &[TupleValueSerializableImpl], is_overwrite: bool, ) -> Result<(), DatabaseError> { - let (key, value) = - unsafe { &*self.table_codec() }.encode_tuple(table_name, &mut tuple, serializers)?; - - if !is_overwrite && self.get(&key)?.is_some() { - return Err(DatabaseError::DuplicatePrimaryKey); - } - self.set(key, value)?; + let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; + let value = tuple.serialize_to(serializers, unsafe { &*self.table_codec() }.arena())?; - Ok(()) + unsafe { &*self.table_codec() }.with_tuple_key(table_name, tuple_id, |key| { + if !is_overwrite && self.exists(key)? { + return Err(DatabaseError::DuplicatePrimaryKey); + } + self.set(key, value.as_slice()) + }) } fn remove_tuple(&mut self, table_name: &str, tuple_id: &TupleId) -> Result<(), DatabaseError> { - let key = unsafe { &*self.table_codec() }.encode_tuple_key(table_name, tuple_id)?; - self.remove(&key)?; - - Ok(()) + unsafe { &*self.table_codec() }.with_tuple_key(table_name, tuple_id, |key| self.remove(key)) } fn rewrite_table_metadata( @@ -334,31 +341,37 @@ pub trait Transaction: Sized { table: &TableCatalog, ) -> Result<(), DatabaseError> { let table_name = table.name().clone(); - let (column_min, column_max) = - unsafe { &*self.table_codec() }.columns_bound(table_name.as_ref()); - self._drop_data(column_min, column_max)?; + unsafe { &*self.table_codec() }.with_columns_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - let (index_meta_min, index_meta_max) = - unsafe { &*self.table_codec() }.index_meta_bound(table_name.as_ref()); - self._drop_data(index_meta_min, index_meta_max)?; + unsafe { &*self.table_codec() } + .with_index_meta_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; let mut reference_tables = ReferenceTables::new(); let _ = reference_tables.push_or_replace(table.name()); for column in table.columns() { - let (key, value) = - unsafe { &*self.table_codec() }.encode_column(column, &mut reference_tables)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() } + .encode_column_value(column, &mut reference_tables)?; + unsafe { &*self.table_codec() } + .with_column_key(column, |key| self.set(key, value.as_slice()))?; } for index_meta in table.indexes() { - let (key, value) = - unsafe { &*self.table_codec() }.encode_index_meta(table.name(), index_meta)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() }.encode_index_meta_value(index_meta)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + table.name(), + index_meta.id, + |key| self.set(key, value.as_slice()), + )?; } table_cache.remove(table.name()); Ok(()) } + #[allow(clippy::too_many_arguments)] fn change_column( &mut self, table_cache: &TableCache, @@ -472,15 +485,19 @@ pub trait Transaction: Sized { vec![col_id], IndexType::Unique, )?; - let (key, value) = - unsafe { &*self.table_codec() }.encode_index_meta(table_name, meta_ref)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + table_name, + meta_ref.id, + |key| self.set(key, value.as_slice()), + )?; } let column = table.get_column_by_id(&col_id).unwrap(); - let (key, value) = unsafe { &*self.table_codec() } - .encode_column(column, &mut ReferenceTables::new())?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() } + .encode_column_value(column, &mut ReferenceTables::new())?; + unsafe { &*self.table_codec() } + .with_column_key(column, |key| self.set(key, value.as_slice()))?; table_cache.remove(table_name); Ok(col_id) @@ -499,21 +516,23 @@ pub trait Transaction: Sized { if let Some(table_catalog) = self.table(table_cache, table_name.clone())?.cloned() { let column = table_catalog.get_column_by_name(column_name).unwrap(); - let (key, _) = unsafe { &*self.table_codec() } - .encode_column(column, &mut ReferenceTables::new())?; - self.remove(&key)?; + unsafe { &*self.table_codec() }.with_column_key(column, |key| self.remove(key))?; for index_meta in table_catalog.indexes.iter() { if !index_meta.column_ids.contains(&column.id().unwrap()) { continue; } - let (index_meta_key, _) = - unsafe { &*self.table_codec() }.encode_index_meta(table_name, index_meta)?; - self.remove(&index_meta_key)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + table_name, + index_meta.id, + |key| self.remove(key), + )?; - let (index_min, index_max) = - unsafe { &*self.table_codec() }.index_bound(table_name, index_meta.id)?; - self._drop_data(index_min, index_max)?; + unsafe { &*self.table_codec() }.with_index_bound( + table_name, + index_meta.id, + |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), + )?; self.remove_statistics_meta(meta_cache, table_name, index_meta.id)?; } @@ -531,16 +550,18 @@ pub trait Transaction: Sized { view: View, or_replace: bool, ) -> Result<(), DatabaseError> { - let (view_key, value) = unsafe { &*self.table_codec() }.encode_view(&view)?; + let value = unsafe { &*self.table_codec() }.encode_view_value(&view)?; - let already_exists = self.get(&view_key)?.is_some(); + let already_exists = + unsafe { &*self.table_codec() }.with_view_key(&view.name, |key| self.exists(key))?; if !or_replace && already_exists { return Err(DatabaseError::ViewExists); } if !already_exists { self.check_name_hash(&view.name)?; } - self.set(view_key, value)?; + unsafe { &*self.table_codec() } + .with_view_key(&view.name, |key| self.set(key, value.as_slice()))?; let _ = view_cache.put(view.name.clone(), view); Ok(()) @@ -559,9 +580,11 @@ pub trait Transaction: Sized { TableCodec::check_primary_key_type(column.datatype())?; } - let (table_key, value) = unsafe { &*self.table_codec() } - .encode_root_table(&TableMeta::empty(table_name.clone()))?; - if self.get(&table_key)?.is_some() { + let value = unsafe { &*self.table_codec() } + .encode_root_table_value(&TableMeta::empty(table_name.clone()))?; + let exists = unsafe { &*self.table_codec() } + .with_root_table_key(table_name.as_ref(), |key| self.exists(key))?; + if exists { if if_not_exists { return Ok(table_name); } @@ -569,13 +592,15 @@ pub trait Transaction: Sized { } self.check_name_hash(&table_name)?; self.create_index_meta_from_column(&mut table_catalog)?; - self.set(table_key, value)?; + unsafe { &*self.table_codec() } + .with_root_table_key(table_name.as_ref(), |key| self.set(key, value.as_slice()))?; let mut reference_tables = ReferenceTables::new(); for column in table_catalog.columns() { - let (key, value) = - unsafe { &*self.table_codec() }.encode_column(column, &mut reference_tables)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() } + .encode_column_value(column, &mut reference_tables)?; + unsafe { &*self.table_codec() } + .with_column_key(column, |key| self.set(key, value.as_slice()))?; } debug_assert_eq!(reference_tables.len(), 1); table_cache.put(table_name.clone(), table_catalog); @@ -584,15 +609,16 @@ pub trait Transaction: Sized { } fn check_name_hash(&mut self, table_name: &TableName) -> Result<(), DatabaseError> { - let (hash_key, value) = unsafe { &*self.table_codec() }.encode_table_hash(table_name); - if self.get(&hash_key)?.is_some() { + if unsafe { &*self.table_codec() } + .with_table_hash_key(table_name, |key| self.exists(key))? + { return Err(DatabaseError::DuplicateSourceHash(table_name.to_string())); } - self.set(hash_key, value) + unsafe { &*self.table_codec() }.with_table_hash_key(table_name, |key| self.set(key, &[])) } fn drop_name_hash(&mut self, table_name: &TableName) -> Result<(), DatabaseError> { - self.remove(&unsafe { &*self.table_codec() }.encode_table_hash_key(table_name)) + unsafe { &*self.table_codec() }.with_table_hash_key(table_name, |key| self.remove(key)) } fn drop_view( @@ -614,7 +640,8 @@ pub trait Transaction: Sized { } } - self.remove(&unsafe { &*self.table_codec() }.encode_view_key(view_name.as_ref()))?; + unsafe { &*self.table_codec() } + .with_view_key(view_name.as_ref(), |key| self.remove(key))?; view_cache.remove(&view_name); Ok(()) @@ -644,13 +671,17 @@ pub trait Transaction: Sized { } let index_id = index_meta.id; - let index_meta_key = - unsafe { &*self.table_codec() }.encode_index_meta_key(table_name.as_ref(), index_id)?; - self.remove(&index_meta_key)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + table_name.as_ref(), + index_id, + |key| self.remove(key), + )?; - let (index_min, index_max) = - unsafe { &*self.table_codec() }.index_bound(table_name.as_ref(), index_id)?; - self._drop_data(index_min, index_max)?; + unsafe { &*self.table_codec() }.with_index_bound( + table_name.as_ref(), + index_id, + |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), + )?; self.remove_statistics_meta(meta_cache, &table_name, index_id)?; @@ -675,32 +706,34 @@ pub trait Transaction: Sized { self.drop_name_hash(&table_name)?; self.drop_data(table_name.as_ref())?; - let (column_min, column_max) = - unsafe { &*self.table_codec() }.columns_bound(table_name.as_ref()); - self._drop_data(column_min, column_max)?; + unsafe { &*self.table_codec() }.with_columns_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - let (index_meta_min, index_meta_max) = - unsafe { &*self.table_codec() }.index_meta_bound(table_name.as_ref()); - self._drop_data(index_meta_min, index_meta_max)?; + unsafe { &*self.table_codec() } + .with_index_meta_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - self.remove(&unsafe { &*self.table_codec() }.encode_root_table_key(table_name.as_ref()))?; + unsafe { &*self.table_codec() } + .with_root_table_key(table_name.as_ref(), |key| self.remove(key))?; table_cache.remove(&table_name); Ok(()) } fn drop_data(&mut self, table_name: &str) -> Result<(), DatabaseError> { - let (tuple_min, tuple_max) = unsafe { &*self.table_codec() }.tuple_bound(table_name); - self._drop_data(tuple_min, tuple_max)?; + unsafe { &*self.table_codec() }.with_tuple_bound(table_name, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - let (index_min, index_max) = unsafe { &*self.table_codec() }.all_index_bound(table_name); - self._drop_data(index_min, index_max)?; + unsafe { &*self.table_codec() }.with_all_index_bound(table_name, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - let (statistics_min, statistics_max) = - unsafe { &*self.table_codec() }.statistics_bound(table_name); - self._drop_data(statistics_min, statistics_max)?; - - Ok(()) + unsafe { &*self.table_codec() }.with_statistics_bound(table_name, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + }) } fn view<'a>( @@ -712,27 +745,29 @@ pub trait Transaction: Sized { if let Some(view) = view_cache.get(&view_name) { return Ok(Some(view)); } - let Some(bytes) = self.get(&unsafe { &*self.table_codec() }.encode_view_key(&view_name))? - else { - return Ok(None); - }; - Ok(Some(view_cache.get_or_insert(view_name.clone(), |_| { - TableCodec::decode_view(&bytes, (self, table_cache)) - })?)) + unsafe { &*self.table_codec() }.with_view_key(&view_name, |key| { + let Some(bytes) = self.get_borrowed(key)? else { + return Ok(None); + }; + Ok(Some(view_cache.get_or_insert(view_name.clone(), |_| { + TableCodec::decode_view(bytes.as_ref(), (self, table_cache)) + })?)) + }) } fn views(&self, table_cache: &TableCache) -> Result, DatabaseError> { let mut metas = vec![]; - let (min, max) = unsafe { &*self.table_codec() }.view_bound(); - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; + unsafe { &*self.table_codec() }.with_view_bound(|min, max| { + let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - while let Some((_, value)) = iter.try_next().ok().flatten() { - let meta = TableCodec::decode_view(&value, (self, table_cache))?; + while let Some((_, value)) = iter.try_next().ok().flatten() { + let meta = TableCodec::decode_view(value, (self, table_cache))?; - metas.push(meta); - } + metas.push(meta); + } - Ok(metas) + Ok(metas) + }) } fn table<'a>( @@ -756,16 +791,17 @@ pub trait Transaction: Sized { fn table_metas(&self) -> Result, DatabaseError> { let mut metas = vec![]; - let (min, max) = unsafe { &*self.table_codec() }.root_table_bound(); - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; + unsafe { &*self.table_codec() }.with_root_table_bound(|min, max| { + let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - while let Some((_, value)) = iter.try_next().ok().flatten() { - let meta = TableCodec::decode_root_table::(&value)?; + while let Some((_, value)) = iter.try_next().ok().flatten() { + let meta = TableCodec::decode_root_table::(value)?; - metas.push(meta); - } + metas.push(meta); + } - Ok(metas) + Ok(metas) + }) } fn save_statistics_meta( @@ -773,45 +809,47 @@ pub trait Transaction: Sized { meta_cache: &StatisticsMetaCache, table_name: &TableName, statistics_meta: StatisticsMeta, - cm_sketch: CountMinSketch, ) -> Result<(), DatabaseError> { let index_id = statistics_meta.index_id(); - let (root, buckets) = statistics_meta.clone().into_parts(); - let (key, value) = - unsafe { &*self.table_codec() }.encode_statistics_meta(table_name.as_ref(), &root)?; - self.set(key, value)?; - let cached_sketch = cm_sketch.clone(); + let cached_meta = statistics_meta.clone(); + let (root, buckets, cm_sketch) = statistics_meta.into_parts(); + let value = unsafe { &*self.table_codec() }.encode_statistics_meta_value(&root)?; + unsafe { &*self.table_codec() }.with_statistics_meta_key( + table_name.as_ref(), + index_id, + |key| self.set(key, value.as_slice()), + )?; let (sketch_meta, sketch_pages) = cm_sketch.into_storage_parts(COUNT_MIN_SKETCH_STORAGE_PAGE_LEN); - let (key, value) = unsafe { &*self.table_codec() }.encode_statistics_sketch_meta( + let value = + unsafe { &*self.table_codec() }.encode_statistics_sketch_meta_value(&sketch_meta)?; + unsafe { &*self.table_codec() }.with_statistics_sketch_meta_key( table_name.as_ref(), index_id, - &sketch_meta, + |key| self.set(key, value.as_slice()), )?; - self.set(key, value)?; for sketch_page in sketch_pages { - let (key, value) = unsafe { &*self.table_codec() }.encode_statistics_sketch_page( + let value = unsafe { &*self.table_codec() } + .encode_statistics_sketch_page_value(&sketch_page)?; + unsafe { &*self.table_codec() }.with_statistics_sketch_page_key( table_name.as_ref(), index_id, &sketch_page, + |key| self.set(key, value.as_slice()), )?; - self.set(key, value)?; } for (ordinal, bucket) in buckets.iter().enumerate() { - let (key, value) = unsafe { &*self.table_codec() }.encode_statistics_bucket( + let value = unsafe { &*self.table_codec() }.encode_statistics_bucket_value(bucket)?; + unsafe { &*self.table_codec() }.with_statistics_bucket_key( table_name.as_ref(), index_id, ordinal as u32, - bucket, + |key| self.set(key, value.as_slice()), )?; - self.set(key, value)?; } - meta_cache.put( - (table_name.clone(), index_id), - Some(StatisticsMetaCacheValue::new(statistics_meta).with_sketch(cached_sketch)), - ); + meta_cache.put((table_name.clone(), index_id), Some(cached_meta)); Ok(()) } @@ -821,75 +859,47 @@ pub trait Transaction: Sized { table_name: &str, index_id: IndexId, ) -> Result, DatabaseError> { - let (min, max) = - unsafe { &*self.table_codec() }.statistics_index_bound(table_name, index_id); - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - let mut root = None; - let mut buckets = Vec::new(); - let mut has_extra = false; - - while let Some((key, value)) = iter.try_next()? { - match unsafe { &*self.table_codec() } - .decode_statistics_codec_type(table_name, index_id, &key)? - { - StatisticsCodecType::Root => { - root = Some(TableCodec::decode_statistics_meta::(&value)?); - } - StatisticsCodecType::SketchMeta | StatisticsCodecType::SketchPage => { - has_extra = true - } - StatisticsCodecType::Bucket => { - buckets.push(TableCodec::decode_statistics_bucket::(&value)?); + unsafe { &*self.table_codec() }.with_statistics_index_bound( + table_name, + index_id, + |min, max| { + let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; + let mut root = None; + let mut buckets = Vec::new(); + let mut sketch_meta = None; + let mut sketch_pages = Vec::::new(); + + while let Some((key, value)) = iter.try_next()? { + match unsafe { &*self.table_codec() }.decode_statistics_codec_type(key)? { + StatisticsCodecType::Root => { + root = Some(TableCodec::decode_statistics_meta::(value)?); + } + StatisticsCodecType::SketchMeta => { + sketch_meta = + Some(TableCodec::decode_statistics_sketch_meta::(value)?); + } + StatisticsCodecType::SketchPage => { + sketch_pages + .push(TableCodec::decode_statistics_sketch_page::(value)?); + } + StatisticsCodecType::Bucket => { + buckets.push(TableCodec::decode_statistics_bucket::(value)?); + } + } } - } - } - - match root { - Some(root) => StatisticsMeta::from_parts(root, buckets).map(Some), - None if !has_extra && buckets.is_empty() => Ok(None), - _ => Err(DatabaseError::InvalidValue( - "statistics meta is incomplete".to_string(), - )), - } - } - fn statistics_sketch( - &self, - table_name: &str, - index_id: IndexId, - ) -> Result>, DatabaseError> { - let (min, max) = - unsafe { &*self.table_codec() }.statistics_index_bound(table_name, index_id); - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - let mut sketch_meta = None; - let mut sketch_pages = Vec::::new(); - let mut has_root_or_bucket = false; - - while let Some((key, value)) = iter.try_next()? { - match unsafe { &*self.table_codec() } - .decode_statistics_codec_type(table_name, index_id, &key)? - { - StatisticsCodecType::Root | StatisticsCodecType::Bucket => { - has_root_or_bucket = true - } - StatisticsCodecType::SketchMeta => { - sketch_meta = Some(TableCodec::decode_statistics_sketch_meta::(&value)?); - } - StatisticsCodecType::SketchPage => { - sketch_pages.push(TableCodec::decode_statistics_sketch_page::(&value)?); + match (root, sketch_meta) { + (Some(root), Some(sketch_meta)) => { + let sketch = CountMinSketch::from_storage_parts(sketch_meta, sketch_pages)?; + StatisticsMeta::from_parts(root, buckets, sketch).map(Some) + } + (None, None) if buckets.is_empty() && sketch_pages.is_empty() => Ok(None), + _ => Err(DatabaseError::InvalidValue( + "statistics meta is incomplete".to_string(), + )), } - } - } - - match sketch_meta { - Some(sketch_meta) => { - CountMinSketch::from_storage_parts(sketch_meta, sketch_pages).map(Some) - } - None if !has_root_or_bucket && sketch_pages.is_empty() => Ok(None), - _ => Err(DatabaseError::InvalidValue( - "statistics sketch is incomplete".to_string(), - )), - } + }, + ) } fn remove_statistics_meta( @@ -898,9 +908,11 @@ pub trait Transaction: Sized { table_name: &TableName, index_id: IndexId, ) -> Result<(), DatabaseError> { - let (min, max) = - unsafe { &*self.table_codec() }.statistics_index_bound(table_name.as_ref(), index_id); - self._drop_data(min, max)?; + unsafe { &*self.table_codec() }.with_statistics_index_bound( + table_name.as_ref(), + index_id, + |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), + )?; meta_cache.remove(&(table_name.clone(), index_id)); @@ -922,47 +934,30 @@ pub trait Transaction: Sized { &self, table_name: &TableName, ) -> Result, Vec)>, DatabaseError> { - let (table_min, table_max) = unsafe { &*self.table_codec() }.table_bound(table_name); - let mut column_iter = self.range( - Bound::Included(table_min.clone()), - Bound::Included(table_max), - )?; - - let mut columns = Vec::new(); - let mut index_metas = Vec::new(); - let mut reference_tables = ReferenceTables::new(); - let _ = reference_tables.push_or_replace(table_name); - - // Tips: only `Column`, `IndexMeta`, `TableMeta` - while let Some((key, value)) = column_iter.try_next().ok().flatten() { - if key.starts_with(&table_min) { - let mut cursor = Cursor::new(value); - columns.push(TableCodec::decode_column::( - &mut cursor, - &reference_tables, - )?); - } else { - index_metas.push(Arc::new(TableCodec::decode_index_meta::(&value)?)); + unsafe { &*self.table_codec() }.with_table_bound(table_name, |table_min, table_max| { + let mut column_iter = + self.range(Bound::Included(table_min), Bound::Included(table_max))?; + + let mut columns = Vec::new(); + let mut index_metas = Vec::new(); + let mut reference_tables = ReferenceTables::new(); + let _ = reference_tables.push_or_replace(table_name); + + // Tips: only `Column`, `IndexMeta`, `TableMeta` + while let Some((key, value)) = column_iter.try_next().ok().flatten() { + if key.starts_with(table_min) { + let mut cursor = Cursor::new(value); + columns.push(TableCodec::decode_column::( + &mut cursor, + &reference_tables, + )?); + } else { + index_metas.push(Arc::new(TableCodec::decode_index_meta::(value)?)); + } } - } - - Ok((!columns.is_empty()).then_some((columns, index_metas))) - } - - fn _drop_data(&mut self, min: BumpBytes, max: BumpBytes) -> Result<(), DatabaseError> { - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - let mut data_keys = vec![]; - - while let Some((key, _)) = iter.try_next()? { - data_keys.push(key); - } - drop(iter); - for key in data_keys { - self.remove(&key)? - } - - Ok(()) + Ok((!columns.is_empty()).then_some((columns, index_metas))) + }) } fn create_index_meta_from_column( @@ -985,9 +980,12 @@ pub trait Transaction: Sized { }; let meta_ref = table.add_index_meta(format!("uk_{}_index", col.name()), vec![col_id], index_ty)?; - let (key, value) = - unsafe { &*self.table_codec() }.encode_index_meta(&table_name, meta_ref)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; + unsafe { &*self.table_codec() }.with_index_meta_key( + &table_name, + meta_ref.id, + |key| self.set(key, value.as_slice()), + )?; } let primary_keys = table .primary_keys() @@ -998,40 +996,156 @@ pub trait Transaction: Sized { is_multiple: primary_keys.len() != 1, }; let meta_ref = table.add_index_meta("pk_index".to_string(), primary_keys, pk_index_ty)?; - let (key, value) = - unsafe { &*self.table_codec() }.encode_index_meta(&table_name, meta_ref)?; - self.set(key, value)?; + let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; + unsafe { &*self.table_codec() }.with_index_meta_key(&table_name, meta_ref.id, |key| { + self.set(key, value.as_slice()) + })?; Ok(()) } - fn get(&self, key: &[u8]) -> Result, DatabaseError>; + fn get_borrowed<'a>( + &'a self, + key: &[u8], + ) -> Result>, DatabaseError>; + + fn exists(&self, key: &[u8]) -> Result { + Ok(self.get_borrowed(key)?.is_some()) + } - fn set(&mut self, key: BumpBytes, value: BumpBytes) -> Result<(), DatabaseError>; + fn set(&mut self, key: &[u8], value: &[u8]) -> Result<(), DatabaseError>; fn remove(&mut self, key: &[u8]) -> Result<(), DatabaseError>; - fn range<'a>( - &'a self, - min: Bound>, - max: Bound>, - ) -> Result, DatabaseError>; + fn range<'txn, 'key>( + &'txn self, + min: Bound<&'key [u8]>, + max: Bound<&'key [u8]>, + ) -> Result, DatabaseError>; + + fn remove_range(&mut self, min: Bound<&[u8]>, max: Bound<&[u8]>) -> Result<(), DatabaseError> { + const DELETE_BATCH_SIZE: usize = 1024; + + let mut lower = owned_bound(min); + let upper = owned_bound(max); + let mut data_keys = Vec::with_capacity(DELETE_BATCH_SIZE); + + loop { + data_keys.clear(); + let mut iter = + self.range(bytes_bound_as_slice(&lower), bytes_bound_as_slice(&upper))?; + + while data_keys.len() < DELETE_BATCH_SIZE { + let Some((key, _)) = iter.try_next()? else { + break; + }; + data_keys.push(key.to_vec()); + } + drop(iter); + + let Some(last_key) = data_keys.pop() else { + return Ok(()); + }; + let batch_full = data_keys.len() + 1 == DELETE_BATCH_SIZE; + + for key in data_keys.drain(..) { + self.remove(&key)?; + } + self.remove(&last_key)?; + + if !batch_full { + return Ok(()); + } + reuse_bound_as_excluded(&mut lower, &last_key); + } + } fn commit(self) -> Result<(), DatabaseError>; } -trait IndexImpl<'bytes, T: Transaction + 'bytes> { - fn index_lookup( +fn owned_bound(bound: Bound<&[u8]>) -> Bound { + match bound { + Bound::Included(bytes) => Bound::Included(bytes.to_vec()), + Bound::Excluded(bytes) => Bound::Excluded(bytes.to_vec()), + Bound::Unbounded => Bound::Unbounded, + } +} + +pub(crate) fn reuse_bound_as_excluded(bound: &mut Bound, key: &[u8]) { + let mut bytes = match mem::replace(bound, Bound::Unbounded) { + Bound::Included(bytes) | Bound::Excluded(bytes) => bytes, + Bound::Unbounded => Vec::new(), + }; + bytes.clear(); + bytes.extend_from_slice(key); + *bound = Bound::Excluded(bytes); +} + +fn bytes_bound_as_slice(bound: &Bound) -> Bound<&[u8]> { + match bound { + Bound::Included(bytes) => Bound::Included(bytes.as_slice()), + Bound::Excluded(bytes) => Bound::Excluded(bytes.as_slice()), + Bound::Unbounded => Bound::Unbounded, + } +} + +#[inline] +fn fill_default_bound<'a>(bound: Bound<&'a [u8]>, default: &'a [u8]) -> Bound<&'a [u8]> { + match bound { + Bound::Included(bytes) => Bound::Included(bytes), + Bound::Excluded(bytes) => Bound::Excluded(bytes), + Bound::Unbounded => Bound::Included(default), + } +} + +#[inline] +fn encode_bound_key(buffer: &mut Bytes, key: &[u8], is_upper: bool) { + buffer.clear(); + buffer.extend_from_slice(key); + if is_upper { + buffer.push(BOUND_MAX_TAG); + } +} + +#[inline] +fn encode_bound<'a>( + bound: Bound, + is_upper: bool, + buffer: &'a mut Bytes, + params: &IndexImplParams<'_, impl Transaction>, + inner: &IndexImplEnum, +) -> Result, DatabaseError> { + match bound { + Bound::Included(mut val) => { + val = params.try_cast(val)?; + inner.bound_key(params, &val, is_upper, buffer)?; + Ok(Bound::Included(buffer.as_slice())) + } + Bound::Excluded(mut val) => { + val = params.try_cast(val)?; + inner.bound_key(params, &val, is_upper, buffer)?; + Ok(Bound::Excluded(buffer.as_slice())) + } + Bound::Unbounded => Ok(Bound::Unbounded), + } +} + +trait IndexImpl { + fn index_lookup_into( &self, - key: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + key: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result; + ) -> Result<(), DatabaseError>; fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError>; fn bound_key( @@ -1039,7 +1153,8 @@ trait IndexImpl<'bytes, T: Transaction + 'bytes> { params: &IndexImplParams, value: &DataValue, is_upper: bool, - ) -> Result, DatabaseError>; + out: &mut Bytes, + ) -> Result<(), DatabaseError>; } enum IndexImplEnum { @@ -1094,7 +1209,6 @@ struct IndexImplParams<'a, T: Transaction> { index_meta: IndexMetaRef, table_name: &'a str, deserializers: Vec, - values_len: usize, total_len: usize, tx: &'a T, cover_mapping: Option, @@ -1121,56 +1235,74 @@ impl IndexImplParams<'_, T> { Ok(val) } - fn get_tuple_by_id(&self, tuple_id: &TupleId) -> Result, DatabaseError> { - let key = unsafe { &*self.table_codec() }.encode_tuple_key(self.table_name, tuple_id)?; - - self.tx - .get(&key)? - .map(|bytes| { - TableCodec::decode_tuple( - &self.deserializers, - Some(tuple_id.clone()), - &bytes, - self.values_len, - self.total_len, - ) - }) - .transpose() + fn get_tuple_by_id_into( + &self, + tuple_id: &TupleId, + tuple: &mut Tuple, + ) -> Result { + unsafe { &*self.table_codec() }.with_tuple_key_unchecked(self.table_name, tuple_id, |key| { + let Some(bytes) = self.tx.get_borrowed(key)? else { + return Ok(false); + }; + TableCodec::decode_tuple_into( + tuple, + &self.deserializers, + Some(tuple_id.clone()), + bytes.as_ref(), + self.total_len, + )?; + Ok(true) + }) } } enum IndexResult<'a, T: Transaction + 'a> { - Tuple(Option), + Hit, + Miss, Scope(T::IterType<'a>), } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for IndexImplEnum { - fn index_lookup( +impl IndexImpl for IndexImplEnum { + fn index_lookup_into( &self, - key: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + key: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { + ) -> Result<(), DatabaseError> { match self { - IndexImplEnum::PrimaryKey(inner) => inner.index_lookup(key, value, params), - IndexImplEnum::Unique(inner) => inner.index_lookup(key, value, params), - IndexImplEnum::Normal(inner) => inner.index_lookup(key, value, params), - IndexImplEnum::Composite(inner) => inner.index_lookup(key, value, params), - IndexImplEnum::Covered(inner) => inner.index_lookup(key, value, params), + IndexImplEnum::PrimaryKey(inner) => inner.index_lookup_into(tuple, key, value, params), + IndexImplEnum::Unique(inner) => inner.index_lookup_into(tuple, key, value, params), + IndexImplEnum::Normal(inner) => inner.index_lookup_into(tuple, key, value, params), + IndexImplEnum::Composite(inner) => inner.index_lookup_into(tuple, key, value, params), + IndexImplEnum::Covered(inner) => inner.index_lookup_into(tuple, key, value, params), } } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError> { match self { - IndexImplEnum::PrimaryKey(inner) => inner.eq_to_res(value, params), - IndexImplEnum::Unique(inner) => inner.eq_to_res(value, params), - IndexImplEnum::Normal(inner) => inner.eq_to_res(value, params), - IndexImplEnum::Composite(inner) => inner.eq_to_res(value, params), - IndexImplEnum::Covered(inner) => inner.eq_to_res(value, params), + IndexImplEnum::PrimaryKey(inner) => { + inner.eq_to_res(tuple, value, params, encode_min, encode_max) + } + IndexImplEnum::Unique(inner) => { + inner.eq_to_res(tuple, value, params, encode_min, encode_max) + } + IndexImplEnum::Normal(inner) => { + inner.eq_to_res(tuple, value, params, encode_min, encode_max) + } + IndexImplEnum::Composite(inner) => { + inner.eq_to_res(tuple, value, params, encode_min, encode_max) + } + IndexImplEnum::Covered(inner) => { + inner.eq_to_res(tuple, value, params, encode_min, encode_max) + } } } @@ -1179,54 +1311,67 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for IndexImplEnum { params: &IndexImplParams, value: &DataValue, is_upper: bool, - ) -> Result, DatabaseError> { + out: &mut Bytes, + ) -> Result<(), DatabaseError> { match self { - IndexImplEnum::PrimaryKey(inner) => inner.bound_key(params, value, is_upper), - IndexImplEnum::Unique(inner) => inner.bound_key(params, value, is_upper), - IndexImplEnum::Normal(inner) => inner.bound_key(params, value, is_upper), - IndexImplEnum::Composite(inner) => inner.bound_key(params, value, is_upper), - IndexImplEnum::Covered(inner) => inner.bound_key(params, value, is_upper), + IndexImplEnum::PrimaryKey(inner) => inner.bound_key(params, value, is_upper, out), + IndexImplEnum::Unique(inner) => inner.bound_key(params, value, is_upper, out), + IndexImplEnum::Normal(inner) => inner.bound_key(params, value, is_upper, out), + IndexImplEnum::Composite(inner) => inner.bound_key(params, value, is_upper, out), + IndexImplEnum::Covered(inner) => inner.bound_key(params, value, is_upper, out), } } } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for PrimaryKeyIndexImpl { - fn index_lookup( +impl IndexImpl for PrimaryKeyIndexImpl { + fn index_lookup_into( &self, - key: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + key: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { + ) -> Result<(), DatabaseError> { let tuple_id = TableCodec::decode_tuple_key(key, ¶ms.index_meta.pk_ty)?; - TableCodec::decode_tuple( + TableCodec::decode_tuple_into( + tuple, ¶ms.deserializers, Some(tuple_id), value, - params.values_len, params.total_len, ) } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + _: &mut Bytes, + _: &mut Bytes, ) -> Result, DatabaseError> { let tuple_id = value.clone(); - let tuple = params - .tx - .get(&unsafe { &*params.table_codec() }.encode_tuple_key(params.table_name, value)?)? - .map(|bytes| { - TableCodec::decode_tuple( + let found = unsafe { &*params.table_codec() }.with_tuple_key_unchecked( + params.table_name, + value, + |key| { + let Some(bytes) = params.tx.get_borrowed(key)? else { + return Ok(false); + }; + TableCodec::decode_tuple_into( + tuple, ¶ms.deserializers, Some(tuple_id.clone()), - &bytes, - params.values_len, + bytes.as_ref(), params.total_len, - ) - }) - .transpose()?; - Ok(IndexResult::Tuple(tuple)) + )?; + Ok(true) + }, + )?; + Ok(if found { + IndexResult::Hit + } else { + IndexResult::Miss + }) } fn bound_key( @@ -1234,45 +1379,69 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for PrimaryKeyIndexIm params: &IndexImplParams, value: &DataValue, _: bool, - ) -> Result, DatabaseError> { - unsafe { &*params.table_codec() }.encode_tuple_key(params.table_name, value) + out: &mut Bytes, + ) -> Result<(), DatabaseError> { + unsafe { &*params.table_codec() }.with_tuple_key_unchecked( + params.table_name, + value, + |key| { + out.clear(); + out.extend_from_slice(key); + Ok(()) + }, + ) } } #[inline(always)] fn secondary_index_lookup( - bytes: &Bytes, + tuple: &mut Tuple, + bytes: &[u8], params: &IndexImplParams, -) -> Result { +) -> Result<(), DatabaseError> { let tuple_id = TableCodec::decode_index(bytes)?; - params - .get_tuple_by_id(&tuple_id)? - .ok_or(DatabaseError::TupleIdNotFound(tuple_id)) + if params.get_tuple_by_id_into(&tuple_id, tuple)? { + Ok(()) + } else { + Err(DatabaseError::TupleIdNotFound(tuple_id)) + } } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for UniqueIndexImpl { - fn index_lookup( +impl IndexImpl for UniqueIndexImpl { + fn index_lookup_into( &self, - _: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + _: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { - secondary_index_lookup(value, params) + ) -> Result<(), DatabaseError> { + secondary_index_lookup(tuple, value, params) } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + _: &mut Bytes, + _: &mut Bytes, ) -> Result, DatabaseError> { - let Some(bytes) = params.tx.get(&self.bound_key(params, value, false)?)? else { - return Ok(IndexResult::Tuple(None)); + let index = Index::new(params.index_meta.id, value, IndexType::Unique); + let Some(bytes) = unsafe { &*params.table_codec() }.with_index_key( + params.table_name, + &index, + None, + |key| params.tx.get_borrowed(key), + )? + else { + return Ok(IndexResult::Miss); }; - let tuple_id = TableCodec::decode_index(&bytes)?; - let tuple = params - .get_tuple_by_id(&tuple_id)? - .ok_or(DatabaseError::TupleIdNotFound(tuple_id))?; - Ok(IndexResult::Tuple(Some(tuple))) + let tuple_id = TableCodec::decode_index(bytes.as_ref())?; + if params.get_tuple_by_id_into(&tuple_id, tuple)? { + Ok(IndexResult::Hit) + } else { + Err(DatabaseError::TupleIdNotFound(tuple_id)) + } } fn bound_key( @@ -1280,29 +1449,38 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for UniqueIndexImpl { params: &IndexImplParams, value: &DataValue, _: bool, - ) -> Result, DatabaseError> { + out: &mut Bytes, + ) -> Result<(), DatabaseError> { let index = Index::new(params.index_meta.id, value, IndexType::Unique); - unsafe { &*params.table_codec() }.encode_index_key(params.table_name, &index, None) + unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + out.clear(); + out.extend_from_slice(key); + Ok(()) + }) } } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for NormalIndexImpl { - fn index_lookup( +impl IndexImpl for NormalIndexImpl { + fn index_lookup_into( &self, - _: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + _: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { - secondary_index_lookup(value, params) + ) -> Result<(), DatabaseError> { + secondary_index_lookup(tuple, value, params) } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(self, value, params) + eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) } fn bound_key( @@ -1310,33 +1488,36 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for NormalIndexImpl { params: &IndexImplParams, value: &DataValue, is_upper: bool, - ) -> Result, DatabaseError> { + out: &mut Bytes, + ) -> Result<(), DatabaseError> { let index = Index::new(params.index_meta.id, value, IndexType::Normal); - - unsafe { &*params.table_codec() }.encode_index_bound_key( - params.table_name, - &index, - is_upper, - ) + unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + encode_bound_key(out, key, is_upper); + Ok(()) + }) } } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CompositeIndexImpl { - fn index_lookup( +impl IndexImpl for CompositeIndexImpl { + fn index_lookup_into( &self, - _: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + _: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { - secondary_index_lookup(value, params) + ) -> Result<(), DatabaseError> { + secondary_index_lookup(tuple, value, params) } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(self, value, params) + eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) } fn bound_key( @@ -1344,24 +1525,24 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CompositeIndexImp params: &IndexImplParams, value: &DataValue, is_upper: bool, - ) -> Result, DatabaseError> { + out: &mut Bytes, + ) -> Result<(), DatabaseError> { let index = Index::new(params.index_meta.id, value, IndexType::Composite); - - unsafe { &*params.table_codec() }.encode_index_bound_key( - params.table_name, - &index, - is_upper, - ) + unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + encode_bound_key(out, key, is_upper); + Ok(()) + }) } } -impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CoveredIndexImpl { - fn index_lookup( +impl IndexImpl for CoveredIndexImpl { + fn index_lookup_into( &self, - key: &Bytes, - value: &Bytes, + tuple: &mut Tuple, + key: &[u8], + value: &[u8], params: &IndexImplParams, - ) -> Result { + ) -> Result<(), DatabaseError> { let mapping = params .cover_mapping .as_ref() @@ -1373,21 +1554,23 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CoveredIndexImpl } else { None }; - let values = match key { + tuple.pk = tuple_id; + tuple.values = match key { DataValue::Tuple(vals, _) => vals, - v => { - vec![v] - } + v => vec![v], }; - Ok(Tuple::new(tuple_id, values)) + Ok(()) } fn eq_to_res<'a>( &self, + tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(self, value, params) + eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) } fn bound_key( @@ -1395,29 +1578,33 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CoveredIndexImpl params: &IndexImplParams, value: &DataValue, is_upper: bool, - ) -> Result, DatabaseError> { + out: &mut Bytes, + ) -> Result<(), DatabaseError> { let index = Index::new(params.index_meta.id, value, params.index_meta.ty); - - unsafe { &*params.table_codec() }.encode_index_bound_key( - params.table_name, - &index, - is_upper, - ) + unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + encode_bound_key(out, key, is_upper); + Ok(()) + }) } } #[inline(always)] fn eq_to_res_scope<'a, T: Transaction + 'a>( - index_impl: &impl IndexImpl<'a, T>, + tuple: &mut Tuple, + index_impl: &impl IndexImpl, value: &DataValue, params: &IndexImplParams<'a, T>, + encode_min: &mut Bytes, + encode_max: &mut Bytes, ) -> Result, DatabaseError> { - let min = index_impl.bound_key(params, value, false)?; - let max = index_impl.bound_key(params, value, true)?; - - let iter = params - .tx - .range(Bound::Included(min), Bound::Included(max))?; + let _ = tuple; + index_impl.bound_key(params, value, false, encode_min)?; + index_impl.bound_key(params, value, true, encode_max)?; + + let iter = params.tx.range( + Bound::Included(encode_min.as_slice()), + Bound::Included(encode_max.as_slice()), + )?; Ok(IndexResult::Scope(iter)) } @@ -1426,16 +1613,15 @@ pub struct TupleIter<'a, T: Transaction + 'a> { limit: Option, pk_ty: Option, deserializers: Vec, - values_len: usize, total_len: usize, iter: T::IterType<'a>, } impl<'a, T: Transaction + 'a> Iter for TupleIter<'a, T> { - fn next_tuple(&mut self) -> Result, DatabaseError> { + fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result { while self.offset > 0 { if self.iter.try_next()?.is_none() { - return Ok(None); + return Ok(false); } self.offset -= 1; } @@ -1444,27 +1630,27 @@ impl<'a, T: Transaction + 'a> Iter for TupleIter<'a, T> { while let Some((key, value)) = self.iter.try_next()? { if let Some(limit) = self.limit.as_mut() { if *limit == 0 { - return Ok(None); + return Ok(false); } *limit -= 1; } let tuple_id = if let Some(pk_ty) = &self.pk_ty { - Some(TableCodec::decode_tuple_key(&key, pk_ty)?) + Some(TableCodec::decode_tuple_key(key, pk_ty)?) } else { None }; - let tuple = TableCodec::decode_tuple( + TableCodec::decode_tuple_into( + tuple, &self.deserializers, tuple_id, - &value, - self.values_len, + value, self.total_len, )?; - return Ok(Some(tuple)); + return Ok(true); } - Ok(None) + Ok(false) } } @@ -1476,6 +1662,8 @@ pub struct IndexIter<'a, T: Transaction> { // for buffering data ranges: IntoIter, state: IndexIterState<'a, T>, + encode_min_buffer: Bytes, + encode_max_buffer: Bytes, } pub enum IndexIterState<'a, T: Transaction + 'a> { @@ -1504,16 +1692,11 @@ impl<'a, T: Transaction + 'a> IndexIter<'a, T> { /// expression -> index value -> tuple impl Iter for IndexIter<'_, T> { - fn next_tuple(&mut self) -> Result, DatabaseError> { - fn check_bound<'a>(value: &mut Bound>, bound: BumpBytes<'a>) { - if matches!(value, Bound::Unbounded) { - let _ = mem::replace(value, Bound::Included(bound)); - } - } + fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result { if matches!(self.limit, Some(0)) { self.state = IndexIterState::Over; - return Ok(None); + return Ok(false); } loop { @@ -1527,58 +1710,58 @@ impl Iter for IndexIter<'_, T> { Range::Scope { min, max } => { let table_name = &self.params.table_name; let index_meta = &self.params.index_meta; - let bound_encode = - |bound: Bound, - is_upper: bool| - -> Result<_, DatabaseError> { - match bound { - Bound::Included(mut val) => { - val = self.params.try_cast(val)?; - - Ok(Bound::Included(self.inner.bound_key( - &self.params, - &val, - is_upper, - )?)) - } - Bound::Excluded(mut val) => { - val = self.params.try_cast(val)?; - - Ok(Bound::Excluded(self.inner.bound_key( - &self.params, - &val, - is_upper, - )?)) - } - Bound::Unbounded => Ok(Bound::Unbounded), - } - }; - let (bound_min, bound_max) = - if matches!(index_meta.ty, IndexType::PrimaryKey { .. }) { - unsafe { &*self.params.table_codec() }.tuple_bound(table_name) - } else { - unsafe { &*self.params.table_codec() } - .index_bound(table_name, index_meta.id)? - }; - let mut encode_min = bound_encode(min, false)?; - check_bound(&mut encode_min, bound_min); - - let mut encode_max = bound_encode(max, true)?; - check_bound(&mut encode_max, bound_max); - let iter = self.params.tx.range(encode_min, encode_max)?; + let encode_min = encode_bound( + min, + false, + &mut self.encode_min_buffer, + &self.params, + &self.inner, + )?; + let encode_max = encode_bound( + max, + true, + &mut self.encode_max_buffer, + &self.params, + &self.inner, + )?; + + let table_codec = unsafe { &*self.params.table_codec() }; + let tx = self.params.tx; + let open_iter = move |bound_min: &[u8], bound_max: &[u8]| { + tx.range( + fill_default_bound(encode_min, bound_min), + fill_default_bound(encode_max, bound_max), + ) + }; + let iter = if matches!(index_meta.ty, IndexType::PrimaryKey { .. }) { + table_codec.with_tuple_bound(table_name, open_iter)? + } else { + table_codec.with_index_bound( + table_name, + index_meta.id, + open_iter, + )? + }; self.state = IndexIterState::Range(iter); } Range::Eq(mut val) => { val = self.params.try_cast(val)?; - match self.inner.eq_to_res(&val, &self.params)? { - IndexResult::Tuple(tuple) => { + match self.inner.eq_to_res( + tuple, + &val, + &self.params, + &mut self.encode_min_buffer, + &mut self.encode_max_buffer, + )? { + IndexResult::Hit => { if Self::offset_move(&mut self.offset) { continue; } Self::limit_sub(&mut self.limit); - return Ok(tuple); + return Ok(true); } + IndexResult::Miss => return Ok(false), IndexResult::Scope(iter) => { self.state = IndexIterState::Range(iter); } @@ -1593,24 +1776,35 @@ impl Iter for IndexIter<'_, T> { continue; } Self::limit_sub(&mut self.limit); - let tuple = self.inner.index_lookup(&key, &value, &self.params)?; + self.inner + .index_lookup_into(tuple, key, value, &self.params)?; - return Ok(Some(tuple)); + return Ok(true); } self.state = IndexIterState::Init; } - IndexIterState::Over => return Ok(None), + IndexIterState::Over => return Ok(false), } } } } pub trait InnerIter { - fn try_next(&mut self) -> Result, DatabaseError>; + fn try_next(&mut self) -> Result>, DatabaseError>; } pub trait Iter { - fn next_tuple(&mut self) -> Result, DatabaseError>; + fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result; +} + +#[cfg(test)] +pub(crate) fn next_tuple_for_test(iter: &mut I) -> Result, DatabaseError> { + let mut tuple = Tuple::default(); + if iter.next_tuple_into(&mut tuple)? { + Ok(Some(tuple)) + } else { + Ok(None) + } } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -1626,7 +1820,7 @@ mod test { use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; use crate::storage::table_codec::TableCodec; use crate::storage::{ - IndexIter, InnerIter, Iter, StatisticsMetaCache, Storage, TableCache, Transaction, + IndexIter, InnerIter, StatisticsMetaCache, Storage, TableCache, Transaction, }; use crate::types::index::{Index, IndexMeta, IndexType}; use crate::types::tuple::Tuple; @@ -1827,12 +2021,22 @@ mod test { true, )?; - assert_eq!(tuple_iter.next_tuple()?.unwrap(), tuples[0]); - assert_eq!(tuple_iter.next_tuple()?.unwrap(), tuples[1]); - assert_eq!(tuple_iter.next_tuple()?.unwrap(), tuples[2]); + assert_eq!( + super::next_tuple_for_test(&mut tuple_iter)?.unwrap(), + tuples[0] + ); + assert_eq!( + super::next_tuple_for_test(&mut tuple_iter)?.unwrap(), + tuples[1] + ); + assert_eq!( + super::next_tuple_for_test(&mut tuple_iter)?.unwrap(), + tuples[2] + ); - let (min, max) = table_codec.tuple_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_tuple_bound("t1", |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); @@ -1853,11 +2057,18 @@ mod test { true, )?; - assert_eq!(tuple_iter.next_tuple()?.unwrap(), tuples[0]); - assert_eq!(tuple_iter.next_tuple()?.unwrap(), tuples[2]); + assert_eq!( + super::next_tuple_for_test(&mut tuple_iter)?.unwrap(), + tuples[0] + ); + assert_eq!( + super::next_tuple_for_test(&mut tuple_iter)?.unwrap(), + tuples[2] + ); - let (min, max) = table_codec.tuple_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_tuple_bound("t1", |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); @@ -1935,8 +2146,9 @@ mod test { &Arc::new(SharedLruCache::new(4, 1, RandomState::new())?), )?; { - let (min, max) = table_codec.index_meta_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_index_meta_bound("t1", |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); @@ -1976,8 +2188,9 @@ mod test { assert_eq!(i2_meta.name, "i2".to_string()); assert_eq!(i2_meta.ty, IndexType::Composite); - let (min, max) = table_codec.index_meta_bound("t1"); - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_index_meta_bound("t1", |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); @@ -2073,12 +2286,22 @@ mod test { { let mut index_iter = build_index_iter(&transaction, &table_cache, c3_column_id)?; - assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[0]); - assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[2]); - assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[1]); + assert_eq!( + super::next_tuple_for_test(&mut index_iter)?.unwrap(), + tuples[0] + ); + assert_eq!( + super::next_tuple_for_test(&mut index_iter)?.unwrap(), + tuples[2] + ); + assert_eq!( + super::next_tuple_for_test(&mut index_iter)?.unwrap(), + tuples[1] + ); - let (min, max) = table_codec.index_bound("t1", 1)?; - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_index_bound("t1", 1, |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); @@ -2092,11 +2315,18 @@ mod test { let mut index_iter = build_index_iter(&transaction, &table_cache, c3_column_id)?; - assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[2]); - assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[1]); + assert_eq!( + super::next_tuple_for_test(&mut index_iter)?.unwrap(), + tuples[2] + ); + assert_eq!( + super::next_tuple_for_test(&mut index_iter)?.unwrap(), + tuples[1] + ); - let (min, max) = table_codec.index_bound("t1", 1)?; - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut iter = table_codec.with_index_bound("t1", 1, |min, max| { + transaction.range(Bound::Included(min), Bound::Included(max)) + })?; let (_, value) = iter.try_next()?.unwrap(); dbg!(value); diff --git a/src/storage/rocksdb.rs b/src/storage/rocksdb.rs index 555ea343..4d062db6 100644 --- a/src/storage/rocksdb.rs +++ b/src/storage/rocksdb.rs @@ -13,11 +13,11 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::storage::table_codec::{BumpBytes, Bytes, TableCodec}; +use crate::storage::table_codec::{Bytes, TableCodec}; use crate::storage::{InnerIter, Storage, Transaction}; use rocksdb::{ statistics::{StatsLevel, Ticker}, - DBIteratorWithThreadMode, Direction, IteratorMode, OptimisticTransactionDB, Options, + DBPinnableSlice, DBRawIteratorWithThreadMode, OptimisticTransactionDB, Options, ReadOptions, SliceTransform, TransactionDB, }; use std::collections::Bound; @@ -374,9 +374,14 @@ pub struct RocksTransaction<'db> { #[macro_export] macro_rules! impl_transaction { ($tx:ident, $iter:ident) => { - impl<'txn> Transaction for $tx<'txn> { + impl<'storage> Transaction for $tx<'storage> { + type BorrowedBytes<'a> + = DBPinnableSlice<'a> + where + Self: 'a; + type IterType<'iter> - = $iter<'txn, 'iter> + = $iter<'storage, 'iter> where Self: 'iter; @@ -386,12 +391,15 @@ macro_rules! impl_transaction { } #[inline] - fn get(&self, key: &[u8]) -> Result, DatabaseError> { - Ok(self.tx.get(key)?) + fn get_borrowed<'a>( + &'a self, + key: &[u8], + ) -> Result>, DatabaseError> { + Ok(self.tx.get_pinned(key)?) } #[inline] - fn set(&mut self, key: BumpBytes, value: BumpBytes) -> Result<(), DatabaseError> { + fn set(&mut self, key: &[u8], value: &[u8]) -> Result<(), DatabaseError> { self.tx.put(key, value)?; Ok(()) @@ -406,44 +414,45 @@ macro_rules! impl_transaction { // Tips: rocksdb has weak support for `Include` and `Exclude`, so precision will be lost #[inline] - fn range<'a>( + fn range<'a, 'key>( &'a self, - min: Bound>, - max: Bound>, + min: Bound<&'key [u8]>, + max: Bound<&'key [u8]>, ) -> Result, DatabaseError> { - let min = match min { - Bound::Included(bytes) => Some(bytes), - Bound::Excluded(mut bytes) => { - // the prefix is the same, but the length is larger - bytes.push(0u8); - Some(bytes) - } - Bound::Unbounded => None, - }; - let lower = min - .as_ref() - .map(|bytes| IteratorMode::From(bytes, Direction::Forward)) - .unwrap_or(IteratorMode::Start); - - if let (Some(min_bytes), Bound::Included(max_bytes) | Bound::Excluded(max_bytes)) = - (&min, &max) + let mut read_opts = ReadOptions::default(); + if let ( + Bound::Included(min_bytes) | Bound::Excluded(min_bytes), + Bound::Included(max_bytes) | Bound::Excluded(max_bytes), + ) = (&min, &max) { let len = min_bytes .iter() .zip(max_bytes.iter()) .take_while(|(x, y)| x == y) .count(); - if len >= ROCKSDB_FIXED_PREFIX_LEN { - let mut iter = self.tx.prefix_iterator(&min_bytes[..len]); - iter.set_mode(lower); + read_opts.set_prefix_same_as_start(true); + } + } - return Ok($iter { upper: max, iter }); + let mut iter = self.tx.raw_iterator_opt(read_opts); + match &min { + Bound::Included(bytes) => iter.seek(*bytes), + Bound::Excluded(bytes) => { + iter.seek(*bytes); + if iter.key() == Some(*bytes) { + iter.next(); + } } + Bound::Unbounded => iter.seek_to_first(), } - let iter = self.tx.iterator(lower); - Ok($iter { upper: max, iter }) + Ok($iter { + upper: owned_bound(max), + iter, + advanced: false, + done: false, + }) } fn commit(self) -> Result<(), DatabaseError> { @@ -458,64 +467,102 @@ impl_transaction!(RocksTransaction, RocksIter); impl_transaction!(OptimisticRocksTransaction, OptimisticRocksIter); pub struct OptimisticRocksIter<'txn, 'iter> { - upper: Bound>, - iter: DBIteratorWithThreadMode<'iter, rocksdb::Transaction<'txn, OptimisticTransactionDB>>, + upper: Bound, + iter: DBRawIteratorWithThreadMode<'iter, rocksdb::Transaction<'txn, OptimisticTransactionDB>>, + advanced: bool, + done: bool, } impl InnerIter for OptimisticRocksIter<'_, '_> { #[inline] - fn try_next(&mut self) -> Result, DatabaseError> { - if let Some(result) = self.iter.by_ref().next() { - return next(self.upper.as_ref(), result?); - } - Ok(None) + fn try_next(&mut self) -> Result>, DatabaseError> { + next( + &mut self.iter, + &self.upper, + &mut self.advanced, + &mut self.done, + ) } } pub struct RocksIter<'txn, 'iter> { - upper: Bound>, - iter: DBIteratorWithThreadMode< + upper: Bound, + iter: DBRawIteratorWithThreadMode< 'iter, rocksdb::Transaction<'txn, TransactionDB>, >, + advanced: bool, + done: bool, } impl InnerIter for RocksIter<'_, '_> { #[inline] - fn try_next(&mut self) -> Result, DatabaseError> { - if let Some(result) = self.iter.by_ref().next() { - return next(self.upper.as_ref(), result?); - } - Ok(None) + fn try_next(&mut self) -> Result>, DatabaseError> { + next( + &mut self.iter, + &self.upper, + &mut self.advanced, + &mut self.done, + ) } } #[inline] -fn next( - upper: Bound<&BumpBytes<'_>>, - (key, value): (Box<[u8]>, Box<[u8]>), -) -> Result, DatabaseError> { +fn next<'a, D: rocksdb::DBAccess>( + iter: &'a mut DBRawIteratorWithThreadMode<'_, D>, + upper: &Bound, + advanced: &mut bool, + done: &mut bool, +) -> Result>, DatabaseError> { + if *done { + return Ok(None); + } + if *advanced { + iter.next(); + } + if !iter.valid() { + *done = true; + iter.status()?; + return Ok(None); + } + + let Some((key, value)) = iter.item() else { + *done = true; + iter.status()?; + return Ok(None); + }; let upper_bound_check = match upper { - Bound::Included(upper) => key.as_ref() <= upper.as_slice(), - Bound::Excluded(upper) => key.as_ref() < upper.as_slice(), + Bound::Included(upper) => key <= upper.as_slice(), + Bound::Excluded(upper) => key < upper.as_slice(), Bound::Unbounded => true, }; if !upper_bound_check { + *done = true; return Ok(None); } - Ok(Some((Vec::from(key), Vec::from(value)))) + + *advanced = true; + Ok(Some((key, value))) +} + +fn owned_bound(bound: Bound<&[u8]>) -> Bound { + match bound { + Bound::Included(bytes) => Bound::Included(bytes.to_vec()), + Bound::Excluded(bytes) => Bound::Excluded(bytes.to_vec()), + Bound::Unbounded => Bound::Unbounded, + } } #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::{DataBaseBuilder, ResultIter}; + use crate::db::DataBaseBuilder; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::storage::rocksdb::RocksStorage; use crate::storage::{ - IndexImplEnum, IndexImplParams, IndexIter, IndexIterState, Iter, PrimaryKeyIndexImpl, - Storage, Transaction, + IndexImplEnum, IndexImplParams, IndexIter, IndexIterState, PrimaryKeyIndexImpl, Storage, + Transaction, }; use crate::types::index::{IndexMeta, IndexType}; use crate::types::tuple::Tuple; @@ -542,7 +589,7 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let kite_sql = DataBaseBuilder::path(temp_dir.path()) .storage_statistics(true) - .build()?; + .build_rocksdb()?; kite_sql .run("create table t_metrics (a int primary key, b int)")? .done()?; @@ -629,10 +676,10 @@ mod test { true, )?; - let option_1 = iter.next_tuple()?; + let option_1 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_1.unwrap().pk, Some(DataValue::Int32(2))); - let option_2 = iter.next_tuple()?; + let option_2 = crate::storage::next_tuple_for_test(&mut iter)?; assert_eq!(option_2, None); Ok(()) @@ -641,7 +688,7 @@ mod test { #[test] fn test_index_iter_pk() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key)")? @@ -667,7 +714,6 @@ mod test { .columns() .map(|column| column.datatype().serializable()) .collect_vec(); - let values_len = deserializers.len(); let mut iter = IndexIter { offset: 0, limit: None, @@ -683,7 +729,6 @@ mod test { }), table_name: &table.name, deserializers, - values_len, total_len: table.columns_len(), tx: &transaction, cover_mapping: None, @@ -699,10 +744,12 @@ mod test { .into_iter(), state: IndexIterState::Init, inner: IndexImplEnum::PrimaryKey(PrimaryKeyIndexImpl), + encode_min_buffer: Vec::new(), + encode_max_buffer: Vec::new(), }; let mut result = Vec::new(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { result.push(tuple.pk.unwrap()); } @@ -714,7 +761,7 @@ mod test { #[test] fn test_read_by_index() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int unique)")? .done()?; @@ -745,7 +792,7 @@ mod test { ) .unwrap(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { assert_eq!(tuple.pk, Some(DataValue::Int32(1))); assert_eq!(tuple.values, vec![DataValue::Int32(1), DataValue::Int32(1)]) } @@ -772,7 +819,7 @@ mod test { ) .unwrap(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { assert_eq!(tuple.pk, Some(DataValue::Int32(3))); assert_eq!(tuple.values, vec![DataValue::Int32(3)]) } @@ -784,7 +831,7 @@ mod test { #[test] fn test_read_by_index_cover() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; kite_sql .run("create table t1 (a int primary key, b int unique)")? .done()?; @@ -868,7 +915,7 @@ mod test { Some(reordered_deserializers), Some(cover_mapping), )?; - let first_tuple = iter.next_tuple()?.unwrap(); + let first_tuple = crate::storage::next_tuple_for_test(&mut iter)?.unwrap(); assert_eq!( first_tuple.values, vec![DataValue::Int32(0), DataValue::Int32(0)] @@ -892,7 +939,7 @@ mod test { )?; let mut tuples = Vec::new(); - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { tuples.push(tuple); } @@ -925,7 +972,7 @@ mod test { Some(vec![0]), )?; let mut row_count = 0; - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = crate::storage::next_tuple_for_test(&mut iter)? { assert_eq!(tuple.values.len(), 1); row_count += 1; } diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 32efe26e..78390e1b 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -27,8 +27,9 @@ use crate::types::value::{DataValue, TupleMappingRef}; use crate::types::LogicalType; use bumpalo::Bump; use siphasher::sip::SipHasher; +use std::cell::RefCell; use std::hash::{Hash, Hasher}; -use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::io::{Cursor, Read, Seek, SeekFrom}; use std::sync::LazyLock; pub(crate) const BOUND_MIN_TAG: u8 = u8::MIN; @@ -53,6 +54,15 @@ pub type BumpBytes<'bump> = bumpalo::collections::Vec<'bump, u8>; #[derive(Default)] pub struct TableCodec { arena: Bump, + key_buffer: RefCell, +} + +#[derive(Default)] +struct KeyBuffer { + lower: Bytes, + upper: Bytes, + cached_table_name: String, + cached_table_hash: [u8; TABLE_NAME_HASH_LEN], } #[derive(Copy, Clone)] @@ -99,6 +109,11 @@ impl StatisticsCodecType { } impl TableCodec { + #[inline] + pub fn arena(&self) -> &Bump { + &self.arena + } + fn hash_bytes(table_name: &str) -> [u8; 8] { let mut hasher = SipHasher::new(); table_name.hash(&mut hasher); @@ -149,302 +164,523 @@ impl TableCodec { Ok(()) } - /// TableName + Type - /// - /// Tips: - /// 1. Root & View & Hash full key = key_prefix - /// 2. table-name hash is fixed-width (8 bytes), so [prefix_extractor](https://github.com/facebook/rocksdb/wiki/Prefix-Seek#defining-a-prefix) can be enabled in rocksdb - fn key_prefix(&self, ty: CodecType, name: &str) -> BumpBytes<'_> { - let mut table_bytes = BumpBytes::new_in(&self.arena); - table_bytes.extend_from_slice(Self::hash_bytes(name).as_slice()); - + #[inline] + fn write_key_prefix(out: &mut Bytes, ty: CodecType, table_hash: [u8; TABLE_NAME_HASH_LEN]) { + out.clear(); match ty { CodecType::Column => { - table_bytes.push(b'0'); + out.extend_from_slice(&table_hash); + out.push(b'0'); } CodecType::IndexMeta => { - table_bytes.push(b'1'); + out.extend_from_slice(&table_hash); + out.push(b'1'); } CodecType::Statistics => { - table_bytes.push(b'3'); + out.extend_from_slice(&table_hash); + out.push(b'3'); } CodecType::Index => { - table_bytes.push(b'7'); + out.extend_from_slice(&table_hash); + out.push(b'7'); } CodecType::Tuple => { - table_bytes.push(b'8'); + out.extend_from_slice(&table_hash); + out.push(b'8'); } CodecType::Root => { - let mut bytes = BumpBytes::new_in(&self.arena); - - bytes.extend_from_slice(&ROOT_BYTES); - bytes.push(BOUND_MIN_TAG); - bytes.extend_from_slice(&table_bytes); - - return bytes; + out.extend_from_slice(ROOT_BYTES.as_slice()); + out.push(BOUND_MIN_TAG); + out.extend_from_slice(&table_hash); } CodecType::View => { - let mut bytes = BumpBytes::new_in(&self.arena); - - bytes.extend_from_slice(&VIEW_BYTES); - bytes.push(BOUND_MIN_TAG); - bytes.extend_from_slice(&table_bytes); - - return bytes; + out.extend_from_slice(VIEW_BYTES.as_slice()); + out.push(BOUND_MIN_TAG); + out.extend_from_slice(&table_hash); } CodecType::Hash => { - let mut bytes = BumpBytes::new_in(&self.arena); - - bytes.extend_from_slice(&HASH_BYTES); - bytes.push(BOUND_MIN_TAG); - bytes.append(&mut table_bytes); - bytes.extend_from_slice(&table_bytes); - - return bytes; + out.extend_from_slice(HASH_BYTES.as_slice()); + out.push(BOUND_MIN_TAG); + out.extend_from_slice(&table_hash); } } - - table_bytes } - pub fn tuple_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = self.key_prefix(CodecType::Tuple, table_name); + #[inline] + fn write_global_bound_prefix(out: &mut Bytes, prefix: &[u8], bound: u8) { + out.clear(); + out.extend_from_slice(prefix); + out.push(bound); + } - key_prefix.push(bound_id); - key_prefix + fn with_table_hash( + &self, + table_name: &str, + f: impl FnOnce(&mut KeyBuffer, [u8; TABLE_NAME_HASH_LEN]) -> Result, + ) -> Result { + let mut key_buffer = self.key_buffer.borrow_mut(); + let table_hash = if key_buffer.cached_table_name != table_name { + key_buffer.cached_table_name.clear(); + key_buffer.cached_table_name.push_str(table_name); + key_buffer.cached_table_hash = Self::hash_bytes(table_name); + key_buffer.cached_table_hash + } else { + key_buffer.cached_table_hash }; - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(&mut key_buffer, table_hash) } - pub fn index_meta_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = self.key_prefix(CodecType::IndexMeta, table_name); + /// Key: `{TableName}{TUPLE_TAG}{BOUND_MIN_TAG}{RowID}`. + pub fn with_tuple_key( + &self, + table_name: &str, + tuple_id: &TupleId, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + Self::check_primary_key(tuple_id, 0)?; + self.with_tuple_key_unchecked(table_name, tuple_id, f) + } - key_prefix.push(bound_id); - key_prefix - }; + #[inline] + pub(crate) fn with_tuple_key_unchecked( + &self, + table_name: &str, + tuple_id: &TupleId, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Tuple, table_hash); + lower.push(BOUND_MIN_TAG); + tuple_id.memcomparable_encode(lower)?; - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice()) + }) } - pub fn index_bound( + /// Range bounds covering all tuple keys for a table. + pub fn with_tuple_bound( &self, table_name: &str, - index_id: IndexId, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let op = |bound_id| -> Result { - let mut key_prefix = self.key_prefix(CodecType::Index, table_name); - - key_prefix.write_all(&[BOUND_MIN_TAG])?; - key_prefix.write_all(&index_id.to_le_bytes()[..])?; - key_prefix.write_all(&[bound_id])?; - Ok(key_prefix) - }; + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Tuple, table_hash); + lower.push(BOUND_MIN_TAG); - Ok((op(BOUND_MIN_TAG)?, op(BOUND_MAX_TAG)?)) - } + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::Tuple, table_hash); + upper.push(BOUND_MAX_TAG); - pub fn all_index_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = self.key_prefix(CodecType::Index, table_name); + f(lower.as_slice(), upper.as_slice()) + }) + } - key_prefix.push(bound_id); - key_prefix - }; + /// Key: `{TableName}{INDEX_META_TAG}{BOUND_MIN_TAG}{IndexID}`. + pub fn with_index_meta_key( + &self, + table_name: &str, + index_id: IndexId, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::IndexMeta, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice()) + }) } - pub fn root_table_bound(&self) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = BumpBytes::new_in(&self.arena); + /// Range bounds covering all index metadata for a table. + pub fn with_index_meta_bound( + &self, + table_name: &str, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::IndexMeta, table_hash); + lower.push(BOUND_MIN_TAG); - key_prefix.extend_from_slice(&ROOT_BYTES); - key_prefix.push(bound_id); - key_prefix - }; + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::IndexMeta, table_hash); + upper.push(BOUND_MAX_TAG); - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice(), upper.as_slice()) + }) } - pub fn table_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let mut column_prefix = self.key_prefix(CodecType::Column, table_name); - column_prefix.push(BOUND_MIN_TAG); + /// Range bounds covering a single secondary index. + pub fn with_index_bound( + &self, + table_name: &str, + index_id: IndexId, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Index, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); + lower.push(BOUND_MIN_TAG); + + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::Index, table_hash); + upper.push(BOUND_MIN_TAG); + upper.extend_from_slice(&index_id.to_le_bytes()); + upper.push(BOUND_MAX_TAG); + + f(lower.as_slice(), upper.as_slice()) + }) + } + + /// Range bounds covering all secondary indexes for a table. + pub fn with_all_index_bound( + &self, + table_name: &str, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Index, table_hash); + lower.push(BOUND_MIN_TAG); - let mut index_prefix = self.key_prefix(CodecType::IndexMeta, table_name); - index_prefix.push(BOUND_MAX_TAG); + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::Index, table_hash); + upper.push(BOUND_MAX_TAG); - (column_prefix, index_prefix) + f(lower.as_slice(), upper.as_slice()) + }) } - pub fn columns_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = self.key_prefix(CodecType::Column, table_name); + /// Non-unique index key: + /// `{TableName}{INDEX_TAG}{BOUND_MIN_TAG}{IndexID}{BOUND_MIN_TAG}{DataValue...}{TupleId}` + /// + /// Unique index key: + /// `{TableName}{INDEX_TAG}{BOUND_MIN_TAG}{IndexID}{BOUND_MIN_TAG}{DataValue}` + pub fn with_index_key( + &self, + table_name: &str, + index: &Index, + tuple_id: Option<&TupleId>, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Index, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index.id.to_le_bytes()); + lower.push(BOUND_MIN_TAG); + index.value.memcomparable_encode(lower)?; + + if let Some(tuple_id) = tuple_id { + if matches!(index.ty, IndexType::Normal | IndexType::Composite) { + tuple_id.memcomparable_encode(lower)?; + } + } + + f(lower.as_slice()) + }) + } - key_prefix.push(bound_id); - key_prefix - }; + /// Key: `{TableName}{COLUMN_TAG}{BOUND_MIN_TAG}{ColumnId}`. + pub fn with_column_key( + &self, + col: &ColumnRef, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + if let ColumnRelation::Table { + column_id, + table_name, + is_temp: false, + } = &col.summary().relation + { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Column, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&column_id.to_bytes()); - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice()) + }) + } else { + Err(DatabaseError::invalid_column( + "column does not belong to table".to_string(), + )) + } } - pub fn statistics_bound(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = self.key_prefix(CodecType::Statistics, table_name); + /// Range bounds covering all column metadata for a table. + pub fn with_columns_bound( + &self, + table_name: &str, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Column, table_hash); + lower.push(BOUND_MIN_TAG); - key_prefix.push(bound_id); - key_prefix - }; + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::Column, table_hash); + upper.push(BOUND_MAX_TAG); - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice(), upper.as_slice()) + }) } - pub fn view_bound(&self) -> (BumpBytes<'_>, BumpBytes<'_>) { - let op = |bound_id| { - let mut key_prefix = BumpBytes::new_in(&self.arena); + /// Range bounds spanning a table's `Column` and `IndexMeta` metadata. + pub fn with_table_bound( + &self, + table_name: &str, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Column, table_hash); + lower.push(BOUND_MIN_TAG); - key_prefix.extend_from_slice(&VIEW_BYTES); - key_prefix.push(bound_id); - key_prefix - }; + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::IndexMeta, table_hash); + upper.push(BOUND_MAX_TAG); - (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + f(lower.as_slice(), upper.as_slice()) + }) } - /// Key: {TableName}{TUPLE_TAG}{BOUND_MIN_TAG}{RowID}(Sorted) - /// Value: Tuple - pub fn encode_tuple( + /// Range bounds covering all statistics keys for a table. + pub fn with_statistics_bound( &self, table_name: &str, - tuple: &mut Tuple, - types: &[TupleValueSerializableImpl], - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; - let key = self.encode_tuple_key(table_name, tuple_id)?; + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + + let upper = &mut key_buffer.upper; + Self::write_key_prefix(upper, CodecType::Statistics, table_hash); + upper.push(BOUND_MAX_TAG); - Ok((key, tuple.serialize_to(types, &self.arena)?)) + f(lower.as_slice(), upper.as_slice()) + }) } - pub fn encode_tuple_key( + /// Range bounds covering all statistics keys for one index. + pub fn with_statistics_index_bound( &self, table_name: &str, - tuple_id: &TupleId, - ) -> Result, DatabaseError> { - Self::check_primary_key(tuple_id, 0)?; - - let mut key_prefix = self.key_prefix(CodecType::Tuple, table_name); - key_prefix.push(BOUND_MIN_TAG); + index_id: IndexId, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); - tuple_id.memcomparable_encode(&mut key_prefix)?; + let upper = &mut key_buffer.upper; + upper.clear(); + upper.extend_from_slice(lower); + upper.push(BOUND_MAX_TAG); - Ok(key_prefix) + f(lower.as_slice(), upper.as_slice()) + }) } - pub fn decode_tuple_key(bytes: &[u8], pk_ty: &LogicalType) -> Result { - DataValue::memcomparable_decode(&mut Cursor::new(&bytes[TUPLE_KEY_PREFIX_LEN..]), pk_ty) - } + /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{ROOT_TAG}`. + pub fn with_statistics_meta_key( + &self, + table_name: &str, + index_id: IndexId, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); + lower.push(StatisticsCodecType::Root.tag()); - #[inline] - pub fn decode_tuple( - deserializers: &[TupleValueSerializableImpl], - tuple_id: Option, - bytes: &[u8], - values_len: usize, - total_len: usize, - ) -> Result { - Tuple::deserialize_from(deserializers, tuple_id, bytes, values_len, total_len) + f(lower.as_slice()) + }) } - pub fn encode_index_meta_key( + /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_META_TAG}`. + pub fn with_statistics_sketch_meta_key( &self, table_name: &str, index_id: IndexId, - ) -> Result, DatabaseError> { - let mut key_prefix = self.key_prefix(CodecType::IndexMeta, table_name); + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); + lower.push(StatisticsCodecType::SketchMeta.tag()); - key_prefix.write_all(&[BOUND_MIN_TAG])?; - key_prefix.write_all(&index_id.to_le_bytes()[..])?; - Ok(key_prefix) + f(lower.as_slice()) + }) } - /// Key: {TableName}{INDEX_META_TAG}{BOUND_MIN_TAG}{IndexID} - /// Value: IndexMeta - pub fn encode_index_meta( + /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_PAGE_TAG}{BOUND_MIN_TAG}{ROW_ID}{BOUND_MIN_TAG}{PAGE_ID}`. + pub fn with_statistics_sketch_page_key( &self, table_name: &str, - index_meta: &IndexMeta, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key_bytes = self.encode_index_meta_key(table_name, index_meta.id)?; - - let mut value_bytes = BumpBytes::new_in(&self.arena); - index_meta.encode(&mut value_bytes, true, &mut ReferenceTables::new())?; + index_id: IndexId, + sketch_page: &CountMinSketchPage, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); + lower.push(StatisticsCodecType::SketchPage.tag()); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&(sketch_page.row_idx() as u32).to_be_bytes()); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&(sketch_page.page_idx() as u32).to_be_bytes()); + + f(lower.as_slice()) + }) + } + + /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{BUCKET_TAG}{BOUND_MIN_TAG}{ORDINAL}`. + pub fn with_statistics_bucket_key( + &self, + table_name: &str, + index_id: IndexId, + ordinal: u32, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Statistics, table_hash); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&index_id.to_le_bytes()); + lower.push(StatisticsCodecType::Bucket.tag()); + lower.push(BOUND_MIN_TAG); + lower.extend_from_slice(&ordinal.to_be_bytes()); + + f(lower.as_slice()) + }) + } + + /// Key: `View{BOUND_MIN_TAG}{ViewNameHash}`. + pub fn with_view_key( + &self, + view_name: &str, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(view_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::View, table_hash); - Ok((key_bytes, value_bytes)) + f(lower.as_slice()) + }) } - pub fn decode_index_meta(bytes: &[u8]) -> Result { - IndexMeta::decode::(&mut Cursor::new(bytes), None, &EMPTY_REFERENCE_TABLES) + /// Range bounds covering all view definitions. + pub fn with_view_bound( + &self, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + let mut key_buffer = self.key_buffer.borrow_mut(); + Self::write_global_bound_prefix( + &mut key_buffer.lower, + VIEW_BYTES.as_slice(), + BOUND_MIN_TAG, + ); + Self::write_global_bound_prefix( + &mut key_buffer.upper, + VIEW_BYTES.as_slice(), + BOUND_MAX_TAG, + ); + + f(key_buffer.lower.as_slice(), key_buffer.upper.as_slice()) } - /// NonUnique Index: - /// Key: {TableName}{INDEX_TAG}{BOUND_MIN_TAG}{IndexID}{BOUND_MIN_TAG}{DataValue1}{BOUND_MIN_TAG}{DataValue2} .. {TupleId} - /// Value: TupleID - /// - /// Unique Index: - /// Key: {TableName}{INDEX_TAG}{BOUND_MIN_TAG}{IndexID}{BOUND_MIN_TAG}{DataValue} - /// Value: TupleID - /// - /// Tips: The unique index has only one ColumnID and one corresponding DataValue, - /// so it can be positioned directly. - pub fn encode_index( + /// Key: `Root{BOUND_MIN_TAG}{TableNameHash}`. + pub fn with_root_table_key( &self, - name: &str, - index: &Index, - tuple_id: &TupleId, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_index_key(name, index, Some(tuple_id))?; - let mut bytes = BumpBytes::new_in(&self.arena); + table_name: &str, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Root, table_hash); + + f(lower.as_slice()) + }) + } - bincode::serialize_into(&mut bytes, tuple_id)?; + /// Range bounds covering all root table metadata. + pub fn with_root_table_bound( + &self, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + let mut key_buffer = self.key_buffer.borrow_mut(); + Self::write_global_bound_prefix( + &mut key_buffer.lower, + ROOT_BYTES.as_slice(), + BOUND_MIN_TAG, + ); + Self::write_global_bound_prefix( + &mut key_buffer.upper, + ROOT_BYTES.as_slice(), + BOUND_MAX_TAG, + ); - Ok((key, bytes)) + f(key_buffer.lower.as_slice(), key_buffer.upper.as_slice()) } - pub fn encode_index_bound_key( + /// Key: `Hash{BOUND_MIN_TAG}{TableNameHash}`. + pub fn with_table_hash_key( &self, - name: &str, - index: &Index, - is_upper: bool, - ) -> Result, DatabaseError> { - let mut key_prefix = self.key_prefix(CodecType::Index, name); - key_prefix.push(BOUND_MIN_TAG); - key_prefix.extend_from_slice(&index.id.to_le_bytes()); - key_prefix.push(BOUND_MIN_TAG); - - index.value.memcomparable_encode(&mut key_prefix)?; - if is_upper { - key_prefix.push(BOUND_MAX_TAG) - } + table_name: &str, + f: impl FnOnce(&[u8]) -> Result, + ) -> Result { + self.with_table_hash(table_name, |key_buffer, table_hash| { + let lower = &mut key_buffer.lower; + Self::write_key_prefix(lower, CodecType::Hash, table_hash); - Ok(key_prefix) + f(lower.as_slice()) + }) + } + pub fn decode_tuple_key(bytes: &[u8], pk_ty: &LogicalType) -> Result { + DataValue::memcomparable_decode(&mut Cursor::new(&bytes[TUPLE_KEY_PREFIX_LEN..]), pk_ty) + } + + #[inline] + pub fn decode_tuple_into( + tuple: &mut Tuple, + deserializers: &[TupleValueSerializableImpl], + tuple_id: Option, + bytes: &[u8], + total_len: usize, + ) -> Result<(), DatabaseError> { + tuple.pk = tuple_id; + tuple.deserialize_from_into(deserializers, bytes, total_len) } - pub fn encode_index_key( + pub fn encode_index_meta_value( &self, - name: &str, - index: &Index, - tuple_id: Option<&TupleId>, + index_meta: &IndexMeta, ) -> Result, DatabaseError> { - let mut key_prefix = self.encode_index_bound_key(name, index, false)?; + let mut value_bytes = BumpBytes::new_in(&self.arena); + index_meta.encode(&mut value_bytes, true, &mut ReferenceTables::new())?; + Ok(value_bytes) + } - if let Some(tuple_id) = tuple_id { - if matches!(index.ty, IndexType::Normal | IndexType::Composite) { - tuple_id.memcomparable_encode(&mut key_prefix)?; - } - } - Ok(key_prefix) + pub fn decode_index_meta(bytes: &[u8]) -> Result { + IndexMeta::decode::(&mut Cursor::new(bytes), None, &EMPTY_REFERENCE_TABLES) } pub fn decode_index_key( @@ -461,35 +697,14 @@ impl TableCodec { Ok(bincode::deserialize_from(&mut Cursor::new(bytes))?) } - /// Key: {TableName}{COLUMN_TAG}{BOUND_MIN_TAG}{ColumnId} - /// Value: ColumnCatalog - /// - /// Tips: the `0` for bound range - pub fn encode_column( + pub fn encode_column_value( &self, col: &ColumnRef, reference_tables: &mut ReferenceTables, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - if let ColumnRelation::Table { - column_id, - table_name, - is_temp: false, - } = &col.summary().relation - { - let mut key_prefix = self.key_prefix(CodecType::Column, table_name); - - key_prefix.write_all(&[BOUND_MIN_TAG])?; - key_prefix.write_all(&column_id.to_bytes()[..])?; - - let mut column_bytes = BumpBytes::new_in(&self.arena); - col.encode(&mut column_bytes, true, reference_tables)?; - - Ok((key_prefix, column_bytes)) - } else { - Err(DatabaseError::invalid_column( - "column does not belong to table".to_string(), - )) - } + ) -> Result, DatabaseError> { + let mut column_bytes = BumpBytes::new_in(&self.arena); + col.encode(&mut column_bytes, true, reference_tables)?; + Ok(column_bytes) } pub fn decode_column( @@ -500,42 +715,13 @@ impl TableCodec { ColumnRef::decode::(reader, None, reference_tables) } - /// Key: {TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID} - /// Value: StatisticsMetaRoot - pub fn encode_statistics_meta( + pub fn encode_statistics_meta_value( &self, - table_name: &str, statistics_meta: &StatisticsMetaRoot, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_statistics_meta_key(table_name, statistics_meta.index_id()); + ) -> Result, DatabaseError> { let mut value = BumpBytes::new_in(&self.arena); - statistics_meta.encode(&mut value, true, &mut ReferenceTables::new())?; - - Ok((key, value)) - } - - fn encode_statistics_index_prefix(&self, table_name: &str, index_id: IndexId) -> BumpBytes<'_> { - let mut key = self.key_prefix(CodecType::Statistics, table_name); - - key.push(BOUND_MIN_TAG); - key.extend(index_id.to_le_bytes()); - key - } - - fn encode_statistics_key_prefix( - &self, - table_name: &str, - index_id: IndexId, - ty: StatisticsCodecType, - ) -> BumpBytes<'_> { - let mut key = self.encode_statistics_index_prefix(table_name, index_id); - key.push(ty.tag()); - key - } - - pub fn encode_statistics_meta_key(&self, table_name: &str, index_id: IndexId) -> BumpBytes<'_> { - self.encode_statistics_key_prefix(table_name, index_id, StatisticsCodecType::Root) + Ok(value) } pub fn decode_statistics_meta( @@ -544,28 +730,13 @@ impl TableCodec { StatisticsMetaRoot::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) } - /// Key: {TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_META_TAG} - /// Value: CountMinSketchMeta - pub fn encode_statistics_sketch_meta( + pub fn encode_statistics_sketch_meta_value( &self, - table_name: &str, - index_id: IndexId, sketch_meta: &CountMinSketchMeta, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_statistics_sketch_meta_key(table_name, index_id); + ) -> Result, DatabaseError> { let mut value = BumpBytes::new_in(&self.arena); - sketch_meta.encode(&mut value, true, &mut ReferenceTables::new())?; - - Ok((key, value)) - } - - pub fn encode_statistics_sketch_meta_key( - &self, - table_name: &str, - index_id: IndexId, - ) -> BumpBytes<'_> { - self.encode_statistics_key_prefix(table_name, index_id, StatisticsCodecType::SketchMeta) + Ok(value) } pub fn decode_statistics_sketch_meta( @@ -574,40 +745,13 @@ impl TableCodec { CountMinSketchMeta::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) } - /// Key: {TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_PAGE_TAG}{BOUND_MIN_TAG}{ROW_ID}{BOUND_MIN_TAG}{PAGE_ID} - /// Value: CountMinSketchPage - pub fn encode_statistics_sketch_page( + pub fn encode_statistics_sketch_page_value( &self, - table_name: &str, - index_id: IndexId, sketch_page: &CountMinSketchPage, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_statistics_sketch_page_key(table_name, index_id, sketch_page)?; + ) -> Result, DatabaseError> { let mut value = BumpBytes::new_in(&self.arena); - sketch_page.encode(&mut value, true, &mut ReferenceTables::new())?; - - Ok((key, value)) - } - - pub fn encode_statistics_sketch_page_key( - &self, - table_name: &str, - index_id: IndexId, - sketch_page: &CountMinSketchPage, - ) -> Result, DatabaseError> { - let mut key = self.encode_statistics_key_prefix( - table_name, - index_id, - StatisticsCodecType::SketchPage, - ); - - key.write_all(&[BOUND_MIN_TAG])?; - key.write_all(&(sketch_page.row_idx() as u32).to_be_bytes())?; - key.write_all(&[BOUND_MIN_TAG])?; - key.write_all(&(sketch_page.page_idx() as u32).to_be_bytes())?; - - Ok(key) + Ok(value) } pub fn decode_statistics_sketch_page( @@ -616,44 +760,20 @@ impl TableCodec { CountMinSketchPage::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) } - pub fn encode_statistics_bucket( + pub fn encode_statistics_bucket_value( &self, - table_name: &str, - index_id: IndexId, - ordinal: u32, bucket: &Bucket, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_statistics_bucket_key(table_name, index_id, ordinal); + ) -> Result, DatabaseError> { let mut value = BumpBytes::new_in(&self.arena); - bucket.encode(&mut value, true, &mut ReferenceTables::new())?; - - Ok((key, value)) - } - - pub fn encode_statistics_bucket_key( - &self, - table_name: &str, - index_id: IndexId, - ordinal: u32, - ) -> BumpBytes<'_> { - let mut key = - self.encode_statistics_key_prefix(table_name, index_id, StatisticsCodecType::Bucket); - - key.push(BOUND_MIN_TAG); - key.extend_from_slice(&ordinal.to_be_bytes()); - key + Ok(value) } pub(crate) fn decode_statistics_codec_type( &self, - table_name: &str, - index_id: IndexId, key: &[u8], ) -> Result { - let prefix_len = self - .encode_statistics_index_prefix(table_name, index_id) - .len(); + let prefix_len = TUPLE_KEY_PREFIX_LEN + INDEX_ID_LEN; let Some(tag) = key.get(prefix_len).copied() else { return Err(DatabaseError::InvalidValue( "statistics key is too short".to_string(), @@ -678,26 +798,7 @@ impl TableCodec { Ok(u32::from_be_bytes(ordinal_bytes.try_into().unwrap())) } - pub fn statistics_index_bound( - &self, - table_name: &str, - index_id: IndexId, - ) -> (BumpBytes<'_>, BumpBytes<'_>) { - let min = self.encode_statistics_index_prefix(table_name, index_id); - let mut max = min.clone(); - max.push(BOUND_MAX_TAG); - - (min, max) - } - - /// Key: View{BOUND_MIN_TAG}{ViewName} - /// Value: View - pub fn encode_view( - &self, - view: &View, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_view_key(&view.name); - + pub fn encode_view_value(&self, view: &View) -> Result, DatabaseError> { let mut reference_tables = ReferenceTables::new(); let mut bytes = BumpBytes::new_in(&self.arena); bytes.resize(4, 0u8); @@ -710,11 +811,7 @@ impl TableCodec { }; bytes[..4].copy_from_slice(&(reference_tables_pos as u32).to_le_bytes()); - Ok((key, bytes)) - } - - pub fn encode_view_key(&self, view_name: &str) -> BumpBytes<'_> { - self.key_prefix(CodecType::View, view_name) + Ok(bytes) } pub fn decode_view( @@ -734,21 +831,13 @@ impl TableCodec { View::decode(&mut cursor, Some(drive), &reference_tables) } - /// Key: Root{BOUND_MIN_TAG}{TableName} - /// Value: TableMeta - pub fn encode_root_table( + pub fn encode_root_table_value( &self, meta: &TableMeta, - ) -> Result<(BumpBytes<'_>, BumpBytes<'_>), DatabaseError> { - let key = self.encode_root_table_key(&meta.table_name); - + ) -> Result, DatabaseError> { let mut meta_bytes = BumpBytes::new_in(&self.arena); meta.encode(&mut meta_bytes, true, &mut ReferenceTables::new())?; - Ok((key, meta_bytes)) - } - - pub fn encode_root_table_key(&self, table_name: &str) -> BumpBytes<'_> { - self.key_prefix(CodecType::Root, table_name) + Ok(meta_bytes) } pub fn decode_root_table(bytes: &[u8]) -> Result { @@ -756,17 +845,6 @@ impl TableCodec { TableMeta::decode::(&mut bytes, None, &EMPTY_REFERENCE_TABLES) } - - pub fn encode_table_hash_key(&self, table_name: &str) -> BumpBytes<'_> { - self.key_prefix(CodecType::Hash, table_name) - } - - pub fn encode_table_hash(&self, table_name: &str) -> (BumpBytes<'_>, BumpBytes<'_>) { - ( - self.key_prefix(CodecType::Hash, table_name), - BumpBytes::new_in(&self.arena), - ) - } } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -781,7 +859,7 @@ mod tests { use crate::optimizer::core::statistics_meta::StatisticsMeta; use crate::serdes::ReferenceTables; use crate::storage::rocksdb::RocksTransaction; - use crate::storage::table_codec::{BumpBytes, TableCodec}; + use crate::storage::table_codec::{Bytes, TableCodec}; use crate::storage::Storage; use crate::types::index::{Index, IndexMeta, IndexType}; use crate::types::tuple::Tuple; @@ -813,32 +891,33 @@ mod tests { #[test] fn test_table_codec_tuple() -> Result<(), DatabaseError> { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let table_catalog = build_table_codec(); - let mut tuple = Tuple::new( + let expected = Tuple::new( Some(DataValue::Int32(0)), vec![DataValue::Int32(0), DataValue::Decimal(Decimal::new(1, 0))], ); - let (_, bytes) = table_codec.encode_tuple( - &table_catalog.name, - &mut tuple, + let bytes = expected.serialize_to( &[ LogicalType::Integer.serializable(), LogicalType::Decimal(None, None).serializable(), ], + table_codec.arena(), )?; let deserializers = table_catalog .columns() .map(|column| column.datatype().serializable()) .collect_vec(); - tuple.pk = None; + let mut tuple = Tuple::default(); + TableCodec::decode_tuple_into(&mut tuple, &deserializers, None, &bytes, 2)?; assert_eq!( - TableCodec::decode_tuple(&deserializers, None, &bytes, deserializers.len(), 2,)?, - tuple + tuple, + Tuple::new( + None, + vec![DataValue::Int32(0), DataValue::Decimal(Decimal::new(1, 0))] + ) ); Ok(()) @@ -846,12 +925,10 @@ mod tests { #[test] fn test_root_catalog() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let table_catalog = build_table_codec(); - let (_, bytes) = table_codec - .encode_root_table(&TableMeta { + let bytes = table_codec + .encode_root_table_value(&TableMeta { table_name: table_catalog.name.clone(), }) .unwrap(); @@ -863,9 +940,7 @@ mod tests { #[test] fn test_table_codec_statistics_meta() -> Result<(), DatabaseError> { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let index_meta = IndexMeta { id: 0, column_ids: vec![Ulid::new()], @@ -881,9 +956,9 @@ mod tests { builder.append(&DataValue::Int32(value))?; } let (histogram, sketch) = builder.build(2)?; - let (root, buckets) = StatisticsMeta::new(histogram).into_parts(); + let (root, buckets, _) = StatisticsMeta::new(histogram, sketch.clone()).into_parts(); - let (_, root_bytes) = table_codec.encode_statistics_meta("t1", &root)?; + let root_bytes = table_codec.encode_statistics_meta_value(&root)?; let decoded_root = TableCodec::decode_statistics_meta::(&root_bytes)?; assert_eq!(decoded_root.index_id(), root.index_id()); assert_eq!( @@ -896,29 +971,33 @@ mod tests { ); let (sketch_meta, mut sketch_pages) = sketch.clone().into_storage_parts(1); - let (_, sketch_meta_bytes) = - table_codec.encode_statistics_sketch_meta("t1", 0, &sketch_meta)?; + let sketch_meta_bytes = table_codec.encode_statistics_sketch_meta_value(&sketch_meta)?; let decoded_sketch_meta = TableCodec::decode_statistics_sketch_meta::(&sketch_meta_bytes)?; assert_eq!(decoded_sketch_meta.width(), sketch_meta.width()); assert_eq!(decoded_sketch_meta.k_num(), sketch_meta.k_num()); let first_sketch_page = sketch_pages.next().unwrap(); - let (_, sketch_page_bytes) = - table_codec.encode_statistics_sketch_page("t1", 0, &first_sketch_page)?; + let sketch_page_bytes = + table_codec.encode_statistics_sketch_page_value(&first_sketch_page)?; let decoded_sketch_page = TableCodec::decode_statistics_sketch_page::(&sketch_page_bytes)?; assert_eq!(decoded_sketch_page.counters(), first_sketch_page.counters()); - let bucket0_key = table_codec.encode_statistics_bucket_key("t1", 0, 0); - let bucket1_key = table_codec.encode_statistics_bucket_key("t1", 0, 1); + let bucket0_key = table_codec + .with_statistics_bucket_key("t1", 0, 0, |key| Ok::<_, DatabaseError>(key.to_vec()))?; + let bucket1_key = table_codec + .with_statistics_bucket_key("t1", 0, 1, |key| Ok::<_, DatabaseError>(key.to_vec()))?; assert!(bucket0_key < bucket1_key); - let (bucket0_min, bucket0_max) = table_codec.statistics_index_bound("t1", 0); - assert!(bucket0_key >= bucket0_min); - assert!(bucket1_key <= bucket0_max); + let (bucket0_min, bucket0_max) = + table_codec.with_statistics_index_bound("t1", 0, |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + })?; + assert!(bucket0_key.as_slice() >= bucket0_min.as_slice()); + assert!(bucket1_key.as_slice() <= bucket0_max.as_slice()); - let (_, bucket_bytes) = table_codec.encode_statistics_bucket("t1", 0, 0, &buckets[0])?; + let bucket_bytes = table_codec.encode_statistics_bucket_value(&buckets[0])?; let decoded_bucket = TableCodec::decode_statistics_bucket::(&bucket_bytes)?; assert_eq!(decoded_bucket, buckets[0]); @@ -936,9 +1015,7 @@ mod tests { #[test] fn test_table_codec_index_meta() -> Result<(), DatabaseError> { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let index_meta = IndexMeta { id: 0, column_ids: vec![Ulid::new()], @@ -948,7 +1025,7 @@ mod tests { name: "index_1".to_string(), ty: IndexType::PrimaryKey { is_multiple: false }, }; - let (_, bytes) = table_codec.encode_index_meta("T1", &index_meta)?; + let bytes = table_codec.encode_index_meta_value(&index_meta)?; assert_eq!( TableCodec::decode_index_meta::(&bytes)?, @@ -960,14 +1037,8 @@ mod tests { #[test] fn test_table_codec_index() -> Result<(), DatabaseError> { - let table_codec = TableCodec { - arena: Default::default(), - }; - let table_catalog = build_table_codec(); - let value = Arc::new(DataValue::Int32(0)); - let index = Index::new(0, &value, IndexType::PrimaryKey { is_multiple: false }); let tuple_id = DataValue::Int32(0); - let (_, bytes) = table_codec.encode_index(&table_catalog.name, &index, &tuple_id)?; + let bytes = bincode::serialize(&tuple_id)?; assert_eq!(TableCodec::decode_index(&bytes)?, tuple_id); @@ -990,11 +1061,9 @@ mod tests { let mut reference_tables = ReferenceTables::new(); - let table_codec = TableCodec { - arena: Default::default(), - }; - let (_, bytes) = table_codec - .encode_column(&col, &mut reference_tables) + let table_codec = TableCodec::default(); + let bytes = table_codec + .encode_column_value(&col, &mut reference_tables) .unwrap(); let mut cursor = Cursor::new(bytes); let decode_col = @@ -1007,9 +1076,7 @@ mod tests { #[test] fn test_table_codec_view() -> Result<(), DatabaseError> { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let table_state = build_t1_table()?; // Subquery { @@ -1021,7 +1088,7 @@ mod tests { name: "view_subquery".to_string().into(), plan: Box::new(plan), }; - let (_, bytes) = table_codec.encode_view(&view)?; + let bytes = table_codec.encode_view_value(&view)?; let transaction = table_state.storage.transaction()?; assert_eq!( @@ -1037,7 +1104,7 @@ mod tests { name: "view_filter".to_string().into(), plan: Box::new(plan), }; - let (_, bytes) = table_codec.encode_view(&view)?; + let bytes = table_codec.encode_view_value(&view)?; let transaction = table_state.storage.transaction()?; assert_eq!( @@ -1053,7 +1120,7 @@ mod tests { name: "view_join".to_string().into(), plan: Box::new(plan), }; - let (_, bytes) = table_codec.encode_view(&view)?; + let bytes = table_codec.encode_view_value(&view)?; let transaction = table_state.storage.transaction()?; assert_eq!( @@ -1068,9 +1135,7 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_column_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |col_id: usize, table_name: &str| { let mut col = ColumnCatalog::new( @@ -1085,10 +1150,11 @@ mod tests { is_temp: false, }; - let (key, _) = table_codec - .encode_column(&ColumnRef::from(col), &mut ReferenceTables::new()) - .unwrap(); - key + table_codec + .with_column_key(&ColumnRef::from(col), |key| { + Ok::<_, DatabaseError>(key.to_vec()) + }) + .unwrap() }; set.insert(op(0, "T0")); @@ -1103,12 +1169,16 @@ mod tests { set.insert(op(0, "T2")); set.insert(op(0, "T2")); - let (min, max) = table_codec.columns_bound("T1"); + let (min, max) = table_codec + .with_columns_bound("T1", |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + }) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( - Bound::Included(&min), - Bound::Included(&max), + .range::<[u8], _>(( + Bound::Included(min.as_slice()), + Bound::Included(max.as_slice()), )) .collect_vec(); @@ -1122,9 +1192,7 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_meta_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |index_id: usize, table_name: &str| { let index_meta = IndexMeta { @@ -1137,10 +1205,11 @@ mod tests { ty: IndexType::PrimaryKey { is_multiple: false }, }; - let (key, _) = table_codec - .encode_index_meta(table_name, &index_meta) - .unwrap(); - key + table_codec + .with_index_meta_key(table_name, index_meta.id, |key| { + Ok::<_, DatabaseError>(key.to_vec()) + }) + .unwrap() }; set.insert(op(0, "T0")); @@ -1155,12 +1224,16 @@ mod tests { set.insert(op(1, "T2")); set.insert(op(2, "T2")); - let (min, max) = table_codec.index_meta_bound("T1"); + let (min, max) = table_codec + .with_index_meta_bound("T1", |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + }) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( - Bound::Included(&min), - Bound::Included(&max), + .range::<[u8], _>(( + Bound::Included(min.as_slice()), + Bound::Included(max.as_slice()), )) .collect_vec(); @@ -1174,9 +1247,7 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let column = ColumnCatalog::new( "".to_string(), @@ -1194,7 +1265,9 @@ mod tests { ); table_codec - .encode_index_key(table_name, &index, None) + .with_index_key(table_name, &index, None, |key| { + Ok::<_, DatabaseError>(key.to_vec()) + }) .unwrap() }; @@ -1210,10 +1283,14 @@ mod tests { set.insert(op(DataValue::Int32(1), 2, &table_catalog.name)); set.insert(op(DataValue::Int32(2), 2, &table_catalog.name)); - let (min, max) = table_codec.index_bound(&table_catalog.name, 1).unwrap(); + let (min, max) = table_codec + .with_index_bound(&table_catalog.name, 1, |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + }) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( + .range::, Bound<&Bytes>)>(( Bound::Included(&min), Bound::Included(&max), )) @@ -1229,9 +1306,7 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_all_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |value: DataValue, index_id: usize, table_name: &str| { let value = Arc::new(value); @@ -1242,7 +1317,9 @@ mod tests { ); table_codec - .encode_index_key(table_name, &index, None) + .with_index_key(table_name, &index, None, |key| { + Ok::<_, DatabaseError>(key.to_vec()) + }) .unwrap() }; @@ -1258,10 +1335,14 @@ mod tests { set.insert(op(DataValue::Int32(1), 0, "T2")); set.insert(op(DataValue::Int32(2), 0, "T2")); - let (min, max) = table_codec.all_index_bound("T1"); + let (min, max) = table_codec + .with_all_index_bound("T1", |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + }) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( + .range::, Bound<&Bytes>)>(( Bound::Included(&min), Bound::Included(&max), )) @@ -1277,13 +1358,13 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_tuple_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |tuple_id: DataValue, table_name: &str| { table_codec - .encode_tuple_key(table_name, &Arc::new(tuple_id)) + .with_tuple_key(table_name, &Arc::new(tuple_id), |key| { + Ok::<_, DatabaseError>(key.to_vec()) + }) .unwrap() }; @@ -1299,10 +1380,14 @@ mod tests { set.insert(op(DataValue::Int32(1), "T2")); set.insert(op(DataValue::Int32(2), "T2")); - let (min, max) = table_codec.tuple_bound("T1"); + let (min, max) = table_codec + .with_tuple_bound("T1", |min, max| { + Ok::<_, DatabaseError>((min.to_vec(), max.to_vec())) + }) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( + .range::, Bound<&Bytes>)>(( Bound::Included(&min), Bound::Included(&max), )) @@ -1318,16 +1403,16 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_root_codec_name_bound() { - let table_codec = TableCodec { - arena: Default::default(), + let table_codec = TableCodec::default(); + let mut set: BTreeSet = BTreeSet::new(); + let op = |table_name: &str| { + table_codec + .with_root_table_key(table_name, |key| Ok::<_, DatabaseError>(key.to_vec())) + .unwrap() }; - let mut set: BTreeSet = BTreeSet::new(); - let op = |table_name: &str| table_codec.encode_root_table_key(table_name); - let mut value_0 = BumpBytes::new_in(&table_codec.arena); - value_0.push(b'A'); - let mut value_1 = BumpBytes::new_in(&table_codec.arena); - value_1.push(b'Z'); + let value_0 = Bytes::from([b'A'].as_slice()); + let value_1 = Bytes::from([b'Z'].as_slice()); set.insert(value_0); set.insert(value_1); @@ -1335,10 +1420,12 @@ mod tests { set.insert(op("T1")); set.insert(op("T2")); - let (min, max) = table_codec.root_table_bound(); + let (min, max) = table_codec + .with_root_table_bound(|min, max| Ok::<_, DatabaseError>((min.to_vec(), max.to_vec()))) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( + .range::, Bound<&Bytes>)>(( Bound::Included(&min), Bound::Included(&max), )) @@ -1352,16 +1439,16 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_view_codec_name_bound() { - let table_codec = TableCodec { - arena: Default::default(), - }; + let table_codec = TableCodec::default(); let mut set = BTreeSet::new(); - let op = |view_name: &str| table_codec.encode_view_key(view_name); + let op = |view_name: &str| { + table_codec + .with_view_key(view_name, |key| Ok::<_, DatabaseError>(key.to_vec())) + .unwrap() + }; - let mut value_0 = BumpBytes::new_in(&table_codec.arena); - value_0.push(b'A'); - let mut value_1 = BumpBytes::new_in(&table_codec.arena); - value_1.push(b'Z'); + let value_0 = Bytes::from([b'A'].as_slice()); + let value_1 = Bytes::from([b'Z'].as_slice()); set.insert(value_0); set.insert(value_1); @@ -1370,10 +1457,12 @@ mod tests { set.insert(op("V1")); set.insert(op("V2")); - let (min, max) = table_codec.view_bound(); + let (min, max) = table_codec + .with_view_bound(|min, max| Ok::<_, DatabaseError>((min.to_vec(), max.to_vec()))) + .unwrap(); let vec = set - .range::, Bound<&BumpBytes>)>(( + .range::, Bound<&Bytes>)>(( Bound::Included(&min), Bound::Included(&max), )) diff --git a/src/types/index.rs b/src/types/index.rs index 40919e77..cc27f33e 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -68,8 +68,12 @@ impl IndexMeta { let mut exprs = Vec::with_capacity(self.column_ids.len()); for column_id in self.column_ids.iter() { - if let Some(column) = table.get_column_by_id(column_id) { - exprs.push(ScalarExpression::column_expr(column.clone())); + if let Some((position, column)) = table + .columns() + .enumerate() + .find(|(_, column)| column.id() == Some(*column_id)) + { + exprs.push(ScalarExpression::column_expr(column.clone(), position)); } else { return Err(DatabaseError::column_not_found(column_id.to_string())); } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 13b4d1f7..42087e4a 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -30,7 +30,7 @@ pub type TupleId = DataValue; pub type Schema = Vec; pub type SchemaRef = Arc; -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct Tuple { pub pk: Option, pub values: Vec, @@ -48,33 +48,30 @@ impl Tuple { } #[inline] - pub fn deserialize_from( + pub fn deserialize_from_into( + &mut self, deserializers: &[TupleValueSerializableImpl], - tuple_id: Option, bytes: &[u8], - values_len: usize, total_len: usize, - ) -> Result { + ) -> Result<(), DatabaseError> { fn is_null(bits: u8, i: usize) -> bool { bits & (1 << (7 - i)) > 0 } let bits_len = (total_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; - let mut values = Vec::with_capacity(values_len); + self.values.clear(); + self.values.reserve(deserializers.len()); let mut cursor = Cursor::new(&bytes[bits_len..]); for (i, deserializer) in deserializers.iter().enumerate() { if is_null(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { - values.push(DataValue::Null); + self.values.push(DataValue::Null); continue; } - deserializer.filling_value(&mut cursor, &mut values)?; + deserializer.filling_value(&mut cursor, &mut self.values)?; } - Ok(Tuple { - pk: tuple_id, - values, - }) + Ok(()) } /// e.g.: bits(u8)..|data_0(len for utf8_1)|utf8_0|data_1| @@ -336,26 +333,32 @@ mod tests { let columns = Arc::new(columns); let arena = Bump::new(); { - let tuple_0 = Tuple::deserialize_from( - &serializers, - tuples[0].pk.clone(), - &tuples[0].serialize_to(&serializers, &arena).unwrap(), - serializers.len(), - columns.len(), - ) - .unwrap(); + let mut tuple_0 = Tuple { + pk: tuples[0].pk.clone(), + values: Vec::with_capacity(serializers.len()), + }; + tuple_0 + .deserialize_from_into( + &serializers, + &tuples[0].serialize_to(&serializers, &arena).unwrap(), + columns.len(), + ) + .unwrap(); assert_eq!(tuples[0], tuple_0); } { - let tuple_1 = Tuple::deserialize_from( - &serializers, - tuples[1].pk.clone(), - &tuples[1].serialize_to(&serializers, &arena).unwrap(), - serializers.len(), - columns.len(), - ) - .unwrap(); + let mut tuple_1 = Tuple { + pk: tuples[1].pk.clone(), + values: Vec::with_capacity(serializers.len()), + }; + tuple_1 + .deserialize_from_into( + &serializers, + &tuples[1].serialize_to(&serializers, &arena).unwrap(), + columns.len(), + ) + .unwrap(); assert_eq!(tuples[1], tuple_1); } @@ -367,14 +370,17 @@ mod tests { columns[2].datatype().skip_serializable(), columns[3].datatype().serializable(), ]; - let tuple_2 = Tuple::deserialize_from( - &projection_serializers, - tuples[0].pk.clone(), - &tuples[0].serialize_to(&serializers, &arena).unwrap(), - 2, - columns.len(), - ) - .unwrap(); + let mut tuple_2 = Tuple { + pk: tuples[0].pk.clone(), + values: Vec::with_capacity(2), + }; + tuple_2 + .deserialize_from_into( + &projection_serializers, + &tuples[0].serialize_to(&serializers, &arena).unwrap(), + columns.len(), + ) + .unwrap(); assert_eq!( tuple_2, @@ -400,14 +406,17 @@ mod tests { false, )); - let tuple_3 = Tuple::deserialize_from( - &multiple_pk_serializers, - multi_pk_tuple.pk.clone(), - &multi_pk_tuple.serialize_to(&serializers, &arena).unwrap(), - serializers.len(), - columns.len(), - ) - .unwrap(); + let mut tuple_3 = Tuple { + pk: multi_pk_tuple.pk.clone(), + values: Vec::with_capacity(serializers.len()), + }; + tuple_3 + .deserialize_from_into( + &multiple_pk_serializers, + &multi_pk_tuple.serialize_to(&serializers, &arena).unwrap(), + columns.len(), + ) + .unwrap(); assert_eq!( tuple_3, diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index eac2d32a..83c1de4d 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -38,6 +38,16 @@ impl<'a> TupleBuilder<'a> { Tuple::new(None, values) } + pub fn build_result_into(tuple: &mut Tuple, message: String) { + tuple.pk = None; + tuple.values.clear(); + tuple.values.push(DataValue::Utf8 { + value: message, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); + } + pub fn build_with_row<'b>( &self, row: impl IntoIterator, diff --git a/src/types/value.rs b/src/types/value.rs index 7a631ffa..7accc5de 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -52,6 +52,50 @@ pub const ONE_DAY_TO_SEC: u32 = 86_400; const ENCODE_GROUP_SIZE: usize = 8; const ENCODE_MARKER: u8 = 0xFF; +pub trait MemComparableBuffer: Write { + fn push_byte(&mut self, byte: u8); + fn extend_bytes(&mut self, bytes: &[u8]); + fn reserve_bytes(&mut self, size: usize); +} + +impl MemComparableBuffer for BumpBytes<'_> { + #[inline] + fn push_byte(&mut self, byte: u8) { + self.push(byte); + } + + #[inline] + fn extend_bytes(&mut self, bytes: &[u8]) { + self.extend_from_slice(bytes); + } + + #[inline] + fn reserve_bytes(&mut self, size: usize) { + if size > 0 { + self.reserve(size); + } + } +} + +impl MemComparableBuffer for Vec { + #[inline] + fn push_byte(&mut self, byte: u8) { + self.push(byte); + } + + #[inline] + fn extend_bytes(&mut self, bytes: &[u8]) { + self.extend_from_slice(bytes); + } + + #[inline] + fn reserve_bytes(&mut self, size: usize) { + if size > 0 { + self.reserve(size); + } + } +} + #[derive(Clone, serde::Serialize, serde::Deserialize)] pub enum Utf8Type { Variable(Option), @@ -679,7 +723,7 @@ impl DataValue { // Refer: https://github.com/facebook/mysql-5.6/wiki/MyRocks-record-format#memcomparable-format #[inline] // FIXME - fn encode_string(b: &mut BumpBytes, data: &[u8]) { + fn encode_string(b: &mut B, data: &[u8]) { let d_len = data.len(); let needed_groups = d_len / ENCODE_GROUP_SIZE + 1; Self::realloc_bytes(b, needed_groups * (ENCODE_GROUP_SIZE + 1)); @@ -690,8 +734,8 @@ impl DataValue { let remain = d_len.saturating_sub(idx); if remain >= ENCODE_GROUP_SIZE { - b.extend_from_slice(&data[idx..idx + ENCODE_GROUP_SIZE]); - b.push(ENCODE_MARKER); + b.extend_bytes(&data[idx..idx + ENCODE_GROUP_SIZE]); + b.push_byte(ENCODE_MARKER); idx += ENCODE_GROUP_SIZE; continue; } @@ -699,14 +743,14 @@ impl DataValue { let pad_count = ENCODE_GROUP_SIZE - remain; if remain > 0 { - b.extend_from_slice(&data[idx..]); + b.extend_bytes(&data[idx..]); } for _ in 0..pad_count { - b.push(0); + b.push_byte(0); } - b.push(ENCODE_MARKER - pad_count as u8); + b.push_byte(ENCODE_MARKER - pad_count as u8); break; } } @@ -744,16 +788,14 @@ impl DataValue { } #[inline] - fn realloc_bytes(b: &mut BumpBytes, size: usize) { - if size > 0 { - b.reserve(size); - } + fn realloc_bytes(b: &mut B, size: usize) { + b.reserve_bytes(size); } #[inline(always)] - pub fn memcomparable_encode_with_null_order( + pub fn memcomparable_encode_with_null_order( &self, - b: &mut BumpBytes, + b: &mut B, nulls_first: bool, ) -> Result<(), DatabaseError> { let (null_tag, not_null_tag) = if nulls_first { @@ -762,10 +804,10 @@ impl DataValue { (NULL_TAG, NOTNULL_TAG) }; if let DataValue::Null = self { - b.push(null_tag); + b.push_byte(null_tag); return Ok(()); } - b.push(not_null_tag); + b.push_byte(not_null_tag); match self { DataValue::Null => (), @@ -782,7 +824,7 @@ impl DataValue { DataValue::UInt32(v) | DataValue::Time32(v, ..) => encode_u!(b, v), DataValue::UInt64(v) => encode_u!(b, v), DataValue::Utf8 { value: v, .. } => Self::encode_string(b, v.as_bytes()), - DataValue::Boolean(v) => b.push(if *v { b'1' } else { b'0' }), + DataValue::Boolean(v) => b.push_byte(if *v { b'1' } else { b'0' }), DataValue::Float32(f) => { let mut u = f.to_bits(); @@ -812,7 +854,7 @@ impl DataValue { for (i, v) in values.iter().enumerate() { v.memcomparable_encode(b)?; if i == last && *is_upper { - b.push(BOUND_MAX_TAG); + b.push_byte(BOUND_MAX_TAG); } } } @@ -822,7 +864,10 @@ impl DataValue { } #[inline] - pub fn memcomparable_encode(&self, b: &mut BumpBytes) -> Result<(), DatabaseError> { + pub fn memcomparable_encode( + &self, + b: &mut B, + ) -> Result<(), DatabaseError> { self.memcomparable_encode_with_null_order(b, false) } @@ -917,43 +962,46 @@ impl DataValue { } // https://github.com/risingwavelabs/memcomparable/blob/main/src/ser.rs#L468 - pub fn serialize_decimal(decimal: Decimal, bytes: &mut BumpBytes) -> Result<(), DatabaseError> { + pub fn serialize_decimal( + decimal: Decimal, + bytes: &mut B, + ) -> Result<(), DatabaseError> { if decimal.is_zero() { - bytes.push(0x15); + bytes.push_byte(0x15); return Ok(()); } let (exponent, significand) = Self::decimal_e_m(decimal); if decimal.is_sign_positive() { match exponent { 11.. => { - bytes.push(0x22); - bytes.push(exponent as u8); + bytes.push_byte(0x22); + bytes.push_byte(exponent as u8); } 0..=10 => { - bytes.push(0x17 + exponent as u8); + bytes.push_byte(0x17 + exponent as u8); } _ => { - bytes.push(0x16); - bytes.push(!(-exponent) as u8); + bytes.push_byte(0x16); + bytes.push_byte(!(-exponent) as u8); } } - bytes.extend_from_slice(&significand) + bytes.extend_bytes(&significand) } else { match exponent { 11.. => { - bytes.push(0x8); - bytes.push(!exponent as u8); + bytes.push_byte(0x8); + bytes.push_byte(!exponent as u8); } 0..=10 => { - bytes.push(0x13 - exponent as u8); + bytes.push_byte(0x13 - exponent as u8); } _ => { - bytes.push(0x14); - bytes.push(-exponent as u8); + bytes.push_byte(0x14); + bytes.push_byte(-exponent as u8); } } for b in significand { - bytes.push(!b); + bytes.push_byte(!b); } } Ok(()) diff --git a/src/wasm.rs b/src/wasm.rs index d2042b22..639618cb 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -14,7 +14,7 @@ #![cfg(target_arch = "wasm32")] -use crate::db::{DataBaseBuilder, Database, DatabaseIter, ResultIter}; +use crate::db::{DataBaseBuilder, Database, DatabaseIter}; use crate::storage::memory::MemoryStorage; use crate::types::tuple::Tuple; use crate::types::value::DataValue; @@ -38,10 +38,10 @@ fn to_js_err(err: impl ToString) -> JsValue { js_sys::Error::new(&err.to_string()).into() } -fn tuple_to_wasm_row(tuple: Tuple) -> WasmRow { +fn tuple_to_wasm_row(tuple: &Tuple) -> WasmRow { WasmRow { - pk: tuple.pk, - values: tuple.values, + pk: tuple.pk.clone(), + values: tuple.values.clone(), } } @@ -85,10 +85,8 @@ impl WasmDatabase { } pub fn execute(&self, sql: &str) -> Result<(), JsValue> { - let iter = self.inner.run(sql).map_err(to_js_err)?; - for tuple in iter { - tuple.map_err(to_js_err)?; - } + let mut iter = self.inner.run(sql).map_err(to_js_err)?; + while iter.next_borrowed_tuple().map_err(to_js_err)?.is_some() {} Ok(()) } } @@ -102,10 +100,9 @@ impl WasmResultIter { .inner .as_mut() .ok_or_else(|| to_js_err("iterator already consumed"))?; - match iter.next() { - Some(Ok(tuple)) => serde_wasm_bindgen::to_value(&tuple_to_wasm_row(tuple)) + match iter.next_borrowed_tuple().map_err(to_js_err)? { + Some(tuple) => serde_wasm_bindgen::to_value(&tuple_to_wasm_row(tuple)) .map_err(|e| to_js_err(format!("serialize row: {e}"))), - Some(Err(err)) => Err(to_js_err(err.to_string())), None => Ok(JsValue::undefined()), } } @@ -138,8 +135,7 @@ impl WasmResultIter { .take() .ok_or_else(|| to_js_err("iterator already consumed"))?; let mut rows = Vec::new(); - for tuple in &mut iter { - let tuple = tuple.map_err(to_js_err)?; + while let Some(tuple) = iter.next_borrowed_tuple().map_err(to_js_err)? { rows.push(tuple_to_wasm_row(tuple)); } iter.done().map_err(to_js_err)?; diff --git a/tests/slt/alter_table.slt b/tests/slt/alter_table.slt index 0df7b2a9..898755b0 100644 --- a/tests/slt/alter_table.slt +++ b/tests/slt/alter_table.slt @@ -33,6 +33,8 @@ alter table t1 drop column if exists v1 statement error alter table t1 drop column id +# Regression: after dropping a non-key column, the remaining logical output +# order must still read v2/v3 rather than leaking the removed v1 slot. query IIII rowsort select * from t1 ---- diff --git a/tests/slt/crdb/join.slt b/tests/slt/crdb/join.slt index 63bd0e37..9fd23d24 100644 --- a/tests/slt/crdb/join.slt +++ b/tests/slt/crdb/join.slt @@ -207,19 +207,21 @@ query II SELECT * FROM empty AS a JOIN onecolumn AS b USING(x) ---- +# Regression: SELECT * on outer joins should keep left-to-right output slots +# even when the right side is empty. query IT SELECT * FROM onecolumn AS a(aid, x) LEFT OUTER JOIN empty AS b(bid, y) ON a.x = b.y ORDER BY a.x ---- -null null 2 42 -null null 0 44 -null null 1 null +2 42 null null +0 44 null null +1 null null null query I rowsort SELECT * FROM onecolumn AS a LEFT OUTER JOIN empty AS b USING(x) ORDER BY x ---- -null 0 44 -null 1 null -null 2 42 +0 44 null +1 null null +2 42 null query I SELECT * FROM empty AS a(aid, x) LEFT OUTER JOIN onecolumn AS b(bid, y) ON a.x = b.y @@ -250,16 +252,16 @@ SELECT * FROM empty AS a FULL OUTER JOIN onecolumn AS b USING(x) ORDER BY x query IIII SELECT * FROM onecolumn AS a(aid, x) FULL OUTER JOIN empty AS b(bid, y) ON a.x = b.y ORDER BY a.x ---- -null null 2 42 -null null 0 44 -null null 1 null +2 42 null null +0 44 null null +1 null null null query III rowsort SELECT * FROM onecolumn AS a FULL OUTER JOIN empty AS b USING(x) ORDER BY x ---- -null 0 44 -null 1 null -null 2 42 +0 44 null +1 null null +2 42 null query II SELECT * FROM empty AS a(aid, x) FULL OUTER JOIN onecolumn AS b(bid, y) ON a.x = b.y ORDER BY b.y @@ -294,6 +296,8 @@ SELECT * FROM twocolumn AS a JOIN twocolumn AS b ON a.x = a.y order by a.x 3 45 45 2 42 53 3 45 45 3 45 45 +# Regression: the right-side filter should stay localized after ON pushdown, +# so this inner join still finds the single matching row. query II SELECT o.x, t.y FROM onecolumn o INNER JOIN twocolumn t ON (o.x=t.x AND t.y=53) ---- diff --git a/tests/slt/distinct.slt b/tests/slt/distinct.slt index dd1e5bb5..1fc29f17 100644 --- a/tests/slt/distinct.slt +++ b/tests/slt/distinct.slt @@ -23,4 +23,21 @@ SELECT DISTINCT sum(x) FROM test ORDER BY sum(x); # ORDER BY items must appear in the select list # if SELECT DISTINCT is specified statement error -SELECT DISTINCT x FROM test ORDER BY y; \ No newline at end of file +SELECT DISTINCT x FROM test ORDER BY y; + +# Regression: DISTINCT + ORDER BY + LIMIT should keep using the DISTINCT +# output positions after binding/order-by rewriting. +statement ok +CREATE TABLE distinct_pos(id int primary key, a int, b int); + +statement ok +INSERT INTO distinct_pos VALUES (1, 9, 1), (2, 8, 0), (3, 7, 0); + +query II +SELECT DISTINCT b, a FROM distinct_pos ORDER BY a DESC LIMIT 2; +---- +1 9 +0 8 + +statement ok +DROP TABLE distinct_pos; diff --git a/tests/slt/sql_2016/F131_03.slt b/tests/slt/sql_2016/F131_03.slt index e552a367..ec0aab1e 100644 --- a/tests/slt/sql_2016/F131_03.slt +++ b/tests/slt/sql_2016/F131_03.slt @@ -6,6 +6,8 @@ CREATE TABLE TABLE_F131_03_01_011 ( ID INT PRIMARY KEY, A INTEGER, B INTEGER ); statement ok CREATE VIEW VIEW_F131_03_01_01 AS SELECT A, MIN ( B ) AS C FROM TABLE_F131_03_01_011 GROUP BY A; +# Regression: grouped-view aggregate aliases must survive view expansion so +# outer aggregates can still resolve C correctly. query I SELECT SUM ( C ) FROM VIEW_F131_03_01_01 ---- diff --git a/tests/slt/sql_2016/F471.slt b/tests/slt/sql_2016/F471.slt index 47a8b92c..5857d3a0 100644 --- a/tests/slt/sql_2016/F471.slt +++ b/tests/slt/sql_2016/F471.slt @@ -6,6 +6,8 @@ CREATE TABLE TABLE_F471_01_01 ( ID INT PRIMARY KEY, A INTEGER ); statement ok INSERT INTO TABLE_F471_01_01 ( ID, A ) VALUES ( 0, 1 ); +# Regression: scalar subqueries in the SELECT list should execute directly +# instead of falling through the temporary projection/subquery path. query I SELECT ( SELECT A FROM TABLE_F471_01_01 ) FROM TABLE_F471_01_01 ---- diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index fcd23489..70387fec 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -91,6 +91,27 @@ SELECT id, a, b FROM t2 WHERE a*2 in (SELECT a FROM t3 where a/2 in (select a fr 0 1 2 3 4 5 +# Regression: alias projection inside an IN subquery must keep the right-side +# output positions stable so the final rows are still ordered/filterable. +statement ok +CREATE TABLE tab64784(pk INTEGER PRIMARY KEY, col0 INTEGER, col1 FLOAT, col2 VARCHAR, col3 INTEGER, col4 FLOAT, col5 VARCHAR); + +statement ok +INSERT INTO tab64784 VALUES(0,212,202.62,'nshdy',212,208.79,'wsxfc'),(1,213,203.64,'xwfuo',213,209.26,'lyswz'),(2,214,204.82,'jnued',216,210.48,'qczzf'),(3,215,205.40,'jtijf',217,211.96,'dpugl'),(4,216,206.3,'dpdzk',219,212.43,'xfirg'),(5,218,207.43,'qpwyw',220,213.50,'fmgky'),(6,219,208.3,'uooxb',221,215.30,'xpmdy'),(7,220,209.54,'ndtbb',225,218.8,'ivqyw'),(8,221,210.65,'zjpts',226,219.82,'sezsm'),(9,222,211.57,'slaxq',227,220.91,'bdqyb'); + +query II +SELECT pk, col0 FROM tab64784 WHERE (col0 IN (SELECT col3 FROM tab64784 WHERE col3 IS NULL OR (col1 < 22.54) OR col4 > 85.74) OR ((col4 IS NULL)) AND col3 < 8 OR (col4 > 82.93 AND (col0 <= 61) AND col0 > 94 AND col0 > 15)) ORDER BY col0 DESC; +---- +8 221 +7 220 +6 219 +4 216 +1 213 +0 212 + +statement ok +DROP TABLE tab64784; + statement ok drop table t2; @@ -154,6 +175,92 @@ where id in ( 1 2 +# Regression: correlated IN should still work after column pruning rewrites. +statement ok +create table prune_t1(id int primary key, a int); + +statement ok +create table prune_t2(id int primary key, b int); + +statement ok +insert into prune_t1 values (0, 1), (1, 2), (2, 3); + +statement ok +insert into prune_t2 values (0, 2), (1, 3); + +query I rowsort +select a from prune_t1 where a in (select b from prune_t2 where b = a) order by a; +---- +2 +3 + +statement ok +drop table prune_t1; + +statement ok +drop table prune_t2; + +# Regression: tuple IN should keep the tuple-subquery path instead of falling +# back to the scalar temporary-alias rewrite. +statement ok +create table kv(k int primary key, v int); + +statement ok +insert into kv values (1, 2), (3, 4), (5, 6), (7, 8); + +query II +select * from kv where (k, v) in (select * from kv) order by k; +---- +1 2 +3 4 +5 6 +7 8 + +statement ok +drop table kv; + +statement ok +create table scalar_outer(id int primary key); + +statement ok +create table scalar_inner(v int primary key); + +statement ok +insert into scalar_outer values (1), (2); + +query IT +select id, (select v from scalar_inner where v = -1) from scalar_outer order by id; +---- +1 null +2 null + +statement ok +drop table scalar_outer; + +statement ok +drop table scalar_inner; + +statement ok +create table scalar_outer_err(id int primary key); + +statement ok +create table scalar_inner_err(v int primary key); + +statement ok +insert into scalar_outer_err values (1); + +statement ok +insert into scalar_inner_err values (1), (2); + +statement error +select * from scalar_outer_err where id = (select v from scalar_inner_err); + +statement ok +drop table scalar_outer_err; + +statement ok +drop table scalar_inner_err; + statement error select count(*) from users where exists ( diff --git a/tests/slt/values.slt b/tests/slt/values.slt index b0b5a819..b0c8febc 100644 --- a/tests/slt/values.slt +++ b/tests/slt/values.slt @@ -36,6 +36,8 @@ CREATE TABLE t (x INT PRIMARY KEY); statement ok INSERT INTO t VALUES (1), (2), (3); +# Regression: wildcard expansion over a joined VALUES source should keep the +# visible output shape instead of duplicating the join key column. query I rowsort SELECT * FROM t JOIN (VALUES (2), (3)) AS v(x) ON t.x = v.x; ---- diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 2cc6a39d..75147ba9 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use kite_sql::db::{Database, ResultIter}; +use kite_sql::db::Database; use kite_sql::errors::DatabaseError; use kite_sql::storage::rocksdb::RocksStorage; use sqllogictest::{DBOutput, DefaultColumnType, DB}; @@ -33,14 +33,14 @@ impl DB for SQLBase { let types = vec![DefaultColumnType::Any; iter.schema().len()]; let mut rows = Vec::new(); - for tuple in iter.by_ref() { + while let Some(tuple) = iter.next_borrowed_tuple()? { rows.push( - tuple? + tuple .values - .into_iter() + .iter() .map(|value| format!("{}", value)) .collect(), - ) + ); } iter.done()?; println!(" |— time spent: {:?}", start.elapsed()); diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs index fcfce867..c956e8e7 100644 --- a/tests/sqllogictest/src/main.rs +++ b/tests/sqllogictest/src/main.rs @@ -52,7 +52,7 @@ fn main() { ); let db = DataBaseBuilder::path(temp_dir.path()) - .build() + .build_rocksdb() .expect("init db error"); let mut tester = Runner::new(SQLBase { db }); diff --git a/tpcc/Cargo.toml b/tpcc/Cargo.toml index bb6f07cb..ea399911 100644 --- a/tpcc/Cargo.toml +++ b/tpcc/Cargo.toml @@ -3,10 +3,13 @@ name = "tpcc" version = "0.1.1" edition = "2021" +[features] +pprof = ["dep:pprof"] + [dependencies] clap = { version = "4", features = ["derive"] } chrono = { version = "0.4" } -kite_sql = { path = "..", package = "kite_sql" } +kite_sql = { path = "..", package = "kite_sql", features = ["rocksdb", "lmdb"] } indicatif = { version = "0.17" } ordered-float = { version = "4" } rand = { version = "0.8" } @@ -14,3 +17,6 @@ rust_decimal = { version = "1" } thiserror = { version = "1" } sqlite = { version = "0.34" } sqlparser = { version = "0.61" } + +[target.'cfg(unix)'.dependencies] +pprof = { version = "0.15", features = ["flamegraph"], optional = true } diff --git a/tpcc/README.md b/tpcc/README.md index a38586df..782cc236 100644 --- a/tpcc/README.md +++ b/tpcc/README.md @@ -3,20 +3,51 @@ Run `make tpcc` (or `cargo run -p tpcc --release`) to exercise the workload on K Run `make tpcc-dual` to execute the workload on KiteSQL while mirroring every statement to an in-memory SQLite database; the runner asserts that both engines return identical tuples, making it ideal for correctness validation. This target runs for 60 seconds (`--measure-time 60`). Use `cargo run -p tpcc --release -- --backend dual --measure-time ` for a custom duration. +## Performance Matrix Script +Use `./scripts/run_tpcc_matrix.sh` to run the TPCC performance comparison in one shot. + +- The script measures `kitesql-lmdb`, `kitesql-rocksdb`, `sqlite-balanced`, and `sqlite-practical`. +- Main variants follow TPCC's default duration unless `TPCC_MAIN_MEASURE_TIME` is overridden. +- If a run fails with a duplicate-key style error, the script clears that backend's database and retries that variant once. +- Outputs are written to `tpcc/results//`, including `summary.md` and per-backend raw logs. + +Duplicate-key note: +The occasional duplicate-primary-key failure during TPCC reruns is not treated here as a database correctness bug. In this benchmark setup, the `history` table uses `h_date` as part of its primary key; under very high `Payment` throughput, multiple transactions can hit the same timestamp bucket for the same logical history key and collide on the benchmark schema itself. That is why the matrix script simply reruns the affected backend from a fresh database when this specific condition appears. + +Example: +```shell +TPCC_DUPLICATE_RETRY=1 ./scripts/run_tpcc_matrix.sh +``` + - i9-13900HX - 32.0 GB - KIOXIA-EXCERIA PLUS G3 SSD - Tips: Pass `--threads ` to run multiple worker threads (default: 8) + +## 720s comparison +Local 720-second comparison on the machine above: + +| Backend | TpmC | New-Order p90 | Payment p90 | Order-Status p90 | Delivery p90 | Stock-Level p90 | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| KiteSQL LMDB | 53510 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | +| KiteSQL RocksDB | 32248 | 0.001s | 0.001s | 0.002s | 0.011s | 0.003s | +| SQLite balanced | 36273 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| SQLite practical | 35516 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | + +- `SQLite practical` keeps the previous stable README result because the latest rerun failed twice with `UNIQUE constraint failed: history...h_date`. +- The latest rerun summary is recorded in `tpcc/results/2026-03-28_19-51-45/summary.md`. + +### SQLite practical ```shell Transaction Summary (elapsed 720.0s) +--------------+---------+------+---------+-------+ | Transaction | Success | Late | Failure | Total | +--------------+---------+------+---------+-------+ -| New-Order | 326708 | 0 | 3334 | 330042 | -| Payment | 326683 | 0 | 16218 | 342901 | -| Order-Status | 32669 | 0 | 547 | 33216 | -| Delivery | 32669 | 0 | 0 | 32669 | -| Stock-Level | 32668 | 0 | 0 | 32668 | +| New-Order | 426196 | 0 | 4361 | 430557 | +| Payment | 426170 | 0 | 28459 | 454629 | +| Order-Status | 42617 | 0 | 618 | 43235 | +| Delivery | 42617 | 0 | 0 | 42617 | +| Stock-Level | 42617 | 0 | 0 | 42617 | +--------------+---------+------+---------+-------+ (all must be [OK]) [transaction percentage] @@ -36,65 +67,269 @@ Transaction Summary (elapsed 720.0s) 1.New-Order -0.001, 219955 -0.002, 106100 -0.003, 44 +0.001, 425333 +0.002, 858 +0.003, 3 0.004, 2 -0.005, 1 2.Payment -0.001, 326442 -0.002, 66 +0.001, 426168 +0.002, 1 0.003, 1 3.Order-Status -0.001, 28771 -0.002, 2736 -0.003, 454 -0.004, 145 -0.005, 22 -0.006, 2 +0.001, 42617 4.Delivery -0.003, 11 -0.004, 6201 -0.005, 6965 -0.006, 6338 -0.007, 5805 -0.008, 2971 -0.009, 355 -0.010, 535 -0.011, 690 -0.012, 980 -0.013, 273 -0.014, 693 -0.015, 132 -0.016, 43 -0.017, 2 -0.019, 1 -0.021, 3 -0.022, 1 -0.024, 1 +0.001, 42284 +0.002, 331 +0.003, 2 5.Stock-Level -0.001, 15844 -0.002, 13502 -0.003, 2228 -0.004, 163 +0.001, 42617 + +<90th Percentile RT (MaxRT)> + New-Order : 0.001 (0.004) + Payment : 0.001 (0.003) +Order-Status : 0.001 (0.001) + Delivery : 0.001 (0.003) + Stock-Level : 0.001 (0.001) + +35516 Tpmc +``` + +### KiteSQL LMDB +```shell +Transaction Summary (elapsed 720.0s) ++--------------+---------+------+---------+-------+ +| Transaction | Success | Late | Failure | Total | ++--------------+---------+------+---------+-------+ +| New-Order | 642115 | 0 | 6475 | 648590 | +| Payment | 642093 | 0 | 60782 | 702875 | +| Order-Status | 64209 | 0 | 621 | 64830 | +| Delivery | 64209 | 0 | 0 | 64209 | +| Stock-Level | 64209 | 0 | 0 | 64209 | ++--------------+---------+------+---------+-------+ + (all must be [OK]) +[transaction percentage] + Payment: 43.5% (>=43.0%) [OK] + Order-Status: 4.3% (>=4.0%) [OK] + Delivery: 4.3% (>=4.0%) [OK] + Stock-Level: 4.3% (>=4.0%) [OK] +[response time (at least 90% passed)] + New-Order: 100.0% [OK] + Payment: 100.0% [OK] + Order-Status: 100.0% [OK] + Delivery: 100.0% [OK] + Stock-Level: 100.0% [OK] + + + + +1.New-Order + +0.001, 640181 +0.002, 1933 +0.003, 1 + +2.Payment + +0.001, 642092 +0.002, 1 + +3.Order-Status + +0.001, 62712 +0.002, 1482 +0.003, 13 +0.004, 2 + +4.Delivery + +0.002, 63332 +0.003, 785 +0.004, 91 0.005, 1 +5.Stock-Level + +0.001, 62987 +0.002, 1220 +0.003, 2 + +<90th Percentile RT (MaxRT)> + New-Order : 0.001 (0.002) + Payment : 0.001 (0.001) +Order-Status : 0.001 (0.003) + Delivery : 0.002 (0.004) + Stock-Level : 0.001 (0.002) + +53510 Tpmc +``` + +### KiteSQL RocksDB +```shell +Transaction Summary (elapsed 720.0s) ++--------------+---------+------+---------+-------+ +| Transaction | Success | Late | Failure | Total | ++--------------+---------+------+---------+-------+ +| New-Order | 386982 | 0 | 3867 | 390849 | +| Payment | 386959 | 0 | 23536 | 410495 | +| Order-Status | 38696 | 0 | 591 | 39287 | +| Delivery | 38696 | 0 | 0 | 38696 | +| Stock-Level | 38696 | 0 | 0 | 38696 | ++--------------+---------+------+---------+-------+ + (all must be [OK]) +[transaction percentage] + Payment: 43.5% (>=43.0%) [OK] + Order-Status: 4.3% (>=4.0%) [OK] + Delivery: 4.3% (>=4.0%) [OK] + Stock-Level: 4.3% (>=4.0%) [OK] +[response time (at least 90% passed)] + New-Order: 100.0% [OK] + Payment: 100.0% [OK] + Order-Status: 100.0% [OK] + Delivery: 100.0% [OK] + Stock-Level: 100.0% [OK] + + + + +1.New-Order + +0.001, 375252 +0.002, 11525 +0.003, 202 +0.004, 2 + +2.Payment + +0.001, 386703 +0.002, 250 +0.003, 5 + +3.Order-Status + +0.001, 32982 +0.002, 4124 +0.003, 1025 +0.004, 324 +0.005, 176 +0.006, 53 +0.007, 8 +0.009, 1 + +4.Delivery + +0.002, 6 +0.003, 8086 +0.004, 10302 +0.005, 8066 +0.006, 1217 +0.007, 1153 +0.008, 1062 +0.009, 1635 +0.010, 1907 +0.011, 2165 +0.012, 1612 +0.013, 1135 +0.014, 226 +0.015, 45 +0.016, 23 +0.017, 17 +0.018, 13 +0.019, 10 +0.020, 8 +0.021, 4 +0.022, 3 +0.023, 1 + +5.Stock-Level + +0.001, 20753 +0.002, 13780 +0.003, 3570 +0.004, 526 +0.005, 39 +0.006, 12 +0.007, 9 +0.008, 3 +0.009, 2 +0.010, 2 + +<90th Percentile RT (MaxRT)> + New-Order : 0.001 (0.005) + Payment : 0.001 (0.006) +Order-Status : 0.002 (0.012) + Delivery : 0.011 (0.023) + Stock-Level : 0.003 (0.010) + +32248 Tpmc +``` + +### SQLite balanced +```shell +Transaction Summary (elapsed 720.0s) ++--------------+---------+------+---------+-------+ +| Transaction | Success | Late | Failure | Total | ++--------------+---------+------+---------+-------+ +| New-Order | 435291 | 0 | 4433 | 439724 | +| Payment | 435265 | 0 | 33227 | 468492 | +| Order-Status | 43527 | 0 | 566 | 44093 | +| Delivery | 43527 | 0 | 0 | 43527 | +| Stock-Level | 43527 | 0 | 0 | 43527 | ++--------------+---------+------+---------+-------+ + (all must be [OK]) +[transaction percentage] + Payment: 43.5% (>=43.0%) [OK] + Order-Status: 4.3% (>=4.0%) [OK] + Delivery: 4.3% (>=4.0%) [OK] + Stock-Level: 4.3% (>=4.0%) [OK] +[response time (at least 90% passed)] + New-Order: 100.0% [OK] + Payment: 100.0% [OK] + Order-Status: 100.0% [OK] + Delivery: 100.0% [OK] + Stock-Level: 100.0% [OK] + + + + +1.New-Order + +0.001, 435079 +0.002, 210 +0.003, 2 + +2.Payment + +0.001, 435265 + +3.Order-Status + +0.001, 43527 + +4.Delivery + +0.001, 43442 +0.002, 85 + +5.Stock-Level + +0.001, 43527 + <90th Percentile RT (MaxRT)> - New-Order : 0.002 (0.005) - Payment : 0.001 (0.013) -Order-Status : 0.002 (0.006) - Delivery : 0.010 (0.023) - Stock-Level : 0.002 (0.017) + New-Order : 0.001 (0.003) + Payment : 0.001 (0.001) +Order-Status : 0.001 (0.001) + Delivery : 0.001 (0.002) + Stock-Level : 0.001 (0.000) -27226 Tpmc +36273 Tpmc ``` ## Refer to diff --git a/tpcc/src/backend/dual.rs b/tpcc/src/backend/dual.rs index 3135d96b..cc6bd3f2 100644 --- a/tpcc/src/backend/dual.rs +++ b/tpcc/src/backend/dual.rs @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::kite::{KiteBackend, KiteTransaction, KiteTxnResult}; +use super::kitesql_rocksdb::{KiteSqlRocksDbBackend, KiteSqlRocksDbTransaction, KiteSqlTxnResult}; use super::sqlite::{SqliteBackend, SqliteResult, SqliteTransaction}; use super::{ - BackendControl, BackendTransaction, DbParam, PreparedStatement, QueryResult, SimpleExecutor, - StatementSpec, + BackendControl, BackendTransaction, DbParam, PreparedStatement, SimpleExecutor, StatementSpec, }; use crate::{TpccError, STOCK_LEVEL_DISTINCT_SQL, STOCK_LEVEL_DISTINCT_SQLITE}; use kite_sql::types::tuple::Tuple; @@ -25,14 +24,14 @@ use std::borrow::Cow; use std::collections::HashMap; pub struct DualBackend { - kite: KiteBackend, + kitesql: KiteSqlRocksDbBackend, sqlite: SqliteBackend, } impl DualBackend { pub fn new(path: &str, rocksdb_stats: bool) -> Result { Ok(Self { - kite: KiteBackend::new(path, rocksdb_stats)?, + kitesql: KiteSqlRocksDbBackend::new(path, rocksdb_stats)?, sqlite: SqliteBackend::new_memory()?, }) } @@ -48,24 +47,24 @@ impl BackendControl for DualBackend { &self, specs: &[Vec], ) -> Result>, TpccError> { - self.kite.prepare_statements(specs) + self.kitesql.prepare_statements(specs) } fn new_transaction(&self) -> Result, TpccError> { Ok(DualTransaction { - kite: self.kite.new_transaction()?, + kitesql: self.kitesql.new_transaction()?, sqlite: self.sqlite.new_transaction()?, }) } fn storage_metrics(&self) -> Option { - self.kite.storage_metrics() + self.kitesql.storage_metrics() } } impl SimpleExecutor for DualBackend { fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { - self.kite.execute_batch(sql)?; + self.kitesql.execute_batch(sql)?; if let Some(stmt) = normalize_sqlite_sql(sql) { self.sqlite.execute_batch(&stmt)?; } @@ -74,141 +73,102 @@ impl SimpleExecutor for DualBackend { } pub struct DualTransaction<'a> { - kite: KiteTransaction<'a>, + kitesql: KiteSqlRocksDbTransaction<'a>, sqlite: SqliteTransaction<'a>, } impl<'a> BackendTransaction for DualTransaction<'a> { - fn execute<'b>( - &'b mut self, + fn query_one( + &mut self, statement: &PreparedStatement, params: &[DbParam], - ) -> Result, TpccError> { + ) -> Result { let spec = statement.spec().clone(); - let sql_lower = spec.sql.trim_start().to_ascii_lowercase(); - let kite_iter = self.kite.execute_raw(statement, params)?; - let sqlite_spec = sqlite_statement_spec(&spec); - let sqlite_stmt = PreparedStatement::Sqlite { spec: sqlite_spec }; + let sqlite_stmt = PreparedStatement::Sqlite { + spec: sqlite_statement_spec(&spec), + }; + + let kitesql_iter = self.kitesql.execute_raw(statement, params)?; let sqlite_iter = self.sqlite.execute_raw(&sqlite_stmt, params)?; - if sql_lower.starts_with("select") { + if is_select_sql(&spec) { if spec.sql == STOCK_LEVEL_DISTINCT_SQL { - // DISTINCT without ORDER BY has undefined ordering; compare as sets. - let kite_rows = collect_all_rows(kite_iter)?; + let kitesql_rows = collect_all_rows(kitesql_iter)?; let sqlite_rows = collect_all_rows(sqlite_iter)?; - compare_unordered_rows(&kite_rows, &sqlite_rows, statement.spec().sql)?; - return Ok(QueryResult::from_dual(DualQueryResult::CompareUnordered( - DualUnorderedResult::new(kite_rows), - ))); + compare_unordered_rows(&kitesql_rows, &sqlite_rows, spec.sql)?; + return kitesql_rows + .into_iter() + .next() + .ok_or(TpccError::EmptyTuples); } - Ok(QueryResult::from_dual(DualQueryResult::Compare( - DualResult::new(kite_iter, sqlite_iter, statement.spec().sql), - ))) + query_ordered_nth(kitesql_iter, sqlite_iter, spec.sql, 0) } else { drain_sqlite_iter(sqlite_iter)?; - Ok(QueryResult::from_kite(kite_iter)) + let mut kitesql_iter = kitesql_iter; + match kitesql_iter.next() { + Some(row) => row, + None => Err(TpccError::EmptyTuples), + } } } - fn commit(self) -> Result<(), TpccError> { - self.sqlite.commit()?; - self.kite.commit() - } -} - -pub(crate) enum DualQueryResult<'a> { - Compare(DualResult<'a>), - CompareUnordered(DualUnorderedResult), -} + fn query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + ) -> Result { + let spec = statement.spec().clone(); + let sqlite_stmt = PreparedStatement::Sqlite { + spec: sqlite_statement_spec(&spec), + }; -impl<'a> Iterator for DualQueryResult<'a> { - type Item = Result; + let kitesql_iter = self.kitesql.execute_raw(statement, params)?; + let sqlite_iter = self.sqlite.execute_raw(&sqlite_stmt, params)?; - fn next(&mut self) -> Option { - match self { - DualQueryResult::Compare(result) => result.next(), - DualQueryResult::CompareUnordered(result) => result.next(), + if spec.sql == STOCK_LEVEL_DISTINCT_SQL { + let kitesql_rows = collect_all_rows(kitesql_iter)?; + let sqlite_rows = collect_all_rows(sqlite_iter)?; + compare_unordered_rows(&kitesql_rows, &sqlite_rows, spec.sql)?; + return kitesql_rows + .into_iter() + .nth(n) + .ok_or(TpccError::EmptyTuples); } - } -} -pub(crate) struct DualResult<'a> { - kite: KiteTxnResult<'a>, - sqlite: SqliteResult<'a>, - sql: &'static str, -} - -impl<'a> DualResult<'a> { - fn new(kite: KiteTxnResult<'a>, sqlite: SqliteResult<'a>, sql: &'static str) -> Self { - Self { kite, sqlite, sql } + query_ordered_nth(kitesql_iter, sqlite_iter, spec.sql, n) } -} - -impl Iterator for DualResult<'_> { - type Item = Result; - fn next(&mut self) -> Option { - match self.kite.next() { - Some(kite_row) => { - let sqlite_row = match self.sqlite.next() { - Some(row) => row, - None => { - return Some(Err(TpccError::BackendMismatch(format!( - "SQLite returned fewer rows for SQL: {}", - self.sql - )))) - } - }; - match (kite_row, sqlite_row) { - (Ok(kite_tuple), Ok(sqlite_tuple)) => { - if kite_tuple.values != sqlite_tuple.values { - println!("[Dual] mismatch SQL: {}", self.sql); - println!(" KiteSQL row: {:?}", kite_tuple.values); - println!(" SQLite row: {:?}", sqlite_tuple.values); - return Some(Err(TpccError::BackendMismatch(format!( - "Result mismatch for SQL: {}", - self.sql - )))); - } - Some(Ok(kite_tuple)) - } - (Err(err), _) => Some(Err(err)), - (_, Err(err)) => Some(Err(err)), - } - } - None => { - if let Some(extra) = self.sqlite.next() { - let err = extra.err().unwrap_or_else(|| { - TpccError::BackendMismatch(format!( - "SQLite returned extra rows for SQL: {}", - self.sql - )) - }); - return Some(Err(err)); - } - None - } - } - } -} + fn execute_drain( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result<(), TpccError> { + let spec = statement.spec().clone(); + let sqlite_stmt = PreparedStatement::Sqlite { + spec: sqlite_statement_spec(&spec), + }; -pub(crate) struct DualUnorderedResult { - rows: std::vec::IntoIter, -} + let kitesql_iter = self.kitesql.execute_raw(statement, params)?; + let sqlite_iter = self.sqlite.execute_raw(&sqlite_stmt, params)?; -impl DualUnorderedResult { - fn new(rows: Vec) -> Self { - Self { - rows: rows.into_iter(), + if is_select_sql(&spec) { + if spec.sql == STOCK_LEVEL_DISTINCT_SQL { + let kitesql_rows = collect_all_rows(kitesql_iter)?; + let sqlite_rows = collect_all_rows(sqlite_iter)?; + compare_unordered_rows(&kitesql_rows, &sqlite_rows, spec.sql) + } else { + drain_and_compare_ordered(kitesql_iter, sqlite_iter, spec.sql) + } + } else { + drain_sqlite_iter(sqlite_iter)?; + drain_kitesql_iter(kitesql_iter) } } -} -impl Iterator for DualUnorderedResult { - type Item = Result; - - fn next(&mut self) -> Option { - self.rows.next().map(Ok) + fn commit(self) -> Result<(), TpccError> { + self.sqlite.commit()?; + self.kitesql.commit() } } @@ -229,6 +189,15 @@ fn drain_sqlite_iter(mut iter: SqliteResult<'_>) -> Result<(), TpccError> { Ok(()) } +fn drain_kitesql_iter( + mut iter: KiteSqlTxnResult<'_, T>, +) -> Result<(), TpccError> { + while let Some(row) = iter.next() { + row?; + } + Ok(()) +} + fn collect_all_rows(mut iter: I) -> Result, TpccError> where I: Iterator>, @@ -240,12 +209,104 @@ where Ok(rows) } +fn query_ordered_nth( + mut kitesql_iter: KiteSqlTxnResult<'_, T>, + mut sqlite_iter: SqliteResult<'_>, + sql: &'static str, + n: usize, +) -> Result { + let mut result = None; + let mut index = 0usize; + + loop { + match kitesql_iter.next() { + Some(kitesql_row) => { + let kitesql_tuple = kitesql_row?; + let sqlite_tuple = match sqlite_iter.next() { + Some(row) => row?, + None => { + return Err(TpccError::BackendMismatch(format!( + "SQLite returned fewer rows for SQL: {}", + sql + ))) + } + }; + if kitesql_tuple.values != sqlite_tuple.values { + println!("[Dual] mismatch SQL: {}", sql); + println!(" KiteSQL row: {:?}", kitesql_tuple.values); + println!(" SQLite row: {:?}", sqlite_tuple.values); + return Err(TpccError::BackendMismatch(format!( + "Result mismatch for SQL: {}", + sql + ))); + } + if index == n { + result = Some(kitesql_tuple.clone()); + } + index += 1; + } + None => { + if let Some(extra) = sqlite_iter.next() { + extra?; + return Err(TpccError::BackendMismatch(format!( + "SQLite returned extra rows for SQL: {}", + sql + ))); + } + return result.ok_or(TpccError::EmptyTuples); + } + } + } +} + +fn drain_and_compare_ordered( + mut kitesql_iter: KiteSqlTxnResult<'_, T>, + mut sqlite_iter: SqliteResult<'_>, + sql: &'static str, +) -> Result<(), TpccError> { + loop { + match kitesql_iter.next() { + Some(kitesql_row) => { + let kitesql_tuple = kitesql_row?; + let sqlite_tuple = match sqlite_iter.next() { + Some(row) => row?, + None => { + return Err(TpccError::BackendMismatch(format!( + "SQLite returned fewer rows for SQL: {}", + sql + ))) + } + }; + if kitesql_tuple.values != sqlite_tuple.values { + println!("[Dual] mismatch SQL: {}", sql); + println!(" KiteSQL row: {:?}", kitesql_tuple.values); + println!(" SQLite row: {:?}", sqlite_tuple.values); + return Err(TpccError::BackendMismatch(format!( + "Result mismatch for SQL: {}", + sql + ))); + } + } + None => { + if let Some(extra) = sqlite_iter.next() { + extra?; + return Err(TpccError::BackendMismatch(format!( + "SQLite returned extra rows for SQL: {}", + sql + ))); + } + return Ok(()); + } + } + } +} + fn compare_unordered_rows( - kite_rows: &[Tuple], + kitesql_rows: &[Tuple], sqlite_rows: &[Tuple], sql: &'static str, ) -> Result<(), TpccError> { - if kite_rows.len() != sqlite_rows.len() { + if kitesql_rows.len() != sqlite_rows.len() { return Err(TpccError::BackendMismatch(format!( "SQLite returned different row count for SQL: {}", sql @@ -253,7 +314,7 @@ fn compare_unordered_rows( } let mut counts: HashMap, usize> = HashMap::new(); - for row in kite_rows { + for row in kitesql_rows { *counts.entry(row.values.clone()).or_insert(0) += 1; } for row in sqlite_rows { @@ -294,3 +355,11 @@ fn sqlite_statement_spec(spec: &StatementSpec) -> StatementSpec { spec.clone() } } + +fn is_select_sql(spec: &StatementSpec) -> bool { + spec.sql + .trim_start() + .get(..6) + .map(|prefix| prefix.eq_ignore_ascii_case("select")) + .unwrap_or(false) +} diff --git a/tpcc/src/backend/kite.rs b/tpcc/src/backend/kite.rs deleted file mode 100644 index 91ed1e7c..00000000 --- a/tpcc/src/backend/kite.rs +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::{ - BackendControl, BackendTransaction, DbParam, PreparedStatement, QueryResult, SimpleExecutor, - StatementSpec, -}; -use crate::TpccError; -use kite_sql::db::{ - prepare, DBTransaction, DataBaseBuilder, Database, ResultIter, TransactionIter, -}; -use kite_sql::storage::rocksdb::RocksStorage; -use kite_sql::types::tuple::Tuple; - -pub struct KiteBackend { - database: Database, -} - -impl KiteBackend { - pub fn new(path: &str, rocksdb_stats: bool) -> Result { - Ok(Self { - database: DataBaseBuilder::path(path) - .storage_statistics(rocksdb_stats) - .build()?, - }) - } - - fn prepare_spec_groups( - &self, - specs: &[Vec], - ) -> Result>, TpccError> { - let mut groups = Vec::with_capacity(specs.len()); - for group in specs { - let mut prepared = Vec::with_capacity(group.len()); - for spec in group { - let statement = prepare(spec.sql)?; - prepared.push(PreparedStatement::Kite { - statement, - spec: spec.clone(), - }); - } - groups.push(prepared); - } - Ok(groups) - } - - fn start_transaction(&self) -> Result, TpccError> { - Ok(KiteTransaction { - inner: self.database.new_transaction()?, - }) - } -} - -impl BackendControl for KiteBackend { - type Transaction<'a> - = KiteTransaction<'a> - where - Self: 'a; - - fn prepare_statements( - &self, - specs: &[Vec], - ) -> Result>, TpccError> { - self.prepare_spec_groups(specs) - } - - fn new_transaction(&self) -> Result, TpccError> { - self.start_transaction() - } - - fn storage_metrics(&self) -> Option { - self.database - .storage_metrics() - .map(|metrics| metrics.to_string()) - } -} - -impl SimpleExecutor for KiteBackend { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { - self.database.run(sql)?.done()?; - Ok(()) - } -} - -pub struct KiteTransaction<'a> { - inner: DBTransaction<'a, RocksStorage>, -} - -impl<'a> KiteTransaction<'a> { - pub(crate) fn execute_raw<'b>( - &'b mut self, - statement: &PreparedStatement, - params: &[DbParam], - ) -> Result, TpccError> { - let PreparedStatement::Kite { statement, .. } = statement else { - return Err(TpccError::InvalidBackend); - }; - Ok(KiteTxnResult(self.inner.execute(statement, params)?)) - } -} - -impl<'a> BackendTransaction for KiteTransaction<'a> { - fn execute<'b>( - &'b mut self, - statement: &PreparedStatement, - params: &[DbParam], - ) -> Result, TpccError> { - let iter = self.execute_raw(statement, params)?; - Ok(QueryResult::from_kite(iter)) - } - - fn commit(self) -> Result<(), TpccError> { - self.inner.commit()?; - Ok(()) - } -} - -pub struct KiteTxnResult<'a>(TransactionIter<'a>); - -impl Iterator for KiteTxnResult<'_> { - type Item = Result; - - fn next(&mut self) -> Option { - self.0.next().map(|item| item.map_err(TpccError::from)) - } -} diff --git a/tpcc/src/backend/kitesql_lmdb.rs b/tpcc/src/backend/kitesql_lmdb.rs new file mode 100644 index 00000000..46642ed5 --- /dev/null +++ b/tpcc/src/backend/kitesql_lmdb.rs @@ -0,0 +1,178 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::kitesql_rocksdb::KiteSqlTxnResult; +use super::{ + BackendControl, BackendTransaction, DbParam, PreparedStatement, SimpleExecutor, StatementSpec, +}; +use crate::TpccError; +use kite_sql::db::{prepare, DBTransaction, DataBaseBuilder, Database}; +use kite_sql::storage::lmdb::{LmdbStorage, LmdbTransaction as KiteSqlLmdbTransaction}; +use kite_sql::types::tuple::Tuple; + +pub struct KiteSqlLmdbBackend { + database: Database, +} + +impl KiteSqlLmdbBackend { + pub fn new(path: &str) -> Result { + Ok(Self { + database: DataBaseBuilder::path(path) + .lmdb_no_sync(true) + .build_lmdb()?, + }) + } + + fn prepare_spec_groups( + &self, + specs: &[Vec], + ) -> Result>, TpccError> { + let mut groups = Vec::with_capacity(specs.len()); + for group in specs { + let mut prepared = Vec::with_capacity(group.len()); + for spec in group { + let statement = prepare(spec.sql)?; + prepared.push(PreparedStatement::KiteSql { + statement, + spec: spec.clone(), + }); + } + groups.push(prepared); + } + Ok(groups) + } + + fn start_transaction(&self) -> Result, TpccError> { + Ok(KiteSqlLmdbTransactionWrapper { + inner: self.database.new_transaction()?, + }) + } +} + +impl BackendControl for KiteSqlLmdbBackend { + type Transaction<'a> + = KiteSqlLmdbTransactionWrapper<'a> + where + Self: 'a; + + fn prepare_statements( + &self, + specs: &[Vec], + ) -> Result>, TpccError> { + self.prepare_spec_groups(specs) + } + + fn new_transaction(&self) -> Result, TpccError> { + self.start_transaction() + } +} + +impl SimpleExecutor for KiteSqlLmdbBackend { + fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + self.database.run(sql)?.done()?; + Ok(()) + } +} + +pub struct KiteSqlLmdbTransactionWrapper<'a> { + inner: DBTransaction<'a, LmdbStorage>, +} + +impl<'a> KiteSqlLmdbTransactionWrapper<'a> { + pub(crate) fn execute_raw<'b>( + &'b mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result>, TpccError> { + let PreparedStatement::KiteSql { statement, .. } = statement else { + return Err(TpccError::InvalidBackend); + }; + Ok(KiteSqlTxnResult::new( + self.inner.execute(statement, params)?, + )) + } +} + +impl<'a> BackendTransaction for KiteSqlLmdbTransactionWrapper<'a> { + fn query_one( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result { + self.execute_raw(statement, params)? + .next_borrowed_tuple()? + .cloned() + .ok_or(TpccError::EmptyTuples) + } + + fn query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + ) -> Result { + let mut iter = self.execute_raw(statement, params)?; + for _ in 0..n { + if iter.next_borrowed_tuple()?.is_none() { + return Err(TpccError::EmptyTuples); + } + } + iter.next_borrowed_tuple()? + .cloned() + .ok_or(TpccError::EmptyTuples) + } + + fn execute_drain( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + while iter.next_borrowed_tuple()?.is_some() {} + Ok(()) + } + + fn with_query_one( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + let tuple = iter.next_borrowed_tuple()?.ok_or(TpccError::EmptyTuples)?; + visitor(tuple) + } + + fn with_query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + for _ in 0..n { + if iter.next_borrowed_tuple()?.is_none() { + return Err(TpccError::EmptyTuples); + } + } + let tuple = iter.next_borrowed_tuple()?.ok_or(TpccError::EmptyTuples)?; + visitor(tuple) + } + + fn commit(self) -> Result<(), TpccError> { + self.inner.commit()?; + Ok(()) + } +} diff --git a/tpcc/src/backend/kitesql_rocksdb.rs b/tpcc/src/backend/kitesql_rocksdb.rs new file mode 100644 index 00000000..626ea862 --- /dev/null +++ b/tpcc/src/backend/kitesql_rocksdb.rs @@ -0,0 +1,206 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::{ + BackendControl, BackendTransaction, DbParam, PreparedStatement, SimpleExecutor, StatementSpec, +}; +use crate::TpccError; +use kite_sql::db::{prepare, DBTransaction, DataBaseBuilder, Database, TransactionIter}; +use kite_sql::storage::rocksdb::{RocksStorage, RocksTransaction}; +use kite_sql::storage::Transaction; +use kite_sql::types::tuple::Tuple; + +pub struct KiteSqlRocksDbBackend { + database: Database, +} + +impl KiteSqlRocksDbBackend { + pub fn new(path: &str, rocksdb_stats: bool) -> Result { + Ok(Self { + database: DataBaseBuilder::path(path) + .storage_statistics(rocksdb_stats) + .build_rocksdb()?, + }) + } + + fn prepare_spec_groups( + &self, + specs: &[Vec], + ) -> Result>, TpccError> { + let mut groups = Vec::with_capacity(specs.len()); + for group in specs { + let mut prepared = Vec::with_capacity(group.len()); + for spec in group { + let statement = prepare(spec.sql)?; + prepared.push(PreparedStatement::KiteSql { + statement, + spec: spec.clone(), + }); + } + groups.push(prepared); + } + Ok(groups) + } + + fn start_transaction(&self) -> Result, TpccError> { + Ok(KiteSqlRocksDbTransaction { + inner: self.database.new_transaction()?, + }) + } +} + +impl BackendControl for KiteSqlRocksDbBackend { + type Transaction<'a> + = KiteSqlRocksDbTransaction<'a> + where + Self: 'a; + + fn prepare_statements( + &self, + specs: &[Vec], + ) -> Result>, TpccError> { + self.prepare_spec_groups(specs) + } + + fn new_transaction(&self) -> Result, TpccError> { + self.start_transaction() + } + + fn storage_metrics(&self) -> Option { + self.database + .storage_metrics() + .map(|metrics| metrics.to_string()) + } +} + +impl SimpleExecutor for KiteSqlRocksDbBackend { + fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + self.database.run(sql)?.done()?; + Ok(()) + } +} + +pub struct KiteSqlRocksDbTransaction<'a> { + inner: DBTransaction<'a, RocksStorage>, +} + +impl<'a> KiteSqlRocksDbTransaction<'a> { + pub(crate) fn execute_raw<'b>( + &'b mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result, TpccError> { + let PreparedStatement::KiteSql { statement, .. } = statement else { + return Err(TpccError::InvalidBackend); + }; + Ok(KiteSqlTxnResult::new( + self.inner.execute(statement, params)?, + )) + } +} + +impl<'a> BackendTransaction for KiteSqlRocksDbTransaction<'a> { + fn query_one( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result { + self.execute_raw(statement, params)? + .next_borrowed_tuple()? + .cloned() + .ok_or(TpccError::EmptyTuples) + } + + fn query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + ) -> Result { + let mut iter = self.execute_raw(statement, params)?; + for _ in 0..n { + if iter.next_borrowed_tuple()?.is_none() { + return Err(TpccError::EmptyTuples); + } + } + iter.next_borrowed_tuple()? + .cloned() + .ok_or(TpccError::EmptyTuples) + } + + fn execute_drain( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + while iter.next_borrowed_tuple()?.is_some() {} + Ok(()) + } + + fn with_query_one( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + let tuple = iter.next_borrowed_tuple()?.ok_or(TpccError::EmptyTuples)?; + visitor(tuple) + } + + fn with_query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + for _ in 0..n { + if iter.next_borrowed_tuple()?.is_none() { + return Err(TpccError::EmptyTuples); + } + } + let tuple = iter.next_borrowed_tuple()?.ok_or(TpccError::EmptyTuples)?; + visitor(tuple) + } + + fn commit(self) -> Result<(), TpccError> { + self.inner.commit()?; + Ok(()) + } +} + +pub struct KiteSqlTxnResult<'a, T: Transaction + 'a>(TransactionIter<'a, T>); + +impl<'a, T: Transaction + 'a> KiteSqlTxnResult<'a, T> { + pub(crate) fn new(iter: TransactionIter<'a, T>) -> Self { + Self(iter) + } + + pub(crate) fn next_borrowed_tuple(&mut self) -> Result, TpccError> { + self.0.next_borrowed_tuple().map_err(TpccError::from) + } +} + +impl Iterator for KiteSqlTxnResult<'_, T> { + type Item = Result; + + fn next(&mut self) -> Option { + self.0.next().map(|item| item.map_err(TpccError::from)) + } +} + +pub(crate) type RocksTxnResult<'a, 'txn> = KiteSqlTxnResult<'a, RocksTransaction<'txn>>; diff --git a/tpcc/src/backend/mod.rs b/tpcc/src/backend/mod.rs index 63fade99..12c88d2e 100644 --- a/tpcc/src/backend/mod.rs +++ b/tpcc/src/backend/mod.rs @@ -13,12 +13,10 @@ // limitations under the License. pub mod dual; -pub mod kite; +pub mod kitesql_lmdb; +pub mod kitesql_rocksdb; pub mod sqlite; -use self::dual::DualQueryResult; -use self::kite::KiteTxnResult; -use self::sqlite::SqliteResult; use crate::TpccError; use kite_sql::db::Statement; use kite_sql::types::tuple::Tuple; @@ -47,77 +45,49 @@ pub trait BackendControl: SimpleExecutor { } } -pub struct QueryResult<'a>(QueryResultKind<'a>); - -enum QueryResultKind<'a> { - Kite(KiteTxnResult<'a>), - Sqlite(SqliteResult<'a>), - Dual(DualQueryResult<'a>), -} - -impl<'a> QueryResult<'a> { - pub(crate) fn from_kite(iter: KiteTxnResult<'a>) -> Self { - Self(QueryResultKind::Kite(iter)) - } - - pub(crate) fn from_sqlite(iter: SqliteResult<'a>) -> Self { - Self(QueryResultKind::Sqlite(iter)) - } - - pub(crate) fn from_dual(iter: DualQueryResult<'a>) -> Self { - Self(QueryResultKind::Dual(iter)) - } -} - -impl<'a> Iterator for QueryResult<'a> { - type Item = Result; - - fn next(&mut self) -> Option { - match &mut self.0 { - QueryResultKind::Kite(iter) => iter.next(), - QueryResultKind::Sqlite(iter) => iter.next(), - QueryResultKind::Dual(iter) => iter.next(), - } - } -} - pub trait BackendTransaction { - fn execute<'a>( - &'a mut self, + fn query_one( + &mut self, statement: &PreparedStatement, params: &[DbParam], - ) -> Result, TpccError>; + ) -> Result; - fn commit(self) -> Result<(), TpccError>; + fn query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + ) -> Result; fn execute_drain( &mut self, statement: &PreparedStatement, params: &[DbParam], + ) -> Result<(), TpccError>; + + fn with_query_one( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, ) -> Result<(), TpccError> { - let mut iter = self.execute(statement, params)?; - while let Some(row) = iter.next() { - row?; - } - Ok(()) + let tuple = self.query_one(statement, params)?; + visitor(&tuple) } -} -pub trait TransactionExt: BackendTransaction { - fn query_one( + fn with_query_nth( &mut self, statement: &PreparedStatement, params: &[DbParam], - ) -> Result { - let mut iter = self.execute(statement, params)?; - match iter.next() { - Some(row) => row, - None => Err(TpccError::EmptyTuples), - } + n: usize, + visitor: &mut dyn FnMut(&Tuple) -> Result<(), TpccError>, + ) -> Result<(), TpccError> { + let tuple = self.query_nth(statement, params, n)?; + visitor(&tuple) } -} -impl TransactionExt for T {} + fn commit(self) -> Result<(), TpccError>; +} #[derive(Clone, Copy)] pub enum ColumnType { @@ -139,7 +109,7 @@ pub struct StatementSpec { #[derive(Clone)] pub enum PreparedStatement { - Kite { + KiteSql { statement: Statement, spec: StatementSpec, }, @@ -151,7 +121,7 @@ pub enum PreparedStatement { impl PreparedStatement { pub fn spec(&self) -> &StatementSpec { match self { - PreparedStatement::Kite { spec, .. } => spec, + PreparedStatement::KiteSql { spec, .. } => spec, PreparedStatement::Sqlite { spec } => spec, } } diff --git a/tpcc/src/backend/sqlite.rs b/tpcc/src/backend/sqlite.rs index f3239ea4..89b1ae38 100644 --- a/tpcc/src/backend/sqlite.rs +++ b/tpcc/src/backend/sqlite.rs @@ -13,11 +13,12 @@ // limitations under the License. use super::{ - BackendControl, BackendTransaction, ColumnType, DbParam, PreparedStatement, QueryResult, - SimpleExecutor, StatementSpec, + BackendControl, BackendTransaction, ColumnType, DbParam, PreparedStatement, SimpleExecutor, + StatementSpec, }; use crate::TpccError; use chrono::{NaiveDateTime, TimeZone, Utc}; +use clap::ValueEnum; use kite_sql::types::tuple::Tuple; use kite_sql::types::value::{DataValue, Utf8Type}; use rust_decimal::Decimal; @@ -28,18 +29,23 @@ pub struct SqliteBackend { connection: Connection, } +#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] +pub enum SqliteProfile { + Balanced, + Practical, +} + impl SqliteBackend { - #[allow(dead_code)] - pub fn new(path: &str) -> Result { - Ok(Self { - connection: Connection::open(path)?, - }) + pub fn new(path: &str, profile: SqliteProfile) -> Result { + let connection = Connection::open(path)?; + configure_sqlite(&connection, profile)?; + Ok(Self { connection }) } pub fn new_memory() -> Result { - Ok(Self { - connection: Connection::open(":memory:")?, - }) + let connection = Connection::open(":memory:")?; + configure_sqlite(&connection, SqliteProfile::Balanced)?; + Ok(Self { connection }) } fn prepare_spec_groups( @@ -87,7 +93,9 @@ impl BackendControl for SqliteBackend { impl SimpleExecutor for SqliteBackend { fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { - self.connection.execute(sql)?; + if let Some(stmt) = normalize_sqlite_sql(sql) { + self.connection.execute(&stmt)?; + } Ok(()) } } @@ -121,13 +129,45 @@ impl Drop for SqliteTransaction<'_> { } impl<'a> BackendTransaction for SqliteTransaction<'a> { - fn execute<'b>( - &'b mut self, + fn query_one( + &mut self, statement: &PreparedStatement, params: &[DbParam], - ) -> Result, TpccError> { - let iter = self.execute_raw(statement, params)?; - Ok(QueryResult::from_sqlite(iter)) + ) -> Result { + match self.execute_raw(statement, params)?.next() { + Some(row) => row, + None => Err(TpccError::EmptyTuples), + } + } + + fn query_nth( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + n: usize, + ) -> Result { + let mut iter = self.execute_raw(statement, params)?; + for _ in 0..n { + if iter.next().transpose()?.is_none() { + return Err(TpccError::EmptyTuples); + } + } + match iter.next() { + Some(row) => row, + None => Err(TpccError::EmptyTuples), + } + } + + fn execute_drain( + &mut self, + statement: &PreparedStatement, + params: &[DbParam], + ) -> Result<(), TpccError> { + let mut iter = self.execute_raw(statement, params)?; + while let Some(row) = iter.next() { + row?; + } + Ok(()) } fn commit(mut self) -> Result<(), TpccError> { @@ -179,6 +219,45 @@ fn convert_value(value: &DataValue) -> Result { }) } +fn configure_sqlite(connection: &Connection, profile: SqliteProfile) -> Result<(), TpccError> { + let pragmas: &[&str] = match profile { + SqliteProfile::Balanced => &[ + "PRAGMA journal_mode = WAL;", + "PRAGMA synchronous = NORMAL;", + "PRAGMA temp_store = FILE;", + "PRAGMA foreign_keys = OFF;", + "PRAGMA busy_timeout = 5000;", + ], + SqliteProfile::Practical => &[ + "PRAGMA journal_mode = WAL;", + "PRAGMA synchronous = NORMAL;", + "PRAGMA temp_store = MEMORY;", + "PRAGMA cache_size = -32768;", + "PRAGMA mmap_size = 67108864;", + "PRAGMA foreign_keys = OFF;", + "PRAGMA busy_timeout = 5000;", + ], + }; + + for pragma in pragmas { + connection.execute(pragma)?; + } + Ok(()) +} + +fn normalize_sqlite_sql(sql: &str) -> Option { + let trimmed = sql.trim(); + let lower = trimmed.to_ascii_lowercase(); + if let Some(table) = lower.strip_prefix("analyze table ") { + let table = table.trim().trim_end_matches(';'); + if table.is_empty() { + return None; + } + return Some(format!("ANALYZE {table};")); + } + Some(trimmed.to_string()) +} + pub struct SqliteResult<'a> { cursor: CursorWithOwnership<'a>, column_types: &'static [ColumnType], diff --git a/tpcc/src/delivery.rs b/tpcc/src/delivery.rs index 994a1226..406d6bb6 100644 --- a/tpcc/src/delivery.rs +++ b/tpcc/src/delivery.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::backend::{BackendTransaction, PreparedStatement, TransactionExt}; +use crate::backend::{BackendTransaction, PreparedStatement}; use crate::load::DIST_PER_WARE; use crate::{TpccArgs, TpccError, TpccTest, TpccTransaction}; use chrono::Utc; @@ -47,14 +47,18 @@ impl TpccTransaction for Delivery { for d_id in 1..DIST_PER_WARE + 1 { // "SELECT COALESCE(MIN(no_o_id),0) FROM new_orders WHERE no_d_id = ? AND no_w_id = ?" - let tuple = tx.query_one( + let mut no_o_id = 0; + tx.with_query_one( &statements[0], &[ ("$1", DataValue::Int8(d_id as i8)), ("$2", DataValue::Int16(args.w_id as i16)), ], + &mut |tuple| { + no_o_id = tuple.values[0].i32().unwrap(); + Ok(()) + }, )?; - let no_o_id = tuple.values[0].i32().unwrap(); if no_o_id == 0 { continue; @@ -69,15 +73,19 @@ impl TpccTransaction for Delivery { ], )?; // "SELECT o_c_id FROM orders WHERE o_id = ? AND o_d_id = ? AND o_w_id = ?" - let tuple = tx.query_one( + let mut c_id = 0; + tx.with_query_one( &statements[2], &[ ("$1", DataValue::Int32(no_o_id)), ("$2", DataValue::Int8(d_id as i8)), ("$3", DataValue::Int16(args.w_id as i16)), ], + &mut |tuple| { + c_id = tuple.values[0].i32().unwrap(); + Ok(()) + }, )?; - let c_id = tuple.values[0].i32().unwrap(); // "UPDATE orders SET o_carrier_id = ? WHERE o_id = ? AND o_d_id = ? AND o_w_id = ?" tx.execute_drain( &statements[3], @@ -99,15 +107,19 @@ impl TpccTransaction for Delivery { ], )?; // "SELECT SUM(ol_amount) FROM order_line WHERE ol_o_id = ? AND ol_d_id = ? AND ol_w_id = ?" - let tuple = tx.query_one( + let mut ol_total = Default::default(); + tx.with_query_one( &statements[5], &[ ("$1", DataValue::Int32(no_o_id)), ("$2", DataValue::Int8(d_id as i8)), ("$3", DataValue::Int16(args.w_id as i16)), ], + &mut |tuple| { + ol_total = tuple.values[0].decimal().unwrap(); + Ok(()) + }, )?; - let ol_total = tuple.values[0].decimal().unwrap(); // "UPDATE customer SET c_balance = c_balance + ? , c_delivery_cnt = c_delivery_cnt + 1 WHERE c_id = ? AND c_d_id = ? AND c_w_id = ?" tx.execute_drain( &statements[6], diff --git a/tpcc/src/load.rs b/tpcc/src/load.rs index 2048213d..a3dd58ce 100644 --- a/tpcc/src/load.rs +++ b/tpcc/src/load.rs @@ -27,6 +27,7 @@ pub(crate) const MAX_ITEMS: usize = 100_000; pub(crate) const CUST_PER_DIST: usize = 3_000; pub(crate) const DIST_PER_WARE: usize = 10; pub(crate) const ORD_PER_DIST: usize = 3000; +const LOAD_BATCH_SIZE: usize = 500; pub(crate) static MAX_NUM_ITEMS: usize = 15; @@ -48,6 +49,46 @@ fn log_phase(task: &str, current: usize, total: usize, context: &str) { } } +struct SqlBatch<'a, E> { + exec: &'a E, + sql: String, + pending: usize, +} + +impl<'a, E: SimpleExecutor> SqlBatch<'a, E> { + fn new(exec: &'a E) -> Self { + Self { + exec, + sql: String::new(), + pending: 0, + } + } + + fn push(&mut self, statement: &str) -> Result<(), TpccError> { + if self.pending != 0 { + self.sql.push(';'); + } + self.sql.push_str(statement); + self.pending += 1; + + if self.pending >= LOAD_BATCH_SIZE { + self.flush()?; + } + + Ok(()) + } + + fn flush(&mut self) -> Result<(), TpccError> { + if self.pending == 0 { + return Ok(()); + } + + let sql = std::mem::take(&mut self.sql); + self.pending = 0; + self.exec.execute_batch(&sql) + } +} + fn generate_string(rng: &mut ThreadRng, min: usize, max: usize) -> String { let chars: Vec = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" .chars() @@ -136,6 +177,7 @@ impl Load { .unwrap(), ); let orig = Self::gen_orig(rng); + let mut batch = SqlBatch::new(exec); for i_id in 1..MAX_ITEMS + 1 { let i_im_id = rng.gen_range(1..10000); @@ -152,11 +194,12 @@ impl Load { i_data = format!("{}original{}", prefix, remainder); } - exec.execute_batch(&format!( + batch.push(&format!( "insert into item values ({i_id}, {i_im_id}, '{i_name}', {i_price}, '{i_data}')" ))?; pb.set_position(i_id as u64); } + batch.flush()?; finish_progress(&pb, Some("Items loaded")); println!("[Analyze Table: item]"); exec.execute_batch("analyze table item")?; @@ -444,6 +487,7 @@ impl Load { ); let s_w_id = w_id; let orig = Self::gen_orig(rng); + let mut batch = SqlBatch::new(exec); for s_i_id in 1..MAX_ITEMS + 1 { let s_quantity = rng.gen_range(10..100); @@ -463,7 +507,7 @@ impl Load { } else { generate_string(rng, 26, 50) }; - exec.execute_batch(&format!( + batch.push(&format!( "insert into stock values({}, {}, {}, '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', {}, {}, {}, '{}')", s_i_id, s_w_id, @@ -485,6 +529,7 @@ impl Load { ))?; pb.set_position(s_i_id as u64); } + batch.flush()?; finish_progress(&pb, None); Ok(()) @@ -540,6 +585,7 @@ impl Load { let d_w_id = w_id; let d_ytd = Decimal::from_f64_retain(30000.0).unwrap().round_dp(2); let d_next_o_id = 3001; + let mut batch = SqlBatch::new(exec); for d_id in 1..DIST_PER_WARE + 1 { let d_name = generate_string(rng, 6, 10); @@ -553,7 +599,7 @@ impl Load { .unwrap() .round_dp(2); - exec.execute_batch(&format!( + batch.push(&format!( "insert into district values({}, {}, '{}', '{}', '{}', '{}', '{}', '{}', {}, {}, {})", d_id, d_w_id, @@ -569,6 +615,7 @@ impl Load { ))?; pb.set_position(d_id as u64); } + batch.flush()?; finish_progress(&pb, None); Ok(()) @@ -626,6 +673,7 @@ impl Load { .unwrap(), ); let date = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); + let mut batch = SqlBatch::new(exec); for c_id in 1..CUST_PER_DIST + 1 { let c_d_id = d_id; @@ -661,7 +709,7 @@ impl Load { let c_data = generate_string(rng, 300, 500); - exec.execute_batch(&format!( + batch.push(&format!( "insert into customer values({}, {}, {}, '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}', {}, {}, {}, {}, {}, {}, '{}')", c_id, c_d_id, @@ -690,12 +738,13 @@ impl Load { let h_amount = Decimal::from_f64_retain(10.0).unwrap().round_dp(2); let h_data = generate_string(rng, 12, 24); - exec.execute_batch(&format!( + batch.push(&format!( "insert into history values({}, {}, {}, {}, {}, '{}', {}, '{}')", c_id, c_d_id, c_w_id, c_d_id, c_w_id, h_date, h_amount, h_data, ))?; pb.set_position(c_id as u64); } + batch.flush()?; finish_progress(&pb, None); Ok(()) @@ -756,6 +805,7 @@ impl Load { let o_w_id = w_id; let nums = init_permutation(rng); + let mut batch = SqlBatch::new(exec); for o_id in 1..ORD_PER_DIST + 1 { let o_c_id = nums[o_id - 1]; @@ -765,7 +815,7 @@ impl Load { let date = format!("'{}'", Utc::now().format("%Y-%m-%d %H:%M:%S")); let o_carrier_id = if o_id > 2100 { - exec.execute_batch(&format!( + batch.push(&format!( "insert into new_orders values({}, {}, {})", o_id, o_d_id, o_w_id, ))?; @@ -773,7 +823,7 @@ impl Load { } else { o_carrier_id.to_string() }; - exec.execute_batch(&format!( + batch.push(&format!( "insert into orders values({}, {}, {}, {}, {}, {}, {}, {})", o_id, o_d_id, o_w_id, o_c_id, date, o_carrier_id, o_ol_cnt, "1", ))?; @@ -789,7 +839,7 @@ impl Load { } else { (date.as_str(), format!("{:.2}", rng.gen_range(0.1..100.0))) }; - exec.execute_batch(&format!( + batch.push(&format!( "insert into order_line values({}, {}, {}, {}, {}, {}, {}, {}, {}, '{}')", o_id, o_d_id, @@ -805,8 +855,60 @@ impl Load { } pb.set_position((o_id - 1) as u64); } + batch.flush()?; finish_progress(&pb, None); Ok(()) } } + +#[cfg(test)] +mod tests { + use super::{SqlBatch, LOAD_BATCH_SIZE}; + use crate::backend::SimpleExecutor; + use crate::TpccError; + use std::cell::RefCell; + + #[derive(Default)] + struct RecordingExecutor { + calls: RefCell>, + } + + impl SimpleExecutor for RecordingExecutor { + fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + self.calls.borrow_mut().push(sql.to_string()); + Ok(()) + } + } + + #[test] + fn sql_batch_groups_statements() { + let exec = RecordingExecutor::default(); + let mut batch = SqlBatch::new(&exec); + + batch.push("insert 1").unwrap(); + batch.push("insert 2").unwrap(); + batch.flush().unwrap(); + + assert_eq!( + exec.calls.into_inner(), + vec!["insert 1;insert 2".to_string()] + ); + } + + #[test] + fn sql_batch_flushes_at_batch_size() { + let exec = RecordingExecutor::default(); + let mut batch = SqlBatch::new(&exec); + + for i in 0..=LOAD_BATCH_SIZE { + batch.push(&format!("insert {i}")).unwrap(); + } + batch.flush().unwrap(); + + let calls = exec.calls.into_inner(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].split(';').count(), LOAD_BATCH_SIZE); + assert_eq!(calls[1], format!("insert {}", LOAD_BATCH_SIZE)); + } +} diff --git a/tpcc/src/main.rs b/tpcc/src/main.rs index 4fb59412..9e4fd8d6 100644 --- a/tpcc/src/main.rs +++ b/tpcc/src/main.rs @@ -13,11 +13,13 @@ // limitations under the License. use crate::backend::dual::DualBackend; -use crate::backend::kite::KiteBackend; +use crate::backend::kitesql_lmdb::KiteSqlLmdbBackend; +use crate::backend::kitesql_rocksdb::KiteSqlRocksDbBackend; +use crate::backend::sqlite::{SqliteBackend, SqliteProfile}; use crate::backend::{ BackendControl, BackendTransaction, ColumnType, PreparedStatement, StatementSpec, }; -use crate::delivery::DeliveryTest; +use crate::delivery::{Delivery, DeliveryArgs, DeliveryTest}; use crate::load::Load; use crate::new_ord::NewOrdTest; use crate::order_stat::OrderStatTest; @@ -28,10 +30,14 @@ use crate::utils::SeqGen; use clap::{Parser, ValueEnum}; use indicatif::{ProgressBar, ProgressStyle}; use kite_sql::errors::DatabaseError; +use kite_sql::types::value::DataValue; +#[cfg(all(unix, feature = "pprof"))] +use pprof::ProfilerGuard; use rand::prelude::ThreadRng; use rand::Rng; use std::fs; use std::path::Path; +use std::path::PathBuf; use std::time::{Duration, Instant}; mod backend; @@ -97,7 +103,7 @@ struct Args { joins: bool, #[clap(long, default_value = "kite_sql_tpcc")] path: String, - #[clap(long, value_enum, default_value = "kite")] + #[clap(long, value_enum, default_value = "kitesql-lmdb")] backend: BackendKind, #[clap(long, default_value = "5")] max_retry: usize, @@ -107,11 +113,21 @@ struct Args { num_ware: usize, #[clap(long, default_value = "false")] rocksdb_stats: bool, + #[clap(long, value_enum, default_value = "balanced")] + sqlite_profile: SqliteProfile, + #[cfg(feature = "pprof")] + #[clap(long)] + pprof_output: Option, + #[cfg(feature = "pprof")] + #[clap(long, default_value = "100")] + pprof_frequency: i32, } #[derive(Copy, Clone, Debug, ValueEnum)] enum BackendKind { - Kite, + KitesqlRocksdb, + KitesqlLmdb, + Sqlite, Dual, } @@ -121,23 +137,39 @@ fn main() -> Result<(), TpccError> { let mut rng = rand::thread_rng(); match args.backend { - BackendKind::Kite => { - let db_path = Path::new(&args.path); - if db_path.exists() { - fs::remove_dir_all(db_path)?; - } - let backend = KiteBackend::new(&args.path, args.rocksdb_stats)?; - run_tpcc(&backend, &args, &mut rng)?; + BackendKind::KitesqlRocksdb => { + reset_db_path(Path::new(&args.path))?; + let backend = KiteSqlRocksDbBackend::new(&args.path, args.rocksdb_stats)?; + run_tpcc(&backend, &args, &mut rng) + } + BackendKind::KitesqlLmdb => { + reset_db_path(Path::new(&args.path))?; + let backend = KiteSqlLmdbBackend::new(&args.path)?; + run_tpcc(&backend, &args, &mut rng) + } + BackendKind::Sqlite => { + reset_db_path(Path::new(&args.path))?; + let backend = SqliteBackend::new(&args.path, args.sqlite_profile)?; + run_tpcc(&backend, &args, &mut rng) } BackendKind::Dual => { - let db_path = Path::new(&args.path); - if db_path.exists() { - fs::remove_dir_all(db_path)?; - } + reset_db_path(Path::new(&args.path))?; let backend = DualBackend::new(&args.path, args.rocksdb_stats)?; - run_tpcc(&backend, &args, &mut rng)?; + run_tpcc(&backend, &args, &mut rng) } } +} + +fn reset_db_path(path: &Path) -> Result<(), TpccError> { + if !path.exists() { + return Ok(()); + } + + if path.is_dir() { + fs::remove_dir_all(path)?; + } else { + fs::remove_file(path)?; + } Ok(()) } @@ -172,6 +204,8 @@ fn run_tpcc( let mut round_count = 0; let mut seq_gen = SeqGen::new(10, 10, 1, 1, 1); let tpcc_start = Instant::now(); + #[cfg(all(unix, feature = "pprof"))] + let pprof = PprofSession::start(args)?; let progress = ProgressBar::new_spinner(); progress.set_style(ProgressStyle::with_template("{spinner:.green} [TPCC] {msg}").unwrap()); progress.enable_steady_tick(Duration::from_millis(120)); @@ -264,6 +298,7 @@ fn run_tpcc( print_constraint_checks(&success, &late); print_response_checks(&success, &late); println!(); + rt_hist.finalize(); rt_hist.hist_report(); println!(""); let tpmc = ((success[0] + late[0]) as f64 / (actual_tpcc_time.as_secs_f64() / 60.0)).round(); @@ -272,10 +307,47 @@ fn run_tpcc( println!(); println!("{metrics}"); } + #[cfg(all(unix, feature = "pprof"))] + if let Some(pprof) = pprof { + pprof.finish()?; + } Ok(()) } +#[cfg(all(unix, feature = "pprof"))] +struct PprofSession { + guard: ProfilerGuard<'static>, + output: PathBuf, +} + +#[cfg(all(unix, feature = "pprof"))] +impl PprofSession { + fn start(args: &Args) -> Result, TpccError> { + let Some(output) = args.pprof_output.clone() else { + return Ok(None); + }; + let guard = ProfilerGuard::new(args.pprof_frequency) + .map_err(|err| TpccError::Profile(err.to_string()))?; + Ok(Some(Self { guard, output })) + } + + fn finish(self) -> Result<(), TpccError> { + let report = self + .guard + .report() + .build() + .map_err(|err| TpccError::Profile(err.to_string()))?; + let file = fs::File::create(&self.output)?; + report + .flamegraph(file) + .map_err(|err| TpccError::Profile(err.to_string()))?; + println!(); + println!("[pprof] flamegraph written to {}", self.output.display()); + Ok(()) + } +} + fn statement_specs() -> Vec> { vec![ vec![ @@ -694,6 +766,8 @@ pub enum TpccError { InvalidDateTime, #[error("backend mismatch: {0}")] BackendMismatch(String), + #[error("profile error: {0}")] + Profile(String), } #[ignore] @@ -702,7 +776,7 @@ fn explain_tpcc() -> Result<(), DatabaseError> { use kite_sql::db::DataBaseBuilder; use kite_sql::types::tuple::create_table; - let database = DataBaseBuilder::path("./kite_sql_tpcc").build()?; + let database = DataBaseBuilder::path(tpcc_db_path()).build_lmdb()?; let mut tx = database.new_transaction()?; let customer_tuple = tx @@ -956,3 +1030,11 @@ fn explain_tpcc() -> Result<(), DatabaseError> { Ok(()) } + +#[cfg(test)] +fn tpcc_db_path() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("kite_sql_tpcc") +} diff --git a/tpcc/src/new_ord.rs b/tpcc/src/new_ord.rs index d70b9623..ec3550a3 100644 --- a/tpcc/src/new_ord.rs +++ b/tpcc/src/new_ord.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::backend::{BackendTransaction, PreparedStatement, TransactionExt}; +use crate::backend::{BackendTransaction, PreparedStatement}; use crate::load::{nu_rand, CUST_PER_DIST, DIST_PER_WARE, MAX_ITEMS, MAX_NUM_ITEMS}; use crate::{other_ware, TpccArgs, TpccError, TpccTest, TpccTransaction, ALLOW_MULTI_WAREHOUSE_TX}; use chrono::Utc; @@ -81,7 +81,11 @@ impl TpccTransaction for NewOrd { let (c_discount, _c_last, _c_credit, w_tax) = if args.joins { // "SELECT c_discount, c_last, c_credit, w_tax FROM customer, warehouse WHERE w_id = ? AND c_w_id = w_id AND c_d_id = ? AND c_id = ?" - let tuple = tx.query_one( + let mut c_discount = Decimal::default(); + let mut c_last = String::new(); + let mut c_credit = String::new(); + let mut w_tax = Decimal::default(); + tx.with_query_one( &statements[0], &[ ("$1", DataValue::Int16(args.w_id as i16)), @@ -89,45 +93,63 @@ impl TpccTransaction for NewOrd { ("$3", DataValue::Int8(args.d_id as i8)), ("$4", DataValue::Int64(args.c_id as i64)), ], + &mut |tuple| { + c_discount = tuple.values[0].decimal().unwrap(); + c_last = tuple.values[1].utf8().unwrap().to_string(); + c_credit = tuple.values[2].utf8().unwrap().to_string(); + w_tax = tuple.values[3].decimal().unwrap(); + Ok(()) + }, )?; - let c_discount = tuple.values[0].decimal().unwrap(); - let c_last = tuple.values[1].utf8().unwrap().to_string(); - let c_credit = tuple.values[2].utf8().unwrap().to_string(); - let w_tax = tuple.values[3].decimal().unwrap(); (c_discount, c_last, c_credit, w_tax) } else { // "SELECT c_discount, c_last, c_credit FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_id = ?" - let tuple = tx.query_one( + let mut c_discount = Decimal::default(); + let mut c_last = String::new(); + let mut c_credit = String::new(); + tx.with_query_one( &statements[1], &[ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ("$3", DataValue::Int32(args.c_id as i32)), ], + &mut |tuple| { + c_discount = tuple.values[0].decimal().unwrap(); + c_last = tuple.values[1].utf8().unwrap().to_string(); + c_credit = tuple.values[2].utf8().unwrap().to_string(); + Ok(()) + }, )?; - let c_discount = tuple.values[0].decimal().unwrap(); - let c_last = tuple.values[1].utf8().unwrap().to_string(); - let c_credit = tuple.values[2].utf8().unwrap().to_string(); // "SELECT w_tax FROM warehouse WHERE w_id = ?" - let tuple = tx.query_one( + let mut w_tax = Decimal::default(); + tx.with_query_one( &statements[2], &[("$1", DataValue::Int16(args.w_id as i16))], + &mut |tuple| { + w_tax = tuple.values[0].decimal().unwrap(); + Ok(()) + }, )?; - let w_tax = tuple.values[0].decimal().unwrap(); (c_discount, c_last, c_credit, w_tax) }; // "SELECT d_next_o_id, d_tax FROM district WHERE d_id = ? AND d_w_id = ? FOR UPDATE" - let tuple = tx.query_one( + let mut d_next_o_id = 0; + let mut d_tax = Decimal::default(); + tx.with_query_one( &statements[3], &[ ("$1", DataValue::Int8(args.d_id as i8)), ("$2", DataValue::Int16(args.w_id as i16)), ], + &mut |tuple| { + d_next_o_id = tuple.values[0].i32().unwrap(); + d_tax = tuple.values[1].decimal().unwrap(); + Ok(()) + }, )?; - let d_next_o_id = tuple.values[0].i32().unwrap(); - let d_tax = tuple.values[1].decimal().unwrap(); // "UPDATE district SET d_next_o_id = ? + 1 WHERE d_id = ? AND d_w_id = ?" tx.execute_drain( &statements[4], @@ -189,36 +211,55 @@ impl TpccTransaction for NewOrd { let ol_quantity = args.qty[ol_num_seq[ol_number - 1]]; // "SELECT i_price, i_name, i_data FROM item WHERE i_id = ?" let params = [("$1", DataValue::Int32(ol_i_id as i32))]; - let tuple = tx.query_one(&statements[7], ¶ms)?; - let i_price = tuple.values[0].decimal().unwrap(); - let i_name = tuple.values[1].utf8().unwrap(); - let i_data = tuple.values[2].utf8().unwrap(); + let mut i_price = Decimal::default(); + let mut i_name = String::new(); + let mut i_data = String::new(); + tx.with_query_one(&statements[7], ¶ms, &mut |tuple| { + i_price = tuple.values[0].decimal().unwrap(); + i_name = tuple.values[1].utf8().unwrap().to_string(); + i_data = tuple.values[2].utf8().unwrap().to_string(); + Ok(()) + })?; price[ol_num_seq[ol_number - 1]] = i_price; - iname[ol_num_seq[ol_number - 1]] = i_name.to_string(); + iname[ol_num_seq[ol_number - 1]] = i_name; // "SELECT s_quantity, s_data, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, s_dist_07, s_dist_08, s_dist_09, s_dist_10 FROM stock WHERE s_i_id = ? AND s_w_id = ? FOR UPDATE" let params = [ ("$1", DataValue::Int32(ol_i_id as i32)), ("$2", DataValue::Int16(ol_supply_w_id as i16)), ]; - let tuple = tx.query_one(&statements[8], ¶ms)?; - let mut s_quantity = tuple.values[0].i16().unwrap(); - let s_data = tuple.values[1].utf8().unwrap(); - let s_dist_01 = tuple.values[2].utf8().unwrap(); - let s_dist_02 = tuple.values[3].utf8().unwrap(); - let s_dist_03 = tuple.values[4].utf8().unwrap(); - let s_dist_04 = tuple.values[5].utf8().unwrap(); - let s_dist_05 = tuple.values[6].utf8().unwrap(); - let s_dist_06 = tuple.values[7].utf8().unwrap(); - let s_dist_07 = tuple.values[8].utf8().unwrap(); - let s_dist_08 = tuple.values[9].utf8().unwrap(); - let s_dist_09 = tuple.values[10].utf8().unwrap(); - let s_dist_10 = tuple.values[11].utf8().unwrap(); + let mut s_quantity = 0; + let mut s_data = String::new(); + let mut s_dist_01 = String::new(); + let mut s_dist_02 = String::new(); + let mut s_dist_03 = String::new(); + let mut s_dist_04 = String::new(); + let mut s_dist_05 = String::new(); + let mut s_dist_06 = String::new(); + let mut s_dist_07 = String::new(); + let mut s_dist_08 = String::new(); + let mut s_dist_09 = String::new(); + let mut s_dist_10 = String::new(); + tx.with_query_one(&statements[8], ¶ms, &mut |tuple| { + s_quantity = tuple.values[0].i16().unwrap(); + s_data = tuple.values[1].utf8().unwrap().to_string(); + s_dist_01 = tuple.values[2].utf8().unwrap().to_string(); + s_dist_02 = tuple.values[3].utf8().unwrap().to_string(); + s_dist_03 = tuple.values[4].utf8().unwrap().to_string(); + s_dist_04 = tuple.values[5].utf8().unwrap().to_string(); + s_dist_05 = tuple.values[6].utf8().unwrap().to_string(); + s_dist_06 = tuple.values[7].utf8().unwrap().to_string(); + s_dist_07 = tuple.values[8].utf8().unwrap().to_string(); + s_dist_08 = tuple.values[9].utf8().unwrap().to_string(); + s_dist_09 = tuple.values[10].utf8().unwrap().to_string(); + s_dist_10 = tuple.values[11].utf8().unwrap().to_string(); + Ok(()) + })?; let ol_dist_info = pick_dist_info( - args.d_id, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, - s_dist_07, s_dist_08, s_dist_09, s_dist_10, + args.d_id, &s_dist_01, &s_dist_02, &s_dist_03, &s_dist_04, &s_dist_05, &s_dist_06, + &s_dist_07, &s_dist_08, &s_dist_09, &s_dist_10, ); stock[ol_num_seq[ol_number - 1]] = s_quantity; bg[ol_num_seq[ol_number - 1]] = diff --git a/tpcc/src/order_stat.rs b/tpcc/src/order_stat.rs index 3c7490fa..883bd288 100644 --- a/tpcc/src/order_stat.rs +++ b/tpcc/src/order_stat.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::backend::{BackendTransaction, PreparedStatement, TransactionExt}; +use crate::backend::{BackendTransaction, PreparedStatement}; use crate::load::{last_name, nu_rand, CUST_PER_DIST, DIST_PER_WARE}; use crate::{TpccArgs, TpccError, TpccTest, TpccTransaction}; use kite_sql::types::value::DataValue; use rand::prelude::ThreadRng; use rand::Rng; -use rust_decimal::Decimal; #[derive(Debug)] pub(crate) struct OrderStatArgs { @@ -60,54 +59,62 @@ impl TpccTransaction for OrderStat { ) -> Result<(), TpccError> { let (_c_balance, _c_first, _c_middle, _c_last) = if args.by_name { // "SELECT count(c_id) FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ?" - let tuple = tx.query_one( + let mut name_cnt = 0usize; + tx.with_query_one( &statements[0], &[ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ("$3", DataValue::from(args.c_last.clone())), ], + &mut |tuple| { + name_cnt = tuple.values[0].i32().unwrap() as usize; + Ok(()) + }, )?; - let mut name_cnt = tuple.values[0].i32().unwrap() as usize; // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ? ORDER BY c_first" let params = [ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ("$3", DataValue::from(args.c_last.clone())), ]; - let mut tuple_iter = tx.execute(&statements[1], ¶ms)?; - if name_cnt % 2 == 1 { name_cnt += 1; } - let mut c_balance = Decimal::default(); + let target = name_cnt / 2 - 1; + let mut c_balance = Default::default(); let mut c_first = String::new(); let mut c_middle = String::new(); let mut c_last = String::new(); - - for _ in 0..name_cnt / 2 { - let tuple = tuple_iter.next().unwrap()?; - + tx.with_query_nth(&statements[1], ¶ms, target, &mut |tuple| { c_balance = tuple.values[0].decimal().unwrap(); c_first = tuple.values[1].utf8().unwrap().to_string(); c_middle = tuple.values[2].utf8().unwrap().to_string(); c_last = tuple.values[3].utf8().unwrap().to_string(); - } + Ok(()) + })?; (c_balance, c_first, c_middle, c_last) } else { // "SELECT c_balance, c_first, c_middle, c_last FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_id = ?" - let tuple = tx.query_one( + let mut c_balance = Default::default(); + let mut c_first = String::new(); + let mut c_middle = String::new(); + let mut c_last = String::new(); + tx.with_query_one( &statements[2], &[ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ("$3", DataValue::Int32(args.c_id as i32)), ], + &mut |tuple| { + c_balance = tuple.values[0].decimal().unwrap(); + c_first = tuple.values[1].utf8().unwrap().to_string(); + c_middle = tuple.values[2].utf8().unwrap().to_string(); + c_last = tuple.values[3].utf8().unwrap().to_string(); + Ok(()) + }, )?; - let c_balance = tuple.values[0].decimal().unwrap(); - let c_first = tuple.values[1].utf8().unwrap().to_string(); - let c_middle = tuple.values[2].utf8().unwrap().to_string(); - let c_last = tuple.values[3].utf8().unwrap().to_string(); (c_balance, c_first, c_middle, c_last) }; // "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders WHERE o_w_id = ? AND o_d_id = ? AND o_c_id = ? AND o_id = (SELECT MAX(o_id) FROM orders WHERE o_w_id = ? AND o_d_id = ? AND o_c_id = ?)" @@ -119,15 +126,18 @@ impl TpccTransaction for OrderStat { ("$5", DataValue::Int8(args.d_id as i8)), ("$6", DataValue::Int32(args.c_id as i32)), ]; - let tuple = tx.query_one(&statements[3], ¶ms)?; - let o_id = tuple.values[0].i32().unwrap(); + let mut o_id = 0; + tx.with_query_one(&statements[3], ¶ms, &mut |tuple| { + o_id = tuple.values[0].i32().unwrap(); + Ok(()) + })?; // "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d FROM order_line WHERE ol_w_id = ? AND ol_d_id = ? AND ol_o_id = ?" let params = [ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ("$3", DataValue::Int32(o_id)), ]; - let _tuple = tx.query_one(&statements[4], ¶ms)?; + tx.with_query_one(&statements[4], ¶ms, &mut |_| Ok(()))?; // let ol_i_id = tuple.values[0].i32(); // let ol_supply_w_id = tuple.values[1].i16(); // let ol_quantity = tuple.values[2].i8(); diff --git a/tpcc/src/payment.rs b/tpcc/src/payment.rs index 9c64e411..1a41e856 100644 --- a/tpcc/src/payment.rs +++ b/tpcc/src/payment.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::backend::{BackendTransaction, PreparedStatement, TransactionExt}; +use crate::backend::{BackendTransaction, PreparedStatement}; use crate::load::{last_name, nu_rand, CUST_PER_DIST, DIST_PER_WARE}; use crate::{other_ware, TpccArgs, TpccError, TpccTest, TpccTransaction, ALLOW_MULTI_WAREHOUSE_TX}; use chrono::Utc; @@ -80,16 +80,25 @@ impl TpccTransaction for Payment { ], )?; // "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name FROM warehouse WHERE w_id = ?" - let tuple = tx.query_one( + let mut w_street_1 = String::new(); + let mut w_street_2 = String::new(); + let mut w_city = String::new(); + let mut w_state = String::new(); + let mut w_zip = String::new(); + let mut w_name = String::new(); + tx.with_query_one( &statements[1], &[("$1", DataValue::Int16(args.w_id as i16))], + &mut |tuple| { + w_street_1 = tuple.values[0].utf8().unwrap().to_string(); + w_street_2 = tuple.values[1].utf8().unwrap().to_string(); + w_city = tuple.values[2].utf8().unwrap().to_string(); + w_state = tuple.values[3].utf8().unwrap().to_string(); + w_zip = tuple.values[4].utf8().unwrap().to_string(); + w_name = tuple.values[5].utf8().unwrap().to_string(); + Ok(()) + }, )?; - let w_street_1 = tuple.values[0].utf8().unwrap(); - let w_street_2 = tuple.values[1].utf8().unwrap(); - let w_city = tuple.values[2].utf8().unwrap(); - let w_state = tuple.values[3].utf8().unwrap(); - let w_zip = tuple.values[4].utf8().unwrap(); - let w_name = tuple.values[5].utf8().unwrap(); // "UPDATE district SET d_ytd = d_ytd + ? WHERE d_w_id = ? AND d_id = ?" tx.execute_drain( @@ -102,84 +111,118 @@ impl TpccTransaction for Payment { )?; // "SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name FROM district WHERE d_w_id = ? AND d_id = ?" - let tuple = tx.query_one( + let mut d_street_1 = String::new(); + let mut d_street_2 = String::new(); + let mut d_city = String::new(); + let mut d_state = String::new(); + let mut d_zip = String::new(); + let mut d_name = String::new(); + tx.with_query_one( &statements[3], &[ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(args.d_id as i8)), ], + &mut |tuple| { + d_street_1 = tuple.values[0].utf8().unwrap().to_string(); + d_street_2 = tuple.values[1].utf8().unwrap().to_string(); + d_city = tuple.values[2].utf8().unwrap().to_string(); + d_state = tuple.values[3].utf8().unwrap().to_string(); + d_zip = tuple.values[4].utf8().unwrap().to_string(); + d_name = tuple.values[5].utf8().unwrap().to_string(); + Ok(()) + }, )?; - let d_street_1 = tuple.values[0].utf8().unwrap(); - let d_street_2 = tuple.values[1].utf8().unwrap(); - let d_city = tuple.values[2].utf8().unwrap(); - let d_state = tuple.values[3].utf8().unwrap(); - let d_zip = tuple.values[4].utf8().unwrap(); - let d_name = tuple.values[5].utf8().unwrap(); let mut c_id = args.c_id as i32; if args.by_name { // "SELECT count(c_id) FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ?" - let tuple = tx.query_one( + let mut name_cnt = 0; + tx.with_query_one( &statements[4], &[ ("$1", DataValue::Int16(args.c_w_id as i16)), ("$2", DataValue::Int8(args.c_d_id as i8)), ("$3", DataValue::from(args.c_last.clone())), ], + &mut |tuple| { + name_cnt = tuple.values[0].i32().unwrap(); + Ok(()) + }, )?; - let mut name_cnt = tuple.values[0].i32().unwrap(); // "SELECT c_id FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_last = ? ORDER BY c_first" let params = [ ("$1", DataValue::Int16(args.c_w_id as i16)), ("$2", DataValue::Int8(args.c_d_id as i8)), ("$3", DataValue::from(args.c_last.clone())), ]; - let mut tuple_iter = tx.execute(&statements[5], ¶ms)?; if name_cnt % 2 == 1 { name_cnt += 1; } - for _ in 0..name_cnt / 2 { - let result = tuple_iter.next().unwrap()?; - c_id = result.values[0].i32().unwrap(); - } + let target = name_cnt as usize / 2 - 1; + tx.with_query_nth(&statements[5], ¶ms, target, &mut |tuple| { + c_id = tuple.values[0].i32().unwrap(); + Ok(()) + })?; } // "SELECT c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_credit, c_credit_lim, c_discount, c_balance, c_since FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_id = ? FOR UPDATE" - let tuple = tx.query_one( + let mut c_first = String::new(); + let mut c_middle = String::new(); + let mut c_last = String::new(); + let mut c_street_1 = String::new(); + let mut c_street_2 = String::new(); + let mut c_city = String::new(); + let mut c_state = String::new(); + let mut c_zip = String::new(); + let mut c_phone = String::new(); + let mut c_credit = None; + let mut c_credit_lim = 0; + let mut c_discount = Decimal::default(); + let mut c_balance = Decimal::default(); + let mut c_since = Default::default(); + tx.with_query_one( &statements[6], &[ ("$1", DataValue::Int16(args.c_w_id as i16)), ("$2", DataValue::Int8(args.c_d_id as i8)), ("$3", DataValue::Int32(c_id)), ], + &mut |tuple| { + c_first = tuple.values[0].utf8().unwrap().to_string(); + c_middle = tuple.values[1].utf8().unwrap().to_string(); + c_last = tuple.values[2].utf8().unwrap().to_string(); + c_street_1 = tuple.values[3].utf8().unwrap().to_string(); + c_street_2 = tuple.values[4].utf8().unwrap().to_string(); + c_city = tuple.values[5].utf8().unwrap().to_string(); + c_state = tuple.values[6].utf8().unwrap().to_string(); + c_zip = tuple.values[7].utf8().unwrap().to_string(); + c_phone = tuple.values[8].utf8().unwrap().to_string(); + c_credit = tuple.values[9].utf8().map(ToString::to_string); + c_credit_lim = tuple.values[10].i64().unwrap(); + c_discount = tuple.values[11].decimal().unwrap(); + c_balance = tuple.values[12].decimal().unwrap(); + c_since = tuple.values[13].datetime().unwrap(); + Ok(()) + }, )?; - let c_first = tuple.values[0].utf8().unwrap(); - let c_middle = tuple.values[1].utf8().unwrap(); - let c_last = tuple.values[2].utf8().unwrap(); - let c_street_1 = tuple.values[3].utf8().unwrap(); - let c_street_2 = tuple.values[4].utf8().unwrap(); - let c_city = tuple.values[5].utf8().unwrap(); - let c_state = tuple.values[6].utf8().unwrap(); - let c_zip = tuple.values[7].utf8().unwrap(); - let c_phone = tuple.values[8].utf8().unwrap(); - let c_credit = tuple.values[9].utf8(); - let c_credit_lim = tuple.values[10].i64().unwrap(); - let c_discount = tuple.values[11].decimal().unwrap(); - let mut c_balance = tuple.values[12].decimal().unwrap(); - let c_since = tuple.values[13].datetime().unwrap(); c_balance += args.h_amount; if let Some(c_credit) = c_credit { if c_credit.contains("BC") { // "SELECT c_data FROM customer WHERE c_w_id = ? AND c_d_id = ? AND c_id = ?" - let tuple = tx.query_one( + let mut c_data = String::new(); + tx.with_query_one( &statements[7], &[ ("$1", DataValue::Int16(args.c_w_id as i16)), ("$2", DataValue::Int8(args.c_d_id as i8)), ("$3", DataValue::Int32(c_id)), ], + &mut |tuple| { + c_data = tuple.values[0].utf8().unwrap().to_string(); + Ok(()) + }, )?; - let c_data = tuple.values[0].utf8().unwrap(); // https://github.com/AgilData/tpcc/blob/dfbabe1e35cc93b2bf2e107fc699eb29c2097e24/src/main/java/com/codefutures/tpcc/Payment.java#L284 // let c_new_data = format!("| {} {} {} {} {} {} {}", c_id, args.c_d_id, args.c_w_id, args.d_id, args.w_id, args.h_amount, ) @@ -189,7 +232,7 @@ impl TpccTransaction for Payment { &statements[8], &[ ("$1", DataValue::Decimal(c_balance)), - ("$2", DataValue::from(c_data.to_string())), + ("$2", DataValue::from(c_data)), ("$3", DataValue::Int16(args.c_w_id as i16)), ("$4", DataValue::Int8(args.c_d_id as i8)), ("$5", DataValue::Int32(c_id)), diff --git a/tpcc/src/rt_hist.rs b/tpcc/src/rt_hist.rs index 21402115..8ed62166 100644 --- a/tpcc/src/rt_hist.rs +++ b/tpcc/src/rt_hist.rs @@ -88,6 +88,19 @@ impl RtHist { line as f64 / REC_PER_SEC as f64 } + // Report histograms + pub fn finalize(&mut self) { + for transaction in 0..NUM_TRANSACTIONS { + for i in 0..(MAX_REC * REC_PER_SEC) { + self.total_hist[transaction][i] += self.cur_hist[transaction][i]; + self.cur_hist[transaction][i] = 0; + } + self.max_rt[transaction] = *OrderedFloat(self.cur_max_rt[transaction]) + .max(OrderedFloat(self.max_rt[transaction])); + self.cur_max_rt[transaction] = 0.0; + } + } + // Report histograms pub fn hist_report(&self) { let mut total = [0; NUM_TRANSACTIONS]; diff --git a/tpcc/src/slev.rs b/tpcc/src/slev.rs index 24c477af..c3bb10e3 100644 --- a/tpcc/src/slev.rs +++ b/tpcc/src/slev.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::backend::{BackendTransaction, PreparedStatement, TransactionExt}; +use crate::backend::{BackendTransaction, PreparedStatement}; use crate::load::DIST_PER_WARE; use crate::{TpccArgs, TpccError, TpccTest, TpccTransaction}; use kite_sql::types::value::DataValue; @@ -44,16 +44,21 @@ impl TpccTransaction for Slev { statements: &[PreparedStatement], ) -> Result<(), TpccError> { // "SELECT d_next_o_id FROM district WHERE d_id = ? AND d_w_id = ?" - let tuple = tx.query_one( + let mut d_next_o_id = 0; + tx.with_query_one( &statements[0], &[ ("$1", DataValue::Int8(args.d_id as i8)), ("$2", DataValue::Int16(args.w_id as i16)), ], + &mut |tuple| { + d_next_o_id = tuple.values[0].i32().unwrap(); + Ok(()) + }, )?; - let d_next_o_id = tuple.values[0].i32().unwrap(); // "SELECT DISTINCT ol_i_id FROM order_line WHERE ol_w_id = ? AND ol_d_id = ? AND ol_o_id < ? AND ol_o_id >= (? - 20)" - let tuple = tx.query_one( + let mut ol_i_id = 0; + tx.with_query_one( &statements[1], &[ ("$1", DataValue::Int16(args.w_id as i16)), @@ -61,16 +66,20 @@ impl TpccTransaction for Slev { ("$3", DataValue::Int32(d_next_o_id)), ("$4", DataValue::Int32(d_next_o_id)), ], + &mut |tuple| { + ol_i_id = tuple.values[0].i32().unwrap(); + Ok(()) + }, )?; - let ol_i_id = tuple.values[0].i32().unwrap(); // "SELECT count(*) FROM stock WHERE s_w_id = ? AND s_i_id = ? AND s_quantity < ?" - let _tuple = tx.query_one( + tx.with_query_one( &statements[2], &[ ("$1", DataValue::Int16(args.w_id as i16)), ("$2", DataValue::Int8(ol_i_id as i8)), ("$3", DataValue::Int16(args.level as i16)), ], + &mut |_| Ok(()), )?; // let i_count = tuple.values[0].i32().unwrap();