From a12e198926191e3e1df6c9a21b33ca2545696309 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 11:58:27 +0200 Subject: [PATCH 01/12] first pass --- src/execute/layer.rs | 13 + src/execute/mod.rs | 22 +- src/plot/layer/geom/area.rs | 20 +- src/plot/layer/geom/arrow.rs | 4 + src/plot/layer/geom/bar.rs | 16 +- src/plot/layer/geom/errorbar.rs | 4 + src/plot/layer/geom/line.rs | 22 +- src/plot/layer/geom/mod.rs | 47 +- src/plot/layer/geom/path.rs | 4 + src/plot/layer/geom/point.rs | 4 + src/plot/layer/geom/polygon.rs | 4 + src/plot/layer/geom/ribbon.rs | 20 +- src/plot/layer/geom/rule.rs | 4 + src/plot/layer/geom/segment.rs | 4 + src/plot/layer/geom/stat_aggregate.rs | 982 ++++++++++++++++++++++++++ src/plot/layer/geom/text.rs | 4 + src/plot/layer/mod.rs | 5 + src/reader/duckdb.rs | 8 + src/reader/mod.rs | 11 + 19 files changed, 1164 insertions(+), 34 deletions(-) create mode 100644 src/plot/layer/geom/stat_aggregate.rs diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 6af5c641..1c9b1b0d 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -584,6 +584,19 @@ where layer.mappings.aesthetics.remove(aes); } + // Auto-remap stat columns whose names are position aesthetics that were + // consumed by the stat (e.g. Aggregate's `pos1`/`pos2` outputs). The geom + // can't list these in `default_remappings` because the set of position + // aesthetics in play is dynamic per layer. + for stat in &stat_columns { + if final_remappings.contains_key(stat) { + continue; + } + if aesthetic::is_position_aesthetic(stat) && consumed_aesthetics.contains(stat) { + final_remappings.insert(stat.clone(), stat.clone()); + } + } + // Apply stat_columns to layer aesthetics using the remappings for stat in &stat_columns { if let Some(aesthetic) = final_remappings.get(stat) { diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 82c9e1c0..063ca1d6 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -116,24 +116,38 @@ fn validate( } } - // Validate remapping source columns are valid stat columns for this geom + // Validate remapping source columns are valid stat columns for this geom. + // Geoms that opt into the Aggregate stat (`supports_aggregate`) also accept + // `aggregate`, `count`, and any position aesthetic name as a stat source. let valid_stat_columns = layer.geom.valid_stat_columns(); + let supports_aggregate = layer.geom.supports_aggregate(); for stat_value in layer.remappings.aesthetics.values() { if let Some(stat_col) = stat_value.column_name() { - if !valid_stat_columns.contains(&stat_col) { - if valid_stat_columns.is_empty() { + let is_aggregate_stat_col = supports_aggregate + && (stat_col == "aggregate" + || stat_col == "count" + || crate::plot::aesthetic::is_position_aesthetic(stat_col)); + if !valid_stat_columns.contains(&stat_col) && !is_aggregate_stat_col { + if valid_stat_columns.is_empty() && !supports_aggregate { return Err(GgsqlError::ValidationError(format!( "Layer {}: REMAPPING not supported for geom '{}' (no stat transform)", idx + 1, layer.geom ))); } else { + let mut valid: Vec = + valid_stat_columns.iter().map(|s| s.to_string()).collect(); + if supports_aggregate { + valid.push("aggregate".to_string()); + valid.push("count".to_string()); + } + let valid_refs: Vec<&str> = valid.iter().map(|s| s.as_str()).collect(); return Err(GgsqlError::ValidationError(format!( "Layer {}: REMAPPING references unknown stat column '{}'. Valid stat columns for geom '{}' are: {}", idx + 1, stat_col, layer.geom, - crate::and_list(valid_stat_columns) + crate::and_list(&valid_refs) ))); } } diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index a9df6bff..101806d0 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -5,8 +5,9 @@ use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamDefinition}; use crate::{naming, Mappings}; +use super::stat_aggregate; use super::types::{ParamConstraint, POSITION_VALUES}; -use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; /// Area geom - filled area charts #[derive(Debug, Clone, Copy)] @@ -54,6 +55,10 @@ impl GeomTrait for Area { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -61,13 +66,16 @@ impl GeomTrait for Area { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Area geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 375d9754..2e3369d2 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -39,6 +39,10 @@ impl GeomTrait for Arrow { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index d64bce9f..7824f74f 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -3,10 +3,11 @@ use std::collections::HashMap; use std::collections::HashSet; +use super::stat_aggregate; use super::types::{get_column_name, POSITION_VALUES}; use super::{ - DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, - StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, + ParamDefinition, StatResult, }; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; @@ -79,6 +80,10 @@ impl GeomTrait for Bar { &["pos1", "pos2", "weight"] } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true // Bar stat decides COUNT vs identity based on y mapping } @@ -89,10 +94,13 @@ impl GeomTrait for Bar { schema: &Schema, aesthetics: &Mappings, group_by: &[String], - _parameters: &HashMap, + parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, - _dialect: &dyn SqlDialect, + dialect: &dyn SqlDialect, ) -> Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } stat_bar_count(query, schema, aesthetics, group_by) } } diff --git a/src/plot/layer/geom/errorbar.rs b/src/plot/layer/geom/errorbar.rs index 394c81e9..2821d141 100644 --- a/src/plot/layer/geom/errorbar.rs +++ b/src/plot/layer/geom/errorbar.rs @@ -44,6 +44,10 @@ impl GeomTrait for ErrorBar { ]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for ErrorBar { diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index a8ded3b1..20b87228 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -1,8 +1,9 @@ //! Line geom implementation +use super::stat_aggregate; use super::{ - DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, - StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, + ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; @@ -39,6 +40,10 @@ impl GeomTrait for Line { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -46,13 +51,16 @@ impl GeomTrait for Line { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Line geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 145f8089..a10e6165 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -43,6 +43,7 @@ mod ribbon; mod rule; mod segment; mod smooth; +pub(crate) mod stat_aggregate; mod text; mod tile; mod violin; @@ -192,20 +193,35 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } + /// Whether this geom accepts the `aggregate` SETTING parameter. + /// + /// Geoms that opt in (the Identity-stat geoms) gain a generic Aggregate stat + /// that groups by discrete mappings + PARTITION BY and emits one row per + /// (group × aggregation function). Statistical geoms (histogram, density, + /// smooth, boxplot, violin) leave this `false` to keep their bespoke stats. + fn supports_aggregate(&self) -> bool { + false + } + /// Apply statistical transformation to the layer query. /// - /// The default implementation returns identity (no transformation). + /// The default implementation dispatches to the Aggregate stat when + /// `supports_aggregate()` is true and the `aggregate` parameter is set; + /// otherwise returns identity (no transformation). #[allow(clippy::too_many_arguments)] fn apply_stat_transform( &self, - _query: &str, - _schema: &Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &HashMap, + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, - _dialect: &dyn SqlDialect, + dialect: &dyn SqlDialect, ) -> Result { + if self.supports_aggregate() && has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } Ok(StatResult::Identity) } @@ -248,10 +264,22 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { for param in self.default_params() { valid.push(param.name); } + if self.supports_aggregate() { + valid.push("aggregate"); + } valid } } +/// True when `parameters["aggregate"]` is set to a non-null string or array. +pub(crate) fn has_aggregate_param(parameters: &HashMap) -> bool { + match parameters.get("aggregate") { + None | Some(ParameterValue::Null) => false, + Some(ParameterValue::String(_)) | Some(ParameterValue::Array(_)) => true, + _ => false, + } +} + /// Wrapper struct for geom trait objects /// /// This provides a convenient interface for working with geoms while hiding @@ -455,6 +483,11 @@ impl Geom { self.0.valid_settings() } + /// Whether this geom accepts the `aggregate` SETTING parameter. + pub fn supports_aggregate(&self) -> bool { + self.0.supports_aggregate() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index 5e32a3be..062e5f73 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -36,6 +36,10 @@ impl GeomTrait for Path { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 3dafde2a..5101f2f0 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -38,6 +38,10 @@ impl GeomTrait for Point { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index d1ed6841..dee34338 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -37,6 +37,10 @@ impl GeomTrait for Polygon { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Polygon { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 87d4636c..98b60951 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -1,7 +1,8 @@ //! Ribbon geom implementation +use super::stat_aggregate; use super::types::POSITION_VALUES; -use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::{naming, Mappings}; @@ -39,6 +40,10 @@ impl GeomTrait for Ribbon { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -46,13 +51,16 @@ impl GeomTrait for Ribbon { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index be434f7a..a495cb48 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -25,6 +25,10 @@ impl GeomTrait for Rule { } } + fn supports_aggregate(&self) -> bool { + true + } + fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { // Rule requires exactly one of pos1 or pos2 (XOR logic) let has_pos1 = mappings.contains_key("pos1"); diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index 2ebfe920..d3fac22d 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -39,6 +39,10 @@ impl GeomTrait for Segment { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs new file mode 100644 index 00000000..54692d88 --- /dev/null +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -0,0 +1,982 @@ +//! Aggregate stat - groups data and applies one or more aggregation functions per group. +//! +//! When a layer's `aggregate` SETTING is set to a function name (or array of names), +//! this stat groups by discrete mappings + PARTITION BY columns and produces one row +//! per (group × function), aggregating numeric position aesthetics. +//! +//! Output columns: +//! - One column per numeric position aesthetic (named `pos1`, `pos2`, etc.) holding the +//! aggregated value. NULL for `count` rows. +//! - `aggregate` - the function name for the row. +//! - `count` (only when `count` is requested) - the row tally for that group. + +use std::collections::HashMap; + +use super::types::StatResult; +use crate::naming; +use crate::plot::aesthetic::is_position_aesthetic; +use crate::plot::types::{ + DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, +}; +use crate::reader::SqlDialect; +use crate::{GgsqlError, Mappings, Result}; + +/// All aggregation function names accepted by the `aggregate` SETTING. +pub const AGG_NAMES: &[&str] = &[ + // Tallies & sums + "count", + "sum", + "prod", + // Extremes + "min", + "max", + "range", + // Central tendency + "mean", + "geomean", + "harmean", + "rms", + "median", + // Spread (standalone) + "sdev", + "var", + "iqr", + // Quantiles + "q05", + "q10", + "q25", + "q50", + "q75", + "q90", + "q95", + // Bands (mean ± spread) + "mean-sdev", + "mean+sdev", + "mean-2sdev", + "mean+2sdev", + "mean-se", + "mean+se", +]; + +/// Returns the `ParamDefinition` for the `aggregate` SETTING parameter. +/// +/// Used by `Layer::validate_settings` to check the value against `AGG_NAMES`, +/// and by geoms that support aggregation. +pub fn aggregate_param_definition() -> ParamDefinition { + ParamDefinition { + name: "aggregate", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_or_string_array(AGG_NAMES), + } +} + +/// Apply the Aggregate stat to a layer query. +/// +/// Returns `StatResult::Identity` when the `aggregate` parameter is unset or null. +/// Otherwise, builds a grouped-aggregation query and returns `StatResult::Transformed`. +/// +/// Strategy: +/// - **Single-pass** (preferred): one `GROUP BY` produces a wide row per group, then +/// `CROSS JOIN VALUES(...)` of function names explodes to one row per (group × function). +/// Used when all requested functions are inline-able. +/// - **UNION ALL fallback**: when a quantile is requested but the dialect doesn't +/// provide `sql_quantile_inline`, fall back to per-function subqueries using +/// `dialect.sql_percentile`. +pub fn apply( + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &HashMap, + dialect: &dyn SqlDialect, +) -> Result { + let funcs = match extract_aggregate_param(parameters) { + None => return Ok(StatResult::Identity), + Some(funcs) => funcs, + }; + + // Discover position aesthetics on the layer, splitting into numeric (to be + // aggregated) and discrete (to be carried through as group columns). + let mut numeric_pos: Vec<(String, String)> = Vec::new(); // (aesthetic, prefixed col) + let mut discrete_pos_cols: Vec = Vec::new(); + for (aesthetic, value) in &aesthetics.aesthetics { + if !is_position_aesthetic(aesthetic) { + continue; + } + let col = match value.column_name() { + Some(c) => c.to_string(), + None => continue, + }; + let info = schema.iter().find(|c| c.name == col); + let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); + if is_discrete { + discrete_pos_cols.push(col); + } else { + numeric_pos.push((aesthetic.clone(), col)); + } + } + numeric_pos.sort_by(|a, b| a.0.cmp(&b.0)); + discrete_pos_cols.sort(); + + if numeric_pos.is_empty() && !funcs.iter().any(|f| f == "count") { + return Err(GgsqlError::ValidationError( + "aggregate requires at least one numeric position aesthetic, or the 'count' function" + .to_string(), + )); + } + + // Group columns: PARTITION BY + discrete mappings (already in group_by) + discrete + // position aesthetic columns. Deduplicated, preserving order. + let mut group_cols: Vec = Vec::new(); + for g in group_by { + if !group_cols.contains(g) { + group_cols.push(g.clone()); + } + } + for c in &discrete_pos_cols { + if !group_cols.contains(c) { + group_cols.push(c.clone()); + } + } + + let needs_count_col = funcs.iter().any(|f| f == "count"); + + // Decide strategy: single-pass when every quantile can be inlined. + let needs_fallback = funcs.iter().any(|f| { + if let Some(frac) = quantile_fraction(f) { + // Use the first numeric column (any will do) for the probe, since we + // only care whether the dialect produces Some or None. + let probe = numeric_pos + .first() + .map(|(_, c)| c.as_str()) + .unwrap_or("__ggsql_probe__"); + dialect.sql_quantile_inline(probe, frac).is_none() + } else { + false + } + }); + + let transformed_query = if needs_fallback { + build_union_all_query(query, &funcs, &numeric_pos, &group_cols, dialect) + } else { + build_single_pass_query(query, &funcs, &numeric_pos, &group_cols, dialect) + }; + + let mut stat_columns: Vec = numeric_pos.iter().map(|(a, _)| a.clone()).collect(); + stat_columns.push("aggregate".to_string()); + if needs_count_col { + stat_columns.push("count".to_string()); + } + + let consumed_aesthetics: Vec = numeric_pos.into_iter().map(|(a, _)| a).collect(); + + Ok(StatResult::Transformed { + query: transformed_query, + stat_columns, + dummy_columns: vec![], + consumed_aesthetics, + }) +} + +/// Extract the `aggregate` parameter as a list of function names, or `None` when +/// the parameter is unset/null. +fn extract_aggregate_param(parameters: &HashMap) -> Option> { + use crate::plot::types::ArrayElement; + match parameters.get("aggregate") { + None | Some(ParameterValue::Null) => None, + Some(ParameterValue::String(s)) => Some(vec![s.clone()]), + Some(ParameterValue::Array(arr)) => { + let names: Vec = arr + .iter() + .filter_map(|el| match el { + ArrayElement::String(s) => Some(s.clone()), + _ => None, + }) + .collect(); + if names.is_empty() { + None + } else { + Some(names) + } + } + _ => None, + } +} + +/// Map a quantile function name (`q05`..`q95`, `median`) to its fraction. +fn quantile_fraction(func: &str) -> Option { + match func { + "median" | "q50" => Some(0.50), + "q05" => Some(0.05), + "q10" => Some(0.10), + "q25" => Some(0.25), + "q75" => Some(0.75), + "q90" => Some(0.90), + "q95" => Some(0.95), + _ => None, + } +} + +/// Build the inline SQL fragment for a function applied to a quoted column. +/// +/// Returns None for `count` (which doesn't take a column) and for quantiles when +/// the dialect lacks an inline form (caller should switch to UNION ALL strategy). +fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { + if func == "count" { + return None; + } + if let Some(frac) = quantile_fraction(func) { + // Strip the quotes added by `naming::quote_ident` so we can re-quote inside + // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. + let unquoted = unquote(qcol); + return dialect.sql_quantile_inline(&unquoted, frac); + } + Some(match func { + "sum" => format!("SUM({})", qcol), + "prod" => format!("EXP(SUM(LN({})))", qcol), + "min" => format!("MIN({})", qcol), + "max" => format!("MAX({})", qcol), + "range" => format!("(MAX({c}) - MIN({c}))", c = qcol), + "mean" => format!("AVG({})", qcol), + "geomean" => format!("EXP(AVG(LN({})))", qcol), + "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), + "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), + "sdev" => format!("STDDEV_POP({})", qcol), + "var" => format!("VAR_POP({})", qcol), + "mean-sdev" => format!("(AVG({c}) - STDDEV_POP({c}))", c = qcol), + "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), + "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), + "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), + "mean-se" => format!( + "(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", + c = qcol + ), + "mean+se" => format!( + "(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", + c = qcol + ), + // `iqr` is computed from quantiles - handled separately. + _ => return None, + }) +} + +/// Strip surrounding double quotes from an identifier, undoing `naming::quote_ident`. +fn unquote(qcol: &str) -> String { + let trimmed = qcol.trim_start_matches('"').trim_end_matches('"'); + trimmed.replace("\"\"", "\"") +} + +/// SQL for a function name literal, properly escaped. +fn func_literal(func: &str) -> String { + format!("'{}'", func.replace('\'', "''")) +} + +// ============================================================================= +// Single-pass strategy: GROUP BY produces a wide CTE, then CROSS JOIN explodes +// rows per requested function. +// ============================================================================= + +fn build_single_pass_query( + query: &str, + funcs: &[String], + numeric_pos: &[(String, String)], + group_cols: &[String], + dialect: &dyn SqlDialect, +) -> String { + let src_alias = "\"__ggsql_stat_src__\""; + let agg_alias = "\"__ggsql_stat_agg__\""; + let funcs_alias = "\"__ggsql_stat_funcs__\""; + + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + // Build the wide aggregation SELECT: one column per (function × position). + let mut wide_select_exprs: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + + // Track the synthetic column names for each (aesthetic, function) pair. + let mut wide_col_for: HashMap<(String, String), String> = HashMap::new(); + + for (aes, col) in numeric_pos { + let qcol = naming::quote_ident(col); + for func in funcs { + if func == "count" { + continue; + } + let key = (aes.clone(), func.clone()); + if wide_col_for.contains_key(&key) { + continue; + } + let wide_name = synthetic_col_name(aes, func); + let expr = match func.as_str() { + "iqr" => { + // q75 - q25 inline if dialect supports it + let q75 = dialect + .sql_quantile_inline(col, 0.75) + .expect("sql_quantile_inline must be Some when single-pass is selected"); + let q25 = dialect + .sql_quantile_inline(col, 0.25) + .expect("sql_quantile_inline must be Some when single-pass is selected"); + format!("({} - {})", q75, q25) + } + _ => function_inline_sql(func, &qcol, dialect) + .expect("function_inline_sql must be Some when single-pass is selected"), + }; + wide_select_exprs.push(format!("{} AS {}", expr, naming::quote_ident(&wide_name))); + wide_col_for.insert(key, wide_name); + } + } + + let needs_count_col = funcs.iter().any(|f| f == "count"); + let count_wide = if needs_count_col { + let c = "__ggsql_stat_cnt__"; + wide_select_exprs.push(format!("COUNT(*) AS {}", naming::quote_ident(c))); + Some(c.to_string()) + } else { + None + }; + + let wide_select = wide_select_exprs.join(", "); + + // Build the CROSS JOIN VALUES table of function names. + let funcs_values: Vec = funcs.iter().map(|f| format!("({})", func_literal(f))).collect(); + let funcs_cte = format!( + "{}(name) AS (VALUES {})", + funcs_alias, + funcs_values.join(", ") + ); + + // Build the outer SELECT: group cols + per-aesthetic CASE + count CASE + name AS aggregate. + let mut outer_exprs: Vec = group_cols + .iter() + .map(|c| format!("{}.{}", agg_alias, naming::quote_ident(c))) + .collect(); + + for (aes, _) in numeric_pos { + let stat_col = naming::stat_column(aes); + let mut whens: Vec = Vec::new(); + for func in funcs { + if let Some(wide_name) = wide_col_for.get(&(aes.clone(), func.clone())) { + whens.push(format!( + "WHEN {} THEN {}.{}", + func_literal(func), + agg_alias, + naming::quote_ident(wide_name) + )); + } + } + let case_expr = if whens.is_empty() { + "NULL".to_string() + } else { + format!( + "CASE {}.name {} ELSE NULL END", + funcs_alias, + whens.join(" ") + ) + }; + outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + } + + if let Some(count_wide) = count_wide { + let stat_col = naming::stat_column("count"); + let case_expr = format!( + "CASE {f}.name WHEN {lit} THEN {a}.{c} ELSE NULL END", + f = funcs_alias, + a = agg_alias, + lit = func_literal("count"), + c = naming::quote_ident(&count_wide) + ); + outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + } + + let stat_aggregate_col = naming::stat_column("aggregate"); + outer_exprs.push(format!( + "{}.name AS {}", + funcs_alias, + naming::quote_ident(&stat_aggregate_col) + )); + + format!( + "WITH {src} AS ({query}), \ + {agg_alias_def} AS (SELECT {wide_select} FROM {src}{group_by}), \ + {funcs_cte} \ + SELECT {outer} FROM {agg} CROSS JOIN {funcs}", + src = src_alias, + query = query, + agg_alias_def = agg_alias, + wide_select = wide_select, + group_by = group_by_clause, + funcs_cte = funcs_cte, + outer = outer_exprs.join(", "), + agg = agg_alias, + funcs = funcs_alias, + ) +} + +/// Synthetic name for a (aesthetic, function) intermediate column in the wide CTE. +/// Includes a sanitized form of the function name to avoid collisions on `+`/`-`. +fn synthetic_col_name(aes: &str, func: &str) -> String { + let safe: String = func + .chars() + .map(|c| match c { + '+' => 'p', + '-' => 'm', + _ if c.is_ascii_alphanumeric() => c, + _ => '_', + }) + .collect(); + format!("__ggsql_stat_{}_{}", aes, safe) +} + +// ============================================================================= +// UNION ALL fallback strategy: one SELECT per requested function. +// ============================================================================= + +fn build_union_all_query( + query: &str, + funcs: &[String], + numeric_pos: &[(String, String)], + group_cols: &[String], + dialect: &dyn SqlDialect, +) -> String { + let src_alias = "\"__ggsql_stat_src__\""; + + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + let group_select: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + + let needs_count_col = funcs.iter().any(|f| f == "count"); + let stat_aggregate_col = naming::stat_column("aggregate"); + let stat_count_col = naming::stat_column("count"); + + let branches: Vec = funcs + .iter() + .map(|func| { + let mut select_parts: Vec = group_select.clone(); + + for (aes, col) in numeric_pos { + let stat_col = naming::stat_column(aes); + let value_expr = if func == "count" { + "NULL".to_string() + } else if func == "iqr" { + let q75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); + let q25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); + format!("({} - {})", q75, q25) + } else if let Some(frac) = quantile_fraction(func) { + dialect.sql_percentile(col, frac, src_alias, group_cols) + } else { + let qcol = naming::quote_ident(col); + function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) + }; + select_parts.push(format!("{} AS {}", value_expr, naming::quote_ident(&stat_col))); + } + + if needs_count_col { + let value_expr = if func == "count" { + "COUNT(*)".to_string() + } else { + "NULL".to_string() + }; + select_parts.push(format!( + "{} AS {}", + value_expr, + naming::quote_ident(&stat_count_col) + )); + } + + select_parts.push(format!( + "{} AS {}", + func_literal(func), + naming::quote_ident(&stat_aggregate_col) + )); + + // Quantile fallbacks (sql_percentile) need the outer alias `__ggsql_qt__` + // so their correlated WHERE clause can find group columns. + format!( + "SELECT {} FROM {} AS \"__ggsql_qt__\"{}", + select_parts.join(", "), + src_alias, + group_by_clause + ) + }) + .collect(); + + format!( + "WITH {src} AS ({query}) {branches}", + src = src_alias, + query = query, + branches = branches.join(" UNION ALL ") + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::types::{AestheticValue, ColumnInfo}; + use arrow::datatypes::DataType; + + /// A test dialect that mimics DuckDB's native QUANTILE_CONT support. + struct InlineQuantileDialect; + impl SqlDialect for InlineQuantileDialect { + fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option { + Some(format!( + "QUANTILE_CONT({}, {})", + naming::quote_ident(column), + fraction + )) + } + } + + /// A test dialect with no inline quantile support, exercising the UNION ALL fallback. + struct NoInlineQuantileDialect; + impl SqlDialect for NoInlineQuantileDialect {} + + fn col(name: &str) -> AestheticValue { + AestheticValue::Column { + name: name.to_string(), + original_name: None, + is_dummy: false, + } + } + + fn numeric_schema(cols: &[&str]) -> Schema { + cols.iter() + .map(|c| ColumnInfo { + name: c.to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }) + .collect() + } + + #[test] + fn returns_identity_when_param_unset() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let params = HashMap::new(); + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + assert_eq!(result, StatResult::Identity); + } + + #[test] + fn returns_identity_when_param_null() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let mut params = HashMap::new(); + params.insert("aggregate".to_string(), ParameterValue::Null); + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + assert_eq!(result, StatResult::Identity); + } + + #[test] + fn single_pass_for_mean_emits_avg() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(query.contains("CROSS JOIN")); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(stat_columns.contains(&"aggregate".to_string())); + assert!(!stat_columns.contains(&"count".to_string())); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn count_emits_count_star_and_keeps_count_column() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("count".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + + match result { + StatResult::Transformed { + query, + stat_columns, + .. + } => { + assert!(query.contains("COUNT(*)")); + assert!(stat_columns.contains(&"count".to_string())); + assert!(stat_columns.contains(&"aggregate".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mixed_count_and_mean_produces_two_rows_per_group() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("count".to_string()), + ArrayElement::String("mean".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + assert!(query.contains("COUNT(*)")); + assert!(query.contains("'count'")); + assert!(query.contains("'mean'")); + // The count CASE must reference the agg CTE for the value column, + // not the funcs CTE (regression: previously emitted funcs.cnt which + // doesn't exist). + assert!( + query.contains("\"__ggsql_stat_agg__\".\"__ggsql_stat_cnt__\""), + "count CASE should reference the agg CTE, query was: {}", + query + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn quantile_uses_dialect_inline_when_available() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("q25".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("QUANTILE_CONT")); + assert!(query.contains("0.25")); + assert!(!query.contains("UNION ALL")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn quantile_falls_back_to_union_all_without_dialect_support() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("q25".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // Fallback dialect uses NTILE-based correlated subquery via UNION ALL. + assert!(query.contains("NTILE(4)")); + assert!(query.contains("UNION ALL") || !query.contains("CROSS JOIN")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mean_sdev_emits_avg_and_stddev() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("STDDEV_POP")); + assert!(query.contains("AVG")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mean_se_includes_sqrt_count() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+se".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("SQRT(COUNT")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn prod_emits_exp_sum_ln() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("prod".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("EXP(SUM(LN")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn iqr_emits_q75_minus_q25() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("iqr".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("0.75")); + assert!(query.contains("0.25")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn discrete_position_aesthetic_becomes_group_column() { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + // pos1 (discrete) is in GROUP BY, not aggregated. + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + // pos2 is aggregated. + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + // Only pos2 is consumed. + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + // Only pos2 (numeric) appears in stat_columns; pos1 stays as-is. + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn explicit_group_by_columns_appear_in_query() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &["region".to_string()], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("GROUP BY \"region\"")); + } + _ => panic!("expected Transformed"), + } + } +} diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index 6ceb45f9..5909c34d 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -63,6 +63,10 @@ impl GeomTrait for Text { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn post_process( &self, df: DataFrame, diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 33321a48..e6656590 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -419,6 +419,11 @@ impl Layer { { validate_parameter(param_name, value, ¶m.constraint)?; } + // Or the shared `aggregate` param for Identity-stat geoms + else if param_name == "aggregate" && self.geom.supports_aggregate() { + let definition = crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); + validate_parameter(param_name, value, &definition.constraint)?; + } // Otherwise it's a valid aesthetic setting (no constraint validation needed) } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index aae89f20..6e7ab0cb 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -41,6 +41,14 @@ impl super::SqlDialect for DuckDbDialect { ) } + fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option { + Some(format!( + "QUANTILE_CONT({}, {})", + naming::quote_ident(column), + fraction + )) + } + fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 16a96b66..a02bda3c 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -215,6 +215,17 @@ pub trait SqlDialect { ) } + /// Inline-form quantile aggregate, usable directly in a `SELECT` list. + /// + /// Returns `Some(sql_fragment)` when the dialect supports a native quantile + /// aggregate that can be combined with other aggregates in the same `GROUP BY` + /// query (e.g. DuckDB's `QUANTILE_CONT`). Returns `None` when no native + /// inline form exists; callers should then fall back to [`sql_percentile`], + /// which produces a correlated scalar subquery. + fn sql_quantile_inline(&self, _column: &str, _fraction: f64) -> Option { + None + } + /// SQL literal for a date value (days since Unix epoch). fn sql_date_literal(&self, days_since_epoch: i32) -> String { format!( From 778b6acbfe7305403d99543e922cdd266b4376bb Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 13:58:46 +0200 Subject: [PATCH 02/12] support numeric axis geoms --- src/plot/layer/geom/area.rs | 31 +-- src/plot/layer/geom/arrow.rs | 4 + src/plot/layer/geom/bar.rs | 10 +- src/plot/layer/geom/line.rs | 30 ++- src/plot/layer/geom/mod.rs | 28 ++- src/plot/layer/geom/path.rs | 4 + src/plot/layer/geom/point.rs | 4 + src/plot/layer/geom/polygon.rs | 4 + src/plot/layer/geom/ribbon.rs | 31 +-- src/plot/layer/geom/rule.rs | 6 + src/plot/layer/geom/segment.rs | 4 + src/plot/layer/geom/stat_aggregate.rs | 303 +++++++++++++++++++++++++- src/plot/layer/geom/text.rs | 4 + src/plot/layer/geom/types.rs | 88 ++++++++ 14 files changed, 501 insertions(+), 50 deletions(-) diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index 101806d0..f6388032 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -3,10 +3,10 @@ use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamDefinition}; -use crate::{naming, Mappings}; +use crate::Mappings; use super::stat_aggregate; -use super::types::{ParamConstraint, POSITION_VALUES}; +use super::types::{wrap_with_order_by, ParamConstraint, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; /// Area geom - filled area charts @@ -73,17 +73,22 @@ impl GeomTrait for Area { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Area geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Area needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 2e3369d2..5737bb95 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -43,6 +43,10 @@ impl GeomTrait for Arrow { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index 7824f74f..b3b82d72 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -99,7 +99,15 @@ impl GeomTrait for Bar { dialect: &dyn SqlDialect, ) -> Result { if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + return stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + ); } stat_bar_count(query, schema, aesthetics, group_by) } diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 20b87228..92f7927c 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -1,13 +1,14 @@ //! Line geom implementation use super::stat_aggregate; +use super::types::wrap_with_order_by; use super::{ has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; -use crate::{naming, Mappings}; +use crate::Mappings; /// Line geom - line charts with connected points #[derive(Debug, Clone, Copy)] @@ -58,17 +59,22 @@ impl GeomTrait for Line { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Line geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Line needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index a10e6165..1fd22dbd 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -203,6 +203,19 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } + /// Which numeric position-aesthetic slots the Aggregate stat should reduce. + /// + /// Slot 1 is `pos1`/`pos1min`/`pos1max`/`pos1end` (the independent / domain axis). + /// Slot 2 is `pos2`/`pos2min`/`pos2max`/`pos2end` (the dependent / range axis). + /// + /// Default: `&[2]` — only the dependent axis is reduced; pos1-family stays as a + /// grouping column, so e.g. line geoms produce a summary trace along x. Geoms + /// whose natural Aggregate is centroid-like (point, polygon, segment, arrow, + /// text, path, tile, rule) override to `&[1, 2]`. + fn aggregate_slots(&self) -> &'static [u8] { + &[2] + } + /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -220,7 +233,15 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { dialect: &dyn SqlDialect, ) -> Result { if self.supports_aggregate() && has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + return stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + ); } Ok(StatResult::Identity) } @@ -488,6 +509,11 @@ impl Geom { self.0.supports_aggregate() } + /// Which position-aesthetic slots the Aggregate stat should reduce. + pub fn aggregate_slots(&self) -> &'static [u8] { + self.0.aggregate_slots() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index 062e5f73..c2c8af9f 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -40,6 +40,10 @@ impl GeomTrait for Path { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 5101f2f0..1f60a5f6 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -42,6 +42,10 @@ impl GeomTrait for Point { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index dee34338..efda483e 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -41,6 +41,10 @@ impl GeomTrait for Polygon { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Polygon { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 98b60951..bf1898b9 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -1,11 +1,11 @@ //! Ribbon geom implementation use super::stat_aggregate; -use super::types::POSITION_VALUES; +use super::types::{wrap_with_order_by, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; -use crate::{naming, Mappings}; +use crate::Mappings; /// Ribbon geom - confidence bands and ranges #[derive(Debug, Clone, Copy)] @@ -58,17 +58,22 @@ impl GeomTrait for Ribbon { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Ribbon needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index a495cb48..21d7adbe 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -29,6 +29,12 @@ impl GeomTrait for Rule { true } + fn aggregate_slots(&self) -> &'static [u8] { + // Rule maps exactly one of pos1/pos2 (XOR). Allow either to be the reduced + // axis — whichever is mapped wins, and the other slot has nothing to filter. + &[1, 2] + } + fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { // Rule requires exactly one of pos1 or pos2 (XOR logic) let has_pos1 = mappings.contains_key("pos1"); diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index d3fac22d..b229d054 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -43,6 +43,10 @@ impl GeomTrait for Segment { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 54692d88..b446cc0a 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -14,7 +14,7 @@ use std::collections::HashMap; use super::types::StatResult; use crate::naming; -use crate::plot::aesthetic::is_position_aesthetic; +use crate::plot::aesthetic::{is_position_aesthetic, parse_position}; use crate::plot::types::{ DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, }; @@ -89,16 +89,19 @@ pub fn apply( group_by: &[String], parameters: &HashMap, dialect: &dyn SqlDialect, + agg_slots: &[u8], ) -> Result { let funcs = match extract_aggregate_param(parameters) { None => return Ok(StatResult::Identity), Some(funcs) => funcs, }; - // Discover position aesthetics on the layer, splitting into numeric (to be - // aggregated) and discrete (to be carried through as group columns). + // Walk the layer's position aesthetics and route each by (slot, type): + // in-axis slot && numeric → aggregated (numeric_pos) + // in-axis slot && discrete → kept as group column (kept_pos_cols) + // out-of-axis (any type) → kept as group column (kept_pos_cols) let mut numeric_pos: Vec<(String, String)> = Vec::new(); // (aesthetic, prefixed col) - let mut discrete_pos_cols: Vec = Vec::new(); + let mut kept_pos_cols: Vec = Vec::new(); for (aesthetic, value) in &aesthetics.aesthetics { if !is_position_aesthetic(aesthetic) { continue; @@ -107,16 +110,19 @@ pub fn apply( Some(c) => c.to_string(), None => continue, }; + let slot = parse_position(aesthetic).map(|(s, _)| s).unwrap_or(0); + let in_axis = agg_slots.contains(&slot); let info = schema.iter().find(|c| c.name == col); let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); - if is_discrete { - discrete_pos_cols.push(col); + + if !in_axis || is_discrete { + kept_pos_cols.push(col); } else { numeric_pos.push((aesthetic.clone(), col)); } } numeric_pos.sort_by(|a, b| a.0.cmp(&b.0)); - discrete_pos_cols.sort(); + kept_pos_cols.sort(); if numeric_pos.is_empty() && !funcs.iter().any(|f| f == "count") { return Err(GgsqlError::ValidationError( @@ -125,15 +131,16 @@ pub fn apply( )); } - // Group columns: PARTITION BY + discrete mappings (already in group_by) + discrete - // position aesthetic columns. Deduplicated, preserving order. + // Group columns: PARTITION BY + discrete mappings (already in group_by) + any + // position-aesthetic columns we kept (out-of-axis or in-axis-but-discrete). + // Deduplicated, preserving order. let mut group_cols: Vec = Vec::new(); for g in group_by { if !group_cols.contains(g) { group_cols.push(g.clone()); } } - for c in &discrete_pos_cols { + for c in &kept_pos_cols { if !group_cols.contains(c) { group_cols.push(c.clone()); } @@ -577,6 +584,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -595,6 +603,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -618,6 +627,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); @@ -657,6 +667,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); @@ -696,6 +707,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -735,6 +747,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -765,6 +778,7 @@ mod tests { &[], ¶ms, &NoInlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -799,6 +813,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -828,6 +843,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -856,6 +872,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -884,6 +901,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -929,6 +947,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -970,6 +989,7 @@ mod tests { &["region".to_string()], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -979,4 +999,267 @@ mod tests { _ => panic!("expected Transformed"), } } + + #[test] + fn line_style_groups_by_pos1_and_aggregates_pos2() { + // slots=[2]: pos1 stays as group (even though numeric), pos2 gets aggregated. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("max".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn point_style_aggregates_both_slots() { + // slots=[1,2]: both pos1 and pos2 (numeric) get aggregated → centroid. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[1, 2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "query: {}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(!query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + let mut consumed = consumed_aesthetics.clone(); + consumed.sort(); + assert_eq!(consumed, vec!["pos1".to_string(), "pos2".to_string()]); + assert!(stat_columns.contains(&"pos1".to_string())); + assert!(stat_columns.contains(&"pos2".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn errorbar_aggregates_pos2_minmax() { + // slots=[2]: pos1 fixed (group), pos2min and pos2max both aggregated. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2min", col("__ggsql_aes_pos2min__")); + aes.insert("pos2max", col("__ggsql_aes_pos2max__")); + let schema = numeric_schema(&[ + "__ggsql_aes_pos1__", + "__ggsql_aes_pos2min__", + "__ggsql_aes_pos2max__", + ]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")"), "query: {}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos2max__\")"), "query: {}", query); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + let mut consumed = consumed_aesthetics.clone(); + consumed.sort(); + assert_eq!(consumed, vec!["pos2max".to_string(), "pos2min".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn out_of_axis_numeric_pos_stays_as_group() { + // slots=[2], numeric pos1 → still goes to GROUP BY (not aggregated). + // Same expectation as line_style_groups_by_pos1_and_aggregates_pos2 but + // explicit about the "numeric out-of-axis" path. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn discrete_in_axis_pos_stays_as_group_on_centroid_geom() { + // slots=[1,2], pos1 discrete + pos2 numeric → only pos2 aggregated, + // pos1 stays as GROUP BY. Confirms numeric check is preserved on + // slot=[1,2] geoms (e.g. point with category AS x, value AS y). + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[1, 2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + assert!(!query.contains("AVG(\"__ggsql_aes_pos1__\")")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn count_works_with_no_numeric_pos() { + // slots=[2], only discrete pos1 mapped, aggregate=count → no + // "needs numeric" error; query has COUNT(*) and groups by pos1. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + let schema = vec![ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("count".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + .. + } => { + assert!(query.contains("COUNT(*)")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert!(stat_columns.contains(&"count".to_string())); + } + _ => panic!("expected Transformed"), + } + } } diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index 5909c34d..d9af79ac 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -67,6 +67,10 @@ impl GeomTrait for Text { true } + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } + fn post_process( &self, df: DataFrame, diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index fb0ab5b8..4bdcf897 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -175,6 +175,44 @@ pub use crate::plot::types::ColumnInfo; /// Schema of a data source - list of columns with type info pub use crate::plot::types::Schema; +/// Wrap a stat result with `ORDER BY `. +/// +/// Used by line/area/ribbon to ensure the rendered output is sorted along the +/// domain axis whether or not the layer also goes through the Aggregate stat. +/// +/// - `Identity` → becomes `Transformed` with ` ORDER BY `, +/// empty `stat_columns`/`dummy_columns`/`consumed_aesthetics`. Same shape as +/// the previous inline `ORDER BY` path produced. +/// - `Transformed` → wraps the existing query in +/// `SELECT * FROM () AS "__ggsql_ord__" ORDER BY ` and preserves +/// the stat metadata. +pub fn wrap_with_order_by(input_query: &str, result: StatResult, aesthetic: &str) -> StatResult { + let order_col = naming::aesthetic_column(aesthetic); + let order_quoted = naming::quote_ident(&order_col); + match result { + StatResult::Identity => StatResult::Transformed { + query: format!("{} ORDER BY {}", input_query, order_quoted), + stat_columns: vec![], + dummy_columns: vec![], + consumed_aesthetics: vec![], + }, + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => StatResult::Transformed { + query: format!( + "SELECT * FROM ({}) AS \"__ggsql_ord__\" ORDER BY {}", + query, order_quoted + ), + stat_columns, + dummy_columns, + consumed_aesthetics, + }, + } +} + /// Helper to extract column name from aesthetic value pub fn get_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option { use crate::AestheticValue; @@ -260,6 +298,56 @@ mod tests { assert!(!aes.is_required("yend")); } + #[test] + fn wrap_with_order_by_identity_appends_order() { + let result = wrap_with_order_by("SELECT * FROM t", StatResult::Identity, "pos1"); + match result { + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => { + assert_eq!( + query, + "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\"" + ); + assert!(stat_columns.is_empty()); + assert!(dummy_columns.is_empty()); + assert!(consumed_aesthetics.is_empty()); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn wrap_with_order_by_transformed_wraps_query_and_preserves_metadata() { + let inner = StatResult::Transformed { + query: "SELECT * FROM grouped".to_string(), + stat_columns: vec!["pos2".to_string(), "aggregate".to_string()], + dummy_columns: vec!["pos1".to_string()], + consumed_aesthetics: vec!["pos2".to_string()], + }; + let result = wrap_with_order_by("SELECT * FROM raw", inner, "pos1"); + match result { + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => { + assert_eq!( + query, + "SELECT * FROM (SELECT * FROM grouped) AS \"__ggsql_ord__\" ORDER BY \"__ggsql_aes_pos1__\"" + ); + assert_eq!(stat_columns, vec!["pos2".to_string(), "aggregate".to_string()]); + assert_eq!(dummy_columns, vec!["pos1".to_string()]); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn test_color_alias_requires_stroke_or_fill() { // Geom with neither stroke nor fill: color alias should NOT be supported From 0a1b214f12a0a389028319ee441311c053b17ee0 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 14:27:16 +0200 Subject: [PATCH 03/12] support range geoms --- src/plot/layer/geom/area.rs | 1 + src/plot/layer/geom/bar.rs | 1 + src/plot/layer/geom/errorbar.rs | 8 + src/plot/layer/geom/line.rs | 1 + src/plot/layer/geom/mod.rs | 19 ++ src/plot/layer/geom/ribbon.rs | 9 + src/plot/layer/geom/stat_aggregate.rs | 414 ++++++++++++++++++++++++++ src/plot/layer/mod.rs | 15 + 8 files changed, 468 insertions(+) diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index f6388032..31617ea6 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -82,6 +82,7 @@ impl GeomTrait for Area { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index b3b82d72..aebf207e 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -107,6 +107,7 @@ impl GeomTrait for Bar { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), ); } stat_bar_count(query, schema, aesthetics, group_by) diff --git a/src/plot/layer/geom/errorbar.rs b/src/plot/layer/geom/errorbar.rs index 2821d141..5d088224 100644 --- a/src/plot/layer/geom/errorbar.rs +++ b/src/plot/layer/geom/errorbar.rs @@ -21,6 +21,10 @@ impl GeomTrait for ErrorBar { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), + // pos2 is the input column for the Aggregate stat in range mode + // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 + // and produces pos2min/pos2max). Optional otherwise. + ("pos2", DefaultAestheticValue::Null), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(1.0)), ("linewidth", DefaultAestheticValue::Number(1.0)), @@ -48,6 +52,10 @@ impl GeomTrait for ErrorBar { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } } impl std::fmt::Display for ErrorBar { diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 92f7927c..6493fd83 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -68,6 +68,7 @@ impl GeomTrait for Line { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 1fd22dbd..806a1405 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -216,6 +216,19 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { &[2] } + /// Range pair for range-style Aggregate output. + /// + /// When `Some((lower, upper))`, this geom is a "range geom" that takes exactly + /// two `aggregate` functions and assigns them to the two named aesthetics + /// (e.g. `("pos2min", "pos2max")` for ribbon/errorbar). The user maps `pos2` + /// as the input column; the stat consumes pos2 and produces the range pair. + /// One row per group; no `aggregate` tag column. + /// + /// `None` (default) means standard per-function-rows aggregation. + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + None + } + /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -241,6 +254,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), ); } Ok(StatResult::Identity) @@ -514,6 +528,11 @@ impl Geom { self.0.aggregate_slots() } + /// Range pair for range-style Aggregate output, if any. + pub fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + self.0.aggregate_range_pair() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index bf1898b9..07f005d7 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -22,6 +22,10 @@ impl GeomTrait for Ribbon { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), + // pos2 is the input column for the Aggregate stat in range mode + // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 + // and produces pos2min/pos2max). Optional otherwise. + ("pos2", DefaultAestheticValue::Null), ("fill", DefaultAestheticValue::String("black")), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(0.8)), @@ -44,6 +48,10 @@ impl GeomTrait for Ribbon { true } + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -67,6 +75,7 @@ impl GeomTrait for Ribbon { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index b446cc0a..2d3ce1a0 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -90,12 +90,17 @@ pub fn apply( parameters: &HashMap, dialect: &dyn SqlDialect, agg_slots: &[u8], + range_pair: Option<(&'static str, &'static str)>, ) -> Result { let funcs = match extract_aggregate_param(parameters) { None => return Ok(StatResult::Identity), Some(funcs) => funcs, }; + if let Some((lo, hi)) = range_pair { + return apply_range_mode(query, schema, aesthetics, group_by, &funcs, dialect, lo, hi); + } + // Walk the layer's position aesthetics and route each by (slot, type): // in-axis slot && numeric → aggregated (numeric_pos) // in-axis slot && discrete → kept as group column (kept_pos_cols) @@ -278,6 +283,143 @@ fn func_literal(func: &str) -> String { format!("'{}'", func.replace('\'', "''")) } +// ============================================================================= +// Range-mode strategy: exactly two functions filling a (lower, upper) aesthetic +// pair on the same row. Used by ribbon/errorbar. +// ============================================================================= + +fn apply_range_mode( + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + funcs: &[String], + dialect: &dyn SqlDialect, + lo: &'static str, + hi: &'static str, +) -> Result { + if funcs.len() != 2 { + return Err(GgsqlError::ValidationError(format!( + "aggregate on a range geom must be an array of exactly two functions (lower, upper), got {}", + funcs.len() + ))); + } + + // Range mode requires `pos2` mapped to a numeric input column. The user + // writes `MAPPING value AS y` and the stat consumes it to produce both + // bounds. + let input_col = match aesthetics.get("pos2").and_then(|v| v.column_name()) { + Some(c) => c.to_string(), + None => { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom requires a `y` (pos2) mapping as the input column" + .to_string(), + )); + } + }; + let info = schema.iter().find(|c| c.name == input_col); + if info.map(|c| c.is_discrete).unwrap_or(false) { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom requires a numeric `y` (pos2) input, not a discrete column" + .to_string(), + )); + } + let qcol = naming::quote_ident(&input_col); + + // Group columns: PARTITION BY + discrete mappings (already in group_by) + + // any discrete position aesthetics on the layer (e.g. pos1 if it's a string). + let mut group_cols: Vec = Vec::new(); + for g in group_by { + if !group_cols.contains(g) { + group_cols.push(g.clone()); + } + } + for (aesthetic, value) in &aesthetics.aesthetics { + if !is_position_aesthetic(aesthetic) || aesthetic == "pos2" { + continue; + } + let col = match value.column_name() { + Some(c) => c.to_string(), + None => continue, + }; + if !group_cols.contains(&col) { + group_cols.push(col); + } + } + + let src_alias = "\"__ggsql_stat_src__\""; + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + // Build the two function expressions. Quantiles use the inline form when + // available; otherwise fall back to `sql_percentile` correlated to the + // outer alias used in the FROM (`__ggsql_qt__`, matching boxplot/etc.). + let lo_expr = build_range_function_sql(&funcs[0], &qcol, &input_col, dialect, &group_cols)?; + let hi_expr = build_range_function_sql(&funcs[1], &qcol, &input_col, dialect, &group_cols)?; + + let stat_lo = naming::stat_column(lo); + let stat_hi = naming::stat_column(hi); + + let group_select: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + let mut select_parts = group_select.clone(); + select_parts.push(format!("{} AS {}", lo_expr, naming::quote_ident(&stat_lo))); + select_parts.push(format!("{} AS {}", hi_expr, naming::quote_ident(&stat_hi))); + + let transformed_query = format!( + "WITH {src} AS ({query}) SELECT {sel} FROM {src} AS \"__ggsql_qt__\"{gb}", + src = src_alias, + query = query, + sel = select_parts.join(", "), + gb = group_by_clause, + ); + + // consumed_aesthetics: pos2 carries the original-name capture for axis + // labels; lo/hi flag the auto-rename in execute/layer.rs (their stat-column + // names match the position aesthetics they fill). + Ok(StatResult::Transformed { + query: transformed_query, + stat_columns: vec![lo.to_string(), hi.to_string()], + dummy_columns: vec![], + consumed_aesthetics: vec!["pos2".to_string(), lo.to_string(), hi.to_string()], + }) +} + +/// Build the SQL fragment for one function in range mode. Quantiles get the +/// inline form when the dialect supports it; otherwise the fallback subquery. +fn build_range_function_sql( + func: &str, + qcol: &str, + raw_col: &str, + dialect: &dyn SqlDialect, + group_cols: &[String], +) -> Result { + if func == "count" { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom does not support 'count' (it has no range semantics)" + .to_string(), + )); + } + if let Some(frac) = quantile_fraction(func) { + if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { + return Ok(inline); + } + return Ok(dialect.sql_percentile(raw_col, frac, "\"__ggsql_stat_src__\"", group_cols)); + } + function_inline_sql(func, qcol, dialect).ok_or_else(|| { + GgsqlError::ValidationError(format!( + "aggregate on a range geom does not support function '{}' on this dialect", + func + )) + }) +} + // ============================================================================= // Single-pass strategy: GROUP BY produces a wide CTE, then CROSS JOIN explodes // rows per requested function. @@ -585,6 +727,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -604,6 +747,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -628,6 +772,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); @@ -668,6 +813,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); @@ -708,6 +854,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -748,6 +895,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -779,6 +927,7 @@ mod tests { ¶ms, &NoInlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -814,6 +963,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -844,6 +994,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -873,6 +1024,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -902,6 +1054,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -948,6 +1101,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -990,6 +1144,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1021,6 +1176,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1062,6 +1218,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[1, 2], + None, ) .unwrap(); match result { @@ -1110,6 +1267,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1152,6 +1310,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1200,6 +1359,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[1, 2], + None, ) .unwrap(); match result { @@ -1247,6 +1407,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1262,4 +1423,257 @@ mod tests { _ => panic!("expected Transformed"), } } + + // ======================================================================== + // Range-mode tests (ribbon / errorbar) + // ======================================================================== + + fn range_pair() -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } + + fn range_input_aes_with_group() -> (Mappings, Schema) { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + (aes, schema) + } + + #[test] + fn range_mode_two_functions_emits_one_row_per_group() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "lower bound expr missing: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "upper bound expr missing: {}", + query + ); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert!(!query.contains("UNION ALL")); + assert!(!query.contains("CROSS JOIN")); + // No `aggregate` tag column in range mode. + assert!(!query.contains("__ggsql_stat_aggregate__")); + assert_eq!( + stat_columns, + vec!["pos2min".to_string(), "pos2max".to_string()] + ); + assert!(consumed_aesthetics.contains(&"pos2".to_string())); + assert!(consumed_aesthetics.contains(&"pos2min".to_string())); + assert!(consumed_aesthetics.contains(&"pos2max".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_rejects_single_function() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("exactly two"), + "expected 'exactly two' in error, got: {}", + err + ); + } + + #[test] + fn range_mode_rejects_three_functions() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("min".to_string()), + ArrayElement::String("mean".to_string()), + ArrayElement::String("max".to_string()), + ]), + ); + + let err = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("exactly two")); + } + + #[test] + fn range_mode_quantile_uses_inline_when_available() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("q25".to_string()), + ArrayElement::String("q75".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("QUANTILE_CONT")); + assert!(query.contains("0.25")); + assert!(query.contains("0.75")); + assert!(!query.contains("NTILE(4)")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_quantile_falls_back_without_dialect_support() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("q25".to_string()), + ArrayElement::String("q75".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("NTILE(4)")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_requires_pos2_input() { + // Range geom but pos2 not mapped → error. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + let schema = vec![ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }]; + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let err = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("pos2") || err.contains("`y`"), + "expected pos2/y mention in error, got: {}", + err + ); + } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index e6656590..8416c411 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -200,10 +200,25 @@ impl Layer { }; // Check if all required aesthetics exist. + // When `aggregate` is set on a range geom, the (lower, upper) range pair + // is filled by the stat (e.g. pos2min/pos2max for ribbon) and shouldn't + // be required from the user. + let range_pair_skip: Option<(&'static str, &'static str)> = + if crate::plot::layer::geom::has_aggregate_param(&self.parameters) { + self.geom.aggregate_range_pair() + } else { + None + }; + let mut missing = Vec::new(); let mut position_reqs: Vec<(&str, u8, &str)> = Vec::new(); for aesthetic in self.geom.aesthetics().required() { + if let Some((lo, hi)) = range_pair_skip { + if aesthetic == lo || aesthetic == hi { + continue; + } + } if let Some((slot, suffix)) = parse_position(aesthetic) { position_reqs.push((aesthetic, slot, suffix)) } else if !self.mappings.contains_key(aesthetic) { From 218f302aea3b1691f2b3b17503ac3052fc89df19 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 14:32:55 +0200 Subject: [PATCH 04/12] reformat --- src/plot/layer/geom/bar.rs | 4 +- src/plot/layer/geom/line.rs | 4 +- src/plot/layer/geom/stat_aggregate.rs | 93 ++++++++++++++++++--------- src/plot/layer/geom/types.rs | 10 +-- src/plot/layer/mod.rs | 3 +- 5 files changed, 72 insertions(+), 42 deletions(-) diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index aebf207e..e65a0256 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -6,8 +6,8 @@ use std::collections::HashSet; use super::stat_aggregate; use super::types::{get_column_name, POSITION_VALUES}; use super::{ - has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, - ParamDefinition, StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, + ParamConstraint, ParamDefinition, StatResult, }; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 6493fd83..e0600af0 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -3,8 +3,8 @@ use super::stat_aggregate; use super::types::wrap_with_order_by; use super::{ - has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, - ParamDefinition, StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, + ParamConstraint, ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 2d3ce1a0..bd8cb0dc 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -259,14 +259,8 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), - "mean-se" => format!( - "(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", - c = qcol - ), - "mean+se" => format!( - "(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", - c = qcol - ), + "mean-se" => format!("(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), + "mean+se" => format!("(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), // `iqr` is computed from quantiles - handled separately. _ => return None, }) @@ -364,10 +358,7 @@ fn apply_range_mode( let stat_lo = naming::stat_column(lo); let stat_hi = naming::stat_column(hi); - let group_select: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let mut select_parts = group_select.clone(); select_parts.push(format!("{} AS {}", lo_expr, naming::quote_ident(&stat_lo))); select_parts.push(format!("{} AS {}", hi_expr, naming::quote_ident(&stat_hi))); @@ -444,10 +435,8 @@ fn build_single_pass_query( }; // Build the wide aggregation SELECT: one column per (function × position). - let mut wide_select_exprs: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let mut wide_select_exprs: Vec = + group_cols.iter().map(|c| naming::quote_ident(c)).collect(); // Track the synthetic column names for each (aesthetic, function) pair. let mut wide_col_for: HashMap<(String, String), String> = HashMap::new(); @@ -494,7 +483,10 @@ fn build_single_pass_query( let wide_select = wide_select_exprs.join(", "); // Build the CROSS JOIN VALUES table of function names. - let funcs_values: Vec = funcs.iter().map(|f| format!("({})", func_literal(f))).collect(); + let funcs_values: Vec = funcs + .iter() + .map(|f| format!("({})", func_literal(f))) + .collect(); let funcs_cte = format!( "{}(name) AS (VALUES {})", funcs_alias, @@ -529,7 +521,11 @@ fn build_single_pass_query( whens.join(" ") ) }; - outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + outer_exprs.push(format!( + "{} AS {}", + case_expr, + naming::quote_ident(&stat_col) + )); } if let Some(count_wide) = count_wide { @@ -541,7 +537,11 @@ fn build_single_pass_query( lit = func_literal("count"), c = naming::quote_ident(&count_wide) ); - outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + outer_exprs.push(format!( + "{} AS {}", + case_expr, + naming::quote_ident(&stat_col) + )); } let stat_aggregate_col = naming::stat_column("aggregate"); @@ -603,10 +603,7 @@ fn build_union_all_query( format!(" GROUP BY {}", qcols.join(", ")) }; - let group_select: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let needs_count_col = funcs.iter().any(|f| f == "count"); let stat_aggregate_col = naming::stat_column("aggregate"); @@ -631,7 +628,11 @@ fn build_union_all_query( let qcol = naming::quote_ident(col); function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) }; - select_parts.push(format!("{} AS {}", value_expr, naming::quote_ident(&stat_col))); + select_parts.push(format!( + "{} AS {}", + value_expr, + naming::quote_ident(&stat_col) + )); } if needs_count_col { @@ -783,7 +784,11 @@ mod tests { consumed_aesthetics, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(query.contains("CROSS JOIN")); assert!(stat_columns.contains(&"pos2".to_string())); assert!(stat_columns.contains(&"aggregate".to_string())); @@ -1186,7 +1191,11 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("MAX(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); @@ -1228,8 +1237,16 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "query: {}", query); - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos1__\")"), + "query: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(!query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); let mut consumed = consumed_aesthetics.clone(); consumed.sort(); @@ -1276,8 +1293,16 @@ mod tests { consumed_aesthetics, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")"), "query: {}", query); - assert!(query.contains("AVG(\"__ggsql_aes_pos2max__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2min__\")"), + "query: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2max__\")"), + "query: {}", + query + ); assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); let mut consumed = consumed_aesthetics.clone(); consumed.sort(); @@ -1487,12 +1512,16 @@ mod tests { .. } => { assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "lower bound expr missing: {}", query ); assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "upper bound expr missing: {}", query ); diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index 4bdcf897..e9f86bcf 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -308,10 +308,7 @@ mod tests { dummy_columns, consumed_aesthetics, } => { - assert_eq!( - query, - "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\"" - ); + assert_eq!(query, "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\""); assert!(stat_columns.is_empty()); assert!(dummy_columns.is_empty()); assert!(consumed_aesthetics.is_empty()); @@ -340,7 +337,10 @@ mod tests { query, "SELECT * FROM (SELECT * FROM grouped) AS \"__ggsql_ord__\" ORDER BY \"__ggsql_aes_pos1__\"" ); - assert_eq!(stat_columns, vec!["pos2".to_string(), "aggregate".to_string()]); + assert_eq!( + stat_columns, + vec!["pos2".to_string(), "aggregate".to_string()] + ); assert_eq!(dummy_columns, vec!["pos1".to_string()]); assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 8416c411..562a55fc 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -436,7 +436,8 @@ impl Layer { } // Or the shared `aggregate` param for Identity-stat geoms else if param_name == "aggregate" && self.geom.supports_aggregate() { - let definition = crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); + let definition = + crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); validate_parameter(param_name, value, &definition.constraint)?; } // Otherwise it's a valid aesthetic setting (no constraint validation needed) From 8c5845fa9a618c3e89a672ee85dfd3a2eb9978eb Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 11:41:04 +0200 Subject: [PATCH 05/12] support aggregation in segment --- src/plot/layer/geom/segment.rs | 5 +++++ src/plot/layer/mod.rs | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index e0815be7..499ae173 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -44,6 +44,11 @@ impl GeomTrait for Segment { } fn aggregate_slots(&self) -> &'static [u8] { + // Segment is two endpoints connected by a line. Aggregate runs + // independently on each of the four position aesthetics: pos1 and + // pos1end (slot 1), pos2 and pos2end (slot 2). With `aggregate => 'mean'`, + // the segment goes from `(mean(pos1), mean(pos2))` to + // `(mean(pos1end), mean(pos2end))`. &[1, 2] } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 562a55fc..14572919 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -201,8 +201,8 @@ impl Layer { // Check if all required aesthetics exist. // When `aggregate` is set on a range geom, the (lower, upper) range pair - // is filled by the stat (e.g. pos2min/pos2max for ribbon) and shouldn't - // be required from the user. + // is filled by the stat (e.g. pos2min/pos2max for ribbon, pos2/pos2end + // for segment) and shouldn't be required from the user. let range_pair_skip: Option<(&'static str, &'static str)> = if crate::plot::layer::geom::has_aggregate_param(&self.parameters) { self.geom.aggregate_range_pair() From 2cb021646acba3741620640d9b9776590b71b96a Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 12:21:11 +0200 Subject: [PATCH 06/12] allow orientation in range and ribbon for aggregation case --- src/plot/layer/geom/range.rs | 8 ++++++++ src/plot/layer/geom/ribbon.rs | 20 +++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/plot/layer/geom/range.rs b/src/plot/layer/geom/range.rs index 12789dd4..d368c4e7 100644 --- a/src/plot/layer/geom/range.rs +++ b/src/plot/layer/geom/range.rs @@ -4,6 +4,7 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; +use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; /// Range geom - intervals along the secondary axis @@ -45,6 +46,13 @@ impl GeomTrait for Range { default: DefaultParamValue::Number(10.0), constraint: ParamConstraint::number_min(0.0), }, + // Default Null → resolve_orientation auto-detects from mappings/scales. + // User can override with `SETTING orientation => 'transposed'`. + ParamDefinition { + name: "orientation", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_option(ORIENTATION_VALUES), + }, ]; PARAMS } diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 07f005d7..47b58a97 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -3,6 +3,7 @@ use super::stat_aggregate; use super::types::{wrap_with_order_by, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::Mappings; @@ -36,11 +37,20 @@ impl GeomTrait for Ribbon { } fn default_params(&self) -> &'static [ParamDefinition] { - const PARAMS: &[ParamDefinition] = &[ParamDefinition { - name: "position", - default: DefaultParamValue::String("identity"), - constraint: ParamConstraint::string_option(POSITION_VALUES), - }]; + const PARAMS: &[ParamDefinition] = &[ + ParamDefinition { + name: "position", + default: DefaultParamValue::String("identity"), + constraint: ParamConstraint::string_option(POSITION_VALUES), + }, + // Default Null → resolve_orientation auto-detects from mappings/scales. + // User can override with `SETTING orientation => 'transposed'`. + ParamDefinition { + name: "orientation", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_option(ORIENTATION_VALUES), + }, + ]; PARAMS } From cc390bd6e31230799c1bfb7d5672b04e9bc736a9 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 12:25:57 +0200 Subject: [PATCH 07/12] rename to percentile --- src/plot/layer/geom/stat_aggregate.rs | 70 +++++++++++++-------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 67dc9d38..875f6e29 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -41,14 +41,14 @@ pub const AGG_NAMES: &[&str] = &[ "sdev", "var", "iqr", - // Quantiles - "q05", - "q10", - "q25", - "q50", - "q75", - "q90", - "q95", + // Percentiles + "p05", + "p10", + "p25", + "p50", + "p75", + "p90", + "p95", // Bands (mean ± spread) "mean-sdev", "mean+sdev", @@ -155,7 +155,7 @@ pub fn apply( // Decide strategy: single-pass when every quantile can be inlined. let needs_fallback = funcs.iter().any(|f| { - if let Some(frac) = quantile_fraction(f) { + if let Some(frac) = percentile_fraction(f) { // Use the first numeric column (any will do) for the probe, since we // only care whether the dialect produces Some or None. let probe = numeric_pos @@ -215,16 +215,16 @@ fn extract_aggregate_param(parameters: &HashMap) -> Opti } } -/// Map a quantile function name (`q05`..`q95`, `median`) to its fraction. -fn quantile_fraction(func: &str) -> Option { +/// Map a percentile function name (`p05`..`p95`, `median`) to its fraction. +fn percentile_fraction(func: &str) -> Option { match func { - "median" | "q50" => Some(0.50), - "q05" => Some(0.05), - "q10" => Some(0.10), - "q25" => Some(0.25), - "q75" => Some(0.75), - "q90" => Some(0.90), - "q95" => Some(0.95), + "median" | "p50" => Some(0.50), + "p05" => Some(0.05), + "p10" => Some(0.10), + "p25" => Some(0.25), + "p75" => Some(0.75), + "p90" => Some(0.90), + "p95" => Some(0.95), _ => None, } } @@ -237,7 +237,7 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti if func == "count" { return None; } - if let Some(frac) = quantile_fraction(func) { + if let Some(frac) = percentile_fraction(func) { // Strip the quotes added by `naming::quote_ident` so we can re-quote inside // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. let unquoted = unquote(qcol); @@ -397,7 +397,7 @@ fn build_range_function_sql( .to_string(), )); } - if let Some(frac) = quantile_fraction(func) { + if let Some(frac) = percentile_fraction(func) { if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { return Ok(inline); } @@ -454,14 +454,14 @@ fn build_single_pass_query( let wide_name = synthetic_col_name(aes, func); let expr = match func.as_str() { "iqr" => { - // q75 - q25 inline if dialect supports it - let q75 = dialect + // p75 - p25 inline if dialect supports it + let p75 = dialect .sql_quantile_inline(col, 0.75) .expect("sql_quantile_inline must be Some when single-pass is selected"); - let q25 = dialect + let p25 = dialect .sql_quantile_inline(col, 0.25) .expect("sql_quantile_inline must be Some when single-pass is selected"); - format!("({} - {})", q75, q25) + format!("({} - {})", p75, p25) } _ => function_inline_sql(func, &qcol, dialect) .expect("function_inline_sql must be Some when single-pass is selected"), @@ -619,10 +619,10 @@ fn build_union_all_query( let value_expr = if func == "count" { "NULL".to_string() } else if func == "iqr" { - let q75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); - let q25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); - format!("({} - {})", q75, q25) - } else if let Some(frac) = quantile_fraction(func) { + let p75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); + let p25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); + format!("({} - {})", p75, p25) + } else if let Some(frac) = percentile_fraction(func) { dialect.sql_percentile(col, frac, src_alias, group_cols) } else { let qcol = naming::quote_ident(col); @@ -889,7 +889,7 @@ mod tests { let mut params = HashMap::new(); params.insert( "aggregate".to_string(), - ParameterValue::String("q25".to_string()), + ParameterValue::String("p25".to_string()), ); let result = apply( @@ -921,7 +921,7 @@ mod tests { let mut params = HashMap::new(); params.insert( "aggregate".to_string(), - ParameterValue::String("q25".to_string()), + ParameterValue::String("p25".to_string()), ); let result = apply( @@ -1041,7 +1041,7 @@ mod tests { } #[test] - fn iqr_emits_q75_minus_q25() { + fn iqr_emits_p75_minus_p25() { let mut aes = Mappings::new(); aes.insert("pos2", col("__ggsql_aes_pos2__")); let schema = numeric_schema(&["__ggsql_aes_pos2__"]); @@ -1606,8 +1606,8 @@ mod tests { params.insert( "aggregate".to_string(), ParameterValue::Array(vec![ - ArrayElement::String("q25".to_string()), - ArrayElement::String("q75".to_string()), + ArrayElement::String("p25".to_string()), + ArrayElement::String("p75".to_string()), ]), ); @@ -1641,8 +1641,8 @@ mod tests { params.insert( "aggregate".to_string(), ParameterValue::Array(vec![ - ArrayElement::String("q25".to_string()), - ArrayElement::String("q75".to_string()), + ArrayElement::String("p25".to_string()), + ArrayElement::String("p75".to_string()), ]), ); From 447600565ce9c6dfffeb5cdda7221bbea3b9ea4e Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:06:50 +0200 Subject: [PATCH 08/12] make aggregates parametric --- src/plot/layer/geom/stat_aggregate.rs | 749 +++++++++++++++++++++++--- src/plot/layer/mod.rs | 4 +- 2 files changed, 666 insertions(+), 87 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 875f6e29..1400bc33 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -15,13 +15,15 @@ use std::collections::HashMap; use super::types::StatResult; use crate::naming; use crate::plot::aesthetic::{is_position_aesthetic, parse_position}; -use crate::plot::types::{ - DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, -}; +use crate::plot::types::{ParameterValue, Schema}; use crate::reader::SqlDialect; use crate::{GgsqlError, Mappings, Result}; -/// All aggregation function names accepted by the `aggregate` SETTING. +/// All simple-aggregation function names accepted by the `aggregate` SETTING. +/// +/// Band names (e.g. `mean+sdev`, `median-0.5iqr`) are validated separately by +/// `parse_agg_name`, which checks the offset against `OFFSET_STATS` and the +/// expansion against `EXPANSION_STATS`. pub const AGG_NAMES: &[&str] = &[ // Tallies & sums "count", @@ -49,25 +51,212 @@ pub const AGG_NAMES: &[&str] = &[ "p75", "p90", "p95", - // Bands (mean ± spread) - "mean-sdev", - "mean+sdev", - "mean-2sdev", - "mean+2sdev", - "mean-se", - "mean+se", ]; -/// Returns the `ParamDefinition` for the `aggregate` SETTING parameter. +/// Stats that can appear as the *offset* (left of `±`) in a band name like +/// `mean+sdev`. Single-value central or representative quantities only — +/// counts/spreads are excluded. +pub const OFFSET_STATS: &[&str] = &[ + "mean", + "median", + "geomean", + "harmean", + "rms", + "sum", + "prod", + "min", + "max", + "p05", + "p10", + "p25", + "p50", + "p75", + "p90", + "p95", +]; + +/// Stats that can appear as the *expansion* (right of `±[mod]`) in a band name. +/// Spread / dispersion measures only. +pub const EXPANSION_STATS: &[&str] = &["sdev", "se", "var", "iqr", "range"]; + +/// Parsed representation of any aggregate-function name. /// -/// Used by `Layer::validate_settings` to check the value against `AGG_NAMES`, -/// and by geoms that support aggregation. -pub fn aggregate_param_definition() -> ParamDefinition { - ParamDefinition { - name: "aggregate", - default: DefaultParamValue::Null, - constraint: ParamConstraint::string_or_string_array(AGG_NAMES), +/// Simple aggregates (`mean`, `count`, `p25`) have `band == None`. Band names +/// (`mean+sdev`, `median-0.5iqr`) have `band == Some(...)` with the offset +/// stored in `offset` and the spread/multiplier in `band`. +#[derive(Debug, Clone, PartialEq)] +pub struct AggSpec { + pub offset: &'static str, + pub band: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Band { + pub sign: char, + pub mod_value: f64, + pub expansion: &'static str, +} + +/// Resolve a name to its canonical `&'static str` from the given vocabulary, +/// or `None` if the input doesn't match any entry. +fn resolve_static(name: &str, vocab: &'static [&'static str]) -> Option<&'static str> { + vocab.iter().copied().find(|v| *v == name) +} + +/// Parse an aggregate-function name into an `AggSpec`. Returns `None` on +/// invalid input (unknown stat, malformed band, or band with vocabulary +/// violation). +pub fn parse_agg_name(name: &str) -> Option { + if let Some(spec) = parse_band(name) { + return Some(spec); + } + resolve_static(name, AGG_NAMES).map(|offset| AggSpec { offset, band: None }) +} + +/// Try to parse `name` as a band: `?`. Returns +/// `None` if it doesn't match the band shape OR if either half is outside its +/// allowed vocabulary. +fn parse_band(name: &str) -> Option { + // Walk offsets longest-first so `median` matches before `mean`. + let mut offsets: Vec<&'static str> = OFFSET_STATS.to_vec(); + offsets.sort_by_key(|s| std::cmp::Reverse(s.len())); + + for offset in offsets { + let rest = match name.strip_prefix(offset) { + Some(r) => r, + None => continue, // doesn't start with this offset + }; + let (sign, after_sign) = match rest.chars().next() { + Some('+') => ('+', &rest[1..]), + Some('-') => ('-', &rest[1..]), + _ => continue, // wrong sign char — try next offset + }; + + let (mod_value, expansion_str) = parse_mod_and_remainder(after_sign); + let expansion = match resolve_static(expansion_str, EXPANSION_STATS) { + Some(e) => e, + None => continue, // expansion doesn't match — try next offset + }; + + return Some(AggSpec { + offset, + band: Some(Band { + sign, + mod_value, + expansion, + }), + }); } + None +} + +/// Parse a leading `(.)?` modifier from `s`. Returns +/// `(parsed_value, rest_of_string)`. If no leading digits, returns +/// `(1.0, s)` — modifier defaults to 1. +fn parse_mod_and_remainder(s: &str) -> (f64, &str) { + let mut idx = 0; + let bytes = s.as_bytes(); + while idx < bytes.len() && bytes[idx].is_ascii_digit() { + idx += 1; + } + if idx < bytes.len() && bytes[idx] == b'.' { + let mut after_dot = idx + 1; + while after_dot < bytes.len() && bytes[after_dot].is_ascii_digit() { + after_dot += 1; + } + if after_dot > idx + 1 { + // need at least one digit after '.' + idx = after_dot; + } + } + if idx == 0 { + return (1.0, s); + } + let num_str = &s[..idx]; + let value: f64 = num_str.parse().unwrap_or(1.0); + (value, &s[idx..]) +} + +/// Validate the `aggregate` SETTING value: null, a single function name, or +/// an array of function names. Each name must be parseable by `parse_agg_name`. +pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<(), String> { + use crate::plot::types::ArrayElement; + match value { + ParameterValue::Null => Ok(()), + ParameterValue::String(s) => validate_function_name(s), + ParameterValue::Array(arr) => { + for el in arr { + match el { + ArrayElement::String(s) => validate_function_name(s)?, + ArrayElement::Null => continue, + _ => { + return Err( + "'aggregate' array entries must be strings or null".to_string() + ); + } + } + } + Ok(()) + } + _ => Err("'aggregate' must be a string, array of strings, or null".to_string()), + } +} + +fn validate_function_name(name: &str) -> std::result::Result<(), String> { + match parse_agg_name(name) { + Some(_) => Ok(()), + None => Err(diagnose_invalid_function_name(name)), + } +} + +/// Build a per-role error message for a name that didn't parse. Re-walks the +/// input with looser rules to identify which side (offset / expansion) failed. +fn diagnose_invalid_function_name(name: &str) -> String { + // Look for a sign character. If there is one, examine the offset and + // expansion halves separately. + if let Some(sign_idx) = name.find(|c| c == '+' || c == '-') { + let offset_str = &name[..sign_idx]; + let after_sign = &name[sign_idx + 1..]; + let (_mod_value, expansion_str) = parse_mod_and_remainder(after_sign); + + let offset_known_simple = AGG_NAMES.contains(&offset_str); + let offset_known_band = OFFSET_STATS.contains(&offset_str); + let expansion_known_band = EXPANSION_STATS.contains(&expansion_str); + + if !offset_known_band { + // The offset half is the problem. + if offset_known_simple { + return format!( + "'{}': '{}' is not a valid offset stat. Allowed offsets: {}", + name, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ); + } + return format!( + "'{}': '{}' is not a known stat. Allowed offsets: {}", + name, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ); + } + if !expansion_known_band { + return format!( + "'{}': '{}' is not a valid expansion stat. Allowed expansions: {}", + name, + expansion_str, + crate::or_list_quoted(EXPANSION_STATS, '\''), + ); + } + // Both halves are individually valid but band parsing failed for some + // other reason (e.g. malformed modifier). + return format!("'{}' is not a valid aggregate function name", name); + } + format!( + "unknown aggregate function '{}'. Allowed: {} (or use a band like `mean+sdev`)", + name, + crate::or_list_quoted(AGG_NAMES, '\''), + ) } /// Apply the Aggregate stat to a layer query. @@ -153,19 +342,15 @@ pub fn apply( let needs_count_col = funcs.iter().any(|f| f == "count"); - // Decide strategy: single-pass when every quantile can be inlined. + // Decide strategy: single-pass when every percentile component can be inlined. + let probe = numeric_pos + .first() + .map(|(_, c)| c.as_str()) + .unwrap_or("__ggsql_probe__"); let needs_fallback = funcs.iter().any(|f| { - if let Some(frac) = percentile_fraction(f) { - // Use the first numeric column (any will do) for the probe, since we - // only care whether the dialect produces Some or None. - let probe = numeric_pos - .first() - .map(|(_, c)| c.as_str()) - .unwrap_or("__ggsql_probe__"); - dialect.sql_quantile_inline(probe, frac).is_none() - } else { - false - } + parse_agg_name(f) + .map(|spec| needs_quantile_fallback(&spec, probe, dialect)) + .unwrap_or(false) }); let transformed_query = if needs_fallback { @@ -229,21 +414,27 @@ fn percentile_fraction(func: &str) -> Option { } } -/// Build the inline SQL fragment for a function applied to a quoted column. +/// Build the inline SQL fragment for a *simple* stat (no band) applied to a +/// quoted column. /// -/// Returns None for `count` (which doesn't take a column) and for quantiles when -/// the dialect lacks an inline form (caller should switch to UNION ALL strategy). -fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { - if func == "count" { +/// Returns `None` for `count` (which doesn't take a column) and for percentile- +/// based stats (`p05..p95`, `median`, `iqr`) when the dialect lacks an inline +/// quantile aggregate (caller should switch to UNION ALL strategy). +fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { + if name == "count" { return None; } - if let Some(frac) = percentile_fraction(func) { - // Strip the quotes added by `naming::quote_ident` so we can re-quote inside - // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. + if let Some(frac) = percentile_fraction(name) { let unquoted = unquote(qcol); return dialect.sql_quantile_inline(&unquoted, frac); } - Some(match func { + if name == "iqr" { + let unquoted = unquote(qcol); + let p75 = dialect.sql_quantile_inline(&unquoted, 0.75)?; + let p25 = dialect.sql_quantile_inline(&unquoted, 0.25)?; + return Some(format!("({} - {})", p75, p25)); + } + Some(match name { "sum" => format!("SUM({})", qcol), "prod" => format!("EXP(SUM(LN({})))", qcol), "min" => format!("MIN({})", qcol), @@ -254,18 +445,103 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), "sdev" => format!("STDDEV_POP({})", qcol), + "se" => format!("(STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), "var" => format!("VAR_POP({})", qcol), - "mean-sdev" => format!("(AVG({c}) - STDDEV_POP({c}))", c = qcol), - "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), - "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), - "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), - "mean-se" => format!("(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), - "mean+se" => format!("(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), - // `iqr` is computed from quantiles - handled separately. _ => return None, }) } +/// Inline SQL for a parsed `AggSpec`. Combines the offset and (optional) +/// expansion halves with the appropriate sign and modifier. +fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Option { + let offset_sql = simple_stat_sql_inline(spec.offset, qcol, dialect)?; + match &spec.band { + None => Some(offset_sql), + Some(band) => { + let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?; + Some(format_band(&offset_sql, band.sign, band.mod_value, &exp_sql)) + } + } +} + +/// Build the SQL fragment `(offset ± mod * exp)`, omitting the `mod *` prefix +/// when `mod_value == 1.0`. +fn format_band(offset: &str, sign: char, mod_value: f64, exp: &str) -> String { + if mod_value == 1.0 { + format!("({} {} {})", offset, sign, exp) + } else { + format!("({} {} {} * {})", offset, sign, mod_value, exp) + } +} + +/// Fallback SQL for a simple stat. Used by the UNION-ALL path for percentile +/// components (which need correlated `sql_percentile`) and falls through to +/// the inline form for everything else. +fn simple_stat_sql_fallback( + name: &str, + raw_col: &str, + dialect: &dyn SqlDialect, + src_alias: &str, + group_cols: &[String], +) -> String { + if name == "count" { + return "NULL".to_string(); + } + if let Some(frac) = percentile_fraction(name) { + return dialect.sql_percentile(raw_col, frac, src_alias, group_cols); + } + if name == "iqr" { + let p75 = dialect.sql_percentile(raw_col, 0.75, src_alias, group_cols); + let p25 = dialect.sql_percentile(raw_col, 0.25, src_alias, group_cols); + return format!("({} - {})", p75, p25); + } + let qcol = naming::quote_ident(raw_col); + simple_stat_sql_inline(name, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) +} + +/// Fallback SQL for a parsed `AggSpec` (UNION-ALL path). +fn agg_sql_fallback( + spec: &AggSpec, + raw_col: &str, + dialect: &dyn SqlDialect, + src_alias: &str, + group_cols: &[String], +) -> String { + let offset_sql = simple_stat_sql_fallback(spec.offset, raw_col, dialect, src_alias, group_cols); + match &spec.band { + None => offset_sql, + Some(band) => { + let exp_sql = + simple_stat_sql_fallback(band.expansion, raw_col, dialect, src_alias, group_cols); + format_band(&offset_sql, band.sign, band.mod_value, &exp_sql) + } + } +} + +/// Whether this spec has any percentile component that the dialect can't +/// inline (in which case the caller must use the UNION-ALL fallback). +fn needs_quantile_fallback(spec: &AggSpec, probe_col: &str, dialect: &dyn SqlDialect) -> bool { + if simple_needs_fallback(spec.offset, probe_col, dialect) { + return true; + } + if let Some(band) = &spec.band { + if simple_needs_fallback(band.expansion, probe_col, dialect) { + return true; + } + } + false +} + +fn simple_needs_fallback(name: &str, probe_col: &str, dialect: &dyn SqlDialect) -> bool { + if let Some(frac) = percentile_fraction(name) { + return dialect.sql_quantile_inline(probe_col, frac).is_none(); + } + if name == "iqr" { + return dialect.sql_quantile_inline(probe_col, 0.5).is_none(); + } + false +} + /// Strip surrounding double quotes from an identifier, undoing `naming::quote_ident`. fn unquote(qcol: &str) -> String { let trimmed = qcol.trim_start_matches('"').trim_end_matches('"'); @@ -349,9 +625,9 @@ fn apply_range_mode( format!(" GROUP BY {}", qcols.join(", ")) }; - // Build the two function expressions. Quantiles use the inline form when - // available; otherwise fall back to `sql_percentile` correlated to the - // outer alias used in the FROM (`__ggsql_qt__`, matching boxplot/etc.). + // Parse and emit each bound. Use the inline form when the dialect supports + // every percentile component; otherwise fall back to `sql_percentile` + // correlated to the outer alias used in the FROM (`__ggsql_qt__`). let lo_expr = build_range_function_sql(&funcs[0], &qcol, &input_col, dialect, &group_cols)?; let hi_expr = build_range_function_sql(&funcs[1], &qcol, &input_col, dialect, &group_cols)?; @@ -382,8 +658,10 @@ fn apply_range_mode( }) } -/// Build the SQL fragment for one function in range mode. Quantiles get the -/// inline form when the dialect supports it; otherwise the fallback subquery. +/// Build the SQL fragment for one function in range mode. Parses the function +/// name into an `AggSpec` (which validates the offset/expansion vocabulary) +/// and emits inline SQL when the dialect supports every percentile component, +/// otherwise the correlated fallback. fn build_range_function_sql( func: &str, qcol: &str, @@ -397,18 +675,28 @@ fn build_range_function_sql( .to_string(), )); } - if let Some(frac) = percentile_fraction(func) { - if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { - return Ok(inline); - } - return Ok(dialect.sql_percentile(raw_col, frac, "\"__ggsql_stat_src__\"", group_cols)); - } - function_inline_sql(func, qcol, dialect).ok_or_else(|| { + let spec = parse_agg_name(func).ok_or_else(|| { GgsqlError::ValidationError(format!( - "aggregate on a range geom does not support function '{}' on this dialect", - func + "aggregate on a range geom: {}", + diagnose_invalid_function_name(func) )) - }) + })?; + if needs_quantile_fallback(&spec, raw_col, dialect) { + Ok(agg_sql_fallback( + &spec, + raw_col, + dialect, + "\"__ggsql_stat_src__\"", + group_cols, + )) + } else { + agg_sql_inline(&spec, qcol, dialect).ok_or_else(|| { + GgsqlError::ValidationError(format!( + "aggregate on a range geom does not support function '{}' on this dialect", + func + )) + }) + } } // ============================================================================= @@ -452,20 +740,10 @@ fn build_single_pass_query( continue; } let wide_name = synthetic_col_name(aes, func); - let expr = match func.as_str() { - "iqr" => { - // p75 - p25 inline if dialect supports it - let p75 = dialect - .sql_quantile_inline(col, 0.75) - .expect("sql_quantile_inline must be Some when single-pass is selected"); - let p25 = dialect - .sql_quantile_inline(col, 0.25) - .expect("sql_quantile_inline must be Some when single-pass is selected"); - format!("({} - {})", p75, p25) - } - _ => function_inline_sql(func, &qcol, dialect) - .expect("function_inline_sql must be Some when single-pass is selected"), - }; + let spec = parse_agg_name(func) + .expect("aggregate function names are validated upstream of single-pass"); + let expr = agg_sql_inline(&spec, &qcol, dialect) + .expect("agg_sql_inline must be Some when single-pass is selected"); wide_select_exprs.push(format!("{} AS {}", expr, naming::quote_ident(&wide_name))); wide_col_for.insert(key, wide_name); } @@ -614,19 +892,18 @@ fn build_union_all_query( .map(|func| { let mut select_parts: Vec = group_select.clone(); + // Parse the function name once per branch. Falls through to a + // string-NULL value column if parsing fails (shouldn't happen + // because validation runs upstream, but stay defensive). + let parsed_spec = parse_agg_name(func); for (aes, col) in numeric_pos { let stat_col = naming::stat_column(aes); let value_expr = if func == "count" { "NULL".to_string() - } else if func == "iqr" { - let p75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); - let p25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); - format!("({} - {})", p75, p25) - } else if let Some(frac) = percentile_fraction(func) { - dialect.sql_percentile(col, frac, src_alias, group_cols) + } else if let Some(spec) = &parsed_spec { + agg_sql_fallback(spec, col, dialect, src_alias, group_cols) } else { - let qcol = naming::quote_ident(col); - function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) + "NULL".to_string() }; select_parts.push(format!( "{} AS {}", @@ -1705,4 +1982,308 @@ mod tests { err ); } + + // ======================================================================== + // Parser tests (parse_agg_name) + // ======================================================================== + + #[test] + fn parse_simple_names() { + assert_eq!( + parse_agg_name("mean"), + Some(AggSpec { offset: "mean", band: None }) + ); + assert_eq!( + parse_agg_name("count"), + Some(AggSpec { offset: "count", band: None }) + ); + assert_eq!( + parse_agg_name("p25"), + Some(AggSpec { offset: "p25", band: None }) + ); + } + + #[test] + fn parse_band_default_modifier() { + let spec = parse_agg_name("mean+sdev").unwrap(); + assert_eq!(spec.offset, "mean"); + let band = spec.band.unwrap(); + assert_eq!(band.sign, '+'); + assert_eq!(band.mod_value, 1.0); + assert_eq!(band.expansion, "sdev"); + } + + #[test] + fn parse_band_integer_modifier() { + let spec = parse_agg_name("mean-2sdev").unwrap(); + let band = spec.band.unwrap(); + assert_eq!(band.sign, '-'); + assert_eq!(band.mod_value, 2.0); + assert_eq!(band.expansion, "sdev"); + } + + #[test] + fn parse_band_decimal_modifier() { + let spec = parse_agg_name("mean+1.96sdev").unwrap(); + let band = spec.band.unwrap(); + assert_eq!(band.mod_value, 1.96); + } + + #[test] + fn parse_band_longest_offset_wins() { + // 'median+sdev' must match offset 'median', not 'me' (which isn't an + // offset anyway, but more pertinently the parser must not stop at a + // shorter prefix). + let spec = parse_agg_name("median+sdev").unwrap(); + assert_eq!(spec.offset, "median"); + } + + #[test] + fn parse_band_percentile_offset() { + let spec = parse_agg_name("p25+0.5range").unwrap(); + assert_eq!(spec.offset, "p25"); + let band = spec.band.unwrap(); + assert_eq!(band.mod_value, 0.5); + assert_eq!(band.expansion, "range"); + } + + #[test] + fn parse_band_rejects_invalid_offset() { + assert!(parse_agg_name("count+sdev").is_none()); + assert!(parse_agg_name("iqr+sdev").is_none()); + } + + #[test] + fn parse_band_rejects_invalid_expansion() { + assert!(parse_agg_name("mean+count").is_none()); + assert!(parse_agg_name("mean+median").is_none()); + } + + #[test] + fn parse_rejects_unknown() { + assert!(parse_agg_name("foo").is_none()); + assert!(parse_agg_name("").is_none()); + } + + // ======================================================================== + // Validation tests (validate_aggregate_param) + // ======================================================================== + + #[test] + fn validate_accepts_simple_names_and_bands() { + use crate::plot::types::ArrayElement; + validate_aggregate_param(&ParameterValue::String("mean".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::String("mean+sdev".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::String("median-0.5iqr".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::Array(vec![ + ArrayElement::String("mean".to_string()), + ArrayElement::String("mean+1.96sdev".to_string()), + ])) + .unwrap(); + } + + #[test] + fn validate_diagnostic_for_invalid_offset() { + let err = validate_aggregate_param(&ParameterValue::String("count+sdev".to_string())) + .unwrap_err(); + assert!(err.contains("count"), "err: {}", err); + assert!(err.contains("offset"), "err: {}", err); + } + + #[test] + fn validate_diagnostic_for_invalid_expansion() { + let err = validate_aggregate_param(&ParameterValue::String("mean+count".to_string())) + .unwrap_err(); + assert!(err.contains("count"), "err: {}", err); + assert!(err.contains("expansion"), "err: {}", err); + } + + #[test] + fn validate_diagnostic_for_unknown() { + let err = + validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); + assert!(err.contains("unknown"), "err: {}", err); + assert!(err.contains("foo"), "err: {}", err); + } + + // ======================================================================== + // SQL emission for parametric bands + // ======================================================================== + + #[test] + fn band_decimal_modifier_emits_in_sql() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+1.96sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_percentile_offset_inline() { + // median-0.5iqr on a dialect with inline quantile support. + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("median-0.5iqr".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // median uses QUANTILE_CONT(col, 0.5); iqr uses QUANTILE_CONT(.., 0.75) and 0.25. + assert!( + query.contains("QUANTILE_CONT") && query.contains("0.5"), + "query: {}", + query + ); + assert!(query.contains("0.75") && query.contains("0.25")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_percentile_offset_falls_back() { + // median+2sdev on a dialect WITHOUT inline quantile support → UNION-ALL + // path with sql_percentile for median, inline STDDEV_POP for sdev. + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("median+2sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("NTILE(4)")); + assert!(query.contains("STDDEV_POP")); + assert!(query.contains("2 * ")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_default_modifier_omits_one_prefix() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // mod=1 case: (offset + exp), no `1 *` prefix. + assert!( + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), + "expected `(AVG + STDDEV_POP)` form, got: {}", + query + ); + assert!(!query.contains("1 * STDDEV_POP")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_supports_decimal_band() { + // Ribbon range mode + 95% CI band. + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-1.96sdev".to_string()), + ArrayElement::String("mean+1.96sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("- 1.96 * STDDEV_POP")); + assert!(query.contains("+ 1.96 * STDDEV_POP")); + } + _ => panic!("expected Transformed"), + } + } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 14572919..91961156 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -436,9 +436,7 @@ impl Layer { } // Or the shared `aggregate` param for Identity-stat geoms else if param_name == "aggregate" && self.geom.supports_aggregate() { - let definition = - crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); - validate_parameter(param_name, value, &definition.constraint)?; + crate::plot::layer::geom::stat_aggregate::validate_aggregate_param(value)?; } // Otherwise it's a valid aesthetic setting (no constraint validation needed) } From 3f1a4335059e9d6f70ee3d945d3ef6aa90d66022 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:07:03 +0200 Subject: [PATCH 09/12] reformat --- src/plot/layer/geom/stat_aggregate.rs | 81 ++++++++++----------------- 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 1400bc33..98246b8f 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -26,53 +26,19 @@ use crate::{GgsqlError, Mappings, Result}; /// expansion against `EXPANSION_STATS`. pub const AGG_NAMES: &[&str] = &[ // Tallies & sums - "count", - "sum", - "prod", - // Extremes - "min", - "max", - "range", - // Central tendency - "mean", - "geomean", - "harmean", - "rms", - "median", - // Spread (standalone) - "sdev", - "var", - "iqr", - // Percentiles - "p05", - "p10", - "p25", - "p50", - "p75", - "p90", - "p95", + "count", "sum", "prod", // Extremes + "min", "max", "range", // Central tendency + "mean", "geomean", "harmean", "rms", "median", // Spread (standalone) + "sdev", "var", "iqr", // Percentiles + "p05", "p10", "p25", "p50", "p75", "p90", "p95", ]; /// Stats that can appear as the *offset* (left of `±`) in a band name like /// `mean+sdev`. Single-value central or representative quantities only — /// counts/spreads are excluded. pub const OFFSET_STATS: &[&str] = &[ - "mean", - "median", - "geomean", - "harmean", - "rms", - "sum", - "prod", - "min", - "max", - "p05", - "p10", - "p25", - "p50", - "p75", - "p90", - "p95", + "mean", "median", "geomean", "harmean", "rms", "sum", "prod", "min", "max", "p05", "p10", + "p25", "p50", "p75", "p90", "p95", ]; /// Stats that can appear as the *expansion* (right of `±[mod]`) in a band name. @@ -190,9 +156,7 @@ pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<( ArrayElement::String(s) => validate_function_name(s)?, ArrayElement::Null => continue, _ => { - return Err( - "'aggregate' array entries must be strings or null".to_string() - ); + return Err("'aggregate' array entries must be strings or null".to_string()); } } } @@ -459,7 +423,12 @@ fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Optio None => Some(offset_sql), Some(band) => { let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?; - Some(format_band(&offset_sql, band.sign, band.mod_value, &exp_sql)) + Some(format_band( + &offset_sql, + band.sign, + band.mod_value, + &exp_sql, + )) } } } @@ -1991,15 +1960,24 @@ mod tests { fn parse_simple_names() { assert_eq!( parse_agg_name("mean"), - Some(AggSpec { offset: "mean", band: None }) + Some(AggSpec { + offset: "mean", + band: None + }) ); assert_eq!( parse_agg_name("count"), - Some(AggSpec { offset: "count", band: None }) + Some(AggSpec { + offset: "count", + band: None + }) ); assert_eq!( parse_agg_name("p25"), - Some(AggSpec { offset: "p25", band: None }) + Some(AggSpec { + offset: "p25", + band: None + }) ); } @@ -2100,8 +2078,7 @@ mod tests { #[test] fn validate_diagnostic_for_unknown() { - let err = - validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); + let err = validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); assert!(err.contains("unknown"), "err: {}", err); assert!(err.contains("foo"), "err: {}", err); } @@ -2135,7 +2112,9 @@ mod tests { match result { StatResult::Transformed { query, .. } => { assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "query: {}", query ); From 6147ccc42cf4556bbd47a2b2a60bd2b3153b80e5 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:16:25 +0200 Subject: [PATCH 10/12] clippy be happy --- src/plot/layer/geom/stat_aggregate.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 98246b8f..61ee15cf 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -178,7 +178,7 @@ fn validate_function_name(name: &str) -> std::result::Result<(), String> { fn diagnose_invalid_function_name(name: &str) -> String { // Look for a sign character. If there is one, examine the offset and // expansion halves separately. - if let Some(sign_idx) = name.find(|c| c == '+' || c == '-') { + if let Some(sign_idx) = name.find(['+', '-']) { let offset_str = &name[..sign_idx]; let after_sign = &name[sign_idx + 1..]; let (_mod_value, expansion_str) = parse_mod_and_remainder(after_sign); @@ -235,6 +235,7 @@ fn diagnose_invalid_function_name(name: &str) -> String { /// - **UNION ALL fallback**: when a quantile is requested but the dialect doesn't /// provide `sql_quantile_inline`, fall back to per-function subqueries using /// `dialect.sql_percentile`. +#[allow(clippy::too_many_arguments)] pub fn apply( query: &str, schema: &Schema, @@ -527,6 +528,7 @@ fn func_literal(func: &str) -> String { // pair on the same row. Used by ribbon/range. // ============================================================================= +#[allow(clippy::too_many_arguments)] fn apply_range_mode( query: &str, schema: &Schema, From 1c613e48228cd429e260bf1902da0108b1ff331e Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:55:08 +0200 Subject: [PATCH 11/12] ensure multiple aggregates give rise to multiple groups --- src/execute/layer.rs | 45 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 1c9b1b0d..1f0c10d0 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -187,11 +187,16 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result = df .get_column_names() .into_iter() - .filter(|name| naming::is_stat_column(name)) + .filter(|name| { + naming::is_stat_column(name) && !layer.partition_by.contains(&name.to_string()) + }) .collect(); if !stat_cols.is_empty() { df = df.drop_many(&stat_cols)?; @@ -200,6 +205,18 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result) -> usize { + match parameters.get("aggregate") { + Some(ParameterValue::String(_)) => 1, + Some(ParameterValue::Array(arr)) => arr.len(), + _ => 0, + } +} + /// Convert a literal value to an Arrow ArrayRef with constant values. /// /// For string literals, attempts to parse as temporal types (date/datetime/time) @@ -634,6 +651,30 @@ where } } + // The `aggregate` stat column (produced by stat_aggregate when the + // user requests multiple functions) tags each row with its function + // name. For mark types that connect rows within a group (line, area, + // path, polygon), we need to add this column to `layer.partition_by` + // so that e.g. `aggregate => ('min', 'max')` renders as two separate + // lines rather than one zigzag through both. Resolves to the + // post-rename data-column name: if the user remapped `aggregate AS + // `, the prefixed aesthetic column; otherwise the stat column. + // + // Only fires when more than one function is requested — a single + // function produces a constant aggregate column, partitioning by + // which would just add a no-op detail channel. + if stat_columns.iter().any(|s| s == "aggregate") + && aggregate_param_function_count(&layer.parameters) > 1 + { + let partition_col = match final_remappings.get("aggregate") { + Some(aes) => naming::aesthetic_column(aes), + None => naming::stat_column("aggregate"), + }; + if !layer.partition_by.contains(&partition_col) { + layer.partition_by.push(partition_col); + } + } + // Wrap transformed query to rename stat columns to prefixed aesthetic names let stat_rename_exprs: Vec = stat_columns .iter() From f3081a3e4dbdc6ecaceeea700c0053b869218d45 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 15:37:42 +0200 Subject: [PATCH 12/12] begin to document --- doc/syntax/clause/draw.qmd | 19 +++++++++++++++++++ doc/syntax/layer/type/area.qmd | 5 ++++- doc/syntax/layer/type/bar.qmd | 13 +++++++++++++ doc/syntax/layer/type/line.qmd | 18 ++++++++++++++++-- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index 31bf8f3a..f8b8b30e 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -76,6 +76,25 @@ The `SETTING` clause can be used for two different things: #### Position A special setting is `position` which controls how overlapping objects are repositioned to avoid overlapping etc. Position adjustments have special mapping requirements so all position adjustments will not be relevant for all layer types. Different layers have different defaults as detailed in their documentation. You can read about each different position adjustment at [their own documentation sites](../index.qmd#position-adjustments). +#### Aggregate +Some layers support aggregation of its data through the `aggregate` setting. These layers will state this. `aggregate` allows a single string or an array of strings that specify the aggregation to calculate. The aggregates can be either a simple function or a parameterized band function. + +The simple functions can be one of: + +* `'count'`: Row count +* `'sum'` and `'prod'`: The sum or product +* `'min'`, `'max'`, and `'range'`: Extremes and max - min +* `'mean'`, and `'median'`: Central tendency +* `'geomean'`, `'harmean'`, and `'rms'`: Geometric, harmonic, and root-mean-square +* `'sdev'`, `'var'`, `'iqr'`, and `'se'`: Standard deviation, variance, interquartile range, and standard error +* `'p05'`, `'p10'`, `'p25'`, `'p50'`, `'p75'`, `'p90'`, and `'p95'`: Percentiles + +For band functions you combine an offset with an expansion, potentially multiplied. An example could be `'mean-1.96sdev'` which does exactly what you'd expect it to be. The general form is `±` with `` being optional (defaults to `1`). + +Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `'sum'`, `'prod'`, `'min'`, `'max'`, and `'p05'`–`'p95'` + +Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` + ### `FILTER` ```ggsql FILTER diff --git a/doc/syntax/layer/type/area.qmd b/doc/syntax/layer/type/area.qmd index a72b059f..1fa0cdc3 100644 --- a/doc/syntax/layer/type/area.qmd +++ b/doc/syntax/layer/type/area.qmd @@ -25,9 +25,12 @@ The following aesthetics are recognised by the area layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The area layer sorts the data along its primary axis +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + +Further, the area layer sorts the data along its primary axis before returning it. ## Orientation Area plots are sorted and connected along their primary axis. Since the primary axis cannot be deduced from the mapping it must be specified using the `orientation` setting. E.g. if you wish to create a vertical area plot you need to set `orientation => 'transposed'` to indicate that the primary layer axis follows the second axis of the coordinate system. diff --git a/doc/syntax/layer/type/bar.qmd b/doc/syntax/layer/type/bar.qmd index d34a4953..f8efc63b 100644 --- a/doc/syntax/layer/type/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -25,10 +25,13 @@ The bar layer has no required aesthetics ## Settings * `position`: Position adjustment. One of `'identity'`, `'stack'` (default), `'dodge'`, or `'jitter'` * `width`: The width of the bars as a proportion of the available width (0 to 1) +* `aggregate`: Aggregation functions to apply per group if the secondary position has been mapped. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation If the secondary axis has not been mapped the layer will calculate counts for you and display these as the secondary axis. +If the secondary axis has been mapped you can apply aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + ### Properties * `weight`: If mapped, the sum of the weights within each group is calculated instead of the count in each group @@ -116,3 +119,13 @@ DRAW bar MAPPING species AS fill PROJECT TO polar ``` + +Use a different type of aggregation for the bars through the `aggregate` setting: + +```{ggsql} +VISUALISE species AS y, body_mass AS y FROM ggsql:penguins +DRAW bar + SETTING aggregate => 'mean', fill => 'steelblue' +DRAW range + setting aggregate => ('mean-1.96sdev', 'mean+1.96sdev') +``` diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index 3ec9ec21..88bc9034 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -24,11 +24,16 @@ The following aesthetics are recognised by the line layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The line layer sorts the data along its primary axis. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + +Further, the line layer sorts the data along its primary axis before returning it. + If the line has a variable `stroke` or `opacity` aesthetic within groups, the line is broken into segments. Each segment gets the property of the preceding datapoint, so the last datapoint in a group does not transfer these properties. +This behavior is not compatible with aggregation. ## Orientation Line plots are sorted and connected along their primary axis. Since the primary axis cannot be deduced from the mapping it must be specified using the `orientation` setting. If you wish to create a vertical line plot, you need to set `orientation => 'transposed'` to indicate that the primary layer axis follows the second axis of the coordinate system. @@ -89,4 +94,13 @@ VISUALISE x, y FROM data DRAW line MAPPING z AS linewidth SCALE linewidth TO (0, 30) -``` \ No newline at end of file +``` + +Use aggregation to draw min and max lines from a set of observations + +```{ggsql} +VISUALISE Day AS x, Temp AS y FROM ggsql:airquality +DRAW line + SETTING aggregate => ('min', 'max') +DRAW point +```