From 27e3d34bc3a38962ab45f9fb9c3079f15d87ed26 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 6 Feb 2026 01:26:09 -0500 Subject: [PATCH 01/14] Add DAG-based kernel typechecker Implement a Lean 4 kernel typechecker using a DAG representation with BUBS (Bottom-Up Beta Substitution) for efficient reduction. The kernel operates on a mutable DAG rather than tree-based expressions, enabling in-place substitution and shared subterm reduction. 12 modules: doubly-linked list, DAG nodes with 10 pointer variants, BUBS upcopy with 12 parent cases, Expr/DAG conversion, universe level operations, WHNF via trail algorithm, definitional equality with lazy delta/proof irrelevance/eta, type inference, and checking for quotients and inductives. --- src/ix.rs | 1 + src/ix/kernel/convert.rs | 813 +++++++++++++++++ src/ix/kernel/dag.rs | 527 +++++++++++ src/ix/kernel/def_eq.rs | 1298 +++++++++++++++++++++++++++ src/ix/kernel/dll.rs | 214 +++++ src/ix/kernel/error.rs | 59 ++ src/ix/kernel/inductive.rs | 772 ++++++++++++++++ src/ix/kernel/level.rs | 393 +++++++++ src/ix/kernel/mod.rs | 11 + src/ix/kernel/quot.rs | 291 +++++++ src/ix/kernel/tc.rs | 1694 ++++++++++++++++++++++++++++++++++++ src/ix/kernel/upcopy.rs | 659 ++++++++++++++ src/ix/kernel/whnf.rs | 1420 ++++++++++++++++++++++++++++++ 13 files changed, 8152 insertions(+) create mode 100644 src/ix/kernel/convert.rs create mode 100644 src/ix/kernel/dag.rs create mode 100644 src/ix/kernel/def_eq.rs create mode 100644 src/ix/kernel/dll.rs create mode 100644 src/ix/kernel/error.rs create mode 100644 src/ix/kernel/inductive.rs create mode 100644 src/ix/kernel/level.rs create mode 100644 src/ix/kernel/mod.rs create mode 100644 src/ix/kernel/quot.rs create mode 100644 src/ix/kernel/tc.rs create mode 100644 src/ix/kernel/upcopy.rs create mode 100644 src/ix/kernel/whnf.rs diff --git a/src/ix.rs b/src/ix.rs index f200d81b..42d298c2 100644 --- a/src/ix.rs +++ b/src/ix.rs @@ -12,6 +12,7 @@ pub mod env; pub mod graph; pub mod ground; pub mod ixon; +pub mod kernel; pub mod mutual; pub mod store; pub mod strong_ordering; diff --git a/src/ix/kernel/convert.rs b/src/ix/kernel/convert.rs new file mode 100644 index 00000000..90811948 --- /dev/null +++ b/src/ix/kernel/convert.rs @@ -0,0 +1,813 @@ +use core::ptr::NonNull; +use std::collections::BTreeMap; + +use crate::ix::env::{Expr, ExprData, Level, Name}; +use crate::lean::nat::Nat; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Expr -> DAG +// ============================================================================ + +pub fn from_expr(expr: &Expr) -> DAG { + let root_parents = DLL::alloc(ParentPtr::Root); + let head = from_expr_go(expr, 0, &BTreeMap::new(), Some(root_parents)); + DAG { head } +} + +fn from_expr_go( + expr: &Expr, + depth: u64, + ctx: &BTreeMap>, + parents: Option>, +) -> DAGPtr { + match expr.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 < depth { + let level = depth - 1 - idx_u64; + match ctx.get(&level) { + Some(&var_ptr) => { + if let Some(parent_link) = parents { + add_to_parents(DAGPtr::Var(var_ptr), parent_link); + } + DAGPtr::Var(var_ptr) + }, + None => { + let var = alloc_val(Var { + depth: level, + binder: BinderPtr::Free, + parents, + }); + DAGPtr::Var(var) + }, + } + } else { + // Free bound variable (dangling de Bruijn index) + let var = + alloc_val(Var { depth: idx_u64, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + } + }, + + ExprData::Fvar(_name, _) => { + // Encode fvar name into depth as a unique ID. + // We'll recover it during to_expr using a side table. + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + // Store name→var mapping (caller should manage the side table) + DAGPtr::Var(var) + }, + + ExprData::Sort(level, _) => { + let sort = alloc_val(Sort { level: level.clone(), parents }); + DAGPtr::Sort(sort) + }, + + ExprData::Const(name, levels, _) => { + let cnst = alloc_val(Cnst { + name: name.clone(), + levels: levels.clone(), + parents, + }); + DAGPtr::Cnst(cnst) + }, + + ExprData::Lit(lit, _) => { + let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); + DAGPtr::Lit(lit_node) + }, + + ExprData::App(fun_expr, arg_expr, _) => { + let app_ptr = alloc_app( + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let app = &mut *app_ptr.as_ptr(); + let fun_ref_ptr = + NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); + let arg_ref_ptr = + NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); + app.fun = from_expr_go(fun_expr, depth, ctx, Some(fun_ref_ptr)); + app.arg = from_expr_go(arg_expr, depth, ctx, Some(arg_ref_ptr)); + } + DAGPtr::App(app_ptr) + }, + + ExprData::Lam(name, typ, body, bi, _) => { + // Lean Lam → DAG Fun(dom, Lam(bod, var)) + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let fun_ptr = alloc_fun( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); + fun.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + // Set Lam's parent to FunImg + let img_ref_ptr = + NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Fun(fun_ptr) + }, + + ExprData::ForallE(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let pi_ptr = alloc_pi( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); + pi.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + let img_ref_ptr = + NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Pi(pi_ptr) + }, + + ExprData::LetE(name, typ, val, body, non_dep, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let let_ptr = alloc_let( + name.clone(), + *non_dep, + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let typ_ref_ptr = + NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); + let val_ref_ptr = + NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); + let_node.typ = from_expr_go(typ, depth, ctx, Some(typ_ref_ptr)); + let_node.val = from_expr_go(val, depth, ctx, Some(val_ref_ptr)); + + let bod_ref_ptr = + NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let inner_bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(inner_bod_ref_ptr)); + } + DAGPtr::Let(let_ptr) + }, + + ExprData::Proj(type_name, idx, structure, _) => { + let proj_ptr = alloc_proj( + type_name.clone(), + idx.clone(), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + let expr_ref_ptr = + NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); + proj.expr = + from_expr_go(structure, depth, ctx, Some(expr_ref_ptr)); + } + DAGPtr::Proj(proj_ptr) + }, + + // Mdata: strip metadata, convert inner expression + ExprData::Mdata(_, inner, _) => from_expr_go(inner, depth, ctx, parents), + + // Mvar: treat as terminal (shouldn't appear in well-typed terms) + ExprData::Mvar(_name, _) => { + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + }, + } +} + +// ============================================================================ +// Literal clone +// ============================================================================ + +impl Clone for crate::ix::env::Literal { + fn clone(&self) -> Self { + match self { + crate::ix::env::Literal::NatVal(n) => { + crate::ix::env::Literal::NatVal(n.clone()) + }, + crate::ix::env::Literal::StrVal(s) => { + crate::ix::env::Literal::StrVal(s.clone()) + }, + } + } +} + +// ============================================================================ +// DAG -> Expr +// ============================================================================ + +pub fn to_expr(dag: &DAG) -> Expr { + let mut var_map: BTreeMap<*const Var, u64> = BTreeMap::new(); + to_expr_go(dag.head, &mut var_map, 0) +} + +fn to_expr_go( + node: DAGPtr, + var_map: &mut BTreeMap<*const Var, u64>, + depth: u64, +) -> Expr { + unsafe { + match node { + DAGPtr::Var(link) => { + let var = link.as_ptr(); + let var_key = var as *const Var; + if let Some(&bind_depth) = var_map.get(&var_key) { + let idx = depth - bind_depth - 1; + Expr::bvar(Nat::from(idx)) + } else { + // Free variable + Expr::bvar(Nat::from((*var).depth)) + } + }, + + DAGPtr::Sort(link) => { + let sort = &*link.as_ptr(); + Expr::sort(sort.level.clone()) + }, + + DAGPtr::Cnst(link) => { + let cnst = &*link.as_ptr(); + Expr::cnst(cnst.name.clone(), cnst.levels.clone()) + }, + + DAGPtr::Lit(link) => { + let lit = &*link.as_ptr(); + Expr::lit(lit.val.clone()) + }, + + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun = to_expr_go(app.fun, var_map, depth); + let arg = to_expr_go(app.arg, var_map, depth); + Expr::app(fun, arg) + }, + + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let lam = &*fun.img.as_ptr(); + let dom = to_expr_go(fun.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::lam( + fun.binder_name.clone(), + dom, + bod, + fun.binder_info.clone(), + ) + }, + + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let lam = &*pi.img.as_ptr(); + let dom = to_expr_go(pi.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::all( + pi.binder_name.clone(), + dom, + bod, + pi.binder_info.clone(), + ) + }, + + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let lam = &*let_node.bod.as_ptr(); + let typ = to_expr_go(let_node.typ, var_map, depth); + let val = to_expr_go(let_node.val, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::letE( + let_node.binder_name.clone(), + typ, + val, + bod, + let_node.non_dep, + ) + }, + + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let structure = to_expr_go(proj.expr, var_map, depth); + Expr::proj(proj.type_name.clone(), proj.idx.clone(), structure) + }, + + DAGPtr::Lam(link) => { + // Standalone Lam shouldn't appear at the top level, + // but handle it gracefully for completeness. + let lam = &*link.as_ptr(); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + // Wrap in a lambda with anonymous name and default binder info + Expr::lam( + Name::anon(), + Expr::sort(Level::zero()), + bod, + crate::ix::env::BinderInfo::Default, + ) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::env::{BinderInfo, Literal}; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + // ========================================================================== + // Terminal roundtrips + // ========================================================================== + + #[test] + fn roundtrip_sort() { + let e = Expr::sort(Level::succ(Level::zero())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_sort_param() { + let e = Expr::sort(Level::param(mk_name("u"))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const() { + let e = Expr::cnst( + mk_name("Foo"), + vec![Level::zero(), Level::succ(Level::zero())], + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nat_lit() { + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_string_lit() { + let e = Expr::lit(Literal::StrVal("hello world".into())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Binder roundtrips + // ========================================================================== + + #[test] + fn roundtrip_identity_lambda() { + // fun (x : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const_lambda() { + // fun (x : Nat) (y : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_pi() { + // (x : Nat) → Nat + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_dependent_pi() { + // (A : Sort 0) → A → A + let sort0 = Expr::sort(Level::zero()); + let e = Expr::all( + mk_name("A"), + sort0, + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), // A + Expr::bvar(Nat::from(1u64)), // A + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // App roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app() { + // f a + let e = Expr::app( + Expr::cnst(mk_name("f"), vec![]), + nat_zero(), + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nested_app() { + // f a b + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f, a), b); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Let roundtrips + // ========================================================================== + + #[test] + fn roundtrip_let() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_let_non_dep() { + // let x : Nat := Nat.zero in Nat.zero (non_dep = true) + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + nat_zero(), + true, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Proj roundtrips + // ========================================================================== + + #[test] + fn roundtrip_proj() { + let e = Expr::proj(mk_name("Prod"), Nat::from(0u64), nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Complex roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app_of_lambda() { + // (fun x : Nat => x) Nat.zero + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_lambda_in_lambda() { + // fun (f : Nat → Nat) (x : Nat) => f x + let nat_to_nat = Expr::all( + mk_name("_"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::lam( + mk_name("f"), + nat_to_nat, + Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(1u64)), // f + Expr::bvar(Nat::from(0u64)), // x + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_bvar_sharing() { + // fun (x : Nat) => App(x, x) + // Both bvar(0) should map to the same Var in DAG + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_free_bvar() { + // Bvar(5) with no enclosing binder — should survive roundtrip + let e = Expr::bvar(Nat::from(5u64)); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_implicit_binder() { + // fun {x : Nat} => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Implicit, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Property tests (quickcheck) + // ========================================================================== + + /// Generate a random well-formed Expr with bound variables properly scoped. + /// `depth` tracks how many binders are in scope (for valid bvar generation). + fn arb_expr(g: &mut Gen, depth: u64, size: usize) -> Expr { + if size == 0 { + // Terminal: pick among Sort, Const, Lit, or Bvar (if depth > 0) + let choices = if depth > 0 { 5 } else { 4 }; + match usize::arbitrary(g) % choices { + 0 => Expr::sort(arb_level(g, 2)), + 1 => { + let names = ["Nat", "Bool", "String", "Unit", "Int"]; + let idx = usize::arbitrary(g) % names.len(); + Expr::cnst(mk_name(names[idx]), vec![]) + }, + 2 => { + let n = u64::arbitrary(g) % 100; + Expr::lit(Literal::NatVal(Nat::from(n))) + }, + 3 => { + let s: String = String::arbitrary(g); + // Truncate at a char boundary to avoid panics + let s: String = s.chars().take(10).collect(); + Expr::lit(Literal::StrVal(s)) + }, + 4 => { + // Bvar within scope + let idx = u64::arbitrary(g) % depth; + Expr::bvar(Nat::from(idx)) + }, + _ => unreachable!(), + } + } else { + let next = size / 2; + match usize::arbitrary(g) % 5 { + 0 => { + // App + let f = arb_expr(g, depth, next); + let a = arb_expr(g, depth, next); + Expr::app(f, a) + }, + 1 => { + // Lam + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::lam(mk_name("x"), dom, bod, BinderInfo::Default) + }, + 2 => { + // Pi + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::all(mk_name("a"), dom, bod, BinderInfo::Default) + }, + 3 => { + // Let + let typ = arb_expr(g, depth, next); + let val = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next / 2); + Expr::letE(mk_name("v"), typ, val, bod, bool::arbitrary(g)) + }, + 4 => { + // Proj + let idx = u64::arbitrary(g) % 4; + let structure = arb_expr(g, depth, next); + Expr::proj(mk_name("S"), Nat::from(idx), structure) + }, + _ => unreachable!(), + } + } + } + + fn arb_level(g: &mut Gen, size: usize) -> Level { + if size == 0 { + match usize::arbitrary(g) % 3 { + 0 => Level::zero(), + 1 => { + let params = ["u", "v", "w"]; + let idx = usize::arbitrary(g) % params.len(); + Level::param(mk_name(params[idx])) + }, + 2 => Level::succ(Level::zero()), + _ => unreachable!(), + } + } else { + match usize::arbitrary(g) % 3 { + 0 => Level::succ(arb_level(g, size - 1)), + 1 => Level::max(arb_level(g, size / 2), arb_level(g, size / 2)), + 2 => Level::imax(arb_level(g, size / 2), arb_level(g, size / 2)), + _ => unreachable!(), + } + } + } + + /// Newtype wrapper for quickcheck Arbitrary derivation. + #[derive(Clone, Debug)] + struct ArbExpr(Expr); + + impl Arbitrary for ArbExpr { + fn arbitrary(g: &mut Gen) -> Self { + let size = usize::arbitrary(g) % 5; + ArbExpr(arb_expr(g, 0, size)) + } + } + + #[quickcheck] + fn prop_roundtrip(e: ArbExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } + + /// Same test but with expressions generated inside binders. + #[derive(Clone, Debug)] + struct ArbBinderExpr(Expr); + + impl Arbitrary for ArbBinderExpr { + fn arbitrary(g: &mut Gen) -> Self { + let inner_size = usize::arbitrary(g) % 4; + let body = arb_expr(g, 1, inner_size); + let dom = arb_expr(g, 0, 0); + ArbBinderExpr(Expr::lam( + mk_name("x"), + dom, + body, + BinderInfo::Default, + )) + } + } + + #[quickcheck] + fn prop_roundtrip_binder(e: ArbBinderExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } +} diff --git a/src/ix/kernel/dag.rs b/src/ix/kernel/dag.rs new file mode 100644 index 00000000..9837405f --- /dev/null +++ b/src/ix/kernel/dag.rs @@ -0,0 +1,527 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Level, Literal, Name}; +use crate::lean::nat::Nat; +use rustc_hash::FxHashSet; + +use super::dll::DLL; + +pub type Parents = DLL; + +// ============================================================================ +// Pointer types +// ============================================================================ + +#[derive(Debug)] +pub enum DAGPtr { + Var(NonNull), + Sort(NonNull), + Cnst(NonNull), + Lit(NonNull), + Lam(NonNull), + Fun(NonNull), + Pi(NonNull), + App(NonNull), + Let(NonNull), + Proj(NonNull), +} + +impl Copy for DAGPtr {} +impl Clone for DAGPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for DAGPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (DAGPtr::Var(a), DAGPtr::Var(b)) => a == b, + (DAGPtr::Sort(a), DAGPtr::Sort(b)) => a == b, + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => a == b, + (DAGPtr::Lit(a), DAGPtr::Lit(b)) => a == b, + (DAGPtr::Lam(a), DAGPtr::Lam(b)) => a == b, + (DAGPtr::Fun(a), DAGPtr::Fun(b)) => a == b, + (DAGPtr::Pi(a), DAGPtr::Pi(b)) => a == b, + (DAGPtr::App(a), DAGPtr::App(b)) => a == b, + (DAGPtr::Let(a), DAGPtr::Let(b)) => a == b, + (DAGPtr::Proj(a), DAGPtr::Proj(b)) => a == b, + _ => false, + } + } +} +impl Eq for DAGPtr {} + +#[derive(Debug)] +pub enum ParentPtr { + Root, + LamBod(NonNull), + FunDom(NonNull), + FunImg(NonNull), + PiDom(NonNull), + PiImg(NonNull), + AppFun(NonNull), + AppArg(NonNull), + LetTyp(NonNull), + LetVal(NonNull), + LetBod(NonNull), + ProjExpr(NonNull), +} + +impl Copy for ParentPtr {} +impl Clone for ParentPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for ParentPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ParentPtr::Root, ParentPtr::Root) => true, + (ParentPtr::LamBod(a), ParentPtr::LamBod(b)) => a == b, + (ParentPtr::FunDom(a), ParentPtr::FunDom(b)) => a == b, + (ParentPtr::FunImg(a), ParentPtr::FunImg(b)) => a == b, + (ParentPtr::PiDom(a), ParentPtr::PiDom(b)) => a == b, + (ParentPtr::PiImg(a), ParentPtr::PiImg(b)) => a == b, + (ParentPtr::AppFun(a), ParentPtr::AppFun(b)) => a == b, + (ParentPtr::AppArg(a), ParentPtr::AppArg(b)) => a == b, + (ParentPtr::LetTyp(a), ParentPtr::LetTyp(b)) => a == b, + (ParentPtr::LetVal(a), ParentPtr::LetVal(b)) => a == b, + (ParentPtr::LetBod(a), ParentPtr::LetBod(b)) => a == b, + (ParentPtr::ProjExpr(a), ParentPtr::ProjExpr(b)) => a == b, + _ => false, + } + } +} +impl Eq for ParentPtr {} + +/// Binder pointer: from a Var to its binding Lam, or Free. +#[derive(Debug)] +pub enum BinderPtr { + Free, + Lam(NonNull), +} + +impl Copy for BinderPtr {} +impl Clone for BinderPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for BinderPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (BinderPtr::Free, BinderPtr::Free) => true, + (BinderPtr::Lam(a), BinderPtr::Lam(b)) => a == b, + _ => false, + } + } +} + +// ============================================================================ +// Node structs +// ============================================================================ + +/// Bound or free variable. +#[repr(C)] +pub struct Var { + /// De Bruijn level (used during from_expr/to_expr conversion). + pub depth: u64, + /// Points to the binding Lam, or Free for free variables. + pub binder: BinderPtr, + /// Parent pointers. + pub parents: Option>, +} + +impl Copy for Var {} +impl Clone for Var { + fn clone(&self) -> Self { + *self + } +} + +/// Sort node (universe). +#[repr(C)] +pub struct Sort { + pub level: Level, + pub parents: Option>, +} + +/// Constant reference. +#[repr(C)] +pub struct Cnst { + pub name: Name, + pub levels: Vec, + pub parents: Option>, +} + +/// Literal value (Nat or String). +#[repr(C)] +pub struct LitNode { + pub val: Literal, + pub parents: Option>, +} + +/// Internal binding node (spine). Carries an embedded Var. +/// Always appears as the img/bod of Fun/Pi/Let. +#[repr(C)] +pub struct Lam { + pub bod: DAGPtr, + pub bod_ref: Parents, + pub var: Var, + pub parents: Option>, +} + +/// Lean lambda: `fun (name : dom) => bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Fun { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Lean Pi/ForallE: `(name : dom) → bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Pi { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Application node. +#[repr(C)] +pub struct App { + pub fun: DAGPtr, + pub arg: DAGPtr, + pub fun_ref: Parents, + pub arg_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Let binding: `let name : typ := val in bod`. +#[repr(C)] +pub struct LetNode { + pub binder_name: Name, + pub non_dep: bool, + pub typ: DAGPtr, + pub val: DAGPtr, + pub bod: NonNull, + pub typ_ref: Parents, + pub val_ref: Parents, + pub bod_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Projection from a structure. +#[repr(C)] +pub struct ProjNode { + pub type_name: Name, + pub idx: Nat, + pub expr: DAGPtr, + pub expr_ref: Parents, + pub parents: Option>, +} + +/// A DAG with a head node. +pub struct DAG { + pub head: DAGPtr, +} + +// ============================================================================ +// Allocation helpers +// ============================================================================ + +#[inline] +pub fn alloc_val(val: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(val))).unwrap() +} + +pub fn alloc_lam( + depth: u64, + bod: DAGPtr, + parents: Option>, +) -> NonNull { + let lam_ptr = alloc_val(Lam { + bod, + bod_ref: DLL::singleton(ParentPtr::Root), + var: Var { depth, binder: BinderPtr::Free, parents: None }, + parents, + }); + unsafe { + let lam = &mut *lam_ptr.as_ptr(); + lam.bod_ref = DLL::singleton(ParentPtr::LamBod(lam_ptr)); + lam.var.binder = BinderPtr::Lam(lam_ptr); + } + lam_ptr +} + +pub fn alloc_app( + fun: DAGPtr, + arg: DAGPtr, + parents: Option>, +) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +pub fn alloc_fun( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +pub fn alloc_pi( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +pub fn alloc_let( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, + parents: Option>, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +pub fn alloc_proj( + type_name: Name, + idx: Nat, + expr: DAGPtr, + parents: Option>, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Parent pointer helpers +// ============================================================================ + +pub fn get_parents(node: DAGPtr) -> Option> { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents, + DAGPtr::App(p) => (*p.as_ptr()).parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents, + } + } +} + +pub fn set_parents(node: DAGPtr, parents: Option>) { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents = parents, + DAGPtr::App(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents = parents, + } + } +} + +pub fn add_to_parents(node: DAGPtr, parent_link: NonNull) { + unsafe { + match get_parents(node) { + None => set_parents(node, Some(parent_link)), + Some(parents) => { + (*parents.as_ptr()).merge(parent_link); + }, + } + } +} + +// ============================================================================ +// DAG-level helpers +// ============================================================================ + +/// Get a unique key for a DAG node pointer (for use in hash sets). +pub fn dag_ptr_key(node: DAGPtr) -> usize { + match node { + DAGPtr::Var(p) => p.as_ptr() as usize, + DAGPtr::Sort(p) => p.as_ptr() as usize, + DAGPtr::Cnst(p) => p.as_ptr() as usize, + DAGPtr::Lit(p) => p.as_ptr() as usize, + DAGPtr::Lam(p) => p.as_ptr() as usize, + DAGPtr::Fun(p) => p.as_ptr() as usize, + DAGPtr::Pi(p) => p.as_ptr() as usize, + DAGPtr::App(p) => p.as_ptr() as usize, + DAGPtr::Let(p) => p.as_ptr() as usize, + DAGPtr::Proj(p) => p.as_ptr() as usize, + } +} + +/// Free all DAG nodes reachable from the head. +/// Only frees the node structs themselves; DLL parent entries that are +/// inline in parent structs are freed with those structs. The root_parents +/// DLL node (heap-allocated in from_expr) is a small accepted leak. +pub fn free_dag(dag: DAG) { + let mut visited = FxHashSet::default(); + free_dag_nodes(dag.head, &mut visited); +} + +fn free_dag_nodes(node: DAGPtr, visited: &mut FxHashSet) { + let key = dag_ptr_key(node); + if !visited.insert(key) { + return; + } + unsafe { + match node { + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + // Only free separately-allocated free vars; bound vars are + // embedded in their Lam struct and freed with it. + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + free_dag_nodes(lam.bod, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + free_dag_nodes(fun.dom, visited); + free_dag_nodes(DAGPtr::Lam(fun.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + free_dag_nodes(pi.dom, visited); + free_dag_nodes(DAGPtr::Lam(pi.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + free_dag_nodes(app.fun, visited); + free_dag_nodes(app.arg, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + free_dag_nodes(let_node.typ, visited); + free_dag_nodes(let_node.val, visited); + free_dag_nodes(DAGPtr::Lam(let_node.bod), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + free_dag_nodes(proj.expr, visited); + drop(Box::from_raw(link.as_ptr())); + }, + } + } +} diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs new file mode 100644 index 00000000..c2110381 --- /dev/null +++ b/src/ix/kernel/def_eq.rs @@ -0,0 +1,1298 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::level::{eq_antisymm, eq_antisymm_many}; +use super::tc::TypeChecker; +use super::whnf::*; + +/// Result of lazy delta reduction. +enum DeltaResult { + Found(bool), + Exhausted(Expr, Expr), +} + +/// Check definitional equality of two expressions. +pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + if let Some(quick) = def_eq_quick_check(x, y) { + return quick; + } + + let x_n = tc.whnf(x); + let y_n = tc.whnf(y); + + if let Some(quick) = def_eq_quick_check(&x_n, &y_n) { + return quick; + } + + if proof_irrel_eq(&x_n, &y_n, tc) { + return true; + } + + match lazy_delta_step(&x_n, &y_n, tc) { + DeltaResult::Found(result) => result, + DeltaResult::Exhausted(x_e, y_e) => { + def_eq_const(&x_e, &y_e) + || def_eq_proj(&x_e, &y_e, tc) + || def_eq_app(&x_e, &y_e, tc) + || def_eq_binder_full(&x_e, &y_e, tc) + || try_eta_expansion(&x_e, &y_e, tc) + || try_eta_struct(&x_e, &y_e, tc) + || is_def_eq_unit_like(&x_e, &y_e, tc) + }, + } +} + +/// Quick syntactic checks. +fn def_eq_quick_check(x: &Expr, y: &Expr) -> Option { + if x == y { + return Some(true); + } + if let Some(r) = def_eq_sort(x, y) { + return Some(r); + } + if let Some(r) = def_eq_binder(x, y) { + return Some(r); + } + None +} + +fn def_eq_sort(x: &Expr, y: &Expr) -> Option { + match (x.as_data(), y.as_data()) { + (ExprData::Sort(l, _), ExprData::Sort(r, _)) => { + Some(eq_antisymm(l, r)) + }, + _ => None, + } +} + +/// Check if two binder expressions (Pi/Lam) are definitionally equal. +/// Always defers to full checking after WHNF, since binder types could be +/// definitionally equal without being syntactically identical. +fn def_eq_binder(_x: &Expr, _y: &Expr) -> Option { + None +} + +fn def_eq_const(x: &Expr, y: &Expr) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Const(xn, xl, _), + ExprData::Const(yn, yl, _), + ) => xn == yn && eq_antisymm_many(xl, yl), + _ => false, + } +} + +fn def_eq_proj(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Proj(_, idx_l, structure_l, _), + ExprData::Proj(_, idx_r, structure_r, _), + ) => idx_l == idx_r && def_eq(structure_l, structure_r, tc), + _ => false, + } +} + +fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let (f1, args1) = unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + + if !def_eq(&f1, &f2, tc) { + return false; + } + args1.iter().zip(args2.iter()).all(|(a, b)| def_eq(a, b, tc)) +} + +/// Full recursive binder comparison: two Pi or two Lam types with +/// definitionally equal domain types and bodies (ignoring binder names). +fn def_eq_binder_full( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::ForallE(_, t1, b1, _, _), + ExprData::ForallE(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + ( + ExprData::Lam(_, t1, b1, _, _), + ExprData::Lam(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + _ => false, + } +} + +/// Proof irrelevance: if both x and y are proofs of the same proposition, +/// they are definitionally equal. +fn proof_irrel_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&x_ty, tc) { + return false; + } + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&y_ty, tc) { + return false; + } + def_eq(&x_ty, &y_ty, tc) +} + +/// Check if an expression's type is Prop (Sort 0). +fn is_proposition(ty: &Expr, tc: &mut TypeChecker) -> bool { + let ty_of_ty = match tc.infer(ty) { + Ok(t) => t, + Err(_) => return false, + }; + let whnfd = tc.whnf(&ty_of_ty); + matches!(whnfd.as_data(), ExprData::Sort(l, _) if super::level::is_zero(l)) +} + +/// Eta expansion: `fun x => f x` ≡ `f` when `f : (x : A) → B`. +fn try_eta_expansion(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_expansion_aux(x, y, tc) || try_eta_expansion_aux(y, x, tc) +} + +fn try_eta_expansion_aux( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + if let ExprData::Lam(_, _, _, _, _) = x.as_data() { + let y_ty = match tc.infer(y) { + Ok(t) => t, + Err(_) => return false, + }; + let y_ty_whnf = tc.whnf(&y_ty); + if let ExprData::ForallE(name, binder_type, _, bi, _) = + y_ty_whnf.as_data() + { + // eta-expand y: fun x => y x + let body = Expr::app(y.clone(), Expr::bvar(crate::lean::nat::Nat::from(0))); + let expanded = Expr::lam( + name.clone(), + binder_type.clone(), + body, + bi.clone(), + ); + return def_eq(x, &expanded, tc); + } + } + false +} + +/// Check if a name refers to a structure-like inductive: +/// exactly 1 constructor, not recursive, no indices. +fn is_structure_like(name: &Name, env: &Env) -> bool { + match env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + iv.ctors.len() == 1 && !iv.is_rec && iv.num_indices == Nat::ZERO + }, + _ => false, + } +} + +/// Structure eta: `p =def= S.mk (S.1 p) (S.2 p)` when S is a +/// single-constructor non-recursive inductive with no indices. +fn try_eta_struct(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_struct_core(x, y, tc) || try_eta_struct_core(y, x, tc) +} + +/// Try to decompose `s` as a constructor application for a structure-like +/// type, then check that each field matches the corresponding projection of `t`. +fn try_eta_struct_core( + t: &Expr, + s: &Expr, + tc: &mut TypeChecker, +) -> bool { + let (head, args) = unfold_apps(s); + let ctor_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + + let ctor_info = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return false, + }; + + if !is_structure_like(&ctor_info.induct, tc.env) { + return false; + } + + let num_params = ctor_info.num_params.to_u64().unwrap() as usize; + let num_fields = ctor_info.num_fields.to_u64().unwrap() as usize; + + if args.len() != num_params + num_fields { + return false; + } + + for i in 0..num_fields { + let field = &args[num_params + i]; + let proj = Expr::proj( + ctor_info.induct.clone(), + Nat::from(i as u64), + t.clone(), + ); + if !def_eq(field, &proj, tc) { + return false; + } + } + + true +} + +/// Unit-like equality: types with a single zero-field constructor have all +/// inhabitants definitionally equal. +fn is_def_eq_unit_like(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + // Types must be def-eq + if !def_eq(&x_ty, &y_ty, tc) { + return false; + } + // Check if the type is a unit-like inductive + let whnf_ty = tc.whnf(&x_ty); + let (head, _) = unfold_apps(&whnf_ty); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + match tc.env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + if iv.ctors.len() != 1 { + return false; + } + // Check single constructor has zero fields + if let Some(ConstantInfo::CtorInfo(c)) = tc.env.get(&iv.ctors[0]) { + c.num_fields == Nat::ZERO + } else { + false + } + }, + _ => false, + } +} + +/// Lazy delta reduction: unfold definitions step by step. +fn lazy_delta_step( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> DeltaResult { + let mut x = x.clone(); + let mut y = y.clone(); + + loop { + let x_def = get_applied_def(&x, tc.env); + let y_def = get_applied_def(&y, tc.env); + + match (&x_def, &y_def) { + (None, None) => return DeltaResult::Exhausted(x, y), + (Some(_), None) => { + x = delta(&x, tc); + }, + (None, Some(_)) => { + y = delta(&y, tc); + }, + (Some((x_name, x_hint)), Some((y_name, y_hint))) => { + // Same name and same height: try congruence first + if x_name == y_name && x_hint == y_hint { + if def_eq_app(&x, &y, tc) { + return DeltaResult::Found(true); + } + x = delta(&x, tc); + y = delta(&y, tc); + } else if hint_lt(x_hint, y_hint) { + y = delta(&y, tc); + } else { + x = delta(&x, tc); + } + }, + } + + if let Some(quick) = def_eq_quick_check(&x, &y) { + return DeltaResult::Found(quick); + } + } +} + +/// Get the name and reducibility hint of an applied definition. +fn get_applied_def( + e: &Expr, + env: &Env, +) -> Option<(Name, ReducibilityHints)> { + let (head, _) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + let ci = env.get(name)?; + match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + None + } else { + Some((name.clone(), d.hints)) + } + }, + ConstantInfo::ThmInfo(_) => { + Some((name.clone(), ReducibilityHints::Opaque)) + }, + _ => None, + } +} + +/// Unfold a definition and do cheap WHNF. +fn delta(e: &Expr, tc: &mut TypeChecker) -> Expr { + match try_unfold_def(e, tc.env) { + Some(unfolded) => tc.whnf(&unfolded), + None => e.clone(), + } +} + +/// Compare reducibility hints for ordering. +fn hint_lt(a: &ReducibilityHints, b: &ReducibilityHints) -> bool { + match (a, b) { + (ReducibilityHints::Opaque, _) => true, + (_, ReducibilityHints::Opaque) => false, + (ReducibilityHints::Abbrev, _) => false, + (_, ReducibilityHints::Abbrev) => true, + (ReducibilityHints::Regular(ha), ReducibilityHints::Regular(hb)) => { + ha < hb + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + /// Minimal env with Nat, Nat.zero, Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + // ========================================================================== + // Reflexivity + // ========================================================================== + + #[test] + fn def_eq_reflexive_sort() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_const() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_lambda() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e, &e)); + } + + // ========================================================================== + // Sort equality + // ========================================================================== + + #[test] + fn def_eq_sort_max_comm() { + // Sort(max u v) =def= Sort(max v u) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let s1 = Expr::sort(Level::max(u.clone(), v.clone())); + let s2 = Expr::sort(Level::max(v, u)); + assert!(tc.def_eq(&s1, &s2)); + } + + #[test] + fn def_eq_sort_not_equal() { + // Sort(0) ≠ Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s0 = Expr::sort(Level::zero()); + let s1 = Expr::sort(Level::succ(Level::zero())); + assert!(!tc.def_eq(&s0, &s1)); + } + + // ========================================================================== + // Alpha equivalence (same structure, different binder names) + // ========================================================================== + + #[test] + fn def_eq_alpha_lambda() { + // fun (x : Nat) => x =def= fun (y : Nat) => y + // (de Bruijn indices are the same, so this is syntactic equality) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e2 = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + #[test] + fn def_eq_alpha_pi() { + // (x : Nat) → Nat =def= (y : Nat) → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e2 = Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + // ========================================================================== + // Beta equivalence + // ========================================================================== + + #[test] + fn def_eq_beta() { + // (fun x : Nat => x) Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let lhs = Expr::app(id_fn, nat_zero()); + let rhs = nat_zero(); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_beta_nested() { + // (fun x y : Nat => x) Nat.zero Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + inner, + BinderInfo::Default, + ); + let lhs = Expr::app(Expr::app(k_fn, nat_zero()), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Delta equivalence (definition unfolding) + // ========================================================================== + + #[test] + fn def_eq_delta() { + // def myZero := Nat.zero + // myZero =def= Nat.zero + let mut env = mk_nat_env(); + let my_zero = mk_name("myZero"); + env.insert( + my_zero.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_zero.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_zero.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(my_zero, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + #[test] + fn def_eq_delta_both_sides() { + // def a := Nat.zero, def b := Nat.zero + // a =def= b + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(a, vec![]); + let rhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Zeta equivalence (let unfolding) + // ========================================================================== + + #[test] + fn def_eq_zeta() { + // (let x : Nat := Nat.zero in x) =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Negative tests + // ========================================================================== + + #[test] + fn def_eq_different_consts() { + // Nat ≠ String + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let nat = nat_type(); + let string = Expr::cnst(mk_name("String"), vec![]); + assert!(!tc.def_eq(&nat, &string)); + } + + #[test] + fn def_eq_different_nat_levels() { + // Nat.zero ≠ Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let zero = nat_zero(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + assert!(!tc.def_eq(&zero, &succ)); + } + + #[test] + fn def_eq_app_congruence() { + // f a =def= f a (for same f, same a) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let f = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let a = nat_zero(); + let lhs = Expr::app(f.clone(), a.clone()); + let rhs = Expr::app(f, a); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_app_different_args() { + // Nat.succ Nat.zero ≠ Nat.succ (Nat.succ Nat.zero) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let lhs = Expr::app(succ.clone(), nat_zero()); + let rhs = + Expr::app(succ.clone(), Expr::app(succ, nat_zero())); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Const-level equality + // ========================================================================== + + #[test] + fn def_eq_const_levels() { + // A.{max u v} =def= A.{max v u} + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let lhs = Expr::cnst(a_name.clone(), vec![Level::max(u.clone(), v.clone()), Level::zero()]); + let rhs = Expr::cnst(a_name, vec![Level::max(v, u), Level::zero()]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Hint ordering + // ========================================================================== + + #[test] + fn hint_lt_opaque_less_than_all() { + assert!(hint_lt(&ReducibilityHints::Opaque, &ReducibilityHints::Abbrev)); + assert!(hint_lt( + &ReducibilityHints::Opaque, + &ReducibilityHints::Regular(0) + )); + } + + #[test] + fn hint_lt_abbrev_greatest() { + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Opaque + )); + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Regular(100) + )); + } + + #[test] + fn hint_lt_regular_ordering() { + assert!(hint_lt( + &ReducibilityHints::Regular(1), + &ReducibilityHints::Regular(2) + )); + assert!(!hint_lt( + &ReducibilityHints::Regular(2), + &ReducibilityHints::Regular(1) + )); + } + + // ========================================================================== + // Eta expansion + // ========================================================================== + + #[test] + fn def_eq_eta_lam_vs_const() { + // fun x : Nat => Nat.succ x =def= Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&eta_expanded, &succ)); + } + + #[test] + fn def_eq_eta_symmetric() { + // Nat.succ =def= fun x : Nat => Nat.succ x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&succ, &eta_expanded)); + } + + // ========================================================================== + // Lazy delta step with different heights + // ========================================================================== + + #[test] + fn def_eq_lazy_delta_higher_unfolds_first() { + // def a := Nat.zero (height 1) + // def b := a (height 2) + // b =def= Nat.zero should work by unfolding b first (higher height) + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Regular(1), + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Regular(2), + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Transitivity through delta + // ========================================================================== + + #[test] + fn def_eq_transitive_delta() { + // def a := Nat.zero, def b := Nat.zero + // def c := Nat.zero + // a =def= b, a =def= c, b =def= c + let mut env = mk_nat_env(); + for name_str in &["a", "b", "c"] { + let n = mk_name(name_str); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + } + let mut tc = TypeChecker::new(&env); + let a = Expr::cnst(mk_name("a"), vec![]); + let b = Expr::cnst(mk_name("b"), vec![]); + let c = Expr::cnst(mk_name("c"), vec![]); + assert!(tc.def_eq(&a, &b)); + assert!(tc.def_eq(&a, &c)); + assert!(tc.def_eq(&b, &c)); + } + + // ========================================================================== + // Nat literal equality through WHNF + // ========================================================================== + + #[test] + fn def_eq_nat_lit_same() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(42u64))); + assert!(tc.def_eq(&a, &b)); + } + + #[test] + fn def_eq_nat_lit_different() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(1u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(2u64))); + assert!(!tc.def_eq(&a, &b)); + } + + // ========================================================================== + // Beta-delta combined + // ========================================================================== + + #[test] + fn def_eq_beta_delta_combined() { + // def myId := fun x : Nat => x + // myId Nat.zero =def= Nat.zero + let mut env = mk_nat_env(); + let my_id = mk_name("myId"); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + my_id.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_id.clone(), + level_params: vec![], + typ: fun_ty, + }, + value: Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_id.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::app(Expr::cnst(my_id, vec![]), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Structure eta + // ========================================================================== + + /// Build an env with Nat + Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_ctor_name = mk_name2("Prod", "mk"); + + // Prod.{u,v} (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_ctor_name.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_ctor_name, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn eta_struct_ctor_eq_proj() { + // Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) =def= p + // where p is a free variable of type Prod Nat Nat + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&ctor_app, &p)); + } + + #[test] + fn eta_struct_symmetric() { + // p =def= Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&p, &ctor_app)); + } + + #[test] + fn eta_struct_nat_not_structure_like() { + // Nat has 2 constructors, so it is NOT structure-like + let env = mk_nat_env(); + assert!(!super::is_structure_like(&mk_name("Nat"), &env)); + } + + // ========================================================================== + // Binder full comparison + // ========================================================================== + + #[test] + fn def_eq_binder_full_different_domains() { + // (x : myNat) → Nat =def= (x : Nat) → Nat + // where myNat unfolds to Nat + let mut env = mk_nat_env(); + let my_nat = mk_name("myNat"); + env.insert( + my_nat.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_nat.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: nat_type(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_nat.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::all( + mk_name("x"), + Expr::cnst(my_nat, vec![]), + nat_type(), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Proj congruence + // ========================================================================== + + #[test] + fn def_eq_proj_congruence() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(0u64), s); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_proj_different_idx() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(1u64), s); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Unit-like equality + // ========================================================================== + + #[test] + fn def_eq_unit_like() { + // Unit-type: single ctor, zero fields + // Any two inhabitants should be def-eq + let mut env = mk_nat_env(); + let unit_name = mk_name("Unit"); + let unit_star = mk_name2("Unit", "star"); + + env.insert( + unit_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: unit_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![unit_name.clone()], + ctors: vec![unit_star.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + unit_star.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: unit_star.clone(), + level_params: vec![], + typ: Expr::cnst(unit_name.clone(), vec![]), + }, + induct: unit_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + + // Two distinct fvars of type Unit should be def-eq + let unit_ty = Expr::cnst(unit_name, vec![]); + let x = tc.mk_local(&mk_name("x"), &unit_ty); + let y = tc.mk_local(&mk_name("y"), &unit_ty); + assert!(tc.def_eq(&x, &y)); + } +} diff --git a/src/ix/kernel/dll.rs b/src/ix/kernel/dll.rs new file mode 100644 index 00000000..07dfe135 --- /dev/null +++ b/src/ix/kernel/dll.rs @@ -0,0 +1,214 @@ +use core::marker::PhantomData; +use core::ptr::NonNull; + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +pub struct DLL { + pub next: Option>>, + pub prev: Option>>, + pub elem: T, +} + +pub struct Iter<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &*node.as_ptr() }; + self.next = deref.next; + &deref.elem + }) + } +} + +pub struct IterMut<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &mut *node.as_ptr() }; + self.next = deref.next; + &mut deref.elem + }) + } +} + +impl DLL { + #[inline] + pub fn singleton(elem: T) -> Self { + DLL { next: None, prev: None, elem } + } + + #[inline] + pub fn alloc(elem: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(Self::singleton(elem)))).unwrap() + } + + #[inline] + pub fn is_singleton(dll: Option>) -> bool { + dll.is_some_and(|dll| unsafe { + let dll = &*dll.as_ptr(); + dll.prev.is_none() && dll.next.is_none() + }) + } + + #[inline] + pub fn is_empty(dll: Option>) -> bool { + dll.is_none() + } + + pub fn merge(&mut self, node: NonNull) { + unsafe { + (*node.as_ptr()).prev = self.prev; + (*node.as_ptr()).next = NonNull::new(self); + if let Some(ptr) = self.prev { + (*ptr.as_ptr()).next = Some(node); + } + self.prev = Some(node); + } + } + + pub fn unlink_node(&self) -> Option> { + unsafe { + let next = self.next; + let prev = self.prev; + if let Some(next) = next { + (*next.as_ptr()).prev = prev; + } + if let Some(prev) = prev { + (*prev.as_ptr()).next = next; + } + prev.or(next) + } + } + + pub fn first(mut node: NonNull) -> NonNull { + loop { + let prev = unsafe { (*node.as_ptr()).prev }; + match prev { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn last(mut node: NonNull) -> NonNull { + loop { + let next = unsafe { (*node.as_ptr()).next }; + match next { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn concat(dll: NonNull, rest: Option>) { + let last = DLL::last(dll); + let first = rest.map(DLL::first); + unsafe { + (*last.as_ptr()).next = first; + } + if let Some(first) = first { + unsafe { + (*first.as_ptr()).prev = Some(last); + } + } + } + + #[inline] + pub fn iter_option(dll: Option>) -> Iter<'static, T> { + Iter { next: dll.map(DLL::first), marker: PhantomData } + } + + #[inline] + #[allow(dead_code)] + pub fn iter_mut_option(dll: Option>) -> IterMut<'static, T> { + IterMut { next: dll.map(DLL::first), marker: PhantomData } + } + + #[allow(unsafe_op_in_unsafe_fn)] + pub unsafe fn free_all(dll: Option>) { + if let Some(start) = dll { + let first = DLL::first(start); + let mut current = Some(first); + while let Some(node) = current { + let next = (*node.as_ptr()).next; + drop(Box::from_raw(node.as_ptr())); + current = next; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn to_vec(dll: Option>>) -> Vec { + DLL::iter_option(dll).copied().collect() + } + + #[test] + fn test_singleton() { + let dll = DLL::alloc(42); + assert!(DLL::is_singleton(Some(dll))); + unsafe { + assert_eq!((*dll.as_ptr()).elem, 42); + drop(Box::from_raw(dll.as_ptr())); + } + } + + #[test] + fn test_is_empty() { + assert!(DLL::::is_empty(None)); + let dll = DLL::alloc(1); + assert!(!DLL::is_empty(Some(dll))); + unsafe { DLL::free_all(Some(dll)) }; + } + + #[test] + fn test_merge() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + (*a.as_ptr()).merge(b); + assert_eq!(to_vec(Some(a)), vec![2, 1]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_concat() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + DLL::concat(a, Some(b)); + assert_eq!(to_vec(Some(a)), vec![1, 2]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_unlink_singleton() { + unsafe { + let dll = DLL::alloc(42); + let remaining = (*dll.as_ptr()).unlink_node(); + assert!(remaining.is_none()); + drop(Box::from_raw(dll.as_ptr())); + } + } +} diff --git a/src/ix/kernel/error.rs b/src/ix/kernel/error.rs new file mode 100644 index 00000000..33816246 --- /dev/null +++ b/src/ix/kernel/error.rs @@ -0,0 +1,59 @@ +use crate::ix::env::{Expr, Name}; + +#[derive(Debug)] +pub enum TcError { + TypeExpected { + expr: Expr, + inferred: Expr, + }, + FunctionExpected { + expr: Expr, + inferred: Expr, + }, + TypeMismatch { + expected: Expr, + found: Expr, + expr: Expr, + }, + DefEqFailure { + lhs: Expr, + rhs: Expr, + }, + UnknownConst { + name: Name, + }, + DuplicateUniverse { + name: Name, + }, + FreeBoundVariable { + idx: u64, + }, + KernelException { + msg: String, + }, +} + +impl std::fmt::Display for TcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TcError::TypeExpected { .. } => write!(f, "type expected"), + TcError::FunctionExpected { .. } => write!(f, "function expected"), + TcError::TypeMismatch { .. } => write!(f, "type mismatch"), + TcError::DefEqFailure { .. } => { + write!(f, "definitional equality failure") + }, + TcError::UnknownConst { name } => { + write!(f, "unknown constant: {}", name.pretty()) + }, + TcError::DuplicateUniverse { name } => { + write!(f, "duplicate universe: {}", name.pretty()) + }, + TcError::FreeBoundVariable { idx } => { + write!(f, "free bound variable at index {}", idx) + }, + TcError::KernelException { msg } => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for TcError {} diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs new file mode 100644 index 00000000..a06ed819 --- /dev/null +++ b/src/ix/kernel/inductive.rs @@ -0,0 +1,772 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::error::TcError; +use super::level; +use super::tc::TypeChecker; +use super::whnf::{inst, unfold_apps}; + +type TcResult = Result; + +/// Validate an inductive type declaration. +/// Performs structural checks: constructors exist, belong to this inductive, +/// and have well-formed types. Mutual types are verified to exist. +pub fn check_inductive( + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Verify the type is well-formed + tc.check_declar_info(&ind.cnst)?; + + // Verify all constructors exist and belong to this inductive + for ctor_name in &ind.ctors { + let ctor_ci = tc.env.get(ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + let ctor = match ctor_ci { + ConstantInfo::CtorInfo(c) => c, + _ => { + return Err(TcError::KernelException { + msg: format!( + "{} is not a constructor", + ctor_name.pretty() + ), + }) + }, + }; + // Verify constructor's induct field matches + if ctor.induct != ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} belongs to {} but expected {}", + ctor_name.pretty(), + ctor.induct.pretty(), + ind.cnst.name.pretty() + ), + }); + } + // Verify constructor type is well-formed + tc.check_declar_info(&ctor.cnst)?; + } + + // Verify constructor return types and positivity + for ctor_name in &ind.ctors { + let ctor = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => continue, // already checked above + }; + check_ctor_return_type(ctor, ind, tc)?; + if !ind.is_unsafe { + check_ctor_positivity(ctor, ind, tc)?; + check_field_universe_constraints(ctor, ind, tc)?; + } + } + + // Verify all mutual types exist + for name in &ind.all { + if tc.env.get(name).is_none() { + return Err(TcError::UnknownConst { name: name.clone() }); + } + } + + Ok(()) +} + +/// Validate that a recursor's K flag is consistent with the inductive's structure. +/// K-target requires: non-mutual, in Prop, single constructor, zero fields. +/// If `rec.k == true` but conditions don't hold, reject. +pub fn validate_k_flag( + rec: &RecursorVal, + env: &Env, +) -> TcResult<()> { + if !rec.k { + return Ok(()); // conservative false is always fine + } + + // Must be non-mutual: `rec.all` should have exactly 1 inductive + if rec.all.len() != 1 { + return Err(TcError::KernelException { + msg: "recursor claims K but inductive is mutual".into(), + }); + } + + let ind_name = &rec.all[0]; + let ind = match env.get(ind_name) { + Some(ConstantInfo::InductInfo(iv)) => iv, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not an inductive", + ind_name.pretty() + ), + }) + }, + }; + + // Must be in Prop (Sort 0) + // Walk type telescope past all binders to get the sort + let mut ty = ind.cnst.typ.clone(); + loop { + match ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ty = body.clone(); + }, + _ => break, + } + } + let is_prop = match ty.as_data() { + ExprData::Sort(l, _) => level::is_zero(l), + _ => false, + }; + if !is_prop { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not in Prop", + ind_name.pretty() + ), + }); + } + + // Must have single constructor + if ind.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} has {} constructors (need 1)", + ind_name.pretty(), + ind.ctors.len() + ), + }); + } + + // Constructor must have zero fields (all args are params) + let ctor_name = &ind.ctors[0]; + if let Some(ConstantInfo::CtorInfo(c)) = env.get(ctor_name) { + if c.num_fields != Nat::ZERO { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but constructor {} has {} fields (need 0)", + ctor_name.pretty(), + c.num_fields + ), + }); + } + } + + Ok(()) +} + +/// Check if an expression mentions a constant by name. +fn expr_mentions_const(e: &Expr, name: &Name) -> bool { + match e.as_data() { + ExprData::Const(n, _, _) => n == name, + ExprData::App(f, a, _) => { + expr_mentions_const(f, name) || expr_mentions_const(a, name) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + expr_mentions_const(t, name) || expr_mentions_const(b, name) + }, + ExprData::LetE(_, t, v, b, _, _) => { + expr_mentions_const(t, name) + || expr_mentions_const(v, name) + || expr_mentions_const(b, name) + }, + ExprData::Proj(_, _, s, _) => expr_mentions_const(s, name), + ExprData::Mdata(_, inner, _) => expr_mentions_const(inner, name), + _ => false, + } +} + +/// Check that no inductive name from `ind.all` appears in a negative position +/// in the constructor's field types. +fn check_ctor_positivity( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ty = ctor.cnst.typ.clone(); + + // Skip parameter binders + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => return Ok(()), // fewer binders than params — odd but not our problem + } + } + + // For each remaining field, check its domain for positivity + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // The domain is the field type — check strict positivity + check_strict_positivity(binder_type, &ind.all, tc)?; + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Check strict positivity of a field type w.r.t. a set of inductive names. +/// +/// Strict positivity for `T` w.r.t. `I`: +/// - If `T` doesn't mention `I`, OK. +/// - If `T = I args...`, OK (the inductive itself at the head). +/// - If `T = (x : A) → B`, then `A` must NOT mention `I` at all, +/// and `B` must satisfy strict positivity w.r.t. `I`. +/// - Otherwise (I appears but not at head and not in Pi), reject. +fn check_strict_positivity( + ty: &Expr, + ind_names: &[Name], + tc: &mut TypeChecker, +) -> TcResult<()> { + let whnf_ty = tc.whnf(ty); + + // If no inductive name is mentioned, we're fine + if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { + return Ok(()); + } + + match whnf_ty.as_data() { + ExprData::ForallE(_, domain, body, _, _) => { + // Domain must NOT mention any inductive name + for ind_name in ind_names { + if expr_mentions_const(domain, ind_name) { + return Err(TcError::KernelException { + msg: format!( + "inductive {} occurs in negative position (strict positivity violation)", + ind_name.pretty() + ), + }); + } + } + // Recurse into body + check_strict_positivity(body, ind_names, tc) + }, + _ => { + // The inductive is mentioned and we're not in a Pi — check if + // it's simply an application `I args...` (which is OK). + let (head, _) = unfold_apps(&whnf_ty); + match head.as_data() { + ExprData::Const(name, _, _) + if ind_names.iter().any(|n| n == name) => + { + Ok(()) + }, + _ => Err(TcError::KernelException { + msg: "inductive type occurs in a non-positive position".into(), + }), + } + }, + } +} + +/// Check that constructor field types live in universes ≤ the inductive's universe. +fn check_field_universe_constraints( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Walk the inductive type telescope past num_params binders to find the sort level. + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ind_ty = ind.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + // Skip remaining binders (indices) to get to the target sort + loop { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => { + ind_ty = whnf_ty; + break; + }, + } + } + let ind_level = match ind_ty.as_data() { + ExprData::Sort(l, _) => l.clone(), + _ => return Ok(()), // can't extract sort, skip + }; + + // Walk ctor type, skip params, then check each field + let mut ctor_ty = ctor.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + + // For each remaining field binder, check its sort level ≤ ind_level + loop { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // Infer the sort of the binder_type + if let Ok(field_level) = tc.infer_sort_of(binder_type) { + if !level::leq(&field_level, &ind_level) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} field type lives in a universe larger than the inductive's universe", + ctor.cnst.name.pretty() + ), + }); + } + } + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Verify that a constructor's return type targets the parent inductive. +/// Walks the constructor type telescope, then checks that the resulting +/// type is an application of the parent inductive with at least `num_params` args. +fn check_ctor_return_type( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let mut ty = ctor.cnst.typ.clone(); + + // Walk past all Pi binders + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => { + ty = whnf_ty; + break; + }, + } + } + + // The return type should be `I args...` + let (head, args) = unfold_apps(&ty); + let head_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type head is not a constant", + ctor.cnst.name.pretty() + ), + }) + }, + }; + + if head_name != &ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} returns {} but should return {}", + ctor.cnst.name.pretty(), + head_name.pretty(), + ind.cnst.name.pretty() + ), + }); + } + + let num_params = ind.num_params.to_u64().unwrap() as usize; + if args.len() < num_params { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type has {} args but inductive has {} params", + ctor.cnst.name.pretty(), + args.len(), + num_params + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn check_nat_inductive_passes() { + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn check_ctor_wrong_return_type() { + let mut env = mk_nat_env(); + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name.clone()], + ctors: vec![mk_name2("Bool", "bad")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // Constructor returns Nat instead of Bool + let bad_ctor_name = mk_name2("Bool", "bad"); + env.insert( + bad_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: bad_ctor_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: bool_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&bool_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Positivity checking + // ========================================================================== + + fn bool_type() -> Expr { + Expr::cnst(mk_name("Bool"), vec![]) + } + + /// Helper to make a simple inductive + ctor env for positivity tests. + fn mk_single_ctor_env( + ind_name: &str, + ctor_name: &str, + ctor_typ: Expr, + num_fields: u64, + ) -> Env { + let mut env = mk_nat_env(); + // Bool + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name], + ctors: vec![mk_name2("Bool", "true"), mk_name2("Bool", "false")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let iname = mk_name(ind_name); + let cname = mk_name2(ind_name, ctor_name); + env.insert( + iname.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: iname.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![iname.clone()], + ctors: vec![cname.clone()], + num_nested: Nat::from(0u64), + is_rec: num_fields > 0, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + cname.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: cname, + level_params: vec![], + typ: ctor_typ, + }, + induct: iname, + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(num_fields), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn positivity_bad_negative() { + // inductive Bad | mk : (Bad → Bool) → Bad + let bad = mk_name("Bad"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("x"), Expr::cnst(bad, vec![]), bool_type(), BinderInfo::Default), + Expr::cnst(mk_name("Bad"), vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + #[test] + fn positivity_nat_succ_ok() { + // Nat.succ : Nat → Nat (positive) + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_tree_positive_function() { + // inductive Tree | node : (Nat → Tree) → Tree + // Tree appears positive in `Nat → Tree` + let tree = mk_name("Tree"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("n"), nat_type(), Expr::cnst(tree.clone(), vec![]), BinderInfo::Default), + Expr::cnst(tree, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Tree", "node", ctor_ty, 1); + let ind = match env.get(&mk_name("Tree")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_depth2_negative() { + // inductive Bad2 | mk : ((Bad2 → Nat) → Nat) → Bad2 + // Bad2 appears in negative position at depth 2 + let bad2 = mk_name("Bad2"); + let inner = Expr::all( + mk_name("g"), + Expr::all(mk_name("x"), Expr::cnst(bad2.clone(), vec![]), nat_type(), BinderInfo::Default), + nat_type(), + BinderInfo::Default, + ); + let ctor_ty = Expr::all( + mk_name("f"), + inner, + Expr::cnst(bad2, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad2", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad2")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Field universe constraints + // ========================================================================== + + #[test] + fn field_universe_nat_field_in_type1_ok() { + // Nat : Sort 1, Nat.succ field is Nat : Sort 1 — leq(1, 1) passes + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn field_universe_prop_inductive_with_type_field_fails() { + // inductive PropBad : Prop | mk : Nat → PropBad + // PropBad lives in Sort 0, Nat lives in Sort 1 — leq(1, 0) fails + let mut env = mk_nat_env(); + let pb_name = mk_name("PropBad"); + let pb_mk = mk_name2("PropBad", "mk"); + env.insert( + pb_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: pb_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), // Prop + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![pb_name.clone()], + ctors: vec![pb_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + pb_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: pb_mk, + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), // Nat : Sort 1 + Expr::cnst(pb_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: pb_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&pb_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } +} diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs new file mode 100644 index 00000000..90931ca6 --- /dev/null +++ b/src/ix/kernel/level.rs @@ -0,0 +1,393 @@ +use crate::ix::env::{Expr, ExprData, Level, LevelData, Name}; + +/// Simplify a universe level expression. +pub fn simplify(l: &Level) -> Level { + match l.as_data() { + LevelData::Zero(_) | LevelData::Param(..) | LevelData::Mvar(..) => { + l.clone() + }, + LevelData::Succ(inner, _) => { + let inner_s = simplify(inner); + Level::succ(inner_s) + }, + LevelData::Max(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + combining(&a_s, &b_s) + }, + LevelData::Imax(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + if is_zero(&a_s) || is_one(&a_s) { + b_s + } else { + match b_s.as_data() { + LevelData::Zero(_) => b_s, + LevelData::Succ(..) => combining(&a_s, &b_s), + _ => Level::imax(a_s, b_s), + } + } + }, + } +} + +/// Combine two levels, simplifying Max(Zero, x) = x and +/// Max(Succ a, Succ b) = Succ(Max(a, b)). +fn combining(l: &Level, r: &Level) -> Level { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) => r.clone(), + (_, LevelData::Zero(_)) => l.clone(), + (LevelData::Succ(a, _), LevelData::Succ(b, _)) => { + let inner = combining(a, b); + Level::succ(inner) + }, + _ => Level::max(l.clone(), r.clone()), + } +} + +fn is_one(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Succ(inner, _) if is_zero(inner)) +} + +/// Check if a level is definitionally zero: l <= 0. +pub fn is_zero(l: &Level) -> bool { + leq(l, &Level::zero()) +} + +/// Check if `l <= r`. +pub fn leq(l: &Level, r: &Level) -> bool { + let l_s = simplify(l); + let r_s = simplify(r); + leq_core(&l_s, &r_s, 0) +} + +/// Check `l <= r + diff`. +fn leq_core(l: &Level, r: &Level, diff: isize) -> bool { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) if diff >= 0 => true, + (_, LevelData::Zero(_)) if diff < 0 => false, + (LevelData::Param(a, _), LevelData::Param(b, _)) => a == b && diff >= 0, + (LevelData::Param(..), LevelData::Zero(_)) => false, + (LevelData::Zero(_), LevelData::Param(..)) => diff >= 0, + (LevelData::Succ(s, _), _) => leq_core(s, r, diff - 1), + (_, LevelData::Succ(s, _)) => leq_core(l, s, diff + 1), + (LevelData::Max(a, b, _), _) => { + leq_core(a, r, diff) && leq_core(b, r, diff) + }, + (LevelData::Param(..) | LevelData::Zero(_), LevelData::Max(x, y, _)) => { + leq_core(l, x, diff) || leq_core(l, y, diff) + }, + (LevelData::Imax(a, b, _), LevelData::Imax(x, y, _)) + if a == x && b == y => + { + true + }, + (LevelData::Imax(_, b, _), _) if is_param(b) => { + leq_imax_by_cases(b, l, r, diff) + }, + (_, LevelData::Imax(_, y, _)) if is_param(y) => { + leq_imax_by_cases(y, l, r, diff) + }, + (LevelData::Imax(a, b, _), _) if is_any_max(b) => { + match b.as_data() { + LevelData::Imax(x, y, _) => { + let new_lhs = Level::imax(a.clone(), y.clone()); + let new_rhs = Level::imax(x.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(&new_max, r, diff) + }, + LevelData::Max(x, y, _) => { + let new_lhs = Level::imax(a.clone(), x.clone()); + let new_rhs = Level::imax(a.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(&simplified, r, diff) + }, + _ => unreachable!(), + } + }, + (_, LevelData::Imax(x, y, _)) if is_any_max(y) => { + match y.as_data() { + LevelData::Imax(j, k, _) => { + let new_lhs = Level::imax(x.clone(), k.clone()); + let new_rhs = Level::imax(j.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(l, &new_max, diff) + }, + LevelData::Max(j, k, _) => { + let new_lhs = Level::imax(x.clone(), j.clone()); + let new_rhs = Level::imax(x.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(l, &simplified, diff) + }, + _ => unreachable!(), + } + }, + _ => false, + } +} + +/// Test l <= r by substituting param with 0 and Succ(param) and checking both. +fn leq_imax_by_cases( + param: &Level, + lhs: &Level, + rhs: &Level, + diff: isize, +) -> bool { + let zero = Level::zero(); + let succ_param = Level::succ(param.clone()); + + let lhs_0 = subst_and_simplify(lhs, param, &zero); + let rhs_0 = subst_and_simplify(rhs, param, &zero); + let lhs_s = subst_and_simplify(lhs, param, &succ_param); + let rhs_s = subst_and_simplify(rhs, param, &succ_param); + + leq_core(&lhs_0, &rhs_0, diff) && leq_core(&lhs_s, &rhs_s, diff) +} + +fn subst_and_simplify(level: &Level, from: &Level, to: &Level) -> Level { + let substituted = subst_single_level(level, from, to); + simplify(&substituted) +} + +/// Substitute a single level parameter. +fn subst_single_level(level: &Level, from: &Level, to: &Level) -> Level { + if level == from { + return to.clone(); + } + match level.as_data() { + LevelData::Zero(_) | LevelData::Mvar(..) => level.clone(), + LevelData::Param(..) => { + if level == from { + to.clone() + } else { + level.clone() + } + }, + LevelData::Succ(inner, _) => { + Level::succ(subst_single_level(inner, from, to)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + } +} + +fn is_param(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Param(..)) +} + +fn is_any_max(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Max(..) | LevelData::Imax(..)) +} + +/// Check universe level equality via antisymmetry: l == r iff l <= r && r <= l. +pub fn eq_antisymm(l: &Level, r: &Level) -> bool { + leq(l, r) && leq(r, l) +} + +/// Check that two lists of levels are pointwise equal. +pub fn eq_antisymm_many(ls: &[Level], rs: &[Level]) -> bool { + ls.len() == rs.len() + && ls.iter().zip(rs.iter()).all(|(l, r)| eq_antisymm(l, r)) +} + +/// Substitute universe parameters: `level[params[i] := values[i]]`. +pub fn subst_level( + level: &Level, + params: &[Name], + values: &[Level], +) -> Level { + match level.as_data() { + LevelData::Zero(_) => level.clone(), + LevelData::Succ(inner, _) => { + Level::succ(subst_level(inner, params, values)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Param(name, _) => { + for (i, p) in params.iter().enumerate() { + if name == p { + return values[i].clone(); + } + } + level.clone() + }, + LevelData::Mvar(..) => level.clone(), + } +} + +/// Check that all universe parameters in `level` are contained in `params`. +pub fn all_uparams_defined(level: &Level, params: &[Name]) -> bool { + match level.as_data() { + LevelData::Zero(_) => true, + LevelData::Succ(inner, _) => all_uparams_defined(inner, params), + LevelData::Max(a, b, _) | LevelData::Imax(a, b, _) => { + all_uparams_defined(a, params) && all_uparams_defined(b, params) + }, + LevelData::Param(name, _) => params.iter().any(|p| p == name), + LevelData::Mvar(..) => true, + } +} + +/// Check that all universe parameters in an expression are contained in `params`. +/// Recursively walks the Expr, checking all Levels in Sort and Const nodes. +pub fn all_expr_uparams_defined(e: &Expr, params: &[Name]) -> bool { + match e.as_data() { + ExprData::Sort(level, _) => all_uparams_defined(level, params), + ExprData::Const(_, levels, _) => { + levels.iter().all(|l| all_uparams_defined(l, params)) + }, + ExprData::App(f, a, _) => { + all_expr_uparams_defined(f, params) + && all_expr_uparams_defined(a, params) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::LetE(_, t, v, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(v, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::Proj(_, _, s, _) => all_expr_uparams_defined(s, params), + ExprData::Mdata(_, inner, _) => all_expr_uparams_defined(inner, params), + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => true, + } +} + +/// Check that a list of levels are all Params with no duplicates. +pub fn no_dupes_all_params(levels: &[Name]) -> bool { + for (i, a) in levels.iter().enumerate() { + for b in &levels[i + 1..] { + if a == b { + return false; + } + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simplify_zero() { + let z = Level::zero(); + assert_eq!(simplify(&z), z); + } + + #[test] + fn test_simplify_max_zero() { + let z = Level::zero(); + let p = Level::param(Name::str(Name::anon(), "u".into())); + let m = Level::max(z, p.clone()); + assert_eq!(simplify(&m), p); + } + + #[test] + fn test_simplify_imax_zero_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let z = Level::zero(); + let im = Level::imax(p, z.clone()); + assert_eq!(simplify(&im), z); + } + + #[test] + fn test_simplify_imax_succ_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let one = Level::succ(Level::zero()); + let im = Level::imax(p.clone(), one.clone()); + let simplified = simplify(&im); + // imax(p, 1) where p is nonzero → combining(p, 1) + // Actually: imax(u, 1) simplifies since a_s = u, b_s = 1 = Succ(0) + // → combining(u, 1) = max(u, 1) since u is Param, 1 is Succ + let expected = Level::max(p, one); + assert_eq!(simplified, expected); + } + + #[test] + fn test_simplify_idempotent() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let l = Level::max( + Level::imax(p.clone(), q.clone()), + Level::succ(Level::zero()), + ); + let s1 = simplify(&l); + let s2 = simplify(&s1); + assert_eq!(s1, s2); + } + + #[test] + fn test_leq_reflexive() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&p, &p)); + assert!(leq(&Level::zero(), &Level::zero())); + } + + #[test] + fn test_leq_zero_anything() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&Level::zero(), &p)); + assert!(leq(&Level::zero(), &Level::succ(Level::zero()))); + } + + #[test] + fn test_leq_succ_not_zero() { + let one = Level::succ(Level::zero()); + assert!(!leq(&one, &Level::zero())); + } + + #[test] + fn test_eq_antisymm_identity() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(eq_antisymm(&p, &p)); + } + + #[test] + fn test_eq_antisymm_max_comm() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let m1 = Level::max(p.clone(), q.clone()); + let m2 = Level::max(q, p); + assert!(eq_antisymm(&m1, &m2)); + } + + #[test] + fn test_subst_level() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let one = Level::succ(Level::zero()); + let result = subst_level(&p, &[u_name], &[one.clone()]); + assert_eq!(result, one); + } + + #[test] + fn test_subst_level_nested() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let l = Level::succ(p); + let zero = Level::zero(); + let result = subst_level(&l, &[u_name], &[zero]); + let expected = Level::succ(Level::zero()); + assert_eq!(result, expected); + } +} diff --git a/src/ix/kernel/mod.rs b/src/ix/kernel/mod.rs new file mode 100644 index 00000000..d6a5750e --- /dev/null +++ b/src/ix/kernel/mod.rs @@ -0,0 +1,11 @@ +pub mod convert; +pub mod dag; +pub mod def_eq; +pub mod dll; +pub mod error; +pub mod inductive; +pub mod level; +pub mod quot; +pub mod tc; +pub mod upcopy; +pub mod whnf; diff --git a/src/ix/kernel/quot.rs b/src/ix/kernel/quot.rs new file mode 100644 index 00000000..51a1e070 --- /dev/null +++ b/src/ix/kernel/quot.rs @@ -0,0 +1,291 @@ +use crate::ix::env::*; + +use super::error::TcError; + +type TcResult = Result; + +/// Verify that the quotient declarations are consistent with the environment. +/// Checks that Quot is an inductive, Quot.mk is its constructor, and +/// Quot.lift and Quot.ind exist. +pub fn check_quot(env: &Env) -> TcResult<()> { + let quot_name = Name::str(Name::anon(), "Quot".into()); + let quot_mk_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "mk".into()); + let quot_lift_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "lift".into()); + let quot_ind_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "ind".into()); + + // Check Quot exists and is an inductive + let quot = + env.get("_name).ok_or(TcError::UnknownConst { name: quot_name })?; + match quot { + ConstantInfo::InductInfo(_) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot is not an inductive type".into(), + }) + }, + } + + // Check Quot.mk exists and is a constructor of Quot + let mk = env + .get("_mk_name) + .ok_or(TcError::UnknownConst { name: quot_mk_name })?; + match mk { + ConstantInfo::CtorInfo(c) + if c.induct + == Name::str(Name::anon(), "Quot".into()) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot.mk is not a constructor of Quot".into(), + }) + }, + } + + // Check Eq exists as an inductive with exactly 1 universe param and 1 ctor + let eq_name = Name::str(Name::anon(), "Eq".into()); + if let Some(eq_ci) = env.get(&eq_name) { + match eq_ci { + ConstantInfo::InductInfo(iv) => { + if iv.cnst.level_params.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }); + } + if iv.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 constructor, found {}", + iv.ctors.len() + ), + }); + } + }, + _ => { + return Err(TcError::KernelException { + msg: "Eq is not an inductive type".into(), + }) + }, + } + } else { + return Err(TcError::KernelException { + msg: "Eq not found in environment (required for quotient types)".into(), + }); + } + + // Check Quot has exactly 1 level param + match quot { + ConstantInfo::InductInfo(iv) if iv.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.mk has 1 level param + match mk { + ConstantInfo::CtorInfo(c) if c.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot.mk should have 1 universe parameter, found {}", + c.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.lift exists and has 2 level params + let lift = env + .get("_lift_name) + .ok_or(TcError::UnknownConst { name: quot_lift_name })?; + if lift.get_level_params().len() != 2 { + return Err(TcError::KernelException { + msg: format!( + "Quot.lift should have 2 universe parameters, found {}", + lift.get_level_params().len() + ), + }); + } + + // Check Quot.ind exists and has 1 level param + let ind = env + .get("_ind_name) + .ok_or(TcError::UnknownConst { name: quot_ind_name })?; + if ind.get_level_params().len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Quot.ind should have 1 universe parameter, found {}", + ind.get_level_params().len() + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + /// Build a well-formed quotient environment. + fn mk_quot_env() -> Env { + let mut env = Env::default(); + let u = mk_name("u"); + let v = mk_name("v"); + let dummy_ty = Expr::sort(Level::param(u.clone())); + + // Eq.{u} — 1 uparam, 1 ctor + let eq_name = mk_name("Eq"); + let eq_refl = mk_name2("Eq", "refl"); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Eq"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + // Quot.{u} — 1 uparam + let quot_name = mk_name("Quot"); + let quot_mk = mk_name2("Quot", "mk"); + env.insert( + quot_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: quot_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![quot_name], + ctors: vec![quot_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + quot_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: quot_mk, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Quot"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Quot.lift.{u,v} — 2 uparams + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![u.clone(), v.clone()], + typ: dummy_ty.clone(), + }, + is_unsafe: false, + }), + ); + + // Quot.ind.{u} — 1 uparam + let quot_ind = mk_name2("Quot", "ind"); + env.insert( + quot_ind.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_ind, + level_params: vec![u], + typ: dummy_ty, + }, + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_quot_well_formed() { + let env = mk_quot_env(); + assert!(check_quot(&env).is_ok()); + } + + #[test] + fn check_quot_missing_eq() { + let mut env = mk_quot_env(); + env.remove(&mk_name("Eq")); + assert!(check_quot(&env).is_err()); + } + + #[test] + fn check_quot_wrong_lift_levels() { + let mut env = mk_quot_env(); + // Replace Quot.lift with 1 level param instead of 2 + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }), + ); + assert!(check_quot(&env).is_err()); + } +} diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs new file mode 100644 index 00000000..e80416fd --- /dev/null +++ b/src/ix/kernel/tc.rs @@ -0,0 +1,1694 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; +use rustc_hash::FxHashMap; + +use super::def_eq::def_eq; +use super::error::TcError; +use super::level::{all_expr_uparams_defined, no_dupes_all_params}; +use super::whnf::*; + +type TcResult = Result; + +/// The kernel type checker. +pub struct TypeChecker<'env> { + pub env: &'env Env, + pub whnf_cache: FxHashMap, + pub infer_cache: FxHashMap, + pub local_counter: u64, + pub local_types: FxHashMap, +} + +impl<'env> TypeChecker<'env> { + pub fn new(env: &'env Env) -> Self { + TypeChecker { + env, + whnf_cache: FxHashMap::default(), + infer_cache: FxHashMap::default(), + local_counter: 0, + local_types: FxHashMap::default(), + } + } + + // ========================================================================== + // WHNF with caching + // ========================================================================== + + pub fn whnf(&mut self, e: &Expr) -> Expr { + if let Some(cached) = self.whnf_cache.get(e) { + return cached.clone(); + } + let result = whnf(e, self.env); + self.whnf_cache.insert(e.clone(), result.clone()); + result + } + + // ========================================================================== + // Local context management + // ========================================================================== + + /// Create a fresh free variable for entering a binder. + pub fn mk_local(&mut self, name: &Name, ty: &Expr) -> Expr { + let id = self.local_counter; + self.local_counter += 1; + let local_name = Name::num(name.clone(), Nat::from(id)); + self.local_types.insert(local_name.clone(), ty.clone()); + Expr::fvar(local_name) + } + + // ========================================================================== + // Ensure helpers + // ========================================================================== + + pub fn ensure_sort(&mut self, e: &Expr) -> TcResult { + if let ExprData::Sort(level, _) = e.as_data() { + return Ok(level.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::Sort(level, _) => Ok(level.clone()), + _ => Err(TcError::TypeExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + pub fn ensure_pi(&mut self, e: &Expr) -> TcResult { + if let ExprData::ForallE(..) = e.as_data() { + return Ok(e.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::ForallE(..) => Ok(whnfd), + _ => Err(TcError::FunctionExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + /// Infer the type of `e` and ensure it's a sort; return the universe level. + pub fn infer_sort_of(&mut self, e: &Expr) -> TcResult { + let ty = self.infer(e)?; + let whnfd = self.whnf(&ty); + self.ensure_sort(&whnfd) + } + + // ========================================================================== + // Type inference + // ========================================================================== + + pub fn infer(&mut self, e: &Expr) -> TcResult { + if let Some(cached) = self.infer_cache.get(e) { + return Ok(cached.clone()); + } + let result = self.infer_core(e)?; + self.infer_cache.insert(e.clone(), result.clone()); + Ok(result) + } + + fn infer_core(&mut self, e: &Expr) -> TcResult { + match e.as_data() { + ExprData::Sort(level, _) => self.infer_sort(level), + ExprData::Const(name, levels, _) => self.infer_const(name, levels), + ExprData::App(..) => self.infer_app(e), + ExprData::Lam(..) => self.infer_lambda(e), + ExprData::ForallE(..) => self.infer_pi(e), + ExprData::LetE(_, typ, val, body, _, _) => { + self.infer_let(typ, val, body) + }, + ExprData::Lit(lit, _) => self.infer_lit(lit), + ExprData::Proj(type_name, idx, structure, _) => { + self.infer_proj(type_name, idx, structure) + }, + ExprData::Mdata(_, inner, _) => self.infer(inner), + ExprData::Fvar(name, _) => { + match self.local_types.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context".into(), + }), + } + }, + ExprData::Bvar(idx, _) => Err(TcError::FreeBoundVariable { + idx: idx.to_u64().unwrap_or(u64::MAX), + }), + ExprData::Mvar(..) => Err(TcError::KernelException { + msg: "cannot infer type of metavariable".into(), + }), + } + } + + fn infer_sort(&mut self, level: &Level) -> TcResult { + Ok(Expr::sort(Level::succ(level.clone()))) + } + + fn infer_const( + &mut self, + name: &Name, + levels: &[Level], + ) -> TcResult { + let ci = self + .env + .get(name) + .ok_or_else(|| TcError::UnknownConst { name: name.clone() })?; + + let decl_params = ci.get_level_params(); + if levels.len() != decl_params.len() { + return Err(TcError::KernelException { + msg: format!( + "universe parameter count mismatch for {}", + name.pretty() + ), + }); + } + + let ty = ci.get_type(); + Ok(subst_expr_levels(ty, decl_params, levels)) + } + + fn infer_app(&mut self, e: &Expr) -> TcResult { + let (fun, args) = unfold_apps(e); + let mut fun_ty = self.infer(&fun)?; + + for arg in &args { + let pi = self.ensure_pi(&fun_ty)?; + match pi.as_data() { + ExprData::ForallE(_, binder_type, body, _, _) => { + // Check argument type matches binder + let arg_ty = self.infer(arg)?; + self.assert_def_eq(&arg_ty, binder_type)?; + fun_ty = inst(body, &[arg.clone()]); + }, + _ => unreachable!(), + } + } + + Ok(fun_ty) + } + + fn infer_lambda(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut binder_types = Vec::new(); + let mut binder_infos = Vec::new(); + let mut binder_names = Vec::new(); + + while let ExprData::Lam(name, binder_type, body, bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + self.infer_sort_of(&binder_type_inst)?; + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + binder_types.push(binder_type_inst); + binder_infos.push(bi.clone()); + binder_names.push(name.clone()); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let body_ty = self.infer(&body_inst)?; + + // Abstract back: build Pi telescope + let mut result = abstr(&body_ty, &locals); + for i in (0..locals.len()).rev() { + let binder_type_abstrd = abstr(&binder_types[i], &locals[..i]); + result = Expr::all( + binder_names[i].clone(), + binder_type_abstrd, + result, + binder_infos[i].clone(), + ); + } + + Ok(result) + } + + fn infer_pi(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut universes = Vec::new(); + + while let ExprData::ForallE(name, binder_type, body, _bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + let dom_univ = self.infer_sort_of(&binder_type_inst)?; + universes.push(dom_univ); + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let mut result_level = self.infer_sort_of(&body_inst)?; + + for univ in universes.into_iter().rev() { + result_level = Level::imax(univ, result_level); + } + + Ok(Expr::sort(result_level)) + } + + fn infer_let( + &mut self, + typ: &Expr, + val: &Expr, + body: &Expr, + ) -> TcResult { + // Verify value matches declared type + let val_ty = self.infer(val)?; + self.assert_def_eq(&val_ty, typ)?; + let body_inst = inst(body, &[val.clone()]); + self.infer(&body_inst) + } + + fn infer_lit(&mut self, lit: &Literal) -> TcResult { + match lit { + Literal::NatVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "Nat".into()), vec![])) + }, + Literal::StrVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "String".into()), vec![])) + }, + } + } + + fn infer_proj( + &mut self, + type_name: &Name, + idx: &Nat, + structure: &Expr, + ) -> TcResult { + let structure_ty = self.infer(structure)?; + let structure_ty_whnf = self.whnf(&structure_ty); + + let (_, struct_ty_args) = unfold_apps(&structure_ty_whnf); + let struct_ty_head = match unfold_apps(&structure_ty_whnf).0.as_data() { + ExprData::Const(name, levels, _) => (name.clone(), levels.clone()), + _ => { + return Err(TcError::KernelException { + msg: "projection structure type is not a constant".into(), + }) + }, + }; + + let ind = self.env.get(&struct_ty_head.0).ok_or_else(|| { + TcError::UnknownConst { name: struct_ty_head.0.clone() } + })?; + + let (num_params, ctor_name) = match ind { + ConstantInfo::InductInfo(iv) => { + let ctor = iv.ctors.first().ok_or_else(|| { + TcError::KernelException { + msg: "inductive has no constructors".into(), + } + })?; + (iv.num_params.to_u64().unwrap(), ctor.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection type is not an inductive".into(), + }) + }, + }; + + let ctor_ci = self.env.get(&ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + + let mut ctor_ty = subst_expr_levels( + ctor_ci.get_type(), + ctor_ci.get_level_params(), + &struct_ty_head.1, + ); + + // Skip params + for i in 0..num_params as usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ctor_ty = inst(body, &[struct_ty_args[i].clone()]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (params)".into(), + }) + }, + } + } + + // Walk to the idx-th field + let idx_usize = idx.to_u64().unwrap() as usize; + for i in 0..idx_usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + let proj = + Expr::proj(type_name.clone(), Nat::from(i as u64), structure.clone()); + ctor_ty = inst(body, &[proj]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (fields)".into(), + }) + }, + } + } + + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, binder_type, _, _, _) => { + Ok(binder_type.clone()) + }, + _ => Err(TcError::KernelException { + msg: "ran out of constructor telescope (target field)".into(), + }), + } + } + + // ========================================================================== + // Definitional equality (delegated to def_eq module) + // ========================================================================== + + pub fn def_eq(&mut self, x: &Expr, y: &Expr) -> bool { + def_eq(x, y, self) + } + + pub fn assert_def_eq(&mut self, x: &Expr, y: &Expr) -> TcResult<()> { + if self.def_eq(x, y) { + Ok(()) + } else { + Err(TcError::DefEqFailure { lhs: x.clone(), rhs: y.clone() }) + } + } + + // ========================================================================== + // Declaration checking + // ========================================================================== + + /// Check that a declaration's type is well-formed. + pub fn check_declar_info( + &mut self, + info: &ConstantVal, + ) -> TcResult<()> { + // Check for duplicate universe params + if !no_dupes_all_params(&info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "duplicate universe parameters in {}", + info.name.pretty() + ), + }); + } + + // Check that the type has no loose bound variables + if has_loose_bvars(&info.typ) { + return Err(TcError::KernelException { + msg: format!( + "free bound variables in type of {}", + info.name.pretty() + ), + }); + } + + // Check that all universe parameters in the type are declared + if !all_expr_uparams_defined(&info.typ, &info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in type of {}", + info.name.pretty() + ), + }); + } + + // Check that the type is a type (infers to a Sort) + let inferred = self.infer(&info.typ)?; + self.ensure_sort(&inferred)?; + + Ok(()) + } + + /// Check a single declaration. + pub fn check_declar( + &mut self, + ci: &ConstantInfo, + ) -> TcResult<()> { + match ci { + ConstantInfo::AxiomInfo(v) => { + self.check_declar_info(&v.cnst)?; + }, + ConstantInfo::DefnInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::ThmInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::OpaqueInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::QuotInfo(v) => { + self.check_declar_info(&v.cnst)?; + super::quot::check_quot(self.env)?; + }, + ConstantInfo::InductInfo(v) => { + super::inductive::check_inductive(v, self)?; + }, + ConstantInfo::CtorInfo(v) => { + self.check_declar_info(&v.cnst)?; + // Verify the parent inductive exists + if self.env.get(&v.induct).is_none() { + return Err(TcError::UnknownConst { + name: v.induct.clone(), + }); + } + }, + ConstantInfo::RecInfo(v) => { + self.check_declar_info(&v.cnst)?; + for ind_name in &v.all { + if self.env.get(ind_name).is_none() { + return Err(TcError::UnknownConst { + name: ind_name.clone(), + }); + } + } + super::inductive::validate_k_flag(v, self.env)?; + }, + } + Ok(()) + } +} + +/// Check all declarations in an environment. +pub fn check_env(env: &Env) -> Vec<(Name, TcError)> { + let mut errors = Vec::new(); + for (name, ci) in env.iter() { + let mut tc = TypeChecker::new(env); + if let Err(e) = tc.check_declar(ci) { + errors.push((name.clone(), e)); + } + } + errors +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + fn prop() -> Expr { + Expr::sort(Level::zero()) + } + + fn type_u() -> Expr { + Expr::sort(Level::param(mk_name("u"))) + } + + /// Build a minimal environment with Nat, Nat.zero, and Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + + let nat_name = mk_name("Nat"); + // Nat : Sort 1 + let nat = ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }); + env.insert(nat_name, nat); + + // Nat.zero : Nat + let zero_name = mk_name2("Nat", "zero"); + let zero = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + env.insert(zero_name, zero); + + // Nat.succ : Nat → Nat + let succ_name = mk_name2("Nat", "succ"); + let succ_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let succ = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: succ_ty, + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }); + env.insert(succ_name, succ); + + env + } + + // ========================================================================== + // Infer: Sort + // ========================================================================== + + #[test] + fn infer_sort_zero() { + // Sort(0) : Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = prop(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_sort_succ() { + // Sort(1) : Sort(2) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::succ(Level::zero())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::succ(Level::zero())))); + } + + #[test] + fn infer_sort_param() { + // Sort(u) : Sort(u+1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let e = Expr::sort(u.clone()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(u))); + } + + // ========================================================================== + // Infer: Const + // ========================================================================== + + #[test] + fn infer_const_nat() { + // Nat : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_const_nat_zero() { + // Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_const_nat_succ() { + // Nat.succ : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let ty = tc.infer(&e).unwrap(); + let expected = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_unknown() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("NonExistent"), vec![]); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_const_universe_mismatch() { + // Nat has 0 universe params; passing 1 should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![Level::zero()]); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // Infer: Lit + // ========================================================================== + + #[test] + fn infer_nat_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_string_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::StrVal("hello".into())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::cnst(mk_name("String"), vec![])); + } + + // ========================================================================== + // Infer: Lambda + // ========================================================================== + + #[test] + fn infer_identity_lambda() { + // fun (x : Nat) => x : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let ty = tc.infer(&id_fn).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_lambda() { + // fun (x : Nat) (y : Nat) => x : Nat → Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let body = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + body, + BinderInfo::Default, + ); + let ty = tc.infer(&k_fn).unwrap(); + // Nat → Nat → Nat + let expected = Expr::all( + mk_name("x"), + nat_type(), + Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + // ========================================================================== + // Infer: App + // ========================================================================== + + #[test] + fn infer_app_succ_zero() { + // Nat.succ Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_app_identity() { + // (fun x : Nat => x) Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Pi + // ========================================================================== + + #[test] + fn infer_pi_nat_to_nat() { + // (Nat → Nat) : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let pi = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(1, 1)) which simplifies to Sort(1) + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::eq_antisymm( + level, + &Level::succ(Level::zero()) + ), + "Nat → Nat should live in Sort 1, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + #[test] + fn infer_pi_prop_to_prop() { + // (Prop → Prop) : Sort 1 + // An axiom P : Prop, then P → P : Sort 1 + let mut env = Env::default(); + let p_name = mk_name("P"); + env.insert( + p_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: prop(), + }, + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + let p = Expr::cnst(p_name, vec![]); + let pi = Expr::all(mk_name("x"), p.clone(), p.clone(), BinderInfo::Default); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(0, 0)) = Sort(0) = Prop + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::is_zero(level), + "Prop → Prop should live in Prop, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + // ========================================================================== + // Infer: Let + // ========================================================================== + + #[test] + fn infer_let_simple() { + // let x : Nat := Nat.zero in x : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: errors + // ========================================================================== + + #[test] + fn infer_free_bvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::bvar(Nat::from(0u64)); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_fvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::fvar(mk_name("x")); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_app_wrong_arg_type() { + // Nat.succ expects Nat, but we pass Sort(0) — should fail with DefEqFailure + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + prop(), // Sort(0), not Nat + ); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_let_type_mismatch() { + // let x : Nat → Nat := Nat.zero in x + // Nat.zero : Nat, but annotation says Nat → Nat — should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let nat_to_nat = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::letE( + mk_name("x"), + nat_to_nat, + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // check_declar + // ========================================================================== + + #[test] + fn check_axiom_declar() { + // axiom myAxiom : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("myAxiom"), + level_params: vec![], + typ: ax_ty, + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_declar() { + // def myId : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }); + assert!(tc.check_declar(&defn).is_ok()); + } + + #[test] + fn check_defn_type_mismatch() { + // def bad : Nat := Nat.succ (wrong: Nat.succ : Nat → Nat, not Nat) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + #[test] + fn check_declar_loose_bvar() { + // Type with a dangling bound variable should fail + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: Expr::bvar(Nat::from(0u64)), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_declar_duplicate_uparams() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![u.clone(), u], + typ: type_u(), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + // ========================================================================== + // check_env + // ========================================================================== + + #[test] + fn check_nat_env() { + let env = mk_nat_env(); + let errors = check_env(&env); + assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors); + } + + // ========================================================================== + // Polymorphic constants + // ========================================================================== + + #[test] + fn infer_polymorphic_const() { + // axiom A.{u} : Sort u + // A.{0} should give Sort(0) + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone()], + typ: Expr::sort(Level::param(u_name)), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + // A.{0} : Sort(0) + let e = Expr::cnst(a_name, vec![Level::zero()]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::zero())); + } + + // ========================================================================== + // Infer: whnf caching + // ========================================================================== + + #[test] + fn whnf_cache_works() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + let r1 = tc.whnf(&e); + let r2 = tc.whnf(&e); + assert_eq!(r1, r2); + } + + // ========================================================================== + // check_declar: Theorem + // ========================================================================== + + #[test] + fn check_theorem_declar() { + // theorem myThm : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("myThm"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + all: vec![mk_name("myThm")], + }); + assert!(tc.check_declar(&thm).is_ok()); + } + + #[test] + fn check_theorem_type_mismatch() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("badThm"), + level_params: vec![], + typ: nat_type(), // claims : Nat + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), // but is : Nat → Nat + all: vec![mk_name("badThm")], + }); + assert!(tc.check_declar(&thm).is_err()); + } + + // ========================================================================== + // check_declar: Opaque + // ========================================================================== + + #[test] + fn check_opaque_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let opaque = ConstantInfo::OpaqueInfo(OpaqueVal { + cnst: ConstantVal { + name: mk_name("myOpaque"), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + is_unsafe: false, + all: vec![mk_name("myOpaque")], + }); + assert!(tc.check_declar(&opaque).is_ok()); + } + + // ========================================================================== + // check_declar: Ctor (parent existence check) + // ========================================================================== + + #[test] + fn check_ctor_missing_parent() { + // A constructor whose parent inductive doesn't exist + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "mk"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + induct: mk_name("Fake"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_err()); + } + + #[test] + fn check_ctor_with_parent() { + // Nat.zero : Nat, with Nat in env + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "zero"), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_ok()); + } + + // ========================================================================== + // check_declar: Rec (mutual reference check) + // ========================================================================== + + #[test] + fn check_rec_missing_inductive() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "rec"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + all: vec![mk_name("Fake")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(0u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_with_inductive() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + // ========================================================================== + // Infer: App with delta (definition in head) + // ========================================================================== + + #[test] + fn infer_app_through_delta() { + // def myId : Nat → Nat := fun x => x + // myId Nat.zero : Nat + let mut env = mk_nat_env(); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + env.insert( + mk_name("myId"), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }), + ); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name("myId"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Proj + // ========================================================================== + + /// Build an env with a simple Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_name_prod = mk_name2("Prod", "mk"); + + // Prod.{u,v} : Sort u → Sort v → Sort (max u v) + // Simplified: Prod (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_name_prod.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + // Type: (α : Sort u) → (β : Sort v) → α → β → Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_name_prod.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name_prod, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn infer_proj_fst() { + // Given p : Prod Nat Nat, (Prod.1 p) : Nat + // Build: Prod.mk Nat Nat Nat.zero Nat.zero, then project field 0 + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let pair = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + nat_zero(), + ), + nat_zero(), + ); + + let proj = Expr::proj(mk_name("Prod"), Nat::from(0u64), pair); + let ty = tc.infer(&proj).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: nested let + // ========================================================================== + + #[test] + fn infer_nested_let() { + // let x := Nat.zero in let y := x in y : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::letE( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), // x + Expr::bvar(Nat::from(0u64)), // y + false, + ); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + inner, + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer caching + // ========================================================================== + + #[test] + fn infer_cache_hit() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty1 = tc.infer(&e).unwrap(); + let ty2 = tc.infer(&e).unwrap(); + assert_eq!(ty1, ty2); + assert_eq!(tc.infer_cache.len(), 1); + } + + // ========================================================================== + // Universe parameter validation + // ========================================================================== + + #[test] + fn check_axiom_undeclared_uparam_in_type() { + // axiom bad.{u} : Sort v — v is not declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("v"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_axiom_declared_uparam_in_type() { + // axiom good.{u} : Sort u — u is declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("good"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_undeclared_uparam_in_value() { + // def bad.{u} : Sort 1 := Sort v — v not declared, in value + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: Expr::sort(Level::param(mk_name("v"))), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + // ========================================================================== + // K-flag validation + // ========================================================================== + + /// Build an env with a Prop inductive + single zero-field ctor (Eq-like). + fn mk_eq_like_env() -> Env { + let mut env = mk_nat_env(); + let u = mk_name("u"); + let eq_name = mk_name("MyEq"); + let eq_refl = mk_name2("MyEq", "refl"); + + // MyEq.{u} (α : Sort u) (a : α) : α → Prop + // Simplified: type lives in Prop (Sort 0) + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + // MyEq.refl.{u} (α : Sort u) (a : α) : MyEq α a a + // zero fields + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_rec_k_flag_valid() { + let env = mk_eq_like_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("MyEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("MyEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + #[test] + fn check_rec_k_flag_invalid_2_ctors() { + // Nat has 2 constructors — K should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, // invalid: Nat is not in Prop and has 2 ctors + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } +} diff --git a/src/ix/kernel/upcopy.rs b/src/ix/kernel/upcopy.rs new file mode 100644 index 00000000..89dae8a0 --- /dev/null +++ b/src/ix/kernel/upcopy.rs @@ -0,0 +1,659 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Name}; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Upcopy +// ============================================================================ + +pub fn upcopy(new_child: DAGPtr, cc: ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + let var = &lam.var; + let new_lam = alloc_lam(var.depth, new_child, None); + let new_lam_ref = &mut *new_lam.as_ptr(); + let bod_ref_ptr = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_child, bod_ref_ptr); + let new_var_ptr = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + for parent in DLL::iter_option(var.parents) { + upcopy(DAGPtr::Var(new_var_ptr), *parent); + } + for parent in DLL::iter_option(lam.parents) { + upcopy(DAGPtr::Lam(new_lam), *parent); + } + }, + ParentPtr::AppFun(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).fun = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(new_child, app.arg); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).arg = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(app.fun, new_child); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::FunDom(link) => { + let fun = &mut *link.as_ptr(); + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_child, + fun.img, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + // new_child must be a Lam + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("FunImg parent expects Lam child"), + }; + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + new_lam, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::PiDom(link) => { + let pi = &mut *link.as_ptr(); + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_child, + pi.img, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("PiImg parent expects Lam child"), + }; + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + pi.dom, + new_lam, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::LetTyp(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).typ = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + new_child, + let_node.val, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetVal(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).val = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + new_child, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("LetBod parent expects Lam child"), + }; + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).bod = new_lam; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + let_node.val, + new_lam, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + let new_proj = alloc_proj_no_uplinks( + proj.type_name.clone(), + proj.idx.clone(), + new_child, + ); + for parent in DLL::iter_option(proj.parents) { + upcopy(DAGPtr::Proj(new_proj), *parent); + } + }, + } + } +} + +// ============================================================================ +// No-uplink allocators for upcopy +// ============================================================================ + +fn alloc_app_no_uplinks(fun: DAGPtr, arg: DAGPtr) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +fn alloc_fun_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +fn alloc_pi_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +fn alloc_let_no_uplinks( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +fn alloc_proj_no_uplinks( + type_name: Name, + idx: crate::lean::nat::Nat, + expr: DAGPtr, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents: None, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Clean up: Clear copy caches after reduction +// ============================================================================ + +pub fn clean_up(cc: &ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + for parent in DLL::iter_option(lam.var.parents) { + clean_up(parent); + } + for parent in DLL::iter_option(lam.parents) { + clean_up(parent); + } + }, + ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + if let Some(app_copy) = app.copy { + let App { fun, arg, fun_ref, arg_ref, .. } = + &mut *app_copy.as_ptr(); + app.copy = None; + add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); + add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); + for parent in DLL::iter_option(app.parents) { + clean_up(parent); + } + } + }, + ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + if let Some(fun_copy) = fun.copy { + let Fun { dom, img, dom_ref, img_ref, .. } = + &mut *fun_copy.as_ptr(); + fun.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(fun.parents) { + clean_up(parent); + } + } + }, + ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + if let Some(pi_copy) = pi.copy { + let Pi { dom, img, dom_ref, img_ref, .. } = + &mut *pi_copy.as_ptr(); + pi.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(pi.parents) { + clean_up(parent); + } + } + }, + ParentPtr::LetTyp(link) + | ParentPtr::LetVal(link) + | ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + if let Some(let_copy) = let_node.copy { + let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = + &mut *let_copy.as_ptr(); + let_node.copy = None; + add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); + add_to_parents(*val, NonNull::new(val_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); + for parent in DLL::iter_option(let_node.parents) { + clean_up(parent); + } + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + for parent in DLL::iter_option(proj.parents) { + clean_up(parent); + } + }, + } + } +} + +// ============================================================================ +// Replace child +// ============================================================================ + +pub fn replace_child(old: DAGPtr, new: DAGPtr) { + unsafe { + if let Some(parents) = get_parents(old) { + for parent in DLL::iter_option(Some(parents)) { + match parent { + ParentPtr::Root => {}, + ParentPtr::LamBod(p) => (*p.as_ptr()).bod = new, + ParentPtr::FunDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::FunImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("FunImg expects Lam"), + }, + ParentPtr::PiDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::PiImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("PiImg expects Lam"), + }, + ParentPtr::AppFun(p) => (*p.as_ptr()).fun = new, + ParentPtr::AppArg(p) => (*p.as_ptr()).arg = new, + ParentPtr::LetTyp(p) => (*p.as_ptr()).typ = new, + ParentPtr::LetVal(p) => (*p.as_ptr()).val = new, + ParentPtr::LetBod(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).bod = lam, + _ => panic!("LetBod expects Lam"), + }, + ParentPtr::ProjExpr(p) => (*p.as_ptr()).expr = new, + } + } + set_parents(old, None); + match get_parents(new) { + None => set_parents(new, Some(parents)), + Some(new_parents) => { + DLL::concat(new_parents, Some(parents)); + }, + } + } + } +} + +// ============================================================================ +// Free dead nodes +// ============================================================================ + +pub fn free_dead_node(node: DAGPtr) { + unsafe { + match node { + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + let bod_ref_ptr = &lam.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(lam.bod, Some(remaining)); + } else { + set_parents(lam.bod, None); + free_dead_node(lam.bod); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun_ref_ptr = &app.fun_ref as *const Parents; + if let Some(remaining) = (*fun_ref_ptr).unlink_node() { + set_parents(app.fun, Some(remaining)); + } else { + set_parents(app.fun, None); + free_dead_node(app.fun); + } + let arg_ref_ptr = &app.arg_ref as *const Parents; + if let Some(remaining) = (*arg_ref_ptr).unlink_node() { + set_parents(app.arg, Some(remaining)); + } else { + set_parents(app.arg, None); + free_dead_node(app.arg); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let dom_ref_ptr = &fun.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(fun.dom, Some(remaining)); + } else { + set_parents(fun.dom, None); + free_dead_node(fun.dom); + } + let img_ref_ptr = &fun.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(fun.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(fun.img), None); + free_dead_node(DAGPtr::Lam(fun.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let dom_ref_ptr = &pi.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(pi.dom, Some(remaining)); + } else { + set_parents(pi.dom, None); + free_dead_node(pi.dom); + } + let img_ref_ptr = &pi.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(pi.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(pi.img), None); + free_dead_node(DAGPtr::Lam(pi.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let typ_ref_ptr = &let_node.typ_ref as *const Parents; + if let Some(remaining) = (*typ_ref_ptr).unlink_node() { + set_parents(let_node.typ, Some(remaining)); + } else { + set_parents(let_node.typ, None); + free_dead_node(let_node.typ); + } + let val_ref_ptr = &let_node.val_ref as *const Parents; + if let Some(remaining) = (*val_ref_ptr).unlink_node() { + set_parents(let_node.val, Some(remaining)); + } else { + set_parents(let_node.val, None); + free_dead_node(let_node.val); + } + let bod_ref_ptr = &let_node.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(let_node.bod), None); + free_dead_node(DAGPtr::Lam(let_node.bod)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let expr_ref_ptr = &proj.expr_ref as *const Parents; + if let Some(remaining) = (*expr_ref_ptr).unlink_node() { + set_parents(proj.expr, Some(remaining)); + } else { + set_parents(proj.expr, None); + free_dead_node(proj.expr); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + } + } +} + +// ============================================================================ +// Lambda reduction +// ============================================================================ + +/// Contract a lambda redex: (Fun dom (Lam bod var)) arg → [arg/var]bod. +pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { + unsafe { + let app = &*redex.as_ptr(); + let lambda = &*lam.as_ptr(); + let var = &lambda.var; + let arg = app.arg; + + if DLL::is_singleton(lambda.parents) { + if DLL::is_empty(var.parents) { + return lambda.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + return lambda.bod; + } + + if DLL::is_empty(var.parents) { + return lambda.bod; + } + + // General case: upcopy arg through var's parents + for parent in DLL::iter_option(var.parents) { + upcopy(arg, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lambda.bod + } +} + +/// Contract a let redex: Let(typ, val, Lam(bod, var)) → [val/var]bod. +pub fn reduce_let(let_node: NonNull) -> DAGPtr { + unsafe { + let ln = &*let_node.as_ptr(); + let lam = &*ln.bod.as_ptr(); + let var = &lam.var; + let val = ln.val; + + if DLL::is_singleton(lam.parents) { + if DLL::is_empty(var.parents) { + return lam.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), val); + return lam.bod; + } + + if DLL::is_empty(var.parents) { + return lam.bod; + } + + for parent in DLL::iter_option(var.parents) { + upcopy(val, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lam.bod + } +} diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs new file mode 100644 index 00000000..4fdde07a --- /dev/null +++ b/src/ix/kernel/whnf.rs @@ -0,0 +1,1420 @@ +use core::ptr::NonNull; + +use crate::ix::env::*; +use crate::lean::nat::Nat; +use num_bigint::BigUint; + +use super::convert::{from_expr, to_expr}; +use super::dag::*; +use super::level::{simplify, subst_level}; +use super::upcopy::{reduce_lam, reduce_let}; + + +// ============================================================================ +// Expression helpers (inst, unfold_apps, foldl_apps, subst_expr_levels) +// ============================================================================ + +/// Instantiate bound variables: `body[0 := substs[0], 1 := substs[1], ...]`. +/// `substs[0]` replaces `Bvar(0)` (innermost). +pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { + if substs.is_empty() { + return body.clone(); + } + inst_aux(body, substs, 0) +} + +fn inst_aux(e: &Expr, substs: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 >= offset { + let adjusted = (idx_u64 - offset) as usize; + if adjusted < substs.len() { + return substs[adjusted].clone(); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = inst_aux(f, substs, offset); + let a2 = inst_aux(a, substs, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = inst_aux(t, substs, offset); + let v2 = inst_aux(v, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = inst_aux(s, substs, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = inst_aux(inner, substs, offset); + Expr::mdata(kvs.clone(), inner2) + }, + // Terminals with no bound vars + ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Abstract: replace free variable `fvar` with `Bvar(offset)` in `e`. +pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { + if fvars.is_empty() { + return e.clone(); + } + abstr_aux(e, fvars, 0) +} + +fn abstr_aux(e: &Expr, fvars: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Fvar(..) => { + for (i, fv) in fvars.iter().enumerate().rev() { + if e == fv { + return Expr::bvar(Nat::from(i as u64 + offset)); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = abstr_aux(f, fvars, offset); + let a2 = abstr_aux(a, fvars, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = abstr_aux(t, fvars, offset); + let v2 = abstr_aux(v, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = abstr_aux(s, fvars, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = abstr_aux(inner, fvars, offset); + Expr::mdata(kvs.clone(), inner2) + }, + ExprData::Bvar(..) + | ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])`. +pub fn unfold_apps(e: &Expr) -> (Expr, Vec) { + let mut args = Vec::new(); + let mut cursor = e.clone(); + loop { + match cursor.as_data() { + ExprData::App(f, a, _) => { + args.push(a.clone()); + cursor = f.clone(); + }, + _ => break, + } + } + args.reverse(); + (cursor, args) +} + +/// Reconstruct `f a1 a2 ... an`. +pub fn foldl_apps(mut fun: Expr, args: impl Iterator) -> Expr { + for arg in args { + fun = Expr::app(fun, arg); + } + fun +} + +/// Substitute universe level parameters in an expression. +pub fn subst_expr_levels( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + if params.is_empty() { + return e.clone(); + } + subst_expr_levels_aux(e, params, values) +} + +fn subst_expr_levels_aux( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + match e.as_data() { + ExprData::Sort(level, _) => { + Expr::sort(subst_level(level, params, values)) + }, + ExprData::Const(name, levels, _) => { + let new_levels: Vec = + levels.iter().map(|l| subst_level(l, params, values)).collect(); + Expr::cnst(name.clone(), new_levels) + }, + ExprData::App(f, a, _) => { + let f2 = subst_expr_levels_aux(f, params, values); + let a2 = subst_expr_levels_aux(a, params, values); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let v2 = subst_expr_levels_aux(v, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = subst_expr_levels_aux(s, params, values); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = subst_expr_levels_aux(inner, params, values); + Expr::mdata(kvs.clone(), inner2) + }, + // No levels to substitute + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Check if an expression has any loose bound variables above `offset`. +pub fn has_loose_bvars(e: &Expr) -> bool { + has_loose_bvars_aux(e, 0) +} + +fn has_loose_bvars_aux(e: &Expr, depth: u64) -> bool { + match e.as_data() { + ExprData::Bvar(idx, _) => idx.to_u64().unwrap_or(u64::MAX) >= depth, + ExprData::App(f, a, _) => { + has_loose_bvars_aux(f, depth) || has_loose_bvars_aux(a, depth) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_loose_bvars_aux(t, depth) || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_loose_bvars_aux(t, depth) + || has_loose_bvars_aux(v, depth) + || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::Proj(_, _, s, _) => has_loose_bvars_aux(s, depth), + ExprData::Mdata(_, inner, _) => has_loose_bvars_aux(inner, depth), + _ => false, + } +} + +/// Check if expression contains any free variables (Fvar). +pub fn has_fvars(e: &Expr) -> bool { + match e.as_data() { + ExprData::Fvar(..) => true, + ExprData::App(f, a, _) => has_fvars(f) || has_fvars(a), + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_fvars(t) || has_fvars(b) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_fvars(t) || has_fvars(v) || has_fvars(b) + }, + ExprData::Proj(_, _, s, _) => has_fvars(s), + ExprData::Mdata(_, inner, _) => has_fvars(inner), + _ => false, + } +} + +// ============================================================================ +// Name helpers +// ============================================================================ + +pub(crate) fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) +} + +// ============================================================================ +// WHNF +// ============================================================================ + +/// Weak head normal form reduction. +/// +/// Uses DAG-based reduction internally: converts Expr to DAG, reduces using +/// BUBS (reduce_lam/reduce_let) for beta/zeta, falls back to Expr level for +/// iota/quot/nat/projection, and uses DAG-level splicing for delta. +pub fn whnf(e: &Expr, env: &Env) -> Expr { + let mut dag = from_expr(e); + whnf_dag(&mut dag, env); + let result = to_expr(&dag); + free_dag(dag); + result +} + +/// Trail-based WHNF on DAG. Walks down the App spine collecting a trail, +/// then dispatches on the head node. +fn whnf_dag(dag: &mut DAG, env: &Env) { + loop { + // Build trail of App nodes by walking down the fun chain + let mut trail: Vec> = Vec::new(); + let mut cursor = dag.head; + + loop { + match cursor { + DAGPtr::App(app) => { + trail.push(app); + cursor = unsafe { (*app.as_ptr()).fun }; + }, + _ => break, + } + } + + match cursor { + // Beta: Fun at head with args on trail + DAGPtr::Fun(fun_ptr) if !trail.is_empty() => { + let app = trail.pop().unwrap(); + let lam = unsafe { (*fun_ptr.as_ptr()).img }; + let result = reduce_lam(app, lam); + set_dag_head(dag, result, &trail); + continue; + }, + + // Zeta: Let at head + DAGPtr::Let(let_ptr) => { + let result = reduce_let(let_ptr); + set_dag_head(dag, result, &trail); + continue; + }, + + // Const: try iota, quot, nat, then delta + DAGPtr::Cnst(_) => { + // Try iota, quot, nat at Expr level + if try_expr_reductions(dag, env) { + continue; + } + // Try delta (definition unfolding) on DAG + if try_dag_delta(dag, &trail, env) { + continue; + } + return; // stuck + }, + + // Proj: try projection reduction (Expr-level fallback) + DAGPtr::Proj(_) => { + if try_expr_reductions(dag, env) { + continue; + } + return; // stuck + }, + + // Sort: simplify level in place + DAGPtr::Sort(sort_ptr) => { + unsafe { + let sort = &mut *sort_ptr.as_ptr(); + sort.level = simplify(&sort.level); + } + return; + }, + + // Mdata: strip metadata (Expr-level fallback) + DAGPtr::Lit(_) => { + // Check if this is a Nat literal that could be a Nat.succ application + // by trying Expr-level reductions (which handles nat ops) + if !trail.is_empty() { + if try_expr_reductions(dag, env) { + continue; + } + } + return; + }, + + // Everything else (Var, Pi, Lam without args, etc.): already WHNF + _ => return, + } + } +} + +/// Set the DAG head after a reduction step. +/// If trail is empty, the result becomes the new head. +/// If trail is non-empty, splice result into the innermost remaining App. +fn set_dag_head( + dag: &mut DAG, + result: DAGPtr, + trail: &[NonNull], +) { + if trail.is_empty() { + dag.head = result; + } else { + unsafe { + (*trail.last().unwrap().as_ptr()).fun = result; + } + dag.head = DAGPtr::App(trail[0]); + } +} + +/// Try iota/quot/nat/projection reductions at Expr level. +/// Converts current DAG to Expr, attempts reduction, converts back if +/// successful. +fn try_expr_reductions(dag: &mut DAG, env: &Env) -> bool { + let current_expr = to_expr(&DAG { head: dag.head }); + + let (head, args) = unfold_apps(¤t_expr); + + let reduced = match head.as_data() { + ExprData::Const(name, levels, _) => { + // Try iota (recursor) reduction + if let Some(result) = try_reduce_rec(name, levels, &args, env) { + Some(result) + } + // Try quotient reduction + else if let Some(result) = try_reduce_quot(name, &args, env) { + Some(result) + } + // Try nat reduction + else if let Some(result) = + try_reduce_nat(¤t_expr, env) + { + Some(result) + } else { + None + } + }, + ExprData::Proj(type_name, idx, structure, _) => { + reduce_proj(type_name, idx, structure, env) + .map(|result| foldl_apps(result, args.into_iter())) + }, + ExprData::Mdata(_, inner, _) => { + Some(foldl_apps(inner.clone(), args.into_iter())) + }, + _ => None, + }; + + if let Some(result_expr) = reduced { + let result_dag = from_expr(&result_expr); + dag.head = result_dag.head; + true + } else { + false + } +} + +/// Try delta (definition) unfolding on DAG. +/// Looks up the constant, substitutes universe levels in the definition body, +/// converts it to a DAG, and splices it into the current DAG. +fn try_dag_delta( + dag: &mut DAG, + trail: &[NonNull], + env: &Env, +) -> bool { + // Extract constant info from head + let cnst_ref = match dag_head_past_trail(dag, trail) { + DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, + _ => return false, + }; + + let ci = match env.get(&cnst_ref.name) { + Some(c) => c, + None => return false, + }; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) + if d.hints != ReducibilityHints::Opaque => + { + (&d.cnst.level_params, &d.value) + }, + _ => return false, + }; + + if cnst_ref.levels.len() != def_params.len() { + return false; + } + + // Substitute levels at Expr level, then convert to DAG + let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); + let body_dag = from_expr(&val); + + // Splice body into the working DAG + set_dag_head(dag, body_dag.head, trail); + true +} + +/// Get the head node past the trail (the non-App node at the bottom). +fn dag_head_past_trail( + dag: &DAG, + trail: &[NonNull], +) -> DAGPtr { + if trail.is_empty() { + dag.head + } else { + unsafe { (*trail.last().unwrap().as_ptr()).fun } + } +} + +/// Try to unfold a definition at the head. +pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + let (name, levels) = match head.as_data() { + ExprData::Const(name, levels, _) => (name, levels), + _ => return None, + }; + + let ci = env.get(name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + _ => return None, + }; + + if levels.len() != def_params.len() { + return None; + } + + let val = subst_expr_levels(def_value, def_params, levels); + Some(foldl_apps(val, args.into_iter())) +} + +/// Try to reduce a recursor application (iota reduction). +fn try_reduce_rec( + name: &Name, + levels: &[Level], + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let rec = match ci { + ConstantInfo::RecInfo(r) => r, + _ => return None, + }; + + let major_idx = rec.num_params.to_u64().unwrap() as usize + + rec.num_motives.to_u64().unwrap() as usize + + rec.num_minors.to_u64().unwrap() as usize + + rec.num_indices.to_u64().unwrap() as usize; + + let major = args.get(major_idx)?; + + // WHNF the major premise + let major_whnf = whnf(major, env); + + // Handle nat literal → constructor + let major_ctor = match major_whnf.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => nat_lit_to_constructor(n), + _ => major_whnf.clone(), + }; + + let (ctor_head, ctor_args) = unfold_apps(&major_ctor); + + // Find the matching rec rule + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + let rule = rec.rules.iter().find(|r| &r.ctor == ctor_name)?; + + let n_fields = rule.n_fields.to_u64().unwrap() as usize; + let num_params = rec.num_params.to_u64().unwrap() as usize; + let num_motives = rec.num_motives.to_u64().unwrap() as usize; + let num_minors = rec.num_minors.to_u64().unwrap() as usize; + + // The constructor args may have extra params for nested inductives + let ctor_args_wo_params = + if ctor_args.len() >= n_fields { + &ctor_args[ctor_args.len() - n_fields..] + } else { + return None; + }; + + // Substitute universe levels in the rule's RHS + let rhs = subst_expr_levels( + &rule.rhs, + &rec.cnst.level_params, + levels, + ); + + // Apply: params, motives, minors + let prefix_count = num_params + num_motives + num_minors; + let mut result = rhs; + for arg in args.iter().take(prefix_count) { + result = Expr::app(result, arg.clone()); + } + + // Apply constructor fields + for arg in ctor_args_wo_params { + result = Expr::app(result, arg.clone()); + } + + // Apply remaining args after major + for arg in args.iter().skip(major_idx + 1) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Convert a Nat literal to its constructor form. +fn nat_lit_to_constructor(n: &Nat) -> Expr { + if n.0 == BigUint::ZERO { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } else { + let pred = Nat(n.0.clone() - BigUint::from(1u64)); + let pred_expr = Expr::lit(Literal::NatVal(pred)); + Expr::app(Expr::cnst(mk_name2("Nat", "succ"), vec![]), pred_expr) + } +} + +/// Convert a string literal to its constructor form: +/// `"hello"` → `String.mk (List.cons 'h' (List.cons 'e' ... List.nil))` +/// where chars are represented as `Char.ofNat n`. +fn string_lit_to_constructor(s: &str) -> Expr { + let list_name = Name::str(Name::anon(), "List".into()); + let char_name = Name::str(Name::anon(), "Char".into()); + let char_type = Expr::cnst(char_name.clone(), vec![]); + + // Build the list from right to left + // List.nil.{0} : List Char + let nil = Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "nil".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ); + + let result = s.chars().rev().fold(nil, |acc, c| { + let char_val = Expr::app( + Expr::cnst(Name::str(char_name.clone(), "ofNat".into()), vec![]), + Expr::lit(Literal::NatVal(Nat::from(c as u64))), + ); + // List.cons.{0} Char char_val acc + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "cons".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ), + char_val, + ), + acc, + ) + }); + + // String.mk list + Expr::app( + Expr::cnst( + Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), + vec![], + ), + result, + ) +} + +/// Try to reduce a projection. +fn reduce_proj( + _type_name: &Name, + idx: &Nat, + structure: &Expr, + env: &Env, +) -> Option { + let structure_whnf = whnf(structure, env); + + // Handle string literal → constructor + let structure_ctor = match structure_whnf.as_data() { + ExprData::Lit(Literal::StrVal(s), _) => { + string_lit_to_constructor(s) + }, + _ => structure_whnf, + }; + + let (ctor_head, ctor_args) = unfold_apps(&structure_ctor); + + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + // Look up constructor to get num_params + let ci = env.get(ctor_name)?; + let num_params = match ci { + ConstantInfo::CtorInfo(c) => c.num_params.to_u64().unwrap() as usize, + _ => return None, + }; + + let field_idx = num_params + idx.to_u64().unwrap() as usize; + ctor_args.get(field_idx).cloned() +} + +/// Try to reduce a quotient operation. +fn try_reduce_quot( + name: &Name, + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let kind = match ci { + ConstantInfo::QuotInfo(q) => &q.kind, + _ => return None, + }; + + let (qmk_idx, rest_idx) = match kind { + QuotKind::Lift => (5, 6), + QuotKind::Ind => (4, 5), + _ => return None, + }; + + let qmk = args.get(qmk_idx)?; + let qmk_whnf = whnf(qmk, env); + + // Check that the head is Quot.mk + let (qmk_head, _) = unfold_apps(&qmk_whnf); + match qmk_head.as_data() { + ExprData::Const(n, _, _) if *n == mk_name2("Quot", "mk") => {}, + _ => return None, + } + + let f = args.get(3)?; + + // Extract the argument of Quot.mk + let qmk_arg = match qmk_whnf.as_data() { + ExprData::App(_, arg, _) => arg, + _ => return None, + }; + + let mut result = Expr::app(f.clone(), qmk_arg.clone()); + for arg in args.iter().skip(rest_idx) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Try to reduce nat operations. +fn try_reduce_nat(e: &Expr, env: &Env) -> Option { + if has_fvars(e) { + return None; + } + + let (head, args) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + match args.len() { + 1 => { + if *name == mk_name2("Nat", "succ") { + let arg_whnf = whnf(&args[0], env); + let n = get_nat_value(&arg_whnf)?; + Some(Expr::lit(Literal::NatVal(Nat(n + BigUint::from(1u64))))) + } else { + None + } + }, + 2 => { + let a_whnf = whnf(&args[0], env); + let b_whnf = whnf(&args[1], env); + let a = get_nat_value(&a_whnf)?; + let b = get_nat_value(&b_whnf)?; + + let result = if *name == mk_name2("Nat", "add") { + Some(Expr::lit(Literal::NatVal(Nat(a + b)))) + } else if *name == mk_name2("Nat", "sub") { + Some(Expr::lit(Literal::NatVal(Nat(if a >= b { + a - b + } else { + BigUint::ZERO + })))) + } else if *name == mk_name2("Nat", "mul") { + Some(Expr::lit(Literal::NatVal(Nat(a * b)))) + } else if *name == mk_name2("Nat", "div") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + BigUint::ZERO + } else { + a / b + })))) + } else if *name == mk_name2("Nat", "mod") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + a + } else { + a % b + })))) + } else if *name == mk_name2("Nat", "beq") { + bool_to_expr(a == b) + } else if *name == mk_name2("Nat", "ble") { + bool_to_expr(a <= b) + } else if *name == mk_name2("Nat", "pow") { + let exp = u32::try_from(&b).unwrap_or(u32::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a.pow(exp))))) + } else if *name == mk_name2("Nat", "land") { + Some(Expr::lit(Literal::NatVal(Nat(a & b)))) + } else if *name == mk_name2("Nat", "lor") { + Some(Expr::lit(Literal::NatVal(Nat(a | b)))) + } else if *name == mk_name2("Nat", "xor") { + Some(Expr::lit(Literal::NatVal(Nat(a ^ b)))) + } else if *name == mk_name2("Nat", "shiftLeft") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a << shift)))) + } else if *name == mk_name2("Nat", "shiftRight") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a >> shift)))) + } else if *name == mk_name2("Nat", "blt") { + bool_to_expr(a < b) + } else { + None + }; + result + }, + _ => None, + } +} + +fn get_nat_value(e: &Expr) -> Option { + match e.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => Some(n.0.clone()), + ExprData::Const(name, _, _) if *name == mk_name2("Nat", "zero") => { + Some(BigUint::ZERO) + }, + _ => None, + } +} + +fn bool_to_expr(b: bool) -> Option { + let name = if b { + mk_name2("Bool", "true") + } else { + mk_name2("Bool", "false") + }; + Some(Expr::cnst(name, vec![])) +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + #[test] + fn test_inst_bvar() { + let body = Expr::bvar(Nat::from(0)); + let arg = nat_zero(); + let result = inst(&body, &[arg.clone()]); + assert_eq!(result, arg); + } + + #[test] + fn test_inst_nested() { + // body = Lam(_, Nat, Bvar(1)) — references outer binder + // After inst with [zero], should become Lam(_, Nat, zero) + let body = Expr::lam( + Name::anon(), + nat_type(), + Expr::bvar(Nat::from(1)), + BinderInfo::Default, + ); + let result = inst(&body, &[nat_zero()]); + let expected = Expr::lam( + Name::anon(), + nat_type(), + nat_zero(), + BinderInfo::Default, + ); + assert_eq!(result, expected); + } + + #[test] + fn test_unfold_apps() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + assert_eq!(head, f); + assert_eq!(args.len(), 2); + assert_eq!(args[0], a); + assert_eq!(args[1], b); + } + + #[test] + fn test_beta_reduce_identity() { + // (fun x : Nat => x) Nat.zero + let id = Expr::lam( + Name::str(Name::anon(), "x".into()), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let e = Expr::app(id, nat_zero()); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_zeta_reduce() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + Name::str(Name::anon(), "x".into()), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0)), + false, + ); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Delta reduction + // ========================================================================== + + fn mk_defn_env(name: &str, value: Expr, typ: Expr) -> Env { + let mut env = Env::default(); + let n = mk_name(name); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + env + } + + #[test] + fn test_delta_unfold() { + // def myZero := Nat.zero + // whnf(myZero) = Nat.zero + let env = mk_defn_env("myZero", nat_zero(), nat_type()); + let e = Expr::cnst(mk_name("myZero"), vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_delta_opaque_no_unfold() { + // An opaque definition should NOT unfold + let mut env = Env::default(); + let n = mk_name("opaqueVal"); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Opaque, + safety: DefinitionSafety::Safe, + all: vec![n.clone()], + }), + ); + let e = Expr::cnst(n.clone(), vec![]); + let result = whnf(&e, &env); + // Should still be the constant, not unfolded + assert_eq!(result, e); + } + + #[test] + fn test_delta_chained() { + // def a := Nat.zero, def b := a => whnf(b) = Nat.zero + let mut env = Env::default(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let e = Expr::cnst(b, vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Nat arithmetic reduction + // ========================================================================== + + fn nat_lit(n: u64) -> Expr { + Expr::lit(Literal::NatVal(Nat::from(n))) + } + + #[test] + fn test_nat_add() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "add"), vec![]), nat_lit(3)), + nat_lit(4), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub_underflow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(3)), + nat_lit(10), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mul() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mul"), vec![]), nat_lit(6)), + nat_lit(7), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(42)); + } + + #[test] + fn test_nat_div() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(3)); + } + + #[test] + fn test_nat_div_by_zero() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(0), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mod() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mod"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(1)); + } + + #[test] + fn test_nat_beq_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_beq_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_ble_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(3)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_ble_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_pow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "pow"), vec![]), nat_lit(2)), + nat_lit(10), + ); + assert_eq!(whnf(&e, &env), nat_lit(1024)); + } + + #[test] + fn test_nat_land() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "land"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1000)); + } + + #[test] + fn test_nat_lor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "lor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1110)); + } + + #[test] + fn test_nat_xor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "xor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b0110)); + } + + #[test] + fn test_nat_shift_left() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftLeft"), vec![]), nat_lit(1)), + nat_lit(8), + ); + assert_eq!(whnf(&e, &env), nat_lit(256)); + } + + #[test] + fn test_nat_shift_right() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), nat_lit(256)), + nat_lit(4), + ); + assert_eq!(whnf(&e, &env), nat_lit(16)); + } + + #[test] + fn test_nat_blt_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(3)), + nat_lit(5), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_blt_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(5)), + nat_lit(3), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + // ========================================================================== + // Sort simplification in WHNF + // ========================================================================== + + #[test] + fn test_string_lit_proj_reduces() { + // Build an env with String, String.mk ctor, List, List.cons, List.nil, Char + let mut env = Env::default(); + let string_name = mk_name("String"); + let string_mk = mk_name2("String", "mk"); + let list_name = mk_name("List"); + let char_name = mk_name("Char"); + + // String : Sort 1 + env.insert( + string_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: string_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![string_name.clone()], + ctors: vec![string_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // String.mk : List Char → String (1 field, 0 params) + let list_char = Expr::app( + Expr::cnst(list_name, vec![Level::succ(Level::zero())]), + Expr::cnst(char_name, vec![]), + ); + env.insert( + string_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: string_mk, + level_params: vec![], + typ: Expr::all( + mk_name("data"), + list_char, + Expr::cnst(string_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: string_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Proj String 0 "hi" should reduce (not return None) + let proj = Expr::proj( + string_name, + Nat::from(0u64), + Expr::lit(Literal::StrVal("hi".into())), + ); + let result = whnf(&proj, &env); + // The result should NOT be a Proj anymore (it should have reduced) + assert!( + !matches!(result.as_data(), ExprData::Proj(..)), + "String projection should reduce, got: {:?}", + result + ); + } + + #[test] + fn test_whnf_sort_simplifies() { + // Sort(max 0 u) should simplify to Sort(u) + let env = Env::default(); + let u = Level::param(mk_name("u")); + let e = Expr::sort(Level::max(Level::zero(), u.clone())); + let result = whnf(&e, &env); + assert_eq!(result, Expr::sort(u)); + } + + // ========================================================================== + // Already-WHNF terms + // ========================================================================== + + #[test] + fn test_whnf_sort_unchanged() { + let env = Env::default(); + let e = Expr::sort(Level::zero()); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_lambda_unchanged() { + // A lambda without applied arguments is already WHNF + let env = Env::default(); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_pi_unchanged() { + let env = Env::default(); + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + // ========================================================================== + // Helper function tests + // ========================================================================== + + #[test] + fn test_has_loose_bvars_true() { + assert!(has_loose_bvars(&Expr::bvar(Nat::from(0)))); + } + + #[test] + fn test_has_loose_bvars_false_under_binder() { + // fun x : Nat => x — bvar(0) is bound, not loose + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + assert!(!has_loose_bvars(&e)); + } + + #[test] + fn test_has_loose_bvars_const() { + assert!(!has_loose_bvars(&nat_zero())); + } + + #[test] + fn test_has_fvars_true() { + assert!(has_fvars(&Expr::fvar(mk_name("x")))); + } + + #[test] + fn test_has_fvars_false() { + assert!(!has_fvars(&nat_zero())); + } + + #[test] + fn test_foldl_apps_roundtrip() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = nat_type(); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + let rebuilt = foldl_apps(head, args.into_iter()); + assert_eq!(rebuilt, e); + } + + #[test] + fn test_abstr_simple() { + // abstr(fvar("x"), [fvar("x")]) = bvar(0) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&x, &[x.clone()]); + assert_eq!(result, Expr::bvar(Nat::from(0))); + } + + #[test] + fn test_abstr_not_found() { + // abstr(Nat.zero, [fvar("x")]) = Nat.zero (unchanged) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&nat_zero(), &[x]); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_subst_expr_levels_simple() { + // Sort(u) with u := 0 => Sort(0) + let u_name = mk_name("u"); + let e = Expr::sort(Level::param(u_name.clone())); + let result = subst_expr_levels(&e, &[u_name], &[Level::zero()]); + assert_eq!(result, Expr::sort(Level::zero())); + } +} From 13da42f245f403af3588018ab89cdadce4e1763f Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 07:45:54 -0500 Subject: [PATCH 02/14] WIP kernel --- Ix/Address.lean | 1 + Ix/Cli/CheckCmd.lean | 122 ++ Ix/CompileM.lean | 16 +- Ix/DecompileM.lean | 24 +- Ix/Ixon.lean | 16 +- Ix/Kernel.lean | 44 + Ix/Kernel/Convert.lean | 841 +++++++++++ Ix/Kernel/Datatypes.lean | 181 +++ Ix/Kernel/DecompileM.lean | 254 ++++ Ix/Kernel/Equal.lean | 168 +++ Ix/Kernel/Eval.lean | 530 +++++++ Ix/Kernel/Infer.lean | 406 +++++ Ix/Kernel/Level.lean | 131 ++ Ix/Kernel/TypecheckM.lean | 180 +++ Ix/Kernel/Types.lean | 569 +++++++ Main.lean | 2 + Tests/Ix/Check.lean | 107 ++ Tests/Ix/Compile.lean | 73 +- Tests/Ix/KernelTests.lean | 761 ++++++++++ Tests/Ix/PP.lean | 333 +++++ Tests/Main.lean | 17 + docs/Ixon.md | 5 +- src/ix/decompile.rs | 57 +- src/ix/ixon/env.rs | 47 +- src/ix/ixon/serialize.rs | 2 - src/ix/kernel/convert.rs | 835 +++++++---- src/ix/kernel/dag.rs | 645 +++++++- src/ix/kernel/dag_tc.rs | 2857 ++++++++++++++++++++++++++++++++++++ src/ix/kernel/def_eq.rs | 480 +++++- src/ix/kernel/inductive.rs | 121 +- src/ix/kernel/level.rs | 58 +- src/ix/kernel/mod.rs | 1 + src/ix/kernel/tc.rs | 663 +++++++-- src/ix/kernel/upcopy.rs | 872 +++++------ src/ix/kernel/whnf.rs | 1674 +++++++++++++++------ src/lean/ffi.rs | 1 + src/lean/ffi/check.rs | 182 +++ src/lean/ffi/lean_env.rs | 6 +- 38 files changed, 11748 insertions(+), 1534 deletions(-) create mode 100644 Ix/Cli/CheckCmd.lean create mode 100644 Ix/Kernel.lean create mode 100644 Ix/Kernel/Convert.lean create mode 100644 Ix/Kernel/Datatypes.lean create mode 100644 Ix/Kernel/DecompileM.lean create mode 100644 Ix/Kernel/Equal.lean create mode 100644 Ix/Kernel/Eval.lean create mode 100644 Ix/Kernel/Infer.lean create mode 100644 Ix/Kernel/Level.lean create mode 100644 Ix/Kernel/TypecheckM.lean create mode 100644 Ix/Kernel/Types.lean create mode 100644 Tests/Ix/Check.lean create mode 100644 Tests/Ix/KernelTests.lean create mode 100644 Tests/Ix/PP.lean create mode 100644 src/ix/kernel/dag_tc.rs create mode 100644 src/lean/ffi/check.rs diff --git a/Ix/Address.lean b/Ix/Address.lean index ee11eb85..562dd028 100644 --- a/Ix/Address.lean +++ b/Ix/Address.lean @@ -14,6 +14,7 @@ structure Address where /-- Compute the Blake3 hash of a `ByteArray`, returning an `Address`. -/ def Address.blake3 (x: ByteArray) : Address := ⟨(Blake3.hash x).val⟩ + /-- Convert a nibble (0--15) to its lowercase hexadecimal character. -/ def hexOfNat : Nat -> Option Char | 0 => .some '0' diff --git a/Ix/Cli/CheckCmd.lean b/Ix/Cli/CheckCmd.lean new file mode 100644 index 00000000..f8e388f0 --- /dev/null +++ b/Ix/Cli/CheckCmd.lean @@ -0,0 +1,122 @@ +import Cli +import Ix.Common +import Ix.Kernel +import Ix.Meta +import Ix.CompileM +import Lean + +open System (FilePath) + +/-- If the project depends on Mathlib, download the Mathlib cache. -/ +private def fetchMathlibCache (cwd : Option FilePath) : IO Unit := do + let root := cwd.getD "." + let manifest := root / "lake-manifest.json" + let contents ← IO.FS.readFile manifest + if contents.containsSubstr "leanprover-community/mathlib4" then + let mathlibBuild := root / ".lake" / "packages" / "mathlib" / ".lake" / "build" + if ← mathlibBuild.pathExists then + println! "Mathlib cache already present, skipping fetch." + return + println! "Detected Mathlib dependency. Fetching Mathlib cache..." + let child ← IO.Process.spawn { + cmd := "lake" + args := #["exe", "cache", "get"] + cwd := cwd + stdout := .inherit + stderr := .inherit + } + let exitCode ← child.wait + if exitCode != 0 then + throw $ IO.userError "lake exe cache get failed" + +/-- Build the Lean module at the given file path using Lake. -/ +private def buildFile (path : FilePath) : IO Unit := do + let path ← IO.FS.realPath path + let some moduleName := path.fileStem + | throw $ IO.userError s!"cannot determine module name from {path}" + fetchMathlibCache path.parent + let child ← IO.Process.spawn { + cmd := "lake" + args := #["build", moduleName] + cwd := path.parent + stdout := .inherit + stderr := .inherit + } + let exitCode ← child.wait + if exitCode != 0 then + throw $ IO.userError "lake build failed" + +/-- Run the Lean NbE kernel checker. -/ +private def runLeanCheck (leanEnv : Lean.Environment) : IO UInt32 := do + println! "Compiling to Ixon..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileElapsed := (← IO.monoMsNow) - compileStart + let numConsts := ixonEnv.consts.size + println! "Compiled {numConsts} constants in {compileElapsed.formatMs}" + + println! "Converting Ixon → Kernel..." + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + println! "Conversion error: {e}" + return 1 + | .ok (kenv, prims, quotInit) => + let convertElapsed := (← IO.monoMsNow) - convertStart + println! "Converted {kenv.size} constants in {convertElapsed.formatMs}" + + println! "Typechecking..." + let checkStart ← IO.monoMsNow + match Ix.Kernel.typecheckAll kenv prims quotInit with + | .error e => + let elapsed := (← IO.monoMsNow) - checkStart + println! "Kernel check failed in {elapsed.formatMs}: {e}" + return 1 + | .ok () => + let elapsed := (← IO.monoMsNow) - checkStart + println! "Checked {kenv.size} constants in {elapsed.formatMs}" + return 0 + +/-- Run the Rust kernel checker. -/ +private def runRustCheck (leanEnv : Lean.Environment) : IO UInt32 := do + let totalConsts := leanEnv.constants.toList.length + println! "Total constants: {totalConsts}" + + let start ← IO.monoMsNow + let errors ← Ix.Kernel.rsCheckEnv leanEnv + let elapsed := (← IO.monoMsNow) - start + + if errors.isEmpty then + println! "Checked {totalConsts} constants in {elapsed.formatMs}" + return 0 + else + println! "Kernel check failed with {errors.size} error(s) in {elapsed.formatMs}:" + for (name, err) in errors[:min 50 errors.size] do + println! " {repr name}: {repr err}" + return 1 + +def runCheckCmd (p : Cli.Parsed) : IO UInt32 := do + let some path := p.flag? "path" + | p.printError "error: must specify --path" + return 1 + let pathStr := path.as! String + let useLean := p.hasFlag "lean" + + buildFile pathStr + let leanEnv ← getFileEnv pathStr + + if useLean then + println! "Running Lean NbE kernel checker on {pathStr}" + runLeanCheck leanEnv + else + println! "Running Rust kernel checker on {pathStr}" + runRustCheck leanEnv + +def checkCmd : Cli.Cmd := `[Cli| + check VIA runCheckCmd; + "Type-check Lean file with kernel" + + FLAGS: + path : String; "Path to file to check" + lean; "Use Lean NbE kernel instead of Rust kernel" +] diff --git a/Ix/CompileM.lean b/Ix/CompileM.lean index e527f62c..efd8abd2 100644 --- a/Ix/CompileM.lean +++ b/Ix/CompileM.lean @@ -1604,11 +1604,10 @@ def compileEnv (env : Ix.Environment) (blocks : Ix.CondensedBlocks) (dbg : Bool -- Build reverse index and names map, storing name string components as blobs -- Seed with blockNames collected during compilation (binder names, level params, etc.) - let (addrToNameMap, namesMap, nameBlobs) := - compileEnv.nameToNamed.fold (init := ({}, blockNames, {})) fun (addrMap, namesMap, blobs) name named => - let addrMap := addrMap.insert named.addr name + let (namesMap, nameBlobs) := + compileEnv.nameToNamed.fold (init := (blockNames, {})) fun (namesMap, blobs) name _named => let (namesMap, blobs) := Ixon.RawEnv.addNameComponentsWithBlobs namesMap blobs name - (addrMap, namesMap, blobs) + (namesMap, blobs) -- Merge name string blobs into the main blobs map let allBlobs := nameBlobs.fold (fun m k v => m.insert k v) compileEnv.blobs @@ -1619,7 +1618,6 @@ def compileEnv (env : Ix.Environment) (blocks : Ix.CondensedBlocks) (dbg : Bool blobs := allBlobs names := namesMap comms := {} - addrToName := addrToNameMap } return .ok (ixonEnv, compileEnv.totalBytes) @@ -1890,11 +1888,10 @@ def compileEnvParallel (env : Ix.Environment) (blocks : Ix.CondensedBlocks) -- Build reverse index and names map, storing name string components as blobs -- Seed with blockNames collected during compilation (binder names, level params, etc.) - let (addrToNameMap, namesMap, nameBlobs) := - nameToNamed.fold (init := ({}, blockNames, {})) fun (addrMap, namesMap, nameBlobs) name named => - let addrMap := addrMap.insert named.addr name + let (namesMap, nameBlobs) := + nameToNamed.fold (init := (blockNames, {})) fun (namesMap, nameBlobs) name _named => let (namesMap, nameBlobs) := Ixon.RawEnv.addNameComponentsWithBlobs namesMap nameBlobs name - (addrMap, namesMap, nameBlobs) + (namesMap, nameBlobs) -- Merge name string blobs into the main blobs map let blockBlobCount := blobs.size @@ -1912,7 +1909,6 @@ def compileEnvParallel (env : Ix.Environment) (blocks : Ix.CondensedBlocks) blobs := allBlobs names := namesMap comms := {} - addrToName := addrToNameMap } return .ok (ixonEnv, totalBytes) diff --git a/Ix/DecompileM.lean b/Ix/DecompileM.lean index d22fb8f7..e1e8050b 100644 --- a/Ix/DecompileM.lean +++ b/Ix/DecompileM.lean @@ -117,12 +117,6 @@ def lookupNameAddrOrAnon (addr : Address) : DecompileM Ix.Name := do | some n => pure n | none => pure Ix.Name.mkAnon -/-- Resolve constant Address → Ix.Name via addrToName. -/ -def lookupConstName (addr : Address) : DecompileM Ix.Name := do - match (← getEnv).ixonEnv.addrToName.get? addr with - | some n => pure n - | none => throw (.missingAddress addr) - def lookupBlob (addr : Address) : DecompileM ByteArray := do match (← getEnv).ixonEnv.blobs.get? addr with | some blob => pure blob @@ -390,18 +384,14 @@ partial def decompileExpr (e : Ixon.Expr) (arenaIdx : UInt64) : DecompileM Ix.Ex pure (applyMdata (Ix.Expr.mkLit (.strVal s)) mdataLayers) -- Ref with arena metadata - | .ref nameAddr, .ref refIdx univIndices => do - let name ← match (← getEnv).ixonEnv.names.get? nameAddr with - | some n => pure n - | none => getRef refIdx >>= lookupConstName + | .ref nameAddr, .ref _refIdx univIndices => do + let name ← lookupNameAddr nameAddr let lvls ← decompileUnivIndices univIndices pure (applyMdata (Ix.Expr.mkConst name lvls) mdataLayers) -- Ref without arena metadata - | _, .ref refIdx univIndices => do - let name ← getRef refIdx >>= lookupConstName - let lvls ← decompileUnivIndices univIndices - pure (applyMdata (Ix.Expr.mkConst name lvls) mdataLayers) + | _, .ref _refIdx _univIndices => do + throw (.badConstantFormat "ref without arena metadata") -- Rec with arena metadata | .ref nameAddr, .recur recIdx univIndices => do @@ -472,10 +462,8 @@ partial def decompileExpr (e : Ixon.Expr) (arenaIdx : UInt64) : DecompileM Ix.Ex let valExpr ← decompileExpr val child pure (applyMdata (Ix.Expr.mkProj typeName fieldIdx.toNat valExpr) mdataLayers) - | _, .prj typeRefIdx fieldIdx val => do - let typeName ← getRef typeRefIdx >>= lookupConstName - let valExpr ← decompileExpr val UInt64.MAX - pure (applyMdata (Ix.Expr.mkProj typeName fieldIdx.toNat valExpr) mdataLayers) + | _, .prj _typeRefIdx _fieldIdx _val => do + throw (.badConstantFormat "prj without arena metadata") | _, .share _ => throw (.badConstantFormat "unexpected Share in decompileExpr") diff --git a/Ix/Ixon.lean b/Ix/Ixon.lean index 5432d12c..cc4d1d11 100644 --- a/Ix/Ixon.lean +++ b/Ix/Ixon.lean @@ -1380,12 +1380,10 @@ structure Env where named : Std.HashMap Ix.Name Named := {} /-- Raw data blobs: Address → bytes -/ blobs : Std.HashMap Address ByteArray := {} - /-- Hash-consed name components: Address → Ix.Name -/ - names : Std.HashMap Address Ix.Name := {} /-- Cryptographic commitments: Address → Comm -/ comms : Std.HashMap Address Comm := {} - /-- Reverse index: constant Address → Ix.Name -/ - addrToName : Std.HashMap Address Ix.Name := {} + /-- Hash-consed name components: Address → Ix.Name -/ + names : Std.HashMap Address Ix.Name := {} deriving Inhabited namespace Env @@ -1401,8 +1399,7 @@ def getConst? (env : Env) (addr : Address) : Option Constant := /-- Register a name with full Named metadata. -/ def registerName (env : Env) (name : Ix.Name) (named : Named) : Env := { env with - named := env.named.insert name named - addrToName := env.addrToName.insert named.addr name } + named := env.named.insert name named } /-- Register a name with just an address (empty metadata). -/ def registerNameAddr (env : Env) (name : Ix.Name) (addr : Address) : Env := @@ -1416,10 +1413,6 @@ def getAddr? (env : Env) (name : Ix.Name) : Option Address := def getNamed? (env : Env) (name : Ix.Name) : Option Named := env.named.get? name -/-- Look up an address's name. -/ -def getName? (env : Env) (addr : Address) : Option Ix.Name := - env.addrToName.get? addr - /-- Store a blob and return its content address. -/ def storeBlob (env : Env) (bytes : ByteArray) : Env × Address := let addr := Address.blake3 bytes @@ -1742,8 +1735,7 @@ def getEnv : GetM Env := do | some name => let namedEntry : Named := ⟨constAddr, constMeta⟩ env := { env with - named := env.named.insert name namedEntry - addrToName := env.addrToName.insert constAddr name } + named := env.named.insert name namedEntry } | none => throw s!"getEnv: named entry references unknown name address {reprStr (toString nameAddr)}" diff --git a/Ix/Kernel.lean b/Ix/Kernel.lean new file mode 100644 index 00000000..cbb6c467 --- /dev/null +++ b/Ix/Kernel.lean @@ -0,0 +1,44 @@ +import Lean +import Ix.Environment +import Ix.Kernel.Types +import Ix.Kernel.Datatypes +import Ix.Kernel.Level +import Ix.Kernel.TypecheckM +import Ix.Kernel.Eval +import Ix.Kernel.Equal +import Ix.Kernel.Infer +import Ix.Kernel.Convert + +namespace Ix.Kernel + +/-- Type-checking errors from the Rust kernel, mirroring `TcError` in Rust. -/ +inductive CheckError where + | typeExpected (expr : Ix.Expr) (inferred : Ix.Expr) + | functionExpected (expr : Ix.Expr) (inferred : Ix.Expr) + | typeMismatch (expected : Ix.Expr) (found : Ix.Expr) (expr : Ix.Expr) + | defEqFailure (lhs : Ix.Expr) (rhs : Ix.Expr) + | unknownConst (name : Ix.Name) + | duplicateUniverse (name : Ix.Name) + | freeBoundVariable (idx : UInt64) + | kernelException (msg : String) + deriving Repr + +/-- FFI: Run Rust kernel type-checker over all declarations in a Lean environment. -/ +@[extern "rs_check_env"] +opaque rsCheckEnvFFI : @& List (Lean.Name × Lean.ConstantInfo) → IO (Array (Ix.Name × CheckError)) + +/-- Check all declarations in a Lean environment using the Rust kernel. + Returns an array of (name, error) pairs for any declarations that fail. -/ +def rsCheckEnv (leanEnv : Lean.Environment) : IO (Array (Ix.Name × CheckError)) := + rsCheckEnvFFI leanEnv.constants.toList + +/-- FFI: Type-check a single constant by dotted name string. -/ +@[extern "rs_check_const"] +opaque rsCheckConstFFI : @& List (Lean.Name × Lean.ConstantInfo) → @& String → IO (Option CheckError) + +/-- Check a single constant by name using the Rust kernel. + Returns `none` on success, `some err` on failure. -/ +def rsCheckConst (leanEnv : Lean.Environment) (name : String) : IO (Option CheckError) := + rsCheckConstFFI leanEnv.constants.toList name + +end Ix.Kernel diff --git a/Ix/Kernel/Convert.lean b/Ix/Kernel/Convert.lean new file mode 100644 index 00000000..369ffca2 --- /dev/null +++ b/Ix/Kernel/Convert.lean @@ -0,0 +1,841 @@ +/- + Kernel Convert: Ixon.Env → Kernel.Env conversion. + + Two modes: + - `convert` produces `Kernel.Env .meta` with full names and binder info + - `convertAnon` produces `Kernel.Env .anon` with all metadata as () + + Much simpler than DecompileM: no Blake3 hash computation, no mdata reconstruction. +-/ +import Ix.Kernel.Types +import Ix.Ixon + +namespace Ix.Kernel.Convert + +open Ix (Name) +open Ixon (Constant ConstantInfo ConstantMeta MutConst Named) + +/-! ## Universe conversion -/ + +partial def convertUniv (m : MetaMode) (levelParamNames : Array (MetaField m Ix.Name) := #[]) + : Ixon.Univ → Level m + | .zero => .zero + | .succ l => .succ (convertUniv m levelParamNames l) + | .max l₁ l₂ => .max (convertUniv m levelParamNames l₁) (convertUniv m levelParamNames l₂) + | .imax l₁ l₂ => .imax (convertUniv m levelParamNames l₁) (convertUniv m levelParamNames l₂) + | .var idx => + let name := if h : idx.toNat < levelParamNames.size then levelParamNames[idx.toNat] else default + .param idx.toNat name + +/-! ## Expression conversion monad -/ + +structure ConvertEnv (m : MetaMode) where + sharing : Array Ixon.Expr + refs : Array Address + univs : Array Ixon.Univ + blobs : Std.HashMap Address ByteArray + recurAddrs : Array Address := #[] + arena : Ixon.ExprMetaArena := {} + names : Std.HashMap Address Ix.Name := {} + levelParamNames : Array (MetaField m Ix.Name) := #[] + binderNames : List (MetaField m Ix.Name) := [] + +structure ConvertState (m : MetaMode) where + exprCache : Std.HashMap (UInt64 × UInt64) (Expr m) := {} + +inductive ConvertError where + | refOutOfBounds (refIdx : UInt64) (refsSize : Nat) + | recurOutOfBounds (recIdx : UInt64) (recurAddrsSize : Nat) + | prjRefOutOfBounds (typeRefIdx : UInt64) (refsSize : Nat) + | missingMemberAddr (memberIdx : Nat) (numMembers : Nat) + | unresolvableCtxAddr (addr : Address) + | missingName (nameAddr : Address) + +instance : ToString ConvertError where + toString + | .refOutOfBounds idx sz => s!"ref index {idx} out of bounds (refs.size={sz})" + | .recurOutOfBounds idx sz => s!"recur index {idx} out of bounds (recurAddrs.size={sz})" + | .prjRefOutOfBounds idx sz => s!"proj type ref index {idx} out of bounds (refs.size={sz})" + | .missingMemberAddr idx n => s!"no address for member {idx} (numMembers={n})" + | .unresolvableCtxAddr addr => s!"unresolvable ctx address {addr}" + | .missingName addr => s!"missing name for address {addr}" + +abbrev ConvertM (m : MetaMode) := ReaderT (ConvertEnv m) (StateT (ConvertState m) (ExceptT ConvertError Id)) + +def ConvertState.init (_ : ConvertEnv m) : ConvertState m := {} + +def ConvertM.run (env : ConvertEnv m) (x : ConvertM m α) : Except ConvertError α := + match x env |>.run (ConvertState.init env) with + | .ok (a, _) => .ok a + | .error e => .error e + +/-- Run a ConvertM computation with existing state, return result and final state. -/ +def ConvertM.runWith (env : ConvertEnv m) (st : ConvertState m) (x : ConvertM m α) + : Except ConvertError (α × ConvertState m) := + x env |>.run st + +/-! ## Expression conversion -/ + +def resolveUnivs (m : MetaMode) (idxs : Array UInt64) : ConvertM m (Array (Level m)) := do + let ctx ← read + return idxs.map fun i => + if h : i.toNat < ctx.univs.size + then convertUniv m ctx.levelParamNames ctx.univs[i.toNat] + else .zero + +def decodeBlobNat (bytes : ByteArray) : Nat := Id.run do + let mut acc := 0 + for i in [:bytes.size] do + acc := acc + bytes[i]!.toNat * 256 ^ i + return acc + +def decodeBlobStr (bytes : ByteArray) : String := + String.fromUTF8! bytes + +/-- Look up an arena node by index, automatically unwrapping `.mdata` wrappers. -/ +partial def getArenaNode (idx : Option UInt64) : ConvertM m (Option Ixon.ExprMetaData) := do + match idx with + | none => return none + | some i => + let ctx ← read + if h : i.toNat < ctx.arena.nodes.size + then match ctx.arena.nodes[i.toNat] with + | .mdata _ child => getArenaNode (some child) + | node => return some node + else return none + +def mkMetaName (m : MetaMode) (name? : Option Ix.Name) : MetaField m Ix.Name := + match m with + | .meta => name?.getD default + | .anon => () + +/-- Resolve a name hash Address to a MetaField name via the names table. -/ +def resolveName (nameAddr : Address) : ConvertM m (MetaField m Ix.Name) := do + let ctx ← read + match ctx.names.get? nameAddr with + | some name => return (mkMetaName m (some name)) + | none => throw (.missingName nameAddr) + +partial def convertExpr (m : MetaMode) (expr : Ixon.Expr) (metaIdx : Option UInt64 := none) + : ConvertM m (Expr m) := do + -- 1. Expand share transparently, passing arena index through (same as DecompileM) + match expr with + | .share idx => + let ctx ← read + if h : idx.toNat < ctx.sharing.size then + convertExpr m ctx.sharing[idx.toNat] metaIdx + else return default + | _ => + + -- 1b. Handle .var before cache (binder names are context-dependent) + if let .var idx := expr then + let name := match (← read).binderNames[idx.toNat]? with + | some n => n | none => default + return (.bvar idx.toNat name) + + -- 2. Check cache (keyed on expression hash + arena index) + let cacheKey := (hash expr, metaIdx.getD UInt64.MAX) + if let some cached := (← get).exprCache.get? cacheKey then return cached + + -- 3. Resolve arena node + let node ← getArenaNode metaIdx + + -- 4. Convert expression + let result ← match expr with + | .sort idx => do + let ctx ← read + if h : idx.toNat < ctx.univs.size + then pure (.sort (convertUniv m ctx.levelParamNames ctx.univs[idx.toNat])) + else pure (.sort .zero) + | .var _ => pure default -- unreachable, handled above + | .ref refIdx univIdxs => do + let ctx ← read + let levels ← resolveUnivs m univIdxs + let addr ← match ctx.refs[refIdx.toNat]? with + | some a => pure a + | none => throw (.refOutOfBounds refIdx ctx.refs.size) + let name ← match node with + | some (.ref nameAddr) => resolveName nameAddr + | _ => pure default + pure (.const addr levels name) + | .recur recIdx univIdxs => do + let ctx ← read + let levels ← resolveUnivs m univIdxs + let addr ← match ctx.recurAddrs[recIdx.toNat]? with + | some a => pure a + | none => throw (.recurOutOfBounds recIdx ctx.recurAddrs.size) + let name ← match node with + | some (.ref nameAddr) => resolveName nameAddr + | _ => pure default + pure (.const addr levels name) + | .prj typeRefIdx fieldIdx struct => do + let ctx ← read + let typeAddr ← match ctx.refs[typeRefIdx.toNat]? with + | some a => pure a + | none => throw (.prjRefOutOfBounds typeRefIdx ctx.refs.size) + let (structChild, typeName) ← match node with + | some (.prj structNameAddr child) => do + let n ← resolveName structNameAddr + pure (some child, n) + | _ => pure (none, default) + let s ← convertExpr m struct structChild + pure (.proj typeAddr fieldIdx.toNat s typeName) + | .str blobRefIdx => do + let ctx ← read + if h : blobRefIdx.toNat < ctx.refs.size then + let blobAddr := ctx.refs[blobRefIdx.toNat] + match ctx.blobs.get? blobAddr with + | some bytes => pure (.lit (.strVal (decodeBlobStr bytes))) + | none => pure (.lit (.strVal "")) + else pure (.lit (.strVal "")) + | .nat blobRefIdx => do + let ctx ← read + if h : blobRefIdx.toNat < ctx.refs.size then + let blobAddr := ctx.refs[blobRefIdx.toNat] + match ctx.blobs.get? blobAddr with + | some bytes => pure (.lit (.natVal (decodeBlobNat bytes))) + | none => pure (.lit (.natVal 0)) + else pure (.lit (.natVal 0)) + | .app fn arg => do + let (fnChild, argChild) := match node with + | some (.app f a) => (some f, some a) + | _ => (none, none) + let f ← convertExpr m fn fnChild + let a ← convertExpr m arg argChild + pure (.app f a) + | .lam ty body => do + let (name, bi, tyChild, bodyChild) ← match node with + | some (.binder nameAddr info tyC bodyC) => do + let n ← resolveName nameAddr + let i : MetaField m Lean.BinderInfo := match m with | .meta => info | .anon => () + pure (n, i, some tyC, some bodyC) + | _ => pure (default, default, none, none) + let t ← convertExpr m ty tyChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.lam t b name bi) + | .all ty body => do + let (name, bi, tyChild, bodyChild) ← match node with + | some (.binder nameAddr info tyC bodyC) => do + let n ← resolveName nameAddr + let i : MetaField m Lean.BinderInfo := match m with | .meta => info | .anon => () + pure (n, i, some tyC, some bodyC) + | _ => pure (default, default, none, none) + let t ← convertExpr m ty tyChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.forallE t b name bi) + | .letE _nonDep ty val body => do + let (name, tyChild, valChild, bodyChild) ← match node with + | some (.letBinder nameAddr tyC valC bodyC) => do + let n ← resolveName nameAddr + pure (n, some tyC, some valC, some bodyC) + | _ => pure (default, none, none, none) + let t ← convertExpr m ty tyChild + let v ← convertExpr m val valChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.letE t v b name) + | .share _ => pure default -- unreachable, handled above + + -- 5. Cache and return + modify fun s => { s with exprCache := s.exprCache.insert cacheKey result } + pure result + +/-! ## Enum conversions -/ + +def convertHints : Lean.ReducibilityHints → ReducibilityHints + | .opaque => .opaque + | .abbrev => .abbrev + | .regular h => .regular h + +def convertSafety : Ix.DefinitionSafety → DefinitionSafety + | .unsaf => .unsafe + | .safe => .safe + | .part => .partial + +def convertQuotKind : Ix.QuotKind → QuotKind + | .type => .type + | .ctor => .ctor + | .lift => .lift + | .ind => .ind + +/-! ## Constant conversion helpers -/ + +def mkConvertEnv (m : MetaMode) (c : Constant) (blobs : Std.HashMap Address ByteArray) + (recurAddrs : Array Address := #[]) + (arena : Ixon.ExprMetaArena := {}) + (names : Std.HashMap Address Ix.Name := {}) + (levelParamNames : Array (MetaField m Ix.Name) := #[]) : ConvertEnv m := + { sharing := c.sharing, refs := c.refs, univs := c.univs, blobs, recurAddrs, arena, names, + levelParamNames } + +def mkConstantVal (m : MetaMode) (numLvls : UInt64) (typ : Expr m) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) : ConstantVal m := + { numLevels := numLvls.toNat, type := typ, name, levelParams } + +/-! ## Factored constant conversion helpers -/ + +/-- Extract arena from ConstantMeta. -/ +def metaArena : ConstantMeta → Ixon.ExprMetaArena + | .defn _ _ _ _ _ a _ _ => a + | .axio _ _ a _ => a + | .quot _ _ a _ => a + | .indc _ _ _ _ _ a _ => a + | .ctor _ _ _ a _ => a + | .recr _ _ _ _ _ a _ _ => a + | .empty => {} + +/-- Extract type root index from ConstantMeta. -/ +def metaTypeRoot? : ConstantMeta → Option UInt64 + | .defn _ _ _ _ _ _ r _ => some r + | .axio _ _ _ r => some r + | .quot _ _ _ r => some r + | .indc _ _ _ _ _ _ r => some r + | .ctor _ _ _ _ r => some r + | .recr _ _ _ _ _ _ r _ => some r + | .empty => none + +/-- Extract value root index from ConstantMeta (defn only). -/ +def metaValueRoot? : ConstantMeta → Option UInt64 + | .defn _ _ _ _ _ _ _ r => some r + | .empty => none + | _ => none + +/-- Extract level param name addresses from ConstantMeta. -/ +def metaLvlAddrs : ConstantMeta → Array Address + | .defn _ lvls _ _ _ _ _ _ => lvls + | .axio _ lvls _ _ => lvls + | .quot _ lvls _ _ => lvls + | .indc _ lvls _ _ _ _ _ => lvls + | .ctor _ lvls _ _ _ => lvls + | .recr _ lvls _ _ _ _ _ _ => lvls + | .empty => #[] + +/-- Resolve level param addresses to MetaField names via the names table. -/ +def resolveLevelParams (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (lvlAddrs : Array Address) : Array (MetaField m Ix.Name) := + match m with + | .anon => lvlAddrs.map fun _ => () + | .meta => lvlAddrs.map fun addr => names.getD addr default + +/-- Build the MetaField levelParams value from resolved names. -/ +def mkLevelParams (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (lvlAddrs : Array Address) : MetaField m (Array Ix.Name) := + match m with + | .anon => () + | .meta => lvlAddrs.map fun addr => names.getD addr default + +/-- Resolve an array of name-hash addresses to a MetaField array of names. -/ +def resolveMetaNames (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (addrs : Array Address) : MetaField m (Array Ix.Name) := + match m with | .anon => () | .meta => addrs.map fun a => names.getD a default + +/-- Resolve a single name-hash address to a MetaField name. -/ +def resolveMetaName (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (addr : Address) : MetaField m Ix.Name := + match m with | .anon => () | .meta => names.getD addr default + +/-- Extract rule root indices from ConstantMeta (recr only). -/ +def metaRuleRoots : ConstantMeta → Array UInt64 + | .recr _ _ _ _ _ _ _ rs => rs + | _ => #[] + +def convertRule (m : MetaMode) (rule : Ixon.RecursorRule) (ctorAddr : Address) + (ctorName : MetaField m Ix.Name := default) + (ruleRoot : Option UInt64 := none) : + ConvertM m (Ix.Kernel.RecursorRule m) := do + let rhs ← convertExpr m rule.rhs ruleRoot + return { ctor := ctorAddr, ctorName, nfields := rule.fields.toNat, rhs } + +def convertDefinition (m : MetaMode) (d : Ixon.Definition) + (hints : ReducibilityHints) (all : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m d.typ (metaTypeRoot? cMeta) + let value ← convertExpr m d.value (metaValueRoot? cMeta) + let cv := mkConstantVal m d.lvls typ name levelParams + match d.kind with + | .defn => return .defnInfo { toConstantVal := cv, value, hints, safety := convertSafety d.safety, all, allNames } + | .opaq => return .opaqueInfo { toConstantVal := cv, value, isUnsafe := d.safety == .unsaf, all, allNames } + | .thm => return .thmInfo { toConstantVal := cv, value, all, allNames } + +def convertAxiom (m : MetaMode) (a : Ixon.Axiom) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m a.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m a.lvls typ name levelParams + return .axiomInfo { toConstantVal := cv, isUnsafe := a.isUnsafe } + +def convertQuotient (m : MetaMode) (q : Ixon.Quotient) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m q.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m q.lvls typ name levelParams + return .quotInfo { toConstantVal := cv, kind := convertQuotKind q.kind } + +def convertInductive (m : MetaMode) (ind : Ixon.Inductive) + (ctorAddrs all : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) + (ctorNames : MetaField m (Array Ix.Name) := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m ind.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m ind.lvls typ name levelParams + let v : Ix.Kernel.InductiveVal m := + { toConstantVal := cv, numParams := ind.params.toNat, + numIndices := ind.indices.toNat, all, ctors := ctorAddrs, allNames, ctorNames, + numNested := ind.nested.toNat, isRec := ind.recr, isUnsafe := ind.isUnsafe, + isReflexive := ind.refl } + return .inductInfo v + +def convertConstructor (m : MetaMode) (c : Ixon.Constructor) + (inductAddr : Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (inductName : MetaField m Ix.Name := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m c.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m c.lvls typ name levelParams + let v : Ix.Kernel.ConstructorVal m := + { toConstantVal := cv, induct := inductAddr, inductName, + cidx := c.cidx.toNat, numParams := c.params.toNat, numFields := c.fields.toNat, + isUnsafe := c.isUnsafe } + return .ctorInfo v + +def convertRecursor (m : MetaMode) (r : Ixon.Recursor) + (all ruleCtorAddrs : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) + (ruleCtorNames : Array (MetaField m Ix.Name) := #[]) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m r.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m r.lvls typ name levelParams + let ruleRoots := (metaRuleRoots cMeta) + let mut rules : Array (Ix.Kernel.RecursorRule m) := #[] + for i in [:r.rules.size] do + let ctorAddr := if h : i < ruleCtorAddrs.size then ruleCtorAddrs[i] else default + let ctorName := if h : i < ruleCtorNames.size then ruleCtorNames[i] else default + let ruleRoot := if h : i < ruleRoots.size then some ruleRoots[i] else none + rules := rules.push (← convertRule m r.rules[i]! ctorAddr ctorName ruleRoot) + let v : Ix.Kernel.RecursorVal m := + { toConstantVal := cv, all, allNames, + numParams := r.params.toNat, numIndices := r.indices.toNat, + numMotives := r.motives.toNat, numMinors := r.minors.toNat, + rules, k := r.k, isUnsafe := r.isUnsafe } + return .recInfo v + +/-! ## Metadata helpers -/ + +/-- Build a direct name-hash Address → constant Address lookup table. -/ +def buildHashToAddr (ixonEnv : Ixon.Env) : Std.HashMap Address Address := Id.run do + let mut acc : Std.HashMap Address Address := {} + for (nameHash, name) in ixonEnv.names do + match ixonEnv.named.get? name with + | some entry => acc := acc.insert nameHash entry.addr + | none => pure () + return acc + +/-- Extract block address from a projection constant, if it is one. -/ +def projBlockAddr : Ixon.ConstantInfo → Option Address + | .iPrj prj => some prj.block + | .cPrj prj => some prj.block + | .rPrj prj => some prj.block + | .dPrj prj => some prj.block + | _ => none + +/-! ## BlockIndex -/ + +/-- Cross-reference index for projections within a single muts block. + Built from the block group before conversion so we can derive addresses + without relying on metadata. -/ +structure BlockIndex where + /-- memberIdx → iPrj address (inductive type address) -/ + inductAddrs : Std.HashMap UInt64 Address := {} + /-- memberIdx → Array of cPrj addresses, ordered by cidx -/ + ctorAddrs : Std.HashMap UInt64 (Array Address) := {} + /-- All iPrj addresses in the block (the `all` array for inductives/recursors) -/ + allInductAddrs : Array Address := #[] + /-- memberIdx → primary projection address (for .recur resolution). + iPrj for inductives, dPrj for definitions. -/ + memberAddrs : Std.HashMap UInt64 Address := {} + +/-- Build a BlockIndex from a group of projections. -/ +def buildBlockIndex (projections : Array (Address × Constant)) : BlockIndex := Id.run do + let mut inductAddrs : Std.HashMap UInt64 Address := {} + let mut ctorEntries : Std.HashMap UInt64 (Array (UInt64 × Address)) := {} + let mut allInductAddrs : Array Address := #[] + let mut memberAddrs : Std.HashMap UInt64 Address := {} + for (addr, projConst) in projections do + match projConst.info with + | .iPrj prj => + inductAddrs := inductAddrs.insert prj.idx addr + allInductAddrs := allInductAddrs.push addr + memberAddrs := memberAddrs.insert prj.idx addr + | .cPrj prj => + let entries := ctorEntries.getD prj.idx #[] + ctorEntries := ctorEntries.insert prj.idx (entries.push (prj.cidx, addr)) + | .dPrj prj => + memberAddrs := memberAddrs.insert prj.idx addr + | .rPrj prj => + -- Only set if no iPrj/dPrj already set for this member + if !memberAddrs.contains prj.idx then + memberAddrs := memberAddrs.insert prj.idx addr + | _ => pure () + -- Sort constructor entries by cidx and extract just addresses + let mut ctorAddrs : Std.HashMap UInt64 (Array Address) := {} + for (idx, entries) in ctorEntries do + let sorted := entries.insertionSort (fun a b => a.1 < b.1) + ctorAddrs := ctorAddrs.insert idx (sorted.map (·.2)) + { inductAddrs, ctorAddrs, allInductAddrs, memberAddrs } + +/-- All constructor addresses in declaration order (by inductive member index, then cidx). + This matches the order of RecursorVal.rules in the Lean kernel. -/ +def BlockIndex.allCtorAddrsInOrder (bIdx : BlockIndex) : Array Address := Id.run do + let sorted := bIdx.inductAddrs.toArray.insertionSort (fun a b => a.1 < b.1) + let mut result : Array Address := #[] + for (idx, _) in sorted do + result := result ++ (bIdx.ctorAddrs.getD idx #[]) + result + +/-- Build recurAddrs array from BlockIndex. Maps member index → projection address. -/ +def buildRecurAddrs (bIdx : BlockIndex) (numMembers : Nat) : Except ConvertError (Array Address) := do + let mut addrs : Array Address := #[] + for i in [:numMembers] do + match bIdx.memberAddrs.get? i.toUInt64 with + | some addr => addrs := addrs.push addr + | none => throw (.missingMemberAddr i numMembers) + return addrs + +/-! ## Projection conversion -/ + +/-- Convert a single projection constant as a ConvertM action. + Uses BlockIndex for cross-references instead of metadata. -/ +def convertProjAction (m : MetaMode) + (addr : Address) (c : Constant) + (blockConst : Constant) (bIdx : BlockIndex) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (names : Std.HashMap Address Ix.Name := {}) + : Except String (ConvertM m (Ix.Kernel.ConstantInfo m)) := do + let .muts members := blockConst.info + | .error s!"projection block is not a muts at {addr}" + match c.info with + | .iPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .indc ind => + let ctorAs := bIdx.ctorAddrs.getD prj.idx #[] + let allNs := resolveMetaNames m names (match cMeta with | .indc _ _ _ a _ _ _ => a | _ => #[]) + let ctorNs := resolveMetaNames m names (match cMeta with | .indc _ _ c _ _ _ _ => c | _ => #[]) + .ok (convertInductive m ind ctorAs bIdx.allInductAddrs name levelParams cMeta allNs ctorNs) + | _ => .error s!"iPrj at {addr} does not point to an inductive" + else .error s!"iPrj index out of bounds at {addr}" + | .cPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .indc ind => + if h2 : prj.cidx.toNat < ind.ctors.size then + let ctor := ind.ctors[prj.cidx.toNat] + let inductAddr := bIdx.inductAddrs.getD prj.idx default + let inductNm := resolveMetaName m names (match cMeta with | .ctor _ _ i _ _ => i | _ => default) + .ok (convertConstructor m ctor inductAddr name levelParams cMeta inductNm) + else .error s!"cPrj cidx out of bounds at {addr}" + | _ => .error s!"cPrj at {addr} does not point to an inductive" + else .error s!"cPrj index out of bounds at {addr}" + | .rPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .recr r => + let ruleCtorAs := bIdx.allCtorAddrsInOrder + let allNs := resolveMetaNames m names (match cMeta with | .recr _ _ _ a _ _ _ _ => a | _ => #[]) + let metaRules := match cMeta with | .recr _ _ rules _ _ _ _ _ => rules | _ => #[] + let ruleCtorNs := metaRules.map fun x => resolveMetaName m names x + .ok (convertRecursor m r bIdx.allInductAddrs ruleCtorAs name levelParams cMeta allNs ruleCtorNs) + | _ => .error s!"rPrj at {addr} does not point to a recursor" + else .error s!"rPrj index out of bounds at {addr}" + | .dPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .defn d => + let hints := match cMeta with + | .defn _ _ h _ _ _ _ _ => convertHints h + | _ => .opaque + let allNs := resolveMetaNames m names (match cMeta with | .defn _ _ _ a _ _ _ _ => a | _ => #[]) + .ok (convertDefinition m d hints bIdx.allInductAddrs name levelParams cMeta allNs) + | _ => .error s!"dPrj at {addr} does not point to a definition" + else .error s!"dPrj index out of bounds at {addr}" + | _ => .error s!"not a projection at {addr}" + +/-! ## Work items -/ + +/-- An entry to convert: address, constant, name, and metadata. -/ +structure ConvertEntry (m : MetaMode) where + addr : Address + const : Constant + name : MetaField m Ix.Name + constMeta : ConstantMeta + +/-- A work item: either a standalone constant or a complete block group. -/ +inductive WorkItem (m : MetaMode) where + | standalone (entry : ConvertEntry m) + | block (blockAddr : Address) (entries : Array (ConvertEntry m)) + +/-- Extract ctx addresses from ConstantMeta (mutual context for .recur resolution). -/ +def metaCtxAddrs : ConstantMeta → Array Address + | .defn _ _ _ _ ctx .. => ctx + | .indc _ _ _ _ ctx .. => ctx + | .recr _ _ _ _ ctx .. => ctx + | _ => #[] + +/-- Extract parent inductive name-hash address from ConstantMeta (ctor only). -/ +def metaInductAddr : ConstantMeta → Address + | .ctor _ _ induct _ _ => induct + | _ => default + +/-- Resolve ctx name-hash addresses to constant addresses for recurAddrs. -/ +def resolveCtxAddrs (hashToAddr : Std.HashMap Address Address) (ctx : Array Address) + : Except ConvertError (Array Address) := + ctx.mapM fun x => + match hashToAddr.get? x with + | some addr => .ok addr + | none => .error (.unresolvableCtxAddr x) + +/-- Convert a standalone (non-projection) constant. -/ +def convertStandalone (m : MetaMode) (hashToAddr : Std.HashMap Address Address) + (ixonEnv : Ixon.Env) (entry : ConvertEntry m) : + Except String (Option (Ix.Kernel.ConstantInfo m)) := do + let cMeta := entry.constMeta + let recurAddrs ← (resolveCtxAddrs hashToAddr (metaCtxAddrs cMeta)).mapError toString + let lvlNames := resolveLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let lps := mkLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let cEnv := mkConvertEnv m entry.const ixonEnv.blobs + (recurAddrs := recurAddrs) (arena := (metaArena cMeta)) (names := ixonEnv.names) + (levelParamNames := lvlNames) + match entry.const.info with + | .defn d => + let hints := match cMeta with + | .defn _ _ h _ _ _ _ _ => convertHints h + | _ => .opaque + let allHashAddrs := match cMeta with + | .defn _ _ _ a _ _ _ _ => a + | _ => #[] + let all := allHashAddrs.map fun x => hashToAddr.getD x x + let allNames := resolveMetaNames m ixonEnv.names allHashAddrs + let ci ← (ConvertM.run cEnv (convertDefinition m d hints all entry.name lps cMeta allNames)).mapError toString + return some ci + | .axio a => + let ci ← (ConvertM.run cEnv (convertAxiom m a entry.name lps cMeta)).mapError toString + return some ci + | .quot q => + let ci ← (ConvertM.run cEnv (convertQuotient m q entry.name lps cMeta)).mapError toString + return some ci + | .recr r => + let pair : Array Address × Array Address := match cMeta with + | .recr _ _ rules all _ _ _ _ => (all, rules) + | _ => (#[entry.addr], #[]) + let (metaAll, metaRules) := pair + let all := metaAll.map fun x => hashToAddr.getD x x + let ruleCtorAddrs := metaRules.map fun x => hashToAddr.getD x x + let allNames := resolveMetaNames m ixonEnv.names metaAll + let ruleCtorNames := metaRules.map fun x => resolveMetaName m ixonEnv.names x + let ci ← (ConvertM.run cEnv (convertRecursor m r all ruleCtorAddrs entry.name lps cMeta allNames ruleCtorNames)).mapError toString + return some ci + | .muts _ => return none + | _ => return none -- projections handled separately + +/-- Convert a complete block group (all projections share cache + recurAddrs). -/ +def convertWorkBlock (m : MetaMode) + (ixonEnv : Ixon.Env) (blockAddr : Address) + (entries : Array (ConvertEntry m)) + (results : Array (Address × Ix.Kernel.ConstantInfo m)) (errors : Array (Address × String)) + : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do + let mut results := results + let mut errors := errors + match ixonEnv.getConst? blockAddr with + | some blockConst => + -- Dedup projections by address for buildBlockIndex (avoid duplicate allInductAddrs) + let mut canonicalProjs : Array (Address × Constant) := #[] + let mut seenAddrs : Std.HashSet Address := {} + for e in entries do + if !seenAddrs.contains e.addr then + canonicalProjs := canonicalProjs.push (e.addr, e.const) + seenAddrs := seenAddrs.insert e.addr + let bIdx := buildBlockIndex canonicalProjs + let numMembers := match blockConst.info with + | .muts members => members.size + | _ => 0 + let recurAddrs ← match buildRecurAddrs bIdx numMembers with + | .ok addrs => pure addrs + | .error e => + for entry in entries do + errors := errors.push (entry.addr, toString e) + return (results, errors) + -- Base env (no arena/levelParamNames — each projection sets its own) + let baseEnv := mkConvertEnv m blockConst ixonEnv.blobs recurAddrs (names := ixonEnv.names) + let mut state := ConvertState.init baseEnv + let shareCache := match m with | .anon => true | .meta => false + for entry in entries do + if !shareCache then + state := ConvertState.init baseEnv + let cMeta := entry.constMeta + let lvlNames := resolveLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let lps := mkLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let cEnv := { baseEnv with arena := (metaArena cMeta), levelParamNames := lvlNames } + match convertProjAction m entry.addr entry.const blockConst bIdx entry.name lps cMeta ixonEnv.names with + | .ok action => + match ConvertM.runWith cEnv state action with + | .ok (ci, state') => + state := state' + results := results.push (entry.addr, ci) + | .error e => + errors := errors.push (entry.addr, toString e) + | .error e => errors := errors.push (entry.addr, e) + | none => + for entry in entries do + errors := errors.push (entry.addr, s!"block not found: {blockAddr}") + (results, errors) + +/-- Convert a chunk of work items. -/ +def convertChunk (m : MetaMode) (hashToAddr : Std.HashMap Address Address) + (ixonEnv : Ixon.Env) (chunk : Array (WorkItem m)) + : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do + let mut results : Array (Address × Ix.Kernel.ConstantInfo m) := #[] + let mut errors : Array (Address × String) := #[] + for item in chunk do + match item with + | .standalone entry => + match convertStandalone m hashToAddr ixonEnv entry with + | .ok (some ci) => results := results.push (entry.addr, ci) + | .ok none => pure () + | .error e => errors := errors.push (entry.addr, e) + | .block blockAddr entries => + (results, errors) := convertWorkBlock m ixonEnv blockAddr entries results errors + (results, errors) + +/-! ## Top-level conversion -/ + +/-- Convert an entire Ixon.Env to a Kernel.Env with primitives and quotInit flag. + Iterates named constants first (with full metadata), then picks up anonymous + constants not in named. Groups projections by block and parallelizes. -/ +def convertEnv (m : MetaMode) (ixonEnv : Ixon.Env) (numWorkers : Nat := 32) + : Except String (Ix.Kernel.Env m × Primitives × Bool) := + -- Build primitives with quot addresses + let prims : Primitives := Id.run do + let mut p := buildPrimitives + for (addr, c) in ixonEnv.consts do + match c.info with + | .quot q => match q.kind with + | .type => p := { p with quotType := addr } + | .ctor => p := { p with quotCtor := addr } + | .lift => p := { p with quotLift := addr } + | .ind => p := { p with quotInd := addr } + | _ => pure () + return p + let quotInit := Id.run do + for (_, c) in ixonEnv.consts do + if let .quot _ := c.info then return true + return false + let hashToAddr := buildHashToAddr ixonEnv + let (constants, allErrors) := Id.run do + -- Phase 1: Build entries from named constants (have names + metadata) + let mut entries : Array (ConvertEntry m) := #[] + let mut seen : Std.HashSet Address := {} + for (ixName, named) in ixonEnv.named do + let addr := named.addr + match ixonEnv.consts.get? addr with + | some c => + let name := mkMetaName m (some ixName) + entries := entries.push { addr, const := c, name, constMeta := named.constMeta } + seen := seen.insert addr + | none => pure () + -- Phase 2: Pick up anonymous constants not covered by named + for (addr, c) in ixonEnv.consts do + if !seen.contains addr then + entries := entries.push { addr, const := c, name := default, constMeta := .empty } + -- Phase 2.5: In .anon mode, dedup all entries by address (copies identical). + -- In .meta mode, keep all entries (named variants have distinct metadata). + let shouldDedup := match m with | .anon => true | .meta => false + if shouldDedup then + let mut dedupedEntries : Array (ConvertEntry m) := #[] + let mut seenDedup : Std.HashSet Address := {} + for entry in entries do + if !seenDedup.contains entry.addr then + dedupedEntries := dedupedEntries.push entry + seenDedup := seenDedup.insert entry.addr + entries := dedupedEntries + -- Phase 3: Group into standalones and block groups + -- Use (blockAddr, ctxKey) to disambiguate colliding block addresses + let mut standalones : Array (ConvertEntry m) := #[] + -- Pass 1: Build nameHash → ctx map from entries with ctx + let mut nameHashToCtx : Std.HashMap Address (Array Address) := {} + let mut projEntries : Array (Address × ConvertEntry m) := #[] + for entry in entries do + match projBlockAddr entry.const.info with + | some blockAddr => + projEntries := projEntries.push (blockAddr, entry) + let ctx := metaCtxAddrs entry.constMeta + if ctx.size > 0 then + for nameHash in ctx do + nameHashToCtx := nameHashToCtx.insert nameHash ctx + | none => standalones := standalones.push entry + -- Pass 2: Group by (blockAddr, ctxKey) to avoid collisions + let mut blockGroups : Std.HashMap (Address × UInt64) (Array (ConvertEntry m)) := {} + for (blockAddr, entry) in projEntries do + let ctx0 := metaCtxAddrs entry.constMeta + let ctx := if ctx0.size > 0 then ctx0 + else nameHashToCtx.getD (metaInductAddr entry.constMeta) #[] + let ctxKey := hash ctx + let key := (blockAddr, ctxKey) + blockGroups := blockGroups.insert key + ((blockGroups.getD key #[]).push entry) + -- Phase 4: Build work items + let mut workItems : Array (WorkItem m) := #[] + for entry in standalones do + workItems := workItems.push (.standalone entry) + for ((blockAddr, _), blockEntries) in blockGroups do + workItems := workItems.push (.block blockAddr blockEntries) + -- Phase 5: Chunk work items and parallelize + let total := workItems.size + let chunkSize := (total + numWorkers - 1) / numWorkers + let mut tasks : Array (Task (Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String))) := #[] + let mut offset := 0 + while offset < total do + let endIdx := min (offset + chunkSize) total + let chunk := workItems[offset:endIdx] + let task := Task.spawn (prio := .dedicated) fun () => + convertChunk m hashToAddr ixonEnv chunk.toArray + tasks := tasks.push task + offset := endIdx + -- Phase 6: Collect results + let mut constants : Ix.Kernel.Env m := default + let mut allErrors : Array (Address × String) := #[] + for task in tasks do + let (chunkResults, chunkErrors) := task.get + for (addr, ci) in chunkResults do + constants := constants.insert addr ci + allErrors := allErrors ++ chunkErrors + (constants, allErrors) + if !allErrors.isEmpty then + let msgs := allErrors[:min 10 allErrors.size].toArray.map fun (addr, e) => s!" {addr}: {e}" + .error s!"conversion errors ({allErrors.size}):\n{"\n".intercalate msgs.toList}" + else + .ok (constants, prims, quotInit) + +/-- Convert an Ixon.Env to a Kernel.Env with full metadata. -/ +def convert (ixonEnv : Ixon.Env) : Except String (Ix.Kernel.Env .meta × Primitives × Bool) := + convertEnv .meta ixonEnv + +/-- Convert an Ixon.Env to a Kernel.Env without metadata. -/ +def convertAnon (ixonEnv : Ixon.Env) : Except String (Ix.Kernel.Env .anon × Primitives × Bool) := + convertEnv .anon ixonEnv + +end Ix.Kernel.Convert diff --git a/Ix/Kernel/Datatypes.lean b/Ix/Kernel/Datatypes.lean new file mode 100644 index 00000000..d94d8701 --- /dev/null +++ b/Ix/Kernel/Datatypes.lean @@ -0,0 +1,181 @@ +/- + Kernel Datatypes: Value, Neutral, SusValue, TypedExpr, Env, TypedConst. + + Closure-based semantic domain for NbE typechecking. + Parameterized over MetaMode for compile-time metadata erasure. +-/ +import Ix.Kernel.Types + +namespace Ix.Kernel + +/-! ## TypeInfo -/ + +inductive TypeInfo (m : MetaMode) where + | unit | proof | none + | sort : Level m → TypeInfo m + deriving Inhabited + +/-! ## AddInfo -/ + +structure AddInfo (Info Body : Type) where + info : Info + body : Body + deriving Inhabited + +/-! ## Forward declarations for mutual types -/ + +abbrev TypedExpr (m : MetaMode) := AddInfo (TypeInfo m) (Expr m) + +/-! ## Value / Neutral / SusValue -/ + +mutual + inductive Value (m : MetaMode) where + | sort : Level m → Value m + | app : Neutral m → List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (TypeInfo m) → Value m + | lam : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m + → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m + | pi : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m + → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m + | lit : Lean.Literal → Value m + | exception : String → Value m + + inductive Neutral (m : MetaMode) where + | fvar : Nat → MetaField m Ix.Name → Neutral m + | const : Address → Array (Level m) → MetaField m Ix.Name → Neutral m + | proj : Address → Nat → AddInfo (TypeInfo m) (Value m) → MetaField m Ix.Name → Neutral m + + inductive ValEnv (m : MetaMode) where + | mk : List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (Level m) → ValEnv m +end + +instance : Inhabited (Value m) where default := .exception "uninit" +instance : Inhabited (Neutral m) where default := .fvar 0 default +instance : Inhabited (ValEnv m) where default := .mk [] [] + +abbrev SusValue (m : MetaMode) := AddInfo (TypeInfo m) (Thunk (Value m)) + +instance : Inhabited (SusValue m) where + default := .mk default { fn := fun _ => default } + +/-! ## TypedConst -/ + +inductive TypedConst (m : MetaMode) where + | «axiom» : (type : TypedExpr m) → TypedConst m + | «theorem» : (type value : TypedExpr m) → TypedConst m + | «inductive» : (type : TypedExpr m) → (struct : Bool) → TypedConst m + | «opaque» : (type value : TypedExpr m) → TypedConst m + | definition : (type value : TypedExpr m) → (part : Bool) → TypedConst m + | constructor : (type : TypedExpr m) → (idx fields : Nat) → TypedConst m + | recursor : (type : TypedExpr m) → (params motives minors indices : Nat) → (k : Bool) + → (indAddr : Address) → (rules : Array (Nat × TypedExpr m)) → TypedConst m + | quotient : (type : TypedExpr m) → (kind : QuotKind) → TypedConst m + deriving Inhabited + +def TypedConst.type : TypedConst m → TypedExpr m + | «axiom» type .. + | «theorem» type .. + | «inductive» type .. + | «opaque» type .. + | definition type .. + | constructor type .. + | recursor type .. + | quotient type .. => type + +/-! ## Accessors -/ + +namespace AddInfo + +def expr (t : TypedExpr m) : Expr m := t.body +def thunk (sus : SusValue m) : Thunk (Value m) := sus.body +def get (sus : SusValue m) : Value m := sus.body.get +def getTyped (sus : SusValue m) : AddInfo (TypeInfo m) (Value m) := ⟨sus.info, sus.body.get⟩ +def value (val : AddInfo (TypeInfo m) (Value m)) : Value m := val.body +def sus (val : AddInfo (TypeInfo m) (Value m)) : SusValue m := ⟨val.info, val.body⟩ + +end AddInfo + +/-! ## TypedExpr helpers -/ + +partial def TypedExpr.toImplicitLambda : TypedExpr m → TypedExpr m + | .mk _ (.lam _ body _ _) => toImplicitLambda ⟨default, body⟩ + | x => x + +/-! ## Value helpers -/ + +def Value.neu (n : Neutral m) : Value m := .app n [] [] + +def Value.ctorName : Value m → String + | .sort .. => "sort" + | .app .. => "app" + | .lam .. => "lam" + | .pi .. => "pi" + | .lit .. => "lit" + | .exception .. => "exception" + +def Neutral.summary : Neutral m → String + | .fvar idx name => s!"fvar({name}, {idx})" + | .const addr _ name => s!"const({name}, {addr})" + | .proj _ idx _ name => s!"proj({name}, {idx})" + +def Value.summary : Value m → String + | .sort _ => "Sort" + | .app neu args _ => s!"{neu.summary} applied to {args.length} args" + | .lam .. => "lam" + | .pi .. => "Pi" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit(\"{s}\")" + | .exception e => s!"exception({e})" + +def TypeInfo.pp : TypeInfo m → String + | .unit => ".unit" + | .proof => ".proof" + | .none => ".none" + | .sort _ => ".sort" + +private def listGetOpt (l : List α) (i : Nat) : Option α := + match l, i with + | [], _ => none + | x :: _, 0 => some x + | _ :: xs, n+1 => listGetOpt xs n + +/-- Deep structural dump (one level into args) for debugging stuck terms. -/ +def Value.dump : Value m → String + | .sort _ => "Sort" + | .app neu args infos => + let argStrs := args.zipIdx.map fun (a, i) => + let info := match listGetOpt infos i with | some ti => TypeInfo.pp ti | none => "?" + s!" [{i}] info={info} val={a.get.summary}" + s!"{neu.summary} applied to {args.length} args:\n" ++ String.intercalate "\n" argStrs + | .lam dom _ _ _ _ => s!"lam(dom={dom.get.summary}, info={dom.info.pp})" + | .pi dom _ _ _ _ => s!"Pi(dom={dom.get.summary}, info={dom.info.pp})" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit(\"{s}\")" + | .exception e => s!"exception({e})" + +/-! ## ValEnv helpers -/ + +namespace ValEnv + +def exprs : ValEnv m → List (SusValue m) + | .mk es _ => es + +def univs : ValEnv m → List (Level m) + | .mk _ us => us + +def extendWith (env : ValEnv m) (thunk : SusValue m) : ValEnv m := + .mk (thunk :: env.exprs) env.univs + +def withExprs (env : ValEnv m) (exprs : List (SusValue m)) : ValEnv m := + .mk exprs env.univs + +end ValEnv + +/-! ## Smart constructors -/ + +def mkConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : Value m := + .neu (.const addr univs name) + +def mkSusVar (info : TypeInfo m) (idx : Nat) (name : MetaField m Ix.Name := default) : SusValue m := + .mk info (.mk fun _ => .neu (.fvar idx name)) + +end Ix.Kernel diff --git a/Ix/Kernel/DecompileM.lean b/Ix/Kernel/DecompileM.lean new file mode 100644 index 00000000..d52bda4a --- /dev/null +++ b/Ix/Kernel/DecompileM.lean @@ -0,0 +1,254 @@ +/- + Kernel DecompileM: Kernel.Expr/ConstantInfo → Lean.Expr/ConstantInfo decompilation. + + Used for roundtrip validation: Lean.Environment → Ixon.Env → Kernel.Env → Lean.ConstantInfo. + Comparing the roundtripped Lean.ConstantInfo against the original catches conversion bugs. +-/ +import Ix.Kernel.Types + +namespace Ix.Kernel.Decompile + +/-! ## Name conversion -/ + +/-- Convert Ix.Name to Lean.Name by stripping embedded hashes. -/ +def ixNameToLean : Ix.Name → Lean.Name + | .anonymous _ => .anonymous + | .str parent s _ => .str (ixNameToLean parent) s + | .num parent n _ => .num (ixNameToLean parent) n + +/-! ## Level conversion -/ + +/-- Convert a Kernel.Level back to Lean.Level. + Level param names are synthetic (`u_0`, `u_1`, ...) since Convert.lean + stores `default` for both param names and levelParams. -/ +partial def decompileLevel (levelParams : Array Ix.Name) : Level .meta → Lean.Level + | .zero => .zero + | .succ l => .succ (decompileLevel levelParams l) + | .max l₁ l₂ => .max (decompileLevel levelParams l₁) (decompileLevel levelParams l₂) + | .imax l₁ l₂ => .imax (decompileLevel levelParams l₁) (decompileLevel levelParams l₂) + | .param idx name => + let ixName := if name != default then name + else if h : idx < levelParams.size then levelParams[idx] + else Ix.Name.mkStr Ix.Name.mkAnon s!"u_{idx}" + .param (ixNameToLean ixName) + +/-! ## Expression conversion -/ + +@[inline] def kernelExprPtr (e : Expr .meta) : USize := unsafe ptrAddrUnsafe e + +/-- Convert a Kernel.Expr back to Lean.Expr with pointer-based caching. + Known lossy fields: + - `letE.nonDep` is always `true` (lost in Kernel conversion) + - Binder names/info come from metadata (may be `default` if missing) -/ +partial def decompileExprCached (levelParams : Array Ix.Name) (e : Expr .meta) + : StateM (Std.HashMap USize Lean.Expr) Lean.Expr := do + let ptr := kernelExprPtr e + if let some cached := (← get).get? ptr then return cached + let result ← match e with + | .bvar idx _ => pure (.bvar idx) + | .sort lvl => pure (.sort (decompileLevel levelParams lvl)) + | .const _addr levels name => + pure (.const (ixNameToLean name) (levels.toList.map (decompileLevel levelParams))) + | .app fn arg => do + let f ← decompileExprCached levelParams fn + let a ← decompileExprCached levelParams arg + pure (.app f a) + | .lam ty body name bi => do + let t ← decompileExprCached levelParams ty + let b ← decompileExprCached levelParams body + pure (.lam (ixNameToLean name) t b bi) + | .forallE ty body name bi => do + let t ← decompileExprCached levelParams ty + let b ← decompileExprCached levelParams body + pure (.forallE (ixNameToLean name) t b bi) + | .letE ty val body name => do + let t ← decompileExprCached levelParams ty + let v ← decompileExprCached levelParams val + let b ← decompileExprCached levelParams body + pure (.letE (ixNameToLean name) t v b true) + | .lit lit => pure (.lit lit) + | .proj _typeAddr idx struct typeName => do + let s ← decompileExprCached levelParams struct + pure (.proj (ixNameToLean typeName) idx s) + modify (·.insert ptr result) + pure result + +def decompileExpr (levelParams : Array Ix.Name) (e : Expr .meta) : Lean.Expr := + (decompileExprCached levelParams e |>.run {}).1 + +/-! ## ConstantInfo conversion -/ + +/-- Convert Kernel.DefinitionSafety to Lean.DefinitionSafety. -/ +def decompileSafety : DefinitionSafety → Lean.DefinitionSafety + | .safe => .safe + | .unsafe => .unsafe + | .partial => .partial + +/-- Convert Kernel.ReducibilityHints to Lean.ReducibilityHints. -/ +def decompileHints : ReducibilityHints → Lean.ReducibilityHints + | .opaque => .opaque + | .abbrev => .abbrev + | .regular h => .regular h + +/-- Synthetic level params: `[u_0, u_1, ..., u_{n-1}]`. -/ +def syntheticLevelParams (n : Nat) : List Lean.Name := + (List.range n).map fun i => .str .anonymous s!"u_{i}" + +/-- Convert a Kernel.ConstantInfo (.meta) back to Lean.ConstantInfo. + Name fields are resolved directly from the MetaField name fields + on the sub-structures (allNames, ctorNames, inductName, ctorName). -/ +def decompileConstantInfo (ci : ConstantInfo .meta) : Lean.ConstantInfo := + let cv := ci.cv + let lps := syntheticLevelParams cv.numLevels + let lpArr := cv.levelParams -- Array Ix.Name + let decompTy := decompileExpr lpArr cv.type + let decompVal (e : Expr .meta) := decompileExpr lpArr e + let name := ixNameToLean cv.name + match ci with + | .axiomInfo v => + .axiomInfo { + name, levelParams := lps, type := decompTy, isUnsafe := v.isUnsafe + } + | .defnInfo v => + .defnInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value + hints := decompileHints v.hints + safety := decompileSafety v.safety + } + | .thmInfo v => + .thmInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value + } + | .opaqueInfo v => + .opaqueInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value, isUnsafe := v.isUnsafe + } + | .quotInfo v => + let leanKind : Lean.QuotKind := match v.kind with + | .type => .type | .ctor => .ctor | .lift => .lift | .ind => .ind + .quotInfo { + name, levelParams := lps, type := decompTy, kind := leanKind + } + | .inductInfo v => + .inductInfo { + name, levelParams := lps, type := decompTy + numParams := v.numParams, numIndices := v.numIndices + isRec := v.isRec, isUnsafe := v.isUnsafe, isReflexive := v.isReflexive + all := v.allNames.toList.map ixNameToLean + ctors := v.ctorNames.toList.map ixNameToLean + numNested := v.numNested + } + | .ctorInfo v => + .ctorInfo { + name, levelParams := lps, type := decompTy + induct := ixNameToLean v.inductName + cidx := v.cidx, numParams := v.numParams, numFields := v.numFields + isUnsafe := v.isUnsafe + } + | .recInfo v => + .recInfo { + name, levelParams := lps, type := decompTy + all := v.allNames.toList.map ixNameToLean + numParams := v.numParams, numIndices := v.numIndices + numMotives := v.numMotives, numMinors := v.numMinors + k := v.k, isUnsafe := v.isUnsafe + rules := v.rules.toList.map fun r => { + ctor := ixNameToLean r.ctorName + nfields := r.nfields + rhs := decompVal r.rhs + } + } + +/-! ## Structural comparison -/ + +@[inline] def leanExprPtr (e : Lean.Expr) : USize := unsafe ptrAddrUnsafe e + +structure ExprPtrPair where + a : USize + b : USize + deriving Hashable, BEq + +/-- Compare two Lean.Exprs structurally, ignoring binder names and binder info. + Uses pointer-pair caching to avoid exponential blowup on shared subexpressions. + Returns `none` if structurally equal, `some (path, lhs, rhs)` on first mismatch. -/ +partial def exprStructEq (a b : Lean.Expr) (path : String := "") + : StateM (Std.HashSet ExprPtrPair) (Option (String × String × String)) := do + let ptrA := leanExprPtr a + let ptrB := leanExprPtr b + if ptrA == ptrB then return none + let pair := ExprPtrPair.mk ptrA ptrB + if (← get).contains pair then return none + let result ← match a, b with + | .bvar i, .bvar j => + pure (if i == j then none else some (path, s!"bvar({i})", s!"bvar({j})")) + | .sort l₁, .sort l₂ => + pure (if Lean.Level.isEquiv l₁ l₂ then none else some (path, s!"sort", s!"sort")) + | .const n₁ ls₁, .const n₂ ls₂ => + pure (if n₁ != n₂ then some (path, s!"const({n₁})", s!"const({n₂})") + else if ls₁.length != ls₂.length then + some (path, s!"const({n₁}) {ls₁.length} lvls", s!"const({n₂}) {ls₂.length} lvls") + else none) + | .app f₁ a₁, .app f₂ a₂ => do + match ← exprStructEq f₁ f₂ (path ++ ".app.fn") with + | some m => pure (some m) + | none => exprStructEq a₁ a₂ (path ++ ".app.arg") + | .lam _ t₁ b₁ _, .lam _ t₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".lam.ty") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".lam.body") + | .forallE _ t₁ b₁ _, .forallE _ t₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".pi.ty") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".pi.body") + | .letE _ t₁ v₁ b₁ _, .letE _ t₂ v₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".let.ty") with + | some m => pure (some m) + | none => match ← exprStructEq v₁ v₂ (path ++ ".let.val") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".let.body") + | .lit l₁, .lit l₂ => + pure (if l₁ == l₂ then none + else + let showLit : Lean.Literal → String + | .natVal n => s!"natLit({n})" + | .strVal s => s!"strLit({s})" + some (path, showLit l₁, showLit l₂)) + | .proj t₁ i₁ s₁, .proj t₂ i₂ s₂ => + if t₁ != t₂ then pure (some (path, s!"proj({t₁}.{i₁})", s!"proj({t₂}.{i₂})")) + else if i₁ != i₂ then pure (some (path, s!"proj.idx({i₁})", s!"proj.idx({i₂})")) + else exprStructEq s₁ s₂ (path ++ ".proj.struct") + | .mdata _ e₁, _ => exprStructEq e₁ b path + | _, .mdata _ e₂ => exprStructEq a e₂ path + | _, _ => + let tag (e : Lean.Expr) : String := match e with + | .bvar _ => "bvar" | .sort _ => "sort" | .const .. => "const" + | .app .. => "app" | .lam .. => "lam" | .forallE .. => "forallE" + | .letE .. => "letE" | .lit .. => "lit" | .proj .. => "proj" + | .fvar .. => "fvar" | .mvar .. => "mvar" | .mdata .. => "mdata" + pure (some (path, tag a, tag b)) + if result.isNone then modify (·.insert pair) + pure result + +/-- Compare two Lean.ConstantInfos structurally. Returns list of mismatches. -/ +def constInfoStructEq (a b : Lean.ConstantInfo) + : Array (String × String × String) := + let check : StateM (Std.HashSet ExprPtrPair) (Array (String × String × String)) := do + let mut mismatches : Array (String × String × String) := #[] + -- Compare types + if let some m ← exprStructEq a.type b.type "type" then + mismatches := mismatches.push m + -- Compare values if both have them + match a.value?, b.value? with + | some va, some vb => + if let some m ← exprStructEq va vb "value" then + mismatches := mismatches.push m + | none, some _ => mismatches := mismatches.push ("value", "none", "some") + | some _, none => mismatches := mismatches.push ("value", "some", "none") + | none, none => pure () + return mismatches + (check.run {}).1 + +end Ix.Kernel.Decompile diff --git a/Ix/Kernel/Equal.lean b/Ix/Kernel/Equal.lean new file mode 100644 index 00000000..4f219b7c --- /dev/null +++ b/Ix/Kernel/Equal.lean @@ -0,0 +1,168 @@ +/- + Kernel Equal: Definitional equality checking. + + Handles proof irrelevance, unit types, eta expansion. + In NbE, all non-partial definitions are eagerly unfolded by `eval`, so there + is no lazy delta reduction here — different const-headed values are genuinely + unequal (they are stuck constructors, recursors, axioms, or partial defs). + Adapted from Yatima.Typechecker.Equal, parameterized over MetaMode. +-/ +import Ix.Kernel.Eval + +namespace Ix.Kernel + +/-- Pointer equality on thunks: if two thunks share the same pointer, they must + produce the same value. Returns false conservatively when pointers differ. -/ +@[inline] private def susValuePtrEq (a b : SusValue m) : Bool := + unsafe ptrAddrUnsafe a.body == ptrAddrUnsafe b.body + +/-- Compare two arrays of levels for equality. -/ +private def equalUnivArrays (us us' : Array (Level m)) : Bool := + us.size == us'.size && Id.run do + let mut i := 0 + while i < us.size do + if !Level.equalLevel us[i]! us'[i]! then return false + i := i + 1 + return true + +/-- Construct a canonicalized cache key for two SusValues using their pointer addresses. + The smaller pointer always comes first, making the key symmetric: key(a,b) == key(b,a). -/ +@[inline] private def susValueCacheKey (a b : SusValue m) : USize × USize := + let pa := unsafe ptrAddrUnsafe a.body + let pb := unsafe ptrAddrUnsafe b.body + if pa ≤ pb then (pa, pb) else (pb, pa) + +mutual + /-- Try eta expansion for structure-like types. -/ + partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := do + match term'.get with + | .app (.const k _ _) args _ => + match (← get).typedConsts.get? k with + | some (.constructor type ..) => + match ← applyType (← eval type) args with + | .app (.const tk _ _) targs _ => + match (← get).typedConsts.get? tk with + | some (.inductive _ struct ..) => + -- Skip struct eta for Prop types (proof irrelevance handles them) + let isProp := match term'.info with | .proof => true | _ => false + if struct && !isProp then + targs.zipIdx.foldlM (init := true) fun acc (arg, i) => do + match arg.get with + | .app (.proj _ idx val _) _ _ => + pure (acc && i == idx && (← equal lvl term val.sus)) + | _ => pure false + else pure false + | _ => pure false + | _ => pure false + | _ => pure false + | _ => pure false + + /-- Check if two suspended values are definitionally equal at the given level. + Assumes both have the same type and live in the same context. -/ + partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := + match term.info, term'.info with + | .unit, .unit => pure true + | .proof, .proof => pure true + | _, _ => withFuelCheck do + if (← read).trace then dbg_trace s!"equal: {term.get.ctorName} vs {term'.get.ctorName}" + -- Fast path: pointer equality on thunks + if susValuePtrEq term term' then return true + -- Check equality cache + let key := susValueCacheKey term term' + if let some true := (← get).equalCache.get? key then return true + let tv := term.get + let tv' := term'.get + let result ← match tv, tv' with + | .lit lit, .lit lit' => pure (lit == lit') + | .sort u, .sort u' => pure (Level.equalLevel u u') + | .pi dom img env _ _, .pi dom' img' env' _ _ => do + let res ← equal lvl dom dom' + let ctx ← read + let stt ← get + let img := suspend img { ctx with env := env.extendWith (mkSusVar dom.info lvl) } stt + let img' := suspend img' { ctx with env := env'.extendWith (mkSusVar dom'.info lvl) } stt + let res' ← equal (lvl + 1) img img' + if !res' then + dbg_trace s!"equal Pi images FAILED at lvl={lvl}: lhs={img.get.dump} rhs={img'.get.dump}" + pure (res && res') + | .lam dom bod env _ _, .lam dom' bod' env' _ _ => do + let res ← equal lvl dom dom' + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith (mkSusVar dom.info lvl) } stt + let bod' := suspend bod' { ctx with env := env'.extendWith (mkSusVar dom'.info lvl) } stt + let res' ← equal (lvl + 1) bod bod' + pure (res && res') + | .lam dom bod env _ _, .app neu' args' infos' => do + let var := mkSusVar dom.info lvl + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith var } stt + let app := Value.app neu' (var :: args') (term'.info :: infos') + equal (lvl + 1) bod (.mk bod.info app) + | .app neu args infos, .lam dom bod env _ _ => do + let var := mkSusVar dom.info lvl + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith var } stt + let app := Value.app neu (var :: args) (term.info :: infos) + equal (lvl + 1) (.mk bod.info app) bod + | .app (.fvar idx _) args _, .app (.fvar idx' _) args' _ => + if idx == idx' then equalThunks lvl args args' + else pure false + | .app (.const k us _) args _, .app (.const k' us' _) args' _ => + if k == k' && equalUnivArrays us us' then + equalThunks lvl args args' + else + -- In NbE, eval eagerly unfolds all non-partial definitions. + -- Different const heads here are stuck terms that can't reduce further. + pure false + -- Nat literal vs constructor expansion + | .lit (.natVal _), .app (.const _ _ _) _ _ => do + let prims := (← read).prims + let expanded ← toCtorIfLit prims tv + equal lvl (.mk term.info (.mk fun _ => expanded)) term' + | .app (.const _ _ _) _ _, .lit (.natVal _) => do + let prims := (← read).prims + let expanded ← toCtorIfLit prims tv' + equal lvl term (.mk term'.info (.mk fun _ => expanded)) + -- String literal vs constructor expansion + | .lit (.strVal _), .app (.const _ _ _) _ _ => do + let prims := (← read).prims + let expanded ← strLitToCtorVal prims (match tv with | .lit (.strVal s) => s | _ => "") + equal lvl (.mk term.info (.mk fun _ => expanded)) term' + | .app (.const _ _ _) _ _, .lit (.strVal _) => do + let prims := (← read).prims + let expanded ← strLitToCtorVal prims (match tv' with | .lit (.strVal s) => s | _ => "") + equal lvl term (.mk term'.info (.mk fun _ => expanded)) + | _, .app (.const _ _ _) _ _ => + tryEtaStruct lvl term term' + | .app (.const _ _ _) _ _, _ => + tryEtaStruct lvl term' term + | .app (.proj ind idx val _) args _, .app (.proj ind' idx' val' _) args' _ => + if ind == ind' && idx == idx' then do + let eqVal ← equal lvl val.sus val'.sus + let eqThunks ← equalThunks lvl args args' + pure (eqVal && eqThunks) + else pure false + | .exception e, _ | _, .exception e => + throw s!"exception in equal: {e}" + | _, _ => + dbg_trace s!"equal FALLTHROUGH at lvl={lvl}: lhs={tv.dump} rhs={tv'.dump}" + pure false + if result then + modify fun stt => { stt with equalCache := stt.equalCache.insert key true } + return result + + /-- Check if two lists of suspended values are pointwise equal. -/ + partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m Bool := + match vals, vals' with + | val :: vals, val' :: vals' => do + let eq ← equal lvl val val' + let eq' ← equalThunks lvl vals vals' + pure (eq && eq') + | [], [] => pure true + | _, _ => pure false +end + +end Ix.Kernel diff --git a/Ix/Kernel/Eval.lean b/Ix/Kernel/Eval.lean new file mode 100644 index 00000000..9fa74125 --- /dev/null +++ b/Ix/Kernel/Eval.lean @@ -0,0 +1,530 @@ +/- + Kernel Eval: Expression evaluation, constant/recursor/quot/nat reduction. + + Adapted from Yatima.Typechecker.Eval, parameterized over MetaMode. +-/ +import Ix.Kernel.TypecheckM + +namespace Ix.Kernel + +open Level (instBulkReduce reduceIMax) + +def TypeInfo.update (univs : Array (Level m)) : TypeInfo m → TypeInfo m + | .sort lvl => .sort (instBulkReduce univs lvl) + | .unit => .unit + | .proof => .proof + | .none => .none + +/-! ## Helpers (needed by mutual block) -/ + +/-- Check if an address is a primitive operation that takes arguments. -/ +private def isPrimOp (prims : Primitives) (addr : Address) : Bool := + addr == prims.natAdd || addr == prims.natSub || addr == prims.natMul || + addr == prims.natPow || addr == prims.natGcd || addr == prims.natMod || + addr == prims.natDiv || addr == prims.natBeq || addr == prims.natBle || + addr == prims.natLand || addr == prims.natLor || addr == prims.natXor || + addr == prims.natShiftLeft || addr == prims.natShiftRight || + addr == prims.natSucc + +/-- Look up element in a list by index. -/ +def listGet? (l : List α) (n : Nat) : Option α := + match l, n with + | [], _ => none + | a :: _, 0 => some a + | _ :: l, n+1 => listGet? l n + +/-- Try to reduce a primitive operation if all arguments are available. -/ +private def tryPrimOp (prims : Primitives) (addr : Address) + (args : List (SusValue m)) : TypecheckM m (Option (Value m)) := do + -- Nat.succ: 1 arg + if addr == prims.natSucc then + if args.length >= 1 then + match args.head!.get with + | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) + | _ => return none + else return none + -- Binary nat operations: 2 args + else if args.length >= 2 then + let a := args[0]!.get + let b := args[1]!.get + match a, b with + | .lit (.natVal x), .lit (.natVal y) => + if addr == prims.natAdd then return some (.lit (.natVal (x + y))) + else if addr == prims.natSub then return some (.lit (.natVal (x - y))) + else if addr == prims.natMul then return some (.lit (.natVal (x * y))) + else if addr == prims.natPow then + if y > 16777216 then return none + return some (.lit (.natVal (Nat.pow x y))) + else if addr == prims.natMod then return some (.lit (.natVal (x % y))) + else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) + else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) + else if addr == prims.natBeq then + let boolAddr := if x == y then prims.boolTrue else prims.boolFalse + let boolName ← lookupName boolAddr + return some (mkConst boolAddr #[] boolName) + else if addr == prims.natBle then + let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse + let boolName ← lookupName boolAddr + return some (mkConst boolAddr #[] boolName) + else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) + else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) + else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) + else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) + else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) + else return none + | _, _ => return none + else return none + +/-- Expand a string literal to its constructor form: String.mk (list-of-chars). + Each character is represented as Char.ofNat n, and the list uses + List.cons/List.nil at universe level 0. -/ +def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) := do + let charMkName ← lookupName prims.charMk + let charName ← lookupName prims.char + let listNilName ← lookupName prims.listNil + let listConsName ← lookupName prims.listCons + let stringMkName ← lookupName prims.stringMk + let mkCharOfNat (c : Char) : SusValue m := + ⟨.none, .mk fun _ => + Value.app (.const prims.charMk #[] charMkName) + [⟨.none, .mk fun _ => .lit (.natVal c.toNat)⟩] [.none]⟩ + let charType : SusValue m := + ⟨.none, .mk fun _ => Value.neu (.const prims.char #[] charName)⟩ + let nilVal : Value m := + Value.app (.const prims.listNil #[.zero] listNilName) [charType] [.none] + let listVal := s.toList.foldr (fun c acc => + let tail : SusValue m := ⟨.none, .mk fun _ => acc⟩ + let head := mkCharOfNat c + Value.app (.const prims.listCons #[.zero] listConsName) + [tail, head, charType] [.none, .none, .none] + ) nilVal + let data : SusValue m := ⟨.none, .mk fun _ => listVal⟩ + pure (Value.app (.const prims.stringMk #[] stringMkName) [data] [.none]) + +/-! ## Eval / Apply mutual block -/ + +mutual + /-- Evaluate a typed expression to a value. -/ + partial def eval (t : TypedExpr m) : TypecheckM m (Value m) := withFuelCheck do + if (← read).trace then dbg_trace s!"eval: {t.body.tag}" + match t.body with + | .app fnc arg => do + let ctx ← read + let stt ← get + let argThunk := suspend ⟨default, arg⟩ ctx stt + let fnc ← evalTyped ⟨default, fnc⟩ + try apply fnc argThunk + catch e => + throw s!"{e}\n in app: ({fnc.body.summary}) applied to ({arg.pp})" + | .lam ty body name bi => do + let ctx ← read + let stt ← get + let dom := suspend ⟨default, ty⟩ ctx stt + pure (.lam dom ⟨default, body⟩ ctx.env name bi) + | .bvar idx _ => do + let some thunk := listGet? (← read).env.exprs idx + | throw s!"Index {idx} is out of range for expression environment" + pure thunk.get + | .const addr levels name => do + let env := (← read).env + let levels := levels.map (instBulkReduce env.univs.toArray) + try evalConst addr levels name + catch e => + let nameStr := match (← read).kenv.find? addr with + | some c => s!"{c.cv.name}" | none => s!"{addr}" + throw s!"{e}\n in eval const {nameStr}" + | .letE _ val body _ => do + let ctx ← read + let stt ← get + let thunk := suspend ⟨default, val⟩ ctx stt + withExtendedEnv thunk (eval ⟨default, body⟩) + | .forallE ty body name bi => do + let ctx ← read + let stt ← get + let dom := suspend ⟨default, ty⟩ ctx stt + pure (.pi dom ⟨default, body⟩ ctx.env name bi) + | .sort univ => do + let env := (← read).env + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .lit lit => + pure (.lit lit) + | .proj typeAddr idx struct typeName => do + let raw ← eval ⟨default, struct⟩ + -- Expand string literals to constructor form before projecting + let val ← match raw with + | .lit (.strVal s) => strLitToCtorVal (← read).prims s + | v => pure v + match val with + | .app (.const ctorAddr _ _) args _ => + let ctx ← read + match ctx.kenv.find? ctorAddr with + | some (.ctorInfo v) => + let idx := v.numParams + idx + let some arg := listGet? args.reverse idx + | throw s!"Invalid projection of index {idx} but constructor has only {args.length} arguments" + pure arg.get + | _ => do + let ti := TypeInfo.update (← read).env.univs.toArray (default : TypeInfo m) + pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) + | .app _ _ _ => do + let ti := TypeInfo.update (← read).env.univs.toArray (default : TypeInfo m) + pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) + | e => throw s!"Value is impossible to project: {e.ctorName}" + + partial def evalTyped (t : TypedExpr m) : TypecheckM m (AddInfo (TypeInfo m) (Value m)) := do + let reducedInfo := t.info.update (← read).env.univs.toArray + let value ← eval t + pure ⟨reducedInfo, value⟩ + + /-- Evaluate a constant that is not a primitive. + Theorems are treated as opaque (not unfolded) — proof irrelevance handles + equality of proof terms, and this avoids deep recursion through proof bodies. + Caches evaluated definition bodies to avoid redundant evaluation. -/ + partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + match (← read).kenv.find? addr with + | some (.defnInfo _) => + -- Check eval cache (must also match universe parameters) + if let some (cachedUnivs, cachedVal) := (← get).evalCache.get? addr then + if cachedUnivs == univs then return cachedVal + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.definition _ deref part) => + if part then pure (mkConst addr univs name) + else + let val ← withEnv (.mk [] univs.toList) (eval deref) + modify fun stt => { stt with evalCache := stt.evalCache.insert addr (univs, val) } + pure val + | _ => throw "Invalid const kind for evaluation" + | _ => pure (mkConst addr univs name) + + /-- Evaluate a constant: check if it's Nat.zero, a primitive op, or unfold it. -/ + partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + let prims := (← read).prims + if addr == prims.natZero then pure (.lit (.natVal 0)) + else if isPrimOp prims addr then pure (mkConst addr univs name) + else evalConst' addr univs name + + /-- Create a suspended value from a typed expression, capturing context. -/ + partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m) (stt : TypecheckState m) : SusValue m := + let thunk : Thunk (Value m) := .mk fun _ => + match TypecheckM.run ctx stt (eval expr) with + | .ok a => a + | .error e => .exception e + let reducedInfo := expr.info.update ctx.env.univs.toArray + ⟨reducedInfo, thunk⟩ + + /-- Apply a value to an argument. -/ + partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m (Value m) := do + if (← read).trace then dbg_trace s!"apply: {val.body.ctorName}" + match val.body with + | .lam _ bod lamEnv _ _ => + withNewExtendedEnv lamEnv arg (eval bod) + | .pi dom img piEnv _ _ => + -- Propagate TypeInfo: if domain is Prop, argument is a proof + let enrichedArg : SusValue m := match arg.info, dom.info with + | .none, .sort (.zero) => ⟨.proof, arg.body⟩ + | _, _ => arg + withNewExtendedEnv piEnv enrichedArg (eval img) + | .app (.const addr univs name) args infos => applyConst addr univs arg args val.info infos name + | .app neu args infos => pure (.app neu (arg :: args) (val.info :: infos)) + | v => + throw s!"Invalid case for apply: got {v.ctorName} ({v.summary})" + + /-- Apply a named constant to arguments, handling recursors, quotients, and primitives. -/ + partial def applyConst (addr : Address) (univs : Array (Level m)) (arg : SusValue m) + (args : List (SusValue m)) (info : TypeInfo m) (infos : List (TypeInfo m)) + (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + let prims := (← read).prims + -- Try primitive operations + if let some result ← tryPrimOp prims addr (arg :: args) then + return result + + ---- Try recursor/quotient (ensure provisional entry exists for eval-time lookups) + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.recursor _ params motives minors indices isK indAddr rules) => + let majorIdx := params + motives + minors + indices + if args.length != majorIdx then + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else if isK then + -- K-reduce when major is a constructor, or shortcut via proof irrelevance + let isKCtor ← match ← toCtorIfLit prims (arg.get) with + | .app (.const ctorAddr _ _) _ _ => + match (← get).typedConsts.get? ctorAddr with + | some (.constructor ..) => pure true + | _ => match (← read).kenv.find? ctorAddr with + | some (.ctorInfo _) => pure true + | _ => pure false + | _ => pure false + -- Also check if the inductive lives in Prop, since eval doesn't track TypeInfo + let isPropInd := match (← read).kenv.find? indAddr with + | some (.inductInfo v) => + let rec getSort : Expr m → Bool + | .forallE _ body _ _ => getSort body + | .sort (.zero) => true + | _ => false + getSort v.type + | _ => false + if isKCtor || isPropInd || (match arg.info with | .proof => true | _ => false) then + let nArgs := args.length + let nDrop := params + motives + 1 + if nArgs < nDrop then throw s!"Too few arguments ({nArgs}). At least {nDrop} needed" + let minorIdx := nArgs - nDrop + let some minor := listGet? args minorIdx | throw s!"Index {minorIdx} is out of range" + pure minor.get + else + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + -- Skip Nat.rec reduction on large literals to avoid O(n) eval overhead + let skipLargeNat := match arg.get with + | .lit (.natVal n) => indAddr == prims.nat && n > 256 + | _ => false + if skipLargeNat then + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + match ← toCtorIfLit prims (arg.get) with + | .app (.const ctorAddr _ _) ctorArgs _ => + let st ← get + let ctx ← read + let ctorInfo? := match st.typedConsts.get? ctorAddr with + | some (.constructor _ ctorIdx numFields) => some (ctorIdx, numFields) + | _ => match ctx.kenv.find? ctorAddr with + | some (.ctorInfo cv) => some (cv.cidx, cv.numFields) + | _ => none + match ctorInfo? with + | some (ctorIdx, _) => + match rules[ctorIdx]? with + | some (fields, rhs) => + let exprs := (ctorArgs.take fields) ++ (args.drop indices) + withEnv (.mk exprs univs.toList) (eval rhs.toImplicitLambda) + | none => throw s!"Constructor has no associated recursion rule" + | none => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => + -- Structure eta: expand struct-like major via projections + let kenv := (← read).kenv + let doStructEta := match arg.info with + | .proof => false + | _ => kenv.isStructureLike indAddr + if doStructEta then + match rules[0]? with + | some (fields, rhs) => + let mut projArgs : List (SusValue m) := [] + for i in [:fields] do + let proj : SusValue m := ⟨.none, .mk fun _ => + Value.app (.proj indAddr i ⟨.none, arg.get⟩ default) [] []⟩ + projArgs := proj :: projArgs + let exprs := projArgs ++ (args.drop indices) + withEnv (.mk exprs univs.toList) (eval rhs.toImplicitLambda) + | none => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | some (.quotient _ kind) => match kind with + | .lift => applyQuot prims arg args 6 1 (.app (.const addr univs name) (arg :: args) (info :: infos)) + | .ind => applyQuot prims arg args 5 0 (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + + /-- Apply a quotient to a value. -/ + partial def applyQuot (_prims : Primitives) (major : SusValue m) (args : List (SusValue m)) + (reduceSize argPos : Nat) (default : Value m) : TypecheckM m (Value m) := + let argsLength := args.length + 1 + if argsLength == reduceSize then + match major.get with + | .app (.const majorFn _ _) majorArgs _ => do + match (← get).typedConsts.get? majorFn with + | some (.quotient _ .ctor) => + if majorArgs.length != 3 then throw "majorArgs should have size 3" + let some majorArg := majorArgs.head? | throw "majorArgs can't be empty" + let some head := listGet? args argPos | throw s!"{argPos} is an invalid index for args" + apply head.getTyped majorArg + | _ => pure default + | _ => pure default + else if argsLength < reduceSize then pure default + else throw s!"argsLength {argsLength} can't be greater than reduceSize {reduceSize}" + + /-- Convert a nat literal to Nat.succ/Nat.zero constructors. -/ + partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m (Value m) + | .lit (.natVal 0) => do + let name ← lookupName prims.natZero + pure (Value.neu (.const prims.natZero #[] name)) + | .lit (.natVal (n+1)) => do + let name ← lookupName prims.natSucc + let thunk : SusValue m := ⟨.none, Thunk.mk fun _ => .lit (.natVal n)⟩ + pure (.app (.const prims.natSucc #[] name) [thunk] [.none]) + | v => pure v +end + +/-! ## Quoting (read-back from Value to Expr) -/ + +mutual + partial def quote (lvl : Nat) : Value m → TypecheckM m (Expr m) + | .sort univ => do + let env := (← read).env + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .app neu args infos => do + let argsInfos := args.zip infos + argsInfos.foldrM (init := ← quoteNeutral lvl neu) fun (arg, _info) acc => do + let argExpr ← quoteTyped lvl arg.getTyped + pure (.app acc argExpr.body) + | .lam dom bod env name bi => do + let dom ← quoteTyped lvl dom.getTyped + let var := mkSusVar (default : TypeInfo m) lvl name + let bod ← quoteTypedExpr (lvl+1) bod (env.extendWith var) + pure (.lam dom.body bod.body name bi) + | .pi dom img env name bi => do + let dom ← quoteTyped lvl dom.getTyped + let var := mkSusVar (default : TypeInfo m) lvl name + let img ← quoteTypedExpr (lvl+1) img (env.extendWith var) + pure (.forallE dom.body img.body name bi) + | .lit lit => pure (.lit lit) + | .exception e => throw e + + partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m (TypedExpr m) := do + pure ⟨val.info, ← quote lvl val.body⟩ + + partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m (TypedExpr m) := do + let e ← quoteExpr lvl t.body env + pure ⟨t.info, e⟩ + + partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m (Expr m) := + match expr with + | .bvar idx _ => do + match listGet? env.exprs idx with + | some val => quote lvl val.get + | none => throw s!"Unbound variable _@{idx}" + | .app fnc arg => do + let fnc ← quoteExpr lvl fnc env + let arg ← quoteExpr lvl arg env + pure (.app fnc arg) + | .lam ty body n bi => do + let ty ← quoteExpr lvl ty env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.lam ty body n bi) + | .forallE ty body n bi => do + let ty ← quoteExpr lvl ty env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.forallE ty body n bi) + | .letE ty val body n => do + let ty ← quoteExpr lvl ty env + let val ← quoteExpr lvl val env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.letE ty val body n) + | .const addr levels name => + pure (.const addr (levels.map (instBulkReduce env.univs.toArray)) name) + | .sort univ => + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .proj typeAddr idx struct name => do + let struct ← quoteExpr lvl struct env + pure (.proj typeAddr idx struct name) + | .lit .. => pure expr + + partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m (Expr m) + | .fvar idx name => do + pure (.bvar (lvl - idx - 1) name) + | .const addr univs name => do + let env := (← read).env + pure (.const addr (univs.map (instBulkReduce env.univs.toArray)) name) + | .proj typeAddr idx val name => do + let te ← quoteTyped lvl val + pure (.proj typeAddr idx te.body name) +end + +/-! ## Literal folding for pretty printing -/ + +/-- Try to extract a Char from a Char.ofNat application in an Expr. -/ +private partial def tryFoldChar (prims : Primitives) (e : Expr m) : Option Char := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.charMk then + let args := e.getAppArgs + if args.size == 1 then + match args[0]! with + | .lit (.natVal n) => some (Char.ofNat n) + | _ => none + else none + else none + | _ => none + +/-- Try to extract a List Char from a List.cons/List.nil chain in an Expr. -/ +private partial def tryFoldCharList (prims : Primitives) (e : Expr m) : Option (List Char) := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.listNil then some [] + else if addr == prims.listCons then + let args := e.getAppArgs + -- args = [type, head, tail] + if args.size == 3 then + match tryFoldChar prims args[1]!, tryFoldCharList prims args[2]! with + | some c, some cs => some (c :: cs) + | _, _ => none + else none + else none + | _ => none + +/-- Walk an Expr and fold Nat.zero/Nat.succ chains to nat literals, + and String.mk (char list) to string literals. -/ +partial def foldLiterals (prims : Primitives) : Expr m → Expr m + | .const addr lvls name => + if addr == prims.natZero then .lit (.natVal 0) + else .const addr lvls name + | .app fn arg => + let fn' := foldLiterals prims fn + let arg' := foldLiterals prims arg + let e := Expr.app fn' arg' + -- Try folding the fully-reconstructed app + match e.getAppFn with + | .const addr _ _ => + if addr == prims.natSucc && e.getAppNumArgs == 1 then + match e.appArg! with + | .lit (.natVal n) => .lit (.natVal (n + 1)) + | _ => e + else if addr == prims.stringMk && e.getAppNumArgs == 1 then + match tryFoldCharList prims e.appArg! with + | some cs => .lit (.strVal (String.ofList cs)) + | none => e + else e + | _ => e + | .lam ty body n bi => + .lam (foldLiterals prims ty) (foldLiterals prims body) n bi + | .forallE ty body n bi => + .forallE (foldLiterals prims ty) (foldLiterals prims body) n bi + | .letE ty val body n => + .letE (foldLiterals prims ty) (foldLiterals prims val) (foldLiterals prims body) n + | .proj ta idx s tn => + .proj ta idx (foldLiterals prims s) tn + | e => e + +/-! ## Value pretty printing -/ + +/-- Pretty-print a value by quoting it back to an Expr, then using Expr.pp. + Folds Nat/String constructor chains back to literals for readability. -/ +partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m String := do + let expr ← quote lvl v + let expr := foldLiterals (← read).prims expr + return expr.pp + +/-- Pretty-print a suspended value. -/ +partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m String := + ppValue lvl sv.get + +/-- Pretty-print a value, falling back to the shallow summary on error. -/ +partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m String := do + try ppValue lvl v + catch _ => return v.summary + +/-- Apply a value to a list of arguments. -/ +def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m (Value m) := + match args with + | [] => pure v + | arg :: rest => do + let info : TypeInfo m := .none + let v' ← try apply ⟨info, v⟩ arg + catch e => + let ppV ← tryPpValue (← read).lvl v + throw s!"{e}\n in applyType: {ppV} with {args.length} remaining args" + applyType v' rest + +end Ix.Kernel diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean new file mode 100644 index 00000000..1d0b0159 --- /dev/null +++ b/Ix/Kernel/Infer.lean @@ -0,0 +1,406 @@ +/- + Kernel Infer: Type inference and declaration checking. + + Adapted from Yatima.Typechecker.Infer, parameterized over MetaMode. +-/ +import Ix.Kernel.Equal + +namespace Ix.Kernel + +/-! ## Type info helpers -/ + +def lamInfo : TypeInfo m → TypeInfo m + | .proof => .proof + | _ => .none + +def piInfo (dom img : TypeInfo m) : TypecheckM m (TypeInfo m) := match dom, img with + | .sort lvl, .sort lvl' => pure (.sort (Level.reduceIMax lvl lvl')) + | _, _ => pure .none + +def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m Bool := do + match inferType.info, expectType.info with + | .sort lvl, .sort lvl' => pure (Level.equalLevel lvl lvl') + | _, _ => pure true -- info unavailable; defer to structural equality + +def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := + match typ.info with + | .sort (.zero) => pure .proof + | _ => + match typ.get with + | .app (.const addr _ _) _ _ => do + match (← read).kenv.find? addr with + | some (.inductInfo v) => + -- Check if it's unit-like: one constructor with zero fields + if v.ctors.size == 1 then + match (← read).kenv.find? v.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields == 0 then pure .unit else pure .none + | _ => pure .none + else pure .none + | _ => pure .none + | .sort lvl => pure (.sort lvl) + | _ => pure .none + +/-! ## Inference / Checking -/ + +mutual + /-- Check that a term has a given type. -/ + partial def check (term : Expr m) (type : SusValue m) : TypecheckM m (TypedExpr m) := do + if (← read).trace then dbg_trace s!"check: {term.tag}" + let (te, inferType) ← infer term + if !(← eqSortInfo inferType type) then + throw s!"Info mismatch on {term.tag}" + if !(← equal (← read).lvl type inferType) then + let lvl := (← read).lvl + let ppInferred ← tryPpValue lvl inferType.get + let ppExpected ← tryPpValue lvl type.get + let dumpInferred := inferType.get.dump + let dumpExpected := type.get.dump + throw s!"Type mismatch on {term.tag}\n inferred: {ppInferred}\n expected: {ppExpected}\n inferred dump: {dumpInferred}\n expected dump: {dumpExpected}\n inferred info: {inferType.info.pp}\n expected info: {type.info.pp}" + pure te + + /-- Infer the type of an expression, returning the typed expression and its type. -/ + partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × SusValue m) := withFuelCheck do + if (← read).trace then dbg_trace s!"infer: {term.tag}" + match term with + | .bvar idx bvarName => do + let ctx ← read + if idx < ctx.lvl then + let some type := listGet? ctx.types idx + | throw s!"var@{idx} out of environment range (size {ctx.types.length})" + let te : TypedExpr m := ⟨← infoFromType type, .bvar idx bvarName⟩ + pure (te, type) + else + -- Mutual reference + match ctx.mutTypes.get? (idx - ctx.lvl) with + | some (addr, typeValFn) => + if some addr == ctx.recAddr? then + throw s!"Invalid recursion" + let univs := ctx.env.univs.toArray + let type := typeValFn univs + let name ← lookupName addr + let te : TypedExpr m := ⟨← infoFromType type, .const addr univs name⟩ + pure (te, type) + | none => + throw s!"var@{idx} out of environment range and does not represent a mutual constant" + | .sort lvl => do + let univs := (← read).env.univs.toArray + let lvl := Level.instBulkReduce univs lvl + let lvl' := Level.succ lvl + let typ : SusValue m := .mk (.sort (Level.succ lvl')) (.mk fun _ => .sort lvl') + let te : TypedExpr m := ⟨.sort lvl', .sort lvl⟩ + pure (te, typ) + | .app fnc arg => do + let (fnTe, fncType) ← infer fnc + match fncType.get with + | .pi dom img piEnv _ _ => do + let argTe ← check arg dom + let ctx ← read + let stt ← get + let typ := suspend img { ctx with env := piEnv.extendWith (suspend argTe ctx stt) } stt + let te : TypedExpr m := ⟨← infoFromType typ, .app fnTe.body argTe.body⟩ + pure (te, typ) + | v => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a pi type, got {ppV}\n dump: {v.dump}\n fncType info: {fncType.info.pp}\n function: {fnc.pp}\n argument: {arg.pp}" + | .lam ty body lamName lamBi => do + let (domTe, _) ← isSort ty + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl lamName + let (bodTe, imgVal) ← withExtendedCtx var domVal (infer body) + let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ + let imgTE ← quoteTyped (ctx.lvl+1) imgVal.getTyped + let typ : SusValue m := ⟨← piInfo domVal.info imgVal.info, + Thunk.mk fun _ => Value.pi domVal imgTE ctx.env lamName lamBi⟩ + pure (te, typ) + | .forallE ty body piName _ => do + let (domTe, domLvl) ← isSort ty + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let domSusVal := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx domSusVal domVal do + let (imgTe, imgLvl) ← isSort body + let sortLvl := Level.reduceIMax domLvl imgLvl + let typ : SusValue m := .mk (.sort (Level.succ sortLvl)) (.mk fun _ => .sort sortLvl) + let te : TypedExpr m := ⟨← infoFromType typ, .forallE domTe.body imgTe.body piName default⟩ + pure (te, typ) + | .letE ty val body letName => do + let (tyTe, _) ← isSort ty + let ctx ← read + let stt ← get + let tyVal := suspend tyTe ctx stt + let valTe ← check val tyVal + let valVal := suspend valTe ctx stt + let (bodTe, typ) ← withExtendedCtx valVal tyVal (infer body) + let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ + pure (te, typ) + | .lit (.natVal _) => do + let prims := (← read).prims + let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.nat #[]) + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .lit (.strVal _) => do + let prims := (← read).prims + let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.string #[]) + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .const addr constUnivs _ => do + ensureTypedConst addr + let ctx ← read + let univs := ctx.env.univs.toArray + let reducedUnivs := constUnivs.toList.map (Level.instBulkReduce univs) + -- Check const type cache (must also match universe parameters) + match (← get).constTypeCache.get? addr with + | some (cachedUnivs, cachedTyp) => + if cachedUnivs == reducedUnivs then + let te : TypedExpr m := ⟨← infoFromType cachedTyp, term⟩ + pure (te, cachedTyp) + else + let tconst ← derefTypedConst addr + let env : ValEnv m := .mk [] reducedUnivs + let stt ← get + let typ := suspend tconst.type { ctx with env := env } stt + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } + let te : TypedExpr m := ⟨← infoFromType typ, term⟩ + pure (te, typ) + | none => + let tconst ← derefTypedConst addr + let env : ValEnv m := .mk [] reducedUnivs + let stt ← get + let typ := suspend tconst.type { ctx with env := env } stt + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } + let te : TypedExpr m := ⟨← infoFromType typ, term⟩ + pure (te, typ) + | .proj typeAddr idx struct _ => do + let (structTe, structType) ← infer struct + let (ctorType, univs, params) ← getStructInfo structType.get + let mut ct ← applyType (← withEnv (.mk [] univs) (eval ctorType)) params.reverse + for i in [:idx] do + match ct with + | .pi dom img piEnv _ _ => do + let info ← infoFromType dom + let ctx ← read + let stt ← get + let proj := suspend ⟨info, .proj typeAddr i structTe.body default⟩ ctx stt + ct ← withNewExtendedEnv piEnv proj (eval img) + | _ => pure () + match ct with + | .pi dom _ _ _ _ => + let te : TypedExpr m := ⟨← infoFromType dom, .proj typeAddr idx structTe.body default⟩ + pure (te, dom) + | _ => throw "Impossible case: structure type does not have enough fields" + + /-- Check if an expression is a Sort, returning the typed expr and the universe level. -/ + partial def isSort (expr : Expr m) : TypecheckM m (TypedExpr m × Level m) := do + let (te, typ) ← infer expr + match typ.get with + | .sort u => pure (te, u) + | v => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a sort type, got {ppV}\n expr: {expr.pp}" + + /-- Get structure info from a value that should be a structure type. -/ + partial def getStructInfo (v : Value m) : + TypecheckM m (TypedExpr m × List (Level m) × List (SusValue m)) := do + match v with + | .app (.const indAddr univs _) params _ => + match (← read).kenv.find? indAddr with + | some (.inductInfo v) => + if v.ctors.size != 1 || params.length != v.numParams then + throw s!"Expected a structure type, but {v.name} ({indAddr}) has {v.ctors.size} ctors and {params.length}/{v.numParams} params" + ensureTypedConst indAddr + let ctorAddr := v.ctors[0]! + ensureTypedConst ctorAddr + match (← get).typedConsts.get? ctorAddr with + | some (.constructor type _ _) => + return (type, univs.toList, params) + | _ => throw s!"Constructor {ctorAddr} is not in typed consts" + | some ci => throw s!"Expected a structure type, but {indAddr} is a {ci.kindName}" + | none => throw s!"Expected a structure type, but {indAddr} not found in env" + | _ => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a structure type, got {ppV}" + + /-- Typecheck a constant. With fresh state per declaration, dependencies get + provisional entries via `ensureTypedConst` and are assumed well-typed. -/ + partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do + -- Reset fuel and per-constant caches + modify fun stt => { stt with + fuel := defaultFuel + evalCache := {} + equalCache := {} + constTypeCache := {} } + -- Skip if already in typedConsts (provisional entry is fine — dependency assumed well-typed) + if (← get).typedConsts.get? addr |>.isSome then + return () + let ci ← derefConst addr + let univs := ci.cv.mkUnivParams + withEnv (.mk [] univs.toList) do + let newConst ← match ci with + | .axiomInfo _ => + let (type, _) ← isSort ci.type + pure (TypedConst.axiom type) + | .opaqueInfo _ => + let (type, _) ← isSort ci.type + let typeSus := suspend type (← read) (← get) + let value ← withRecAddr addr (check ci.value?.get! typeSus) + pure (TypedConst.opaque type value) + | .thmInfo _ => + let (type, lvl) ← isSort ci.type + if !Level.isZero lvl then + throw s!"theorem type must be a proposition (Sort 0)" + let typeSus := suspend type (← read) (← get) + let value ← withRecAddr addr (check ci.value?.get! typeSus) + pure (TypedConst.theorem type value) + | .defnInfo v => + let (type, _) ← isSort ci.type + let ctx ← read + let stt ← get + let typeSus := suspend type ctx stt + let part := v.safety == .partial + let value ← + if part then + let typeSusFn := suspend type { ctx with env := ValEnv.mk ctx.env.exprs ctx.env.univs } stt + let mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare := + (Std.TreeMap.empty).insert 0 (addr, fun _ => typeSusFn) + withMutTypes mutTypes (withRecAddr addr (check v.value typeSus)) + else withRecAddr addr (check v.value typeSus) + pure (TypedConst.definition type value part) + | .quotInfo v => + let (type, _) ← isSort ci.type + pure (TypedConst.quotient type v.kind) + | .inductInfo _ => + checkIndBlock addr + return () + | .ctorInfo v => + checkIndBlock v.induct + return () + | .recInfo v => do + -- Extract the major premise's inductive from the recursor type + let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices + |>.getD default + -- Ensure the inductive has a provisional entry (assumed well-typed with fresh state per decl) + ensureTypedConst indAddr + -- Check recursor type + let (type, _) ← isSort ci.type + -- Check recursor rules + let typedRules ← v.rules.mapM fun rule => do + let (rhs, _) ← infer rule.rhs + pure (rule.nfields, rhs) + pure (TypedConst.recursor type v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules) + modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr newConst } + + /-- Walk a Pi chain to extract the return sort level (the universe of the result type). + Assumes the expression ends in `Sort u` after `numBinders` forall binders. -/ + partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m (Level m) := + match numBinders, expr with + | 0, .sort u => do + let univs := (← read).env.univs.toArray + pure (Level.instBulkReduce univs u) + | 0, _ => do + -- Not syntactically a sort; try to infer + let (_, typ) ← infer expr + match typ.get with + | .sort u => pure u + | _ => throw "inductive return type is not a sort" + | n+1, .forallE dom body _ _ => do + let (domTe, _) ← isSort dom + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl + withExtendedCtx var domVal (getReturnSort body n) + | _, _ => throw "inductive type has fewer binders than expected" + + /-- Typecheck a mutual inductive block starting from one of its addresses. -/ + partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do + let ci ← derefConst addr + -- Find the inductive info + let indInfo ← match ci with + | .inductInfo _ => pure ci + | .ctorInfo v => + match (← read).kenv.find? v.induct with + | some ind@(.inductInfo ..) => pure ind + | _ => throw "Constructor's inductive not found" + | _ => throw "Expected an inductive" + let .inductInfo iv := indInfo | throw "unreachable" + -- Check if already done + if (← get).typedConsts.get? addr |>.isSome then return () + -- Check the inductive type + let univs := iv.toConstantVal.mkUnivParams + let (type, _) ← withEnv (.mk [] univs.toList) (isSort iv.type) + let isStruct := !iv.isRec && iv.numIndices == 0 && iv.ctors.size == 1 && + match (← read).kenv.find? iv.ctors[0]! with + | some (.ctorInfo cv) => cv.numFields > 0 + | _ => false + modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr (TypedConst.inductive type isStruct) } + -- Check constructors + for (ctorAddr, cidx) in iv.ctors.toList.zipIdx do + match (← read).kenv.find? ctorAddr with + | some (.ctorInfo cv) => do + let ctorUnivs := cv.toConstantVal.mkUnivParams + let (ctorType, _) ← withEnv (.mk [] ctorUnivs.toList) (isSort cv.type) + modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cidx cv.numFields) } + | _ => throw s!"Constructor {ctorAddr} not found" + -- Note: recursors are checked individually via checkConst's .recInfo branch, + -- which calls checkConst on the inductives first then checks rules. +end -- mutual + +/-! ## Top-level entry points -/ + +/-- Typecheck a single constant by address. -/ +def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) + (quotInit : Bool := true) : Except String Unit := do + let ctx : TypecheckCtx m := { + lvl := 0, env := default, types := [], kenv := kenv, + prims := prims, safety := .safe, quotInit := quotInit, + mutTypes := default, recAddr? := none + } + let stt : TypecheckState m := { typedConsts := default } + TypecheckM.run ctx stt (checkConst addr) + +/-- Typecheck all constants in a kernel environment. + Uses fresh state per declaration — dependencies are assumed well-typed. -/ +def typecheckAll (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) + : Except String Unit := do + for (addr, ci) in kenv do + match typecheckConst kenv prims addr quotInit with + | .ok () => pure () + | .error e => + let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + let typ := ci.type.pp + let val := match ci.value? with + | some v => s!"\n value: {v.pp}" + | none => "" + throw s!"{header}: {e}\n type: {typ}{val}" + +/-- Typecheck all constants with IO progress reporting. + Uses fresh state per declaration — dependencies are assumed well-typed. -/ +def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) + : IO (Except String Unit) := do + let mut items : Array (Address × ConstantInfo m) := #[] + for (addr, ci) in kenv do + items := items.push (addr, ci) + let total := items.size + for h : idx in [:total] do + let (addr, ci) := items[idx] + --let typ := ci.type.pp + --let val := match ci.value? with + -- | some v => s!"\n value: {v.pp}" + -- | none => "" + let (typ, val) := ("_", "_") + (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})\n type: {typ}{val}" + (← IO.getStdout).flush + match typecheckConst kenv prims addr quotInit with + | .ok () => + (← IO.getStdout).putStrLn s!" ✓ {ci.cv.name}" + (← IO.getStdout).flush + | .error e => + let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + return .error s!"{header}: {e}\n type: {typ}{val}" + return .ok () + +end Ix.Kernel diff --git a/Ix/Kernel/Level.lean b/Ix/Kernel/Level.lean new file mode 100644 index 00000000..f22bcb53 --- /dev/null +++ b/Ix/Kernel/Level.lean @@ -0,0 +1,131 @@ +/- + Level normalization and comparison for `Level m`. + + Generic over MetaMode — metadata on `.param` is ignored. + Adapted from Yatima.Datatypes.Univ + Ix.IxVM.Level. +-/ +import Init.Data.Int +import Ix.Kernel.Types + +namespace Ix.Kernel + +namespace Level + +/-! ## Reduction -/ + +/-- Reduce `max a b` assuming `a` and `b` are already reduced. -/ +def reduceMax (a b : Level m) : Level m := + match a, b with + | .zero, _ => b + | _, .zero => a + | .succ a, .succ b => .succ (reduceMax a b) + | .param idx _, .param idx' _ => if idx == idx' then a else .max a b + | _, _ => .max a b + +/-- Reduce `imax a b` assuming `a` and `b` are already reduced. -/ +def reduceIMax (a b : Level m) : Level m := + match b with + | .zero => .zero + | .succ _ => reduceMax a b + | .param idx _ => match a with + | .param idx' _ => if idx == idx' then a else .imax a b + | _ => .imax a b + | _ => .imax a b + +/-- Reduce a level to normal form. -/ +def reduce : Level m → Level m + | .succ u => .succ (reduce u) + | .max a b => reduceMax (reduce a) (reduce b) + | .imax a b => + let b' := reduce b + match b' with + | .zero => .zero + | .succ _ => reduceMax (reduce a) b' + | _ => .imax (reduce a) b' + | u => u + +/-! ## Instantiation -/ + +/-- Instantiate a single variable and reduce. Assumes `subst` is already reduced. + Does not shift variables (used only in comparison algorithm). -/ +def instReduce (u : Level m) (idx : Nat) (subst : Level m) : Level m := + match u with + | .succ u => .succ (instReduce u idx subst) + | .max a b => reduceMax (instReduce a idx subst) (instReduce b idx subst) + | .imax a b => + let a' := instReduce a idx subst + let b' := instReduce b idx subst + match b' with + | .zero => .zero + | .succ _ => reduceMax a' b' + | _ => .imax a' b' + | .param idx' _ => if idx' == idx then subst else u + | .zero => u + +/-- Instantiate multiple variables at once and reduce. Substitutes `.param idx` by `substs[idx]`. + Assumes already reduced `substs`. -/ +def instBulkReduce (substs : Array (Level m)) : Level m → Level m + | z@(.zero ..) => z + | .succ u => .succ (instBulkReduce substs u) + | .max a b => reduceMax (instBulkReduce substs a) (instBulkReduce substs b) + | .imax a b => + let b' := instBulkReduce substs b + match b' with + | .zero => .zero + | .succ _ => reduceMax (instBulkReduce substs a) b' + | _ => .imax (instBulkReduce substs a) b' + | .param idx name => + if h : idx < substs.size then substs[idx] + else .param (idx - substs.size) name + +/-! ## Comparison -/ + +/-- Comparison algorithm: `a <= b + diff`. Assumes `a` and `b` are already reduced. -/ +partial def leq (a b : Level m) (diff : _root_.Int) : Bool := + if diff >= 0 && match a with | .zero => true | _ => false then true + else match a, b with + | .zero, .zero => diff >= 0 + -- Succ cases + | .succ a, _ => leq a b (diff - 1) + | _, .succ b => leq a b (diff + 1) + | .param .., .zero => false + | .zero, .param .. => diff >= 0 + | .param x _, .param y _ => x == y && diff >= 0 + -- IMax cases + | .imax _ (.param idx _), _ => + leq .zero (instReduce b idx .zero) diff && + let s := .succ (.param idx default) + leq (instReduce a idx s) (instReduce b idx s) diff + | .imax c (.max e f), _ => + let newMax := reduceMax (reduceIMax c e) (reduceIMax c f) + leq newMax b diff + | .imax c (.imax e f), _ => + let newMax := reduceMax (reduceIMax c f) (.imax e f) + leq newMax b diff + | _, .imax _ (.param idx _) => + leq (instReduce a idx .zero) .zero diff && + let s := .succ (.param idx default) + leq (instReduce a idx s) (instReduce b idx s) diff + | _, .imax c (.max e f) => + let newMax := reduceMax (reduceIMax c e) (reduceIMax c f) + leq a newMax diff + | _, .imax c (.imax e f) => + let newMax := reduceMax (reduceIMax c f) (.imax e f) + leq a newMax diff + -- Max cases + | .max c d, _ => leq c b diff && leq d b diff + | _, .max c d => leq a c diff || leq a d diff + | _, _ => false + +/-- Semantic equality of levels. Assumes `a` and `b` are already reduced. -/ +def equalLevel (a b : Level m) : Bool := + leq a b 0 && leq b a 0 + +/-- Faster equality for zero, assumes input is already reduced. -/ +def isZero : Level m → Bool + | .zero => true + | _ => false + +end Level + +end Ix.Kernel diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean new file mode 100644 index 00000000..8b1a93ba --- /dev/null +++ b/Ix/Kernel/TypecheckM.lean @@ -0,0 +1,180 @@ +/- + TypecheckM: Monad stack, context, state, and utilities for the kernel typechecker. +-/ +import Ix.Kernel.Datatypes +import Ix.Kernel.Level + +namespace Ix.Kernel + +/-! ## Typechecker Context -/ + +structure TypecheckCtx (m : MetaMode) where + lvl : Nat + env : ValEnv m + types : List (SusValue m) + kenv : Env m + prims : Primitives + safety : DefinitionSafety + quotInit : Bool + /-- Maps a variable index (mutual reference) to (address, type-value function). -/ + mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare + /-- Tracks the address of the constant currently being checked, for recursion detection. -/ + recAddr? : Option Address + /-- Depth fuel: bounds the call-stack depth to prevent native stack overflow. + Decremented via the reader on each entry to eval/equal/infer. + Thunks inherit the depth from their capture point. -/ + depth : Nat := 3000 + /-- Enable dbg_trace on major entry points for debugging. -/ + trace : Bool := false + deriving Inhabited + +/-! ## Typechecker State -/ + +/-- Default fuel for bounding total recursive work per constant. -/ +def defaultFuel : Nat := 100000 + +structure TypecheckState (m : MetaMode) where + typedConsts : Std.TreeMap Address (TypedConst m) Address.compare + /-- Fuel counter for bounding total recursive work. Decremented on each entry to + eval/equal/infer. Reset at the start of each `checkConst` call. -/ + fuel : Nat := defaultFuel + /-- Cache for evaluated constant definitions. Maps an address to its universe + parameters and evaluated value. Universe-polymorphic constants produce different + values for different universe instantiations, so we store and check univs. -/ + evalCache : Std.HashMap Address (Array (Level m) × Value m) := {} + /-- Cache for definitional equality results. Maps `(ptrAddrUnsafe a, ptrAddrUnsafe b)` + (canonicalized so smaller pointer comes first) to `Bool`. Only `true` results are + cached (monotone under state growth). -/ + equalCache : Std.HashMap (USize × USize) Bool := {} + /-- Cache for constant type SusValues. When `infer (.const addr _)` computes a + suspended type, it is cached here so repeated references to the same constant + share the same SusValue pointer, enabling fast-path pointer equality in `equal`. + Stores universe parameters alongside the value for correctness with polymorphic constants. -/ + constTypeCache : Std.HashMap Address (List (Level m) × SusValue m) := {} + deriving Inhabited + +/-! ## TypecheckM monad -/ + +abbrev TypecheckM (m : MetaMode) := ReaderT (TypecheckCtx m) (StateT (TypecheckState m) (Except String)) + +def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) : Except String α := + match (StateT.run (ReaderT.run x ctx) stt) with + | .error e => .error e + | .ok (a, _) => .ok a + +def TypecheckM.runState (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) + : Except String (α × TypecheckState m) := + StateT.run (ReaderT.run x ctx) stt + +/-! ## Context modifiers -/ + +def withEnv (env : ValEnv m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := env } + +def withResetCtx : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with + lvl := 0, env := default, types := default, mutTypes := default, recAddr? := none } + +def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare) : + TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with mutTypes := mutTypes } + +def withExtendedCtx (val typ : SusValue m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with + lvl := ctx.lvl + 1, + types := typ :: ctx.types, + env := ctx.env.extendWith val } + +def withExtendedEnv (thunk : SusValue m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := ctx.env.extendWith thunk } + +def withNewExtendedEnv (env : ValEnv m) (thunk : SusValue m) : + TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := env.extendWith thunk } + +def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with recAddr? := some addr } + +/-- Check both fuel counters, decrement them, and run the action. + - State fuel bounds total work (prevents exponential blowup / hanging). + - Reader depth bounds call-stack depth (prevents native stack overflow). -/ +def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do + let ctx ← read + if ctx.depth == 0 then + throw "deep recursion depth limit reached" + let stt ← get + if stt.fuel == 0 then throw "deep recursion work limit reached" + set { stt with fuel := stt.fuel - 1 } + withReader (fun ctx => { ctx with depth := ctx.depth - 1 }) action + +/-! ## Name lookup -/ + +/-- Look up the MetaField name for a constant address from the kernel environment. -/ +def lookupName (addr : Address) : TypecheckM m (MetaField m Ix.Name) := do + match (← read).kenv.find? addr with + | some ci => pure ci.cv.name + | none => pure default + +/-! ## Const dereferencing -/ + +def derefConst (addr : Address) : TypecheckM m (ConstantInfo m) := do + let ctx ← read + match ctx.kenv.find? addr with + | some ci => pure ci + | none => throw s!"unknown constant {addr}" + +def derefTypedConst (addr : Address) : TypecheckM m (TypedConst m) := do + match (← get).typedConsts.get? addr with + | some tc => pure tc + | none => throw s!"typed constant not found: {addr}" + +/-! ## Provisional TypedConst -/ + +/-- Extract the major premise's inductive address from a recursor type. + Skips numParams + numMotives + numMinors + numIndices foralls, + then the next forall's domain's app head is the inductive const. -/ +def getMajorInduct (type : Expr m) (numParams numMotives numMinors numIndices : Nat) : Option Address := + go (numParams + numMotives + numMinors + numIndices) type +where + go : Nat → Expr m → Option Address + | 0, e => match e with + | .forallE dom _ _ _ => some dom.getAppFn.constAddr! + | _ => none + | n+1, e => match e with + | .forallE _ body _ _ => go n body + | _ => none + +/-- Build a provisional TypedConst entry from raw ConstantInfo. + Used when `infer` encounters a `.const` reference before the constant + has been fully typechecked. The entry uses default TypeInfo and raw + expressions directly from the kernel environment. -/ +def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := + let rawType : TypedExpr m := ⟨default, ci.type⟩ + match ci with + | .axiomInfo _ => .axiom rawType + | .thmInfo v => .theorem rawType ⟨default, v.value⟩ + | .defnInfo v => + .definition rawType ⟨default, v.value⟩ (v.safety == .partial) + | .opaqueInfo v => .opaque rawType ⟨default, v.value⟩ + | .quotInfo v => .quotient rawType v.kind + | .inductInfo v => + let isStruct := v.ctors.size == 1 -- approximate; refined by checkIndBlock + .inductive rawType isStruct + | .ctorInfo v => .constructor rawType v.cidx v.numFields + | .recInfo v => + let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices + |>.getD default + let typedRules := v.rules.map fun r => (r.nfields, (⟨default, r.rhs⟩ : TypedExpr m)) + .recursor rawType v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules + +/-- Ensure a constant has a TypedConst entry. If not already present, build a + provisional one from raw ConstantInfo. This avoids the deep recursion of + `checkConst` when called from `infer`. -/ +def ensureTypedConst (addr : Address) : TypecheckM m Unit := do + if (← get).typedConsts.get? addr |>.isSome then return () + let ci ← derefConst addr + let tc := provisionalTypedConst ci + modify fun stt => { stt with + typedConsts := stt.typedConsts.insert addr tc } + +end Ix.Kernel diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean new file mode 100644 index 00000000..fba45b00 --- /dev/null +++ b/Ix/Kernel/Types.lean @@ -0,0 +1,569 @@ +/- + Kernel Types: Closure-based typechecker types with compile-time metadata erasure. + + The MetaMode flag controls whether name/binder metadata is present: + - `Expr .meta` carries full names and binder info (for debugging) + - `Expr .anon` has Unit fields (proven no metadata leakage) +-/ +import Ix.Address +import Ix.Environment + +namespace Ix.Kernel + +/-! ## MetaMode and MetaField -/ + +inductive MetaMode where | «meta» | anon + +def MetaField (m : MetaMode) (α : Type) : Type := + match m with + | .meta => α + | .anon => Unit + +instance {m : MetaMode} {α : Type} [Inhabited α] : Inhabited (MetaField m α) := + match m with + | .meta => inferInstanceAs (Inhabited α) + | .anon => ⟨()⟩ + +instance {m : MetaMode} {α : Type} [BEq α] : BEq (MetaField m α) := + match m with + | .meta => inferInstanceAs (BEq α) + | .anon => ⟨fun _ _ => true⟩ + +instance {m : MetaMode} {α : Type} [Repr α] : Repr (MetaField m α) := + match m with + | .meta => inferInstanceAs (Repr α) + | .anon => ⟨fun _ _ => "()".toFormat⟩ + +instance {m : MetaMode} {α : Type} [ToString α] : ToString (MetaField m α) := + match m with + | .meta => inferInstanceAs (ToString α) + | .anon => ⟨fun _ => "()"⟩ + +instance {m : MetaMode} {α : Type} [Ord α] : Ord (MetaField m α) := + match m with + | .meta => inferInstanceAs (Ord α) + | .anon => ⟨fun _ _ => .eq⟩ + +/-! ## Level -/ + +inductive Level (m : MetaMode) where + | zero + | succ (l : Level m) + | max (l₁ l₂ : Level m) + | imax (l₁ l₂ : Level m) + | param (idx : Nat) (name : MetaField m Ix.Name) + deriving Inhabited, BEq + +/-! ## Expr -/ + +inductive Expr (m : MetaMode) where + | bvar (idx : Nat) (name : MetaField m Ix.Name) + | sort (level : Level m) + | const (addr : Address) (levels : Array (Level m)) + (name : MetaField m Ix.Name) + | app (fn arg : Expr m) + | lam (ty body : Expr m) + (name : MetaField m Ix.Name) (bi : MetaField m Lean.BinderInfo) + | forallE (ty body : Expr m) + (name : MetaField m Ix.Name) (bi : MetaField m Lean.BinderInfo) + | letE (ty val body : Expr m) + (name : MetaField m Ix.Name) + | lit (l : Lean.Literal) + | proj (typeAddr : Address) (idx : Nat) (struct : Expr m) + (typeName : MetaField m Ix.Name) + deriving Inhabited, BEq + +/-! ## Pretty printing helpers -/ + +private def succCount : Level m → Nat → Nat × Level m + | .succ l, n => succCount l (n + 1) + | l, n => (n, l) + +private partial def ppLevel : Level m → String + | .zero => "0" + | .succ l => + let (n, base) := succCount l 1 + match base with + | .zero => toString n + | _ => s!"{ppLevel base} + {n}" + | .max l₁ l₂ => s!"max ({ppLevel l₁}) ({ppLevel l₂})" + | .imax l₁ l₂ => s!"imax ({ppLevel l₁}) ({ppLevel l₂})" + | .param idx name => + let s := s!"{name}" + if s == "()" then s!"u_{idx}" else s + +private def ppSort (l : Level m) : String := + match l with + | .zero => "Prop" + | .succ .zero => "Type" + | .succ l' => + let s := ppLevel l' + if s.any (· == ' ') then s!"Type ({s})" else s!"Type {s}" + | _ => + let s := ppLevel l + if s.any (· == ' ') then s!"Sort ({s})" else s!"Sort {s}" + +private def ppBinderName (name : MetaField m Ix.Name) : String := + let s := s!"{name}" + if s == "()" then "_" + else if s.isEmpty then "???" + else s + +private def ppVarName (name : MetaField m Ix.Name) (idx : Nat) : String := + let s := s!"{name}" + if s == "()" then s!"^{idx}" + else if s.isEmpty then "???" + else s + +private def ppConstName (name : MetaField m Ix.Name) (addr : Address) : String := + let s := s!"{name}" + if s == "()" then s!"#{String.ofList ((toString addr).toList.take 8)}" + else if s.isEmpty then s!"{addr}" + else s + +/-! ## Expr smart constructors -/ + +namespace Expr + +def mkBVar (idx : Nat) : Expr m := .bvar idx default +def mkSort (level : Level m) : Expr m := .sort level +def mkConst (addr : Address) (levels : Array (Level m)) : Expr m := + .const addr levels default +def mkApp (fn arg : Expr m) : Expr m := .app fn arg +def mkLam (ty body : Expr m) : Expr m := .lam ty body default default +def mkForallE (ty body : Expr m) : Expr m := .forallE ty body default default +def mkLetE (ty val body : Expr m) : Expr m := .letE ty val body default +def mkLit (l : Lean.Literal) : Expr m := .lit l +def mkProj (typeAddr : Address) (idx : Nat) (struct : Expr m) : Expr m := + .proj typeAddr idx struct default + +/-! ### Predicates -/ + +def isSort : Expr m → Bool | sort .. => true | _ => false +def isForall : Expr m → Bool | forallE .. => true | _ => false +def isLambda : Expr m → Bool | lam .. => true | _ => false +def isApp : Expr m → Bool | app .. => true | _ => false +def isLit : Expr m → Bool | lit .. => true | _ => false +def isConst : Expr m → Bool | const .. => true | _ => false +def isBVar : Expr m → Bool | bvar .. => true | _ => false + +def isConstOf (e : Expr m) (addr : Address) : Bool := + match e with | const a _ _ => a == addr | _ => false + +/-! ### Accessors -/ + +def bvarIdx! : Expr m → Nat | bvar i _ => i | _ => panic! "bvarIdx!" +def sortLevel! : Expr m → Level m | sort l => l | _ => panic! "sortLevel!" +def bindingDomain! : Expr m → Expr m + | forallE ty _ _ _ => ty | lam ty _ _ _ => ty | _ => panic! "bindingDomain!" +def bindingBody! : Expr m → Expr m + | forallE _ b _ _ => b | lam _ b _ _ => b | _ => panic! "bindingBody!" +def appFn! : Expr m → Expr m | app f _ => f | _ => panic! "appFn!" +def appArg! : Expr m → Expr m | app _ a => a | _ => panic! "appArg!" +def constAddr! : Expr m → Address | const a _ _ => a | _ => panic! "constAddr!" +def constLevels! : Expr m → Array (Level m) | const _ ls _ => ls | _ => panic! "constLevels!" +def litValue! : Expr m → Lean.Literal | lit l => l | _ => panic! "litValue!" +def projIdx! : Expr m → Nat | proj _ i _ _ => i | _ => panic! "projIdx!" +def projStruct! : Expr m → Expr m | proj _ _ s _ => s | _ => panic! "projStruct!" +def projTypeAddr! : Expr m → Address | proj a _ _ _ => a | _ => panic! "projTypeAddr!" + +/-! ### App Spine -/ + +def getAppFn : Expr m → Expr m + | app f _ => getAppFn f + | e => e + +def getAppNumArgs : Expr m → Nat + | app f _ => getAppNumArgs f + 1 + | _ => 0 + +partial def getAppRevArgs (e : Expr m) : Array (Expr m) := + go e #[] +where + go : Expr m → Array (Expr m) → Array (Expr m) + | app f a, acc => go f (acc.push a) + | _, acc => acc + +def getAppArgs (e : Expr m) : Array (Expr m) := + e.getAppRevArgs.reverse + +def mkAppN (fn : Expr m) (args : Array (Expr m)) : Expr m := + args.foldl (fun acc a => mkApp acc a) fn + +def mkAppRange (fn : Expr m) (start stop : Nat) (args : Array (Expr m)) : Expr m := Id.run do + let mut r := fn + for i in [start:stop] do + r := mkApp r args[i]! + return r + +def prop : Expr m := mkSort .zero + +partial def pp (atom : Bool := false) : Expr m → String + | .bvar idx name => ppVarName name idx + | .sort level => ppSort level + | .const addr _ name => ppConstName name addr + | .app fn arg => + let s := s!"{pp false fn} {pp true arg}" + if atom then s!"({s})" else s + | .lam ty body name _ => + let s := ppLam s!"({ppBinderName name} : {pp false ty})" body + if atom then s!"({s})" else s + | .forallE ty body name _ => + let s := ppPi s!"({ppBinderName name} : {pp false ty})" body + if atom then s!"({s})" else s + | .letE ty val body name => + let s := s!"let {ppBinderName name} : {pp false ty} := {pp false val}; {pp false body}" + if atom then s!"({s})" else s + | .lit (.natVal n) => toString n + | .lit (.strVal s) => s!"\"{s}\"" + | .proj _ idx struct _ => s!"{pp true struct}.{idx}" +where + ppLam (acc : String) : Expr m → String + | .lam ty body name _ => + ppLam s!"{acc} ({ppBinderName name} : {pp false ty})" body + | e => s!"λ {acc} => {pp false e}" + ppPi (acc : String) : Expr m → String + | .forallE ty body name _ => + ppPi s!"{acc} ({ppBinderName name} : {pp false ty})" body + | e => s!"∀ {acc}, {pp false e}" + +/-- Short constructor tag for tracing (no recursion into subterms). -/ +def tag : Expr m → String + | .bvar idx _ => s!"bvar({idx})" + | .sort _ => "sort" + | .const _ _ name => s!"const({name})" + | .app .. => "app" + | .lam .. => "lam" + | .forallE .. => "forallE" + | .letE .. => "letE" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit({s})" + | .proj _ idx _ _ => s!"proj({idx})" + +end Expr + +/-! ## Enums -/ + +inductive DefinitionSafety where + | safe | «unsafe» | «partial» + deriving BEq, Repr, Inhabited + +inductive ReducibilityHints where + | opaque | abbrev | regular (height : UInt32) + deriving BEq, Repr, Inhabited + +namespace ReducibilityHints + +def lt' : ReducibilityHints → ReducibilityHints → Bool + | .regular d₁, .regular d₂ => d₁ < d₂ + | .regular _, .opaque => true + | .abbrev, .opaque => true + | _, _ => false + +def isRegular : ReducibilityHints → Bool + | .regular _ => true + | _ => false + +end ReducibilityHints + +inductive QuotKind where + | type | ctor | lift | ind + deriving BEq, Repr, Inhabited + +/-! ## ConstantInfo -/ + +structure ConstantVal (m : MetaMode) where + numLevels : Nat + type : Expr m + name : MetaField m Ix.Name + levelParams : MetaField m (Array Ix.Name) + deriving Inhabited + +def ConstantVal.mkUnivParams (cv : ConstantVal m) : Array (Level m) := + match m with + | .meta => + let lps : Array Ix.Name := cv.levelParams + Array.ofFn (n := cv.numLevels) fun i => + .param i.val (if h : i.val < lps.size then lps[i.val] else default) + | .anon => Array.ofFn (n := cv.numLevels) fun i => .param i.val () + +structure AxiomVal (m : MetaMode) extends ConstantVal m where + isUnsafe : Bool + +structure DefinitionVal (m : MetaMode) extends ConstantVal m where + value : Expr m + hints : ReducibilityHints + safety : DefinitionSafety + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure TheoremVal (m : MetaMode) extends ConstantVal m where + value : Expr m + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure OpaqueVal (m : MetaMode) extends ConstantVal m where + value : Expr m + isUnsafe : Bool + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure QuotVal (m : MetaMode) extends ConstantVal m where + kind : QuotKind + +structure InductiveVal (m : MetaMode) extends ConstantVal m where + numParams : Nat + numIndices : Nat + all : Array Address + ctors : Array Address + allNames : MetaField m (Array Ix.Name) := default + ctorNames : MetaField m (Array Ix.Name) := default + numNested : Nat + isRec : Bool + isUnsafe : Bool + isReflexive : Bool + +structure ConstructorVal (m : MetaMode) extends ConstantVal m where + induct : Address + inductName : MetaField m Ix.Name := default + cidx : Nat + numParams : Nat + numFields : Nat + isUnsafe : Bool + +structure RecursorRule (m : MetaMode) where + ctor : Address + ctorName : MetaField m Ix.Name := default + nfields : Nat + rhs : Expr m + +structure RecursorVal (m : MetaMode) extends ConstantVal m where + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + numParams : Nat + numIndices : Nat + numMotives : Nat + numMinors : Nat + rules : Array (RecursorRule m) + k : Bool + isUnsafe : Bool + +inductive ConstantInfo (m : MetaMode) where + | axiomInfo (val : AxiomVal m) + | defnInfo (val : DefinitionVal m) + | thmInfo (val : TheoremVal m) + | opaqueInfo (val : OpaqueVal m) + | quotInfo (val : QuotVal m) + | inductInfo (val : InductiveVal m) + | ctorInfo (val : ConstructorVal m) + | recInfo (val : RecursorVal m) + +namespace ConstantInfo + +def cv : ConstantInfo m → ConstantVal m + | axiomInfo v => v.toConstantVal + | defnInfo v => v.toConstantVal + | thmInfo v => v.toConstantVal + | opaqueInfo v => v.toConstantVal + | quotInfo v => v.toConstantVal + | inductInfo v => v.toConstantVal + | ctorInfo v => v.toConstantVal + | recInfo v => v.toConstantVal + +def numLevels (ci : ConstantInfo m) : Nat := ci.cv.numLevels +def type (ci : ConstantInfo m) : Expr m := ci.cv.type + +def isUnsafe : ConstantInfo m → Bool + | axiomInfo v => v.isUnsafe + | defnInfo v => v.safety == .unsafe + | thmInfo _ => false + | opaqueInfo v => v.isUnsafe + | quotInfo _ => false + | inductInfo v => v.isUnsafe + | ctorInfo v => v.isUnsafe + | recInfo v => v.isUnsafe + +def hasValue : ConstantInfo m → Bool + | defnInfo .. | thmInfo .. | opaqueInfo .. => true + | _ => false + +def value? : ConstantInfo m → Option (Expr m) + | defnInfo v => some v.value + | thmInfo v => some v.value + | opaqueInfo v => some v.value + | _ => none + +def hints : ConstantInfo m → ReducibilityHints + | defnInfo v => v.hints + | _ => .opaque + +def safety : ConstantInfo m → DefinitionSafety + | defnInfo v => v.safety + | _ => .safe + +def all? : ConstantInfo m → Option (Array Address) + | defnInfo v => some v.all + | thmInfo v => some v.all + | opaqueInfo v => some v.all + | inductInfo v => some v.all + | recInfo v => some v.all + | _ => none + +def kindName : ConstantInfo m → String + | axiomInfo .. => "axiom" + | defnInfo .. => "definition" + | thmInfo .. => "theorem" + | opaqueInfo .. => "opaque" + | quotInfo .. => "quotient" + | inductInfo .. => "inductive" + | ctorInfo .. => "constructor" + | recInfo .. => "recursor" + +end ConstantInfo + +/-! ## Kernel.Env -/ + +def Address.compare (a b : Address) : Ordering := Ord.compare a b + +structure EnvId (m : MetaMode) where + addr : Address + name : MetaField m Ix.Name + +instance : Inhabited (EnvId m) where + default := ⟨default, default⟩ + +instance : BEq (EnvId m) where + beq a b := a.addr == b.addr && a.name == b.name + +def EnvId.compare (a b : EnvId m) : Ordering := + match Address.compare a.addr b.addr with + | .eq => Ord.compare a.name b.name + | ord => ord + +structure Env (m : MetaMode) where + entries : Std.TreeMap (EnvId m) (ConstantInfo m) EnvId.compare + addrIndex : Std.TreeMap Address (EnvId m) Address.compare + +instance : Inhabited (Env m) where + default := { entries := .empty, addrIndex := .empty } + +instance : ForIn n (Env m) (Address × ConstantInfo m) where + forIn env init f := + ForIn.forIn env.entries init fun p acc => f (p.1.addr, p.2) acc + +namespace Env + +def find? (env : Env m) (addr : Address) : Option (ConstantInfo m) := + match env.addrIndex.get? addr with + | some id => env.entries.get? id + | none => none + +def findByEnvId (env : Env m) (id : EnvId m) : Option (ConstantInfo m) := + env.entries.get? id + +def get (env : Env m) (addr : Address) : Except String (ConstantInfo m) := + match env.find? addr with + | some ci => .ok ci + | none => .error s!"unknown constant {addr}" + +def insert (env : Env m) (addr : Address) (ci : ConstantInfo m) : Env m := + let id : EnvId m := ⟨addr, ci.cv.name⟩ + let entries := env.entries.insert id ci + let addrIndex := match env.addrIndex.get? addr with + | some _ => env.addrIndex + | none => env.addrIndex.insert addr id + { entries, addrIndex } + +def add (env : Env m) (addr : Address) (ci : ConstantInfo m) : Env m := + env.insert addr ci + +def size (env : Env m) : Nat := + env.addrIndex.size + +def contains (env : Env m) (addr : Address) : Bool := + env.addrIndex.get? addr |>.isSome + +def isStructureLike (env : Env m) (addr : Address) : Bool := + match env.find? addr with + | some (.inductInfo v) => + !v.isRec && v.numIndices == 0 && v.ctors.size == 1 && + match env.find? v.ctors[0]! with + | some (.ctorInfo cv) => cv.numFields > 0 + | _ => false + | _ => false + +end Env + +/-! ## Primitives -/ + +private def addr! (s : String) : Address := + match Address.fromString s with + | some a => a + | none => panic! s!"invalid hex address: {s}" + +structure Primitives where + nat : Address := default + natZero : Address := default + natSucc : Address := default + natAdd : Address := default + natSub : Address := default + natMul : Address := default + natPow : Address := default + natGcd : Address := default + natMod : Address := default + natDiv : Address := default + natBeq : Address := default + natBle : Address := default + natLand : Address := default + natLor : Address := default + natXor : Address := default + natShiftLeft : Address := default + natShiftRight : Address := default + bool : Address := default + boolTrue : Address := default + boolFalse : Address := default + string : Address := default + stringMk : Address := default + char : Address := default + charMk : Address := default + list : Address := default + listNil : Address := default + listCons : Address := default + quotType : Address := default + quotCtor : Address := default + quotLift : Address := default + quotInd : Address := default + deriving Repr, Inhabited + +def buildPrimitives : Primitives := + { nat := addr! "fc0e1e912f2d7f12049a5b315d76eec29562e34dc39ebca25287ae58807db137" + natZero := addr! "fac82f0d2555d6a63e1b8a1fe8d86bd293197f39c396fdc23c1275c60f182b37" + natSucc := addr! "7190ce56f6a2a847b944a355e3ec595a4036fb07e3c3db9d9064fc041be72b64" + natAdd := addr! "dcc96f3f914e363d1e906a8be4c8f49b994137bfdb077d07b6c8a4cf88a4f7bf" + natSub := addr! "6903e9bbd169b6c5515b27b3fc0c289ba2ff8e7e0c7f984747d572de4e6a7853" + natMul := addr! "8e641c3df8fe3878e5a219c888552802743b9251c3c37c32795f5b9b9e0818a5" + natPow := addr! "d9be78292bb4e79c03daaaad82e756c5eb4dd5535d33b155ea69e5cbce6bc056" + natGcd := addr! "e8a3be39063744a43812e1f7b8785e3f5a4d5d1a408515903aa05d1724aeb465" + natMod := addr! "14031083457b8411f655765167b1a57fcd542c621e0c391b15ff5ee716c22a67" + natDiv := addr! "863c18d3a5b100a5a5e423c20439d8ab4941818421a6bcf673445335cc559e55" + natBeq := addr! "127a9d47a15fc2bf91a36f7c2182028857133b881554ece4df63344ec93eb2ce" + natBle := addr! "6e4c17dc72819954d6d6afc412a3639a07aff6676b0813cdc419809cc4513df5" + natLand := addr! "e1425deee6279e2db2ff649964b1a66d4013cc08f9e968fb22cc0a64560e181a" + natLor := addr! "3649a28f945b281bd8657e55f93ae0b8f8313488fb8669992a1ba1373cbff8f6" + natXor := addr! "a711ef2cb4fa8221bebaa17ef8f4a965cf30678a89bc45ff18a13c902e683cc5" + natShiftLeft := addr! "16e4558f51891516843a5b30ddd9d9b405ec096d3e1c728d09ff152b345dd607" + natShiftRight := addr! "b9515e6c2c6b18635b1c65ebca18b5616483ebd53936f78e4ae123f6a27a089e" + bool := addr! "6405a455ba70c2b2179c7966c6f610bf3417bd0f3dd2ba7a522533c2cd9e1d0b" + boolTrue := addr! "420dead2168abd16a7050edfd8e17d45155237d3118782d0e68b6de87742cb8d" + boolFalse := addr! "c127f89f92e0481f7a3e0631c5615fe7f6cbbf439d5fd7eba400fb0603aedf2f" + string := addr! "591cf1c489d505d4082f2767500f123e29db5227eb1bae4721eeedd672f36190" + stringMk := addr! "f055b87da4265d980cdede04ce5c7d986866e55816dc94d32a5d90e805101230" + char := addr! "563b426b73cdf1538b767308d12d10d746e1f0b3b55047085bf690319a86f893" + charMk := addr! "7156fef44bc309789375d784e5c36e387f7119363dd9cd349226c52df43d2075" + list := addr! "abed9ff1aba4634abc0bd3af76ca544285a32dcfe43dc27b129aea8867457620" + listNil := addr! "0ebe345dc46917c824b6c3f6c42b101f2ac8c0e2c99f033a0ee3c60acb9cd84d" + listCons := addr! "f79842f10206598929e6ba60ce3ebaa00d11f201c99e80285f46cc0e90932832" + -- Quot primitives need to be computed; use default until wired up + } + +end Ix.Kernel diff --git a/Main.lean b/Main.lean index 3d111f56..d775bf88 100644 --- a/Main.lean +++ b/Main.lean @@ -1,5 +1,6 @@ --import Ix.Cli.ProveCmd --import Ix.Cli.StoreCmd +import Ix.Cli.CheckCmd import Ix.Cli.CompileCmd import Ix.Cli.ServeCmd import Ix.Cli.ConnectCmd @@ -15,6 +16,7 @@ def ixCmd : Cli.Cmd := `[Cli| SUBCOMMANDS: --proveCmd; --storeCmd; + checkCmd; compileCmd; serveCmd; connectCmd diff --git a/Tests/Ix/Check.lean b/Tests/Ix/Check.lean new file mode 100644 index 00000000..404b478d --- /dev/null +++ b/Tests/Ix/Check.lean @@ -0,0 +1,107 @@ +/- + Kernel type-checker integration tests. + Tests both the Rust kernel (via FFI) and the Lean NbE kernel. +-/ + +import Ix.Kernel +import Ix.Common +import Ix.Meta +import Ix.CompileM +import Lean +import LSpec + +open LSpec + +namespace Tests.Check + +/-! ## Rust kernel tests -/ + +def testCheckEnv : TestSeq := + .individualIO "Rust kernel check_env" (do + let leanEnv ← get_env! + let totalConsts := leanEnv.constants.toList.length + + IO.println s!"[Check] Environment has {totalConsts} constants" + + let start ← IO.monoMsNow + let errors ← Ix.Kernel.rsCheckEnv leanEnv + let elapsed := (← IO.monoMsNow) - start + + IO.println s!"[Check] Rust kernel checked {totalConsts} constants in {elapsed.formatMs}" + + if errors.isEmpty then + IO.println s!"[Check] All constants passed" + return (true, none) + else + IO.println s!"[Check] {errors.size} error(s):" + for (name, err) in errors[:min 20 errors.size] do + IO.println s!" {repr name}: {repr err}" + return (false, some s!"Kernel check failed with {errors.size} error(s)") + ) .done + +def testCheckConst (name : String) : TestSeq := + .individualIO s!"check {name}" (do + let leanEnv ← get_env! + let start ← IO.monoMsNow + let result ← Ix.Kernel.rsCheckConst leanEnv name + let elapsed := (← IO.monoMsNow) - start + match result with + | none => + IO.println s!" [ok] {name} ({elapsed.formatMs})" + return (true, none) + | some err => + IO.println s!" [fail] {name}: {repr err} ({elapsed.formatMs})" + return (false, some s!"{name} failed: {repr err}") + ) .done + +/-! ## Lean NbE kernel tests -/ + +def testKernelCheckEnv : TestSeq := + .individualIO "Lean NbE kernel check_env" (do + let leanEnv ← get_env! + + IO.println s!"[Kernel-NbE] Compiling to Ixon..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileElapsed := (← IO.monoMsNow) - compileStart + let numConsts := ixonEnv.consts.size + IO.println s!"[Kernel-NbE] Compiled {numConsts} constants in {compileElapsed.formatMs}" + + IO.println s!"[Kernel-NbE] Converting..." + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[Kernel-NbE] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, prims, quotInit) => + let convertElapsed := (← IO.monoMsNow) - convertStart + IO.println s!"[Kernel-NbE] Converted {kenv.size} constants in {convertElapsed.formatMs}" + + IO.println s!"[Kernel-NbE] Typechecking {kenv.size} constants..." + let checkStart ← IO.monoMsNow + match ← Ix.Kernel.typecheckAllIO kenv prims quotInit with + | .error e => + let elapsed := (← IO.monoMsNow) - checkStart + IO.println s!"[Kernel-NbE] typecheckAll error in {elapsed.formatMs}: {e}" + return (false, some s!"Kernel NbE check failed: {e}") + | .ok () => + let elapsed := (← IO.monoMsNow) - checkStart + IO.println s!"[Kernel-NbE] All constants passed in {elapsed.formatMs}" + return (true, none) + ) .done + +/-! ## Test suites -/ + +def checkSuiteIO : List TestSeq := [ + testCheckConst "Nat.add", +] + +def checkAllSuiteIO : List TestSeq := [ + testCheckEnv, +] + +def kernelSuiteIO : List TestSeq := [ + testKernelCheckEnv, +] + +end Tests.Check diff --git a/Tests/Ix/Compile.lean b/Tests/Ix/Compile.lean index fa6dadff..af14f820 100644 --- a/Tests/Ix/Compile.lean +++ b/Tests/Ix/Compile.lean @@ -9,6 +9,8 @@ import Ix.Address import Ix.Common import Ix.Meta import Ix.CompileM +import Ix.DecompileM +import Ix.CanonM import Ix.CondenseM import Ix.GraphM import Ix.Sharing @@ -458,10 +460,79 @@ def testCrossImpl : TestSeq := return (false, some s!"Found {result.mismatchedConstants.size} mismatches") ) .done -/-! ## Test Suite -/ +/-! ## Lean → Ixon → Ix → Lean full roundtrip -/ + +/-- Full roundtrip: Rust-compile Lean env to Ixon, decompile back to Ix, uncanon back to Lean, + then structurally compare every constant against the original. -/ +def testIxonFullRoundtrip : TestSeq := + .individualIO "Lean→Ixon→Ix→Lean full roundtrip" (do + let leanEnv ← get_env! + let totalConsts := leanEnv.constants.toList.length + IO.println s!"[ixon-roundtrip] Lean env: {totalConsts} constants" + + -- Step 1: Rust compile to Ixon.Env + IO.println s!"[ixon-roundtrip] Step 1: Rust compile..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - compileStart + IO.println s!"[ixon-roundtrip] {ixonEnv.named.size} named, {ixonEnv.consts.size} consts in {compileMs}ms" + + -- Step 2: Decompile Ixon → Ix + IO.println s!"[ixon-roundtrip] Step 2: Decompile Ixon→Ix (parallel)..." + let decompStart ← IO.monoMsNow + let (ixConsts, decompErrors) := Ix.DecompileM.decompileAllParallel ixonEnv + let decompMs := (← IO.monoMsNow) - decompStart + IO.println s!"[ixon-roundtrip] {ixConsts.size} ok, {decompErrors.size} errors in {decompMs}ms" + if !decompErrors.isEmpty then + IO.println s!"[ixon-roundtrip] First errors:" + for (name, err) in decompErrors.toList.take 5 do + IO.println s!" {name}: {err}" + + -- Step 3: Uncanon Ix → Lean + IO.println s!"[ixon-roundtrip] Step 3: Uncanon Ix→Lean (parallel)..." + let uncanonStart ← IO.monoMsNow + let roundtripped := Ix.CanonM.uncanonEnvParallel ixConsts + let uncanonMs := (← IO.monoMsNow) - uncanonStart + IO.println s!"[ixon-roundtrip] {roundtripped.size} constants in {uncanonMs}ms" + + -- Step 4: Compare roundtripped Lean constants against originals + IO.println s!"[ixon-roundtrip] Step 4: Comparing against original..." + let compareStart ← IO.monoMsNow + let origMap : Std.HashMap Lean.Name Lean.ConstantInfo := + leanEnv.constants.fold (init := {}) fun acc name const => acc.insert name const + let (nMismatches, nMissing, mismatchNames, missingNames) := + Ix.CanonM.compareEnvsParallel origMap roundtripped + let compareMs := (← IO.monoMsNow) - compareStart + IO.println s!"[ixon-roundtrip] {nMissing} missing, {nMismatches} mismatches in {compareMs}ms" + + if !missingNames.isEmpty then + IO.println s!"[ixon-roundtrip] First missing:" + for name in missingNames.toList.take 10 do + IO.println s!" {name}" + + if !mismatchNames.isEmpty then + IO.println s!"[ixon-roundtrip] First mismatches:" + for name in mismatchNames.toList.take 20 do + IO.println s!" {name}" + + let totalMs := compileMs + decompMs + uncanonMs + compareMs + IO.println s!"[ixon-roundtrip] Total: {totalMs}ms" + + let success := decompErrors.size == 0 && nMismatches == 0 && nMissing == 0 + if success then + return (true, none) + else + return (false, some s!"{decompErrors.size} decompile errors, {nMismatches} mismatches, {nMissing} missing") + ) .done + +/-! ## Test Suites -/ def compileSuiteIO : List TestSeq := [ testCrossImpl, ] +def ixonRoundtripSuiteIO : List TestSeq := [ + testIxonFullRoundtrip, +] + end Tests.Compile diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean new file mode 100644 index 00000000..f1ed3c55 --- /dev/null +++ b/Tests/Ix/KernelTests.lean @@ -0,0 +1,761 @@ +/- + Kernel test suite. + - Unit tests for Kernel types, expression operations, and level operations + - Convert tests (Ixon.Env → Kernel.Env) + - Targeted constant-checking tests (individual constants through the full pipeline) +-/ +import Ix.Kernel +import Ix.Kernel.DecompileM +import Ix.CompileM +import Ix.Common +import Ix.Meta +import LSpec + +open LSpec +open Ix.Kernel + +namespace Tests.KernelTests + +/-! ## Unit tests: Expression equality -/ + +def testExprHashEq : TestSeq := + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv0' : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + test "mkBVar 0 == mkBVar 0" (bv0 == bv0') ++ + test "mkBVar 0 != mkBVar 1" (bv0 != bv1) ++ + -- Sort equality + let s0 : Expr .anon := Expr.mkSort Level.zero + let s0' : Expr .anon := Expr.mkSort Level.zero + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "mkSort 0 == mkSort 0" (s0 == s0') ++ + test "mkSort 0 != mkSort 1" (s0 != s1) ++ + -- App equality + let app1 := Expr.mkApp bv0 bv1 + let app1' := Expr.mkApp bv0 bv1 + let app2 := Expr.mkApp bv1 bv0 + test "mkApp bv0 bv1 == mkApp bv0 bv1" (app1 == app1') ++ + test "mkApp bv0 bv1 != mkApp bv1 bv0" (app1 != app2) ++ + -- Lambda equality + let lam1 := Expr.mkLam s0 bv0 + let lam1' := Expr.mkLam s0 bv0 + let lam2 := Expr.mkLam s1 bv0 + test "mkLam s0 bv0 == mkLam s0 bv0" (lam1 == lam1') ++ + test "mkLam s0 bv0 != mkLam s1 bv0" (lam1 != lam2) ++ + -- Forall equality + let pi1 := Expr.mkForallE s0 s1 + let pi1' := Expr.mkForallE s0 s1 + test "mkForallE s0 s1 == mkForallE s0 s1" (pi1 == pi1') ++ + -- Const equality + let addr1 := Address.blake3 (ByteArray.mk #[1]) + let addr2 := Address.blake3 (ByteArray.mk #[2]) + let c1 : Expr .anon := Expr.mkConst addr1 #[] + let c1' : Expr .anon := Expr.mkConst addr1 #[] + let c2 : Expr .anon := Expr.mkConst addr2 #[] + test "mkConst addr1 == mkConst addr1" (c1 == c1') ++ + test "mkConst addr1 != mkConst addr2" (c1 != c2) ++ + -- Const with levels + let c1l : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l' : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l2 : Expr .anon := Expr.mkConst addr1 #[Level.succ Level.zero] + test "mkConst addr1 [0] == mkConst addr1 [0]" (c1l == c1l') ++ + test "mkConst addr1 [0] != mkConst addr1 [1]" (c1l != c1l2) ++ + -- Literal equality + let nat0 : Expr .anon := Expr.mkLit (.natVal 0) + let nat0' : Expr .anon := Expr.mkLit (.natVal 0) + let nat1 : Expr .anon := Expr.mkLit (.natVal 1) + let str1 : Expr .anon := Expr.mkLit (.strVal "hello") + let str1' : Expr .anon := Expr.mkLit (.strVal "hello") + let str2 : Expr .anon := Expr.mkLit (.strVal "world") + test "lit nat 0 == lit nat 0" (nat0 == nat0') ++ + test "lit nat 0 != lit nat 1" (nat0 != nat1) ++ + test "lit str hello == lit str hello" (str1 == str1') ++ + test "lit str hello != lit str world" (str1 != str2) ++ + .done + +/-! ## Unit tests: Expression operations -/ + +def testExprOps : TestSeq := + -- getAppFn / getAppArgs + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + let bv2 : Expr .anon := Expr.mkBVar 2 + let app := Expr.mkApp (Expr.mkApp bv0 bv1) bv2 + test "getAppFn (app (app bv0 bv1) bv2) == bv0" (app.getAppFn == bv0) ++ + test "getAppNumArgs == 2" (app.getAppNumArgs == 2) ++ + test "getAppArgs[0] == bv1" (app.getAppArgs[0]! == bv1) ++ + test "getAppArgs[1] == bv2" (app.getAppArgs[1]! == bv2) ++ + -- mkAppN round-trips + let rebuilt := Expr.mkAppN bv0 #[bv1, bv2] + test "mkAppN round-trips" (rebuilt == app) ++ + -- Predicates + test "isApp" app.isApp ++ + test "isSort" (Expr.mkSort (Level.zero : Level .anon)).isSort ++ + test "isLambda" (Expr.mkLam bv0 bv1).isLambda ++ + test "isForall" (Expr.mkForallE bv0 bv1).isForall ++ + test "isLit" (Expr.mkLit (.natVal 42) : Expr .anon).isLit ++ + test "isBVar" bv0.isBVar ++ + test "isConst" (Expr.mkConst (m := .anon) default #[]).isConst ++ + -- Accessors + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "sortLevel!" (s1.sortLevel! == Level.succ Level.zero) ++ + test "bvarIdx!" (bv1.bvarIdx! == 1) ++ + .done + +/-! ## Unit tests: Level operations -/ + +def testLevelOps : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- reduce + test "reduce zero" (Level.reduce l0 == l0) ++ + test "reduce (succ zero)" (Level.reduce l1 == l1) ++ + -- equalLevel + test "zero equiv zero" (Level.equalLevel l0 l0) ++ + test "succ zero equiv succ zero" (Level.equalLevel l1 l1) ++ + test "max a b equiv max b a" + (Level.equalLevel (Level.max p0 p1) (Level.max p1 p0)) ++ + test "zero not equiv succ zero" (!Level.equalLevel l0 l1) ++ + -- leq + test "zero <= zero" (Level.leq l0 l0 0) ++ + test "succ zero <= zero + 1" (Level.leq l1 l0 1) ++ + test "not (succ zero <= zero)" (!Level.leq l1 l0 0) ++ + test "param 0 <= param 0" (Level.leq p0 p0 0) ++ + test "succ (param 0) <= param 0 + 1" + (Level.leq (Level.succ p0) p0 1) ++ + test "not (succ (param 0) <= param 0)" + (!Level.leq (Level.succ p0) p0 0) ++ + .done + +/-! ## Integration tests: Const pipeline -/ + +/-- Parse a dotted name string like "Nat.add" into an Ix.Name. -/ +private def parseIxName (s : String) : Ix.Name := + let parts := s.splitOn "." + parts.foldl (fun acc part => Ix.Name.mkStr acc part) Ix.Name.mkAnon + +/-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ +private partial def leanNameToIx : Lean.Name → Ix.Name + | .anonymous => Ix.Name.mkAnon + | .str pre s => Ix.Name.mkStr (leanNameToIx pre) s + | .num pre n => Ix.Name.mkNat (leanNameToIx pre) n + +def testConvertEnv : TestSeq := + .individualIO "rsCompileEnv + convertEnv" (do + let leanEnv ← get_env! + let leanCount := leanEnv.constants.toList.length + IO.println s!"[kernel] Lean env: {leanCount} constants" + let start ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - start + let ixonCount := ixonEnv.consts.size + let namedCount := ixonEnv.named.size + IO.println s!"[kernel] rsCompileEnv: {ixonCount} consts, {namedCount} named in {compileMs.formatMs}" + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[kernel] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, _, _) => + let convertMs := (← IO.monoMsNow) - convertStart + let kenvCount := kenv.size + IO.println s!"[kernel] convertEnv: {kenvCount} consts in {convertMs.formatMs} ({ixonCount - kenvCount} muts blocks)" + -- Verify every Lean constant is present in the Kernel.Env + let mut missing : Array String := #[] + let mut notCompiled : Array String := #[] + let mut checked := 0 + for (leanName, _) in leanEnv.constants.toList do + let ixName := leanNameToIx leanName + match ixonEnv.named.get? ixName with + | none => notCompiled := notCompiled.push (toString leanName) + | some named => + checked := checked + 1 + if !kenv.contains named.addr then + missing := missing.push (toString leanName) + if !notCompiled.isEmpty then + IO.println s!"[kernel] {notCompiled.size} Lean constants not in ixonEnv.named (unexpected)" + for n in notCompiled[:min 10 notCompiled.size] do + IO.println s!" not compiled: {n}" + if missing.isEmpty then + IO.println s!"[kernel] All {checked} named constants found in Kernel.Env" + return (true, none) + else + IO.println s!"[kernel] {missing.size}/{checked} named constants missing from Kernel.Env" + for n in missing[:min 20 missing.size] do + IO.println s!" missing: {n}" + return (false, some s!"{missing.size} constants missing from Kernel.Env") + ) .done + +/-- Const pipeline: compile, convert, typecheck specific constants. -/ +def testConstPipeline : TestSeq := + .individualIO "kernel const pipeline" (do + let leanEnv ← get_env! + let start ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - start + IO.println s!"[kernel] rsCompileEnv: {ixonEnv.consts.size} consts in {compileMs.formatMs}" + + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[kernel] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, prims, quotInit) => + let convertMs := (← IO.monoMsNow) - convertStart + IO.println s!"[kernel] convertEnv: {kenv.size} consts in {convertMs.formatMs}" + + -- Check specific constants + let constNames := #[ + "Nat", "Nat.zero", "Nat.succ", "Nat.rec", + "Bool", "Bool.true", "Bool.false", "Bool.rec", + "Eq", "Eq.refl", + "List", "List.nil", "List.cons", + "Nat.below" + ] + let checkStart ← IO.monoMsNow + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"{name}: {e}" + let checkMs := (← IO.monoMsNow) - checkStart + IO.println s!"[kernel] {passed}/{constNames.size} passed in {checkMs.formatMs}" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Primitive address verification -/ + +/-- Look up a primitive address by name (for verification only). -/ +private def lookupPrim (ixonEnv : Ixon.Env) (name : String) : Address := + let ixName := parseIxName name + match ixonEnv.named.get? ixName with + | some n => n.addr + | none => default + +/-- Verify hardcoded primitive addresses match actual compiled addresses. -/ +def testVerifyPrimAddrs : TestSeq := + .individualIO "verify primitive addresses" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let hardcoded := Ix.Kernel.buildPrimitives + let mut failures : Array String := #[] + let checks : Array (String × String × Address) := #[ + ("nat", "Nat", hardcoded.nat), + ("natZero", "Nat.zero", hardcoded.natZero), + ("natSucc", "Nat.succ", hardcoded.natSucc), + ("natAdd", "Nat.add", hardcoded.natAdd), + ("natSub", "Nat.sub", hardcoded.natSub), + ("natMul", "Nat.mul", hardcoded.natMul), + ("natPow", "Nat.pow", hardcoded.natPow), + ("natGcd", "Nat.gcd", hardcoded.natGcd), + ("natMod", "Nat.mod", hardcoded.natMod), + ("natDiv", "Nat.div", hardcoded.natDiv), + ("natBeq", "Nat.beq", hardcoded.natBeq), + ("natBle", "Nat.ble", hardcoded.natBle), + ("natLand", "Nat.land", hardcoded.natLand), + ("natLor", "Nat.lor", hardcoded.natLor), + ("natXor", "Nat.xor", hardcoded.natXor), + ("natShiftLeft", "Nat.shiftLeft", hardcoded.natShiftLeft), + ("natShiftRight", "Nat.shiftRight", hardcoded.natShiftRight), + ("bool", "Bool", hardcoded.bool), + ("boolTrue", "Bool.true", hardcoded.boolTrue), + ("boolFalse", "Bool.false", hardcoded.boolFalse), + ("string", "String", hardcoded.string), + ("stringMk", "String.mk", hardcoded.stringMk), + ("char", "Char", hardcoded.char), + ("charMk", "Char.ofNat", hardcoded.charMk), + ("list", "List", hardcoded.list), + ("listNil", "List.nil", hardcoded.listNil), + ("listCons", "List.cons", hardcoded.listCons) + ] + for (field, name, expected) in checks do + let actual := lookupPrim ixonEnv name + if actual != expected then + failures := failures.push s!"{field}: expected {expected}, got {actual}" + IO.println s!" [MISMATCH] {field} ({name}): {actual} != {expected}" + if failures.isEmpty then + IO.println s!"[prims] All {checks.size} primitive addresses verified" + return (true, none) + else + return (false, some s!"{failures.size} primitive address mismatch(es). Run `lake test -- kernel-dump-prims` to update.") + ) .done + +/-- Dump all primitive addresses for hardcoding. Use this to update buildPrimitives. -/ +def testDumpPrimAddrs : TestSeq := + .individualIO "dump primitive addresses" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let names := #[ + ("nat", "Nat"), ("natZero", "Nat.zero"), ("natSucc", "Nat.succ"), + ("natAdd", "Nat.add"), ("natSub", "Nat.sub"), ("natMul", "Nat.mul"), + ("natPow", "Nat.pow"), ("natGcd", "Nat.gcd"), ("natMod", "Nat.mod"), + ("natDiv", "Nat.div"), ("natBeq", "Nat.beq"), ("natBle", "Nat.ble"), + ("natLand", "Nat.land"), ("natLor", "Nat.lor"), ("natXor", "Nat.xor"), + ("natShiftLeft", "Nat.shiftLeft"), ("natShiftRight", "Nat.shiftRight"), + ("bool", "Bool"), ("boolTrue", "Bool.true"), ("boolFalse", "Bool.false"), + ("string", "String"), ("stringMk", "String.mk"), + ("char", "Char"), ("charMk", "Char.ofNat"), + ("list", "List"), ("listNil", "List.nil"), ("listCons", "List.cons") + ] + for (field, name) in names do + IO.println s!"{field} := \"{lookupPrim ixonEnv name}\"" + return (true, none) + ) .done + +/-! ## Unit tests: Level reduce/imax edge cases -/ + +def testLevelReduceIMax : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- imax u 0 = 0 + test "imax u 0 = 0" (Level.reduceIMax p0 l0 == l0) ++ + -- imax u (succ v) = max u (succ v) + test "imax u (succ v) = max u (succ v)" + (Level.equalLevel (Level.reduceIMax p0 l1) (Level.reduceMax p0 l1)) ++ + -- imax u u = u (same param) + test "imax u u = u" (Level.reduceIMax p0 p0 == p0) ++ + -- imax u v stays imax (different params) + test "imax u v stays imax" + (Level.reduceIMax p0 p1 == Level.imax p0 p1) ++ + -- nested: imax u (imax v 0) — reduce inner first, then outer + let inner := Level.reduceIMax p1 l0 -- = 0 + test "imax u (imax v 0) = imax u 0 = 0" + (Level.reduceIMax p0 inner == l0) ++ + .done + +def testLevelReduceMax : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max 0 u = u + test "max 0 u = u" (Level.reduceMax l0 p0 == p0) ++ + -- max u 0 = u + test "max u 0 = u" (Level.reduceMax p0 l0 == p0) ++ + -- max (succ u) (succ v) = succ (max u v) + test "max (succ u) (succ v) = succ (max u v)" + (Level.reduceMax (Level.succ p0) (Level.succ p1) + == Level.succ (Level.reduceMax p0 p1)) ++ + -- max p0 p0 = p0 + test "max p0 p0 = p0" (Level.reduceMax p0 p0 == p0) ++ + .done + +def testLevelLeqComplex : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max u v <= max v u (symmetry) + test "max u v <= max v u" + (Level.leq (Level.max p0 p1) (Level.max p1 p0) 0) ++ + -- u <= max u v + test "u <= max u v" + (Level.leq p0 (Level.max p0 p1) 0) ++ + -- imax u (succ v) <= max u (succ v) — after reduce they're equal + let lhs := Level.reduce (Level.imax p0 (.succ p1)) + let rhs := Level.reduce (Level.max p0 (.succ p1)) + test "imax u (succ v) <= max u (succ v)" + (Level.leq lhs rhs 0) ++ + -- imax u 0 <= 0 + test "imax u 0 <= 0" + (Level.leq (Level.reduce (.imax p0 l0)) l0 0) ++ + -- not (succ (max u v) <= max u v) + test "not (succ (max u v) <= max u v)" + (!Level.leq (Level.succ (Level.max p0 p1)) (Level.max p0 p1) 0) ++ + -- imax u u <= u + test "imax u u <= u" + (Level.leq (Level.reduce (Level.imax p0 p0)) p0 0) ++ + -- imax 1 (imax 1 u) = u (nested imax decomposition) + let l1 : Level .anon := Level.succ Level.zero + let nested := Level.reduce (Level.imax l1 (Level.imax l1 p0)) + test "imax 1 (imax 1 u) <= u" + (Level.leq nested p0 0) ++ + test "u <= imax 1 (imax 1 u)" + (Level.leq p0 nested 0) ++ + .done + +def testLevelInstBulkReduce : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- Basic: param 0 with [zero] = zero + test "param 0 with [zero] = zero" + (Level.instBulkReduce #[l0] p0 == l0) ++ + -- Multi: param 1 with [zero, succ zero] = succ zero + test "param 1 with [zero, succ zero] = succ zero" + (Level.instBulkReduce #[l0, l1] p1 == l1) ++ + -- Out-of-bounds: param 2 with 2-element array shifts + let p2 : Level .anon := Level.param 2 default + test "param 2 with 2-elem array shifts to param 0" + (Level.instBulkReduce #[l0, l1] p2 == Level.param 0 default) ++ + -- Compound: imax (param 0) (param 1) with [zero, succ zero] + let compound := Level.imax p0 p1 + let result := Level.instBulkReduce #[l0, l1] compound + -- imax 0 (succ 0) = max 0 (succ 0) = succ 0 + test "imax (param 0) (param 1) subst [zero, succ zero]" + (Level.equalLevel result l1) ++ + .done + +def testReducibilityHintsLt : TestSeq := + test "regular 1 < regular 2" (ReducibilityHints.lt' (.regular 1) (.regular 2)) ++ + test "not (regular 2 < regular 1)" (!ReducibilityHints.lt' (.regular 2) (.regular 1)) ++ + test "regular _ < opaque" (ReducibilityHints.lt' (.regular 5) .opaque) ++ + test "abbrev < opaque" (ReducibilityHints.lt' .abbrev .opaque) ++ + test "not (opaque < opaque)" (!ReducibilityHints.lt' .opaque .opaque) ++ + test "not (regular 5 < regular 5)" (!ReducibilityHints.lt' (.regular 5) (.regular 5)) ++ + .done + +/-! ## Expanded integration tests -/ + +/-- Expanded constant pipeline: more constants including quotients, recursors, projections. -/ +def testMoreConstants : TestSeq := + .individualIO "expanded kernel const pipeline" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => return (false, some e) + | .ok (kenv, prims, quotInit) => + let constNames := #[ + -- Quotient types + "Quot", "Quot.mk", "Quot.lift", "Quot.ind", + -- K-reduction exercisers + "Eq.rec", "Eq.subst", "Eq.symm", "Eq.trans", + -- Proof irrelevance + "And.intro", "Or.inl", "Or.inr", + -- K-like reduction with congr + "congr", "congrArg", "congrFun", + -- Structure projections + eta + "Prod.fst", "Prod.snd", "Prod.mk", "Sigma.mk", "Subtype.mk", + -- Nat primitives + "Nat.add", "Nat.sub", "Nat.mul", "Nat.div", "Nat.mod", + "Nat.gcd", "Nat.beq", "Nat.ble", + "Nat.land", "Nat.lor", "Nat.xor", + "Nat.shiftLeft", "Nat.shiftRight", "Nat.pow", + -- Recursors + "Bool.rec", "List.rec", + -- Delta unfolding + "id", "Function.comp", + -- Various inductives + "Empty", "PUnit", "Fin", "Sigma", "Prod", + -- Proofs / proof irrelevance + "True", "False", "And", "Or", + -- Mutual/nested inductives + "List.map", "List.foldl", "List.append", + -- Universe polymorphism + "ULift", "PLift", + -- More complex + "Option", "Option.some", "Option.none", + "String", "String.mk", "Char", + -- Partial definitions + "WellFounded.fix" + ] + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"{name}: {e}" + IO.println s!"[kernel-expanded] {passed}/{constNames.size} passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Anon mode conversion test -/ + +/-- Test that convertEnv in .anon mode produces the same number of constants. -/ +def testAnonConvert : TestSeq := + .individualIO "anon mode conversion" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let metaResult := Ix.Kernel.Convert.convertEnv .meta ixonEnv + let anonResult := Ix.Kernel.Convert.convertEnv .anon ixonEnv + match metaResult, anonResult with + | .ok (metaEnv, _, _), .ok (anonEnv, _, _) => + let metaCount := metaEnv.size + let anonCount := anonEnv.size + IO.println s!"[kernel-anon] meta: {metaCount}, anon: {anonCount}" + if metaCount == anonCount then + return (true, none) + else + return (false, some s!"meta ({metaCount}) != anon ({anonCount})") + | .error e, _ => return (false, some s!"meta conversion failed: {e}") + | _, .error e => return (false, some s!"anon conversion failed: {e}") + ) .done + +/-! ## Negative tests -/ + +/-- Negative test suite: verify that the typechecker rejects malformed declarations. -/ +def negativeTests : TestSeq := + .individualIO "kernel negative tests" (do + let testAddr := Address.blake3 (ByteArray.mk #[1, 0, 42]) + let badAddr := Address.blake3 (ByteArray.mk #[99, 0, 42]) + let prims := buildPrimitives + let mut passed := 0 + let mut failures : Array String := #[] + + -- Test 1: Theorem not in Prop (type = Sort 1, which is Type 0 not Prop) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () } + let ci : ConstantInfo .anon := .thmInfo { toConstantVal := cv, value := .sort .zero, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "theorem-not-prop: expected error" + + -- Test 2: Type mismatch (definition type = Sort 0, value = Sort 1) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort .zero, name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort (.succ .zero), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "type-mismatch: expected error" + + -- Test 3: Unknown constant reference (type references non-existent address) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .const badAddr #[] (), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort .zero, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "unknown-const: expected error" + + -- Test 4: Variable out of range (type = bvar 0 in empty context) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .bvar 0 (), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort .zero, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "var-out-of-range: expected error" + + -- Test 5: Application of non-function (Sort 0 is not a function) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ .zero)), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := .app (.sort .zero) (.sort .zero), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "app-non-function: expected error" + + -- Test 6: Let value type doesn't match annotation (Sort 1 : Sort 2, not Sort 0) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ (.succ .zero))), name := (), levelParams := () } + let letVal : Expr .anon := .letE (.sort .zero) (.sort (.succ .zero)) (.bvar 0 ()) () + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := letVal, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "let-type-mismatch: expected error" + + -- Test 7: Lambda applied to wrong type (domain expects Prop, given Type 0) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ .zero)), name := (), levelParams := () } + let lam : Expr .anon := .lam (.sort .zero) (.bvar 0 ()) () () + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := .app lam (.sort (.succ .zero)), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "app-wrong-type: expected error" + + -- Test 8: Axiom with non-sort type (type = App (Sort 0) (Sort 0), not a sort) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .app (.sort .zero) (.sort .zero), name := (), levelParams := () } + let ci : ConstantInfo .anon := .axiomInfo { toConstantVal := cv, isUnsafe := false } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "axiom-non-sort-type: expected error" + + IO.println s!"[kernel-negative] {passed}/8 passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Focused NbE constant tests -/ + +/-- Test individual constants through the NbE kernel to isolate failures. -/ +def testNbeConsts : TestSeq := + .individualIO "nbe focused const checks" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => return (false, some s!"convertEnv: {e}") + | .ok (kenv, prims, quotInit) => + let constNames := #[ + -- Nat basics + "Nat", "Nat.zero", "Nat.succ", "Nat.rec", + -- Below / brecOn (well-founded recursion scaffolding) + "Nat.below", "Nat.brecOn", + -- PProd (used by Nat.below) + "PProd", "PProd.mk", "PProd.fst", "PProd.snd", + "PUnit", "PUnit.unit", + -- noConfusion (stuck neutral in fresh-state mode) + "Lean.Meta.Grind.Origin.noConfusionType", + "Lean.Meta.Grind.Origin.noConfusion", + "Lean.Meta.Grind.Origin.stx.noConfusion", + -- The previously-hanging constant + "Nat.Linear.Poly.of_denote_eq_cancel", + -- String theorem (fuel-sensitive) + "String.length_empty", + ] + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + IO.println s!" checking {name} ..." + (← IO.getStdout).flush + let start ← IO.monoMsNow + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✓ {name} ({ms.formatMs})" + passed := passed + 1 + | .error e => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✗ {name} ({ms.formatMs}): {e}" + failures := failures.push s!"{name}: {e}" + IO.println s!"[nbe-focus] {passed}/{constNames.size} passed" + if failures.isEmpty then + return (true, none) + else + return (false, some s!"{failures.size} failure(s)") + ) .done + +def nbeFocusSuite : List TestSeq := [ + testNbeConsts, +] + +/-! ## Test suites -/ + +def unitSuite : List TestSeq := [ + testExprHashEq, + testExprOps, + testLevelOps, + testLevelReduceIMax, + testLevelReduceMax, + testLevelLeqComplex, + testLevelInstBulkReduce, + testReducibilityHintsLt, +] + +def convertSuite : List TestSeq := [ + testConvertEnv, +] + +def constSuite : List TestSeq := [ + testConstPipeline, + testMoreConstants, +] + +def negativeSuite : List TestSeq := [ + negativeTests, +] + +def anonConvertSuite : List TestSeq := [ + testAnonConvert, +] + +/-! ## Roundtrip test: Lean → Ixon → Kernel → Lean -/ + +/-- Roundtrip test: compile Lean env to Ixon, convert to Kernel, decompile back to Lean, + and structurally compare against the original. -/ +def testRoundtrip : TestSeq := + .individualIO "kernel roundtrip Lean→Ixon→Kernel→Lean" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[roundtrip] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, _, _) => + -- Build Lean.Name → EnvId map from ixonEnv.named (name-aware lookup) + let mut nameToEnvId : Std.HashMap Lean.Name (Ix.Kernel.EnvId .meta) := {} + for (ixName, named) in ixonEnv.named do + nameToEnvId := nameToEnvId.insert (Ix.Kernel.Decompile.ixNameToLean ixName) ⟨named.addr, ixName⟩ + -- Build work items (filter to constants we can check) + let mut workItems : Array (Lean.Name × Lean.ConstantInfo × Ix.Kernel.ConstantInfo .meta) := #[] + let mut notFound := 0 + for (leanName, origCI) in leanEnv.constants.toList do + let some envId := nameToEnvId.get? leanName + | do notFound := notFound + 1; continue + let some kernelCI := kenv.findByEnvId envId + | continue + workItems := workItems.push (leanName, origCI, kernelCI) + -- Chunked parallel comparison + let numWorkers := 32 + let total := workItems.size + let chunkSize := (total + numWorkers - 1) / numWorkers + let mut tasks : Array (Task (Array (Lean.Name × Array (String × String × String)))) := #[] + let mut offset := 0 + while offset < total do + let endIdx := min (offset + chunkSize) total + let chunk := workItems[offset:endIdx] + let task := Task.spawn (prio := .dedicated) fun () => Id.run do + let mut results : Array (Lean.Name × Array (String × String × String)) := #[] + for (leanName, origCI, kernelCI) in chunk.toArray do + let roundtrippedCI := Ix.Kernel.Decompile.decompileConstantInfo kernelCI + let diffs := Ix.Kernel.Decompile.constInfoStructEq origCI roundtrippedCI + if !diffs.isEmpty then + results := results.push (leanName, diffs) + results + tasks := tasks.push task + offset := endIdx + -- Collect results + let checked := total + let mut mismatches := 0 + for task in tasks do + for (leanName, diffs) in task.get do + mismatches := mismatches + 1 + let diffMsgs := diffs.toList.map fun (path, lhs, rhs) => + s!" {path}: {lhs} ≠ {rhs}" + IO.println s!"[roundtrip] MISMATCH {leanName}:" + for msg in diffMsgs do IO.println msg + IO.println s!"[roundtrip] checked {checked}, mismatches {mismatches}, not found {notFound}" + if mismatches == 0 then + return (true, none) + else + return (false, some s!"{mismatches}/{checked} constants have structural mismatches") + ) .done + +def roundtripSuite : List TestSeq := [ + testRoundtrip, +] + +end Tests.KernelTests diff --git a/Tests/Ix/PP.lean b/Tests/Ix/PP.lean new file mode 100644 index 00000000..d96bd0f1 --- /dev/null +++ b/Tests/Ix/PP.lean @@ -0,0 +1,333 @@ +/- + Pretty printer test suite. + + Tests Expr.pp in both .meta and .anon modes, covering: + - Level/Sort display + - Binder/Var/Const name formatting + - App parenthesization + - Pi and Lambda chain collapsing + - Let expressions + - Literals and projections +-/ +import Ix.Kernel +import LSpec + +open LSpec +open Ix.Kernel + +namespace Tests.PP + +/-! ## Helpers -/ + +private def mkName (s : String) : Ix.Name := + Ix.Name.mkStr Ix.Name.mkAnon s + +private def mkDottedName (a b : String) : Ix.Name := + Ix.Name.mkStr (Ix.Name.mkStr Ix.Name.mkAnon a) b + +private def testAddr : Address := Address.blake3 (ByteArray.mk #[1, 2, 3]) +private def testAddr2 : Address := Address.blake3 (ByteArray.mk #[4, 5, 6]) + +/-- First 8 hex chars of testAddr, for anon mode assertions. -/ +private def testAddrShort : String := + String.ofList ((toString testAddr).toList.take 8) + +/-! ## Meta mode: Level / Sort display -/ + +def testPpSortMeta : TestSeq := + -- Sort display + let prop : Expr .meta := .sort .zero + let type : Expr .meta := .sort (.succ .zero) + let type1 : Expr .meta := .sort (.succ (.succ .zero)) + let type2 : Expr .meta := .sort (.succ (.succ (.succ .zero))) + -- Universe params + let uName := mkName "u" + let vName := mkName "v" + let sortU : Expr .meta := .sort (.param 0 uName) + let typeU : Expr .meta := .sort (.succ (.param 0 uName)) + let sortMax : Expr .meta := .sort (.max (.param 0 uName) (.param 1 vName)) + let sortIMax : Expr .meta := .sort (.imax (.param 0 uName) (.param 1 vName)) + -- Succ offset on param: Type (u + 1), Type (u + 2) + let typeU1 : Expr .meta := .sort (.succ (.succ (.param 0 uName))) + let typeU2 : Expr .meta := .sort (.succ (.succ (.succ (.param 0 uName)))) + test "sort zero → Prop" (prop.pp == "Prop") ++ + test "sort 1 → Type" (type.pp == "Type") ++ + test "sort 2 → Type 1" (type1.pp == "Type 1") ++ + test "sort 3 → Type 2" (type2.pp == "Type 2") ++ + test "sort (param u) → Sort u" (sortU.pp == "Sort u") ++ + test "sort (succ (param u)) → Type u" (typeU.pp == "Type u") ++ + test "sort (succ^2 (param u)) → Type (u + 1)" (typeU1.pp == "Type (u + 1)") ++ + test "sort (succ^3 (param u)) → Type (u + 2)" (typeU2.pp == "Type (u + 2)") ++ + test "sort (max u v) → Sort (max (u) (v))" (sortMax.pp == "Sort (max (u) (v))") ++ + test "sort (imax u v) → Sort (imax (u) (v))" (sortIMax.pp == "Sort (imax (u) (v))") ++ + .done + +/-! ## Meta mode: Atoms (bvar, const, lit) -/ + +def testPpAtomsMeta : TestSeq := + let x := mkName "x" + let natAdd := mkDottedName "Nat" "add" + -- bvar with name + let bv : Expr .meta := .bvar 0 x + test "bvar with name → x" (bv.pp == "x") ++ + -- const with name + let c : Expr .meta := .const testAddr #[] natAdd + test "const Nat.add → Nat.add" (c.pp == "Nat.add") ++ + -- nat literal + let n : Expr .meta := .lit (.natVal 42) + test "natLit 42 → 42" (n.pp == "42") ++ + -- string literal + let s : Expr .meta := .lit (.strVal "hello") + test "strLit hello → \"hello\"" (s.pp == "\"hello\"") ++ + .done + +/-! ## Meta mode: App parenthesization -/ + +def testPpAppMeta : TestSeq := + let f : Expr .meta := .const testAddr #[] (mkName "f") + let g : Expr .meta := .const testAddr2 #[] (mkName "g") + let a : Expr .meta := .bvar 0 (mkName "a") + let b : Expr .meta := .bvar 1 (mkName "b") + -- Simple application: no parens at top level + let fa := Expr.app f a + test "f a (no parens)" (fa.pp == "f a") ++ + -- Nested left-assoc: f a b + let fab := Expr.app (Expr.app f a) b + test "f a b (left-assoc, no parens)" (fab.pp == "f a b") ++ + -- Nested arg: f (g a) — arg needs parens + let fga := Expr.app f (Expr.app g a) + test "f (g a) (arg parens)" (fga.pp == "f (g a)") ++ + -- Atom mode: (f a) + test "f a atom → (f a)" (Expr.pp true fa == "(f a)") ++ + -- Deep nesting: f a (g b) + let fagb := Expr.app (Expr.app f a) (Expr.app g b) + test "f a (g b)" (fagb.pp == "f a (g b)") ++ + .done + +/-! ## Meta mode: Lambda and Pi -/ + +def testPpBindersMeta : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let bool : Expr .meta := .const testAddr2 #[] (mkName "Bool") + let body : Expr .meta := .bvar 0 (mkName "x") + let body2 : Expr .meta := .bvar 1 (mkName "y") + -- Single lambda + let lam1 : Expr .meta := .lam nat body (mkName "x") .default + test "λ (x : Nat) => x" (lam1.pp == "λ (x : Nat) => x") ++ + -- Single forall + let pi1 : Expr .meta := .forallE nat body (mkName "x") .default + test "∀ (x : Nat), x" (pi1.pp == "∀ (x : Nat), x") ++ + -- Chained lambdas + let lam2 : Expr .meta := .lam nat (.lam bool body2 (mkName "y") .default) (mkName "x") .default + test "λ (x : Nat) (y : Bool) => y" (lam2.pp == "λ (x : Nat) (y : Bool) => y") ++ + -- Chained foralls + let pi2 : Expr .meta := .forallE nat (.forallE bool body2 (mkName "y") .default) (mkName "x") .default + test "∀ (x : Nat) (y : Bool), y" (pi2.pp == "∀ (x : Nat) (y : Bool), y") ++ + -- Lambda in atom position + test "lambda atom → (λ ...)" (Expr.pp true lam1 == "(λ (x : Nat) => x)") ++ + -- Forall in atom position + test "forall atom → (∀ ...)" (Expr.pp true pi1 == "(∀ (x : Nat), x)") ++ + .done + +/-! ## Meta mode: Let -/ + +def testPpLetMeta : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let zero : Expr .meta := .lit (.natVal 0) + let body : Expr .meta := .bvar 0 (mkName "x") + let letE : Expr .meta := .letE nat zero body (mkName "x") + test "let x : Nat := 0; x" (letE.pp == "let x : Nat := 0; x") ++ + -- Let in atom position + test "let atom → (let ...)" (Expr.pp true letE == "(let x : Nat := 0; x)") ++ + .done + +/-! ## Meta mode: Projection -/ + +def testPpProjMeta : TestSeq := + let struct : Expr .meta := .bvar 0 (mkName "s") + let proj0 : Expr .meta := .proj testAddr 0 struct (mkName "Prod") + test "s.0" (proj0.pp == "s.0") ++ + -- Projection of app (needs parens around struct) + let f : Expr .meta := .const testAddr #[] (mkName "f") + let a : Expr .meta := .bvar 0 (mkName "a") + let projApp : Expr .meta := .proj testAddr 1 (.app f a) (mkName "Prod") + test "(f a).1" (projApp.pp == "(f a).1") ++ + .done + +/-! ## Anon mode -/ + +def testPpAnon : TestSeq := + -- bvar: ^idx + let bv : Expr .anon := .bvar 3 () + test "anon bvar 3 → ^3" (bv.pp == "^3") ++ + -- const: #hash + let c : Expr .anon := .const testAddr #[] () + test "anon const → #hash" (c.pp == s!"#{testAddrShort}") ++ + -- sort + let prop : Expr .anon := .sort .zero + test "anon sort zero → Prop" (prop.pp == "Prop") ++ + -- level param: u_idx + let sortU : Expr .anon := .sort (.param 0 ()) + test "anon sort (param 0) → Sort u_0" (sortU.pp == "Sort u_0") ++ + -- lambda: binder name = _ + let lam : Expr .anon := .lam (.sort .zero) (.bvar 0 ()) () () + test "anon lam → λ (_ : ...) => ..." (lam.pp == "λ (_ : Prop) => ^0") ++ + -- forall: binder name = _ + let pi : Expr .anon := .forallE (.sort .zero) (.bvar 0 ()) () () + test "anon forall → ∀ (_ : ...), ..." (pi.pp == "∀ (_ : Prop), ^0") ++ + -- let: binder name = _ + let letE : Expr .anon := .letE (.sort .zero) (.lit (.natVal 0)) (.bvar 0 ()) () + test "anon let → let _ : ..." (letE.pp == "let _ : Prop := 0; ^0") ++ + -- chained anon lambdas + let lam2 : Expr .anon := .lam (.sort .zero) (.lam (.sort (.succ .zero)) (.bvar 0 ()) () ()) () () + test "anon chained lam" (lam2.pp == "λ (_ : Prop) (_ : Type) => ^0") ++ + .done + +/-! ## Meta mode: ??? detection (flags naming bugs) -/ + +/-- In .meta mode, default/anonymous names produce "???" in binder positions + and full address hashes in const positions. These indicate naming info was + never present in the source expression (e.g., anonymous Ix.Name). + + Binder names survive the eval/quote round-trip: Value.lam and Value.pi + carry MetaField name and binder info, which quote extracts. + + Remaining const-name loss: `strLitToCtorVal`/`toCtorIfLit` create + Neutral.const with default names for synthetic primitive constructors. +-/ +def testPpMetaDefaultNames : TestSeq := + let anonName := Ix.Name.mkAnon + -- bvar with anonymous name shows ??? + let bv : Expr .meta := .bvar 0 anonName + test "meta bvar with anonymous name → ???" (bv.pp == "???") ++ + -- const with anonymous name shows full hash + let c : Expr .meta := .const testAddr #[] anonName + test "meta const with anonymous name → full hash" (c.pp == s!"{testAddr}") ++ + -- lambda with anonymous binder name shows ??? + let lam : Expr .meta := .lam (.sort .zero) (.bvar 0 anonName) anonName .default + test "meta lam with anonymous binder → λ (??? : Prop) => ???" (lam.pp == "λ (??? : Prop) => ???") ++ + -- forall with anonymous binder name shows ??? + let pi : Expr .meta := .forallE (.sort .zero) (.bvar 0 anonName) anonName .default + test "meta forall with anonymous binder → ∀ (??? : Prop), ???" (pi.pp == "∀ (??? : Prop), ???") ++ + .done + +/-! ## Complex expressions -/ + +def testPpComplex : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let bool : Expr .meta := .const testAddr2 #[] (mkName "Bool") + -- ∀ (n : Nat), Nat → Nat (arrow sugar approximation) + -- This is: forallE Nat (forallE Nat Nat) + let arrow : Expr .meta := .forallE nat (.forallE nat nat (mkName "m") .default) (mkName "n") .default + test "∀ (n : Nat) (m : Nat), Nat" (arrow.pp == "∀ (n : Nat) (m : Nat), Nat") ++ + -- fun (f : Nat → Bool) (x : Nat) => f x + let fType : Expr .meta := .forallE nat bool (mkName "a") .default + let fApp : Expr .meta := .app (.bvar 1 (mkName "f")) (.bvar 0 (mkName "x")) + let expr : Expr .meta := .lam fType (.lam nat fApp (mkName "x") .default) (mkName "f") .default + test "λ (f : ∀ ...) (x : Nat) => f x" + (expr.pp == "λ (f : ∀ (a : Nat), Bool) (x : Nat) => f x") ++ + -- Nested let: let x : Nat := 0; let y : Nat := x; y + let innerLet : Expr .meta := .letE nat (.bvar 0 (mkName "x")) (.bvar 0 (mkName "y")) (mkName "y") + let outerLet : Expr .meta := .letE nat (.lit (.natVal 0)) innerLet (mkName "x") + test "nested let" (outerLet.pp == "let x : Nat := 0; let y : Nat := x; y") ++ + .done + +/-! ## Quote round-trip: names survive eval → quote → pp -/ + +/-- Build a Value with named binders and verify names survive through quote → pp. + Uses a minimal TypecheckM context. -/ +def testQuoteRoundtrip : TestSeq := + .individualIO "quote round-trip preserves names" (do + let xName : MetaField .meta Ix.Name := mkName "x" + let yName : MetaField .meta Ix.Name := mkName "y" + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + -- Build Value.pi: ∀ (x : Nat), Nat + let domVal : SusValue .meta := ⟨.none, Thunk.mk fun _ => Value.neu (.const testAddr #[] (mkName "Nat"))⟩ + let imgTE : TypedExpr .meta := ⟨.none, nat⟩ + let piVal : Value .meta := .pi domVal imgTE (.mk [] []) xName .default + -- Build Value.lam: fun (y : Nat) => y + let bodyTE : TypedExpr .meta := ⟨.none, .bvar 0 yName⟩ + let lamVal : Value .meta := .lam domVal bodyTE (.mk [] []) yName .default + -- Quote and pp in a minimal TypecheckM context + let ctx : TypecheckCtx .meta := { + lvl := 0, env := .mk [] [], types := [], + kenv := default, prims := buildPrimitives, + safety := .safe, quotInit := true, mutTypes := default, recAddr? := none + } + let stt : TypecheckState .meta := { typedConsts := default } + -- Test pi + match TypecheckM.run ctx stt (ppValue 0 piVal) with + | .ok s => + if s != "∀ (x : Nat), Nat" then + return (false, some s!"pi round-trip: expected '∀ (x : Nat), Nat', got '{s}'") + else pure () + | .error e => return (false, some s!"pi round-trip error: {e}") + -- Test lam + match TypecheckM.run ctx stt (ppValue 0 lamVal) with + | .ok s => + if s != "λ (y : Nat) => y" then + return (false, some s!"lam round-trip: expected 'λ (y : Nat) => y', got '{s}'") + else pure () + | .error e => return (false, some s!"lam round-trip error: {e}") + return (true, none) + ) .done + +/-! ## Literal folding: Nat/String constructor chains → literals in ppValue -/ + +def testFoldLiterals : TestSeq := + let prims := buildPrimitives + -- Nat.zero → 0 + let natZero : Expr .meta := .const prims.natZero #[] (mkName "Nat.zero") + let folded := foldLiterals prims natZero + test "fold Nat.zero → 0" (folded.pp == "0") ++ + -- Nat.succ Nat.zero → 1 + let natOne : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) natZero + let folded := foldLiterals prims natOne + test "fold Nat.succ Nat.zero → 1" (folded.pp == "1") ++ + -- Nat.succ (Nat.succ Nat.zero) → 2 + let natTwo : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) natOne + let folded := foldLiterals prims natTwo + test "fold Nat.succ^2 Nat.zero → 2" (folded.pp == "2") ++ + -- Nats inside types get folded: ∀ (n : Nat), Eq Nat n Nat.zero + let natType : Expr .meta := .const prims.nat #[] (mkName "Nat") + let eqAddr := Address.blake3 (ByteArray.mk #[99]) + let eq3 : Expr .meta := + .app (.app (.app (.const eqAddr #[] (mkName "Eq")) natType) (.bvar 0 (mkName "n"))) natZero + let piExpr : Expr .meta := .forallE natType eq3 (mkName "n") .default + let folded := foldLiterals prims piExpr + test "fold nat inside forall" (folded.pp == "∀ (n : Nat), Eq Nat n 0") ++ + -- String.mk (List.cons (Char.ofNat 104) (List.cons (Char.ofNat 105) List.nil)) → "hi" + let charH : Expr .meta := .app (.const prims.charMk #[] (mkName "Char.ofNat")) (.lit (.natVal 104)) + let charI : Expr .meta := .app (.const prims.charMk #[] (mkName "Char.ofNat")) (.lit (.natVal 105)) + let charType : Expr .meta := .const prims.char #[] (mkName "Char") + let nilExpr : Expr .meta := .app (.const prims.listNil #[.zero] (mkName "List.nil")) charType + let consI : Expr .meta := + .app (.app (.app (.const prims.listCons #[.zero] (mkName "List.cons")) charType) charI) nilExpr + let consH : Expr .meta := + .app (.app (.app (.const prims.listCons #[.zero] (mkName "List.cons")) charType) charH) consI + let strExpr : Expr .meta := .app (.const prims.stringMk #[] (mkName "String.mk")) consH + let folded := foldLiterals prims strExpr + test "fold String.mk char list → \"hi\"" (folded.pp == "\"hi\"") ++ + -- Nat.succ applied to a non-literal arg stays unfolded + let succX : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) (.bvar 0 (mkName "x")) + let folded := foldLiterals prims succX + test "fold Nat.succ x → Nat.succ x (no fold)" (folded.pp == "Nat.succ x") ++ + .done + +/-! ## Suites -/ + +def suite : List TestSeq := [ + testPpSortMeta, + testPpAtomsMeta, + testPpAppMeta, + testPpBindersMeta, + testPpLetMeta, + testPpProjMeta, + testPpAnon, + testPpMetaDefaultNames, + testPpComplex, + testQuoteRoundtrip, + testFoldLiterals, +] + +end Tests.PP diff --git a/Tests/Main.lean b/Tests/Main.lean index e25300a8..e7ca61c2 100644 --- a/Tests/Main.lean +++ b/Tests/Main.lean @@ -9,6 +9,9 @@ import Tests.Ix.RustDecompile import Tests.Ix.Sharing import Tests.Ix.CanonM import Tests.Ix.GraphM +import Tests.Ix.Check +import Tests.Ix.KernelTests +import Tests.Ix.PP import Tests.Ix.CondenseM import Tests.FFI import Tests.Keccak @@ -32,6 +35,10 @@ def primarySuites : Std.HashMap String (List LSpec.TestSeq) := .ofList [ ("sharing", Tests.Sharing.suite), ("graph-unit", Tests.Ix.GraphM.suite), ("condense-unit", Tests.Ix.CondenseM.suite), + --("check", Tests.Check.checkSuiteIO), -- disable until rust kernel works + ("kernel-unit", Tests.KernelTests.unitSuite), + ("kernel-negative", Tests.KernelTests.negativeSuite), + ("pp", Tests.PP.suite), ] /-- Ignored test suites - expensive, run only when explicitly requested. These require significant RAM -/ @@ -47,6 +54,16 @@ def ignoredSuites : Std.HashMap String (List LSpec.TestSeq) := .ofList [ ("rust-serialize", Tests.RustSerialize.rustSerializeSuiteIO), ("rust-decompile", Tests.RustDecompile.rustDecompileSuiteIO), ("commit-io", Tests.Commit.suiteIO), + --("check-all", Tests.Check.checkAllSuiteIO), + ("kernel-check-env", Tests.Check.kernelSuiteIO), + ("kernel-convert", Tests.KernelTests.convertSuite), + ("kernel-anon-convert", Tests.KernelTests.anonConvertSuite), + ("kernel-const", Tests.KernelTests.constSuite), + ("kernel-verify-prims", [Tests.KernelTests.testVerifyPrimAddrs]), + ("kernel-dump-prims", [Tests.KernelTests.testDumpPrimAddrs]), + ("nbe-focus", Tests.KernelTests.nbeFocusSuite), + ("kernel-roundtrip", Tests.KernelTests.roundtripSuite), + ("ixon-full-roundtrip", Tests.Compile.ixonRoundtripSuiteIO), ] def main (args : List String) : IO UInt32 := do diff --git a/docs/Ixon.md b/docs/Ixon.md index 655f06d8..74509dfd 100644 --- a/docs/Ixon.md +++ b/docs/Ixon.md @@ -736,7 +736,6 @@ pub struct Env { pub blobs: DashMap>, // Raw data (strings, nats) pub names: DashMap, // Hash-consed Name components pub comms: DashMap, // Cryptographic commitments - pub addr_to_name: DashMap, // Reverse index } pub struct Named { @@ -1001,7 +1000,7 @@ Decompilation reconstructs Lean constants from Ixon format. 2. **Initialize tables** from `sharing`, `refs`, `univs` 3. **Load metadata** from `env.named` 4. **Reconstruct expressions** with names and binder info from metadata -5. **Resolve references**: `Ref(idx, _)` → lookup `refs[idx]`, get name from `addr_to_name` +5. **Resolve references**: `Ref(idx, _)` → lookup name from arena metadata via `names` table 6. **Expand shares**: `Share(idx)` → inline `sharing[idx]` (or cache result) ### Roundtrip Verification @@ -1145,7 +1144,7 @@ To reconstruct the Lean constant: 1. Load `Constant` from `consts[address]` 2. Load `Named` from `named["double"]` -3. Resolve `Ref(0, [])` → `refs[0]` → `Nat` (via `addr_to_name`) +3. Resolve `Ref(0, [])` → name from arena metadata → `Nat` (via `names` table) 4. Resolve `Ref(1, [])` → `refs[1]` → `Nat.add` 5. Attach names from metadata: the binder gets name "n" from `type_meta[0]` diff --git a/src/ix/decompile.rs b/src/ix/decompile.rs index 88082135..26bd3dc7 100644 --- a/src/ix/decompile.rs +++ b/src/ix/decompile.rs @@ -565,39 +565,19 @@ pub fn decompile_expr( // Ref: resolve name from arena Ref node or fallback ( ExprMetaData::Ref { name: name_addr }, - Expr::Ref(ref_idx, univ_indices), + Expr::Ref(_ref_idx, univ_indices), ) => { - let name = decompile_name(name_addr, stt).unwrap_or_else(|_| { - // Fallback: resolve from refs table - cache - .refs - .get(*ref_idx as usize) - .and_then(|addr| stt.env.get_name_by_addr(addr)) - .unwrap_or_else(Name::anon) - }); + let name = decompile_name(name_addr, stt)?; let levels = decompile_univ_indices(univ_indices, lvl_names, cache)?; let expr = apply_mdata(LeanExpr::cnst(name, levels), mdata_layers); results.push(expr); }, - (_, Expr::Ref(ref_idx, univ_indices)) => { - // No Ref metadata — resolve from refs table - let addr = cache.refs.get(*ref_idx as usize).ok_or_else(|| { - DecompileError::InvalidRefIndex { - idx: *ref_idx, - refs_len: cache.refs.len(), - constant: cache.current_const.clone(), - } - })?; - let name = stt - .env - .get_name_by_addr(addr) - .ok_or(DecompileError::MissingAddress(addr.clone()))?; - let levels = - decompile_univ_indices(univ_indices, lvl_names, cache)?; - let expr = apply_mdata(LeanExpr::cnst(name, levels), mdata_layers); - results.push(expr); + (_, Expr::Ref(_ref_idx, _univ_indices)) => { + return Err(DecompileError::BadConstantFormat { + msg: "ref without arena metadata".to_string(), + }); }, // Rec: resolve name from arena Ref node or fallback @@ -735,27 +715,10 @@ pub fn decompile_expr( stack.push(Frame::Decompile(struct_val.clone(), *child)); }, - (_, Expr::Prj(type_ref_idx, field_idx, struct_val)) => { - // Fallback: look up from refs table - let addr = - cache.refs.get(*type_ref_idx as usize).ok_or_else(|| { - DecompileError::InvalidRefIndex { - idx: *type_ref_idx, - refs_len: cache.refs.len(), - constant: cache.current_const.clone(), - } - })?; - let named = stt - .env - .get_named_by_addr(addr) - .ok_or(DecompileError::MissingAddress(addr.clone()))?; - let type_name = decompile_name_from_meta(&named.meta, stt)?; - stack.push(Frame::BuildProj( - type_name, - Nat::from(*field_idx), - mdata_layers, - )); - stack.push(Frame::Decompile(struct_val.clone(), u64::MAX)); + (_, Expr::Prj(_type_ref_idx, _field_idx, _struct_val)) => { + return Err(DecompileError::BadConstantFormat { + msg: "prj without arena metadata".to_string(), + }); }, (_, Expr::Share(_)) => unreachable!("Share handled above"), diff --git a/src/ix/ixon/env.rs b/src/ix/ixon/env.rs index b13ce571..80b4349c 100644 --- a/src/ix/ixon/env.rs +++ b/src/ix/ixon/env.rs @@ -36,7 +36,6 @@ impl Named { /// - `blobs`: Raw data (strings, nats, files) /// - `names`: Hash-consed Lean.Name components (Address -> Name) /// - `comms`: Cryptographic commitments (secrets) -/// - `addr_to_name`: Reverse index from constant address to name (for O(1) lookup) #[derive(Debug, Default)] pub struct Env { /// Alpha-invariant constants: Address -> Constant @@ -49,8 +48,6 @@ pub struct Env { pub names: DashMap, /// Cryptographic commitments: commitment Address -> Comm pub comms: DashMap, - /// Reverse index: constant Address -> Name (for fast lookup during decompile) - pub addr_to_name: DashMap, } impl Env { @@ -61,7 +58,6 @@ impl Env { blobs: DashMap::new(), names: DashMap::new(), comms: DashMap::new(), - addr_to_name: DashMap::new(), } } @@ -90,8 +86,6 @@ impl Env { /// Register a named constant. pub fn register_name(&self, name: Name, named: Named) { - // Also insert into reverse index for O(1) lookup by address - self.addr_to_name.insert(named.addr.clone(), name.clone()); self.named.insert(name, named); } @@ -100,16 +94,6 @@ impl Env { self.named.get(name).map(|r| r.clone()) } - /// Look up name by constant address (O(1) using reverse index). - pub fn get_name_by_addr(&self, addr: &Address) -> Option { - self.addr_to_name.get(addr).map(|r| r.clone()) - } - - /// Look up named entry by constant address (O(1) using reverse index). - pub fn get_named_by_addr(&self, addr: &Address) -> Option { - self.get_name_by_addr(addr).and_then(|name| self.lookup_name(&name)) - } - /// Store a hash-consed name component. pub fn store_name(&self, addr: Address, name: Name) { self.names.insert(addr, name); @@ -183,12 +167,7 @@ impl Clone for Env { comms.insert(entry.key().clone(), entry.value().clone()); } - let addr_to_name = DashMap::new(); - for entry in self.addr_to_name.iter() { - addr_to_name.insert(entry.key().clone(), entry.value().clone()); - } - - Env { consts, named, blobs, names, comms, addr_to_name } + Env { consts, named, blobs, names, comms } } } @@ -244,28 +223,6 @@ mod tests { assert_eq!(got.addr, addr); } - #[test] - fn get_name_by_addr_reverse_index() { - let env = Env::new(); - let name = n("Reverse"); - let addr = Address::hash(b"reverse-addr"); - let named = Named::with_addr(addr.clone()); - env.register_name(name.clone(), named); - let got_name = env.get_name_by_addr(&addr).unwrap(); - assert_eq!(got_name, name); - } - - #[test] - fn get_named_by_addr_resolves_through_reverse_index() { - let env = Env::new(); - let name = n("Through"); - let addr = Address::hash(b"through-addr"); - let named = Named::with_addr(addr.clone()); - env.register_name(name.clone(), named); - let got = env.get_named_by_addr(&addr).unwrap(); - assert_eq!(got.addr, addr); - } - #[test] fn store_and_get_name_component() { let env = Env::new(); @@ -322,8 +279,6 @@ mod tests { assert!(env.get_blob(&missing).is_none()); assert!(env.get_const(&missing).is_none()); assert!(env.lookup_name(&n("missing")).is_none()); - assert!(env.get_name_by_addr(&missing).is_none()); - assert!(env.get_named_by_addr(&missing).is_none()); assert!(env.get_name(&missing).is_none()); assert!(env.get_comm(&missing).is_none()); } diff --git a/src/ix/ixon/serialize.rs b/src/ix/ixon/serialize.rs index c0572160..aa56d9a2 100644 --- a/src/ix/ixon/serialize.rs +++ b/src/ix/ixon/serialize.rs @@ -1186,7 +1186,6 @@ impl Env { let name = names_lookup.get(&name_addr).cloned().ok_or_else(|| { format!("Env::get: missing name for addr {:?}", name_addr) })?; - env.addr_to_name.insert(named.addr.clone(), name.clone()); env.named.insert(name, named); } @@ -1456,7 +1455,6 @@ mod tests { let name = names[i % names.len()].clone(); let meta = ConstantMeta::default(); let named = Named { addr: addr.clone(), meta }; - env.addr_to_name.insert(addr, name.clone()); env.named.insert(name, named); } } diff --git a/src/ix/kernel/convert.rs b/src/ix/kernel/convert.rs index 90811948..c6f5af2c 100644 --- a/src/ix/kernel/convert.rs +++ b/src/ix/kernel/convert.rs @@ -1,7 +1,10 @@ use core::ptr::NonNull; use std::collections::BTreeMap; +use std::sync::Arc; -use crate::ix::env::{Expr, ExprData, Level, Name}; +use rustc_hash::FxHashMap; + +use crate::ix::env::{BinderInfo, Expr, ExprData, Level, Name}; use crate::lean::nat::Nat; use super::dag::*; @@ -23,208 +26,427 @@ fn from_expr_go( ctx: &BTreeMap>, parents: Option>, ) -> DAGPtr { - match expr.as_data() { - ExprData::Bvar(idx, _) => { - let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); - if idx_u64 < depth { - let level = depth - 1 - idx_u64; - match ctx.get(&level) { - Some(&var_ptr) => { - if let Some(parent_link) = parents { - add_to_parents(DAGPtr::Var(var_ptr), parent_link); + // Frame-based iterative Expr → DAG conversion. + // + // For compound nodes, we pre-allocate the DAG node with dangling child + // pointers, then push frames to fill in children after they're converted. + // + // The ctx is cloned at binder boundaries (Fun, Pi, Let) to track + // bound variable bindings. + enum Frame<'a> { + Visit { + expr: &'a Expr, + depth: u64, + ctx: BTreeMap>, + parents: Option>, + }, + SetAppFun(NonNull), + SetAppArg(NonNull), + SetFunDom(NonNull), + SetPiDom(NonNull), + SetLetTyp(NonNull), + SetLetVal(NonNull), + SetProjExpr(NonNull), + // After domain is set, wire up binder body with new ctx + FunBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + PiBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + LetBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + SetLamBod(NonNull), + } + + let mut work: Vec> = vec![Frame::Visit { + expr, + depth, + ctx: ctx.clone(), + parents, + }]; + // Results stack holds DAGPtr for each completed subtree + let mut results: Vec = Vec::new(); + let mut visit_count: u64 = 0; + // Cache for context-independent leaf nodes (Cnst, Sort, Lit). + // Keyed by Arc pointer identity. Enables DAG sharing so the infer cache + // (keyed by DAGPtr address) can dedup repeated references to the same constant. + let mut leaf_cache: FxHashMap<*const ExprData, DAGPtr> = FxHashMap::default(); + + while let Some(frame) = work.pop() { + visit_count += 1; + if visit_count % 100_000 == 0 { + eprintln!("[from_expr_go] visit_count={visit_count} work_len={}", work.len()); + } + match frame { + Frame::Visit { expr, depth, ctx, parents } => { + match expr.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 < depth { + let level = depth - 1 - idx_u64; + match ctx.get(&level) { + Some(&var_ptr) => { + if let Some(parent_link) = parents { + add_to_parents(DAGPtr::Var(var_ptr), parent_link); + } + results.push(DAGPtr::Var(var_ptr)); + }, + None => { + let var = alloc_val(Var { + depth: level, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); + }, + } + } else { + let var = alloc_val(Var { + depth: idx_u64, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); } - DAGPtr::Var(var_ptr) }, - None => { + + ExprData::Fvar(name, _) => { let var = alloc_val(Var { - depth: level, + depth: 0, binder: BinderPtr::Free, + fvar_name: Some(name.clone()), parents, }); - DAGPtr::Var(var) + results.push(DAGPtr::Var(var)); }, - } - } else { - // Free bound variable (dangling de Bruijn index) - let var = - alloc_val(Var { depth: idx_u64, binder: BinderPtr::Free, parents }); - DAGPtr::Var(var) - } - }, - ExprData::Fvar(_name, _) => { - // Encode fvar name into depth as a unique ID. - // We'll recover it during to_expr using a side table. - let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); - // Store name→var mapping (caller should manage the side table) - DAGPtr::Var(var) - }, + ExprData::Sort(level, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let sort = alloc_val(Sort { level: level.clone(), parents }); + let ptr = DAGPtr::Sort(sort); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Sort(level, _) => { - let sort = alloc_val(Sort { level: level.clone(), parents }); - DAGPtr::Sort(sort) - }, + ExprData::Const(name, levels, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let cnst = alloc_val(Cnst { + name: name.clone(), + levels: levels.clone(), + parents, + }); + let ptr = DAGPtr::Cnst(cnst); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Const(name, levels, _) => { - let cnst = alloc_val(Cnst { - name: name.clone(), - levels: levels.clone(), - parents, - }); - DAGPtr::Cnst(cnst) - }, + ExprData::Lit(lit, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); + let ptr = DAGPtr::Lit(lit_node); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Lit(lit, _) => { - let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); - DAGPtr::Lit(lit_node) - }, + ExprData::App(fun_expr, arg_expr, _) => { + let app_ptr = alloc_app( + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let app = &mut *app_ptr.as_ptr(); + let fun_ref = + NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); + let arg_ref = + NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); + // Process arg first (pushed last = processed first after fun) + work.push(Frame::SetAppArg(app_ptr)); + work.push(Frame::Visit { + expr: arg_expr, + depth, + ctx: ctx.clone(), + parents: Some(arg_ref), + }); + work.push(Frame::SetAppFun(app_ptr)); + work.push(Frame::Visit { + expr: fun_expr, + depth, + ctx, + parents: Some(fun_ref), + }); + } + results.push(DAGPtr::App(app_ptr)); + }, - ExprData::App(fun_expr, arg_expr, _) => { - let app_ptr = alloc_app( - DAGPtr::Var(NonNull::dangling()), - DAGPtr::Var(NonNull::dangling()), - parents, - ); - unsafe { - let app = &mut *app_ptr.as_ptr(); - let fun_ref_ptr = - NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); - let arg_ref_ptr = - NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); - app.fun = from_expr_go(fun_expr, depth, ctx, Some(fun_ref_ptr)); - app.arg = from_expr_go(arg_expr, depth, ctx, Some(arg_ref_ptr)); - } - DAGPtr::App(app_ptr) - }, + ExprData::Lam(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let fun_ptr = alloc_fun( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + let dom_ref = + NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); + let img_ref = + NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref); + + let dom_ctx = ctx.clone(); + work.push(Frame::FunBody { + lam_ptr, + body, + depth, + ctx, + }); + work.push(Frame::SetFunDom(fun_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx: dom_ctx, + parents: Some(dom_ref), + }); + } + results.push(DAGPtr::Fun(fun_ptr)); + }, - ExprData::Lam(name, typ, body, bi, _) => { - // Lean Lam → DAG Fun(dom, Lam(bod, var)) - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let fun_ptr = alloc_fun( - name.clone(), - bi.clone(), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let fun = &mut *fun_ptr.as_ptr(); - let dom_ref_ptr = - NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); - fun.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); - - // Set Lam's parent to FunImg - let img_ref_ptr = - NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + ExprData::ForallE(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let pi_ptr = alloc_pi( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + let dom_ref = + NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); + let img_ref = + NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref); + + let dom_ctx = ctx.clone(); + work.push(Frame::PiBody { + lam_ptr, + body, + depth, + ctx, + }); + work.push(Frame::SetPiDom(pi_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx: dom_ctx, + parents: Some(dom_ref), + }); + } + results.push(DAGPtr::Pi(pi_ptr)); + }, + + ExprData::LetE(name, typ, val, body, non_dep, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let let_ptr = alloc_let( + name.clone(), + *non_dep, + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let typ_ref = + NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); + let val_ref = + NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); + let bod_ref = + NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref); + + work.push(Frame::LetBody { + lam_ptr, + body, + depth, + ctx: ctx.clone(), + }); + work.push(Frame::SetLetVal(let_ptr)); + work.push(Frame::Visit { + expr: val, + depth, + ctx: ctx.clone(), + parents: Some(val_ref), + }); + work.push(Frame::SetLetTyp(let_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx, + parents: Some(typ_ref), + }); + } + results.push(DAGPtr::Let(let_ptr)); + }, + ExprData::Proj(type_name, idx, structure, _) => { + let proj_ptr = alloc_proj( + type_name.clone(), + idx.clone(), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + let expr_ref = + NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); + work.push(Frame::SetProjExpr(proj_ptr)); + work.push(Frame::Visit { + expr: structure, + depth, + ctx, + parents: Some(expr_ref), + }); + } + results.push(DAGPtr::Proj(proj_ptr)); + }, + + ExprData::Mdata(_, inner, _) => { + // Strip metadata, convert inner + work.push(Frame::Visit { expr: inner, depth, ctx, parents }); + }, + + ExprData::Mvar(_name, _) => { + let var = alloc_val(Var { + depth: 0, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); + }, + } + }, + Frame::SetAppFun(app_ptr) => unsafe { + let result = results.pop().unwrap(); + (*app_ptr.as_ptr()).fun = result; + }, + Frame::SetAppArg(app_ptr) => unsafe { + let result = results.pop().unwrap(); + (*app_ptr.as_ptr()).arg = result; + }, + Frame::SetFunDom(fun_ptr) => unsafe { + let result = results.pop().unwrap(); + (*fun_ptr.as_ptr()).dom = result; + }, + Frame::SetPiDom(pi_ptr) => unsafe { + let result = results.pop().unwrap(); + (*pi_ptr.as_ptr()).dom = result; + }, + Frame::SetLetTyp(let_ptr) => unsafe { + let result = results.pop().unwrap(); + (*let_ptr.as_ptr()).typ = result; + }, + Frame::SetLetVal(let_ptr) => unsafe { + let result = results.pop().unwrap(); + (*let_ptr.as_ptr()).val = result; + }, + Frame::SetProjExpr(proj_ptr) => unsafe { + let result = results.pop().unwrap(); + (*proj_ptr.as_ptr()).expr = result; + }, + Frame::SetLamBod(lam_ptr) => unsafe { + let result = results.pop().unwrap(); + (*lam_ptr.as_ptr()).bod = result; + }, + Frame::FunBody { lam_ptr, body, depth, mut ctx } => unsafe { + // Domain has been set; now set up body with var binding let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); - } - DAGPtr::Fun(fun_ptr) - }, - - ExprData::ForallE(name, typ, body, bi, _) => { - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let pi_ptr = alloc_pi( - name.clone(), - bi.clone(), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let pi = &mut *pi_ptr.as_ptr(); - let dom_ref_ptr = - NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); - pi.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); - - let img_ref_ptr = - NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); - + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + Frame::PiBody { lam_ptr, body, depth, mut ctx } => unsafe { let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); - } - DAGPtr::Pi(pi_ptr) - }, - - ExprData::LetE(name, typ, val, body, non_dep, _) => { - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let let_ptr = alloc_let( - name.clone(), - *non_dep, - DAGPtr::Var(NonNull::dangling()), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let let_node = &mut *let_ptr.as_ptr(); - let typ_ref_ptr = - NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); - let val_ref_ptr = - NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); - let_node.typ = from_expr_go(typ, depth, ctx, Some(typ_ref_ptr)); - let_node.val = from_expr_go(val, depth, ctx, Some(val_ref_ptr)); - - let bod_ref_ptr = - NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref_ptr); - + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + Frame::LetBody { lam_ptr, body, depth, mut ctx } => unsafe { let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let inner_bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(inner_bod_ref_ptr)); - } - DAGPtr::Let(let_ptr) - }, - - ExprData::Proj(type_name, idx, structure, _) => { - let proj_ptr = alloc_proj( - type_name.clone(), - idx.clone(), - DAGPtr::Var(NonNull::dangling()), - parents, - ); - unsafe { - let proj = &mut *proj_ptr.as_ptr(); - let expr_ref_ptr = - NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); - proj.expr = - from_expr_go(structure, depth, ctx, Some(expr_ref_ptr)); - } - DAGPtr::Proj(proj_ptr) - }, - - // Mdata: strip metadata, convert inner expression - ExprData::Mdata(_, inner, _) => from_expr_go(inner, depth, ctx, parents), - - // Mvar: treat as terminal (shouldn't appear in well-typed terms) - ExprData::Mvar(_name, _) => { - let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); - DAGPtr::Var(var) - }, + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + } } + + results.pop().unwrap() } // ============================================================================ @@ -250,124 +472,193 @@ impl Clone for crate::ix::env::Literal { pub fn to_expr(dag: &DAG) -> Expr { let mut var_map: BTreeMap<*const Var, u64> = BTreeMap::new(); - to_expr_go(dag.head, &mut var_map, 0) + let mut cache: rustc_hash::FxHashMap<(usize, u64), Expr> = + rustc_hash::FxHashMap::default(); + to_expr_go(dag.head, &mut var_map, 0, &mut cache) } fn to_expr_go( node: DAGPtr, var_map: &mut BTreeMap<*const Var, u64>, depth: u64, + cache: &mut rustc_hash::FxHashMap<(usize, u64), Expr>, ) -> Expr { - unsafe { - match node { - DAGPtr::Var(link) => { - let var = link.as_ptr(); - let var_key = var as *const Var; - if let Some(&bind_depth) = var_map.get(&var_key) { - let idx = depth - bind_depth - 1; - Expr::bvar(Nat::from(idx)) - } else { - // Free variable - Expr::bvar(Nat::from((*var).depth)) - } - }, - - DAGPtr::Sort(link) => { - let sort = &*link.as_ptr(); - Expr::sort(sort.level.clone()) - }, - - DAGPtr::Cnst(link) => { - let cnst = &*link.as_ptr(); - Expr::cnst(cnst.name.clone(), cnst.levels.clone()) - }, + // Frame-based iterative conversion from DAG to Expr. + // + // Uses a cache keyed on (dag_ptr_key, depth) to avoid exponential + // blowup when the DAG has sharing (e.g., after beta reduction). + // + // For binder nodes (Fun, Pi, Let, Lam), the pattern is: + // 1. Visit domain/type/value children + // 2. BinderBody: register var in var_map, push Visit for body + // 3. *Build: pop results, unregister var, build Expr + // 4. CacheStore: cache the built result + enum Frame { + Visit(DAGPtr, u64), + App, + BinderBody(*const Var, DAGPtr, u64), + FunBuild(Name, BinderInfo, *const Var), + PiBuild(Name, BinderInfo, *const Var), + LetBuild(Name, bool, *const Var), + Proj(Name, Nat), + LamBuild(*const Var), + CacheStore(usize, u64), + } - DAGPtr::Lit(link) => { - let lit = &*link.as_ptr(); - Expr::lit(lit.val.clone()) + let mut work: Vec = vec![Frame::Visit(node, depth)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(node, depth) => unsafe { + // Check cache first for non-Var nodes + match node { + DAGPtr::Var(_) => {}, // Vars depend on var_map, skip cache + _ => { + let key = (dag_ptr_key(node), depth); + if let Some(cached) = cache.get(&key) { + results.push(cached.clone()); + continue; + } + }, + } + match node { + DAGPtr::Var(link) => { + let var = link.as_ptr(); + let var_key = var as *const Var; + if let Some(&bind_depth) = var_map.get(&var_key) { + results.push(Expr::bvar(Nat::from(depth - bind_depth - 1))); + } else if let Some(name) = &(*var).fvar_name { + results.push(Expr::fvar(name.clone())); + } else { + results.push(Expr::bvar(Nat::from((*var).depth))); + } + }, + DAGPtr::Sort(link) => { + let sort = &*link.as_ptr(); + results.push(Expr::sort(sort.level.clone())); + }, + DAGPtr::Cnst(link) => { + let cnst = &*link.as_ptr(); + results.push(Expr::cnst(cnst.name.clone(), cnst.levels.clone())); + }, + DAGPtr::Lit(link) => { + let lit = &*link.as_ptr(); + results.push(Expr::lit(lit.val.clone())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::App); + work.push(Frame::Visit(app.arg, depth)); + work.push(Frame::Visit(app.fun, depth)); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let lam = &*fun.img.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::FunBuild( + fun.binder_name.clone(), + fun.binder_info.clone(), + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(fun.dom, depth)); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let lam = &*pi.img.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::PiBuild( + pi.binder_name.clone(), + pi.binder_info.clone(), + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(pi.dom, depth)); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let lam = &*let_node.bod.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::LetBuild( + let_node.binder_name.clone(), + let_node.non_dep, + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(let_node.val, depth)); + work.push(Frame::Visit(let_node.typ, depth)); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::Proj(proj.type_name.clone(), proj.idx.clone())); + work.push(Frame::Visit(proj.expr, depth)); + }, + DAGPtr::Lam(link) => { + // Standalone Lam: no domain to visit, just body + let lam = &*link.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::LamBuild(var_ptr)); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + }, + } }, - - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - let fun = to_expr_go(app.fun, var_map, depth); - let arg = to_expr_go(app.arg, var_map, depth); - Expr::app(fun, arg) + Frame::App => { + let arg = results.pop().unwrap(); + let fun = results.pop().unwrap(); + results.push(Expr::app(fun, arg)); }, - - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - let lam = &*fun.img.as_ptr(); - let dom = to_expr_go(fun.dom, var_map, depth); - let var_ptr = &lam.var as *const Var; + Frame::BinderBody(var_ptr, body, depth) => { var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + work.push(Frame::Visit(body, depth + 1)); + }, + Frame::FunBuild(name, bi, var_ptr) => { var_map.remove(&var_ptr); - Expr::lam( - fun.binder_name.clone(), - dom, - bod, - fun.binder_info.clone(), - ) + let bod = results.pop().unwrap(); + let dom = results.pop().unwrap(); + results.push(Expr::lam(name, dom, bod, bi)); }, - - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - let lam = &*pi.img.as_ptr(); - let dom = to_expr_go(pi.dom, var_map, depth); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::PiBuild(name, bi, var_ptr) => { var_map.remove(&var_ptr); - Expr::all( - pi.binder_name.clone(), - dom, - bod, - pi.binder_info.clone(), - ) + let bod = results.pop().unwrap(); + let dom = results.pop().unwrap(); + results.push(Expr::all(name, dom, bod, bi)); }, - - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - let lam = &*let_node.bod.as_ptr(); - let typ = to_expr_go(let_node.typ, var_map, depth); - let val = to_expr_go(let_node.val, var_map, depth); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::LetBuild(name, non_dep, var_ptr) => { var_map.remove(&var_ptr); - Expr::letE( - let_node.binder_name.clone(), - typ, - val, - bod, - let_node.non_dep, - ) + let bod = results.pop().unwrap(); + let val = results.pop().unwrap(); + let typ = results.pop().unwrap(); + results.push(Expr::letE(name, typ, val, bod, non_dep)); }, - - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - let structure = to_expr_go(proj.expr, var_map, depth); - Expr::proj(proj.type_name.clone(), proj.idx.clone(), structure) + Frame::Proj(name, idx) => { + let structure = results.pop().unwrap(); + results.push(Expr::proj(name, idx, structure)); }, - - DAGPtr::Lam(link) => { - // Standalone Lam shouldn't appear at the top level, - // but handle it gracefully for completeness. - let lam = &*link.as_ptr(); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::LamBuild(var_ptr) => { var_map.remove(&var_ptr); - // Wrap in a lambda with anonymous name and default binder info - Expr::lam( + let bod = results.pop().unwrap(); + results.push(Expr::lam( Name::anon(), Expr::sort(Level::zero()), bod, - crate::ix::env::BinderInfo::Default, - ) + BinderInfo::Default, + )); + }, + Frame::CacheStore(key, depth) => { + let result = results.last().unwrap().clone(); + cache.insert((key, depth), result); }, } } + + results.pop().unwrap() } #[cfg(test)] diff --git a/src/ix/kernel/dag.rs b/src/ix/kernel/dag.rs index 9837405f..ae021431 100644 --- a/src/ix/kernel/dag.rs +++ b/src/ix/kernel/dag.rs @@ -2,7 +2,9 @@ use core::ptr::NonNull; use crate::ix::env::{BinderInfo, Level, Literal, Name}; use crate::lean::nat::Nat; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; + +use super::level::subst_level; use super::dll::DLL; @@ -131,17 +133,12 @@ pub struct Var { pub depth: u64, /// Points to the binding Lam, or Free for free variables. pub binder: BinderPtr, + /// If this Var came from an Fvar, preserves the name for roundtrip. + pub fvar_name: Option, /// Parent pointers. pub parents: Option>, } -impl Copy for Var {} -impl Clone for Var { - fn clone(&self) -> Self { - *self - } -} - /// Sort node (universe). #[repr(C)] pub struct Sort { @@ -260,7 +257,7 @@ pub fn alloc_lam( let lam_ptr = alloc_val(Lam { bod, bod_ref: DLL::singleton(ParentPtr::Root), - var: Var { depth, binder: BinderPtr::Free, parents: None }, + var: Var { depth, binder: BinderPtr::Free, fvar_name: None, parents: None }, parents, }); unsafe { @@ -469,59 +466,587 @@ pub fn free_dag(dag: DAG) { free_dag_nodes(dag.head, &mut visited); } -fn free_dag_nodes(node: DAGPtr, visited: &mut FxHashSet) { - let key = dag_ptr_key(node); - if !visited.insert(key) { - return; - } - unsafe { - match node { - DAGPtr::Var(link) => { - let var = &*link.as_ptr(); - // Only free separately-allocated free vars; bound vars are - // embedded in their Lam struct and freed with it. - if let BinderPtr::Free = var.binder { +fn free_dag_nodes(root: DAGPtr, visited: &mut FxHashSet) { + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + let key = dag_ptr_key(node); + if !visited.insert(key) { + continue; + } + unsafe { + match node { + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + stack.push(lam.bod); drop(Box::from_raw(link.as_ptr())); - } - }, - DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lam(link) => { - let lam = &*link.as_ptr(); - free_dag_nodes(lam.bod, visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - free_dag_nodes(fun.dom, visited); - free_dag_nodes(DAGPtr::Lam(fun.img), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - free_dag_nodes(pi.dom, visited); - free_dag_nodes(DAGPtr::Lam(pi.img), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - free_dag_nodes(app.fun, visited); - free_dag_nodes(app.arg, visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - free_dag_nodes(let_node.typ, visited); - free_dag_nodes(let_node.val, visited); - free_dag_nodes(DAGPtr::Lam(let_node.bod), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - free_dag_nodes(proj.expr, visited); - drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + stack.push(fun.dom); + stack.push(DAGPtr::Lam(fun.img)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + stack.push(pi.dom); + stack.push(DAGPtr::Lam(pi.img)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + stack.push(app.fun); + stack.push(app.arg); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + stack.push(let_node.typ); + stack.push(let_node.val); + stack.push(DAGPtr::Lam(let_node.bod)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + stack.push(proj.expr); + drop(Box::from_raw(link.as_ptr())); + }, + } + } + } +} + +// ============================================================================ +// DAG utilities for typechecker +// ============================================================================ + +/// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])` at the DAG level. +pub fn dag_unfold_apps(dag: DAGPtr) -> (DAGPtr, Vec) { + let mut args = Vec::new(); + let mut cursor = dag; + loop { + match cursor { + DAGPtr::App(app) => unsafe { + let app_ref = &*app.as_ptr(); + args.push(app_ref.arg); + cursor = app_ref.fun; }, + _ => break, } } + args.reverse(); + (cursor, args) +} + +/// Reconstruct `f a1 a2 ... an` from a head and arguments at the DAG level. +pub fn dag_foldl_apps(fun: DAGPtr, args: &[DAGPtr]) -> DAGPtr { + let mut result = fun; + for &arg in args { + let app = alloc_app(result, arg, None); + result = DAGPtr::App(app); + } + result +} + +/// Substitute universe level parameters in-place throughout a DAG. +/// +/// Replaces `Level::param(params[i])` with `values[i]` in all Sort and Cnst +/// nodes reachable from `root`. Uses a visited set to handle DAG sharing. +/// +/// The DAG must not be shared with other live structures, since this mutates +/// nodes in place (intended for freshly `from_expr`'d DAGs). +pub fn subst_dag_levels( + root: DAGPtr, + params: &[Name], + values: &[Level], +) -> DAGPtr { + if params.is_empty() { + return root; + } + let mut visited = FxHashSet::default(); + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + let key = dag_ptr_key(node); + if !visited.insert(key) { + continue; + } + unsafe { + match node { + DAGPtr::Sort(p) => { + let sort = &mut *p.as_ptr(); + sort.level = subst_level(&sort.level, params, values); + }, + DAGPtr::Cnst(p) => { + let cnst = &mut *p.as_ptr(); + cnst.levels = + cnst.levels.iter().map(|l| subst_level(l, params, values)).collect(); + }, + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + stack.push(app.fun); + stack.push(app.arg); + }, + DAGPtr::Fun(p) => { + let fun = &*p.as_ptr(); + stack.push(fun.dom); + stack.push(DAGPtr::Lam(fun.img)); + }, + DAGPtr::Pi(p) => { + let pi = &*p.as_ptr(); + stack.push(pi.dom); + stack.push(DAGPtr::Lam(pi.img)); + }, + DAGPtr::Lam(p) => { + let lam = &*p.as_ptr(); + stack.push(lam.bod); + }, + DAGPtr::Let(p) => { + let let_node = &*p.as_ptr(); + stack.push(let_node.typ); + stack.push(let_node.val); + stack.push(DAGPtr::Lam(let_node.bod)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + stack.push(proj.expr); + }, + DAGPtr::Var(_) | DAGPtr::Lit(_) => {}, + } + } + } + root +} + +// ============================================================================ +// Deep-copy substitution for typechecker +// ============================================================================ + +/// Deep-copy a Lam body, substituting `replacement` for the Lam's bound variable. +/// +/// Unlike `subst_pi_body` (which mutates nodes in place via BUBS), this creates +/// a completely fresh DAG. This prevents the type DAG from sharing mutable nodes +/// with the term DAG, avoiding corruption when WHNF later beta-reduces in the +/// type DAG. +/// +/// The `replacement` is also deep-copied to prevent WHNF's `reduce_lam` from +/// modifying the original term DAG when it beta-reduces through substituted +/// Fun/Lam nodes. Vars not bound within the copy scope (outer-binder vars and +/// free vars) are preserved by pointer to maintain identity for `def_eq`. +/// +/// Deep-copy the Lam body with substitution. Used when the Lam is from +/// the TERM DAG (e.g., `infer_lambda`, `infer_pi`, `infer_let`) to +/// protect the term from destructive in-place modification. +/// +/// The replacement is also deep-copied to isolate the term DAG from +/// WHNF mutations. Vars not bound within the copy scope are preserved +/// by pointer to maintain identity for `def_eq`. +pub fn dag_copy_subst(lam: NonNull, replacement: DAGPtr) -> DAGPtr { + use std::sync::atomic::{AtomicU64, Ordering}; + static COPY_SUBST_CALLS: AtomicU64 = AtomicU64::new(0); + static COPY_SUBST_NODES: AtomicU64 = AtomicU64::new(0); + let call_num = COPY_SUBST_CALLS.fetch_add(1, Ordering::Relaxed); + + let mut cache: FxHashMap = FxHashMap::default(); + unsafe { + let lambda = &*lam.as_ptr(); + let var_ptr = + NonNull::new(&lambda.var as *const Var as *mut Var).unwrap(); + let var_key = dag_ptr_key(DAGPtr::Var(var_ptr)); + // Deep-copy the replacement (isolates from term DAG mutations) + let copied_replacement = dag_copy_node(replacement, &mut cache); + let repl_nodes = cache.len(); + // Clear cache: body and replacement are separate DAGs, no shared nodes. + cache.clear(); + // Map the target var to the copied replacement + cache.insert(var_key, copied_replacement); + // Deep copy the body + let result = dag_copy_node(lambda.bod, &mut cache); + let body_nodes = cache.len(); + let total = COPY_SUBST_NODES.fetch_add(body_nodes as u64, Ordering::Relaxed) + body_nodes as u64; + if call_num % 10 == 0 || body_nodes > 1000 { + eprintln!("[dag_copy_subst] call={call_num} repl={repl_nodes} body={body_nodes} total_nodes={total}"); + } + result + } +} + +/// Lightweight substitution for TYPE DAG Lams (from `from_expr` or derived). +/// Only the replacement is deep-copied; the body is modified in-place via +/// BUBS `subst_pi_body`, preserving DAG sharing and avoiding exponential +/// blowup. +pub fn dag_type_subst(lam: NonNull, replacement: DAGPtr) -> DAGPtr { + use super::upcopy::subst_pi_body; + let mut cache: FxHashMap = FxHashMap::default(); + let copied_replacement = dag_copy_node(replacement, &mut cache); + subst_pi_body(lam, copied_replacement) +} + +/// Iteratively copy a DAG node, using `cache` for sharing and var substitution. +/// +/// Uses an explicit work stack to avoid stack overflow on deeply nested DAGs +/// (e.g., 40000+ left-nested App chains from unfolded definitions). +fn dag_copy_node( + root: DAGPtr, + cache: &mut FxHashMap, +) -> DAGPtr { + // Stack frames for the iterative traversal. + // Compound nodes use a two-phase approach: + // Visit → push children + Finish frame → children processed → Finish builds node + // Binder nodes (Fun/Pi/Let/Lam) use three phases: + // Visit → push dom/typ/val + CreateLam → CreateLam inserts var mapping + pushes body + Finish + enum Frame { + Visit(DAGPtr), + FinishApp(usize, NonNull), + FinishProj(usize, NonNull), + CreateFunLam(usize, NonNull), + FinishFun(usize, NonNull, NonNull), + CreatePiLam(usize, NonNull), + FinishPi(usize, NonNull, NonNull), + CreateLamBody(usize, NonNull), + // FinishLam(key, new_lam, old_lam) — old_lam needed to look up body key + FinishLam(usize, NonNull, NonNull), + CreateLetLam(usize, NonNull), + FinishLet(usize, NonNull, NonNull), + } + + let mut stack: Vec = vec![Frame::Visit(root)]; + // Track nodes that have been visited (started processing) to prevent + // exponential blowup when copying DAGs with shared compound nodes. + // Without this, a shared node visited from two parents would be + // processed twice, leading to 2^depth duplication. + let mut visited: FxHashSet = FxHashSet::default(); + // Deferred back-edge patches: (key_of_placeholder, original_node) + // WHNF iota reduction can create cyclic DAGs (e.g., Nat.rec step + // function body → recursive Nat.rec result → step function). + // When we encounter a back-edge during copy, we allocate a placeholder + // and record it here. After the main traversal completes, we patch + // each placeholder's children to point to the cached (copied) versions. + let mut deferred: Vec<(usize, DAGPtr)> = Vec::new(); + + while let Some(frame) = stack.pop() { + unsafe { + match frame { + Frame::Visit(node) => { + let key = dag_ptr_key(node); + if cache.contains_key(&key) { + continue; + } + if visited.contains(&key) { + // Cycle back-edge: allocate placeholder, defer patching + match node { + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + let placeholder = alloc_app(app.fun, app.arg, None); + cache.insert(key, DAGPtr::App(placeholder)); + deferred.push((key, node)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + let placeholder = alloc_proj( + proj.type_name.clone(), proj.idx.clone(), proj.expr, None, + ); + cache.insert(key, DAGPtr::Proj(placeholder)); + deferred.push((key, node)); + }, + // Leaf-like nodes shouldn't cycle; handle just in case + _ => { + cache.insert(key, node); + }, + } + continue; + } + visited.insert(key); + match node { + DAGPtr::Var(_) => { + // Not in cache: outer-binder or free var. Preserve original. + cache.insert(key, node); + }, + DAGPtr::Sort(p) => { + let sort = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Sort(alloc_val(Sort { + level: sort.level.clone(), + parents: None, + })), + ); + }, + DAGPtr::Cnst(p) => { + let cnst = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Cnst(alloc_val(Cnst { + name: cnst.name.clone(), + levels: cnst.levels.clone(), + parents: None, + })), + ); + }, + DAGPtr::Lit(p) => { + let lit = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Lit(alloc_val(LitNode { + val: lit.val.clone(), + parents: None, + })), + ); + }, + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + // Finish after children; visit fun then arg + stack.push(Frame::FinishApp(key, p)); + stack.push(Frame::Visit(app.arg)); + stack.push(Frame::Visit(app.fun)); + }, + DAGPtr::Fun(p) => { + let fun = &*p.as_ptr(); + // Phase 1: visit dom, then create Lam + stack.push(Frame::CreateFunLam(key, p)); + stack.push(Frame::Visit(fun.dom)); + }, + DAGPtr::Pi(p) => { + let pi = &*p.as_ptr(); + stack.push(Frame::CreatePiLam(key, p)); + stack.push(Frame::Visit(pi.dom)); + }, + DAGPtr::Lam(p) => { + // Standalone Lam: create Lam, then visit body + stack.push(Frame::CreateLamBody(key, p)); + }, + DAGPtr::Let(p) => { + let let_node = &*p.as_ptr(); + // Visit typ and val, then create Lam + stack.push(Frame::CreateLetLam(key, p)); + stack.push(Frame::Visit(let_node.val)); + stack.push(Frame::Visit(let_node.typ)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + stack.push(Frame::FinishProj(key, p)); + stack.push(Frame::Visit(proj.expr)); + }, + } + }, + + Frame::FinishApp(key, app_ptr) => { + let app = &*app_ptr.as_ptr(); + let new_fun = cache[&dag_ptr_key(app.fun)]; + let new_arg = cache[&dag_ptr_key(app.arg)]; + let new_app = alloc_app(new_fun, new_arg, None); + let app_ref = &mut *new_app.as_ptr(); + let fun_ref = + NonNull::new(&mut app_ref.fun_ref as *mut Parents).unwrap(); + add_to_parents(new_fun, fun_ref); + let arg_ref = + NonNull::new(&mut app_ref.arg_ref as *mut Parents).unwrap(); + add_to_parents(new_arg, arg_ref); + cache.insert(key, DAGPtr::App(new_app)); + }, + + Frame::FinishProj(key, proj_ptr) => { + let proj = &*proj_ptr.as_ptr(); + let new_expr = cache[&dag_ptr_key(proj.expr)]; + let new_proj = alloc_proj( + proj.type_name.clone(), + proj.idx.clone(), + new_expr, + None, + ); + let proj_ref = &mut *new_proj.as_ptr(); + let expr_ref = + NonNull::new(&mut proj_ref.expr_ref as *mut Parents).unwrap(); + add_to_parents(new_expr, expr_ref); + cache.insert(key, DAGPtr::Proj(new_proj)); + }, + + // --- Fun binder: dom visited, create Lam, visit body --- + Frame::CreateFunLam(key, fun_ptr) => { + let fun = &*fun_ptr.as_ptr(); + let old_lam = &*fun.img.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + // Phase 2: visit body, then finish + stack.push(Frame::FinishFun(key, fun_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishFun(key, fun_ptr, new_lam) => { + let fun = &*fun_ptr.as_ptr(); + let old_lam = &*fun.img.as_ptr(); + let new_dom = cache[&dag_ptr_key(fun.dom)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_fun_node = alloc_fun( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_dom, + new_lam, + None, + ); + let fun_ref = &mut *new_fun_node.as_ptr(); + let dom_ref = + NonNull::new(&mut fun_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(new_dom, dom_ref); + let img_ref = + NonNull::new(&mut fun_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), img_ref); + cache.insert(key, DAGPtr::Fun(new_fun_node)); + }, + + // --- Pi binder: dom visited, create Lam, visit body --- + Frame::CreatePiLam(key, pi_ptr) => { + let pi = &*pi_ptr.as_ptr(); + let old_lam = &*pi.img.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishPi(key, pi_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishPi(key, pi_ptr, new_lam) => { + let pi = &*pi_ptr.as_ptr(); + let old_lam = &*pi.img.as_ptr(); + let new_dom = cache[&dag_ptr_key(pi.dom)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_pi = alloc_pi( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_dom, + new_lam, + None, + ); + let pi_ref = &mut *new_pi.as_ptr(); + let dom_ref = + NonNull::new(&mut pi_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(new_dom, dom_ref); + let img_ref = + NonNull::new(&mut pi_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), img_ref); + cache.insert(key, DAGPtr::Pi(new_pi)); + }, + + // --- Standalone Lam: create Lam, visit body --- + Frame::CreateLamBody(key, old_lam_ptr) => { + let old_lam = &*old_lam_ptr.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishLam(key, new_lam, old_lam_ptr)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishLam(key, new_lam, old_lam_ptr) => { + let old_lam = &*old_lam_ptr.as_ptr(); + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + cache.insert(key, DAGPtr::Lam(new_lam)); + }, + + // --- Let binder: typ+val visited, create Lam, visit body --- + Frame::CreateLetLam(key, let_ptr) => { + let let_node = &*let_ptr.as_ptr(); + let old_lam = &*let_node.bod.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishLet(key, let_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishLet(key, let_ptr, new_lam) => { + let let_node = &*let_ptr.as_ptr(); + let old_lam = &*let_node.bod.as_ptr(); + let new_typ = cache[&dag_ptr_key(let_node.typ)]; + let new_val = cache[&dag_ptr_key(let_node.val)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_let = alloc_let( + let_node.binder_name.clone(), + let_node.non_dep, + new_typ, + new_val, + new_lam, + None, + ); + let let_ref = &mut *new_let.as_ptr(); + let typ_ref = + NonNull::new(&mut let_ref.typ_ref as *mut Parents).unwrap(); + add_to_parents(new_typ, typ_ref); + let val_ref = + NonNull::new(&mut let_ref.val_ref as *mut Parents).unwrap(); + add_to_parents(new_val, val_ref); + let bod_ref2 = + NonNull::new(&mut let_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), bod_ref2); + cache.insert(key, DAGPtr::Let(new_let)); + }, + } + } + } + + cache[&dag_ptr_key(root)] } diff --git a/src/ix/kernel/dag_tc.rs b/src/ix/kernel/dag_tc.rs new file mode 100644 index 00000000..3b70d03d --- /dev/null +++ b/src/ix/kernel/dag_tc.rs @@ -0,0 +1,2857 @@ +use core::ptr::NonNull; + +use num_bigint::BigUint; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rustc_hash::FxHashMap; + +use crate::ix::env::{ + BinderInfo, ConstantInfo, Env, Level, Literal, Name, ReducibilityHints, +}; +use crate::lean::nat::Nat; + +use super::convert::{from_expr, to_expr}; +use super::dag::*; +use super::error::TcError; +use super::level::{ + all_expr_uparams_defined, eq_antisymm, eq_antisymm_many, is_zero, + no_dupes_all_params, +}; +use super::upcopy::replace_child; +use super::whnf::{ + has_loose_bvars, mk_name2, nat_lit_dag, subst_expr_levels, + try_reduce_native_dag, try_reduce_nat_dag, whnf_dag, +}; + +type TcResult = Result; + +/// DAG-native type checker. +/// +/// Operates directly on `DAGPtr` nodes, avoiding Expr↔DAG round-trips. +/// Caches are keyed by `dag_ptr_key` (raw pointer address), which is safe +/// because DAG nodes are never freed during a single `check_declar` call. +pub struct DagTypeChecker<'env> { + pub env: &'env Env, + pub whnf_cache: FxHashMap, + pub whnf_no_delta_cache: FxHashMap, + pub infer_cache: FxHashMap, + /// Cache for `infer_const` results, keyed by the Blake3 hash of the + /// Cnst node's Expr representation (name + levels). Avoids repeated + /// `from_expr` calls for the same constant at the same universe levels. + pub const_type_cache: FxHashMap, + pub local_counter: u64, + pub local_types: FxHashMap, + /// Stack of corresponding bound variable pairs for binder comparison. + /// Each entry `(key_x, key_y)` means `Var_x` and `Var_y` should be + /// treated as equal when comparing under their respective binders. + binder_eq_map: Vec<(usize, usize)>, + // Debug counters + whnf_calls: u64, + def_eq_calls: u64, + infer_calls: u64, + infer_depth: u64, + infer_max_depth: u64, +} + +impl<'env> DagTypeChecker<'env> { + pub fn new(env: &'env Env) -> Self { + DagTypeChecker { + env, + whnf_cache: FxHashMap::default(), + whnf_no_delta_cache: FxHashMap::default(), + infer_cache: FxHashMap::default(), + const_type_cache: FxHashMap::default(), + local_counter: 0, + local_types: FxHashMap::default(), + binder_eq_map: Vec::new(), + whnf_calls: 0, + def_eq_calls: 0, + infer_calls: 0, + infer_depth: 0, + infer_max_depth: 0, + } + } + + // ========================================================================== + // WHNF with caching + // ========================================================================== + + /// Reduce a DAG node to weak head normal form. + /// + /// Checks the cache first, then calls `whnf_dag` and caches the result. + pub fn whnf(&mut self, ptr: DAGPtr) -> DAGPtr { + self.whnf_calls += 1; + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.whnf_cache.get(&key) { + return cached; + } + let t0 = std::time::Instant::now(); + let mut dag = DAG { head: ptr }; + whnf_dag(&mut dag, self.env, false); + let result = dag.head; + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[whnf SLOW] {}ms whnf_calls={}", ms, self.whnf_calls); + } + self.whnf_cache.insert(key, result); + result + } + + /// Reduce to WHNF without delta (definition) unfolding. + /// + /// Used in definitional equality to try structural comparison before + /// committing to delta reduction. + pub fn whnf_no_delta(&mut self, ptr: DAGPtr) -> DAGPtr { + self.whnf_calls += 1; + if self.whnf_calls % 100 == 0 { + eprintln!("[DagTC::whnf_no_delta] calls={}", self.whnf_calls); + } + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.whnf_no_delta_cache.get(&key) { + return cached; + } + let mut dag = DAG { head: ptr }; + whnf_dag(&mut dag, self.env, true); + let result = dag.head; + self.whnf_no_delta_cache.insert(key, result); + result + } + + // ========================================================================== + // Ensure helpers + // ========================================================================== + + /// If `ptr` is already a Sort, return its level. Otherwise WHNF and check. + pub fn ensure_sort(&mut self, ptr: DAGPtr) -> TcResult { + if let DAGPtr::Sort(p) = ptr { + let level = unsafe { &(*p.as_ptr()).level }; + return Ok(level.clone()); + } + let t0 = std::time::Instant::now(); + let whnfd = self.whnf(ptr); + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[ensure_sort] whnf took {}ms", ms); + } + match whnfd { + DAGPtr::Sort(p) => { + let level = unsafe { &(*p.as_ptr()).level }; + Ok(level.clone()) + }, + _ => Err(TcError::TypeExpected { + expr: dag_to_expr(ptr), + inferred: dag_to_expr(whnfd), + }), + } + } + + /// If `ptr` is already a Pi, return it. Otherwise WHNF and check. + pub fn ensure_pi(&mut self, ptr: DAGPtr) -> TcResult { + if let DAGPtr::Pi(_) = ptr { + return Ok(ptr); + } + let t0 = std::time::Instant::now(); + let whnfd = self.whnf(ptr); + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[ensure_pi] whnf took {}ms", ms); + } + match whnfd { + DAGPtr::Pi(_) => Ok(whnfd), + _ => Err(TcError::FunctionExpected { + expr: dag_to_expr(ptr), + inferred: dag_to_expr(whnfd), + }), + } + } + + /// Infer the type of `ptr` and ensure it's a Sort; return the universe level. + pub fn infer_sort_of(&mut self, ptr: DAGPtr) -> TcResult { + let ty = self.infer(ptr)?; + let whnfd = self.whnf(ty); + self.ensure_sort(whnfd) + } + + // ========================================================================== + // Definitional equality + // ========================================================================== + + /// Check definitional equality of two DAG nodes. + /// + /// Uses a conjunction work stack: processes pairs iteratively, all must + /// be equal. Binder comparison uses recursive calls with a binder + /// correspondence map rather than pushing raw bodies. + pub fn def_eq(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.def_eq_calls += 1; + eprintln!("[def_eq#{}] depth={}", self.def_eq_calls, self.infer_depth); + const STEP_LIMIT: u64 = 1_000_000; + let mut work: Vec<(DAGPtr, DAGPtr)> = vec![(x, y)]; + let mut steps: u64 = 0; + while let Some((x, y)) = work.pop() { + steps += 1; + if steps > STEP_LIMIT { + return false; + } + if !self.def_eq_step(x, y, &mut work) { + return false; + } + } + true + } + + /// Quick syntactic checks at DAG level. + fn def_eq_quick_check(&self, x: DAGPtr, y: DAGPtr) -> Option { + if dag_ptr_key(x) == dag_ptr_key(y) { + return Some(true); + } + unsafe { + match (x, y) { + (DAGPtr::Sort(a), DAGPtr::Sort(b)) => { + Some(eq_antisymm(&(*a.as_ptr()).level, &(*b.as_ptr()).level)) + }, + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => { + let ca = &*a.as_ptr(); + let cb = &*b.as_ptr(); + if ca.name == cb.name && eq_antisymm_many(&ca.levels, &cb.levels) { + Some(true) + } else { + None // different names may still be delta-equal + } + }, + (DAGPtr::Lit(a), DAGPtr::Lit(b)) => { + Some((*a.as_ptr()).val == (*b.as_ptr()).val) + }, + (DAGPtr::Var(a), DAGPtr::Var(b)) => { + let va = &*a.as_ptr(); + let vb = &*b.as_ptr(); + match (&va.fvar_name, &vb.fvar_name) { + (Some(na), Some(nb)) => { + if na == nb { Some(true) } else { None } + }, + (None, None) => { + let ka = dag_ptr_key(x); + let kb = dag_ptr_key(y); + Some( + self + .binder_eq_map + .iter() + .any(|&(ma, mb)| ma == ka && mb == kb), + ) + }, + _ => Some(false), + } + }, + _ => None, + } + } + } + + /// Process one def_eq pair. + fn def_eq_step( + &mut self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + if let Some(quick) = self.def_eq_quick_check(x, y) { + return quick; + } + let x_n = self.whnf_no_delta(x); + let y_n = self.whnf_no_delta(y); + if let Some(quick) = self.def_eq_quick_check(x_n, y_n) { + return quick; + } + if self.proof_irrel_eq(x_n, y_n) { + return true; + } + match self.lazy_delta_step(x_n, y_n) { + DagDeltaResult::Found(result) => result, + DagDeltaResult::Exhausted(x_e, y_e) => { + if self.def_eq_const(x_e, y_e) { return true; } + if self.def_eq_proj_push(x_e, y_e, work) { return true; } + if self.def_eq_app_push(x_e, y_e, work) { return true; } + if self.def_eq_binder_full(x_e, y_e) { return true; } + if self.try_eta_expansion(x_e, y_e) { return true; } + if self.try_eta_struct(x_e, y_e) { return true; } + if self.is_def_eq_unit_like(x_e, y_e) { return true; } + false + }, + } + } + + // --- Proof irrelevance --- + + /// If both x and y are proofs of the same proposition, they are def-eq. + fn proof_irrel_eq(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + // Skip for binder types: inferring Fun/Pi/Lam would recurse into + // binder bodies. Kept as a conservative guard for def_eq_binder_full. + if matches!(x, DAGPtr::Fun(_) | DAGPtr::Pi(_) | DAGPtr::Lam(_)) { + return false; + } + if matches!(y, DAGPtr::Fun(_) | DAGPtr::Pi(_) | DAGPtr::Lam(_)) { + return false; + } + let x_ty = match self.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.is_proposition(x_ty) { + return false; + } + let y_ty = match self.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.is_proposition(y_ty) { + return false; + } + self.def_eq(x_ty, y_ty) + } + + /// Check if a type lives in Prop (Sort 0). + fn is_proposition(&mut self, ty: DAGPtr) -> bool { + let whnfd = self.whnf(ty); + match whnfd { + DAGPtr::Sort(s) => unsafe { is_zero(&(*s.as_ptr()).level) }, + _ => false, + } + } + + // --- Lazy delta --- + + fn lazy_delta_step( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> DagDeltaResult { + let mut x = x; + let mut y = y; + let mut iters: u32 = 0; + const MAX_DELTA_ITERS: u32 = 10_000; + loop { + iters += 1; + if iters > MAX_DELTA_ITERS { + return DagDeltaResult::Exhausted(x, y); + } + + if let Some(quick) = self.def_eq_nat_offset(x, y) { + return DagDeltaResult::Found(quick); + } + + if let Some(x_r) = try_lazy_delta_nat_native(x, self.env) { + let x_r = self.whnf_no_delta(x_r); + if let Some(quick) = self.def_eq_quick_check(x_r, y) { + return DagDeltaResult::Found(quick); + } + x = x_r; + continue; + } + if let Some(y_r) = try_lazy_delta_nat_native(y, self.env) { + let y_r = self.whnf_no_delta(y_r); + if let Some(quick) = self.def_eq_quick_check(x, y_r) { + return DagDeltaResult::Found(quick); + } + y = y_r; + continue; + } + + let x_def = dag_get_applied_def(x, self.env); + let y_def = dag_get_applied_def(y, self.env); + match (&x_def, &y_def) { + (None, None) => return DagDeltaResult::Exhausted(x, y), + (Some(_), None) => { + x = self.dag_delta(x); + }, + (None, Some(_)) => { + y = self.dag_delta(y); + }, + (Some((x_name, x_hint)), Some((y_name, y_hint))) => { + if x_name == y_name && x_hint == y_hint { + if self.def_eq_app_eager(x, y) { + return DagDeltaResult::Found(true); + } + x = self.dag_delta(x); + y = self.dag_delta(y); + } else if hint_lt(x_hint, y_hint) { + y = self.dag_delta(y); + } else { + x = self.dag_delta(x); + } + }, + } + + if let Some(quick) = self.def_eq_quick_check(x, y) { + return DagDeltaResult::Found(quick); + } + } + } + + /// Unfold a definition and do cheap WHNF (no delta). + fn dag_delta(&mut self, ptr: DAGPtr) -> DAGPtr { + match dag_try_unfold_def(ptr, self.env) { + Some(unfolded) => self.whnf_no_delta(unfolded), + None => ptr, + } + } + + // --- Nat offset equality --- + + fn def_eq_nat_offset( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> Option { + if is_nat_zero_dag(x) && is_nat_zero_dag(y) { + return Some(true); + } + match (is_nat_succ_dag(x), is_nat_succ_dag(y)) { + (Some(x_pred), Some(y_pred)) => Some(self.def_eq(x_pred, y_pred)), + _ => None, + } + } + + // --- Congruence --- + + fn def_eq_const(&self, x: DAGPtr, y: DAGPtr) -> bool { + unsafe { + match (x, y) { + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => { + let ca = &*a.as_ptr(); + let cb = &*b.as_ptr(); + ca.name == cb.name && eq_antisymm_many(&ca.levels, &cb.levels) + }, + _ => false, + } + } + } + + fn def_eq_proj_push( + &self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + unsafe { + match (x, y) { + (DAGPtr::Proj(a), DAGPtr::Proj(b)) => { + let pa = &*a.as_ptr(); + let pb = &*b.as_ptr(); + if pa.idx == pb.idx { + work.push((pa.expr, pb.expr)); + true + } else { + false + } + }, + _ => false, + } + } + } + + fn def_eq_app_push( + &self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + let (f1, args1) = dag_unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = dag_unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + work.push((f1, f2)); + for (&a, &b) in args1.iter().zip(args2.iter()) { + work.push((a, b)); + } + true + } + + /// Eager app congruence (used by lazy_delta_step). + fn def_eq_app_eager(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let (f1, args1) = dag_unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = dag_unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + if !self.def_eq(f1, f2) { + return false; + } + args1.iter().zip(args2.iter()).all(|(&a, &b)| self.def_eq(a, b)) + } + + // --- Binder full --- + + /// Compare Pi/Fun binders: peel matching layers, push var correspondence + /// into `binder_eq_map`, and compare bodies recursively. + fn def_eq_binder_full(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let mut cx = x; + let mut cy = y; + let mut matched = false; + let mut n_pushed: usize = 0; + loop { + unsafe { + match (cx, cy) { + (DAGPtr::Pi(px), DAGPtr::Pi(py)) => { + let pi_x = &*px.as_ptr(); + let pi_y = &*py.as_ptr(); + if !self.def_eq(pi_x.dom, pi_y.dom) { + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + return false; + } + let lam_x = &*pi_x.img.as_ptr(); + let lam_y = &*pi_y.img.as_ptr(); + let var_x_ptr = NonNull::new( + &lam_x.var as *const Var as *mut Var, + ) + .unwrap(); + let var_y_ptr = NonNull::new( + &lam_y.var as *const Var as *mut Var, + ) + .unwrap(); + self.binder_eq_map.push(( + dag_ptr_key(DAGPtr::Var(var_x_ptr)), + dag_ptr_key(DAGPtr::Var(var_y_ptr)), + )); + n_pushed += 1; + cx = lam_x.bod; + cy = lam_y.bod; + matched = true; + }, + (DAGPtr::Fun(fx), DAGPtr::Fun(fy)) => { + let fun_x = &*fx.as_ptr(); + let fun_y = &*fy.as_ptr(); + if !self.def_eq(fun_x.dom, fun_y.dom) { + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + return false; + } + let lam_x = &*fun_x.img.as_ptr(); + let lam_y = &*fun_y.img.as_ptr(); + let var_x_ptr = NonNull::new( + &lam_x.var as *const Var as *mut Var, + ) + .unwrap(); + let var_y_ptr = NonNull::new( + &lam_y.var as *const Var as *mut Var, + ) + .unwrap(); + self.binder_eq_map.push(( + dag_ptr_key(DAGPtr::Var(var_x_ptr)), + dag_ptr_key(DAGPtr::Var(var_y_ptr)), + )); + n_pushed += 1; + cx = lam_x.bod; + cy = lam_y.bod; + matched = true; + }, + _ => break, + } + } + } + if !matched { + return false; + } + let result = self.def_eq(cx, cy); + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + result + } + + // --- Eta expansion --- + + fn try_eta_expansion(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.try_eta_expansion_aux(x, y) + || self.try_eta_expansion_aux(y, x) + } + + /// Eta: `fun x => f x` ≡ `f` when `f : (x : A) → B`. + fn try_eta_expansion_aux( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> bool { + let fx = match x { + DAGPtr::Fun(f) => f, + _ => return false, + }; + let y_ty = match self.infer(y) { + Ok(t) => t, + Err(_) => return false, + }; + let y_ty_whnf = self.whnf(y_ty); + if !matches!(y_ty_whnf, DAGPtr::Pi(_)) { + return false; + } + unsafe { + let fun_x = &*fx.as_ptr(); + let lam_x = &*fun_x.img.as_ptr(); + let var_x_ptr = + NonNull::new(&lam_x.var as *const Var as *mut Var).unwrap(); + let var_x = DAGPtr::Var(var_x_ptr); + // Build eta body: App(y, var_x) + // Using the SAME var_x on both sides, so pointer identity + // handles bound variable matching without binder_eq_map. + let eta_body = DAGPtr::App(alloc_app(y, var_x, None)); + self.def_eq(lam_x.bod, eta_body) + } + } + + // --- Struct eta --- + + fn try_eta_struct(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.try_eta_struct_core(x, y) + || self.try_eta_struct_core(y, x) + } + + /// Structure eta: `p =def= S.mk (S.1 p) (S.2 p)` when S is a + /// single-constructor non-recursive inductive with no indices. + fn try_eta_struct_core(&mut self, t: DAGPtr, s: DAGPtr) -> bool { + let (head, args) = dag_unfold_apps(s); + let ctor_name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return false, + }; + let ctor_info = match self.env.get(&ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return false, + }; + if !is_structure_like(&ctor_info.induct, self.env) { + return false; + } + let num_params = ctor_info.num_params.to_u64().unwrap() as usize; + let num_fields = ctor_info.num_fields.to_u64().unwrap() as usize; + if args.len() != num_params + num_fields { + return false; + } + for i in 0..num_fields { + let field = args[num_params + i]; + let proj = alloc_proj( + ctor_info.induct.clone(), + Nat::from(i as u64), + t, + None, + ); + if !self.def_eq(field, DAGPtr::Proj(proj)) { + return false; + } + } + true + } + + // --- Unit-like equality --- + + /// Types with a single zero-field constructor have all inhabitants def-eq. + fn is_def_eq_unit_like(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let x_ty = match self.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + let y_ty = match self.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.def_eq(x_ty, y_ty) { + return false; + } + let whnf_ty = self.whnf(x_ty); + let (head, _) = dag_unfold_apps(whnf_ty); + let name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return false, + }; + match self.env.get(&name) { + Some(ConstantInfo::InductInfo(iv)) => { + if iv.ctors.len() != 1 { + return false; + } + if let Some(ConstantInfo::CtorInfo(c)) = + self.env.get(&iv.ctors[0]) + { + c.num_fields == Nat::ZERO + } else { + false + } + }, + _ => false, + } + } + + /// Assert that two DAG nodes are definitionally equal; return TcError if not. + pub fn assert_def_eq( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> TcResult<()> { + if self.def_eq(x, y) { + Ok(()) + } else { + Err(TcError::DefEqFailure { + lhs: dag_to_expr(x), + rhs: dag_to_expr(y), + }) + } + } + + // ========================================================================== + // Local context management + // ========================================================================== + + /// Create a fresh free variable for entering a binder. + /// + /// Returns a `DAGPtr::Var` with a unique `fvar_name` (derived from the + /// binder name and a monotonic counter) and records `ty` as its type + /// in `local_types`. + pub fn mk_dag_local(&mut self, name: &Name, ty: DAGPtr) -> DAGPtr { + let id = self.local_counter; + self.local_counter += 1; + let local_name = Name::num(name.clone(), Nat::from(id)); + let var = alloc_val(Var { + depth: 0, + binder: BinderPtr::Free, + fvar_name: Some(local_name.clone()), + parents: None, + }); + self.local_types.insert(local_name, ty); + DAGPtr::Var(var) + } + + // ========================================================================== + // Type inference + // ========================================================================== + + /// Infer the type of a DAG node. + /// + /// Stub: will be fully implemented in Step 3. + pub fn infer(&mut self, ptr: DAGPtr) -> TcResult { + self.infer_calls += 1; + self.infer_depth += 1; + // Heartbeat every 500 calls + if self.infer_calls % 500 == 0 { + eprintln!("[infer HEARTBEAT] calls={} depth={} cache={} whnf={} def_eq={} copy_subst_total_nodes=?", + self.infer_calls, self.infer_depth, self.infer_cache.len(), self.whnf_calls, self.def_eq_calls); + } + if self.infer_depth > self.infer_max_depth { + self.infer_max_depth = self.infer_depth; + if self.infer_max_depth % 5 == 0 || self.infer_max_depth > 20 { + let detail = unsafe { match ptr { + DAGPtr::Cnst(p) => format!("Cnst({})", (*p.as_ptr()).name.pretty()), + DAGPtr::App(_) => "App".to_string(), + DAGPtr::Fun(p) => format!("Fun({})", (*p.as_ptr()).binder_name.pretty()), + DAGPtr::Pi(p) => format!("Pi({})", (*p.as_ptr()).binder_name.pretty()), + _ => format!("{:?}", std::mem::discriminant(&ptr)), + }}; + eprintln!("[infer] NEW MAX DEPTH={} calls={} cache={} {detail}", self.infer_max_depth, self.infer_calls, self.infer_cache.len()); + } + } + if self.infer_calls % 1000 == 0 { + eprintln!("[infer] calls={} depth={} cache={}", self.infer_calls, self.infer_depth, self.infer_cache.len()); + } + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.infer_cache.get(&key) { + self.infer_depth -= 1; + return Ok(cached); + } + let t0 = std::time::Instant::now(); + let result = self.infer_core(ptr)?; + let ms = t0.elapsed().as_millis(); + if ms > 100 { + let detail = unsafe { match ptr { + DAGPtr::Cnst(p) => format!("Cnst({})", (*p.as_ptr()).name.pretty()), + DAGPtr::App(_) => "App".to_string(), + DAGPtr::Fun(p) => format!("Fun({})", (*p.as_ptr()).binder_name.pretty()), + DAGPtr::Pi(p) => format!("Pi({})", (*p.as_ptr()).binder_name.pretty()), + _ => format!("{:?}", std::mem::discriminant(&ptr)), + }}; + eprintln!("[infer] depth={} took {}ms {detail}", self.infer_depth, ms); + } + self.infer_cache.insert(key, result); + self.infer_depth -= 1; + Ok(result) + } + + fn infer_core(&mut self, ptr: DAGPtr) -> TcResult { + match ptr { + DAGPtr::Var(p) => unsafe { + let var = &*p.as_ptr(); + match &var.fvar_name { + Some(name) => match self.local_types.get(name) { + Some(&ty) => Ok(ty), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context" + .into(), + }), + }, + None => match var.binder { + BinderPtr::Free => Err(TcError::FreeBoundVariable { + idx: var.depth, + }), + BinderPtr::Lam(_) => Err(TcError::KernelException { + msg: "unexpected bound variable during inference".into(), + }), + }, + } + }, + DAGPtr::Sort(p) => { + let level = unsafe { &(*p.as_ptr()).level }; + let result = alloc_val(Sort { + level: Level::succ(level.clone()), + parents: None, + }); + Ok(DAGPtr::Sort(result)) + }, + DAGPtr::Cnst(p) => { + let (name, levels) = unsafe { + let cnst = &*p.as_ptr(); + (cnst.name.clone(), cnst.levels.clone()) + }; + self.infer_const(&name, &levels) + }, + DAGPtr::App(_) => self.infer_app(ptr), + DAGPtr::Fun(_) => self.infer_lambda(ptr), + DAGPtr::Pi(_) => self.infer_pi(ptr), + DAGPtr::Let(p) => { + let (typ, val, bod_lam) = unsafe { + let let_node = &*p.as_ptr(); + (let_node.typ, let_node.val, let_node.bod) + }; + let val_ty = self.infer(val)?; + self.assert_def_eq(val_ty, typ)?; + let body = dag_copy_subst(bod_lam, val); + self.infer(body) + }, + DAGPtr::Lit(p) => { + let val = unsafe { &(*p.as_ptr()).val }; + self.infer_lit(val) + }, + DAGPtr::Proj(p) => { + let (type_name, idx, structure) = unsafe { + let proj = &*p.as_ptr(); + (proj.type_name.clone(), proj.idx.clone(), proj.expr) + }; + self.infer_proj(&type_name, &idx, structure, ptr) + }, + DAGPtr::Lam(_) => Err(TcError::KernelException { + msg: "unexpected standalone Lam during inference".into(), + }), + } + } + + fn infer_const( + &mut self, + name: &Name, + levels: &[Level], + ) -> TcResult { + // Build a cache key from the constant's name + universe level hashes. + let cache_key = { + let mut hasher = blake3::Hasher::new(); + hasher.update(name.get_hash().as_bytes()); + for l in levels { + hasher.update(l.get_hash().as_bytes()); + } + hasher.finalize() + }; + if let Some(&cached) = self.const_type_cache.get(&cache_key) { + return Ok(cached); + } + + let ci = self + .env + .get(name) + .ok_or_else(|| TcError::UnknownConst { name: name.clone() })?; + + let decl_params = ci.get_level_params(); + if levels.len() != decl_params.len() { + return Err(TcError::KernelException { + msg: format!( + "universe parameter count mismatch for {}", + name.pretty() + ), + }); + } + + let ty = ci.get_type(); + let dag = from_expr(ty); + let result = subst_dag_levels(dag.head, decl_params, levels); + self.const_type_cache.insert(cache_key, result); + Ok(result) + } + + fn infer_app(&mut self, e: DAGPtr) -> TcResult { + let (fun, args) = dag_unfold_apps(e); + let mut fun_ty = self.infer(fun)?; + + for &arg in args.iter() { + let pi = self.ensure_pi(fun_ty)?; + + let (dom, img) = unsafe { + match pi { + DAGPtr::Pi(p) => { + let pi_ref = &*p.as_ptr(); + (pi_ref.dom, pi_ref.img) + }, + _ => unreachable!(), + } + }; + let arg_ty = self.infer(arg)?; + if !self.def_eq(arg_ty, dom) { + return Err(TcError::DefEqFailure { + lhs: dag_to_expr(arg_ty), + rhs: dag_to_expr(dom), + }); + } + eprintln!("[infer_app] before dag_copy_subst"); + fun_ty = dag_copy_subst(img, arg); + eprintln!("[infer_app] after dag_copy_subst"); + } + + Ok(fun_ty) + } + + fn infer_lambda(&mut self, e: DAGPtr) -> TcResult { + let mut cursor = e; + let mut locals: Vec = Vec::new(); + let mut binder_doms: Vec = Vec::new(); + let mut binder_infos: Vec = Vec::new(); + let mut binder_names: Vec = Vec::new(); + + // Peel Fun layers + let mut binder_idx = 0usize; + while let DAGPtr::Fun(fun_ptr) = cursor { + let t_binder = std::time::Instant::now(); + let (name, bi, dom, img) = unsafe { + let fun = &*fun_ptr.as_ptr(); + ( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + fun.img, + ) + }; + + let t_sort = std::time::Instant::now(); + self.infer_sort_of(dom)?; + let sort_ms = t_sort.elapsed().as_millis(); + + let local = self.mk_dag_local(&name, dom); + locals.push(local); + binder_doms.push(dom); + binder_infos.push(bi); + binder_names.push(name.clone()); + + // Enter the binder: deep copy because img is from the TERM DAG + let t_copy = std::time::Instant::now(); + cursor = dag_copy_subst(img, local); + let copy_ms = t_copy.elapsed().as_millis(); + + let total_ms = t_binder.elapsed().as_millis(); + if total_ms > 5 { + eprintln!("[infer_lambda] binder#{binder_idx} {} total={}ms sort={}ms copy={}ms", + name.pretty(), total_ms, sort_ms, copy_ms); + } + binder_idx += 1; + } + + // Infer the body type + let t_body = std::time::Instant::now(); + let body_ty = self.infer(cursor)?; + let body_ms = t_body.elapsed().as_millis(); + if body_ms > 5 { + eprintln!("[infer_lambda] body={}ms after {} binders", body_ms, binder_idx); + } + + // Abstract back: build Pi telescope over the locals + Ok(build_pi_over_locals( + body_ty, + &locals, + &binder_names, + &binder_infos, + &binder_doms, + )) + } + + fn infer_pi(&mut self, e: DAGPtr) -> TcResult { + let mut cursor = e; + let mut locals: Vec = Vec::new(); + let mut universes: Vec = Vec::new(); + + // Peel Pi layers + while let DAGPtr::Pi(pi_ptr) = cursor { + let (name, dom, img) = unsafe { + let pi = &*pi_ptr.as_ptr(); + (pi.binder_name.clone(), pi.dom, pi.img) + }; + + let dom_univ = self.infer_sort_of(dom)?; + universes.push(dom_univ); + + let local = self.mk_dag_local(&name, dom); + locals.push(local); + + // Enter the binder: deep copy because img is from the TERM DAG + cursor = dag_copy_subst(img, local); + } + + // The body must also be a type + let mut result_level = self.infer_sort_of(cursor)?; + + // Compute imax of all levels (innermost first) + for univ in universes.into_iter().rev() { + result_level = Level::imax(univ, result_level); + } + + let result = alloc_val(Sort { + level: result_level, + parents: None, + }); + Ok(DAGPtr::Sort(result)) + } + + fn infer_lit(&mut self, lit: &Literal) -> TcResult { + let name = match lit { + Literal::NatVal(_) => Name::str(Name::anon(), "Nat".into()), + Literal::StrVal(_) => Name::str(Name::anon(), "String".into()), + }; + let cnst = alloc_val(Cnst { name, levels: vec![], parents: None }); + Ok(DAGPtr::Cnst(cnst)) + } + + fn infer_proj( + &mut self, + type_name: &Name, + idx: &Nat, + structure: DAGPtr, + _proj_expr: DAGPtr, + ) -> TcResult { + let structure_ty = self.infer(structure)?; + let structure_ty_whnf = self.whnf(structure_ty); + + let (head, struct_ty_args) = dag_unfold_apps(structure_ty_whnf); + let (head_name, head_levels) = unsafe { + match head { + DAGPtr::Cnst(p) => { + let cnst = &*p.as_ptr(); + (cnst.name.clone(), cnst.levels.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection structure type is not a constant".into(), + }) + }, + } + }; + + let ind = self.env.get(&head_name).ok_or_else(|| { + TcError::UnknownConst { name: head_name.clone() } + })?; + + let (num_params, ctor_name) = match ind { + ConstantInfo::InductInfo(iv) => { + let ctor = iv.ctors.first().ok_or_else(|| { + TcError::KernelException { + msg: "inductive has no constructors".into(), + } + })?; + (iv.num_params.to_u64().unwrap(), ctor.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection type is not an inductive".into(), + }) + }, + }; + + let ctor_ci = self.env.get(&ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + + let ctor_ty_dag = from_expr(ctor_ci.get_type()); + let mut ctor_ty = subst_dag_levels( + ctor_ty_dag.head, + ctor_ci.get_level_params(), + &head_levels, + ); + + // Skip params: instantiate with the actual type arguments + for i in 0..num_params as usize { + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let img = unsafe { (*p.as_ptr()).img }; + ctor_ty = dag_copy_subst(img, struct_ty_args[i]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (params)".into(), + }) + }, + } + } + + // Walk to the idx-th field, substituting projections + let idx_usize = idx.to_u64().unwrap() as usize; + for i in 0..idx_usize { + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let img = unsafe { (*p.as_ptr()).img }; + let proj = alloc_proj( + type_name.clone(), + Nat::from(i as u64), + structure, + None, + ); + ctor_ty = dag_copy_subst(img, DAGPtr::Proj(proj)); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (fields)".into(), + }) + }, + } + } + + // Extract the target field's type (the domain of the next Pi) + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let dom = unsafe { (*p.as_ptr()).dom }; + Ok(dom) + }, + _ => Err(TcError::KernelException { + msg: "ran out of constructor telescope (target field)".into(), + }), + } + } + + // ========================================================================== + // Declaration checking + // ========================================================================== + + /// Validate a declaration's type: no duplicate uparams, no loose bvars, + /// all uparams defined, and type infers to a Sort. + pub fn check_declar_info( + &mut self, + info: &crate::ix::env::ConstantVal, + ) -> TcResult<()> { + if !no_dupes_all_params(&info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "duplicate universe parameters in {}", + info.name.pretty() + ), + }); + } + if has_loose_bvars(&info.typ) { + return Err(TcError::KernelException { + msg: format!( + "free bound variables in type of {}", + info.name.pretty() + ), + }); + } + if !all_expr_uparams_defined(&info.typ, &info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in type of {}", + info.name.pretty() + ), + }); + } + let ty_dag = from_expr(&info.typ).head; + self.infer_sort_of(ty_dag)?; + Ok(()) + } + + /// Check a declaration with both type and value (DefnInfo, ThmInfo, OpaqueInfo). + fn check_value_declar( + &mut self, + cnst: &crate::ix::env::ConstantVal, + value: &crate::ix::env::Expr, + ) -> TcResult<()> { + let t_start = std::time::Instant::now(); + self.check_declar_info(cnst)?; + eprintln!("[cvd @{}ms] check_declar_info done", t_start.elapsed().as_millis()); + if !all_expr_uparams_defined(value, &cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + cnst.name.pretty() + ), + }); + } + let t1 = std::time::Instant::now(); + let val_dag = from_expr(value).head; + eprintln!("[check_value_declar] {} from_expr(value): {}ms", cnst.name.pretty(), t1.elapsed().as_millis()); + let t2 = std::time::Instant::now(); + let inferred_type = self.infer(val_dag)?; + eprintln!("[check_value_declar] {} infer: {}ms", cnst.name.pretty(), t2.elapsed().as_millis()); + let t3 = std::time::Instant::now(); + let ty_dag = from_expr(&cnst.typ).head; + eprintln!("[check_value_declar] {} from_expr(type): {}ms", cnst.name.pretty(), t3.elapsed().as_millis()); + if !self.def_eq(inferred_type, ty_dag) { + let lhs_expr = dag_to_expr(inferred_type); + let rhs_expr = dag_to_expr(ty_dag); + return Err(TcError::DefEqFailure { + lhs: lhs_expr, + rhs: rhs_expr, + }); + } + Ok(()) + } + + /// Check a single declaration. + pub fn check_declar( + &mut self, + ci: &ConstantInfo, + ) -> TcResult<()> { + match ci { + ConstantInfo::AxiomInfo(v) => { + self.check_declar_info(&v.cnst)?; + }, + ConstantInfo::DefnInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::ThmInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::OpaqueInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::QuotInfo(v) => { + self.check_declar_info(&v.cnst)?; + super::quot::check_quot(self.env)?; + }, + ConstantInfo::InductInfo(v) => { + // Use Expr-level TypeChecker for structural inductive validation + // (positivity, return types, field universes). These checks aren't + // performance-critical and work on small type telescopes. + let mut expr_tc = super::tc::TypeChecker::new(self.env); + super::inductive::check_inductive(v, &mut expr_tc)?; + }, + ConstantInfo::CtorInfo(v) => { + self.check_declar_info(&v.cnst)?; + if self.env.get(&v.induct).is_none() { + return Err(TcError::UnknownConst { + name: v.induct.clone(), + }); + } + }, + ConstantInfo::RecInfo(v) => { + self.check_declar_info(&v.cnst)?; + for ind_name in &v.all { + if self.env.get(ind_name).is_none() { + return Err(TcError::UnknownConst { + name: ind_name.clone(), + }); + } + } + super::inductive::validate_k_flag(v, self.env)?; + }, + } + Ok(()) + } +} + + +/// Convert a DAGPtr to an Expr. Used only when constructing TcError values. +fn dag_to_expr(ptr: DAGPtr) -> crate::ix::env::Expr { + let dag = DAG { head: ptr }; + to_expr(&dag) +} + +/// Check all declarations in an environment in parallel using the DAG TC. +pub fn dag_check_env(env: &Env) -> Vec<(Name, TcError)> { + use std::collections::BTreeSet; + use std::io::Write; + use std::sync::Mutex; + use std::sync::atomic::{AtomicUsize, Ordering}; + + let total = env.len(); + let checked = AtomicUsize::new(0); + + struct Display { + active: BTreeSet, + prev_lines: usize, + } + let display = + Mutex::new(Display { active: BTreeSet::new(), prev_lines: 0 }); + + let refresh = |d: &mut Display, checked: usize| { + let mut stderr = std::io::stderr().lock(); + if d.prev_lines > 0 { + write!(stderr, "\x1b[{}A", d.prev_lines).ok(); + } + write!( + stderr, + "\x1b[2K[dag_check_env] {}/{} — {} active\n", + checked, + total, + d.active.len() + ) + .ok(); + let mut new_lines = 1; + for name in &d.active { + write!(stderr, "\x1b[2K {}\n", name).ok(); + new_lines += 1; + } + let extra = d.prev_lines.saturating_sub(new_lines); + for _ in 0..extra { + write!(stderr, "\x1b[2K\n").ok(); + } + if extra > 0 { + write!(stderr, "\x1b[{}A", extra).ok(); + } + d.prev_lines = new_lines; + stderr.flush().ok(); + }; + + env + .par_iter() + .filter_map(|(name, ci): (&Name, &ConstantInfo)| { + let pretty = name.pretty(); + { + let mut d = display.lock().unwrap(); + d.active.insert(pretty.clone()); + refresh(&mut d, checked.load(Ordering::Relaxed)); + } + + let mut tc = DagTypeChecker::new(env); + let result = tc.check_declar(ci); + + let n = checked.fetch_add(1, Ordering::Relaxed) + 1; + { + let mut d = display.lock().unwrap(); + d.active.remove(&pretty); + refresh(&mut d, n); + } + + match result { + Ok(()) => None, + Err(e) => Some((name.clone(), e)), + } + }) + .collect() +} + +// ============================================================================ +// build_pi_over_locals +// ============================================================================ + +/// Abstract free variables back into a Pi telescope. +/// +/// Given a `body` type (DAGPtr containing free Vars created by `mk_dag_local`) +/// and corresponding binder information, builds a Pi telescope at the DAG level. +/// +/// Processes binders from innermost (last) to outermost (first). For each: +/// 1. Allocates a `Lam` with `bod = current_result` +/// 2. Calls `replace_child(free_var, lam.var)` to redirect all references +/// 3. Allocates `Pi(name, bi, dom, lam)` and wires parent pointers +pub fn build_pi_over_locals( + body: DAGPtr, + locals: &[DAGPtr], + names: &[Name], + bis: &[BinderInfo], + doms: &[DAGPtr], +) -> DAGPtr { + let mut result = body; + // Process from innermost (last) to outermost (first) + for i in (0..locals.len()).rev() { + // 1. Allocate Lam wrapping the current result + let lam = alloc_lam(0, result, None); + unsafe { + let lam_ref = &mut *lam.as_ptr(); + // Wire bod_ref as parent of result + let bod_ref = + NonNull::new(&mut lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(result, bod_ref); + // 2. Redirect all references from the free var to the bound var + let new_var = NonNull::new(&mut lam_ref.var as *mut Var).unwrap(); + replace_child(locals[i], DAGPtr::Var(new_var)); + } + // 3. Allocate Pi + let pi = alloc_pi(names[i].clone(), bis[i].clone(), doms[i], lam, None); + unsafe { + let pi_ref = &mut *pi.as_ptr(); + // Wire dom_ref as parent of doms[i] + let dom_ref = + NonNull::new(&mut pi_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(doms[i], dom_ref); + // Wire img_ref as parent of Lam + let img_ref = + NonNull::new(&mut pi_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam), img_ref); + } + result = DAGPtr::Pi(pi); + } + result +} + +// ============================================================================ +// Definitional equality helpers (free functions) +// ============================================================================ + +/// Result of lazy delta reduction at DAG level. +enum DagDeltaResult { + Found(bool), + Exhausted(DAGPtr, DAGPtr), +} + +/// Get the name and reducibility hint of an applied definition. +fn dag_get_applied_def( + ptr: DAGPtr, + env: &Env, +) -> Option<(Name, ReducibilityHints)> { + let (head, _) = dag_unfold_apps(ptr); + let name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return None, + }; + let ci = env.get(&name)?; + match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + None + } else { + Some((name, d.hints)) + } + }, + ConstantInfo::ThmInfo(_) => { + Some((name, ReducibilityHints::Opaque)) + }, + _ => None, + } +} + +/// Try to unfold a definition at DAG level. +fn dag_try_unfold_def(ptr: DAGPtr, env: &Env) -> Option { + let (head, args) = dag_unfold_apps(ptr); + let (name, levels) = match head { + DAGPtr::Cnst(c) => unsafe { + let cr = &*c.as_ptr(); + (cr.name.clone(), cr.levels.clone()) + }, + _ => return None, + }; + let ci = env.get(&name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + _ => return None, + }; + if levels.len() != def_params.len() { + return None; + } + let val = subst_expr_levels(def_value, def_params, &levels); + let val_dag = from_expr(&val); + Some(dag_foldl_apps(val_dag.head, &args)) +} + +/// Try nat/native reduction before delta. +fn try_lazy_delta_nat_native(ptr: DAGPtr, env: &Env) -> Option { + let (head, args) = dag_unfold_apps(ptr); + match head { + DAGPtr::Cnst(c) => unsafe { + let name = &(*c.as_ptr()).name; + if let Some(r) = try_reduce_native_dag(name, &args) { + return Some(r); + } + if let Some(r) = try_reduce_nat_dag(name, &args, env) { + return Some(r); + } + None + }, + _ => None, + } +} + +/// Check if a DAGPtr is Nat.zero (either constructor or literal 0). +fn is_nat_zero_dag(ptr: DAGPtr) -> bool { + unsafe { + match ptr { + DAGPtr::Cnst(c) => (*c.as_ptr()).name == mk_name2("Nat", "zero"), + DAGPtr::Lit(l) => { + matches!(&(*l.as_ptr()).val, Literal::NatVal(n) if n.0 == BigUint::ZERO) + }, + _ => false, + } + } +} + +/// If expression is `Nat.succ arg` or `lit (n+1)`, return the predecessor. +fn is_nat_succ_dag(ptr: DAGPtr) -> Option { + unsafe { + match ptr { + DAGPtr::App(app) => { + let a = &*app.as_ptr(); + match a.fun { + DAGPtr::Cnst(c) + if (*c.as_ptr()).name == mk_name2("Nat", "succ") => + { + Some(a.arg) + }, + _ => None, + } + }, + DAGPtr::Lit(l) => match &(*l.as_ptr()).val { + Literal::NatVal(n) if n.0 > BigUint::ZERO => { + Some(nat_lit_dag(Nat(n.0.clone() - BigUint::from(1u64)))) + }, + _ => None, + }, + _ => None, + } + } +} + +/// Check if a name refers to a structure-like inductive: +/// exactly 1 constructor, not recursive, no indices. +fn is_structure_like(name: &Name, env: &Env) -> bool { + match env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + iv.ctors.len() == 1 && !iv.is_rec && iv.num_indices == Nat::ZERO + }, + _ => false, + } +} + +/// Compare reducibility hints for ordering. +fn hint_lt(a: &ReducibilityHints, b: &ReducibilityHints) -> bool { + match (a, b) { + (ReducibilityHints::Opaque, _) => true, + (_, ReducibilityHints::Opaque) => false, + (ReducibilityHints::Abbrev, _) => false, + (_, ReducibilityHints::Abbrev) => true, + (ReducibilityHints::Regular(ha), ReducibilityHints::Regular(hb)) => { + ha < hb + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::env::{BinderInfo, Expr, Level, Literal}; + use crate::ix::kernel::convert::from_expr; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + // ======================================================================== + // subst_dag_levels tests + // ======================================================================== + + #[test] + fn subst_dag_levels_empty_params() { + let e = Expr::sort(Level::param(mk_name("u"))); + let dag = from_expr(&e); + let result = subst_dag_levels(dag.head, &[], &[]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, e); + } + + #[test] + fn subst_dag_levels_sort() { + let u_name = mk_name("u"); + let e = Expr::sort(Level::param(u_name.clone())); + let dag = from_expr(&e); + let result = subst_dag_levels(dag.head, &[u_name], &[Level::zero()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, Expr::sort(Level::zero())); + } + + #[test] + fn subst_dag_levels_cnst() { + let u_name = mk_name("u"); + let e = Expr::cnst(mk_name("List"), vec![Level::param(u_name.clone())]); + let dag = from_expr(&e); + let one = Level::succ(Level::zero()); + let result = subst_dag_levels(dag.head, &[u_name], &[one.clone()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, Expr::cnst(mk_name("List"), vec![one])); + } + + #[test] + fn subst_dag_levels_nested() { + // Pi (A : Sort u) → Sort u with u := 1 + let u_name = mk_name("u"); + let sort_u = Expr::sort(Level::param(u_name.clone())); + let e = Expr::all( + mk_name("A"), + sort_u.clone(), + sort_u, + BinderInfo::Default, + ); + let dag = from_expr(&e); + let one = Level::succ(Level::zero()); + let result = subst_dag_levels(dag.head, &[u_name], &[one.clone()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let sort_1 = Expr::sort(one); + let expected = Expr::all( + mk_name("A"), + sort_1.clone(), + sort_1, + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn subst_dag_levels_no_levels_unchanged() { + // Expression with no Sort or Cnst nodes — pure lambda + let e = Expr::lam( + mk_name("x"), + Expr::lit(Literal::NatVal(Nat::from(0u64))), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let u_name = mk_name("u"); + let result = + subst_dag_levels(dag.head, &[u_name], &[Level::zero()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, e); + } + + // ======================================================================== + // mk_dag_local tests + // ======================================================================== + + #[test] + fn mk_dag_local_creates_free_var() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let name = mk_name("x"); + let ty = from_expr(&nat_type()).head; + let local = tc.mk_dag_local(&name, ty); + match local { + DAGPtr::Var(p) => unsafe { + let var = &*p.as_ptr(); + assert!(matches!(var.binder, BinderPtr::Free)); + assert!(var.fvar_name.is_some()); + }, + _ => panic!("Expected Var"), + } + assert_eq!(tc.local_counter, 1); + assert_eq!(tc.local_types.len(), 1); + } + + #[test] + fn mk_dag_local_unique_names() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let name = mk_name("x"); + let ty = from_expr(&nat_type()).head; + let l1 = tc.mk_dag_local(&name, ty); + let ty2 = from_expr(&nat_type()).head; + let l2 = tc.mk_dag_local(&name, ty2); + // Different pointer identities + assert_ne!(dag_ptr_key(l1), dag_ptr_key(l2)); + // Different fvar names + unsafe { + let n1 = match l1 { + DAGPtr::Var(p) => (*p.as_ptr()).fvar_name.clone().unwrap(), + _ => panic!(), + }; + let n2 = match l2 { + DAGPtr::Var(p) => (*p.as_ptr()).fvar_name.clone().unwrap(), + _ => panic!(), + }; + assert_ne!(n1, n2); + } + } + + // ======================================================================== + // build_pi_over_locals tests + // ======================================================================== + + #[test] + fn build_pi_single_binder() { + // Build: Pi (x : Nat) → Nat + // body = Nat (doesn't reference x), locals = [x_free] + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let nat_dag = from_expr(&nat_type()).head; + let x_local = tc.mk_dag_local(&mk_name("x"), nat_dag); + // Body doesn't use x + let body = from_expr(&nat_type()).head; + let result = build_pi_over_locals( + body, + &[x_local], + &[mk_name("x")], + &[BinderInfo::Default], + &[nat_dag], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn build_pi_dependent() { + // Build: Pi (A : Sort 0) → A + // body = A_local (references A), locals = [A_local] + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort0 = from_expr(&Expr::sort(Level::zero())).head; + let a_local = tc.mk_dag_local(&mk_name("A"), sort0); + // Body IS the local variable + let result = build_pi_over_locals( + a_local, + &[a_local], + &[mk_name("A")], + &[BinderInfo::Default], + &[sort0], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn build_pi_two_binders() { + // Build: Pi (A : Sort 0) (x : A) → A + // Should produce: ForallE A (Sort 0) (ForallE x (bvar 0) (bvar 1)) + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort0 = from_expr(&Expr::sort(Level::zero())).head; + let a_local = tc.mk_dag_local(&mk_name("A"), sort0); + let x_local = tc.mk_dag_local(&mk_name("x"), a_local); + // Body is a_local (the type A) + let result = build_pi_over_locals( + a_local, + &[a_local, x_local], + &[mk_name("A"), mk_name("x")], + &[BinderInfo::Default, BinderInfo::Default], + &[sort0, a_local], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + // ======================================================================== + // DagTypeChecker core method tests + // ======================================================================== + + #[test] + fn whnf_sort_is_identity() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let result = tc.whnf(ptr); + assert_eq!(dag_ptr_key(result), dag_ptr_key(ptr)); + } + + #[test] + fn whnf_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.whnf(ptr); + let r2 = tc.whnf(ptr); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.whnf_cache.len(), 1); + } + + #[test] + fn whnf_no_delta_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.whnf_no_delta(ptr); + let r2 = tc.whnf_no_delta(ptr); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.whnf_no_delta_cache.len(), 1); + } + + #[test] + fn ensure_sort_on_sort() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let result = tc.ensure_sort(DAGPtr::Sort(sort)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Level::zero()); + } + + #[test] + fn ensure_sort_on_non_sort() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let lit = alloc_val(LitNode { + val: Literal::NatVal(Nat::from(42u64)), + parents: None, + }); + let result = tc.ensure_sort(DAGPtr::Lit(lit)); + assert!(result.is_err()); + } + + #[test] + fn ensure_pi_on_pi() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let lam = alloc_lam(0, DAGPtr::Sort(sort), None); + let pi = alloc_pi( + mk_name("x"), + BinderInfo::Default, + DAGPtr::Sort(sort), + lam, + None, + ); + let result = tc.ensure_pi(DAGPtr::Pi(pi)); + assert!(result.is_ok()); + } + + #[test] + fn ensure_pi_on_non_pi() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let lit = alloc_val(LitNode { + val: Literal::NatVal(Nat::from(42u64)), + parents: None, + }); + let result = tc.ensure_pi(DAGPtr::Lit(lit)); + assert!(result.is_err()); + } + + #[test] + fn infer_sort_zero() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let result = tc.infer(DAGPtr::Sort(sort)).unwrap(); + match result { + DAGPtr::Sort(p) => unsafe { + assert_eq!((*p.as_ptr()).level, Level::succ(Level::zero())); + }, + _ => panic!("Expected Sort"), + } + } + + #[test] + fn infer_fvar() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let nat_dag = from_expr(&nat_type()).head; + let local = tc.mk_dag_local(&mk_name("x"), nat_dag); + let result = tc.infer(local).unwrap(); + assert_eq!(dag_ptr_key(result), dag_ptr_key(nat_dag)); + } + + #[test] + fn infer_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.infer(ptr).unwrap(); + let r2 = tc.infer(ptr).unwrap(); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.infer_cache.len(), 1); + } + + #[test] + fn def_eq_pointer_identity() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + assert!(tc.def_eq(ptr, ptr)); + } + + #[test] + fn def_eq_sort_structural() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { level: Level::zero(), parents: None }); + // Same level, different pointers — structurally equal + assert!(tc.def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2))); + } + + #[test] + fn def_eq_sort_different_levels() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { + level: Level::succ(Level::zero()), + parents: None, + }); + assert!(!tc.def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2))); + } + + #[test] + fn assert_def_eq_ok() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + assert!(tc.assert_def_eq(ptr, ptr).is_ok()); + } + + #[test] + fn assert_def_eq_err() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { + level: Level::succ(Level::zero()), + parents: None, + }); + assert!(tc.assert_def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2)).is_err()); + } + + // ======================================================================== + // Type inference tests (Step 3) + // ======================================================================== + + use crate::ix::env::{ + AxiomVal, ConstantVal, ConstructorVal, InductiveVal, + }; + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + fn prop() -> Expr { + Expr::sort(Level::zero()) + } + + /// Build a minimal environment with Nat, Nat.zero, Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + let succ_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: succ_ty, + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + /// Helper: infer the type of an Expr via the DAG TC, return as Expr. + fn dag_infer(env: &Env, e: &Expr) -> Result { + let mut tc = DagTypeChecker::new(env); + let dag = from_expr(e); + let result = tc.infer(dag.head)?; + Ok(dag_to_expr(result)) + } + + // -- Const inference -- + + #[test] + fn dag_infer_const_nat() { + let env = mk_nat_env(); + let ty = dag_infer(&env, &Expr::cnst(mk_name("Nat"), vec![])).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn dag_infer_const_nat_zero() { + let env = mk_nat_env(); + let ty = dag_infer(&env, &nat_zero()).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_const_nat_succ() { + let env = mk_nat_env(); + let ty = + dag_infer(&env, &Expr::cnst(mk_name2("Nat", "succ"), vec![])).unwrap(); + let expected = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn dag_infer_const_unknown() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::cnst(mk_name("Nope"), vec![])).is_err()); + } + + #[test] + fn dag_infer_const_universe_mismatch() { + let env = mk_nat_env(); + assert!( + dag_infer(&env, &Expr::cnst(mk_name("Nat"), vec![Level::zero()])) + .is_err() + ); + } + + // -- Lit inference -- + + #[test] + fn dag_infer_nat_lit() { + let env = Env::default(); + let ty = + dag_infer(&env, &Expr::lit(Literal::NatVal(Nat::from(42u64)))).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_string_lit() { + let env = Env::default(); + let ty = + dag_infer(&env, &Expr::lit(Literal::StrVal("hello".into()))).unwrap(); + assert_eq!(ty, Expr::cnst(mk_name("String"), vec![])); + } + + // -- App inference -- + + #[test] + fn dag_infer_app_succ_zero() { + // Nat.succ Nat.zero : Nat + let env = mk_nat_env(); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_app_identity() { + // (fun x : Nat => x) Nat.zero : Nat + let env = mk_nat_env(); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // -- Lambda inference -- + + #[test] + fn dag_infer_identity_lambda() { + // fun (x : Nat) => x : Nat → Nat + let env = mk_nat_env(); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &e).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn dag_infer_const_lambda() { + // fun (x : Nat) (y : Nat) => x : Nat → Nat → Nat + let env = mk_nat_env(); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &k_fn).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + // -- Pi inference -- + + #[test] + fn dag_infer_pi_nat_to_nat() { + // (Nat → Nat) : Sort 1 + let env = mk_nat_env(); + let pi = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &pi).unwrap(); + if let crate::ix::env::ExprData::Sort(level, _) = ty.as_data() { + assert!( + crate::ix::kernel::level::eq_antisymm( + level, + &Level::succ(Level::zero()) + ), + "Nat → Nat should live in Sort 1, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + #[test] + fn dag_infer_pi_prop_to_prop() { + // P → P : Prop (where P : Prop) + let mut env = Env::default(); + let p_name = mk_name("P"); + env.insert( + p_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: prop(), + }, + is_unsafe: false, + }), + ); + let p = Expr::cnst(p_name, vec![]); + let pi = + Expr::all(mk_name("x"), p.clone(), p.clone(), BinderInfo::Default); + let ty = dag_infer(&env, &pi).unwrap(); + if let crate::ix::env::ExprData::Sort(level, _) = ty.as_data() { + assert!( + crate::ix::kernel::level::is_zero(level), + "Prop → Prop should live in Prop, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + // -- Let inference -- + + #[test] + fn dag_infer_let_simple() { + // let x : Nat := Nat.zero in x : Nat + let env = mk_nat_env(); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // -- Error cases -- + + #[test] + fn dag_infer_free_bvar_fails() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::bvar(Nat::from(0u64))).is_err()); + } + + #[test] + fn dag_infer_fvar_unknown_fails() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::fvar(mk_name("x"))).is_err()); + } + + // ======================================================================== + // Definitional equality tests (Step 4) + // ======================================================================== + + use crate::ix::env::{ + DefinitionSafety, DefinitionVal, ReducibilityHints, TheoremVal, + }; + + /// Helper: check def_eq of two Expr via the DAG TC. + fn dag_def_eq(env: &Env, x: &Expr, y: &Expr) -> bool { + let mut tc = DagTypeChecker::new(env); + let dx = from_expr(x); + let dy = from_expr(y); + tc.def_eq(dx.head, dy.head) + } + + // -- Reflexivity -- + + #[test] + fn dag_def_eq_reflexive_sort() { + let env = Env::default(); + let e = Expr::sort(Level::zero()); + assert!(dag_def_eq(&env, &e, &e)); + } + + #[test] + fn dag_def_eq_reflexive_const() { + let env = mk_nat_env(); + let e = nat_zero(); + assert!(dag_def_eq(&env, &e, &e)); + } + + // -- Sort equality -- + + #[test] + fn dag_def_eq_sort_max_comm() { + let env = Env::default(); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let s1 = Expr::sort(Level::max(u.clone(), v.clone())); + let s2 = Expr::sort(Level::max(v, u)); + assert!(dag_def_eq(&env, &s1, &s2)); + } + + #[test] + fn dag_def_eq_sort_not_equal() { + let env = Env::default(); + let s0 = Expr::sort(Level::zero()); + let s1 = Expr::sort(Level::succ(Level::zero())); + assert!(!dag_def_eq(&env, &s0, &s1)); + } + + // -- Alpha equivalence -- + + #[test] + fn dag_def_eq_alpha_lambda() { + let env = mk_nat_env(); + let e1 = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e2 = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &e1, &e2)); + } + + #[test] + fn dag_def_eq_alpha_pi() { + let env = mk_nat_env(); + let e1 = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e2 = Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &e1, &e2)); + } + + // -- Beta equivalence -- + + #[test] + fn dag_def_eq_beta() { + let env = mk_nat_env(); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let lhs = Expr::app(id_fn, nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + #[test] + fn dag_def_eq_beta_nested() { + let env = mk_nat_env(); + let inner = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + inner, + BinderInfo::Default, + ); + let lhs = Expr::app(Expr::app(k_fn, nat_zero()), nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Delta equivalence -- + + #[test] + fn dag_def_eq_delta() { + let mut env = mk_nat_env(); + let my_zero = mk_name("myZero"); + env.insert( + my_zero.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_zero.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_zero.clone()], + }), + ); + let lhs = Expr::cnst(my_zero, vec![]); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + #[test] + fn dag_def_eq_delta_both_sides() { + let mut env = mk_nat_env(); + for name_str in &["a", "b"] { + let n = mk_name(name_str); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + } + let a = Expr::cnst(mk_name("a"), vec![]); + let b = Expr::cnst(mk_name("b"), vec![]); + assert!(dag_def_eq(&env, &a, &b)); + } + + // -- Zeta equivalence -- + + #[test] + fn dag_def_eq_zeta() { + let env = mk_nat_env(); + let lhs = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Negative tests -- + + #[test] + fn dag_def_eq_different_consts() { + let env = Env::default(); + let nat = nat_type(); + let string = Expr::cnst(mk_name("String"), vec![]); + assert!(!dag_def_eq(&env, &nat, &string)); + } + + // -- App congruence -- + + #[test] + fn dag_def_eq_app_congruence() { + let env = mk_nat_env(); + let f = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let a = nat_zero(); + let lhs = Expr::app(f.clone(), a.clone()); + let rhs = Expr::app(f, a); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_app_different_args() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let lhs = Expr::app(succ.clone(), nat_zero()); + let rhs = Expr::app(succ.clone(), Expr::app(succ, nat_zero())); + assert!(!dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Eta expansion -- + + #[test] + fn dag_def_eq_eta_lam_vs_const() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &eta_expanded, &succ)); + } + + #[test] + fn dag_def_eq_eta_symmetric() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &succ, &eta_expanded)); + } + + // -- Binder full comparison -- + + #[test] + fn dag_def_eq_binder_full_different_domains() { + // (x : myNat) → Nat =def= (x : Nat) → Nat + let mut env = mk_nat_env(); + let my_nat = mk_name("myNat"); + env.insert( + my_nat.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_nat.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: nat_type(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_nat.clone()], + }), + ); + let lhs = Expr::all( + mk_name("x"), + Expr::cnst(my_nat, vec![]), + nat_type(), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_binder_dependent() { + // Pi (A : Sort 0) (x : A) → A =def= Pi (B : Sort 0) (y : B) → B + let env = Env::default(); + let lhs = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("B"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("y"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Nat offset equality -- + + #[test] + fn dag_def_eq_nat_zero_ctor_vs_lit() { + let env = mk_nat_env(); + let lit0 = Expr::lit(Literal::NatVal(Nat::from(0u64))); + assert!(dag_def_eq(&env, &nat_zero(), &lit0)); + } + + #[test] + fn dag_def_eq_nat_lit_vs_succ_lit() { + let env = mk_nat_env(); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::lit(Literal::NatVal(Nat::from(4u64))), + ); + let lit5 = Expr::lit(Literal::NatVal(Nat::from(5u64))); + assert!(dag_def_eq(&env, &lit5, &succ_4)); + } + + #[test] + fn dag_def_eq_nat_lit_not_equal() { + let env = Env::default(); + let a = Expr::lit(Literal::NatVal(Nat::from(1u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(2u64))); + assert!(!dag_def_eq(&env, &a, &b)); + } + + // -- Lazy delta with hints -- + + #[test] + fn dag_def_eq_lazy_delta_higher_unfolds_first() { + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Regular(1), + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Regular(2), + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let lhs = Expr::cnst(b, vec![]); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Proof irrelevance -- + + #[test] + fn dag_def_eq_proof_irrel() { + let mut env = mk_nat_env(); + let true_name = mk_name("True"); + let intro_name = mk_name2("True", "intro"); + env.insert( + true_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: true_name.clone(), + level_params: vec![], + typ: prop(), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![true_name.clone()], + ctors: vec![intro_name.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + intro_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: intro_name.clone(), + level_params: vec![], + typ: Expr::cnst(true_name.clone(), vec![]), + }, + induct: true_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let true_ty = Expr::cnst(true_name, vec![]); + let thm_a = mk_name("thmA"); + let thm_b = mk_name("thmB"); + env.insert( + thm_a.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_a.clone(), + level_params: vec![], + typ: true_ty.clone(), + }, + value: Expr::cnst(intro_name.clone(), vec![]), + all: vec![thm_a.clone()], + }), + ); + env.insert( + thm_b.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_b.clone(), + level_params: vec![], + typ: true_ty, + }, + value: Expr::cnst(intro_name, vec![]), + all: vec![thm_b.clone()], + }), + ); + let a = Expr::cnst(thm_a, vec![]); + let b = Expr::cnst(thm_b, vec![]); + assert!(dag_def_eq(&env, &a, &b)); + } + + // -- Proj congruence -- + + #[test] + fn dag_def_eq_proj_congruence() { + let env = Env::default(); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(0u64), s); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_proj_different_idx() { + let env = Env::default(); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(1u64), s); + assert!(!dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Beta-delta combined -- + + #[test] + fn dag_def_eq_beta_delta_combined() { + let mut env = mk_nat_env(); + let my_id = mk_name("myId"); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + my_id.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_id.clone(), + level_params: vec![], + typ: fun_ty, + }, + value: Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_id.clone()], + }), + ); + let lhs = Expr::app(Expr::cnst(my_id, vec![]), nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Unit-like equality -- + + #[test] + fn dag_def_eq_unit_like() { + let mut env = mk_nat_env(); + let unit_name = mk_name("Unit"); + let unit_star = mk_name2("Unit", "star"); + env.insert( + unit_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: unit_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![unit_name.clone()], + ctors: vec![unit_star.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + unit_star.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: unit_star.clone(), + level_params: vec![], + typ: Expr::cnst(unit_name.clone(), vec![]), + }, + induct: unit_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + // Two distinct fvars of type Unit should be def-eq + let unit_ty = Expr::cnst(unit_name, vec![]); + let mut tc = DagTypeChecker::new(&env); + let x_ty = from_expr(&unit_ty).head; + let x = tc.mk_dag_local(&mk_name("x"), x_ty); + let y_ty = from_expr(&unit_ty).head; + let y = tc.mk_dag_local(&mk_name("y"), y_ty); + assert!(tc.def_eq(x, y)); + } + + // -- Nat add through def_eq -- + + #[test] + fn dag_def_eq_nat_add_result_vs_lit() { + let env = mk_nat_env(); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + Expr::lit(Literal::NatVal(Nat::from(3u64))), + ), + Expr::lit(Literal::NatVal(Nat::from(4u64))), + ); + let lit7 = Expr::lit(Literal::NatVal(Nat::from(7u64))); + assert!(dag_def_eq(&env, &add_3_4, &lit7)); + } +} diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs index c2110381..ada12904 100644 --- a/src/ix/kernel/def_eq.rs +++ b/src/ix/kernel/def_eq.rs @@ -1,5 +1,6 @@ use crate::ix::env::*; use crate::lean::nat::Nat; +use num_bigint::BigUint; use super::level::{eq_antisymm, eq_antisymm_many}; use super::tc::TypeChecker; @@ -12,13 +13,40 @@ enum DeltaResult { } /// Check definitional equality of two expressions. +/// +/// Uses a conjunction work stack: processes pairs iteratively, all must be equal. pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + const DEF_EQ_STEP_LIMIT: u64 = 1_000_000; + let mut work: Vec<(Expr, Expr)> = vec![(x.clone(), y.clone())]; + let mut steps: u64 = 0; + + while let Some((x, y)) = work.pop() { + steps += 1; + if steps > DEF_EQ_STEP_LIMIT { + eprintln!("[def_eq] step limit exceeded ({steps} steps)"); + return false; + } + if !def_eq_step(&x, &y, &mut work, tc) { + return false; + } + } + true +} + +/// Process one def_eq pair. Returns false if definitely not equal. +/// May push additional pairs onto `work` that must all be equal. +fn def_eq_step( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, + tc: &mut TypeChecker, +) -> bool { if let Some(quick) = def_eq_quick_check(x, y) { return quick; } - let x_n = tc.whnf(x); - let y_n = tc.whnf(y); + let x_n = tc.whnf_no_delta(x); + let y_n = tc.whnf_no_delta(y); if let Some(quick) = def_eq_quick_check(&x_n, &y_n) { return quick; @@ -32,9 +60,9 @@ pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { DeltaResult::Found(result) => result, DeltaResult::Exhausted(x_e, y_e) => { def_eq_const(&x_e, &y_e) - || def_eq_proj(&x_e, &y_e, tc) - || def_eq_app(&x_e, &y_e, tc) - || def_eq_binder_full(&x_e, &y_e, tc) + || def_eq_proj_push(&x_e, &y_e, work) + || def_eq_app_push(&x_e, &y_e, work) + || def_eq_binder_full_push(&x_e, &y_e, work) || try_eta_expansion(&x_e, &y_e, tc) || try_eta_struct(&x_e, &y_e, tc) || is_def_eq_unit_like(&x_e, &y_e, tc) @@ -82,16 +110,50 @@ fn def_eq_const(x: &Expr, y: &Expr) -> bool { } } -fn def_eq_proj(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { +/// Proj congruence: push structure pair onto work stack. +fn def_eq_proj_push( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, +) -> bool { match (x.as_data(), y.as_data()) { ( ExprData::Proj(_, idx_l, structure_l, _), ExprData::Proj(_, idx_r, structure_r, _), - ) => idx_l == idx_r && def_eq(structure_l, structure_r, tc), + ) if idx_l == idx_r => { + work.push((structure_l.clone(), structure_r.clone())); + true + }, _ => false, } } +/// App congruence: push head + arg pairs onto work stack. +fn def_eq_app_push( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, +) -> bool { + let (f1, args1) = unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + + work.push((f1, f2)); + for (a, b) in args1.into_iter().zip(args2.into_iter()) { + work.push((a, b)); + } + true +} + +/// Eager app congruence (used by lazy_delta_step where we need a definitive answer). fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { let (f1, args1) = unfold_apps(x); if args1.is_empty() { @@ -111,24 +173,47 @@ fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { args1.iter().zip(args2.iter()).all(|(a, b)| def_eq(a, b, tc)) } -/// Full recursive binder comparison: two Pi or two Lam types with -/// definitionally equal domain types and bodies (ignoring binder names). -fn def_eq_binder_full( +/// Iterative binder comparison: peel matching Pi/Lam layers, pushing +/// domain pairs and the final body pair onto the work stack. +fn def_eq_binder_full_push( x: &Expr, y: &Expr, - tc: &mut TypeChecker, + work: &mut Vec<(Expr, Expr)>, ) -> bool { - match (x.as_data(), y.as_data()) { - ( - ExprData::ForallE(_, t1, b1, _, _), - ExprData::ForallE(_, t2, b2, _, _), - ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), - ( - ExprData::Lam(_, t1, b1, _, _), - ExprData::Lam(_, t2, b2, _, _), - ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), - _ => false, + let mut cx = x.clone(); + let mut cy = y.clone(); + let mut matched = false; + + loop { + match (cx.as_data(), cy.as_data()) { + ( + ExprData::ForallE(_, t1, b1, _, _), + ExprData::ForallE(_, t2, b2, _, _), + ) => { + work.push((t1.clone(), t2.clone())); + cx = b1.clone(); + cy = b2.clone(); + matched = true; + }, + ( + ExprData::Lam(_, t1, b1, _, _), + ExprData::Lam(_, t2, b2, _, _), + ) => { + work.push((t1.clone(), t2.clone())); + cx = b1.clone(); + cy = b2.clone(); + matched = true; + }, + _ => break, + } + } + + if !matched { + return false; } + // Push the final body pair + work.push((cx, cy)); + true } /// Proof irrelevance: if both x and y are proofs of the same proposition, @@ -293,6 +378,66 @@ fn is_def_eq_unit_like(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { } } +/// Check if expression is Nat zero (either `Nat.zero` or `lit 0`). +/// Matches Lean 4's `is_nat_zero`. +fn is_nat_zero(e: &Expr) -> bool { + match e.as_data() { + ExprData::Const(name, _, _) => *name == mk_name2("Nat", "zero"), + ExprData::Lit(Literal::NatVal(n), _) => n.0 == BigUint::ZERO, + _ => false, + } +} + +/// If expression is `Nat.succ arg` or `lit (n+1)`, return the predecessor. +/// Matches Lean 4's `is_nat_succ` / lean4lean's `isNatSuccOf?`. +fn is_nat_succ(e: &Expr) -> Option { + match e.as_data() { + ExprData::App(f, arg, _) => match f.as_data() { + ExprData::Const(name, _, _) if *name == mk_name2("Nat", "succ") => { + Some(arg.clone()) + }, + _ => None, + }, + ExprData::Lit(Literal::NatVal(n), _) if n.0 > BigUint::ZERO => { + Some(Expr::lit(Literal::NatVal(Nat( + n.0.clone() - BigUint::from(1u64), + )))) + }, + _ => None, + } +} + +/// Nat offset equality: `Nat.zero =?= Nat.zero` → true, +/// `Nat.succ n =?= Nat.succ m` → `n =?= m` (recursively via def_eq). +/// Also handles nat literals: `lit 5 =?= Nat.succ (lit 4)` → true. +/// Matches Lean 4's `is_def_eq_offset`. +fn def_eq_nat_offset(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> Option { + if is_nat_zero(x) && is_nat_zero(y) { + return Some(true); + } + match (is_nat_succ(x), is_nat_succ(y)) { + (Some(x_pred), Some(y_pred)) => Some(def_eq(&x_pred, &y_pred, tc)), + _ => None, + } +} + +/// Try to reduce via nat operations or native reductions, returning the reduced form if successful. +fn try_lazy_delta_nat_native(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + match head.as_data() { + ExprData::Const(name, _, _) => { + if let Some(r) = try_reduce_native(name, &args) { + return Some(r); + } + if let Some(r) = try_reduce_nat(e, env) { + return Some(r); + } + None + }, + _ => None, + } +} + /// Lazy delta reduction: unfold definitions step by step. fn lazy_delta_step( x: &Expr, @@ -301,8 +446,38 @@ fn lazy_delta_step( ) -> DeltaResult { let mut x = x.clone(); let mut y = y.clone(); + let mut iters: u32 = 0; + const MAX_DELTA_ITERS: u32 = 10_000; loop { + iters += 1; + if iters > MAX_DELTA_ITERS { + return DeltaResult::Exhausted(x, y); + } + + // Nat offset comparison (Lean 4: isDefEqOffset) + if let Some(quick) = def_eq_nat_offset(&x, &y, tc) { + return DeltaResult::Found(quick); + } + + // Try nat/native reduction on each side before delta + if let Some(x_r) = try_lazy_delta_nat_native(&x, tc.env) { + let x_r = tc.whnf_no_delta(&x_r); + if let Some(quick) = def_eq_quick_check(&x_r, &y) { + return DeltaResult::Found(quick); + } + x = x_r; + continue; + } + if let Some(y_r) = try_lazy_delta_nat_native(&y, tc.env) { + let y_r = tc.whnf_no_delta(&y_r); + if let Some(quick) = def_eq_quick_check(&x, &y_r) { + return DeltaResult::Found(quick); + } + y = y_r; + continue; + } + let x_def = get_applied_def(&x, tc.env); let y_def = get_applied_def(&y, tc.env); @@ -362,10 +537,11 @@ fn get_applied_def( } } -/// Unfold a definition and do cheap WHNF. +/// Unfold a definition and do cheap WHNF (no delta). +/// Matches lean4lean: `let delta e := whnfCore (unfoldDefinition env e).get!`. fn delta(e: &Expr, tc: &mut TypeChecker) -> Expr { match try_unfold_def(e, tc.env) { - Some(unfolded) => tc.whnf(&unfolded), + Some(unfolded) => tc.whnf_no_delta(&unfolded), None => e.clone(), } } @@ -1295,4 +1471,262 @@ mod tests { let y = tc.mk_local(&mk_name("y"), &unit_ty); assert!(tc.def_eq(&x, &y)); } + + // ========================================================================== + // ThmInfo fix: theorems must not enter lazy_delta_step + // ========================================================================== + + /// Build an env with Nat + two ThmInfo constants. + fn mk_thm_env() -> Env { + let mut env = mk_nat_env(); + let thm_a = mk_name("thmA"); + let thm_b = mk_name("thmB"); + let prop = Expr::sort(Level::zero()); + // Two theorems with the same type (True : Prop) + let true_name = mk_name("True"); + env.insert( + true_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: true_name.clone(), + level_params: vec![], + typ: prop.clone(), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![true_name.clone()], + ctors: vec![mk_name2("True", "intro")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let intro_name = mk_name2("True", "intro"); + env.insert( + intro_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: intro_name.clone(), + level_params: vec![], + typ: Expr::cnst(true_name.clone(), vec![]), + }, + induct: true_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let true_ty = Expr::cnst(true_name, vec![]); + env.insert( + thm_a.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_a.clone(), + level_params: vec![], + typ: true_ty.clone(), + }, + value: Expr::cnst(intro_name.clone(), vec![]), + all: vec![thm_a.clone()], + }), + ); + env.insert( + thm_b.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_b.clone(), + level_params: vec![], + typ: true_ty, + }, + value: Expr::cnst(intro_name, vec![]), + all: vec![thm_b.clone()], + }), + ); + env + } + + #[test] + fn test_def_eq_theorem_vs_theorem_terminates() { + // Two theorem constants of the same Prop type should be def-eq + // via proof irrelevance (not via delta). Before the fix, this + // would infinite loop because get_applied_def returned Some for ThmInfo. + let env = mk_thm_env(); + let mut tc = TypeChecker::new(&env); + let a = Expr::cnst(mk_name("thmA"), vec![]); + let b = Expr::cnst(mk_name("thmB"), vec![]); + assert!(tc.def_eq(&a, &b)); + } + + #[test] + fn test_def_eq_theorem_vs_constructor_terminates() { + // A theorem constant vs a constructor of the same type must terminate. + let env = mk_thm_env(); + let mut tc = TypeChecker::new(&env); + let thm = Expr::cnst(mk_name("thmA"), vec![]); + let ctor = Expr::cnst(mk_name2("True", "intro"), vec![]); + // Both have type True (a Prop), so proof irrelevance should make them def-eq + assert!(tc.def_eq(&thm, &ctor)); + } + + #[test] + fn test_get_applied_def_includes_theorems_as_opaque() { + let env = mk_thm_env(); + let thm = Expr::cnst(mk_name("thmA"), vec![]); + let result = get_applied_def(&thm, &env); + assert!(result.is_some()); + let (_, hints) = result.unwrap(); + assert_eq!(hints, ReducibilityHints::Opaque); + } + + // ========================================================================== + // Nat offset equality (is_nat_zero, is_nat_succ, def_eq_nat_offset) + // ========================================================================== + + fn nat_lit(n: u64) -> Expr { + Expr::lit(Literal::NatVal(Nat::from(n))) + } + + #[test] + fn test_is_nat_zero_ctor() { + assert!(super::is_nat_zero(&nat_zero())); + } + + #[test] + fn test_is_nat_zero_lit() { + assert!(super::is_nat_zero(&nat_lit(0))); + } + + #[test] + fn test_is_nat_zero_nonzero_lit() { + assert!(!super::is_nat_zero(&nat_lit(5))); + } + + #[test] + fn test_is_nat_succ_ctor() { + let succ_zero = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + let pred = super::is_nat_succ(&succ_zero); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(4)); + } + + #[test] + fn test_is_nat_succ_lit() { + // lit 5 should decompose to lit 4 (Lean 4: isNatSuccOf?) + let pred = super::is_nat_succ(&nat_lit(5)); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(4)); + } + + #[test] + fn test_is_nat_succ_lit_one() { + // lit 1 should decompose to lit 0 + let pred = super::is_nat_succ(&nat_lit(1)); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(0)); + } + + #[test] + fn test_is_nat_succ_lit_zero() { + // lit 0 should NOT decompose (it's zero, not succ of anything) + assert!(super::is_nat_succ(&nat_lit(0)).is_none()); + } + + #[test] + fn test_is_nat_succ_nat_zero_ctor() { + assert!(super::is_nat_succ(&nat_zero()).is_none()); + } + + #[test] + fn def_eq_nat_zero_ctor_vs_lit() { + // Nat.zero =def= lit 0 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + assert!(tc.def_eq(&nat_zero(), &nat_lit(0))); + } + + #[test] + fn def_eq_nat_lit_vs_succ_lit() { + // lit 5 =def= Nat.succ (lit 4) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + assert!(tc.def_eq(&nat_lit(5), &succ_4)); + } + + #[test] + fn def_eq_nat_succ_lit_vs_lit() { + // Nat.succ (lit 4) =def= lit 5 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + assert!(tc.def_eq(&succ_4, &nat_lit(5))); + } + + #[test] + fn def_eq_nat_lit_one_vs_succ_zero() { + // lit 1 =def= Nat.succ Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_zero = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + assert!(tc.def_eq(&nat_lit(1), &succ_zero)); + } + + #[test] + fn def_eq_nat_lit_not_equal_succ() { + // lit 5 ≠ Nat.succ (lit 5) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_5 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(5), + ); + assert!(!tc.def_eq(&nat_lit(5), &succ_5)); + } + + #[test] + fn def_eq_nat_add_result_vs_lit() { + // Nat.add (lit 3) (lit 4) =def= lit 7 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_lit(3), + ), + nat_lit(4), + ); + assert!(tc.def_eq(&add_3_4, &nat_lit(7))); + } + + #[test] + fn def_eq_nat_add_vs_succ() { + // Nat.add (lit 3) (lit 4) =def= Nat.succ (lit 6) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_lit(3), + ), + nat_lit(4), + ); + let succ_6 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(6), + ); + assert!(tc.def_eq(&add_3_4, &succ_6)); + } } diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs index a06ed819..4cf79d45 100644 --- a/src/ix/kernel/inductive.rs +++ b/src/ix/kernel/inductive.rs @@ -157,23 +157,33 @@ pub fn validate_k_flag( /// Check if an expression mentions a constant by name. fn expr_mentions_const(e: &Expr, name: &Name) -> bool { - match e.as_data() { - ExprData::Const(n, _, _) => n == name, - ExprData::App(f, a, _) => { - expr_mentions_const(f, name) || expr_mentions_const(a, name) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - expr_mentions_const(t, name) || expr_mentions_const(b, name) - }, - ExprData::LetE(_, t, v, b, _, _) => { - expr_mentions_const(t, name) - || expr_mentions_const(v, name) - || expr_mentions_const(b, name) - }, - ExprData::Proj(_, _, s, _) => expr_mentions_const(s, name), - ExprData::Mdata(_, inner, _) => expr_mentions_const(inner, name), - _ => false, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Const(n, _, _) => { + if n == name { + return true; + } + }, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + _ => {}, + } } + false } /// Check that no inductive name from `ind.all` appears in a negative position @@ -228,44 +238,49 @@ fn check_strict_positivity( ind_names: &[Name], tc: &mut TypeChecker, ) -> TcResult<()> { - let whnf_ty = tc.whnf(ty); - - // If no inductive name is mentioned, we're fine - if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { - return Ok(()); - } - - match whnf_ty.as_data() { - ExprData::ForallE(_, domain, body, _, _) => { - // Domain must NOT mention any inductive name - for ind_name in ind_names { - if expr_mentions_const(domain, ind_name) { - return Err(TcError::KernelException { - msg: format!( - "inductive {} occurs in negative position (strict positivity violation)", - ind_name.pretty() - ), - }); + let mut current = ty.clone(); + loop { + let whnf_ty = tc.whnf(¤t); + + // If no inductive name is mentioned, we're fine + if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { + return Ok(()); + } + + match whnf_ty.as_data() { + ExprData::ForallE(_, domain, body, _, _) => { + // Domain must NOT mention any inductive name + for ind_name in ind_names { + if expr_mentions_const(domain, ind_name) { + return Err(TcError::KernelException { + msg: format!( + "inductive {} occurs in negative position (strict positivity violation)", + ind_name.pretty() + ), + }); + } } - } - // Recurse into body - check_strict_positivity(body, ind_names, tc) - }, - _ => { - // The inductive is mentioned and we're not in a Pi — check if - // it's simply an application `I args...` (which is OK). - let (head, _) = unfold_apps(&whnf_ty); - match head.as_data() { - ExprData::Const(name, _, _) - if ind_names.iter().any(|n| n == name) => - { - Ok(()) - }, - _ => Err(TcError::KernelException { - msg: "inductive type occurs in a non-positive position".into(), - }), - } - }, + // Continue with body (was tail-recursive) + current = body.clone(); + }, + _ => { + // The inductive is mentioned and we're not in a Pi — check if + // it's simply an application `I args...` (which is OK). + let (head, _) = unfold_apps(&whnf_ty); + match head.as_data() { + ExprData::Const(name, _, _) + if ind_names.iter().any(|n| n == name) => + { + return Ok(()); + }, + _ => { + return Err(TcError::KernelException { + msg: "inductive type occurs in a non-positive position".into(), + }); + }, + } + }, + } } } diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs index 90931ca6..80195e35 100644 --- a/src/ix/kernel/level.rs +++ b/src/ix/kernel/level.rs @@ -245,31 +245,41 @@ pub fn all_uparams_defined(level: &Level, params: &[Name]) -> bool { /// Check that all universe parameters in an expression are contained in `params`. /// Recursively walks the Expr, checking all Levels in Sort and Const nodes. pub fn all_expr_uparams_defined(e: &Expr, params: &[Name]) -> bool { - match e.as_data() { - ExprData::Sort(level, _) => all_uparams_defined(level, params), - ExprData::Const(_, levels, _) => { - levels.iter().all(|l| all_uparams_defined(l, params)) - }, - ExprData::App(f, a, _) => { - all_expr_uparams_defined(f, params) - && all_expr_uparams_defined(a, params) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - all_expr_uparams_defined(t, params) - && all_expr_uparams_defined(b, params) - }, - ExprData::LetE(_, t, v, b, _, _) => { - all_expr_uparams_defined(t, params) - && all_expr_uparams_defined(v, params) - && all_expr_uparams_defined(b, params) - }, - ExprData::Proj(_, _, s, _) => all_expr_uparams_defined(s, params), - ExprData::Mdata(_, inner, _) => all_expr_uparams_defined(inner, params), - ExprData::Bvar(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => true, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Sort(level, _) => { + if !all_uparams_defined(level, params) { + return false; + } + }, + ExprData::Const(_, levels, _) => { + if !levels.iter().all(|l| all_uparams_defined(l, params)) { + return false; + } + }, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => {}, + } } + true } /// Check that a list of levels are all Params with no duplicates. diff --git a/src/ix/kernel/mod.rs b/src/ix/kernel/mod.rs index d6a5750e..23aea4f6 100644 --- a/src/ix/kernel/mod.rs +++ b/src/ix/kernel/mod.rs @@ -1,5 +1,6 @@ pub mod convert; pub mod dag; +pub mod dag_tc; pub mod def_eq; pub mod dll; pub mod error; diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs index e80416fd..604fbf02 100644 --- a/src/ix/kernel/tc.rs +++ b/src/ix/kernel/tc.rs @@ -1,5 +1,6 @@ use crate::ix::env::*; use crate::lean::nat::Nat; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rustc_hash::FxHashMap; use super::def_eq::def_eq; @@ -13,9 +14,13 @@ type TcResult = Result; pub struct TypeChecker<'env> { pub env: &'env Env, pub whnf_cache: FxHashMap, + pub whnf_no_delta_cache: FxHashMap, pub infer_cache: FxHashMap, pub local_counter: u64, pub local_types: FxHashMap, + pub def_eq_calls: u64, + pub whnf_calls: u64, + pub infer_calls: u64, } impl<'env> TypeChecker<'env> { @@ -23,9 +28,13 @@ impl<'env> TypeChecker<'env> { TypeChecker { env, whnf_cache: FxHashMap::default(), + whnf_no_delta_cache: FxHashMap::default(), infer_cache: FxHashMap::default(), local_counter: 0, local_types: FxHashMap::default(), + def_eq_calls: 0, + whnf_calls: 0, + infer_calls: 0, } } @@ -37,8 +46,33 @@ impl<'env> TypeChecker<'env> { if let Some(cached) = self.whnf_cache.get(e) { return cached.clone(); } + self.whnf_calls += 1; + let tag = match e.as_data() { + ExprData::Sort(..) => "Sort", + ExprData::Const(_, _, _) => "Const", + ExprData::App(..) => "App", + ExprData::Lam(..) => "Lam", + ExprData::ForallE(..) => "Pi", + ExprData::LetE(..) => "Let", + ExprData::Lit(..) => "Lit", + ExprData::Proj(..) => "Proj", + ExprData::Fvar(..) => "Fvar", + ExprData::Bvar(..) => "Bvar", + ExprData::Mvar(..) => "Mvar", + ExprData::Mdata(..) => "Mdata", + }; + eprintln!("[tc.whnf] #{} {tag}", self.whnf_calls); let result = whnf(e, self.env); - self.whnf_cache.insert(e.clone(), result.clone()); + eprintln!("[tc.whnf] #{} {tag} done", self.whnf_calls); + result + } + + pub fn whnf_no_delta(&mut self, e: &Expr) -> Expr { + if let Some(cached) = self.whnf_no_delta_cache.get(e) { + return cached.clone(); + } + let result = whnf_no_delta(e, self.env); + self.whnf_no_delta_cache.insert(e.clone(), result.clone()); result } @@ -102,40 +136,87 @@ impl<'env> TypeChecker<'env> { if let Some(cached) = self.infer_cache.get(e) { return Ok(cached.clone()); } + self.infer_calls += 1; + let tag = match e.as_data() { + ExprData::Sort(..) => "Sort".to_string(), + ExprData::Const(n, _, _) => format!("Const({})", n.pretty()), + ExprData::App(..) => "App".to_string(), + ExprData::Lam(..) => "Lam".to_string(), + ExprData::ForallE(..) => "Pi".to_string(), + ExprData::LetE(..) => "Let".to_string(), + ExprData::Lit(..) => "Lit".to_string(), + ExprData::Proj(..) => "Proj".to_string(), + ExprData::Fvar(n, _) => format!("Fvar({})", n.pretty()), + ExprData::Bvar(..) => "Bvar".to_string(), + ExprData::Mvar(..) => "Mvar".to_string(), + ExprData::Mdata(..) => "Mdata".to_string(), + }; + eprintln!("[tc.infer] #{} {tag}", self.infer_calls); let result = self.infer_core(e)?; self.infer_cache.insert(e.clone(), result.clone()); Ok(result) } fn infer_core(&mut self, e: &Expr) -> TcResult { - match e.as_data() { - ExprData::Sort(level, _) => self.infer_sort(level), - ExprData::Const(name, levels, _) => self.infer_const(name, levels), - ExprData::App(..) => self.infer_app(e), - ExprData::Lam(..) => self.infer_lambda(e), - ExprData::ForallE(..) => self.infer_pi(e), - ExprData::LetE(_, typ, val, body, _, _) => { - self.infer_let(typ, val, body) - }, - ExprData::Lit(lit, _) => self.infer_lit(lit), - ExprData::Proj(type_name, idx, structure, _) => { - self.infer_proj(type_name, idx, structure) - }, - ExprData::Mdata(_, inner, _) => self.infer(inner), - ExprData::Fvar(name, _) => { - match self.local_types.get(name) { - Some(ty) => Ok(ty.clone()), - None => Err(TcError::KernelException { - msg: "cannot infer type of free variable without context".into(), - }), - } - }, - ExprData::Bvar(idx, _) => Err(TcError::FreeBoundVariable { - idx: idx.to_u64().unwrap_or(u64::MAX), - }), - ExprData::Mvar(..) => Err(TcError::KernelException { - msg: "cannot infer type of metavariable".into(), - }), + // Peel Mdata and Let layers iteratively to avoid stack depth + let mut cursor = e.clone(); + loop { + match cursor.as_data() { + ExprData::Mdata(_, inner, _) => { + // Check cache for inner before recursing + if let Some(cached) = self.infer_cache.get(inner) { + return Ok(cached.clone()); + } + cursor = inner.clone(); + continue; + }, + ExprData::LetE(_, typ, val, body, _, _) => { + let val_ty = self.infer(val)?; + self.assert_def_eq(&val_ty, typ)?; + let body_inst = inst(body, &[val.clone()]); + // Check cache for body_inst before looping + if let Some(cached) = self.infer_cache.get(&body_inst) { + return Ok(cached.clone()); + } + // Cache the current let expression's result once we compute it + let orig = cursor.clone(); + cursor = body_inst; + // We need to compute the result and cache it for `orig` + let result = self.infer(&cursor)?; + self.infer_cache.insert(orig, result.clone()); + return Ok(result); + }, + ExprData::Sort(level, _) => return self.infer_sort(level), + ExprData::Const(name, levels, _) => { + return self.infer_const(name, levels) + }, + ExprData::App(..) => return self.infer_app(&cursor), + ExprData::Lam(..) => return self.infer_lambda(&cursor), + ExprData::ForallE(..) => return self.infer_pi(&cursor), + ExprData::Lit(lit, _) => return self.infer_lit(lit), + ExprData::Proj(type_name, idx, structure, _) => { + return self.infer_proj(type_name, idx, structure) + }, + ExprData::Fvar(name, _) => { + return match self.local_types.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context" + .into(), + }), + } + }, + ExprData::Bvar(idx, _) => { + return Err(TcError::FreeBoundVariable { + idx: idx.to_u64().unwrap_or(u64::MAX), + }) + }, + ExprData::Mvar(..) => { + return Err(TcError::KernelException { + msg: "cannot infer type of metavariable".into(), + }) + }, + } } } @@ -253,19 +334,6 @@ impl<'env> TypeChecker<'env> { Ok(Expr::sort(result_level)) } - fn infer_let( - &mut self, - typ: &Expr, - val: &Expr, - body: &Expr, - ) -> TcResult { - // Verify value matches declared type - let val_ty = self.infer(val)?; - self.assert_def_eq(&val_ty, typ)?; - let body_inst = inst(body, &[val.clone()]); - self.infer(&body_inst) - } - fn infer_lit(&mut self, lit: &Literal) -> TcResult { match lit { Literal::NatVal(_) => { @@ -375,7 +443,11 @@ impl<'env> TypeChecker<'env> { // ========================================================================== pub fn def_eq(&mut self, x: &Expr, y: &Expr) -> bool { - def_eq(x, y, self) + self.def_eq_calls += 1; + eprintln!("[tc.def_eq] #{}", self.def_eq_calls); + let result = def_eq(x, y, self); + eprintln!("[tc.def_eq] #{} done => {result}", self.def_eq_calls); + result } pub fn assert_def_eq(&mut self, x: &Expr, y: &Expr) -> TcResult<()> { @@ -432,6 +504,31 @@ impl<'env> TypeChecker<'env> { Ok(()) } + /// Check a declaration that has both a type and a value (DefnInfo, ThmInfo, OpaqueInfo). + fn check_value_declar( + &mut self, + cnst: &ConstantVal, + value: &Expr, + ) -> TcResult<()> { + eprintln!("[check_value_declar] checking type for {}", cnst.name.pretty()); + self.check_declar_info(cnst)?; + eprintln!("[check_value_declar] type OK, checking value uparams"); + if !all_expr_uparams_defined(value, &cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + cnst.name.pretty() + ), + }); + } + eprintln!("[check_value_declar] inferring value type"); + let inferred_type = self.infer(value)?; + eprintln!("[check_value_declar] inferred, checking def_eq"); + self.assert_def_eq(&inferred_type, &cnst.typ)?; + eprintln!("[check_value_declar] done"); + Ok(()) + } + /// Check a single declaration. pub fn check_declar( &mut self, @@ -442,43 +539,13 @@ impl<'env> TypeChecker<'env> { self.check_declar_info(&v.cnst)?; }, ConstantInfo::DefnInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::ThmInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::OpaqueInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::QuotInfo(v) => { self.check_declar_info(&v.cnst)?; @@ -512,16 +579,77 @@ impl<'env> TypeChecker<'env> { } } -/// Check all declarations in an environment. +/// Check all declarations in an environment in parallel. pub fn check_env(env: &Env) -> Vec<(Name, TcError)> { - let mut errors = Vec::new(); - for (name, ci) in env.iter() { - let mut tc = TypeChecker::new(env); - if let Err(e) = tc.check_declar(ci) { - errors.push((name.clone(), e)); - } + use std::collections::BTreeSet; + use std::io::Write; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Mutex; + + let total = env.len(); + let checked = AtomicUsize::new(0); + + struct Display { + active: BTreeSet, + prev_lines: usize, } - errors + let display = Mutex::new(Display { active: BTreeSet::new(), prev_lines: 0 }); + + let refresh = |d: &mut Display, checked: usize| { + let mut stderr = std::io::stderr().lock(); + if d.prev_lines > 0 { + write!(stderr, "\x1b[{}A", d.prev_lines).ok(); + } + write!( + stderr, + "\x1b[2K[check_env] {}/{} — {} active\n", + checked, + total, + d.active.len() + ) + .ok(); + let mut new_lines = 1; + for name in &d.active { + write!(stderr, "\x1b[2K {}\n", name).ok(); + new_lines += 1; + } + let extra = d.prev_lines.saturating_sub(new_lines); + for _ in 0..extra { + write!(stderr, "\x1b[2K\n").ok(); + } + if extra > 0 { + write!(stderr, "\x1b[{}A", extra).ok(); + } + d.prev_lines = new_lines; + stderr.flush().ok(); + }; + + env + .par_iter() + .filter_map(|(name, ci)| { + let pretty = name.pretty(); + { + let mut d = display.lock().unwrap(); + d.active.insert(pretty.clone()); + refresh(&mut d, checked.load(Ordering::Relaxed)); + } + + let mut tc = TypeChecker::new(env); + let result = tc.check_declar(ci); + + let n = checked.fetch_add(1, Ordering::Relaxed) + 1; + { + let mut d = display.lock().unwrap(); + d.active.remove(&pretty); + refresh(&mut d, n); + } + + match result { + Ok(()) => None, + Err(e) => Some((name.clone(), e)), + } + }) + .collect() } #[cfg(test)] @@ -553,9 +681,18 @@ mod tests { Expr::sort(Level::param(mk_name("u"))) } - /// Build a minimal environment with Nat, Nat.zero, and Nat.succ. + fn bvar(n: u64) -> Expr { + Expr::bvar(Nat::from(n)) + } + + fn nat_succ_expr() -> Expr { + Expr::cnst(mk_name2("Nat", "succ"), vec![]) + } + + /// Build a minimal environment with Nat, Nat.zero, Nat.succ, and Nat.rec. fn mk_nat_env() -> Env { let mut env = Env::default(); + let u = mk_name("u"); let nat_name = mk_name("Nat"); // Nat : Sort 1 @@ -614,6 +751,147 @@ mod tests { }); env.insert(succ_name, succ); + // Nat.rec.{u} : + // {motive : Nat → Sort u} → + // motive Nat.zero → + // ((n : Nat) → motive n → motive (Nat.succ n)) → + // (t : Nat) → motive t + let rec_name = mk_name2("Nat", "rec"); + + // Build the type with de Bruijn indices. + // Binder stack (from outermost): motive(3), z(2), s(1), t(0) + // At the innermost body: motive=bvar(3), z=bvar(2), s=bvar(1), t=bvar(0) + let motive_type = Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ); // Nat → Sort u + + // s type: (n : Nat) → motive n → motive (Nat.succ n) + // At s's position: motive=bvar(1), z=bvar(0) + // Inside forallE "n": motive=bvar(2), z=bvar(1), n=bvar(0) + // Inside forallE "_": motive=bvar(3), z=bvar(2), n=bvar(1), _=bvar(0) + let s_type = Expr::all( + mk_name("n"), + nat_type(), + Expr::all( + mk_name("_"), + Expr::app(bvar(2), bvar(0)), // motive n + Expr::app(bvar(3), Expr::app(nat_succ_expr(), bvar(1))), // motive (Nat.succ n) + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let rec_type = Expr::all( + mk_name("motive"), + motive_type.clone(), + Expr::all( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), // motive Nat.zero + Expr::all( + mk_name("s"), + s_type, + Expr::all( + mk_name("t"), + nat_type(), + Expr::app(bvar(3), bvar(0)), // motive t + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Implicit, + ); + + // Zero rule RHS: fun (motive) (z) (s) => z + // Inside: motive=bvar(2), z=bvar(1), s=bvar(0) + let zero_rhs = Expr::lam( + mk_name("motive"), + motive_type.clone(), + Expr::lam( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), + Expr::lam( + mk_name("s"), + nat_type(), // placeholder type for s (not checked) + bvar(1), // z + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + // Succ rule RHS: fun (motive) (z) (s) (n) => s n (Nat.rec.{u} motive z s n) + // Inside: motive=bvar(3), z=bvar(2), s=bvar(1), n=bvar(0) + let nat_rec_u = + Expr::cnst(rec_name.clone(), vec![Level::param(u.clone())]); + let recursive_call = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_u, bvar(3)), // Nat.rec motive + bvar(2), // z + ), + bvar(1), // s + ), + bvar(0), // n + ); + let succ_rhs = Expr::lam( + mk_name("motive"), + motive_type, + Expr::lam( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), + Expr::lam( + mk_name("s"), + nat_type(), // placeholder + Expr::lam( + mk_name("n"), + nat_type(), + Expr::app( + Expr::app(bvar(1), bvar(0)), // s n + recursive_call, // (Nat.rec motive z s n) + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: rec_name.clone(), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: zero_rhs, + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: succ_rhs, + }, + ], + k: false, + is_unsafe: false, + }); + env.insert(rec_name, rec); + env } @@ -1691,4 +1969,219 @@ mod tests { }); assert!(tc.check_declar(&rec).is_err()); } + + // ========================================================================== + // check_declar: Nat.add via Nat.rec + // ========================================================================== + + #[test] + fn check_nat_add_via_rec() { + // Nat.add : Nat → Nat → Nat := + // fun (n m : Nat) => @Nat.rec.{1} (fun _ => Nat) n (fun _ ih => Nat.succ ih) m + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + + let nat = nat_type(); + let nat_rec_1 = Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ); + + // motive: fun (_ : Nat) => Nat + let motive = Expr::lam( + mk_name("_"), + nat.clone(), + nat.clone(), + BinderInfo::Default, + ); + + // step: fun (_ : Nat) (ih : Nat) => Nat.succ ih + let step = Expr::lam( + mk_name("_"), + nat.clone(), + Expr::lam( + mk_name("ih"), + nat.clone(), + Expr::app(nat_succ_expr(), bvar(0)), // Nat.succ ih + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + // value: fun (n m : Nat) => @Nat.rec.{1} (fun _ => Nat) n (fun _ ih => Nat.succ ih) m + // = fun n m => Nat.rec motive n step m + let body = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_1, motive), + bvar(1), // n + ), + step, + ), + bvar(0), // m + ); + let value = Expr::lam( + mk_name("n"), + nat.clone(), + Expr::lam( + mk_name("m"), + nat.clone(), + body, + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let typ = Expr::all( + mk_name("n"), + nat.clone(), + Expr::all(mk_name("m"), nat.clone(), nat, BinderInfo::Default), + BinderInfo::Default, + ); + + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name2("Nat", "add"), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name2("Nat", "add")], + }); + assert!(tc.check_declar(&defn).is_ok()); + } + + /// Build mk_nat_env + Nat.add definition in the env. + fn mk_nat_add_env() -> Env { + let mut env = mk_nat_env(); + let nat = nat_type(); + + let nat_rec_1 = Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ); + + let motive = Expr::lam( + mk_name("_"), + nat.clone(), + nat.clone(), + BinderInfo::Default, + ); + + let step = Expr::lam( + mk_name("_"), + nat.clone(), + Expr::lam( + mk_name("ih"), + nat.clone(), + Expr::app(nat_succ_expr(), bvar(0)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let body = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_1, motive), + bvar(1), // n + ), + step, + ), + bvar(0), // m + ); + let value = Expr::lam( + mk_name("n"), + nat.clone(), + Expr::lam( + mk_name("m"), + nat.clone(), + body, + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let typ = Expr::all( + mk_name("n"), + nat.clone(), + Expr::all(mk_name("m"), nat.clone(), nat, BinderInfo::Default), + BinderInfo::Default, + ); + + env.insert( + mk_name2("Nat", "add"), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name2("Nat", "add"), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name2("Nat", "add")], + }), + ); + + env + } + + #[test] + fn check_nat_add_env() { + // Verify that the full Nat + Nat.add environment typechecks + let env = mk_nat_add_env(); + let errors = check_env(&env); + assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors); + } + + #[test] + fn whnf_nat_add_zero_zero() { + // Nat.add Nat.zero Nat.zero should WHNF to 0 (as nat literal) + let env = mk_nat_add_env(); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_zero(), + ), + nat_zero(), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::lit(Literal::NatVal(Nat::from(0u64)))); + } + + #[test] + fn whnf_nat_add_lit() { + // Nat.add 2 3 should WHNF to 5 + let env = mk_nat_add_env(); + let two = Expr::lit(Literal::NatVal(Nat::from(2u64))); + let three = Expr::lit(Literal::NatVal(Nat::from(3u64))); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + two, + ), + three, + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::lit(Literal::NatVal(Nat::from(5u64)))); + } + + #[test] + fn infer_nat_add_applied() { + // Nat.add Nat.zero Nat.zero : Nat + let env = mk_nat_add_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_zero(), + ), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } } diff --git a/src/ix/kernel/upcopy.rs b/src/ix/kernel/upcopy.rs index 89dae8a0..a3657ac4 100644 --- a/src/ix/kernel/upcopy.rs +++ b/src/ix/kernel/upcopy.rs @@ -10,223 +10,225 @@ use super::dll::DLL; // ============================================================================ pub fn upcopy(new_child: DAGPtr, cc: ParentPtr) { - unsafe { - match cc { - ParentPtr::Root => {}, - ParentPtr::LamBod(link) => { - let lam = &*link.as_ptr(); - let var = &lam.var; - let new_lam = alloc_lam(var.depth, new_child, None); - let new_lam_ref = &mut *new_lam.as_ptr(); - let bod_ref_ptr = - NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); - add_to_parents(new_child, bod_ref_ptr); - let new_var_ptr = - NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); - for parent in DLL::iter_option(var.parents) { - upcopy(DAGPtr::Var(new_var_ptr), *parent); - } - for parent in DLL::iter_option(lam.parents) { - upcopy(DAGPtr::Lam(new_lam), *parent); - } - }, - ParentPtr::AppFun(link) => { - let app = &mut *link.as_ptr(); - match app.copy { - Some(cache) => { - (*cache.as_ptr()).fun = new_child; - }, - None => { - let new_app = alloc_app_no_uplinks(new_child, app.arg); - app.copy = Some(new_app); - for parent in DLL::iter_option(app.parents) { - upcopy(DAGPtr::App(new_app), *parent); - } - }, - } - }, - ParentPtr::AppArg(link) => { - let app = &mut *link.as_ptr(); - match app.copy { - Some(cache) => { - (*cache.as_ptr()).arg = new_child; - }, - None => { - let new_app = alloc_app_no_uplinks(app.fun, new_child); - app.copy = Some(new_app); - for parent in DLL::iter_option(app.parents) { - upcopy(DAGPtr::App(new_app), *parent); - } - }, - } - }, - ParentPtr::FunDom(link) => { - let fun = &mut *link.as_ptr(); - match fun.copy { - Some(cache) => { - (*cache.as_ptr()).dom = new_child; - }, - None => { - let new_fun = alloc_fun_no_uplinks( - fun.binder_name.clone(), - fun.binder_info.clone(), - new_child, - fun.img, - ); - fun.copy = Some(new_fun); - for parent in DLL::iter_option(fun.parents) { - upcopy(DAGPtr::Fun(new_fun), *parent); - } - }, - } - }, - ParentPtr::FunImg(link) => { - let fun = &mut *link.as_ptr(); - // new_child must be a Lam - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("FunImg parent expects Lam child"), - }; - match fun.copy { - Some(cache) => { - (*cache.as_ptr()).img = new_lam; - }, - None => { - let new_fun = alloc_fun_no_uplinks( - fun.binder_name.clone(), - fun.binder_info.clone(), - fun.dom, - new_lam, - ); - fun.copy = Some(new_fun); - for parent in DLL::iter_option(fun.parents) { - upcopy(DAGPtr::Fun(new_fun), *parent); - } - }, - } - }, - ParentPtr::PiDom(link) => { - let pi = &mut *link.as_ptr(); - match pi.copy { - Some(cache) => { - (*cache.as_ptr()).dom = new_child; - }, - None => { - let new_pi = alloc_pi_no_uplinks( - pi.binder_name.clone(), - pi.binder_info.clone(), - new_child, - pi.img, - ); - pi.copy = Some(new_pi); - for parent in DLL::iter_option(pi.parents) { - upcopy(DAGPtr::Pi(new_pi), *parent); - } - }, - } - }, - ParentPtr::PiImg(link) => { - let pi = &mut *link.as_ptr(); - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("PiImg parent expects Lam child"), - }; - match pi.copy { - Some(cache) => { - (*cache.as_ptr()).img = new_lam; - }, - None => { - let new_pi = alloc_pi_no_uplinks( - pi.binder_name.clone(), - pi.binder_info.clone(), - pi.dom, - new_lam, - ); - pi.copy = Some(new_pi); - for parent in DLL::iter_option(pi.parents) { - upcopy(DAGPtr::Pi(new_pi), *parent); - } - }, - } - }, - ParentPtr::LetTyp(link) => { - let let_node = &mut *link.as_ptr(); - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).typ = new_child; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - new_child, - let_node.val, - let_node.bod, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::LetVal(link) => { - let let_node = &mut *link.as_ptr(); - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).val = new_child; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - let_node.typ, - new_child, - let_node.bod, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::LetBod(link) => { - let let_node = &mut *link.as_ptr(); - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("LetBod parent expects Lam child"), - }; - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).bod = new_lam; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - let_node.typ, - let_node.val, - new_lam, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::ProjExpr(link) => { - let proj = &*link.as_ptr(); - let new_proj = alloc_proj_no_uplinks( - proj.type_name.clone(), - proj.idx.clone(), - new_child, - ); - for parent in DLL::iter_option(proj.parents) { - upcopy(DAGPtr::Proj(new_proj), *parent); - } - }, + let mut stack: Vec<(DAGPtr, ParentPtr)> = vec![(new_child, cc)]; + while let Some((new_child, cc)) = stack.pop() { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + let var = &lam.var; + let new_lam = alloc_lam(var.depth, new_child, None); + let new_lam_ref = &mut *new_lam.as_ptr(); + let bod_ref_ptr = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_child, bod_ref_ptr); + let new_var_ptr = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + for parent in DLL::iter_option(var.parents) { + stack.push((DAGPtr::Var(new_var_ptr), *parent)); + } + for parent in DLL::iter_option(lam.parents) { + stack.push((DAGPtr::Lam(new_lam), *parent)); + } + }, + ParentPtr::AppFun(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).fun = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(new_child, app.arg); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + stack.push((DAGPtr::App(new_app), *parent)); + } + }, + } + }, + ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).arg = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(app.fun, new_child); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + stack.push((DAGPtr::App(new_app), *parent)); + } + }, + } + }, + ParentPtr::FunDom(link) => { + let fun = &mut *link.as_ptr(); + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_child, + fun.img, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + stack.push((DAGPtr::Fun(new_fun), *parent)); + } + }, + } + }, + ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("FunImg parent expects Lam child"), + }; + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + new_lam, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + stack.push((DAGPtr::Fun(new_fun), *parent)); + } + }, + } + }, + ParentPtr::PiDom(link) => { + let pi = &mut *link.as_ptr(); + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_child, + pi.img, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + stack.push((DAGPtr::Pi(new_pi), *parent)); + } + }, + } + }, + ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("PiImg parent expects Lam child"), + }; + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + pi.dom, + new_lam, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + stack.push((DAGPtr::Pi(new_pi), *parent)); + } + }, + } + }, + ParentPtr::LetTyp(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).typ = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + new_child, + let_node.val, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::LetVal(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).val = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + new_child, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("LetBod parent expects Lam child"), + }; + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).bod = new_lam; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + let_node.val, + new_lam, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + let new_proj = alloc_proj_no_uplinks( + proj.type_name.clone(), + proj.idx.clone(), + new_child, + ); + for parent in DLL::iter_option(proj.parents) { + stack.push((DAGPtr::Proj(new_proj), *parent)); + } + }, + } } } } @@ -352,79 +354,82 @@ fn alloc_proj_no_uplinks( // ============================================================================ pub fn clean_up(cc: &ParentPtr) { - unsafe { - match cc { - ParentPtr::Root => {}, - ParentPtr::LamBod(link) => { - let lam = &*link.as_ptr(); - for parent in DLL::iter_option(lam.var.parents) { - clean_up(parent); - } - for parent in DLL::iter_option(lam.parents) { - clean_up(parent); - } - }, - ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { - let app = &mut *link.as_ptr(); - if let Some(app_copy) = app.copy { - let App { fun, arg, fun_ref, arg_ref, .. } = - &mut *app_copy.as_ptr(); - app.copy = None; - add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); - add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); - for parent in DLL::iter_option(app.parents) { - clean_up(parent); + let mut stack: Vec = vec![*cc]; + while let Some(cc) = stack.pop() { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + for parent in DLL::iter_option(lam.var.parents) { + stack.push(*parent); } - } - }, - ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { - let fun = &mut *link.as_ptr(); - if let Some(fun_copy) = fun.copy { - let Fun { dom, img, dom_ref, img_ref, .. } = - &mut *fun_copy.as_ptr(); - fun.copy = None; - add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); - for parent in DLL::iter_option(fun.parents) { - clean_up(parent); + for parent in DLL::iter_option(lam.parents) { + stack.push(*parent); } - } - }, - ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { - let pi = &mut *link.as_ptr(); - if let Some(pi_copy) = pi.copy { - let Pi { dom, img, dom_ref, img_ref, .. } = - &mut *pi_copy.as_ptr(); - pi.copy = None; - add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); - for parent in DLL::iter_option(pi.parents) { - clean_up(parent); + }, + ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + if let Some(app_copy) = app.copy { + let App { fun, arg, fun_ref, arg_ref, .. } = + &mut *app_copy.as_ptr(); + app.copy = None; + add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); + add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); + for parent in DLL::iter_option(app.parents) { + stack.push(*parent); + } } - } - }, - ParentPtr::LetTyp(link) - | ParentPtr::LetVal(link) - | ParentPtr::LetBod(link) => { - let let_node = &mut *link.as_ptr(); - if let Some(let_copy) = let_node.copy { - let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = - &mut *let_copy.as_ptr(); - let_node.copy = None; - add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); - add_to_parents(*val, NonNull::new(val_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); - for parent in DLL::iter_option(let_node.parents) { - clean_up(parent); + }, + ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + if let Some(fun_copy) = fun.copy { + let Fun { dom, img, dom_ref, img_ref, .. } = + &mut *fun_copy.as_ptr(); + fun.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(fun.parents) { + stack.push(*parent); + } } - } - }, - ParentPtr::ProjExpr(link) => { - let proj = &*link.as_ptr(); - for parent in DLL::iter_option(proj.parents) { - clean_up(parent); - } - }, + }, + ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + if let Some(pi_copy) = pi.copy { + let Pi { dom, img, dom_ref, img_ref, .. } = + &mut *pi_copy.as_ptr(); + pi.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(pi.parents) { + stack.push(*parent); + } + } + }, + ParentPtr::LetTyp(link) + | ParentPtr::LetVal(link) + | ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + if let Some(let_copy) = let_node.copy { + let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = + &mut *let_copy.as_ptr(); + let_node.copy = None; + add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); + add_to_parents(*val, NonNull::new(val_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); + for parent in DLL::iter_option(let_node.parents) { + stack.push(*parent); + } + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + for parent in DLL::iter_option(proj.parents) { + stack.push(*parent); + } + }, + } } } } @@ -476,119 +481,122 @@ pub fn replace_child(old: DAGPtr, new: DAGPtr) { // Free dead nodes // ============================================================================ -pub fn free_dead_node(node: DAGPtr) { - unsafe { - match node { - DAGPtr::Lam(link) => { - let lam = &*link.as_ptr(); - let bod_ref_ptr = &lam.bod_ref as *const Parents; - if let Some(remaining) = (*bod_ref_ptr).unlink_node() { - set_parents(lam.bod, Some(remaining)); - } else { - set_parents(lam.bod, None); - free_dead_node(lam.bod); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - let fun_ref_ptr = &app.fun_ref as *const Parents; - if let Some(remaining) = (*fun_ref_ptr).unlink_node() { - set_parents(app.fun, Some(remaining)); - } else { - set_parents(app.fun, None); - free_dead_node(app.fun); - } - let arg_ref_ptr = &app.arg_ref as *const Parents; - if let Some(remaining) = (*arg_ref_ptr).unlink_node() { - set_parents(app.arg, Some(remaining)); - } else { - set_parents(app.arg, None); - free_dead_node(app.arg); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - let dom_ref_ptr = &fun.dom_ref as *const Parents; - if let Some(remaining) = (*dom_ref_ptr).unlink_node() { - set_parents(fun.dom, Some(remaining)); - } else { - set_parents(fun.dom, None); - free_dead_node(fun.dom); - } - let img_ref_ptr = &fun.img_ref as *const Parents; - if let Some(remaining) = (*img_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(fun.img), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(fun.img), None); - free_dead_node(DAGPtr::Lam(fun.img)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - let dom_ref_ptr = &pi.dom_ref as *const Parents; - if let Some(remaining) = (*dom_ref_ptr).unlink_node() { - set_parents(pi.dom, Some(remaining)); - } else { - set_parents(pi.dom, None); - free_dead_node(pi.dom); - } - let img_ref_ptr = &pi.img_ref as *const Parents; - if let Some(remaining) = (*img_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(pi.img), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(pi.img), None); - free_dead_node(DAGPtr::Lam(pi.img)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - let typ_ref_ptr = &let_node.typ_ref as *const Parents; - if let Some(remaining) = (*typ_ref_ptr).unlink_node() { - set_parents(let_node.typ, Some(remaining)); - } else { - set_parents(let_node.typ, None); - free_dead_node(let_node.typ); - } - let val_ref_ptr = &let_node.val_ref as *const Parents; - if let Some(remaining) = (*val_ref_ptr).unlink_node() { - set_parents(let_node.val, Some(remaining)); - } else { - set_parents(let_node.val, None); - free_dead_node(let_node.val); - } - let bod_ref_ptr = &let_node.bod_ref as *const Parents; - if let Some(remaining) = (*bod_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(let_node.bod), None); - free_dead_node(DAGPtr::Lam(let_node.bod)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - let expr_ref_ptr = &proj.expr_ref as *const Parents; - if let Some(remaining) = (*expr_ref_ptr).unlink_node() { - set_parents(proj.expr, Some(remaining)); - } else { - set_parents(proj.expr, None); - free_dead_node(proj.expr); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Var(link) => { - let var = &*link.as_ptr(); - if let BinderPtr::Free = var.binder { +pub fn free_dead_node(root: DAGPtr) { + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + unsafe { + match node { + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + let bod_ref_ptr = &lam.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(lam.bod, Some(remaining)); + } else { + set_parents(lam.bod, None); + stack.push(lam.bod); + } drop(Box::from_raw(link.as_ptr())); - } - }, - DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun_ref_ptr = &app.fun_ref as *const Parents; + if let Some(remaining) = (*fun_ref_ptr).unlink_node() { + set_parents(app.fun, Some(remaining)); + } else { + set_parents(app.fun, None); + stack.push(app.fun); + } + let arg_ref_ptr = &app.arg_ref as *const Parents; + if let Some(remaining) = (*arg_ref_ptr).unlink_node() { + set_parents(app.arg, Some(remaining)); + } else { + set_parents(app.arg, None); + stack.push(app.arg); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let dom_ref_ptr = &fun.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(fun.dom, Some(remaining)); + } else { + set_parents(fun.dom, None); + stack.push(fun.dom); + } + let img_ref_ptr = &fun.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(fun.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(fun.img), None); + stack.push(DAGPtr::Lam(fun.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let dom_ref_ptr = &pi.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(pi.dom, Some(remaining)); + } else { + set_parents(pi.dom, None); + stack.push(pi.dom); + } + let img_ref_ptr = &pi.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(pi.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(pi.img), None); + stack.push(DAGPtr::Lam(pi.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let typ_ref_ptr = &let_node.typ_ref as *const Parents; + if let Some(remaining) = (*typ_ref_ptr).unlink_node() { + set_parents(let_node.typ, Some(remaining)); + } else { + set_parents(let_node.typ, None); + stack.push(let_node.typ); + } + let val_ref_ptr = &let_node.val_ref as *const Parents; + if let Some(remaining) = (*val_ref_ptr).unlink_node() { + set_parents(let_node.val, Some(remaining)); + } else { + set_parents(let_node.val, None); + stack.push(let_node.val); + } + let bod_ref_ptr = &let_node.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(let_node.bod), None); + stack.push(DAGPtr::Lam(let_node.bod)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let expr_ref_ptr = &proj.expr_ref as *const Parents; + if let Some(remaining) = (*expr_ref_ptr).unlink_node() { + set_parents(proj.expr, Some(remaining)); + } else { + set_parents(proj.expr, None); + stack.push(proj.expr); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + } } } } @@ -598,6 +606,11 @@ pub fn free_dead_node(node: DAGPtr) { // ============================================================================ /// Contract a lambda redex: (Fun dom (Lam bod var)) arg → [arg/var]bod. +/// +/// After substitution, propagates the result through the redex App's parent +/// pointers (via `replace_child`) and frees the dead App/Fun/Lam nodes. +/// This ensures that enclosing DAG structures are properly updated, enabling +/// DAG-native sub-term WHNF without Expr roundtrips. pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { unsafe { let app = &*redex.as_ptr(); @@ -605,18 +618,46 @@ pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { let var = &lambda.var; let arg = app.arg; + // Perform substitution if DLL::is_singleton(lambda.parents) { - if DLL::is_empty(var.parents) { - return lambda.bod; + if !DLL::is_empty(var.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + } + } else if !DLL::is_empty(var.parents) { + // General case: upcopy arg through var's parents + for parent in DLL::iter_option(var.parents) { + upcopy(arg, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); } - replace_child(DAGPtr::Var(NonNull::from(var)), arg); - return lambda.bod; } + lambda.bod + } +} + +/// Substitute an argument into a Pi's body: given `Pi(dom, Lam(var, body))` +/// and `arg`, produce `[arg/var]body`. Used for computing the result type +/// of function application during type inference. +/// +/// Unlike `reduce_lam`, this does NOT consume the enclosing App/Fun — it +/// works directly on the Pi's Lam node. The Lam should typically be +/// singly-parented (freshly inferred types are not shared). +pub fn subst_pi_body(lam: NonNull, arg: DAGPtr) -> DAGPtr { + unsafe { + let lambda = &*lam.as_ptr(); + let var = &lambda.var; + if DLL::is_empty(var.parents) { return lambda.bod; } + if DLL::is_singleton(lambda.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + return lambda.bod; + } + // General case: upcopy arg through var's parents for parent in DLL::iter_option(var.parents) { upcopy(arg, *parent); @@ -629,6 +670,9 @@ pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { } /// Contract a let redex: Let(typ, val, Lam(bod, var)) → [val/var]bod. +/// +/// After substitution, propagates the result through the Let node's parent +/// pointers (via `replace_child`) and frees the dead Let/Lam nodes. pub fn reduce_let(let_node: NonNull) -> DAGPtr { unsafe { let ln = &*let_node.as_ptr(); @@ -636,24 +680,20 @@ pub fn reduce_let(let_node: NonNull) -> DAGPtr { let var = &lam.var; let val = ln.val; + // Perform substitution if DLL::is_singleton(lam.parents) { - if DLL::is_empty(var.parents) { - return lam.bod; + if !DLL::is_empty(var.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), val); + } + } else if !DLL::is_empty(var.parents) { + for parent in DLL::iter_option(var.parents) { + upcopy(val, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); } - replace_child(DAGPtr::Var(NonNull::from(var)), val); - return lam.bod; - } - - if DLL::is_empty(var.parents) { - return lam.bod; } - for parent in DLL::iter_option(var.parents) { - upcopy(val, *parent); - } - for parent in DLL::iter_option(var.parents) { - clean_up(parent); - } lam.bod } } diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs index 4fdde07a..d7cef49a 100644 --- a/src/ix/kernel/whnf.rs +++ b/src/ix/kernel/whnf.rs @@ -8,14 +8,16 @@ use super::convert::{from_expr, to_expr}; use super::dag::*; use super::level::{simplify, subst_level}; use super::upcopy::{reduce_lam, reduce_let}; - +use crate::ix::env::Literal; // ============================================================================ // Expression helpers (inst, unfold_apps, foldl_apps, subst_expr_levels) // ============================================================================ -/// Instantiate bound variables: `body[0 := substs[0], 1 := substs[1], ...]`. -/// `substs[0]` replaces `Bvar(0)` (innermost). +/// Instantiate bound variables: `body[0 := substs[n-1], 1 := substs[n-2], ...]`. +/// Follows Lean 4's `instantiate` convention: `substs[0]` is the outermost +/// variable and replaces `Bvar(n-1)`, while `substs[n-1]` is the innermost +/// and replaces `Bvar(0)`. pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { if substs.is_empty() { return body.clone(); @@ -24,56 +26,108 @@ pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { } fn inst_aux(e: &Expr, substs: &[Expr], offset: u64) -> Expr { - match e.as_data() { - ExprData::Bvar(idx, _) => { - let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); - if idx_u64 >= offset { - let adjusted = (idx_u64 - offset) as usize; - if adjusted < substs.len() { - return substs[adjusted].clone(); - } - } - e.clone() - }, - ExprData::App(f, a, _) => { - let f2 = inst_aux(f, substs, offset); - let a2 = inst_aux(a, substs, offset); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = inst_aux(t, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = inst_aux(t, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = inst_aux(t, substs, offset); - let v2 = inst_aux(v, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = inst_aux(s, substs, offset); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = inst_aux(inner, substs, offset); - Expr::mdata(kvs.clone(), inner2) - }, - // Terminals with no bound vars - ExprData::Sort(..) - | ExprData::Const(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), + enum Frame<'a> { + Visit(&'a Expr, u64), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut work: Vec> = vec![Frame::Visit(e, offset)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e, offset) => match e.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 >= offset { + let adjusted = (idx_u64 - offset) as usize; + if adjusted < substs.len() { + // Lean 4 convention: substs[0] = outermost, substs[n-1] = innermost + // bvar(0) = innermost → substs[n-1], bvar(n-1) = outermost → substs[0] + results.push(substs[substs.len() - 1 - adjusted].clone()); + continue; + } + } + results.push(e.clone()); + }, + ExprData::App(f, a, _) => { + work.push(Frame::App); + work.push(Frame::Visit(a, offset)); + work.push(Frame::Visit(f, offset)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(v, offset)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s, offset)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner, offset)); + }, + ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => results.push(e.clone()), + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } -/// Abstract: replace free variable `fvar` with `Bvar(offset)` in `e`. +/// Abstract: replace free variables with bound variables. +/// Follows Lean 4 convention: `fvars[0]` (outermost) maps to `Bvar(n-1+offset)`, +/// `fvars[n-1]` (innermost) maps to `Bvar(0+offset)`. pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { if fvars.is_empty() { return e.clone(); @@ -82,50 +136,107 @@ pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { } fn abstr_aux(e: &Expr, fvars: &[Expr], offset: u64) -> Expr { - match e.as_data() { - ExprData::Fvar(..) => { - for (i, fv) in fvars.iter().enumerate().rev() { - if e == fv { - return Expr::bvar(Nat::from(i as u64 + offset)); - } - } - e.clone() - }, - ExprData::App(f, a, _) => { - let f2 = abstr_aux(f, fvars, offset); - let a2 = abstr_aux(a, fvars, offset); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = abstr_aux(t, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = abstr_aux(t, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = abstr_aux(t, fvars, offset); - let v2 = abstr_aux(v, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = abstr_aux(s, fvars, offset); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = abstr_aux(inner, fvars, offset); - Expr::mdata(kvs.clone(), inner2) - }, - ExprData::Bvar(..) - | ExprData::Sort(..) - | ExprData::Const(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), + enum Frame<'a> { + Visit(&'a Expr, u64), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut work: Vec> = vec![Frame::Visit(e, offset)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e, offset) => match e.as_data() { + ExprData::Fvar(..) => { + let n = fvars.len(); + let mut found = false; + for (i, fv) in fvars.iter().enumerate() { + if e == fv { + // fvars[0] (outermost) → Bvar(n-1+offset) + // fvars[n-1] (innermost) → Bvar(0+offset) + let bvar_idx = (n - 1 - i) as u64 + offset; + results.push(Expr::bvar(Nat::from(bvar_idx))); + found = true; + break; + } + } + if !found { + results.push(e.clone()); + } + }, + ExprData::App(f, a, _) => { + work.push(Frame::App); + work.push(Frame::Visit(a, offset)); + work.push(Frame::Visit(f, offset)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(v, offset)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s, offset)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner, offset)); + }, + ExprData::Bvar(..) + | ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => results.push(e.clone()), + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } /// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])`. @@ -154,66 +265,134 @@ pub fn foldl_apps(mut fun: Expr, args: impl Iterator) -> Expr { } /// Substitute universe level parameters in an expression. -pub fn subst_expr_levels( - e: &Expr, - params: &[Name], - values: &[Level], -) -> Expr { +pub fn subst_expr_levels(e: &Expr, params: &[Name], values: &[Level]) -> Expr { if params.is_empty() { return e.clone(); } subst_expr_levels_aux(e, params, values) } -fn subst_expr_levels_aux( - e: &Expr, - params: &[Name], - values: &[Level], -) -> Expr { - match e.as_data() { - ExprData::Sort(level, _) => { - Expr::sort(subst_level(level, params, values)) - }, - ExprData::Const(name, levels, _) => { - let new_levels: Vec = - levels.iter().map(|l| subst_level(l, params, values)).collect(); - Expr::cnst(name.clone(), new_levels) - }, - ExprData::App(f, a, _) => { - let f2 = subst_expr_levels_aux(f, params, values); - let a2 = subst_expr_levels_aux(a, params, values); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let v2 = subst_expr_levels_aux(v, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = subst_expr_levels_aux(s, params, values); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = subst_expr_levels_aux(inner, params, values); - Expr::mdata(kvs.clone(), inner2) - }, - // No levels to substitute - ExprData::Bvar(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), +fn subst_expr_levels_aux(e: &Expr, params: &[Name], values: &[Level]) -> Expr { + use rustc_hash::FxHashMap; + use std::sync::Arc; + + enum Frame<'a> { + Visit(&'a Expr), + CacheResult(*const ExprData), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut cache: FxHashMap<*const ExprData, Expr> = FxHashMap::default(); + let mut work: Vec> = vec![Frame::Visit(e)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e) => { + let key = Arc::as_ptr(&e.0); + if let Some(cached) = cache.get(&key) { + results.push(cached.clone()); + continue; + } + match e.as_data() { + ExprData::Sort(level, _) => { + let r = Expr::sort(subst_level(level, params, values)); + cache.insert(key, r.clone()); + results.push(r); + }, + ExprData::Const(name, levels, _) => { + let new_levels: Vec = + levels.iter().map(|l| subst_level(l, params, values)).collect(); + let r = Expr::cnst(name.clone(), new_levels); + cache.insert(key, r.clone()); + results.push(r); + }, + ExprData::App(f, a, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::App); + work.push(Frame::Visit(a)); + work.push(Frame::Visit(f)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(t)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(t)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(v)); + work.push(Frame::Visit(t)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner)); + }, + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => { + cache.insert(key, e.clone()); + results.push(e.clone()); + }, + } + }, + Frame::CacheResult(key) => { + let result = results.last().unwrap().clone(); + cache.insert(key, result); + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } /// Check if an expression has any loose bound variables above `offset`. @@ -222,40 +401,60 @@ pub fn has_loose_bvars(e: &Expr) -> bool { } fn has_loose_bvars_aux(e: &Expr, depth: u64) -> bool { - match e.as_data() { - ExprData::Bvar(idx, _) => idx.to_u64().unwrap_or(u64::MAX) >= depth, - ExprData::App(f, a, _) => { - has_loose_bvars_aux(f, depth) || has_loose_bvars_aux(a, depth) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - has_loose_bvars_aux(t, depth) || has_loose_bvars_aux(b, depth + 1) - }, - ExprData::LetE(_, t, v, b, _, _) => { - has_loose_bvars_aux(t, depth) - || has_loose_bvars_aux(v, depth) - || has_loose_bvars_aux(b, depth + 1) - }, - ExprData::Proj(_, _, s, _) => has_loose_bvars_aux(s, depth), - ExprData::Mdata(_, inner, _) => has_loose_bvars_aux(inner, depth), - _ => false, + let mut stack: Vec<(&Expr, u64)> = vec![(e, depth)]; + while let Some((e, depth)) = stack.pop() { + match e.as_data() { + ExprData::Bvar(idx, _) => { + if idx.to_u64().unwrap_or(u64::MAX) >= depth { + return true; + } + }, + ExprData::App(f, a, _) => { + stack.push((f, depth)); + stack.push((a, depth)); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push((t, depth)); + stack.push((b, depth + 1)); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push((t, depth)); + stack.push((v, depth)); + stack.push((b, depth + 1)); + }, + ExprData::Proj(_, _, s, _) => stack.push((s, depth)), + ExprData::Mdata(_, inner, _) => stack.push((inner, depth)), + _ => {}, + } } + false } /// Check if expression contains any free variables (Fvar). pub fn has_fvars(e: &Expr) -> bool { - match e.as_data() { - ExprData::Fvar(..) => true, - ExprData::App(f, a, _) => has_fvars(f) || has_fvars(a), - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - has_fvars(t) || has_fvars(b) - }, - ExprData::LetE(_, t, v, b, _, _) => { - has_fvars(t) || has_fvars(v) || has_fvars(b) - }, - ExprData::Proj(_, _, s, _) => has_fvars(s), - ExprData::Mdata(_, inner, _) => has_fvars(inner), - _ => false, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Fvar(..) => return true, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + _ => {}, + } } + false } // ============================================================================ @@ -277,16 +476,63 @@ pub(crate) fn mk_name2(a: &str, b: &str) -> Name { /// iota/quot/nat/projection, and uses DAG-level splicing for delta. pub fn whnf(e: &Expr, env: &Env) -> Expr { let mut dag = from_expr(e); - whnf_dag(&mut dag, env); + whnf_dag(&mut dag, env, false); + let result = to_expr(&dag); + free_dag(dag); + result +} + + + +/// WHNF without delta reduction (beta/zeta/iota/quot/nat/proj only). +/// Matches Lean 4's `whnf_core` used in `is_def_eq_core`. +pub fn whnf_no_delta(e: &Expr, env: &Env) -> Expr { + let mut dag = from_expr(e); + whnf_dag(&mut dag, env, true); let result = to_expr(&dag); free_dag(dag); result } + /// Trail-based WHNF on DAG. Walks down the App spine collecting a trail, /// then dispatches on the head node. -fn whnf_dag(dag: &mut DAG, env: &Env) { +/// When `no_delta` is true, skips delta (definition) unfolding. +pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { + use std::sync::atomic::{AtomicU64, Ordering}; + static WHNF_DEPTH: AtomicU64 = AtomicU64::new(0); + static WHNF_TOTAL: AtomicU64 = AtomicU64::new(0); + + let depth = WHNF_DEPTH.fetch_add(1, Ordering::Relaxed); + let total = WHNF_TOTAL.fetch_add(1, Ordering::Relaxed); + if depth > 50 || total % 10_000 == 0 { + eprintln!("[whnf_dag] depth={depth} total={total} no_delta={no_delta}"); + } + if depth > 200 { + eprintln!("[whnf_dag] DEPTH LIMIT depth={depth}, bailing"); + WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); + return; + } + + const WHNF_STEP_LIMIT: u64 = 100_000; + let mut steps: u64 = 0; + let whnf_done = |depth| { WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); }; loop { + steps += 1; + if steps > WHNF_STEP_LIMIT { + eprintln!("[whnf_dag] step limit exceeded ({steps}) depth={depth}"); + whnf_done(depth); + return; + } + if steps <= 5 || steps % 10_000 == 0 { + let head_variant = match dag.head { + DAGPtr::Var(_) => "Var", DAGPtr::Sort(_) => "Sort", DAGPtr::Cnst(_) => "Cnst", + DAGPtr::App(_) => "App", DAGPtr::Fun(_) => "Fun", DAGPtr::Pi(_) => "Pi", + DAGPtr::Let(_) => "Let", DAGPtr::Lit(_) => "Lit", DAGPtr::Proj(_) => "Proj", + DAGPtr::Lam(_) => "Lam", + }; + eprintln!("[whnf_dag] step={steps} head={head_variant} trail_build_start"); + } // Build trail of App nodes by walking down the fun chain let mut trail: Vec> = Vec::new(); let mut cursor = dag.head; @@ -295,12 +541,26 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { match cursor { DAGPtr::App(app) => { trail.push(app); + if trail.len() > 100_000 { + eprintln!("[whnf_dag] TRAIL OVERFLOW: trail.len()={} — possible App cycle!", trail.len()); + whnf_done(depth); return; + } cursor = unsafe { (*app.as_ptr()).fun }; }, _ => break, } } + if steps <= 5 || steps % 10_000 == 0 { + let cursor_variant = match cursor { + DAGPtr::Var(_) => "Var", DAGPtr::Sort(_) => "Sort", DAGPtr::Cnst(_) => "Cnst", + DAGPtr::App(_) => "App", DAGPtr::Fun(_) => "Fun", DAGPtr::Pi(_) => "Pi", + DAGPtr::Let(_) => "Let", DAGPtr::Lit(_) => "Lit", DAGPtr::Proj(_) => "Proj", + DAGPtr::Lam(_) => "Lam", + }; + eprintln!("[whnf_dag] step={steps} trail_len={} cursor={cursor_variant}", trail.len()); + } + match cursor { // Beta: Fun at head with args on trail DAGPtr::Fun(fun_ptr) if !trail.is_empty() => { @@ -320,23 +580,23 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { // Const: try iota, quot, nat, then delta DAGPtr::Cnst(_) => { - // Try iota, quot, nat at Expr level - if try_expr_reductions(dag, env) { + // Try iota, quot, nat + if try_dag_reductions(dag, env) { continue; } - // Try delta (definition unfolding) on DAG - if try_dag_delta(dag, &trail, env) { + // Try delta (definition unfolding) on DAG, unless no_delta + if !no_delta && try_dag_delta(dag, &trail, env) { continue; } - return; // stuck + whnf_done(depth); return; // stuck }, // Proj: try projection reduction (Expr-level fallback) DAGPtr::Proj(_) => { - if try_expr_reductions(dag, env) { + if try_dag_reductions(dag, env) { continue; } - return; // stuck + whnf_done(depth); return; // stuck }, // Sort: simplify level in place @@ -345,7 +605,7 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { let sort = &mut *sort_ptr.as_ptr(); sort.level = simplify(&sort.level); } - return; + whnf_done(depth); return; }, // Mdata: strip metadata (Expr-level fallback) @@ -353,15 +613,15 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { // Check if this is a Nat literal that could be a Nat.succ application // by trying Expr-level reductions (which handles nat ops) if !trail.is_empty() { - if try_expr_reductions(dag, env) { + if try_dag_reductions(dag, env) { continue; } } - return; + whnf_done(depth); return; }, // Everything else (Var, Pi, Lam without args, etc.): already WHNF - _ => return, + _ => { whnf_done(depth); return; }, } } } @@ -369,11 +629,7 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { /// Set the DAG head after a reduction step. /// If trail is empty, the result becomes the new head. /// If trail is non-empty, splice result into the innermost remaining App. -fn set_dag_head( - dag: &mut DAG, - result: DAGPtr, - trail: &[NonNull], -) { +fn set_dag_head(dag: &mut DAG, result: DAGPtr, trail: &[NonNull]) { if trail.is_empty() { dag.head = result; } else { @@ -384,138 +640,56 @@ fn set_dag_head( } } -/// Try iota/quot/nat/projection reductions at Expr level. -/// Converts current DAG to Expr, attempts reduction, converts back if -/// successful. -fn try_expr_reductions(dag: &mut DAG, env: &Env) -> bool { - let current_expr = to_expr(&DAG { head: dag.head }); - - let (head, args) = unfold_apps(¤t_expr); +/// Try iota/quot/nat/projection reductions directly on DAG. +fn try_dag_reductions(dag: &mut DAG, env: &Env) -> bool { + let (head, args) = dag_unfold_apps(dag.head); - let reduced = match head.as_data() { - ExprData::Const(name, levels, _) => { - // Try iota (recursor) reduction - if let Some(result) = try_reduce_rec(name, levels, &args, env) { + let reduced = match head { + DAGPtr::Cnst(cnst) => unsafe { + let cnst_ref = &*cnst.as_ptr(); + if let Some(result) = + try_reduce_rec_dag(&cnst_ref.name, &cnst_ref.levels, &args, env) + { Some(result) - } - // Try quotient reduction - else if let Some(result) = try_reduce_quot(name, &args, env) { + } else if let Some(result) = + try_reduce_quot_dag(&cnst_ref.name, &args, env) + { Some(result) - } - // Try nat reduction - else if let Some(result) = - try_reduce_nat(¤t_expr, env) + } else if let Some(result) = + try_reduce_native_dag(&cnst_ref.name, &args) + { + Some(result) + } else if let Some(result) = + try_reduce_nat_dag(&cnst_ref.name, &args, env) { Some(result) } else { None } }, - ExprData::Proj(type_name, idx, structure, _) => { - reduce_proj(type_name, idx, structure, env) - .map(|result| foldl_apps(result, args.into_iter())) - }, - ExprData::Mdata(_, inner, _) => { - Some(foldl_apps(inner.clone(), args.into_iter())) + DAGPtr::Proj(proj) => unsafe { + let proj_ref = &*proj.as_ptr(); + reduce_proj_dag(&proj_ref.type_name, &proj_ref.idx, proj_ref.expr, env) + .map(|result| dag_foldl_apps(result, &args)) }, _ => None, }; - if let Some(result_expr) = reduced { - let result_dag = from_expr(&result_expr); - dag.head = result_dag.head; + if let Some(result) = reduced { + dag.head = result; true } else { false } } -/// Try delta (definition) unfolding on DAG. -/// Looks up the constant, substitutes universe levels in the definition body, -/// converts it to a DAG, and splices it into the current DAG. -fn try_dag_delta( - dag: &mut DAG, - trail: &[NonNull], - env: &Env, -) -> bool { - // Extract constant info from head - let cnst_ref = match dag_head_past_trail(dag, trail) { - DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, - _ => return false, - }; - - let ci = match env.get(&cnst_ref.name) { - Some(c) => c, - None => return false, - }; - let (def_params, def_value) = match ci { - ConstantInfo::DefnInfo(d) - if d.hints != ReducibilityHints::Opaque => - { - (&d.cnst.level_params, &d.value) - }, - _ => return false, - }; - - if cnst_ref.levels.len() != def_params.len() { - return false; - } - - // Substitute levels at Expr level, then convert to DAG - let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); - let body_dag = from_expr(&val); - - // Splice body into the working DAG - set_dag_head(dag, body_dag.head, trail); - true -} - -/// Get the head node past the trail (the non-App node at the bottom). -fn dag_head_past_trail( - dag: &DAG, - trail: &[NonNull], -) -> DAGPtr { - if trail.is_empty() { - dag.head - } else { - unsafe { (*trail.last().unwrap().as_ptr()).fun } - } -} - -/// Try to unfold a definition at the head. -pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { - let (head, args) = unfold_apps(e); - let (name, levels) = match head.as_data() { - ExprData::Const(name, levels, _) => (name, levels), - _ => return None, - }; - - let ci = env.get(name)?; - let (def_params, def_value) = match ci { - ConstantInfo::DefnInfo(d) => { - if d.hints == ReducibilityHints::Opaque { - return None; - } - (&d.cnst.level_params, &d.value) - }, - _ => return None, - }; - - if levels.len() != def_params.len() { - return None; - } - - let val = subst_expr_levels(def_value, def_params, levels); - Some(foldl_apps(val, args.into_iter())) -} - -/// Try to reduce a recursor application (iota reduction). -fn try_reduce_rec( +/// Try to reduce a recursor application (iota reduction) on DAG. +fn try_reduce_rec_dag( name: &Name, levels: &[Level], - args: &[Expr], + args: &[DAGPtr], env: &Env, -) -> Option { +) -> Option { let ci = env.get(name)?; let rec = match ci { ConstantInfo::RecInfo(r) => r, @@ -529,150 +703,104 @@ fn try_reduce_rec( let major = args.get(major_idx)?; - // WHNF the major premise - let major_whnf = whnf(major, env); - - // Handle nat literal → constructor - let major_ctor = match major_whnf.as_data() { - ExprData::Lit(Literal::NatVal(n), _) => nat_lit_to_constructor(n), - _ => major_whnf.clone(), + // WHNF the major premise directly on the DAG + let mut major_dag = DAG { head: *major }; + whnf_dag(&mut major_dag, env, false); + + // Decompose the major premise into (ctor_head, ctor_args) at DAG level. + // Handle nat literal → constructor form as DAG nodes directly. + let (ctor_head, ctor_args) = match major_dag.head { + DAGPtr::Lit(lit) => unsafe { + match &(*lit.as_ptr()).val { + Literal::NatVal(n) => { + if n.0 == BigUint::ZERO { + let zero = DAGPtr::Cnst(alloc_val(Cnst { + name: mk_name2("Nat", "zero"), + levels: vec![], + parents: None, + })); + (zero, vec![]) + } else { + let pred = Nat(n.0.clone() - BigUint::from(1u64)); + let succ = DAGPtr::Cnst(alloc_val(Cnst { + name: mk_name2("Nat", "succ"), + levels: vec![], + parents: None, + })); + let pred_lit = nat_lit_dag(pred); + (succ, vec![pred_lit]) + } + }, + _ => return None, + } + }, + _ => dag_unfold_apps(major_dag.head), }; - let (ctor_head, ctor_args) = unfold_apps(&major_ctor); - - // Find the matching rec rule - let ctor_name = match ctor_head.as_data() { - ExprData::Const(name, _, _) => name, + // Find the matching rec rule by reading ctor name from DAG head + let ctor_name = match ctor_head { + DAGPtr::Cnst(cnst) => unsafe { &(*cnst.as_ptr()).name }, _ => return None, }; - let rule = rec.rules.iter().find(|r| &r.ctor == ctor_name)?; + let rule = rec.rules.iter().find(|r| r.ctor == *ctor_name)?; let n_fields = rule.n_fields.to_u64().unwrap() as usize; let num_params = rec.num_params.to_u64().unwrap() as usize; let num_motives = rec.num_motives.to_u64().unwrap() as usize; let num_minors = rec.num_minors.to_u64().unwrap() as usize; - // The constructor args may have extra params for nested inductives - let ctor_args_wo_params = - if ctor_args.len() >= n_fields { - &ctor_args[ctor_args.len() - n_fields..] - } else { - return None; - }; - - // Substitute universe levels in the rule's RHS - let rhs = subst_expr_levels( - &rule.rhs, - &rec.cnst.level_params, - levels, - ); - - // Apply: params, motives, minors - let prefix_count = num_params + num_motives + num_minors; - let mut result = rhs; - for arg in args.iter().take(prefix_count) { - result = Expr::app(result, arg.clone()); - } - - // Apply constructor fields - for arg in ctor_args_wo_params { - result = Expr::app(result, arg.clone()); - } - - // Apply remaining args after major - for arg in args.iter().skip(major_idx + 1) { - result = Expr::app(result, arg.clone()); + if ctor_args.len() < n_fields { + return None; } + let ctor_fields = &ctor_args[ctor_args.len() - n_fields..]; - Some(result) -} - -/// Convert a Nat literal to its constructor form. -fn nat_lit_to_constructor(n: &Nat) -> Expr { - if n.0 == BigUint::ZERO { - Expr::cnst(mk_name2("Nat", "zero"), vec![]) - } else { - let pred = Nat(n.0.clone() - BigUint::from(1u64)); - let pred_expr = Expr::lit(Literal::NatVal(pred)); - Expr::app(Expr::cnst(mk_name2("Nat", "succ"), vec![]), pred_expr) - } -} + // Build RHS as DAG: from_expr(subst_expr_levels(rule.rhs, ...)) once + // (unavoidable — rule RHS is stored as Expr in Env) + let rhs_expr = subst_expr_levels(&rule.rhs, &rec.cnst.level_params, levels); + let rhs_dag = from_expr(&rhs_expr); -/// Convert a string literal to its constructor form: -/// `"hello"` → `String.mk (List.cons 'h' (List.cons 'e' ... List.nil))` -/// where chars are represented as `Char.ofNat n`. -fn string_lit_to_constructor(s: &str) -> Expr { - let list_name = Name::str(Name::anon(), "List".into()); - let char_name = Name::str(Name::anon(), "Char".into()); - let char_type = Expr::cnst(char_name.clone(), vec![]); - - // Build the list from right to left - // List.nil.{0} : List Char - let nil = Expr::app( - Expr::cnst( - Name::str(list_name.clone(), "nil".into()), - vec![Level::succ(Level::zero())], - ), - char_type.clone(), - ); - - let result = s.chars().rev().fold(nil, |acc, c| { - let char_val = Expr::app( - Expr::cnst(Name::str(char_name.clone(), "ofNat".into()), vec![]), - Expr::lit(Literal::NatVal(Nat::from(c as u64))), - ); - // List.cons.{0} Char char_val acc - Expr::app( - Expr::app( - Expr::app( - Expr::cnst( - Name::str(list_name.clone(), "cons".into()), - vec![Level::succ(Level::zero())], - ), - char_type.clone(), - ), - char_val, - ), - acc, - ) - }); + // Collect all args at DAG level: params+motives+minors, ctor_fields, rest + let prefix_count = num_params + num_motives + num_minors; + let mut all_args: Vec = + Vec::with_capacity(prefix_count + n_fields + args.len() - major_idx - 1); + all_args.extend_from_slice(&args[..prefix_count]); + all_args.extend_from_slice(ctor_fields); + all_args.extend_from_slice(&args[major_idx + 1..]); - // String.mk list - Expr::app( - Expr::cnst( - Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), - vec![], - ), - result, - ) + Some(dag_foldl_apps(rhs_dag.head, &all_args)) } -/// Try to reduce a projection. -fn reduce_proj( +/// Try to reduce a projection on DAG. +fn reduce_proj_dag( _type_name: &Name, idx: &Nat, - structure: &Expr, + structure: DAGPtr, env: &Env, -) -> Option { - let structure_whnf = whnf(structure, env); - - // Handle string literal → constructor - let structure_ctor = match structure_whnf.as_data() { - ExprData::Lit(Literal::StrVal(s), _) => { - string_lit_to_constructor(s) +) -> Option { + // WHNF the structure directly on the DAG + let mut struct_dag = DAG { head: structure }; + whnf_dag(&mut struct_dag, env, false); + + // Handle string literal → constructor form at DAG level + let struct_whnf = match struct_dag.head { + DAGPtr::Lit(lit) => unsafe { + match &(*lit.as_ptr()).val { + Literal::StrVal(s) => string_lit_to_dag_ctor(s), + _ => struct_dag.head, + } }, - _ => structure_whnf, + _ => struct_dag.head, }; - let (ctor_head, ctor_args) = unfold_apps(&structure_ctor); + // Decompose at DAG level + let (ctor_head, ctor_args) = dag_unfold_apps(struct_whnf); - let ctor_name = match ctor_head.as_data() { - ExprData::Const(name, _, _) => name, + let ctor_name = match ctor_head { + DAGPtr::Cnst(cnst) => unsafe { &(*cnst.as_ptr()).name }, _ => return None, }; - // Look up constructor to get num_params let ci = env.get(ctor_name)?; let num_params = match ci { ConstantInfo::CtorInfo(c) => c.num_params.to_u64().unwrap() as usize, @@ -680,15 +808,15 @@ fn reduce_proj( }; let field_idx = num_params + idx.to_u64().unwrap() as usize; - ctor_args.get(field_idx).cloned() + ctor_args.get(field_idx).copied() } -/// Try to reduce a quotient operation. -fn try_reduce_quot( +/// Try to reduce a quotient operation on DAG. +fn try_reduce_quot_dag( name: &Name, - args: &[Expr], + args: &[DAGPtr], env: &Env, -) -> Option { +) -> Option { let ci = env.get(name)?; let kind = match ci { ConstantInfo::QuotInfo(q) => &q.kind, @@ -702,33 +830,304 @@ fn try_reduce_quot( }; let qmk = args.get(qmk_idx)?; - let qmk_whnf = whnf(qmk, env); - // Check that the head is Quot.mk - let (qmk_head, _) = unfold_apps(&qmk_whnf); - match qmk_head.as_data() { - ExprData::Const(n, _, _) if *n == mk_name2("Quot", "mk") => {}, + // WHNF the Quot.mk arg directly on the DAG + let mut qmk_dag = DAG { head: *qmk }; + whnf_dag(&mut qmk_dag, env, false); + + // Check that the head is Quot.mk at DAG level + let (qmk_head, _) = dag_unfold_apps(qmk_dag.head); + match qmk_head { + DAGPtr::Cnst(cnst) => unsafe { + if (*cnst.as_ptr()).name != mk_name2("Quot", "mk") { + return None; + } + }, _ => return None, } let f = args.get(3)?; - // Extract the argument of Quot.mk - let qmk_arg = match qmk_whnf.as_data() { - ExprData::App(_, arg, _) => arg, + // Extract the argument of Quot.mk (the outermost App's arg) + let qmk_arg = match qmk_dag.head { + DAGPtr::App(app) => unsafe { (*app.as_ptr()).arg }, _ => return None, }; - let mut result = Expr::app(f.clone(), qmk_arg.clone()); - for arg in args.iter().skip(rest_idx) { - result = Expr::app(result, arg.clone()); + // Build result directly at DAG level: f qmk_arg rest_args... + let mut result_args = Vec::with_capacity(1 + args.len() - rest_idx); + result_args.push(qmk_arg); + result_args.extend_from_slice(&args[rest_idx..]); + Some(dag_foldl_apps(*f, &result_args)) +} + +/// Try to reduce `Lean.reduceBool` / `Lean.reduceNat` on DAG. +pub(crate) fn try_reduce_native_dag(name: &Name, args: &[DAGPtr]) -> Option { + if args.len() != 1 { + return None; + } + let reduce_bool = mk_name2("Lean", "reduceBool"); + let reduce_nat = mk_name2("Lean", "reduceNat"); + if *name == reduce_bool || *name == reduce_nat { + Some(args[0]) + } else { + None } +} - Some(result) +/// Try to reduce nat operations on DAG. +pub(crate) fn try_reduce_nat_dag( + name: &Name, + args: &[DAGPtr], + env: &Env, +) -> Option { + match args.len() { + 1 => { + if *name == mk_name2("Nat", "succ") { + // WHNF the arg directly on the DAG + let mut arg_dag = DAG { head: args[0] }; + whnf_dag(&mut arg_dag, env, false); + let n = get_nat_value_dag(arg_dag.head)?; + let result = alloc_val(LitNode { + val: Literal::NatVal(Nat(n + BigUint::from(1u64))), + parents: None, + }); + Some(DAGPtr::Lit(result)) + } else { + None + } + }, + 2 => { + // WHNF both args directly on the DAG + let mut a_dag = DAG { head: args[0] }; + whnf_dag(&mut a_dag, env, false); + let mut b_dag = DAG { head: args[1] }; + whnf_dag(&mut b_dag, env, false); + let a = get_nat_value_dag(a_dag.head)?; + let b = get_nat_value_dag(b_dag.head)?; + + if *name == mk_name2("Nat", "add") { + Some(nat_lit_dag(Nat(a + b))) + } else if *name == mk_name2("Nat", "sub") { + Some(nat_lit_dag(Nat(if a >= b { a - b } else { BigUint::ZERO }))) + } else if *name == mk_name2("Nat", "mul") { + Some(nat_lit_dag(Nat(a * b))) + } else if *name == mk_name2("Nat", "div") { + Some(nat_lit_dag(Nat(if b == BigUint::ZERO { + BigUint::ZERO + } else { + a / b + }))) + } else if *name == mk_name2("Nat", "mod") { + Some(nat_lit_dag(Nat(if b == BigUint::ZERO { a } else { a % b }))) + } else if *name == mk_name2("Nat", "beq") { + Some(bool_to_dag(a == b)) + } else if *name == mk_name2("Nat", "ble") { + Some(bool_to_dag(a <= b)) + } else if *name == mk_name2("Nat", "pow") { + let exp = u32::try_from(&b).unwrap_or(u32::MAX); + Some(nat_lit_dag(Nat(a.pow(exp)))) + } else if *name == mk_name2("Nat", "land") { + Some(nat_lit_dag(Nat(a & b))) + } else if *name == mk_name2("Nat", "lor") { + Some(nat_lit_dag(Nat(a | b))) + } else if *name == mk_name2("Nat", "xor") { + Some(nat_lit_dag(Nat(a ^ b))) + } else if *name == mk_name2("Nat", "shiftLeft") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(nat_lit_dag(Nat(a << shift))) + } else if *name == mk_name2("Nat", "shiftRight") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(nat_lit_dag(Nat(a >> shift))) + } else if *name == mk_name2("Nat", "blt") { + Some(bool_to_dag(a < b)) + } else { + None + } + }, + _ => None, + } +} + +/// Extract a nat value from a DAGPtr (analog of get_nat_value_expr). +fn get_nat_value_dag(ptr: DAGPtr) -> Option { + unsafe { + match ptr { + DAGPtr::Lit(lit) => match &(*lit.as_ptr()).val { + Literal::NatVal(n) => Some(n.0.clone()), + _ => None, + }, + DAGPtr::Cnst(cnst) => { + if (*cnst.as_ptr()).name == mk_name2("Nat", "zero") { + Some(BigUint::ZERO) + } else { + None + } + }, + _ => None, + } + } +} + +/// Allocate a Nat literal DAG node. +pub(crate) fn nat_lit_dag(n: Nat) -> DAGPtr { + DAGPtr::Lit(alloc_val(LitNode { val: Literal::NatVal(n), parents: None })) +} + +/// Convert a bool to a DAG constant (Bool.true / Bool.false). +fn bool_to_dag(b: bool) -> DAGPtr { + let name = + if b { mk_name2("Bool", "true") } else { mk_name2("Bool", "false") }; + DAGPtr::Cnst(alloc_val(Cnst { name, levels: vec![], parents: None })) +} + +/// Build `String.mk (List.cons (Char.ofNat n1) (List.cons ... List.nil))` +/// entirely at the DAG level (no Expr round-trip). +fn string_lit_to_dag_ctor(s: &str) -> DAGPtr { + let list_name = Name::str(Name::anon(), "List".into()); + let char_name = Name::str(Name::anon(), "Char".into()); + let char_type = DAGPtr::Cnst(alloc_val(Cnst { + name: char_name.clone(), + levels: vec![], + parents: None, + })); + let nil = DAGPtr::App(alloc_app( + DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(list_name.clone(), "nil".into()), + levels: vec![Level::succ(Level::zero())], + parents: None, + })), + char_type, + None, + )); + let list = s.chars().rev().fold(nil, |acc, c| { + let of_nat = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(char_name.clone(), "ofNat".into()), + levels: vec![], + parents: None, + })); + let char_val = + DAGPtr::App(alloc_app(of_nat, nat_lit_dag(Nat::from(c as u64)), None)); + let char_type_copy = DAGPtr::Cnst(alloc_val(Cnst { + name: char_name.clone(), + levels: vec![], + parents: None, + })); + let cons = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(list_name.clone(), "cons".into()), + levels: vec![Level::succ(Level::zero())], + parents: None, + })); + let c1 = DAGPtr::App(alloc_app(cons, char_type_copy, None)); + let c2 = DAGPtr::App(alloc_app(c1, char_val, None)); + DAGPtr::App(alloc_app(c2, acc, None)) + }); + let string_mk = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), + levels: vec![], + parents: None, + })); + DAGPtr::App(alloc_app(string_mk, list, None)) +} + +/// Try delta (definition) unfolding on DAG. +/// Looks up the constant, substitutes universe levels in the definition body, +/// converts it to a DAG, and splices it into the current DAG. +fn try_dag_delta(dag: &mut DAG, trail: &[NonNull], env: &Env) -> bool { + // Extract constant info from head + let cnst_ref = match dag_head_past_trail(dag, trail) { + DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, + _ => return false, + }; + + let ci = match env.get(&cnst_ref.name) { + Some(c) => c, + None => return false, + }; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) if d.hints != ReducibilityHints::Opaque => { + (&d.cnst.level_params, &d.value) + }, + _ => return false, + }; + + if cnst_ref.levels.len() != def_params.len() { + return false; + } + + eprintln!("[try_dag_delta] unfolding: {}", cnst_ref.name.pretty()); + + // Substitute levels at Expr level, then convert to DAG + let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); + eprintln!("[try_dag_delta] subst done, calling from_expr"); + let body_dag = from_expr(&val); + eprintln!("[try_dag_delta] from_expr done, calling set_dag_head"); + + // Splice body into the working DAG + set_dag_head(dag, body_dag.head, trail); + eprintln!("[try_dag_delta] set_dag_head done"); + true +} + +/// Get the head node past the trail (the non-App node at the bottom). +fn dag_head_past_trail(dag: &DAG, trail: &[NonNull]) -> DAGPtr { + if trail.is_empty() { + dag.head + } else { + unsafe { (*trail.last().unwrap().as_ptr()).fun } + } +} + +/// Try to unfold a definition at the head. +pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + let (name, levels) = match head.as_data() { + ExprData::Const(name, levels, _) => (name, levels), + _ => return None, + }; + + let ci = env.get(name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + _ => return None, + }; + + if levels.len() != def_params.len() { + return None; + } + + let val = subst_expr_levels(def_value, def_params, levels); + Some(foldl_apps(val, args.into_iter())) +} + +/// Try to reduce `Lean.reduceBool` / `Lean.reduceNat`. +/// +/// These are opaque constants with special kernel reduction rules. In the Lean 4 +/// kernel they evaluate their argument using compiled native code. Since both are +/// semantically identity functions (`fun b => b` / `fun n => n`), we simply +/// return the argument and let the WHNF loop continue reducing it via our +/// existing efficient paths (e.g. `try_reduce_nat` handles `Nat.ble` etc. in O(1)). +pub(crate) fn try_reduce_native(name: &Name, args: &[Expr]) -> Option { + if args.len() != 1 { + return None; + } + let reduce_bool = mk_name2("Lean", "reduceBool"); + let reduce_nat = mk_name2("Lean", "reduceNat"); + if *name == reduce_bool || *name == reduce_nat { + Some(args[0].clone()) + } else { + None + } } /// Try to reduce nat operations. -fn try_reduce_nat(e: &Expr, env: &Env) -> Option { +pub(crate) fn try_reduce_nat(e: &Expr, env: &Env) -> Option { if has_fvars(e) { return None; } @@ -818,11 +1217,8 @@ fn get_nat_value(e: &Expr) -> Option { } fn bool_to_expr(b: bool) -> Option { - let name = if b { - mk_name2("Bool", "true") - } else { - mk_name2("Bool", "false") - }; + let name = + if b { mk_name2("Bool", "true") } else { mk_name2("Bool", "false") }; Some(Expr::cnst(name, vec![])) } @@ -865,12 +1261,8 @@ mod tests { BinderInfo::Default, ); let result = inst(&body, &[nat_zero()]); - let expected = Expr::lam( - Name::anon(), - nat_type(), - nat_zero(), - BinderInfo::Default, - ); + let expected = + Expr::lam(Name::anon(), nat_type(), nat_zero(), BinderInfo::Default); assert_eq!(result, expected); } @@ -927,11 +1319,7 @@ mod tests { env.insert( n.clone(), ConstantInfo::DefnInfo(DefinitionVal { - cnst: ConstantVal { - name: n.clone(), - level_params: vec![], - typ, - }, + cnst: ConstantVal { name: n.clone(), level_params: vec![], typ }, value, hints: ReducibilityHints::Abbrev, safety: DefinitionSafety::Safe, @@ -1198,7 +1586,10 @@ mod tests { fn test_nat_shift_right() { let env = Env::default(); let e = Expr::app( - Expr::app(Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), nat_lit(256)), + Expr::app( + Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), + nat_lit(256), + ), nat_lit(4), ); assert_eq!(whnf(&e, &env), nat_lit(16)); @@ -1336,12 +1727,8 @@ mod tests { #[test] fn test_whnf_pi_unchanged() { let env = Env::default(); - let e = Expr::all( - mk_name("x"), - nat_type(), - nat_type(), - BinderInfo::Default, - ); + let e = + Expr::all(mk_name("x"), nat_type(), nat_type(), BinderInfo::Default); let result = whnf(&e, &env); assert_eq!(result, e); } @@ -1417,4 +1804,371 @@ mod tests { let result = subst_expr_levels(&e, &[u_name], &[Level::zero()]); assert_eq!(result, Expr::sort(Level::zero())); } + + // ========================================================================== + // Nat.rec on large literals — reproduces the hang + // ========================================================================== + + /// Build a minimal env with Nat, Nat.zero, Nat.succ, and Nat.rec. + fn mk_nat_rec_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + let zero_name = mk_name2("Nat", "zero"); + let succ_name = mk_name2("Nat", "succ"); + let rec_name = mk_name2("Nat", "rec"); + + // Nat : Sort 1 + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![zero_name.clone(), succ_name.clone()], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Nat.zero : Nat + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: nat_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + // Nat.succ : Nat → Nat + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: nat_name.clone(), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Nat.rec.{u} : (motive : Nat → Sort u) → motive Nat.zero → + // ((n : Nat) → motive n → motive (Nat.succ n)) → (t : Nat) → motive t + // Rules: + // Nat.rec m z s Nat.zero => z + // Nat.rec m z s (Nat.succ n) => s n (Nat.rec m z s n) + let u = mk_name("u"); + env.insert( + rec_name.clone(), + ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: rec_name.clone(), + level_params: vec![u.clone()], + typ: Expr::sort(Level::param(u.clone())), // placeholder + }, + all: vec![nat_name], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + // Nat.rec m z s Nat.zero => z + RecursorRule { + ctor: zero_name, + n_fields: Nat::from(0u64), + // RHS is just bvar(1) = z (the zero minor) + // After substitution: Nat.rec m z s Nat.zero + // => rule.rhs applied to [m, z, s] + // => z + rhs: Expr::bvar(Nat::from(1u64)), + }, + // Nat.rec m z s (Nat.succ n) => s n (Nat.rec m z s n) + RecursorRule { + ctor: succ_name, + n_fields: Nat::from(1u64), + // RHS = fun n => s n (Nat.rec m z s n) + // But actually the rule rhs receives [m, z, s] then [n] as args + // rhs = bvar(0) = s, applied to the field n + // Actually the recursor rule rhs is applied as: + // rhs m z s + // For Nat.succ with 1 field (the predecessor n): + // rhs m z s n => s n (Nat.rec.{u} m z s n) + // So rhs = lam receiving params+minors then fields: + // Actually, rhs is an expression that gets applied to + // [params..., motives..., minors..., fields...] + // For Nat.rec: 0 params, 1 motive, 2 minors, 1 field + // So rhs gets applied to: m z s n + // We want: s n (Nat.rec.{u} m z s n) + // As a closed term using bvars after inst: + // After being applied to m z s n: + // bvar(3) = m, bvar(2) = z, bvar(1) = s, bvar(0) = n + // We want: s n (Nat.rec.{u} m z s n) + // = app(app(bvar(1), bvar(0)), + // app(app(app(app(Nat.rec.{u}, bvar(3)), bvar(2)), bvar(1)), bvar(0))) + // But wait, rhs is not a lambda - it gets args applied directly. + // The rhs just receives the args via Expr::app in try_reduce_rec. + // So rhs should be a term that, after being applied to m, z, s, n, + // produces s n (Nat.rec m z s n). + // + // Simplest: rhs is a 4-arg lambda + rhs: Expr::lam( + mk_name("m"), + Expr::sort(Level::zero()), // placeholder type + Expr::lam( + mk_name("z"), + Expr::sort(Level::zero()), + Expr::lam( + mk_name("s"), + Expr::sort(Level::zero()), + Expr::lam( + mk_name("n"), + nat_type(), + // body: s n (Nat.rec.{u} m z s n) + // bvar(3)=m, bvar(2)=z, bvar(1)=s, bvar(0)=n + Expr::app( + Expr::app( + Expr::bvar(Nat::from(1u64)), // s + Expr::bvar(Nat::from(0u64)), // n + ), + Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + rec_name.clone(), + vec![Level::param(u.clone())], + ), + Expr::bvar(Nat::from(3u64)), // m + ), + Expr::bvar(Nat::from(2u64)), // z + ), + Expr::bvar(Nat::from(1u64)), // s + ), + Expr::bvar(Nat::from(0u64)), // n + ), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + }, + ], + k: false, + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn test_nat_rec_small_literal() { + // Nat.rec (fun _ => Nat) 0 (fun n _ => Nat.succ n) 3 + // Should reduce to 3 (identity via recursion) + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive, + ), + zero_case, + ), + succ_case, + ), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(3)); + } + + #[test] + fn test_nat_rec_large_literal_hangs() { + // This test demonstrates the O(n) recursor peeling issue. + // Nat.rec on 65536 (2^16) — would take 65536 recursive steps. + // We use a timeout-style approach: just verify it works for small n + // and document that large n hangs. + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + // Test with 100 — should be fast enough + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive.clone(), + ), + zero_case.clone(), + ), + succ_case.clone(), + ), + nat_lit(100), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(100)); + + // nat_lit(65536) would hang here — that's the bug to fix + } + + // ========================================================================== + // try_reduce_native tests (Lean.reduceBool / Lean.reduceNat) + // ========================================================================== + + #[test] + fn test_reduce_bool_true() { + // Lean.reduceBool Bool.true → Bool.true + let args = vec![Expr::cnst(mk_name2("Bool", "true"), vec![])]; + let result = try_reduce_native(&mk_name2("Lean", "reduceBool"), &args); + assert_eq!(result, Some(Expr::cnst(mk_name2("Bool", "true"), vec![]))); + } + + #[test] + fn test_reduce_nat_literal() { + // Lean.reduceNat (lit 42) → lit 42 + let args = vec![nat_lit(42)]; + let result = try_reduce_native(&mk_name2("Lean", "reduceNat"), &args); + assert_eq!(result, Some(nat_lit(42))); + } + + #[test] + fn test_reduce_bool_with_nat_ble() { + // Lean.reduceBool (Nat.ble 3 5) → passes through the arg + // WHNF will then reduce Nat.ble 3 5 → Bool.true + let ble_expr = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(3)), + nat_lit(5), + ); + let args = vec![ble_expr.clone()]; + let result = try_reduce_native(&mk_name2("Lean", "reduceBool"), &args); + assert_eq!(result, Some(ble_expr)); + + // Verify WHNF continues reducing the returned argument + let env = Env::default(); + let full_result = whnf(&result.unwrap(), &env); + assert_eq!(full_result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_reduce_native_wrong_name() { + let args = vec![nat_lit(1)]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "other"), &args), None); + } + + #[test] + fn test_reduce_native_wrong_arity() { + // 0 args + let empty: Vec = vec![]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "reduceBool"), &empty), None); + // 2 args + let two = vec![nat_lit(1), nat_lit(2)]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "reduceBool"), &two), None); + } + + #[test] + fn test_nat_rec_65536() { + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive, + ), + zero_case, + ), + succ_case, + ), + nat_lit(65536), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(65536)); + } } diff --git a/src/lean/ffi.rs b/src/lean/ffi.rs index 07003a57..40553a06 100644 --- a/src/lean/ffi.rs +++ b/src/lean/ffi.rs @@ -6,6 +6,7 @@ pub mod lean_env; // Modular FFI structure pub mod builder; // IxEnvBuilder struct +pub mod check; // Kernel type-checking: rs_check_env pub mod compile; // Compilation: rs_compile_env_full, rs_compile_phases, etc. pub mod graph; // Graph/SCC: rs_build_ref_graph, rs_compute_sccs pub mod ix; // Ix types: Name, Level, Expr, ConstantInfo, Environment diff --git a/src/lean/ffi/check.rs b/src/lean/ffi/check.rs new file mode 100644 index 00000000..01e69cc7 --- /dev/null +++ b/src/lean/ffi/check.rs @@ -0,0 +1,182 @@ +//! FFI bridge for the Rust kernel type-checker. +//! +//! Provides `extern "C"` function callable from Lean via `@[extern]`: +//! - `rs_check_env`: type-check all declarations in a Lean environment + +use std::ffi::{CString, c_void}; + +use super::builder::LeanBuildCache; +use super::ffi_io_guard; +use super::ix::expr::build_expr; +use super::ix::name::build_name; +use super::lean_env::lean_ptr_to_env; +use crate::ix::env::{ConstantInfo, Name}; +use crate::ix::kernel::dag_tc::{DagTypeChecker, dag_check_env}; +use crate::ix::kernel::error::TcError; +use crate::lean::string::LeanStringObject; +use crate::lean::{ + as_ref_unsafe, lean_alloc_array, lean_alloc_ctor, lean_array_set_core, + lean_ctor_set, lean_ctor_set_uint64, lean_io_result_mk_ok, lean_mk_string, +}; + +/// Build a Lean `Ix.Kernel.CheckError` constructor from a Rust `TcError`. +/// +/// Constructor tags (must match the Lean `inductive CheckError`): +/// - 0: typeExpected (2 obj: expr, inferred) +/// - 1: functionExpected (2 obj: expr, inferred) +/// - 2: typeMismatch (3 obj: expected, found, expr) +/// - 3: defEqFailure (2 obj: lhs, rhs) +/// - 4: unknownConst (1 obj: name) +/// - 5: duplicateUniverse (1 obj: name) +/// - 6: freeBoundVariable (0 obj + 8 byte scalar: idx) +/// - 7: kernelException (1 obj: msg) +unsafe fn build_check_error( + cache: &mut LeanBuildCache, + err: &TcError, +) -> *mut c_void { + unsafe { + match err { + TcError::TypeExpected { expr, inferred } => { + let obj = lean_alloc_ctor(0, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, expr)); + lean_ctor_set(obj, 1, build_expr(cache, inferred)); + obj + }, + TcError::FunctionExpected { expr, inferred } => { + let obj = lean_alloc_ctor(1, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, expr)); + lean_ctor_set(obj, 1, build_expr(cache, inferred)); + obj + }, + TcError::TypeMismatch { expected, found, expr } => { + let obj = lean_alloc_ctor(2, 3, 0); + lean_ctor_set(obj, 0, build_expr(cache, expected)); + lean_ctor_set(obj, 1, build_expr(cache, found)); + lean_ctor_set(obj, 2, build_expr(cache, expr)); + obj + }, + TcError::DefEqFailure { lhs, rhs } => { + let obj = lean_alloc_ctor(3, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, lhs)); + lean_ctor_set(obj, 1, build_expr(cache, rhs)); + obj + }, + TcError::UnknownConst { name } => { + let obj = lean_alloc_ctor(4, 1, 0); + lean_ctor_set(obj, 0, build_name(cache, name)); + obj + }, + TcError::DuplicateUniverse { name } => { + let obj = lean_alloc_ctor(5, 1, 0); + lean_ctor_set(obj, 0, build_name(cache, name)); + obj + }, + TcError::FreeBoundVariable { idx } => { + let obj = lean_alloc_ctor(6, 0, 8); + lean_ctor_set_uint64(obj, 0, *idx); + obj + }, + TcError::KernelException { msg } => { + let c_msg = CString::new(msg.as_str()) + .unwrap_or_else(|_| CString::new("kernel exception").unwrap()); + let obj = lean_alloc_ctor(7, 1, 0); + lean_ctor_set(obj, 0, lean_mk_string(c_msg.as_ptr())); + obj + }, + } + } +} + +/// FFI function to type-check all declarations in a Lean environment using the +/// Rust kernel. Returns `IO (Array (Ix.Name × CheckError))`. +#[unsafe(no_mangle)] +pub extern "C" fn rs_check_env(env_consts_ptr: *const c_void) -> *mut c_void { + ffi_io_guard(std::panic::AssertUnwindSafe(|| { + let rust_env = lean_ptr_to_env(env_consts_ptr); + let errors = dag_check_env(&rust_env); + let mut cache = LeanBuildCache::new(); + unsafe { + let arr = lean_alloc_array(errors.len(), errors.len()); + for (i, (name, tc_err)) in errors.iter().enumerate() { + let name_obj = build_name(&mut cache, name); + let err_obj = build_check_error(&mut cache, tc_err); + let pair = lean_alloc_ctor(0, 2, 0); // Prod.mk + lean_ctor_set(pair, 0, name_obj); + lean_ctor_set(pair, 1, err_obj); + lean_array_set_core(arr, i, pair); + } + lean_io_result_mk_ok(arr) + } + })) +} + +/// Parse a dotted name string (e.g. "ISize.toInt16_ofIntLE") into a `Name`. +fn parse_name(s: &str) -> Name { + let mut name = Name::anon(); + for part in s.split('.') { + name = Name::str(name, part.to_string()); + } + name +} + +/// FFI function to type-check a single constant by name. +/// Takes the environment and a dotted name string. +/// Returns `IO (Option CheckError)` — `none` on success, `some err` on failure. +#[unsafe(no_mangle)] +pub extern "C" fn rs_check_const( + env_consts_ptr: *const c_void, + name_ptr: *const c_void, +) -> *mut c_void { + ffi_io_guard(std::panic::AssertUnwindSafe(|| { + eprintln!("[rs_check_const] entered FFI"); + let rust_env = lean_ptr_to_env(env_consts_ptr); + let name_str: &LeanStringObject = as_ref_unsafe(name_ptr.cast()); + let name = parse_name(&name_str.as_string()); + eprintln!("[rs_check_const] checking: {}", name.pretty()); + + let ci = match rust_env.get(&name) { + Some(ci) => { + match ci { + ConstantInfo::DefnInfo(d) => { + eprintln!("[rs_check_const] type: {:#?}", d.cnst.typ); + eprintln!("[rs_check_const] value: {:#?}", d.value); + eprintln!("[rs_check_const] hints: {:?}", d.hints); + }, + _ => {}, + } + ci + }, + None => { + // Return some (kernelException "not found") + let err = TcError::KernelException { + msg: format!("constant not found: {}", name.pretty()), + }; + let mut cache = LeanBuildCache::new(); + unsafe { + let err_obj = build_check_error(&mut cache, &err); + let some = lean_alloc_ctor(1, 1, 0); // Option.some + lean_ctor_set(some, 0, err_obj); + return lean_io_result_mk_ok(some); + } + }, + }; + + let mut tc = DagTypeChecker::new(&rust_env); + match tc.check_declar(ci) { + Ok(()) => unsafe { + // Option.none = ctor tag 0, 0 fields + let none = lean_alloc_ctor(0, 0, 0); + lean_io_result_mk_ok(none) + }, + Err(e) => { + let mut cache = LeanBuildCache::new(); + unsafe { + let err_obj = build_check_error(&mut cache, &e); + let some = lean_alloc_ctor(1, 1, 0); // Option.some + lean_ctor_set(some, 0, err_obj); + lean_io_result_mk_ok(some) + } + }, + } + })) +} diff --git a/src/lean/ffi/lean_env.rs b/src/lean/ffi/lean_env.rs index 3817e0e4..2562cd94 100644 --- a/src/lean/ffi/lean_env.rs +++ b/src/lean/ffi/lean_env.rs @@ -852,8 +852,10 @@ fn analyze_const_size(stt: &crate::ix::compile::CompileState, name_str: &str) { // BFS through all transitive dependencies while let Some(dep_addr) = queue.pop_front() { if let Some(dep_const) = stt.env.consts.get(&dep_addr) { - // Get the name for this dependency - let dep_name_opt = stt.env.get_name_by_addr(&dep_addr); + // Get the name for this dependency (linear scan through named entries) + let dep_name_opt = stt.env.named.iter() + .find(|entry| entry.value().addr == dep_addr) + .map(|entry| entry.key().clone()); let dep_name_str = dep_name_opt .as_ref() .map_or_else(|| format!("{:?}", dep_addr), |n| n.pretty()); From ff923998e917f5d12c67f892c133a27ae3a2d875 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 08:13:47 -0500 Subject: [PATCH 03/14] reenable printing type of erroring constants --- Ix/Kernel/Infer.lean | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 1d0b0159..0c161539 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -387,12 +387,7 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let total := items.size for h : idx in [:total] do let (addr, ci) := items[idx] - --let typ := ci.type.pp - --let val := match ci.value? with - -- | some v => s!"\n value: {v.pp}" - -- | none => "" - let (typ, val) := ("_", "_") - (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})\n type: {typ}{val}" + (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})" (← IO.getStdout).flush match typecheckConst kenv prims addr quotInit with | .ok () => @@ -400,6 +395,10 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) (← IO.getStdout).flush | .error e => let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + let typ := ci.type.pp + let val := match ci.value? with + | some v => s!"\n value: {v.pp}" + | none => "" return .error s!"{header}: {e}\n type: {typ}{val}" return .ok () From 14380d835eed56f0622b1ad324ee119462fe4800 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 08:20:21 -0500 Subject: [PATCH 04/14] move error printing to end to unhide if types are long --- Ix/Kernel/Infer.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 0c161539..cc2d89e5 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -399,7 +399,8 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let val := match ci.value? with | some v => s!"\n value: {v.pp}" | none => "" - return .error s!"{header}: {e}\n type: {typ}{val}" + IO.println s!"type: {typ}{val}" + return .error s!"{header}: {e}" return .ok () end Ix.Kernel From c77d3096feceef297e4d438b94006d67d4e7495b Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 11:41:48 -0500 Subject: [PATCH 05/14] correctness improvements and ST caching --- Ix/Kernel/Equal.lean | 13 +- Ix/Kernel/Eval.lean | 49 +- Ix/Kernel/Infer.lean | 266 ++++++++- Ix/Kernel/TypecheckM.lean | 82 +-- Tests/Ix/KernelTests.lean | 494 ++++++++++++++++- Tests/Ix/PP.lean | 26 +- src/ix/kernel/def_eq.rs | 12 +- src/ix/kernel/inductive.rs | 1041 +++++++++++++++++++++++++++++++++++- src/ix/kernel/level.rs | 85 +++ src/ix/kernel/tc.rs | 318 ++++++++++- src/ix/kernel/whnf.rs | 13 +- 11 files changed, 2275 insertions(+), 124 deletions(-) diff --git a/Ix/Kernel/Equal.lean b/Ix/Kernel/Equal.lean index 4f219b7c..a2e8db92 100644 --- a/Ix/Kernel/Equal.lean +++ b/Ix/Kernel/Equal.lean @@ -34,7 +34,7 @@ private def equalUnivArrays (us us' : Array (Level m)) : Bool := mutual /-- Try eta expansion for structure-like types. -/ - partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := do + partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m σ Bool := do match term'.get with | .app (.const k _ _) args _ => match (← get).typedConsts.get? k with @@ -59,7 +59,7 @@ mutual /-- Check if two suspended values are definitionally equal at the given level. Assumes both have the same type and live in the same context. -/ - partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := + partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m σ Bool := match term.info, term'.info with | .unit, .unit => pure true | .proof, .proof => pure true @@ -67,9 +67,10 @@ mutual if (← read).trace then dbg_trace s!"equal: {term.get.ctorName} vs {term'.get.ctorName}" -- Fast path: pointer equality on thunks if susValuePtrEq term term' then return true - -- Check equality cache + -- Check equality cache via ST.Ref let key := susValueCacheKey term term' - if let some true := (← get).equalCache.get? key then return true + let eqCache ← (← read).equalCacheRef.get + if let some true := eqCache.get? key then return true let tv := term.get let tv' := term'.get let result ← match tv, tv' with @@ -151,11 +152,11 @@ mutual dbg_trace s!"equal FALLTHROUGH at lvl={lvl}: lhs={tv.dump} rhs={tv'.dump}" pure false if result then - modify fun stt => { stt with equalCache := stt.equalCache.insert key true } + let _ ← (← read).equalCacheRef.modify fun c => c.insert key true return result /-- Check if two lists of suspended values are pointwise equal. -/ - partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m Bool := + partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m σ Bool := match vals, vals' with | val :: vals, val' :: vals' => do let eq ← equal lvl val val' diff --git a/Ix/Kernel/Eval.lean b/Ix/Kernel/Eval.lean index 9fa74125..eed16e52 100644 --- a/Ix/Kernel/Eval.lean +++ b/Ix/Kernel/Eval.lean @@ -35,7 +35,7 @@ def listGet? (l : List α) (n : Nat) : Option α := /-- Try to reduce a primitive operation if all arguments are available. -/ private def tryPrimOp (prims : Primitives) (addr : Address) - (args : List (SusValue m)) : TypecheckM m (Option (Value m)) := do + (args : List (SusValue m)) : TypecheckM m σ (Option (Value m)) := do -- Nat.succ: 1 arg if addr == prims.natSucc then if args.length >= 1 then @@ -78,7 +78,7 @@ private def tryPrimOp (prims : Primitives) (addr : Address) /-- Expand a string literal to its constructor form: String.mk (list-of-chars). Each character is represented as Char.ofNat n, and the list uses List.cons/List.nil at universe level 0. -/ -def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) := do +def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m σ (Value m) := do let charMkName ← lookupName prims.charMk let charName ← lookupName prims.char let listNilName ← lookupName prims.listNil @@ -105,7 +105,7 @@ def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) : mutual /-- Evaluate a typed expression to a value. -/ - partial def eval (t : TypedExpr m) : TypecheckM m (Value m) := withFuelCheck do + partial def eval (t : TypedExpr m) : TypecheckM m σ (Value m) := withFuelCheck do if (← read).trace then dbg_trace s!"eval: {t.body.tag}" match t.body with | .app fnc arg => do @@ -171,7 +171,7 @@ mutual pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) | e => throw s!"Value is impossible to project: {e.ctorName}" - partial def evalTyped (t : TypedExpr m) : TypecheckM m (AddInfo (TypeInfo m) (Value m)) := do + partial def evalTyped (t : TypedExpr m) : TypecheckM m σ (AddInfo (TypeInfo m) (Value m)) := do let reducedInfo := t.info.update (← read).env.univs.toArray let value ← eval t pure ⟨reducedInfo, value⟩ @@ -180,11 +180,12 @@ mutual Theorems are treated as opaque (not unfolded) — proof irrelevance handles equality of proof terms, and this avoids deep recursion through proof bodies. Caches evaluated definition bodies to avoid redundant evaluation. -/ - partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do match (← read).kenv.find? addr with | some (.defnInfo _) => - -- Check eval cache (must also match universe parameters) - if let some (cachedUnivs, cachedVal) := (← get).evalCache.get? addr then + -- Check eval cache via ST.Ref (persists across thunks) + let cache ← (← read).evalCacheRef.get + if let some (cachedUnivs, cachedVal) := cache.get? addr then if cachedUnivs == univs then return cachedVal ensureTypedConst addr match (← get).typedConsts.get? addr with @@ -192,29 +193,29 @@ mutual if part then pure (mkConst addr univs name) else let val ← withEnv (.mk [] univs.toList) (eval deref) - modify fun stt => { stt with evalCache := stt.evalCache.insert addr (univs, val) } + let _ ← (← read).evalCacheRef.modify fun c => c.insert addr (univs, val) pure val | _ => throw "Invalid const kind for evaluation" | _ => pure (mkConst addr univs name) /-- Evaluate a constant: check if it's Nat.zero, a primitive op, or unfold it. -/ - partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do let prims := (← read).prims if addr == prims.natZero then pure (.lit (.natVal 0)) else if isPrimOp prims addr then pure (mkConst addr univs name) else evalConst' addr univs name /-- Create a suspended value from a typed expression, capturing context. -/ - partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m) (stt : TypecheckState m) : SusValue m := + partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m σ) (stt : TypecheckState m) : SusValue m := let thunk : Thunk (Value m) := .mk fun _ => - match TypecheckM.run ctx stt (eval expr) with + match pureRunST (TypecheckM.run ctx stt (eval expr)) with | .ok a => a | .error e => .exception e let reducedInfo := expr.info.update ctx.env.univs.toArray ⟨reducedInfo, thunk⟩ /-- Apply a value to an argument. -/ - partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m (Value m) := do + partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m σ (Value m) := do if (← read).trace then dbg_trace s!"apply: {val.body.ctorName}" match val.body with | .lam _ bod lamEnv _ _ => @@ -233,7 +234,7 @@ mutual /-- Apply a named constant to arguments, handling recursors, quotients, and primitives. -/ partial def applyConst (addr : Address) (univs : Array (Level m)) (arg : SusValue m) (args : List (SusValue m)) (info : TypeInfo m) (infos : List (TypeInfo m)) - (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do let prims := (← read).prims -- Try primitive operations if let some result ← tryPrimOp prims addr (arg :: args) then @@ -326,7 +327,7 @@ mutual /-- Apply a quotient to a value. -/ partial def applyQuot (_prims : Primitives) (major : SusValue m) (args : List (SusValue m)) - (reduceSize argPos : Nat) (default : Value m) : TypecheckM m (Value m) := + (reduceSize argPos : Nat) (default : Value m) : TypecheckM m σ (Value m) := let argsLength := args.length + 1 if argsLength == reduceSize then match major.get with @@ -343,7 +344,7 @@ mutual else throw s!"argsLength {argsLength} can't be greater than reduceSize {reduceSize}" /-- Convert a nat literal to Nat.succ/Nat.zero constructors. -/ - partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m (Value m) + partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m σ (Value m) | .lit (.natVal 0) => do let name ← lookupName prims.natZero pure (Value.neu (.const prims.natZero #[] name)) @@ -357,7 +358,7 @@ end /-! ## Quoting (read-back from Value to Expr) -/ mutual - partial def quote (lvl : Nat) : Value m → TypecheckM m (Expr m) + partial def quote (lvl : Nat) : Value m → TypecheckM m σ (Expr m) | .sort univ => do let env := (← read).env pure (.sort (instBulkReduce env.univs.toArray univ)) @@ -379,14 +380,14 @@ mutual | .lit lit => pure (.lit lit) | .exception e => throw e - partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m (TypedExpr m) := do + partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m σ (TypedExpr m) := do pure ⟨val.info, ← quote lvl val.body⟩ - partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m (TypedExpr m) := do + partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m σ (TypedExpr m) := do let e ← quoteExpr lvl t.body env pure ⟨t.info, e⟩ - partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m (Expr m) := + partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m σ (Expr m) := match expr with | .bvar idx _ => do match listGet? env.exprs idx with @@ -421,7 +422,7 @@ mutual pure (.proj typeAddr idx struct name) | .lit .. => pure expr - partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m (Expr m) + partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m σ (Expr m) | .fvar idx name => do pure (.bvar (lvl - idx - 1) name) | .const addr univs name => do @@ -501,22 +502,22 @@ partial def foldLiterals (prims : Primitives) : Expr m → Expr m /-- Pretty-print a value by quoting it back to an Expr, then using Expr.pp. Folds Nat/String constructor chains back to literals for readability. -/ -partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m String := do +partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m σ String := do let expr ← quote lvl v let expr := foldLiterals (← read).prims expr return expr.pp /-- Pretty-print a suspended value. -/ -partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m String := +partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m σ String := ppValue lvl sv.get /-- Pretty-print a value, falling back to the shallow summary on error. -/ -partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m String := do +partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m σ String := do try ppValue lvl v catch _ => return v.summary /-- Apply a value to a list of arguments. -/ -def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m (Value m) := +def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m σ (Value m) := match args with | [] => pure v | arg :: rest => do diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index cc2d89e5..0dacf465 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -7,22 +7,102 @@ import Ix.Kernel.Equal namespace Ix.Kernel +/-! ## Inductive validation helpers -/ + +/-- Check if an expression mentions a constant at the given address. -/ +partial def exprMentionsConst (e : Expr m) (addr : Address) : Bool := + match e with + | .const a _ _ => a == addr + | .app fn arg => exprMentionsConst fn addr || exprMentionsConst arg addr + | .lam ty body _ _ => exprMentionsConst ty addr || exprMentionsConst body addr + | .forallE ty body _ _ => exprMentionsConst ty addr || exprMentionsConst body addr + | .letE ty val body _ => exprMentionsConst ty addr || exprMentionsConst val addr || exprMentionsConst body addr + | .proj _ _ s _ => exprMentionsConst s addr + | _ => false + +/-- Check strict positivity of a field type w.r.t. a set of inductive addresses. + Returns true if positive, false if negative occurrence found. -/ +partial def checkStrictPositivity (ty : Expr m) (indAddrs : Array Address) : Bool := + -- If no inductive is mentioned, we're fine + if !indAddrs.any (exprMentionsConst ty ·) then true + else match ty with + | .forallE domain body _ _ => + -- Domain must NOT mention any inductive + if indAddrs.any (exprMentionsConst domain ·) then false + -- Continue checking body + else checkStrictPositivity body indAddrs + | e => + -- Not a forall — must be the inductive at the head + let fn := e.getAppFn + match fn with + | .const addr _ _ => indAddrs.any (· == addr) + | _ => false + +/-- Walk a Pi chain, skip numParams binders, then check positivity of each field. + Returns an error message or none on success. -/ +partial def checkCtorPositivity (ctorType : Expr m) (numParams : Nat) (indAddrs : Array Address) + : Option String := + go ctorType numParams +where + go (ty : Expr m) (remainingParams : Nat) : Option String := + match ty with + | .forallE _domain body _name _bi => + if remainingParams > 0 then + go body (remainingParams - 1) + else + -- This is a field — check positivity of its domain + let domain := ty.bindingDomain! + if !checkStrictPositivity domain indAddrs then + some "inductive occurs in negative position (strict positivity violation)" + else + go body 0 + | _ => none + +/-- Walk a Pi chain past numParams + numFields binders to get the return type. + Returns the return type expression (with bvars). -/ +def getCtorReturnType (ctorType : Expr m) (numParams numFields : Nat) : Expr m := + go ctorType (numParams + numFields) +where + go (ty : Expr m) (n : Nat) : Expr m := + match n, ty with + | 0, e => e + | n+1, .forallE _ body _ _ => go body n + | _, e => e + +/-- Extract result universe level from an inductive type expression. + Walks past all forall binders to find the final Sort. -/ +def getIndResultLevel (indType : Expr m) : Option (Level m) := + go indType +where + go : Expr m → Option (Level m) + | .forallE _ body _ _ => go body + | .sort lvl => some lvl + | _ => none + +/-- Check if a level is definitively non-zero (always ≥ 1). -/ +partial def levelIsNonZero : Level m → Bool + | .succ _ => true + | .zero => false + | .param .. => false -- could be zero + | .max a b => levelIsNonZero a || levelIsNonZero b + | .imax _ b => levelIsNonZero b + /-! ## Type info helpers -/ def lamInfo : TypeInfo m → TypeInfo m | .proof => .proof | _ => .none -def piInfo (dom img : TypeInfo m) : TypecheckM m (TypeInfo m) := match dom, img with +def piInfo (dom img : TypeInfo m) : TypecheckM m σ (TypeInfo m) := match dom, img with | .sort lvl, .sort lvl' => pure (.sort (Level.reduceIMax lvl lvl')) | _, _ => pure .none -def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m Bool := do +def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m σ Bool := do match inferType.info, expectType.info with | .sort lvl, .sort lvl' => pure (Level.equalLevel lvl lvl') | _, _ => pure true -- info unavailable; defer to structural equality -def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := +def infoFromType (typ : SusValue m) : TypecheckM m σ (TypeInfo m) := match typ.info with | .sort (.zero) => pure .proof | _ => @@ -45,7 +125,7 @@ def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := mutual /-- Check that a term has a given type. -/ - partial def check (term : Expr m) (type : SusValue m) : TypecheckM m (TypedExpr m) := do + partial def check (term : Expr m) (type : SusValue m) : TypecheckM m σ (TypedExpr m) := do if (← read).trace then dbg_trace s!"check: {term.tag}" let (te, inferType) ← infer term if !(← eqSortInfo inferType type) then @@ -60,7 +140,7 @@ mutual pure te /-- Infer the type of an expression, returning the typed expression and its type. -/ - partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × SusValue m) := withFuelCheck do + partial def infer (term : Expr m) : TypecheckM m σ (TypedExpr m × SusValue m) := withFuelCheck do if (← read).trace then dbg_trace s!"infer: {term.tag}" match term with | .bvar idx bvarName => do @@ -194,7 +274,7 @@ mutual | _ => throw "Impossible case: structure type does not have enough fields" /-- Check if an expression is a Sort, returning the typed expr and the universe level. -/ - partial def isSort (expr : Expr m) : TypecheckM m (TypedExpr m × Level m) := do + partial def isSort (expr : Expr m) : TypecheckM m σ (TypedExpr m × Level m) := do let (te, typ) ← infer expr match typ.get with | .sort u => pure (te, u) @@ -204,7 +284,7 @@ mutual /-- Get structure info from a value that should be a structure type. -/ partial def getStructInfo (v : Value m) : - TypecheckM m (TypedExpr m × List (Level m) × List (SusValue m)) := do + TypecheckM m σ (TypedExpr m × List (Level m) × List (SusValue m)) := do match v with | .app (.const indAddr univs _) params _ => match (← read).kenv.find? indAddr with @@ -226,13 +306,13 @@ mutual /-- Typecheck a constant. With fresh state per declaration, dependencies get provisional entries via `ensureTypedConst` and are assumed well-typed. -/ - partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do + partial def checkConst (addr : Address) : TypecheckM m σ Unit := withResetCtx do -- Reset fuel and per-constant caches - modify fun stt => { stt with - fuel := defaultFuel - evalCache := {} - equalCache := {} - constTypeCache := {} } + modify fun stt => { stt with constTypeCache := {} } + let ctx ← read + let _ ← ctx.fuelRef.set defaultFuel + let _ ← ctx.evalCacheRef.set {} + let _ ← ctx.equalCacheRef.set {} -- Skip if already in typedConsts (provisional entry is fine — dependency assumed well-typed) if (← get).typedConsts.get? addr |>.isSome then return () @@ -286,7 +366,12 @@ mutual ensureTypedConst indAddr -- Check recursor type let (type, _) ← isSort ci.type - -- Check recursor rules + -- (#3) Validate K-flag instead of trusting the environment + if v.k then + validateKFlag v indAddr + -- (#4) Validate recursor rules + validateRecursorRules v indAddr + -- Check recursor rules (type-check RHS) let typedRules ← v.rules.mapM fun rule => do let (rhs, _) ← infer rule.rhs pure (rule.nfields, rhs) @@ -295,7 +380,7 @@ mutual /-- Walk a Pi chain to extract the return sort level (the universe of the result type). Assumes the expression ends in `Sort u` after `numBinders` forall binders. -/ - partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m (Level m) := + partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m σ (Level m) := match numBinders, expr with | 0, .sort u => do let univs := (← read).env.univs.toArray @@ -316,7 +401,7 @@ mutual | _, _ => throw "inductive type has fewer binders than expected" /-- Typecheck a mutual inductive block starting from one of its addresses. -/ - partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do + partial def checkIndBlock (addr : Address) : TypecheckM m σ Unit := do let ci ← derefConst addr -- Find the inductive info let indInfo ← match ci with @@ -337,6 +422,13 @@ mutual | some (.ctorInfo cv) => cv.numFields > 0 | _ => false modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr (TypedConst.inductive type isStruct) } + + -- Collect all inductive addresses in this mutual block + let indAddrs := iv.all + + -- Get the inductive's result universe level + let indResultLevel := getIndResultLevel iv.type + -- Check constructors for (ctorAddr, cidx) in iv.ctors.toList.zipIdx do match (← read).kenv.find? ctorAddr with @@ -344,23 +436,146 @@ mutual let ctorUnivs := cv.toConstantVal.mkUnivParams let (ctorType, _) ← withEnv (.mk [] ctorUnivs.toList) (isSort cv.type) modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cidx cv.numFields) } + + -- (#5) Check constructor parameter count matches inductive + if cv.numParams != iv.numParams then + throw s!"Constructor {ctorAddr} has {cv.numParams} params but inductive has {iv.numParams}" + + -- (#1) Positivity checking (skip for unsafe inductives) + if !iv.isUnsafe then + match checkCtorPositivity cv.type cv.numParams indAddrs with + | some msg => throw s!"Constructor {ctorAddr}: {msg}" + | none => pure () + + -- (#2) Universe constraint checking on constructor fields + -- Each non-parameter field's sort must be ≤ the inductive's result sort. + -- We check this by inferring the sort of each field type and comparing levels. + if !iv.isUnsafe then + if let some indLvl := indResultLevel then + let indLvlReduced := Level.instBulkReduce univs indLvl + checkFieldUniverses cv.type cv.numParams ctorAddr indLvlReduced + + -- (#6) Check indices in ctor return type don't mention the inductive + if !iv.isUnsafe then + let retType := getCtorReturnType cv.type cv.numParams cv.numFields + let args := retType.getAppArgs + -- Index arguments are those after numParams + for i in [iv.numParams:args.size] do + for indAddr in indAddrs do + if exprMentionsConst args[i]! indAddr then + throw s!"Constructor {ctorAddr} index argument mentions the inductive (unsound)" + | _ => throw s!"Constructor {ctorAddr} not found" -- Note: recursors are checked individually via checkConst's .recInfo branch, -- which calls checkConst on the inductives first then checks rules. + + /-- Check that constructor field types have sorts ≤ the inductive's result sort. -/ + partial def checkFieldUniverses (ctorType : Expr m) (numParams : Nat) + (ctorAddr : Address) (indLvl : Level m) : TypecheckM m σ Unit := + go ctorType numParams + where + go (ty : Expr m) (remainingParams : Nat) : TypecheckM m σ Unit := + match ty with + | .forallE dom body piName _ => + if remainingParams > 0 then do + let (domTe, _) ← isSort dom + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx var domVal (go body (remainingParams - 1)) + else do + -- This is a field — infer its sort level and check ≤ indLvl + let (domTe, fieldSortLvl) ← isSort dom + let fieldReduced := Level.reduce fieldSortLvl + let indReduced := Level.reduce indLvl + -- Allow if field ≤ ind, OR if ind is Prop (is_zero allows any field) + if !Level.leq fieldReduced indReduced 0 && !Level.isZero indReduced then + throw s!"Constructor {ctorAddr} field type lives in a universe larger than the inductive's universe" + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx var domVal (go body 0) + | _ => pure () + + /-- (#3) Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ + partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do + -- Must be non-mutual + if rec.all.size != 1 then + throw "recursor claims K but inductive is mutual" + -- Look up the inductive + match (← read).kenv.find? indAddr with + | some (.inductInfo iv) => + -- Must be in Prop + match getIndResultLevel iv.type with + | some lvl => + if levelIsNonZero lvl then + throw s!"recursor claims K but inductive is not in Prop" + | none => throw "recursor claims K but cannot determine inductive's result sort" + -- Must have single constructor + if iv.ctors.size != 1 then + throw s!"recursor claims K but inductive has {iv.ctors.size} constructors (need 1)" + -- Constructor must have zero fields + match (← read).kenv.find? iv.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields != 0 then + throw s!"recursor claims K but constructor has {cv.numFields} fields (need 0)" + | _ => throw "recursor claims K but constructor not found" + | _ => throw s!"recursor claims K but {indAddr} is not an inductive" + + /-- (#4) Validate recursor rules: check rule count, ctor membership, field counts. -/ + partial def validateRecursorRules (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do + -- Collect all constructors from the mutual block + let mut allCtors : Array Address := #[] + for iAddr in rec.all do + match (← read).kenv.find? iAddr with + | some (.inductInfo iv) => + allCtors := allCtors ++ iv.ctors + | _ => throw s!"recursor references {iAddr} which is not an inductive" + -- Check rule count + if rec.rules.size != allCtors.size then + throw s!"recursor has {rec.rules.size} rules but inductive(s) have {allCtors.size} constructors" + -- Check each rule + for h : i in [:rec.rules.size] do + let rule := rec.rules[i] + -- Rule's constructor must match expected constructor in order + if rule.ctor != allCtors[i]! then + throw s!"recursor rule {i} has constructor {rule.ctor} but expected {allCtors[i]!}" + -- Look up the constructor and validate nfields + match (← read).kenv.find? rule.ctor with + | some (.ctorInfo cv) => + if rule.nfields != cv.numFields then + throw s!"recursor rule for {rule.ctor} has nfields={rule.nfields} but constructor has {cv.numFields} fields" + | _ => throw s!"recursor rule constructor {rule.ctor} not found" + -- Validate structural counts against the inductive + match (← read).kenv.find? indAddr with + | some (.inductInfo iv) => + if rec.numParams != iv.numParams then + throw s!"recursor numParams={rec.numParams} but inductive has {iv.numParams}" + if rec.numIndices != iv.numIndices then + throw s!"recursor numIndices={rec.numIndices} but inductive has {iv.numIndices}" + | _ => pure () + end -- mutual /-! ## Top-level entry points -/ /-- Typecheck a single constant by address. -/ def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) - (quotInit : Bool := true) : Except String Unit := do - let ctx : TypecheckCtx m := { - lvl := 0, env := default, types := [], kenv := kenv, - prims := prims, safety := .safe, quotInit := quotInit, - mutTypes := default, recAddr? := none - } - let stt : TypecheckState m := { typedConsts := default } - TypecheckM.run ctx stt (checkConst addr) + (quotInit : Bool := true) : Except String Unit := + runST fun σ => do + let fuelRef ← ST.mkRef defaultFuel + let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level m) × Value m)) + let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) + let ctx : TypecheckCtx m σ := { + lvl := 0, env := default, types := [], kenv := kenv, + prims := prims, safety := .safe, quotInit := quotInit, + mutTypes := default, recAddr? := none, + fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef + } + let stt : TypecheckState m := { typedConsts := default } + TypecheckM.run ctx stt (checkConst addr) /-- Typecheck all constants in a kernel environment. Uses fresh state per declaration — dependencies are assumed well-typed. -/ @@ -399,7 +614,8 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let val := match ci.value? with | some v => s!"\n value: {v.pp}" | none => "" - IO.println s!"type: {typ}{val}" + IO.println s!"type: {typ}" + IO.println s!"val: {val}" return .error s!"{header}: {e}" return .ok () diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 8b1a93ba..9fb0d2cd 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -8,7 +8,7 @@ namespace Ix.Kernel /-! ## Typechecker Context -/ -structure TypecheckCtx (m : MetaMode) where +structure TypecheckCtx (m : MetaMode) (σ : Type) where lvl : Nat env : ValEnv m types : List (SusValue m) @@ -23,29 +23,23 @@ structure TypecheckCtx (m : MetaMode) where /-- Depth fuel: bounds the call-stack depth to prevent native stack overflow. Decremented via the reader on each entry to eval/equal/infer. Thunks inherit the depth from their capture point. -/ - depth : Nat := 3000 + depth : Nat := 10000 /-- Enable dbg_trace on major entry points for debugging. -/ trace : Bool := false - deriving Inhabited + /-- Global fuel counter: bounds total recursive work across all thunks via ST.Ref. -/ + fuelRef : ST.Ref σ Nat + /-- Mutable eval cache: persists across thunk evaluations via ST.Ref. -/ + evalCacheRef : ST.Ref σ (Std.HashMap Address (Array (Level m) × Value m)) + /-- Mutable equality cache: persists across thunk evaluations via ST.Ref. -/ + equalCacheRef : ST.Ref σ (Std.HashMap (USize × USize) Bool) /-! ## Typechecker State -/ /-- Default fuel for bounding total recursive work per constant. -/ -def defaultFuel : Nat := 100000 +def defaultFuel : Nat := 200000 structure TypecheckState (m : MetaMode) where typedConsts : Std.TreeMap Address (TypedConst m) Address.compare - /-- Fuel counter for bounding total recursive work. Decremented on each entry to - eval/equal/infer. Reset at the start of each `checkConst` call. -/ - fuel : Nat := defaultFuel - /-- Cache for evaluated constant definitions. Maps an address to its universe - parameters and evaluated value. Universe-polymorphic constants produce different - values for different universe instantiations, so we store and check univs. -/ - evalCache : Std.HashMap Address (Array (Level m) × Value m) := {} - /-- Cache for definitional equality results. Maps `(ptrAddrUnsafe a, ptrAddrUnsafe b)` - (canonicalized so smaller pointer comes first) to `Bool`. Only `true` results are - cached (monotone under state growth). -/ - equalCache : Std.HashMap (USize × USize) Bool := {} /-- Cache for constant type SusValues. When `infer (.const addr _)` computes a suspended type, it is cached here so repeated references to the same constant share the same SusValue pointer, enabling fast-path pointer equality in `equal`. @@ -55,75 +49,87 @@ structure TypecheckState (m : MetaMode) where /-! ## TypecheckM monad -/ -abbrev TypecheckM (m : MetaMode) := ReaderT (TypecheckCtx m) (StateT (TypecheckState m) (Except String)) +abbrev TypecheckM (m : MetaMode) (σ : Type) := + ReaderT (TypecheckCtx m σ) (ExceptT String (StateT (TypecheckState m) (ST σ))) + +def TypecheckM.run (ctx : TypecheckCtx m σ) (stt : TypecheckState m) + (x : TypecheckM m σ α) : ST σ (Except String α) := do + let (result, _) ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt + pure result + +def TypecheckM.runState (ctx : TypecheckCtx m σ) (stt : TypecheckState m) (x : TypecheckM m σ α) + : ST σ (Except String (α × TypecheckState m)) := do + let (result, stt') ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt + pure (match result with | .ok a => .ok (a, stt') | .error e => .error e) + +/-! ## pureRunST -/ -def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) : Except String α := - match (StateT.run (ReaderT.run x ctx) stt) with - | .error e => .error e - | .ok (a, _) => .ok a +/-- Unsafe bridge: run ST σ from pure code (for Thunk bodies). + Safe because the only side effects are append-only cache mutations. -/ +@[inline] unsafe def pureRunSTImpl {σ α : Type} [Inhabited α] (x : ST σ α) : α := + (x (unsafeCast ())).val -def TypecheckM.runState (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) - : Except String (α × TypecheckState m) := - StateT.run (ReaderT.run x ctx) stt +@[implemented_by pureRunSTImpl] +opaque pureRunST {σ α : Type} [Inhabited α] : ST σ α → α /-! ## Context modifiers -/ -def withEnv (env : ValEnv m) : TypecheckM m α → TypecheckM m α := +def withEnv (env : ValEnv m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := env } -def withResetCtx : TypecheckM m α → TypecheckM m α := +def withResetCtx : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with lvl := 0, env := default, types := default, mutTypes := default, recAddr? := none } def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare) : - TypecheckM m α → TypecheckM m α := + TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with mutTypes := mutTypes } -def withExtendedCtx (val typ : SusValue m) : TypecheckM m α → TypecheckM m α := +def withExtendedCtx (val typ : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with lvl := ctx.lvl + 1, types := typ :: ctx.types, env := ctx.env.extendWith val } -def withExtendedEnv (thunk : SusValue m) : TypecheckM m α → TypecheckM m α := +def withExtendedEnv (thunk : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := ctx.env.extendWith thunk } def withNewExtendedEnv (env : ValEnv m) (thunk : SusValue m) : - TypecheckM m α → TypecheckM m α := + TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := env.extendWith thunk } -def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := +def withRecAddr (addr : Address) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with recAddr? := some addr } /-- Check both fuel counters, decrement them, and run the action. - State fuel bounds total work (prevents exponential blowup / hanging). - Reader depth bounds call-stack depth (prevents native stack overflow). -/ -def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do +def withFuelCheck (action : TypecheckM m σ α) : TypecheckM m σ α := do let ctx ← read if ctx.depth == 0 then throw "deep recursion depth limit reached" - let stt ← get - if stt.fuel == 0 then throw "deep recursion work limit reached" - set { stt with fuel := stt.fuel - 1 } + let fuel ← ctx.fuelRef.get + if fuel == 0 then throw "deep recursion fuel limit reached" + let _ ← ctx.fuelRef.set (fuel - 1) withReader (fun ctx => { ctx with depth := ctx.depth - 1 }) action /-! ## Name lookup -/ /-- Look up the MetaField name for a constant address from the kernel environment. -/ -def lookupName (addr : Address) : TypecheckM m (MetaField m Ix.Name) := do +def lookupName (addr : Address) : TypecheckM m σ (MetaField m Ix.Name) := do match (← read).kenv.find? addr with | some ci => pure ci.cv.name | none => pure default /-! ## Const dereferencing -/ -def derefConst (addr : Address) : TypecheckM m (ConstantInfo m) := do +def derefConst (addr : Address) : TypecheckM m σ (ConstantInfo m) := do let ctx ← read match ctx.kenv.find? addr with | some ci => pure ci | none => throw s!"unknown constant {addr}" -def derefTypedConst (addr : Address) : TypecheckM m (TypedConst m) := do +def derefTypedConst (addr : Address) : TypecheckM m σ (TypedConst m) := do match (← get).typedConsts.get? addr with | some tc => pure tc | none => throw s!"typed constant not found: {addr}" @@ -170,7 +176,7 @@ def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := /-- Ensure a constant has a TypedConst entry. If not already present, build a provisional one from raw ConstantInfo. This avoids the deep recursion of `checkConst` when called from `infer`. -/ -def ensureTypedConst (addr : Address) : TypecheckM m Unit := do +def ensureTypedConst (addr : Address) : TypecheckM m σ Unit := do if (← get).typedConsts.get? addr |>.isSome then return () let ci ← derefConst addr let tc := provisionalTypedConst ci diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index f1ed3c55..b14dbff4 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -131,10 +131,38 @@ def testLevelOps : TestSeq := /-! ## Integration tests: Const pipeline -/ -/-- Parse a dotted name string like "Nat.add" into an Ix.Name. -/ -private def parseIxName (s : String) : Ix.Name := - let parts := s.splitOn "." - parts.foldl (fun acc part => Ix.Name.mkStr acc part) Ix.Name.mkAnon +/-- Parse a dotted name string like "Nat.add" into an Ix.Name. + Handles `«...»` quoted name components (e.g. `Foo.«0».Bar`). -/ +private partial def parseIxName (s : String) : Ix.Name := + let parts := splitParts s.toList [] + parts.foldl (fun acc part => + match part with + | .inl str => Ix.Name.mkStr acc str + | .inr nat => Ix.Name.mkNat acc nat + ) Ix.Name.mkAnon +where + /-- Split a dotted name into parts: .inl for string components, .inr for numeric (guillemet). -/ + splitParts : List Char → List (String ⊕ Nat) → List (String ⊕ Nat) + | [], acc => acc + | '.' :: rest, acc => splitParts rest acc + | '«' :: rest, acc => + let (inside, rest') := collectUntilClose rest "" + let part := match inside.toNat? with + | some n => .inr n + | none => .inl inside + splitParts rest' (acc ++ [part]) + | cs, acc => + let (word, rest) := collectUntilDot cs "" + splitParts rest (if word.isEmpty then acc else acc ++ [.inl word]) + collectUntilClose : List Char → String → String × List Char + | [], s => (s, []) + | '»' :: rest, s => (s, rest) + | c :: rest, s => collectUntilClose rest (s.push c) + collectUntilDot : List Char → String → String × List Char + | [], s => (s, []) + | '.' :: rest, s => (s, '.' :: rest) + | '«' :: rest, s => (s, '«' :: rest) + | c :: rest, s => collectUntilDot rest (s.push c) /-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ private partial def leanNameToIx : Lean.Name → Ix.Name @@ -605,6 +633,461 @@ def negativeTests : TestSeq := return (false, some s!"{failures.size} failure(s)") ) .done +/-! ## Soundness negative tests (inductive validation) -/ + +/-- Helper: make unique addresses from a seed byte. -/ +private def mkAddr (seed : UInt8) : Address := + Address.blake3 (ByteArray.mk #[seed, 0xAA, 0xBB]) + +/-- Soundness negative test suite: verify that the typechecker rejects unsound + inductive declarations (positivity, universe constraints, K-flag, recursor rules). -/ +def soundnessNegativeTests : TestSeq := + .individualIO "kernel soundness negative tests" (do + let prims := buildPrimitives + let mut passed := 0 + let mut failures : Array String := #[] + + -- ======================================================================== + -- Test 1: Positivity violation — Bad | mk : (Bad → Bad) → Bad + -- The inductive appears in negative position (Pi domain). + -- ======================================================================== + do + let badAddr := mkAddr 10 + let badMkAddr := mkAddr 11 + let badType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let badCv : ConstantVal .anon := + { numLevels := 0, type := badType, name := (), levelParams := () } + let badInd : ConstantInfo .anon := .inductInfo { + toConstantVal := badCv, numParams := 0, numIndices := 0, + all := #[badAddr], ctors := #[badMkAddr], numNested := 0, + isRec := true, isUnsafe := false, isReflexive := false + } + -- mk : (Bad → Bad) → Bad + -- The domain (Bad → Bad) has Bad in negative position + let mkType : Expr .anon := + .forallE + (.forallE (.const badAddr #[] ()) (.const badAddr #[] ()) () ()) + (.const badAddr #[] ()) + () () + let mkCv : ConstantVal .anon := + { numLevels := 0, type := mkType, name := (), levelParams := () } + let mkCtor : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := badAddr, cidx := 0, + numParams := 0, numFields := 1, isUnsafe := false + } + let env := ((default : Env .anon).insert badAddr badInd).insert badMkAddr mkCtor + match typecheckConst env prims badAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "positivity-violation: expected error (Bad → Bad in domain)" + + -- ======================================================================== + -- Test 2: Universe constraint violation — Uni1Bad : Sort 1 | mk : Sort 2 → Uni1Bad + -- Field lives in Sort 3 (Sort 2 : Sort 3) but inductive is in Sort 1. + -- (Note: Prop inductives have special exception allowing any field universe, + -- so we test with a Sort 1 inductive instead.) + -- ======================================================================== + do + let ubAddr := mkAddr 20 + let ubMkAddr := mkAddr 21 + let ubType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let ubCv : ConstantVal .anon := + { numLevels := 0, type := ubType, name := (), levelParams := () } + let ubInd : ConstantInfo .anon := .inductInfo { + toConstantVal := ubCv, numParams := 0, numIndices := 0, + all := #[ubAddr], ctors := #[ubMkAddr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + -- mk : Sort 2 → Uni1Bad + -- Sort 2 : Sort 3, so field sort = 3. Inductive sort = 1. 3 ≤ 1 fails. + let mkType : Expr .anon := + .forallE (.sort (.succ (.succ .zero))) (.const ubAddr #[] ()) () () + let mkCv : ConstantVal .anon := + { numLevels := 0, type := mkType, name := (), levelParams := () } + let mkCtor : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := ubAddr, cidx := 0, + numParams := 0, numFields := 1, isUnsafe := false + } + let env := ((default : Env .anon).insert ubAddr ubInd).insert ubMkAddr mkCtor + match typecheckConst env prims ubAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "universe-constraint: expected error (Sort 2 field in Sort 1 inductive)" + + -- ======================================================================== + -- Test 3: K-flag invalid — K=true on non-Prop inductive (Sort 1, 2 ctors) + -- ======================================================================== + do + let indAddr := mkAddr 30 + let mk1Addr := mkAddr 31 + let mk2Addr := mkAddr 32 + let recAddr := mkAddr 33 + let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 (not Prop) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + -- Recursor with k=true on a non-Prop inductive + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ], + k := true, -- INVALID: not Prop + isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "k-flag-not-prop: expected error" + + -- ======================================================================== + -- Test 4: Recursor wrong rule count — 1 rule for 2-ctor inductive + -- ======================================================================== + do + let indAddr := mkAddr 40 + let mk1Addr := mkAddr 41 + let mk2Addr := mkAddr 42 + let recAddr := mkAddr 43 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + -- Recursor with only 1 rule (should be 2) + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[{ ctor := mk1Addr, nfields := 0, rhs := .sort .zero }], -- only 1! + k := false, isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-rule-count: expected error" + + -- ======================================================================== + -- Test 5: Recursor wrong nfields — ctor has 0 fields but rule claims 5 + -- ======================================================================== + do + let indAddr := mkAddr 50 + let mkAddr' := mkAddr 51 + let recAddr := mkAddr 52 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 1, + rules := #[{ ctor := mkAddr', nfields := 5, rhs := .sort .zero }], -- wrong nfields + k := false, isUnsafe := false + } + let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-nfields: expected error" + + -- ======================================================================== + -- Test 6: Recursor wrong num_params — rec claims 5 params, inductive has 0 + -- ======================================================================== + do + let indAddr := mkAddr 60 + let mkAddr' := mkAddr 61 + let recAddr := mkAddr 62 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 5, -- wrong: inductive has 0 + numIndices := 0, numMotives := 1, numMinors := 1, + rules := #[{ ctor := mkAddr', nfields := 0, rhs := .sort .zero }], + k := false, isUnsafe := false + } + let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-num-params: expected error" + + -- ======================================================================== + -- Test 7: Constructor param count mismatch — ctor claims 3 params, ind has 0 + -- ======================================================================== + do + let indAddr := mkAddr 70 + let mkAddr' := mkAddr 71 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 3, -- wrong: inductive has 0 + numFields := 0, isUnsafe := false + } + let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI + match typecheckConst env prims indAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "ctor-param-mismatch: expected error" + + -- ======================================================================== + -- Test 8: K-flag invalid — K=true on Prop inductive with 2 ctors + -- ======================================================================== + do + let indAddr := mkAddr 80 + let mk1Addr := mkAddr 81 + let mk2Addr := mkAddr 82 + let recAddr := mkAddr 83 + let indType : Expr .anon := .sort .zero -- Prop + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 0, type := .sort .zero, name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ], + k := true, -- INVALID: 2 ctors + isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "k-flag-two-ctors: expected error" + + -- ======================================================================== + -- Test 9: Recursor wrong ctor order — rules in wrong order + -- ======================================================================== + do + let indAddr := mkAddr 90 + let mk1Addr := mkAddr 91 + let mk2Addr := mkAddr 92 + let recAddr := mkAddr 93 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero }, -- wrong order! + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero } + ], + k := false, isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-ctor-order: expected error" + + -- ======================================================================== + -- Test 10: Valid single-ctor inductive passes (sanity check) + -- ======================================================================== + do + let indAddr := mkAddr 100 + let mkAddr' := mkAddr 101 + let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI + match typecheckConst env prims indAddr with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"valid-inductive: unexpected error: {e}" + + let totalTests := 10 + IO.println s!"[kernel-soundness] {passed}/{totalTests} passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Unit tests: helper functions -/ + +def testHelperFunctions : TestSeq := + -- exprMentionsConst + let addr1 := mkAddr 200 + let addr2 := mkAddr 201 + let c1 : Expr .anon := .const addr1 #[] () + let c2 : Expr .anon := .const addr2 #[] () + test "exprMentionsConst: direct match" + (exprMentionsConst c1 addr1) ++ + test "exprMentionsConst: no match" + (!exprMentionsConst c2 addr1) ++ + test "exprMentionsConst: in app fn" + (exprMentionsConst (.app c1 c2) addr1) ++ + test "exprMentionsConst: in app arg" + (exprMentionsConst (.app c2 c1) addr1) ++ + test "exprMentionsConst: in forallE domain" + (exprMentionsConst (.forallE c1 c2 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: in forallE body" + (exprMentionsConst (.forallE c2 c1 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: in lam" + (exprMentionsConst (.lam c1 c2 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: absent in sort" + (!exprMentionsConst (.sort .zero : Expr .anon) addr1) ++ + test "exprMentionsConst: absent in bvar" + (!exprMentionsConst (.bvar 0 () : Expr .anon) addr1) ++ + -- checkStrictPositivity + let indAddrs := #[addr1] + test "checkStrictPositivity: no mention is positive" + (checkStrictPositivity c2 indAddrs) ++ + test "checkStrictPositivity: head occurrence is positive" + (checkStrictPositivity c1 indAddrs) ++ + test "checkStrictPositivity: in Pi domain is negative" + (!checkStrictPositivity (.forallE c1 c2 () () : Expr .anon) indAddrs) ++ + test "checkStrictPositivity: in Pi codomain positive" + (checkStrictPositivity (.forallE c2 c1 () () : Expr .anon) indAddrs) ++ + -- getIndResultLevel + test "getIndResultLevel: sort zero" + (getIndResultLevel (.sort .zero : Expr .anon) == some .zero) ++ + test "getIndResultLevel: sort (succ zero)" + (getIndResultLevel (.sort (.succ .zero) : Expr .anon) == some (.succ .zero)) ++ + test "getIndResultLevel: forallE _ (sort zero)" + (getIndResultLevel (.forallE (.sort .zero) (.sort (.succ .zero)) () () : Expr .anon) == some (.succ .zero)) ++ + test "getIndResultLevel: bvar (no sort)" + (getIndResultLevel (.bvar 0 () : Expr .anon) == none) ++ + -- levelIsNonZero + test "levelIsNonZero: zero is false" + (!levelIsNonZero (.zero : Level .anon)) ++ + test "levelIsNonZero: succ zero is true" + (levelIsNonZero (.succ .zero : Level .anon)) ++ + test "levelIsNonZero: param is false" + (!levelIsNonZero (.param 0 () : Level .anon)) ++ + test "levelIsNonZero: max(succ 0, param) is true" + (levelIsNonZero (.max (.succ .zero) (.param 0 ()) : Level .anon)) ++ + test "levelIsNonZero: imax(param, succ 0) is true" + (levelIsNonZero (.imax (.param 0 ()) (.succ .zero) : Level .anon)) ++ + test "levelIsNonZero: imax(succ, param) depends on second" + (!levelIsNonZero (.imax (.succ .zero) (.param 0 ()) : Level .anon)) ++ + -- checkCtorPositivity + test "checkCtorPositivity: no inductive mention is ok" + (checkCtorPositivity c2 0 indAddrs == none) ++ + test "checkCtorPositivity: negative occurrence" + (checkCtorPositivity (.forallE (.forallE c1 c2 () ()) (.const addr1 #[] ()) () () : Expr .anon) 0 indAddrs != none) ++ + -- getCtorReturnType + test "getCtorReturnType: no binders returns expr" + (getCtorReturnType c1 0 0 == c1) ++ + test "getCtorReturnType: skips foralls" + (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) ++ + .done + /-! ## Focused NbE constant tests -/ /-- Test individual constants through the NbE kernel to isolate failures. -/ @@ -631,6 +1114,7 @@ def testNbeConsts : TestSeq := "Nat.Linear.Poly.of_denote_eq_cancel", -- String theorem (fuel-sensitive) "String.length_empty", + "_private.Init.Grind.Ring.Basic.«0».Lean.Grind.IsCharP.mk'_aux._proof_1_5", ] let mut passed := 0 let mut failures : Array String := #[] @@ -673,6 +1157,7 @@ def unitSuite : List TestSeq := [ testLevelLeqComplex, testLevelInstBulkReduce, testReducibilityHintsLt, + testHelperFunctions, ] def convertSuite : List TestSeq := [ @@ -686,6 +1171,7 @@ def constSuite : List TestSeq := [ def negativeSuite : List TestSeq := [ negativeTests, + soundnessNegativeTests, ] def anonConvertSuite : List TestSeq := [ diff --git a/Tests/Ix/PP.lean b/Tests/Ix/PP.lean index d96bd0f1..ab52ea3e 100644 --- a/Tests/Ix/PP.lean +++ b/Tests/Ix/PP.lean @@ -248,22 +248,30 @@ def testQuoteRoundtrip : TestSeq := -- Build Value.lam: fun (y : Nat) => y let bodyTE : TypedExpr .meta := ⟨.none, .bvar 0 yName⟩ let lamVal : Value .meta := .lam domVal bodyTE (.mk [] []) yName .default - -- Quote and pp in a minimal TypecheckM context - let ctx : TypecheckCtx .meta := { - lvl := 0, env := .mk [] [], types := [], - kenv := default, prims := buildPrimitives, - safety := .safe, quotInit := true, mutTypes := default, recAddr? := none - } - let stt : TypecheckState .meta := { typedConsts := default } + -- Quote and pp in a minimal TypecheckM context (wrapped in runST for ST.Ref allocation) + let result := runST fun σ => do + let fuelRef ← ST.mkRef Ix.Kernel.defaultFuel + let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level .meta) × Value .meta)) + let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) + let ctx : TypecheckCtx .meta σ := { + lvl := 0, env := .mk [] [], types := [], + kenv := default, prims := buildPrimitives, + safety := .safe, quotInit := true, mutTypes := default, recAddr? := none, + fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef + } + let stt : TypecheckState .meta := { typedConsts := default } + let piResult ← TypecheckM.run ctx stt (ppValue 0 piVal) + let lamResult ← TypecheckM.run ctx stt (ppValue 0 lamVal) + pure (piResult, lamResult) -- Test pi - match TypecheckM.run ctx stt (ppValue 0 piVal) with + match result.1 with | .ok s => if s != "∀ (x : Nat), Nat" then return (false, some s!"pi round-trip: expected '∀ (x : Nat), Nat', got '{s}'") else pure () | .error e => return (false, some s!"pi round-trip error: {e}") -- Test lam - match TypecheckM.run ctx stt (ppValue 0 lamVal) with + match result.2 with | .ok s => if s != "λ (y : Nat) => y" then return (false, some s!"lam round-trip: expected 'λ (y : Nat) => y', got '{s}'") diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs index ada12904..0cc24620 100644 --- a/src/ix/kernel/def_eq.rs +++ b/src/ix/kernel/def_eq.rs @@ -530,9 +530,8 @@ fn get_applied_def( Some((name.clone(), d.hints)) } }, - ConstantInfo::ThmInfo(_) => { - Some((name.clone(), ReducibilityHints::Opaque)) - }, + // Theorems are never unfolded — proof irrelevance handles them. + // ConstantInfo::ThmInfo(_) => return None, _ => None, } } @@ -1570,13 +1569,12 @@ mod tests { } #[test] - fn test_get_applied_def_includes_theorems_as_opaque() { + fn test_get_applied_def_excludes_theorems() { + // Theorems should never be unfolded — proof irrelevance handles them. let env = mk_thm_env(); let thm = Expr::cnst(mk_name("thmA"), vec![]); let result = get_applied_def(&thm, &env); - assert!(result.is_some()); - let (_, hints) = result.unwrap(); - assert_eq!(hints, ReducibilityHints::Opaque); + assert!(result.is_none()); } // ========================================================================== diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs index 4cf79d45..90da54ba 100644 --- a/src/ix/kernel/inductive.rs +++ b/src/ix/kernel/inductive.rs @@ -155,6 +155,216 @@ pub fn validate_k_flag( Ok(()) } +/// Validate recursor rules against the inductive's constructors. +/// Checks: +/// - One rule per constructor +/// - Each rule's constructor exists and belongs to the inductive +/// - Each rule's n_fields matches the constructor's actual field count +/// - Rules are in constructor order +pub fn validate_recursor_rules( + rec: &RecursorVal, + env: &Env, +) -> TcResult<()> { + // Find the primary inductive + if rec.all.is_empty() { + return Err(TcError::KernelException { + msg: "recursor has no associated inductives".into(), + }); + } + let ind_name = &rec.all[0]; + let ind = match env.get(ind_name) { + Some(ConstantInfo::InductInfo(iv)) => iv, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor's inductive {} is not an inductive type", + ind_name.pretty() + ), + }) + }, + }; + + // For mutual inductives, collect all constructors in order + let mut all_ctors: Vec = Vec::new(); + for iname in &rec.all { + if let Some(ConstantInfo::InductInfo(iv)) = env.get(iname) { + all_ctors.extend(iv.ctors.iter().cloned()); + } + } + + // Check rule count matches total constructor count + if rec.rules.len() != all_ctors.len() { + return Err(TcError::KernelException { + msg: format!( + "recursor has {} rules but inductive(s) have {} constructors", + rec.rules.len(), + all_ctors.len() + ), + }); + } + + // Check each rule + for (i, rule) in rec.rules.iter().enumerate() { + // Rule's constructor must match expected constructor in order + if rule.ctor != all_ctors[i] { + return Err(TcError::KernelException { + msg: format!( + "recursor rule {} has constructor {} but expected {}", + i, + rule.ctor.pretty(), + all_ctors[i].pretty() + ), + }); + } + + // Look up the constructor and validate n_fields + let ctor = match env.get(&rule.ctor) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor rule constructor {} not found or not a constructor", + rule.ctor.pretty() + ), + }) + }, + }; + + if rule.n_fields != ctor.num_fields { + return Err(TcError::KernelException { + msg: format!( + "recursor rule for {} has n_fields={} but constructor has {} fields", + rule.ctor.pretty(), + rule.n_fields, + ctor.num_fields + ), + }); + } + } + + // Validate structural counts against the inductive + let expected_params = ind.num_params.to_u64().unwrap(); + let rec_params = rec.num_params.to_u64().unwrap(); + if rec_params != expected_params { + return Err(TcError::KernelException { + msg: format!( + "recursor num_params={} but inductive has {} params", + rec_params, expected_params + ), + }); + } + + let expected_indices = ind.num_indices.to_u64().unwrap(); + let rec_indices = rec.num_indices.to_u64().unwrap(); + if rec_indices != expected_indices { + return Err(TcError::KernelException { + msg: format!( + "recursor num_indices={} but inductive has {} indices", + rec_indices, expected_indices + ), + }); + } + + // Validate elimination restriction for Prop inductives. + // If the inductive is in Prop and requires elimination only at universe zero, + // then the recursor must not have extra universe parameters beyond the inductive's. + if !rec.is_unsafe { + if let Some(elim_zero) = elim_only_at_universe_zero(ind, env) { + if elim_zero { + // Recursor should have same number of level params as the inductive + // (no extra universe parameter for the motive's result sort) + let ind_level_count = ind.cnst.level_params.len(); + let rec_level_count = rec.cnst.level_params.len(); + if rec_level_count > ind_level_count { + return Err(TcError::KernelException { + msg: format!( + "recursor has {} universe params but inductive has {} — \ + large elimination is not allowed for this Prop inductive", + rec_level_count, ind_level_count + ), + }); + } + } + } + } + + Ok(()) +} + +/// Compute whether a Prop inductive can only eliminate to Prop (universe zero). +/// +/// Returns `Some(true)` if elimination is restricted to Prop, +/// `Some(false)` if large elimination is allowed, +/// `None` if the inductive is not in Prop (no restriction applies). +/// +/// Matches the C++ kernel's `elim_only_at_universe_zero`: +/// 1. If result universe is always non-zero: None (not a predicate) +/// 2. If mutual: restricted +/// 3. If >1 constructor: restricted +/// 4. If 0 constructors: not restricted (e.g., False) +/// 5. If 1 constructor: restricted iff any non-Prop field doesn't appear in result indices +fn elim_only_at_universe_zero( + ind: &InductiveVal, + env: &Env, +) -> Option { + // Check if the inductive's result is in Prop. + // Walk past all binders to find the final Sort. + let mut ty = ind.cnst.typ.clone(); + loop { + match ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ty = body.clone(); + }, + _ => break, + } + } + let result_level = match ty.as_data() { + ExprData::Sort(l, _) => l, + _ => return None, + }; + + // If the result sort is definitively non-zero (e.g., Sort 1, Sort (u+1)), + // this is not a predicate. + if !level::could_be_zero(result_level) { + return None; + } + + // Must be possibly Prop. Apply the 5 conditions. + + // Condition 2: Mutual inductives → restricted + if ind.all.len() > 1 { + return Some(true); + } + + // Condition 3: >1 constructor → restricted + if ind.ctors.len() > 1 { + return Some(true); + } + + // Condition 4: 0 constructors → not restricted (e.g., False) + if ind.ctors.is_empty() { + return Some(false); + } + + // Condition 5: Single constructor — check fields + let ctor = match env.get(&ind.ctors[0]) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return Some(true), // can't look up ctor, be conservative + }; + + // If zero fields, not restricted + if ctor.num_fields == Nat::ZERO { + return Some(false); + } + + // For single-constructor with fields: restricted if any non-Prop field + // doesn't appear in the result type's indices. + // Conservative approximation: if any field exists that could be non-Prop, + // assume restricted. This is safe (may reject some valid large eliminations + // but never allows unsound ones). + Some(true) +} + /// Check if an expression mentions a constant by name. fn expr_mentions_const(e: &Expr, name: &Name) -> bool { let mut stack: Vec<&Expr> = vec![e]; @@ -364,14 +574,33 @@ fn check_field_universe_constraints( /// Verify that a constructor's return type targets the parent inductive. /// Walks the constructor type telescope, then checks that the resulting /// type is an application of the parent inductive with at least `num_params` args. +/// Also validates: +/// - The first `num_params` arguments are definitionally equal to the inductive's parameters. +/// - Index arguments (after params) don't mention the inductive being declared. fn check_ctor_return_type( ctor: &ConstructorVal, ind: &InductiveVal, tc: &mut TypeChecker, ) -> TcResult<()> { - let mut ty = ctor.cnst.typ.clone(); + let num_params = ind.num_params.to_u64().unwrap() as usize; + + // Walk the inductive's type telescope to collect parameter locals. + let mut ind_ty = ind.cnst.typ.clone(); + let mut param_locals = Vec::with_capacity(num_params); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + param_locals.push(local.clone()); + ind_ty = inst(body, &[local]); + }, + _ => break, + } + } - // Walk past all Pi binders + // Walk past all Pi binders in the constructor type. + let mut ty = ctor.cnst.typ.clone(); loop { let whnf_ty = tc.whnf(&ty); match whnf_ty.as_data() { @@ -411,7 +640,6 @@ fn check_ctor_return_type( }); } - let num_params = ind.num_params.to_u64().unwrap() as usize; if args.len() < num_params { return Err(TcError::KernelException { msg: format!( @@ -423,6 +651,35 @@ fn check_ctor_return_type( }); } + // Check that the first num_params arguments match the inductive's parameters. + for i in 0..num_params { + if i < param_locals.len() && !tc.def_eq(&args[i], ¶m_locals[i]) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} parameter {} does not match inductive's parameter", + ctor.cnst.name.pretty(), + i + ), + }); + } + } + + // Check that index arguments (after params) don't mention the inductive. + for i in num_params..args.len() { + for ind_name in &ind.all { + if expr_mentions_const(&args[i], ind_name) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} index argument {} mentions the inductive {}", + ctor.cnst.name.pretty(), + i - num_params, + ind_name.pretty() + ), + }); + } + } + } + Ok(()) } @@ -784,4 +1041,782 @@ mod tests { let mut tc = TypeChecker::new(&env); assert!(check_inductive(ind, &mut tc).is_err()); } + + // ========================================================================== + // Recursor rule validation + // ========================================================================== + + #[test] + fn validate_rec_rules_wrong_count() { + // Nat has 2 ctors but we provide 1 rule + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_wrong_ctor_order() { + // Provide rules in wrong order (succ first, zero second) + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_wrong_nfields() { + // zero has 0 fields but we claim 3 + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(3u64), // wrong! + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_bogus_ctor() { + // Rule references a non-existent constructor + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "bogus"), // doesn't exist + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_correct() { + // Correct rules for Nat + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_ok()); + } + + #[test] + fn validate_rec_rules_wrong_num_params() { + // Recursor claims 5 params but Nat has 0 + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(5u64), // wrong + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + // ========================================================================== + // K-flag validation + // ========================================================================== + + /// Build a Prop inductive with 1 ctor and 0 fields (Eq-like). + fn mk_k_valid_env() -> Env { + let mut env = mk_nat_env(); + let eq_name = mk_name("KEq"); + let eq_refl = mk_name2("KEq", "refl"); + let u = mk_name("u"); + + // KEq.{u} (α : Sort u) (a b : α) : Prop + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), // Prop + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + // KEq.refl.{u} (α : Sort u) (a : α) : KEq α a a + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn validate_k_flag_valid_prop_single_zero_fields() { + let env = mk_k_valid_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("KEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("KEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![RecursorRule { + ctor: mk_name2("KEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_ok()); + } + + #[test] + fn validate_k_flag_fails_not_prop() { + // Nat is in Sort 1, not Prop — K should fail + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + #[test] + fn validate_k_flag_fails_multiple_ctors() { + // Even a Prop inductive with 2 ctors can't be K + // We need a Prop inductive with 2 ctors for this test + let mut env = Env::default(); + let p_name = mk_name("P"); + let mk1 = mk_name2("P", "mk1"); + let mk2 = mk_name2("P", "mk2"); + env.insert( + p_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), // Prop + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![p_name.clone()], + ctors: vec![mk1.clone(), mk2.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + mk1.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk1, + level_params: vec![], + typ: Expr::cnst(p_name.clone(), vec![]), + }, + induct: p_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env.insert( + mk2.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk2, + level_params: vec![], + typ: Expr::cnst(p_name.clone(), vec![]), + }, + induct: p_name, + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("P", "rec"), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + all: vec![mk_name("P")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + #[test] + fn validate_k_flag_false_always_ok() { + // k=false is always conservative, never rejected + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: false, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_ok()); + } + + #[test] + fn validate_k_flag_fails_mutual() { + // K requires all.len() == 1 + let env = mk_k_valid_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("KEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("KEq"), mk_name("OtherInd")], // mutual + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + // ========================================================================== + // Elimination restriction + // ========================================================================== + + #[test] + fn elim_restriction_non_prop_is_none() { + // Nat is in Sort 1, not Prop — no restriction applies + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), None); + } + + #[test] + fn elim_restriction_prop_2_ctors_restricted() { + // A Prop inductive with 2 constructors: restricted to Prop elimination + let mut env = Env::default(); + let p_name = mk_name("P2"); + let mk1 = mk_name2("P2", "mk1"); + let mk2 = mk_name2("P2", "mk2"); + env.insert( + p_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![p_name.clone()], + ctors: vec![mk1.clone(), mk2.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert(mk1.clone(), ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { name: mk1, level_params: vec![], typ: Expr::cnst(p_name.clone(), vec![]) }, + induct: p_name.clone(), cidx: Nat::from(0u64), num_params: Nat::from(0u64), num_fields: Nat::from(0u64), is_unsafe: false, + })); + env.insert(mk2.clone(), ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { name: mk2, level_params: vec![], typ: Expr::cnst(p_name.clone(), vec![]) }, + induct: p_name.clone(), cidx: Nat::from(1u64), num_params: Nat::from(0u64), num_fields: Nat::from(0u64), is_unsafe: false, + })); + let ind = match env.get(&p_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(true)); + } + + #[test] + fn elim_restriction_prop_0_ctors_not_restricted() { + // Empty Prop inductive (like False): can eliminate to any universe + let env_name = mk_name("MyFalse"); + let mut env = Env::default(); + env.insert( + env_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: env_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![env_name.clone()], + ctors: vec![], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let ind = match env.get(&env_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(false)); + } + + #[test] + fn elim_restriction_prop_1_ctor_0_fields_not_restricted() { + // Prop inductive, 1 ctor, 0 fields (like True): not restricted + let mut env = Env::default(); + let t_name = mk_name("MyTrue"); + let t_mk = mk_name2("MyTrue", "intro"); + env.insert( + t_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: t_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![t_name.clone()], + ctors: vec![t_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + t_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: t_mk, + level_params: vec![], + typ: Expr::cnst(t_name.clone(), vec![]), + }, + induct: t_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let ind = match env.get(&t_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(false)); + } + + #[test] + fn elim_restriction_prop_1_ctor_with_fields_restricted() { + // Prop inductive, 1 ctor with fields: conservatively restricted + // (like Exists) + let mut env = Env::default(); + let ex_name = mk_name("MyExists"); + let ex_mk = mk_name2("MyExists", "intro"); + // For simplicity: MyExists : Prop, MyExists.intro : Prop → MyExists + // (simplified from the real Exists which is polymorphic) + env.insert( + ex_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: ex_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![ex_name.clone()], + ctors: vec![ex_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + ex_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: ex_mk, + level_params: vec![], + typ: Expr::all( + mk_name("h"), + Expr::sort(Level::zero()), // a Prop field + Expr::cnst(ex_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: ex_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + let ind = match env.get(&ex_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + // Conservative: any fields means restricted + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(true)); + } + + // ========================================================================== + // Index-mentions-inductive check + // ========================================================================== + + #[test] + fn index_mentions_inductive_rejected() { + // Construct an inductive with 1 param and 1 index where the index + // mentions the inductive itself. This should be rejected. + // + // inductive Bad (α : Type) : Bad α → Type + // | mk : Bad α + // + // The ctor return type is `Bad α (Bad.mk α)`, but for the test + // we manually build a ctor whose index arg mentions `Bad`. + let mut env = mk_nat_env(); + let bad_name = mk_name("BadIdx"); + let bad_mk = mk_name2("BadIdx", "mk"); + + // BadIdx (α : Sort 1) : Sort 1 + // (For simplicity, we make it have 1 param and 1 index) + env.insert( + bad_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bad_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("α"), + Expr::sort(Level::succ(Level::zero())), + Expr::all( + mk_name("_idx"), + nat_type(), // index of type Nat + Expr::sort(Level::succ(Level::zero())), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + }, + num_params: Nat::from(1u64), + num_indices: Nat::from(1u64), + all: vec![bad_name.clone()], + ctors: vec![bad_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // BadIdx.mk (α : Sort 1) : BadIdx α + // The return type's index argument mentions BadIdx + let bad_idx_expr = Expr::app( + Expr::cnst(bad_name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), // dummy + ); + let ctor_ret = Expr::app( + Expr::app( + Expr::cnst(bad_name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), // param α + ), + bad_idx_expr, // index mentions BadIdx! + ); + env.insert( + bad_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: bad_mk, + level_params: vec![], + typ: Expr::all( + mk_name("α"), + Expr::sort(Level::succ(Level::zero())), + ctor_ret, + BinderInfo::Default, + ), + }, + induct: bad_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(1u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&bad_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // expr_mentions_const + // ========================================================================== + + #[test] + fn expr_mentions_const_direct() { + let name = mk_name("Foo"); + let e = Expr::cnst(name.clone(), vec![]); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_nested_app() { + let name = mk_name("Foo"); + let e = Expr::app( + Expr::cnst(mk_name("bar"), vec![]), + Expr::cnst(name.clone(), vec![]), + ); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_absent() { + let name = mk_name("Foo"); + let e = Expr::app( + Expr::cnst(mk_name("bar"), vec![]), + Expr::cnst(mk_name("baz"), vec![]), + ); + assert!(!expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_in_forall_domain() { + let name = mk_name("Foo"); + let e = Expr::all( + mk_name("x"), + Expr::cnst(name.clone(), vec![]), + Expr::sort(Level::zero()), + BinderInfo::Default, + ); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_in_let() { + let name = mk_name("Foo"); + let e = Expr::letE( + mk_name("x"), + Expr::sort(Level::zero()), + Expr::cnst(name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(expr_mentions_const(&e, &name)); + } } diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs index 80195e35..624f8fb2 100644 --- a/src/ix/kernel/level.rs +++ b/src/ix/kernel/level.rs @@ -54,6 +54,23 @@ pub fn is_zero(l: &Level) -> bool { leq(l, &Level::zero()) } +/// Check if a level could possibly be zero (i.e., not definitively non-zero). +/// Returns false only if the level is guaranteed to be ≥ 1 for all parameter assignments. +pub fn could_be_zero(l: &Level) -> bool { + let s = simplify(l); + could_be_zero_core(&s) +} + +fn could_be_zero_core(l: &Level) -> bool { + match l.as_data() { + LevelData::Zero(_) => true, + LevelData::Succ(..) => false, // n+1 is never zero + LevelData::Param(..) | LevelData::Mvar(..) => true, // parameter could be instantiated to zero + LevelData::Max(a, b, _) => could_be_zero_core(a) && could_be_zero_core(b), + LevelData::Imax(_, b, _) => could_be_zero_core(b), // imax(a, 0) = 0 + } +} + /// Check if `l <= r`. pub fn leq(l: &Level, r: &Level) -> bool { let l_s = simplify(l); @@ -400,4 +417,72 @@ mod tests { let expected = Level::succ(Level::zero()); assert_eq!(result, expected); } + + // ========================================================================== + // could_be_zero + // ========================================================================== + + #[test] + fn could_be_zero_zero() { + assert!(could_be_zero(&Level::zero())); + } + + #[test] + fn could_be_zero_succ_is_false() { + // Succ(0) = 1, never zero + assert!(!could_be_zero(&Level::succ(Level::zero()))); + } + + #[test] + fn could_be_zero_succ_param_is_false() { + // u+1 is never zero regardless of u + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(!could_be_zero(&Level::succ(u))); + } + + #[test] + fn could_be_zero_param_is_true() { + // Param u could be zero (instantiated to 0) + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(could_be_zero(&u)); + } + + #[test] + fn could_be_zero_max_both_could() { + // max(u, v) could be zero if both u and v could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(could_be_zero(&Level::max(u, v))); + } + + #[test] + fn could_be_zero_max_one_nonzero() { + // max(u+1, v) cannot be zero because u+1 ≥ 1 + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(!could_be_zero(&Level::max(Level::succ(u), v))); + } + + #[test] + fn could_be_zero_imax_zero_right() { + // imax(u, 0) = 0, so could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(could_be_zero(&Level::imax(u, Level::zero()))); + } + + #[test] + fn could_be_zero_imax_succ_right() { + // imax(u, v+1) = max(u, v+1), never zero since v+1 ≥ 1 + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(!could_be_zero(&Level::imax(u, Level::succ(v)))); + } + + #[test] + fn could_be_zero_imax_param_right() { + // imax(u, v): if v=0 then imax(u,0)=0, so could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(could_be_zero(&Level::imax(u, v))); + } } diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs index 604fbf02..59685192 100644 --- a/src/ix/kernel/tc.rs +++ b/src/ix/kernel/tc.rs @@ -573,6 +573,7 @@ impl<'env> TypeChecker<'env> { } } super::inductive::validate_k_flag(v, self.env)?; + super::inductive::validate_recursor_rules(v, self.env)?; }, } Ok(()) @@ -1542,7 +1543,8 @@ mod tests { } #[test] - fn check_rec_with_inductive() { + fn check_rec_empty_rules_fails() { + // Nat has 2 constructors, so 0 rules should fail let env = mk_nat_env(); let mut tc = TypeChecker::new(&env); let rec = ConstantInfo::RecInfo(RecursorVal { @@ -1560,7 +1562,16 @@ mod tests { k: false, is_unsafe: false, }); - assert!(tc.check_declar(&rec).is_ok()); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_with_valid_rules() { + // Use the full mk_nat_env which includes Nat.rec with proper rules + let env = mk_nat_env(); + let nat_rec = env.get(&mk_name2("Nat", "rec")).unwrap(); + let mut tc = TypeChecker::new(&env); + assert!(tc.check_declar(nat_rec).is_ok()); } // ========================================================================== @@ -1940,7 +1951,11 @@ mod tests { num_indices: Nat::from(1u64), num_motives: Nat::from(1u64), num_minors: Nat::from(1u64), - rules: vec![], + rules: vec![RecursorRule { + ctor: mk_name2("MyEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), // placeholder + }], k: true, is_unsafe: false, }); @@ -2184,4 +2199,301 @@ mod tests { let ty = tc.infer(&e).unwrap(); assert_eq!(ty, nat_type()); } + + // ========================================================================== + // check_declar: Recursor rule validation (integration tests) + // ========================================================================== + + #[test] + fn check_rec_wrong_nfields_via_check_declar() { + // Nat.rec with zero rule claiming 5 fields instead of 0 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let motive_type = Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ); + let rec_type = Expr::all( + mk_name("motive"), + motive_type, + Expr::sort(Level::param(u.clone())), // simplified + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(5u64), // WRONG + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_wrong_ctor_order_via_check_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let rec_type = Expr::all( + mk_name("motive"), + Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ), + Expr::sort(Level::param(u.clone())), + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + // WRONG ORDER: succ then zero + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_wrong_num_params_via_check_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let rec_type = Expr::all( + mk_name("motive"), + Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ), + Expr::sort(Level::param(u.clone())), + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(99u64), // WRONG: Nat has 0 params + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_valid_rules_passes() { + // Full Nat.rec declaration from mk_nat_env passes check_declar + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let nat_rec = env.get(&mk_name2("Nat", "rec")).unwrap(); + assert!(tc.check_declar(nat_rec).is_ok()); + } + + // ========================================================================== + // check_declar: K-flag via check_declar + // ========================================================================== + + /// Build an env with an Eq-like Prop inductive that supports K. + fn mk_k_env() -> Env { + let mut env = mk_nat_env(); + let u = mk_name("u"); + let eq_name = mk_name("MyEq"); + let eq_refl = mk_name2("MyEq", "refl"); + + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn check_k_flag_valid_via_check_declar() { + let env = mk_k_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("MyEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("MyEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![RecursorRule { + ctor: mk_name2("MyEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + #[test] + fn check_k_flag_invalid_on_nat_via_check_declar() { + // K=true on Nat (Sort 1, 2 ctors) should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "recK"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } } diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs index d7cef49a..d4500e85 100644 --- a/src/ix/kernel/whnf.rs +++ b/src/ix/kernel/whnf.rs @@ -509,9 +509,8 @@ pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { eprintln!("[whnf_dag] depth={depth} total={total} no_delta={no_delta}"); } if depth > 200 { - eprintln!("[whnf_dag] DEPTH LIMIT depth={depth}, bailing"); WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); - return; + panic!("[whnf_dag] DEPTH LIMIT exceeded (depth={depth}): possible infinite reduction or extremely deep term"); } const WHNF_STEP_LIMIT: u64 = 100_000; @@ -520,9 +519,8 @@ pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { loop { steps += 1; if steps > WHNF_STEP_LIMIT { - eprintln!("[whnf_dag] step limit exceeded ({steps}) depth={depth}"); whnf_done(depth); - return; + panic!("[whnf_dag] step limit exceeded ({steps} steps at depth={depth}): possible infinite reduction"); } if steps <= 5 || steps % 10_000 == 0 { let head_variant = match dag.head { @@ -925,7 +923,9 @@ pub(crate) fn try_reduce_nat_dag( } else if *name == mk_name2("Nat", "ble") { Some(bool_to_dag(a <= b)) } else if *name == mk_name2("Nat", "pow") { + // Limit exponent to prevent OOM (matches yatima's 2^24 limit) let exp = u32::try_from(&b).unwrap_or(u32::MAX); + if exp > (1 << 24) { return None; } Some(nat_lit_dag(Nat(a.pow(exp)))) } else if *name == mk_name2("Nat", "land") { Some(nat_lit_dag(Nat(a & b))) @@ -934,7 +934,9 @@ pub(crate) fn try_reduce_nat_dag( } else if *name == mk_name2("Nat", "xor") { Some(nat_lit_dag(Nat(a ^ b))) } else if *name == mk_name2("Nat", "shiftLeft") { + // Limit shift to prevent OOM let shift = u64::try_from(&b).unwrap_or(u64::MAX); + if shift > (1 << 24) { return None; } Some(nat_lit_dag(Nat(a << shift))) } else if *name == mk_name2("Nat", "shiftRight") { let shift = u64::try_from(&b).unwrap_or(u64::MAX); @@ -1094,7 +1096,8 @@ pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { } (&d.cnst.level_params, &d.value) }, - ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + // Theorems are never unfolded — proof irrelevance handles them. + // ConstantInfo::ThmInfo(_) => return None, _ => return None, }; From 904c3fb9d46ec61f0545a5c934025777a9b2974d Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Mon, 2 Mar 2026 11:32:59 -0500 Subject: [PATCH 06/14] Rewrite Lean kernel from NbE to environment-based substitution Replace the closure-based NbE (Normalization by Evaluation) kernel with a direct environment-based approach where types are Exprs throughout. - Remove Value/Neutral/ValEnv/SusValue semantic domain (Datatypes.lean) - Replace Eval.lean with Whnf.lean (WHNF via structural + delta reduction) - Replace Equal.lean with DefEq.lean (staged definitional equality with lazy delta reduction guided by ReducibilityHints) - Rewrite Infer.lean to operate on Expr types instead of Values - Simplify TypecheckM: remove NbE-specific state (evalCacheRef, equalCacheRef), add whnf/defEq/infer caches as pure state - Add proof irrelevance, eta expansion, structure eta, nat/string literal expansion to isDefEq - Flatten app spines and binder chains in infer/isDefEq to avoid deep recursion --- Ix/Cli/CheckCmd.lean | 6 +- Ix/Kernel.lean | 4 +- Ix/Kernel/Datatypes.lean | 114 +-- Ix/Kernel/DefEq.lean | 41 + Ix/Kernel/{Equal.lean => Equal.lean.bak} | 0 Ix/Kernel/{Eval.lean => Eval.lean.bak} | 0 Ix/Kernel/Infer.lean | 919 ++++++++++++++--------- Ix/Kernel/TypecheckM.lean | 167 ++-- Ix/Kernel/Types.lean | 125 +++ Ix/Kernel/Whnf.lean | 538 +++++++++++++ Tests/Ix/Check.lean | 24 +- Tests/Ix/KernelTests.lean | 201 ++--- Tests/Ix/PP.lean | 51 +- Tests/Main.lean | 1 - 14 files changed, 1441 insertions(+), 750 deletions(-) create mode 100644 Ix/Kernel/DefEq.lean rename Ix/Kernel/{Equal.lean => Equal.lean.bak} (100%) rename Ix/Kernel/{Eval.lean => Eval.lean.bak} (100%) create mode 100644 Ix/Kernel/Whnf.lean diff --git a/Ix/Cli/CheckCmd.lean b/Ix/Cli/CheckCmd.lean index f8e388f0..f570ea65 100644 --- a/Ix/Cli/CheckCmd.lean +++ b/Ix/Cli/CheckCmd.lean @@ -46,7 +46,7 @@ private def buildFile (path : FilePath) : IO Unit := do if exitCode != 0 then throw $ IO.userError "lake build failed" -/-- Run the Lean NbE kernel checker. -/ +/-- Run the Lean kernel checker. -/ private def runLeanCheck (leanEnv : Lean.Environment) : IO UInt32 := do println! "Compiling to Ixon..." let compileStart ← IO.monoMsNow @@ -106,7 +106,7 @@ def runCheckCmd (p : Cli.Parsed) : IO UInt32 := do let leanEnv ← getFileEnv pathStr if useLean then - println! "Running Lean NbE kernel checker on {pathStr}" + println! "Running Lean kernel checker on {pathStr}" runLeanCheck leanEnv else println! "Running Rust kernel checker on {pathStr}" @@ -118,5 +118,5 @@ def checkCmd : Cli.Cmd := `[Cli| FLAGS: path : String; "Path to file to check" - lean; "Use Lean NbE kernel instead of Rust kernel" + lean; "Use Lean kernel instead of Rust kernel" ] diff --git a/Ix/Kernel.lean b/Ix/Kernel.lean index cbb6c467..ba19b0b4 100644 --- a/Ix/Kernel.lean +++ b/Ix/Kernel.lean @@ -4,8 +4,8 @@ import Ix.Kernel.Types import Ix.Kernel.Datatypes import Ix.Kernel.Level import Ix.Kernel.TypecheckM -import Ix.Kernel.Eval -import Ix.Kernel.Equal +import Ix.Kernel.Whnf +import Ix.Kernel.DefEq import Ix.Kernel.Infer import Ix.Kernel.Convert diff --git a/Ix/Kernel/Datatypes.lean b/Ix/Kernel/Datatypes.lean index d94d8701..f19f983d 100644 --- a/Ix/Kernel/Datatypes.lean +++ b/Ix/Kernel/Datatypes.lean @@ -1,7 +1,7 @@ /- - Kernel Datatypes: Value, Neutral, SusValue, TypedExpr, Env, TypedConst. + Kernel Datatypes: TypeInfo, TypedExpr, TypedConst. - Closure-based semantic domain for NbE typechecking. + Simplified for environment-based kernel (no Value/Neutral/ValEnv). Parameterized over MetaMode for compile-time metadata erasure. -/ import Ix.Kernel.Types @@ -22,41 +22,10 @@ structure AddInfo (Info Body : Type) where body : Body deriving Inhabited -/-! ## Forward declarations for mutual types -/ +/-! ## TypedExpr -/ abbrev TypedExpr (m : MetaMode) := AddInfo (TypeInfo m) (Expr m) -/-! ## Value / Neutral / SusValue -/ - -mutual - inductive Value (m : MetaMode) where - | sort : Level m → Value m - | app : Neutral m → List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (TypeInfo m) → Value m - | lam : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m - → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m - | pi : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m - → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m - | lit : Lean.Literal → Value m - | exception : String → Value m - - inductive Neutral (m : MetaMode) where - | fvar : Nat → MetaField m Ix.Name → Neutral m - | const : Address → Array (Level m) → MetaField m Ix.Name → Neutral m - | proj : Address → Nat → AddInfo (TypeInfo m) (Value m) → MetaField m Ix.Name → Neutral m - - inductive ValEnv (m : MetaMode) where - | mk : List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (Level m) → ValEnv m -end - -instance : Inhabited (Value m) where default := .exception "uninit" -instance : Inhabited (Neutral m) where default := .fvar 0 default -instance : Inhabited (ValEnv m) where default := .mk [] [] - -abbrev SusValue (m : MetaMode) := AddInfo (TypeInfo m) (Thunk (Value m)) - -instance : Inhabited (SusValue m) where - default := .mk default { fn := fun _ => default } - /-! ## TypedConst -/ inductive TypedConst (m : MetaMode) where @@ -86,11 +55,6 @@ def TypedConst.type : TypedConst m → TypedExpr m namespace AddInfo def expr (t : TypedExpr m) : Expr m := t.body -def thunk (sus : SusValue m) : Thunk (Value m) := sus.body -def get (sus : SusValue m) : Value m := sus.body.get -def getTyped (sus : SusValue m) : AddInfo (TypeInfo m) (Value m) := ⟨sus.info, sus.body.get⟩ -def value (val : AddInfo (TypeInfo m) (Value m)) : Value m := val.body -def sus (val : AddInfo (TypeInfo m) (Value m)) : SusValue m := ⟨val.info, val.body⟩ end AddInfo @@ -100,31 +64,7 @@ partial def TypedExpr.toImplicitLambda : TypedExpr m → TypedExpr m | .mk _ (.lam _ body _ _) => toImplicitLambda ⟨default, body⟩ | x => x -/-! ## Value helpers -/ - -def Value.neu (n : Neutral m) : Value m := .app n [] [] - -def Value.ctorName : Value m → String - | .sort .. => "sort" - | .app .. => "app" - | .lam .. => "lam" - | .pi .. => "pi" - | .lit .. => "lit" - | .exception .. => "exception" - -def Neutral.summary : Neutral m → String - | .fvar idx name => s!"fvar({name}, {idx})" - | .const addr _ name => s!"const({name}, {addr})" - | .proj _ idx _ name => s!"proj({name}, {idx})" - -def Value.summary : Value m → String - | .sort _ => "Sort" - | .app neu args _ => s!"{neu.summary} applied to {args.length} args" - | .lam .. => "lam" - | .pi .. => "Pi" - | .lit (.natVal n) => s!"natLit({n})" - | .lit (.strVal s) => s!"strLit(\"{s}\")" - | .exception e => s!"exception({e})" +/-! ## TypeInfo helpers -/ def TypeInfo.pp : TypeInfo m → String | .unit => ".unit" @@ -132,50 +72,4 @@ def TypeInfo.pp : TypeInfo m → String | .none => ".none" | .sort _ => ".sort" -private def listGetOpt (l : List α) (i : Nat) : Option α := - match l, i with - | [], _ => none - | x :: _, 0 => some x - | _ :: xs, n+1 => listGetOpt xs n - -/-- Deep structural dump (one level into args) for debugging stuck terms. -/ -def Value.dump : Value m → String - | .sort _ => "Sort" - | .app neu args infos => - let argStrs := args.zipIdx.map fun (a, i) => - let info := match listGetOpt infos i with | some ti => TypeInfo.pp ti | none => "?" - s!" [{i}] info={info} val={a.get.summary}" - s!"{neu.summary} applied to {args.length} args:\n" ++ String.intercalate "\n" argStrs - | .lam dom _ _ _ _ => s!"lam(dom={dom.get.summary}, info={dom.info.pp})" - | .pi dom _ _ _ _ => s!"Pi(dom={dom.get.summary}, info={dom.info.pp})" - | .lit (.natVal n) => s!"natLit({n})" - | .lit (.strVal s) => s!"strLit(\"{s}\")" - | .exception e => s!"exception({e})" - -/-! ## ValEnv helpers -/ - -namespace ValEnv - -def exprs : ValEnv m → List (SusValue m) - | .mk es _ => es - -def univs : ValEnv m → List (Level m) - | .mk _ us => us - -def extendWith (env : ValEnv m) (thunk : SusValue m) : ValEnv m := - .mk (thunk :: env.exprs) env.univs - -def withExprs (env : ValEnv m) (exprs : List (SusValue m)) : ValEnv m := - .mk exprs env.univs - -end ValEnv - -/-! ## Smart constructors -/ - -def mkConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : Value m := - .neu (.const addr univs name) - -def mkSusVar (info : TypeInfo m) (idx : Nat) (name : MetaField m Ix.Name := default) : SusValue m := - .mk info (.mk fun _ => .neu (.fvar idx name)) - end Ix.Kernel diff --git a/Ix/Kernel/DefEq.lean b/Ix/Kernel/DefEq.lean new file mode 100644 index 00000000..92bdac62 --- /dev/null +++ b/Ix/Kernel/DefEq.lean @@ -0,0 +1,41 @@ +/- + Kernel DefEq: Definitional equality with lazy delta reduction. + + Uses ReducibilityHints to guide delta unfolding order. + Handles proof irrelevance, eta expansion, structure eta. +-/ +import Ix.Kernel.Whnf + +namespace Ix.Kernel + +/-! ## Helpers -/ + +/-- Compare two arrays of levels for equality. -/ +def equalUnivArrays (us us' : Array (Level m)) : Bool := + us.size == us'.size && Id.run do + let mut i := 0 + while i < us.size do + if !Level.equalLevel us[i]! us'[i]! then return false + i := i + 1 + return true + +/-- Check if two expressions have the same const head. -/ +def sameHeadConst (t s : Expr m) : Bool := + match t.getAppFn, s.getAppFn with + | .const a _ _, .const b _ _ => a == b + | _, _ => false + +/-- Unfold a delta-reducible definition one step. -/ +def unfoldDelta (ci : ConstantInfo m) (e : Expr m) : Option (Expr m) := + match ci with + | .defnInfo v => + let levels := e.getAppFn.constLevels! + let body := v.value.instantiateLevelParams levels + some (body.mkAppN (e.getAppArgs)) + | .thmInfo v => + let levels := e.getAppFn.constLevels! + let body := v.value.instantiateLevelParams levels + some (body.mkAppN (e.getAppArgs)) + | _ => none + +end Ix.Kernel diff --git a/Ix/Kernel/Equal.lean b/Ix/Kernel/Equal.lean.bak similarity index 100% rename from Ix/Kernel/Equal.lean rename to Ix/Kernel/Equal.lean.bak diff --git a/Ix/Kernel/Eval.lean b/Ix/Kernel/Eval.lean.bak similarity index 100% rename from Ix/Kernel/Eval.lean rename to Ix/Kernel/Eval.lean.bak diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 0dacf465..abf8a9f2 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -1,9 +1,9 @@ /- Kernel Infer: Type inference and declaration checking. - Adapted from Yatima.Typechecker.Infer, parameterized over MetaMode. + Environment-based kernel: types are Exprs, uses whnf/isDefEq. -/ -import Ix.Kernel.Equal +import Ix.Kernel.DefEq namespace Ix.Kernel @@ -20,26 +20,20 @@ partial def exprMentionsConst (e : Expr m) (addr : Address) : Bool := | .proj _ _ s _ => exprMentionsConst s addr | _ => false -/-- Check strict positivity of a field type w.r.t. a set of inductive addresses. - Returns true if positive, false if negative occurrence found. -/ +/-- Check strict positivity of a field type w.r.t. a set of inductive addresses. -/ partial def checkStrictPositivity (ty : Expr m) (indAddrs : Array Address) : Bool := - -- If no inductive is mentioned, we're fine if !indAddrs.any (exprMentionsConst ty ·) then true else match ty with | .forallE domain body _ _ => - -- Domain must NOT mention any inductive if indAddrs.any (exprMentionsConst domain ·) then false - -- Continue checking body else checkStrictPositivity body indAddrs | e => - -- Not a forall — must be the inductive at the head let fn := e.getAppFn match fn with | .const addr _ _ => indAddrs.any (· == addr) | _ => false -/-- Walk a Pi chain, skip numParams binders, then check positivity of each field. - Returns an error message or none on success. -/ +/-- Walk a Pi chain, skip numParams binders, then check positivity of each field. -/ partial def checkCtorPositivity (ctorType : Expr m) (numParams : Nat) (indAddrs : Array Address) : Option String := go ctorType numParams @@ -50,7 +44,6 @@ where if remainingParams > 0 then go body (remainingParams - 1) else - -- This is a field — check positivity of its domain let domain := ty.bindingDomain! if !checkStrictPositivity domain indAddrs then some "inductive occurs in negative position (strict positivity violation)" @@ -58,8 +51,7 @@ where go body 0 | _ => none -/-- Walk a Pi chain past numParams + numFields binders to get the return type. - Returns the return type expression (with bvars). -/ +/-- Walk a Pi chain past numParams + numFields binders to get the return type. -/ def getCtorReturnType (ctorType : Expr m) (numParams numFields : Nat) : Expr m := go ctorType (numParams + numFields) where @@ -69,8 +61,7 @@ where | n+1, .forallE _ body _ _ => go body n | _, e => e -/-- Extract result universe level from an inductive type expression. - Walks past all forall binders to find the final Sort. -/ +/-- Extract result universe level from an inductive type expression. -/ def getIndResultLevel (indType : Expr m) : Option (Level m) := go indType where @@ -79,11 +70,11 @@ where | .sort lvl => some lvl | _ => none -/-- Check if a level is definitively non-zero (always ≥ 1). -/ +/-- Check if a level is definitively non-zero (always >= 1). -/ partial def levelIsNonZero : Level m → Bool | .succ _ => true | .zero => false - | .param .. => false -- could be zero + | .param .. => false | .max a b => levelIsNonZero a || levelIsNonZero b | .imax _ b => levelIsNonZero b @@ -93,24 +84,22 @@ def lamInfo : TypeInfo m → TypeInfo m | .proof => .proof | _ => .none -def piInfo (dom img : TypeInfo m) : TypecheckM m σ (TypeInfo m) := match dom, img with - | .sort lvl, .sort lvl' => pure (.sort (Level.reduceIMax lvl lvl')) - | _, _ => pure .none +def piInfo (dom img : TypeInfo m) : TypeInfo m := match dom, img with + | .sort lvl, .sort lvl' => .sort (Level.reduceIMax lvl lvl') + | _, _ => .none -def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m σ Bool := do - match inferType.info, expectType.info with - | .sort lvl, .sort lvl' => pure (Level.equalLevel lvl lvl') - | _, _ => pure true -- info unavailable; defer to structural equality - -def infoFromType (typ : SusValue m) : TypecheckM m σ (TypeInfo m) := - match typ.info with +/-- Infer TypeInfo from a type expression (after whnf). -/ +def infoFromType (typ : Expr m) : TypecheckM m (TypeInfo m) := do + let typ' ← whnf typ + match typ' with | .sort (.zero) => pure .proof - | _ => - match typ.get with - | .app (.const addr _ _) _ _ => do + | .sort lvl => pure (.sort lvl) + | .app .. => + let head := typ'.getAppFn + match head with + | .const addr _ _ => match (← read).kenv.find? addr with | some (.inductInfo v) => - -- Check if it's unit-like: one constructor with zero fields if v.ctors.size == 1 then match (← read).kenv.find? v.ctors[0]! with | some (.ctorInfo cv) => @@ -118,292 +107,275 @@ def infoFromType (typ : SusValue m) : TypecheckM m σ (TypeInfo m) := | _ => pure .none else pure .none | _ => pure .none - | .sort lvl => pure (.sort lvl) | _ => pure .none + | _ => pure .none /-! ## Inference / Checking -/ mutual /-- Check that a term has a given type. -/ - partial def check (term : Expr m) (type : SusValue m) : TypecheckM m σ (TypedExpr m) := do + partial def check (term : Expr m) (expectedType : Expr m) : TypecheckM m (TypedExpr m) := do if (← read).trace then dbg_trace s!"check: {term.tag}" - let (te, inferType) ← infer term - if !(← eqSortInfo inferType type) then - throw s!"Info mismatch on {term.tag}" - if !(← equal (← read).lvl type inferType) then - let lvl := (← read).lvl - let ppInferred ← tryPpValue lvl inferType.get - let ppExpected ← tryPpValue lvl type.get - let dumpInferred := inferType.get.dump - let dumpExpected := type.get.dump - throw s!"Type mismatch on {term.tag}\n inferred: {ppInferred}\n expected: {ppExpected}\n inferred dump: {dumpInferred}\n expected dump: {dumpExpected}\n inferred info: {inferType.info.pp}\n expected info: {type.info.pp}" + let (te, inferredType) ← infer term + if !(← isDefEq inferredType expectedType) then + let ppInferred := inferredType.pp + let ppExpected := expectedType.pp + throw s!"Type mismatch on {term.tag}\n inferred: {ppInferred}\n expected: {ppExpected}" pure te /-- Infer the type of an expression, returning the typed expression and its type. -/ - partial def infer (term : Expr m) : TypecheckM m σ (TypedExpr m × SusValue m) := withFuelCheck do + partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × Expr m) := withFuelCheck do + -- Check infer cache: keyed on Expr, context verified on retrieval + let types := (← read).types + if let some (cachedCtx, cachedType) := (← get).inferCache.get? term then + -- Ptr equality first, structural BEq fallback + if unsafe ptrAddrUnsafe cachedCtx == ptrAddrUnsafe types || cachedCtx == types then + let te : TypedExpr m := ⟨← infoFromType cachedType, term⟩ + return (te, cachedType) if (← read).trace then dbg_trace s!"infer: {term.tag}" - match term with - | .bvar idx bvarName => do - let ctx ← read - if idx < ctx.lvl then - let some type := listGet? ctx.types idx - | throw s!"var@{idx} out of environment range (size {ctx.types.length})" - let te : TypedExpr m := ⟨← infoFromType type, .bvar idx bvarName⟩ - pure (te, type) - else - -- Mutual reference - match ctx.mutTypes.get? (idx - ctx.lvl) with - | some (addr, typeValFn) => - if some addr == ctx.recAddr? then - throw s!"Invalid recursion" - let univs := ctx.env.univs.toArray - let type := typeValFn univs - let name ← lookupName addr - let te : TypedExpr m := ⟨← infoFromType type, .const addr univs name⟩ - pure (te, type) - | none => - throw s!"var@{idx} out of environment range and does not represent a mutual constant" - | .sort lvl => do - let univs := (← read).env.univs.toArray - let lvl := Level.instBulkReduce univs lvl - let lvl' := Level.succ lvl - let typ : SusValue m := .mk (.sort (Level.succ lvl')) (.mk fun _ => .sort lvl') - let te : TypedExpr m := ⟨.sort lvl', .sort lvl⟩ - pure (te, typ) - | .app fnc arg => do - let (fnTe, fncType) ← infer fnc - match fncType.get with - | .pi dom img piEnv _ _ => do - let argTe ← check arg dom + let result ← do match term with + | .bvar idx bvarName => do let ctx ← read - let stt ← get - let typ := suspend img { ctx with env := piEnv.extendWith (suspend argTe ctx stt) } stt - let te : TypedExpr m := ⟨← infoFromType typ, .app fnTe.body argTe.body⟩ + let depth := ctx.types.size + if idx < depth then + let arrayIdx := depth - 1 - idx + if h : arrayIdx < ctx.types.size then + let rawType := ctx.types[arrayIdx] + let typ := rawType.liftBVars (idx + 1) + let te : TypedExpr m := ⟨← infoFromType typ, .bvar idx bvarName⟩ + pure (te, typ) + else + throw s!"var@{idx} out of environment range (size {ctx.types.size})" + else + match ctx.mutTypes.get? (idx - depth) with + | some (addr, typeExprFn) => + if some addr == ctx.recAddr? then + throw s!"Invalid recursion" + let univs := Array.ofFn (n := 0) fun i => Level.param i.val (default : MetaField m Ix.Name) + let typ := typeExprFn univs + let name ← lookupName addr + let te : TypedExpr m := ⟨← infoFromType typ, .const addr univs name⟩ + pure (te, typ) + | none => + throw s!"var@{idx} out of environment range and does not represent a mutual constant" + | .sort lvl => do + let lvl' := Level.succ lvl + let typ := Expr.mkSort lvl' + let te : TypedExpr m := ⟨.sort lvl', .sort lvl⟩ pure (te, typ) - | v => - let ppV ← tryPpValue (← read).lvl v - throw s!"Expected a pi type, got {ppV}\n dump: {v.dump}\n fncType info: {fncType.info.pp}\n function: {fnc.pp}\n argument: {arg.pp}" - | .lam ty body lamName lamBi => do - let (domTe, _) ← isSort ty - let ctx ← read - let stt ← get - let domVal := suspend domTe ctx stt - let var := mkSusVar (← infoFromType domVal) ctx.lvl lamName - let (bodTe, imgVal) ← withExtendedCtx var domVal (infer body) - let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ - let imgTE ← quoteTyped (ctx.lvl+1) imgVal.getTyped - let typ : SusValue m := ⟨← piInfo domVal.info imgVal.info, - Thunk.mk fun _ => Value.pi domVal imgTE ctx.env lamName lamBi⟩ - pure (te, typ) - | .forallE ty body piName _ => do - let (domTe, domLvl) ← isSort ty - let ctx ← read - let stt ← get - let domVal := suspend domTe ctx stt - let domSusVal := mkSusVar (← infoFromType domVal) ctx.lvl piName - withExtendedCtx domSusVal domVal do - let (imgTe, imgLvl) ← isSort body + | .app .. => do + -- Flatten app spine to avoid O(num_args) stack depth + let args := term.getAppArgs + let fn := term.getAppFn + let (fnTe, fncType) ← infer fn + let mut currentType := fncType + let mut resultBody := fnTe.body + for h : i in [:args.size] do + let arg := args[i] + let currentType' ← whnf currentType + match currentType' with + | .forallE dom body _ _ => do + let argTe ← check arg dom + resultBody := Expr.mkApp resultBody argTe.body + currentType := body.instantiate1 arg + | _ => + throw s!"Expected a pi type, got {currentType'.pp}\n function: {fn.pp}\n arg #{i}: {arg.pp}" + let te : TypedExpr m := ⟨← infoFromType currentType, resultBody⟩ + pure (te, currentType) + | .lam ty body lamName lamBi => do + let (domTe, _) ← isSort ty + let (bodTe, imgType) ← withExtendedCtx ty (infer body) + let piType := Expr.forallE ty imgType lamName default + let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ + pure (te, piType) + | .forallE ty body piName _ => do + let (domTe, domLvl) ← isSort ty + let (imgTe, imgLvl) ← withExtendedCtx ty (isSort body) let sortLvl := Level.reduceIMax domLvl imgLvl - let typ : SusValue m := .mk (.sort (Level.succ sortLvl)) (.mk fun _ => .sort sortLvl) + let typ := Expr.mkSort sortLvl let te : TypedExpr m := ⟨← infoFromType typ, .forallE domTe.body imgTe.body piName default⟩ pure (te, typ) - | .letE ty val body letName => do - let (tyTe, _) ← isSort ty - let ctx ← read - let stt ← get - let tyVal := suspend tyTe ctx stt - let valTe ← check val tyVal - let valVal := suspend valTe ctx stt - let (bodTe, typ) ← withExtendedCtx valVal tyVal (infer body) - let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ - pure (te, typ) - | .lit (.natVal _) => do - let prims := (← read).prims - let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.nat #[]) - let te : TypedExpr m := ⟨.none, term⟩ - pure (te, typ) - | .lit (.strVal _) => do - let prims := (← read).prims - let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.string #[]) - let te : TypedExpr m := ⟨.none, term⟩ - pure (te, typ) - | .const addr constUnivs _ => do - ensureTypedConst addr - let ctx ← read - let univs := ctx.env.univs.toArray - let reducedUnivs := constUnivs.toList.map (Level.instBulkReduce univs) - -- Check const type cache (must also match universe parameters) - match (← get).constTypeCache.get? addr with - | some (cachedUnivs, cachedTyp) => - if cachedUnivs == reducedUnivs then - let te : TypedExpr m := ⟨← infoFromType cachedTyp, term⟩ - pure (te, cachedTyp) - else + | .letE ty val body letName => do + let (tyTe, _) ← isSort ty + let valTe ← check val ty + let (bodTe, bodType) ← withExtendedCtx ty (infer body) + let resultType := bodType.instantiate1 val + let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ + pure (te, resultType) + | .lit (.natVal _) => do + let prims := (← read).prims + let typ := Expr.mkConst prims.nat #[] + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .lit (.strVal _) => do + let prims := (← read).prims + let typ := Expr.mkConst prims.string #[] + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .const addr constUnivs _ => do + ensureTypedConst addr + match (← get).constTypeCache.get? addr with + | some (cachedUnivs, cachedTyp) => + if cachedUnivs == constUnivs then + let te : TypedExpr m := ⟨← infoFromType cachedTyp, term⟩ + pure (te, cachedTyp) + else + let tconst ← derefTypedConst addr + let typ := tconst.type.body.instantiateLevelParams constUnivs + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (constUnivs, typ) } + let te : TypedExpr m := ⟨← infoFromType typ, term⟩ + pure (te, typ) + | none => let tconst ← derefTypedConst addr - let env : ValEnv m := .mk [] reducedUnivs - let stt ← get - let typ := suspend tconst.type { ctx with env := env } stt - modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } + let typ := tconst.type.body.instantiateLevelParams constUnivs + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (constUnivs, typ) } let te : TypedExpr m := ⟨← infoFromType typ, term⟩ pure (te, typ) - | none => - let tconst ← derefTypedConst addr - let env : ValEnv m := .mk [] reducedUnivs - let stt ← get - let typ := suspend tconst.type { ctx with env := env } stt - modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } - let te : TypedExpr m := ⟨← infoFromType typ, term⟩ - pure (te, typ) - | .proj typeAddr idx struct _ => do - let (structTe, structType) ← infer struct - let (ctorType, univs, params) ← getStructInfo structType.get - let mut ct ← applyType (← withEnv (.mk [] univs) (eval ctorType)) params.reverse - for i in [:idx] do + | .proj typeAddr idx struct _ => do + let (structTe, structType) ← infer struct + let (ctorType, ctorUnivs, numParams, params) ← getStructInfo structType + let mut ct := ctorType.instantiateLevelParams ctorUnivs + for _ in [:numParams] do + ct ← whnf ct + match ct with + | .forallE _ body _ _ => ct := body + | _ => throw "Structure constructor has too few parameters" + ct := ct.instantiate params.reverse + for i in [:idx] do + ct ← whnf ct + match ct with + | .forallE _ body _ _ => + let projExpr := Expr.mkProj typeAddr i structTe.body + ct := body.instantiate1 projExpr + | _ => throw "Structure type does not have enough fields" + ct ← whnf ct match ct with - | .pi dom img piEnv _ _ => do - let info ← infoFromType dom - let ctx ← read - let stt ← get - let proj := suspend ⟨info, .proj typeAddr i structTe.body default⟩ ctx stt - ct ← withNewExtendedEnv piEnv proj (eval img) - | _ => pure () - match ct with - | .pi dom _ _ _ _ => - let te : TypedExpr m := ⟨← infoFromType dom, .proj typeAddr idx structTe.body default⟩ - pure (te, dom) - | _ => throw "Impossible case: structure type does not have enough fields" + | .forallE dom _ _ _ => + let te : TypedExpr m := ⟨← infoFromType dom, .proj typeAddr idx structTe.body default⟩ + pure (te, dom) + | _ => throw "Impossible case: structure type does not have enough fields" + -- Cache the inferred type with the binding context + modify fun stt => { stt with inferCache := stt.inferCache.insert term (types, result.2) } + pure result /-- Check if an expression is a Sort, returning the typed expr and the universe level. -/ - partial def isSort (expr : Expr m) : TypecheckM m σ (TypedExpr m × Level m) := do + partial def isSort (expr : Expr m) : TypecheckM m (TypedExpr m × Level m) := do let (te, typ) ← infer expr - match typ.get with + let typ' ← whnf typ + match typ' with | .sort u => pure (te, u) - | v => - let ppV ← tryPpValue (← read).lvl v - throw s!"Expected a sort type, got {ppV}\n expr: {expr.pp}" - - /-- Get structure info from a value that should be a structure type. -/ - partial def getStructInfo (v : Value m) : - TypecheckM m σ (TypedExpr m × List (Level m) × List (SusValue m)) := do - match v with - | .app (.const indAddr univs _) params _ => + | _ => + throw s!"Expected a sort type, got {typ'.pp}\n expr: {expr.pp}" + + /-- Get structure info from a type that should be a structure. + Returns (constructor type expr, universe levels, numParams, param exprs). -/ + partial def getStructInfo (structType : Expr m) : + TypecheckM m (Expr m × Array (Level m) × Nat × Array (Expr m)) := do + let structType' ← whnf structType + let fn := structType'.getAppFn + match fn with + | .const indAddr univs _ => match (← read).kenv.find? indAddr with | some (.inductInfo v) => - if v.ctors.size != 1 || params.length != v.numParams then - throw s!"Expected a structure type, but {v.name} ({indAddr}) has {v.ctors.size} ctors and {params.length}/{v.numParams} params" + let params := structType'.getAppArgs + if v.ctors.size != 1 || params.size != v.numParams then + throw s!"Expected a structure type, but {v.name} ({indAddr}) has {v.ctors.size} ctors and {params.size}/{v.numParams} params" ensureTypedConst indAddr let ctorAddr := v.ctors[0]! ensureTypedConst ctorAddr match (← get).typedConsts.get? ctorAddr with | some (.constructor type _ _) => - return (type, univs.toList, params) + return (type.body, univs, v.numParams, params) | _ => throw s!"Constructor {ctorAddr} is not in typed consts" | some ci => throw s!"Expected a structure type, but {indAddr} is a {ci.kindName}" | none => throw s!"Expected a structure type, but {indAddr} not found in env" | _ => - let ppV ← tryPpValue (← read).lvl v - throw s!"Expected a structure type, got {ppV}" + throw s!"Expected a structure type, got {structType'.pp}" - /-- Typecheck a constant. With fresh state per declaration, dependencies get - provisional entries via `ensureTypedConst` and are assumed well-typed. -/ - partial def checkConst (addr : Address) : TypecheckM m σ Unit := withResetCtx do + /-- Typecheck a constant. -/ + partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do -- Reset fuel and per-constant caches - modify fun stt => { stt with constTypeCache := {} } - let ctx ← read - let _ ← ctx.fuelRef.set defaultFuel - let _ ← ctx.evalCacheRef.set {} - let _ ← ctx.equalCacheRef.set {} - -- Skip if already in typedConsts (provisional entry is fine — dependency assumed well-typed) + modify fun stt => { stt with + constTypeCache := {}, + whnfCache := {}, + whnfCoreCache := {}, + inferCache := {}, + eqvCache := {}, + failureCache := {}, + fuel := defaultFuel + } + -- Skip if already in typedConsts if (← get).typedConsts.get? addr |>.isSome then return () let ci ← derefConst addr let univs := ci.cv.mkUnivParams - withEnv (.mk [] univs.toList) do - let newConst ← match ci with - | .axiomInfo _ => - let (type, _) ← isSort ci.type - pure (TypedConst.axiom type) - | .opaqueInfo _ => - let (type, _) ← isSort ci.type - let typeSus := suspend type (← read) (← get) - let value ← withRecAddr addr (check ci.value?.get! typeSus) - pure (TypedConst.opaque type value) - | .thmInfo _ => - let (type, lvl) ← isSort ci.type - if !Level.isZero lvl then - throw s!"theorem type must be a proposition (Sort 0)" - let typeSus := suspend type (← read) (← get) - let value ← withRecAddr addr (check ci.value?.get! typeSus) - pure (TypedConst.theorem type value) - | .defnInfo v => - let (type, _) ← isSort ci.type - let ctx ← read - let stt ← get - let typeSus := suspend type ctx stt - let part := v.safety == .partial - let value ← - if part then - let typeSusFn := suspend type { ctx with env := ValEnv.mk ctx.env.exprs ctx.env.univs } stt - let mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare := - (Std.TreeMap.empty).insert 0 (addr, fun _ => typeSusFn) - withMutTypes mutTypes (withRecAddr addr (check v.value typeSus)) - else withRecAddr addr (check v.value typeSus) - pure (TypedConst.definition type value part) - | .quotInfo v => - let (type, _) ← isSort ci.type - pure (TypedConst.quotient type v.kind) - | .inductInfo _ => - checkIndBlock addr - return () - | .ctorInfo v => - checkIndBlock v.induct - return () - | .recInfo v => do - -- Extract the major premise's inductive from the recursor type - let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices - |>.getD default - -- Ensure the inductive has a provisional entry (assumed well-typed with fresh state per decl) - ensureTypedConst indAddr - -- Check recursor type - let (type, _) ← isSort ci.type - -- (#3) Validate K-flag instead of trusting the environment - if v.k then - validateKFlag v indAddr - -- (#4) Validate recursor rules - validateRecursorRules v indAddr - -- Check recursor rules (type-check RHS) - let typedRules ← v.rules.mapM fun rule => do - let (rhs, _) ← infer rule.rhs - pure (rule.nfields, rhs) - pure (TypedConst.recursor type v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules) - modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr newConst } - - /-- Walk a Pi chain to extract the return sort level (the universe of the result type). - Assumes the expression ends in `Sort u` after `numBinders` forall binders. -/ - partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m σ (Level m) := + -- Universe level instantiation for the constant's own level params + let newConst ← match ci with + | .axiomInfo _ => + let (type, _) ← isSort ci.type + pure (TypedConst.axiom type) + | .opaqueInfo _ => + let (type, _) ← isSort ci.type + let value ← withRecAddr addr (check ci.value?.get! type.body) + pure (TypedConst.opaque type value) + | .thmInfo _ => + let (type, lvl) ← isSort ci.type + if !Level.isZero lvl then + throw s!"theorem type must be a proposition (Sort 0)" + let value ← withRecAddr addr (check ci.value?.get! type.body) + pure (TypedConst.theorem type value) + | .defnInfo v => + let (type, _) ← isSort ci.type + let part := v.safety == .partial + let value ← + if part then + let typExpr := type.body + let mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → Expr m)) compare := + (Std.TreeMap.empty).insert 0 (addr, fun _ => typExpr) + withMutTypes mutTypes (withRecAddr addr (check v.value type.body)) + else withRecAddr addr (check v.value type.body) + pure (TypedConst.definition type value part) + | .quotInfo v => + let (type, _) ← isSort ci.type + pure (TypedConst.quotient type v.kind) + | .inductInfo _ => + checkIndBlock addr + return () + | .ctorInfo v => + checkIndBlock v.induct + return () + | .recInfo v => do + let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices + |>.getD default + ensureTypedConst indAddr + let (type, _) ← isSort ci.type + if v.k then + validateKFlag v indAddr + validateRecursorRules v indAddr + let typedRules ← v.rules.mapM fun rule => do + let (rhs, _) ← infer rule.rhs + pure (rule.nfields, rhs) + pure (TypedConst.recursor type v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules) + modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr newConst } + + /-- Walk a Pi chain to extract the return sort level. -/ + partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m (Level m) := match numBinders, expr with - | 0, .sort u => do - let univs := (← read).env.univs.toArray - pure (Level.instBulkReduce univs u) + | 0, .sort u => pure u | 0, _ => do - -- Not syntactically a sort; try to infer let (_, typ) ← infer expr - match typ.get with + let typ' ← whnf typ + match typ' with | .sort u => pure u | _ => throw "inductive return type is not a sort" | n+1, .forallE dom body _ _ => do - let (domTe, _) ← isSort dom - let ctx ← read - let stt ← get - let domVal := suspend domTe ctx stt - let var := mkSusVar (← infoFromType domVal) ctx.lvl - withExtendedCtx var domVal (getReturnSort body n) + let _ ← isSort dom + withExtendedCtx dom (getReturnSort body n) | _, _ => throw "inductive type has fewer binders than expected" - /-- Typecheck a mutual inductive block starting from one of its addresses. -/ - partial def checkIndBlock (addr : Address) : TypecheckM m σ Unit := do + /-- Typecheck a mutual inductive block. -/ + partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do let ci ← derefConst addr - -- Find the inductive info let indInfo ← match ci with | .inductInfo _ => pure ci | .ctorInfo v => @@ -412,111 +384,71 @@ mutual | _ => throw "Constructor's inductive not found" | _ => throw "Expected an inductive" let .inductInfo iv := indInfo | throw "unreachable" - -- Check if already done if (← get).typedConsts.get? addr |>.isSome then return () - -- Check the inductive type - let univs := iv.toConstantVal.mkUnivParams - let (type, _) ← withEnv (.mk [] univs.toList) (isSort iv.type) + let (type, _) ← isSort iv.type let isStruct := !iv.isRec && iv.numIndices == 0 && iv.ctors.size == 1 && match (← read).kenv.find? iv.ctors[0]! with | some (.ctorInfo cv) => cv.numFields > 0 | _ => false modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr (TypedConst.inductive type isStruct) } - - -- Collect all inductive addresses in this mutual block let indAddrs := iv.all - - -- Get the inductive's result universe level let indResultLevel := getIndResultLevel iv.type - - -- Check constructors - for (ctorAddr, cidx) in iv.ctors.toList.zipIdx do + for (ctorAddr, _cidx) in iv.ctors.toList.zipIdx do match (← read).kenv.find? ctorAddr with | some (.ctorInfo cv) => do - let ctorUnivs := cv.toConstantVal.mkUnivParams - let (ctorType, _) ← withEnv (.mk [] ctorUnivs.toList) (isSort cv.type) - modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cidx cv.numFields) } - - -- (#5) Check constructor parameter count matches inductive + let (ctorType, _) ← isSort cv.type + modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cv.cidx cv.numFields) } if cv.numParams != iv.numParams then throw s!"Constructor {ctorAddr} has {cv.numParams} params but inductive has {iv.numParams}" - - -- (#1) Positivity checking (skip for unsafe inductives) if !iv.isUnsafe then match checkCtorPositivity cv.type cv.numParams indAddrs with | some msg => throw s!"Constructor {ctorAddr}: {msg}" | none => pure () - - -- (#2) Universe constraint checking on constructor fields - -- Each non-parameter field's sort must be ≤ the inductive's result sort. - -- We check this by inferring the sort of each field type and comparing levels. if !iv.isUnsafe then if let some indLvl := indResultLevel then - let indLvlReduced := Level.instBulkReduce univs indLvl - checkFieldUniverses cv.type cv.numParams ctorAddr indLvlReduced - - -- (#6) Check indices in ctor return type don't mention the inductive + checkFieldUniverses cv.type cv.numParams ctorAddr indLvl if !iv.isUnsafe then let retType := getCtorReturnType cv.type cv.numParams cv.numFields let args := retType.getAppArgs - -- Index arguments are those after numParams for i in [iv.numParams:args.size] do for indAddr in indAddrs do if exprMentionsConst args[i]! indAddr then throw s!"Constructor {ctorAddr} index argument mentions the inductive (unsound)" - | _ => throw s!"Constructor {ctorAddr} not found" - -- Note: recursors are checked individually via checkConst's .recInfo branch, - -- which calls checkConst on the inductives first then checks rules. - /-- Check that constructor field types have sorts ≤ the inductive's result sort. -/ + /-- Check that constructor field types have sorts <= the inductive's result sort. -/ partial def checkFieldUniverses (ctorType : Expr m) (numParams : Nat) - (ctorAddr : Address) (indLvl : Level m) : TypecheckM m σ Unit := + (ctorAddr : Address) (indLvl : Level m) : TypecheckM m Unit := go ctorType numParams where - go (ty : Expr m) (remainingParams : Nat) : TypecheckM m σ Unit := + go (ty : Expr m) (remainingParams : Nat) : TypecheckM m Unit := match ty with - | .forallE dom body piName _ => + | .forallE dom body _piName _ => if remainingParams > 0 then do - let (domTe, _) ← isSort dom - let ctx ← read - let stt ← get - let domVal := suspend domTe ctx stt - let var := mkSusVar (← infoFromType domVal) ctx.lvl piName - withExtendedCtx var domVal (go body (remainingParams - 1)) + let _ ← isSort dom + withExtendedCtx dom (go body (remainingParams - 1)) else do - -- This is a field — infer its sort level and check ≤ indLvl - let (domTe, fieldSortLvl) ← isSort dom + let (_, fieldSortLvl) ← isSort dom let fieldReduced := Level.reduce fieldSortLvl let indReduced := Level.reduce indLvl - -- Allow if field ≤ ind, OR if ind is Prop (is_zero allows any field) if !Level.leq fieldReduced indReduced 0 && !Level.isZero indReduced then throw s!"Constructor {ctorAddr} field type lives in a universe larger than the inductive's universe" - let ctx ← read - let stt ← get - let domVal := suspend domTe ctx stt - let var := mkSusVar (← infoFromType domVal) ctx.lvl piName - withExtendedCtx var domVal (go body 0) + withExtendedCtx dom (go body 0) | _ => pure () - /-- (#3) Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ - partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do - -- Must be non-mutual + /-- Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ + partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m Unit := do if rec.all.size != 1 then throw "recursor claims K but inductive is mutual" - -- Look up the inductive match (← read).kenv.find? indAddr with | some (.inductInfo iv) => - -- Must be in Prop match getIndResultLevel iv.type with | some lvl => if levelIsNonZero lvl then throw s!"recursor claims K but inductive is not in Prop" | none => throw "recursor claims K but cannot determine inductive's result sort" - -- Must have single constructor if iv.ctors.size != 1 then throw s!"recursor claims K but inductive has {iv.ctors.size} constructors (need 1)" - -- Constructor must have zero fields match (← read).kenv.find? iv.ctors[0]! with | some (.ctorInfo cv) => if cv.numFields != 0 then @@ -524,31 +456,25 @@ mutual | _ => throw "recursor claims K but constructor not found" | _ => throw s!"recursor claims K but {indAddr} is not an inductive" - /-- (#4) Validate recursor rules: check rule count, ctor membership, field counts. -/ - partial def validateRecursorRules (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do - -- Collect all constructors from the mutual block + /-- Validate recursor rules: check rule count, ctor membership, field counts. -/ + partial def validateRecursorRules (rec : RecursorVal m) (indAddr : Address) : TypecheckM m Unit := do let mut allCtors : Array Address := #[] for iAddr in rec.all do match (← read).kenv.find? iAddr with | some (.inductInfo iv) => allCtors := allCtors ++ iv.ctors | _ => throw s!"recursor references {iAddr} which is not an inductive" - -- Check rule count if rec.rules.size != allCtors.size then throw s!"recursor has {rec.rules.size} rules but inductive(s) have {allCtors.size} constructors" - -- Check each rule for h : i in [:rec.rules.size] do let rule := rec.rules[i] - -- Rule's constructor must match expected constructor in order if rule.ctor != allCtors[i]! then throw s!"recursor rule {i} has constructor {rule.ctor} but expected {allCtors[i]!}" - -- Look up the constructor and validate nfields match (← read).kenv.find? rule.ctor with | some (.ctorInfo cv) => if rule.nfields != cv.numFields then throw s!"recursor rule for {rule.ctor} has nfields={rule.nfields} but constructor has {cv.numFields} fields" | _ => throw s!"recursor rule constructor {rule.ctor} not found" - -- Validate structural counts against the inductive match (← read).kenv.find? indAddr with | some (.inductInfo iv) => if rec.numParams != iv.numParams then @@ -557,6 +483,311 @@ mutual throw s!"recursor numIndices={rec.numIndices} but inductive has {iv.numIndices}" | _ => pure () + /-- Quick structural equality check without WHNF. Returns: + - some true: definitely equal + - some false: definitely not equal + - none: unknown, need deeper checks -/ + partial def quickIsDefEq (t s : Expr m) (useHash : Bool := true) : TypecheckM m (Option Bool) := do + if t == s then return some true + let key := eqCacheKey t s + if let some r := (← get).eqvCache.get? key then return some r + if (← get).failureCache.contains key then return some false + match t, s with + | .sort u, .sort u' => pure (some (Level.equalLevel u u')) + | .const a us _, .const b us' _ => pure (some (a == b && equalUnivArrays us us')) + | .lit l, .lit l' => pure (some (l == l')) + | .bvar i _, .bvar j _ => pure (some (i == j)) + | .lam ty body _ _, .lam ty' body' _ _ => + match ← quickIsDefEq ty ty' with + | some true => quickIsDefEq body body' + | other => pure other + | .forallE ty body _ _, .forallE ty' body' _ _ => + match ← quickIsDefEq ty ty' with + | some true => quickIsDefEq body body' + | other => pure other + | _, _ => pure none + + /-- Check if two expressions are definitionally equal. + Uses a staged approach matching lean4/lean4lean: + 1. quickIsDefEq — structural shape match without WHNF + 2. whnfCore(cheapProj=true) — structural reduction, projections stay cheap + 3. Lazy delta reduction — unfold definitions one step at a time + 4. whnfCore(cheapProj=false) — full projection resolution (only if needed) + 5. Structural comparison -/ + partial def isDefEq (t s : Expr m) : TypecheckM m Bool := withFuelCheck do + -- 0. Quick structural check (avoids WHNF for trivially equal/unequal terms) + match ← quickIsDefEq t s with + | some result => return result + | none => pure () + + -- 1. Stage 1: structural reduction + let tn ← whnfCore t + let sn ← whnfCore s + + -- 2. Quick check after whnfCore + match ← quickIsDefEq tn sn with + | some true => cacheResult t s true; return true + | some false => pure () -- don't cache — deeper checks may still succeed + | none => pure () + + -- 3. Proof irrelevance + match ← isDefEqProofIrrel tn sn with + | some result => + cacheResult t s result + return result + | none => pure () + + -- 4. Lazy delta reduction (incremental unfolding) + let (tn', sn', deltaResult) ← lazyDeltaReduction tn sn + if deltaResult == some true then + cacheResult t s true + return true + + -- 5. Stage 2: full whnf (resolves projections + remaining delta) + let tnn ← whnf tn' + let snn ← whnf sn' + if tnn == snn then + cacheResult t s true + return true + + -- 6. Structural comparison on fully-reduced terms + let result ← isDefEqCore tnn snn + + cacheResult t s result + return result + + /-- Check if both terms are proofs of the same Prop type (proof irrelevance). + Returns `none` if inference fails (e.g., free bound variables) or the type isn't Prop. -/ + partial def isDefEqProofIrrel (t s : Expr m) : TypecheckM m (Option Bool) := do + let tType ← try let (_, ty) ← infer t; pure (some ty) catch _ => pure none + let some tType := tType | return none + let tType' ← whnf tType + match tType' with + | .sort .zero => + let sType ← try let (_, ty) ← infer s; pure (some ty) catch _ => pure none + let some sType := sType | return none + let result ← isDefEq tType sType + return some result + | _ => return none + + /-- Core structural comparison after whnf. -/ + partial def isDefEqCore (t s : Expr m) : TypecheckM m Bool := do + match t, s with + -- Sort + | .sort u, .sort u' => pure (Level.equalLevel u u') + + -- Bound variable + | .bvar i _, .bvar j _ => pure (i == j) + + -- Constant + | .const a us _, .const b us' _ => + pure (a == b && equalUnivArrays us us') + + -- Lambda: flatten binder chain to avoid O(num_binders) stack depth + | .lam .., .lam .. => do + let mut a := t + let mut b := s + repeat + match a, b with + | .lam ty body _ _, .lam ty' body' _ _ => + if !(← isDefEq ty ty') then return false + a := body; b := body' + | _, _ => break + isDefEq a b + + -- Pi/ForallE: flatten binder chain to avoid O(num_binders) stack depth + | .forallE .., .forallE .. => do + let mut a := t + let mut b := s + repeat + match a, b with + | .forallE ty body _ _, .forallE ty' body' _ _ => + if !(← isDefEq ty ty') then return false + a := body; b := body' + | _, _ => break + isDefEq a b + + -- Application: flatten app spine to avoid O(num_args) stack depth + | .app .., .app .. => do + let tFn := t.getAppFn + let sFn := s.getAppFn + let tArgs := t.getAppArgs + let sArgs := s.getAppArgs + if tArgs.size != sArgs.size then return false + if !(← isDefEq tFn sFn) then return false + for h : i in [:tArgs.size] do + if !(← isDefEq tArgs[i] sArgs[i]!) then return false + return true + + -- Projection + | .proj a i struct _, .proj b j struct' _ => + if a == b && i == j then isDefEq struct struct' + else pure false + + -- Literals + | .lit l, .lit l' => pure (l == l') + + -- Eta expansion: lambda vs non-lambda + | .lam ty body _ _, _ => do + -- eta: (\x => body) =?= s iff body =?= s x where x = bvar 0 + let sLifted := s.liftBVars 1 + let sApp := Expr.mkApp sLifted (Expr.mkBVar 0) + isDefEq body sApp + + | _, .lam ty body _ _ => do + -- eta: t =?= (\x => body) iff t x =?= body + let tLifted := t.liftBVars 1 + let tApp := Expr.mkApp tLifted (Expr.mkBVar 0) + isDefEq tApp body + + -- Nat literal vs constructor expansion + | .lit (.natVal _), _ => do + let prims := (← read).prims + let expanded := toCtorIfLit prims t + if expanded == t then pure false + else isDefEq expanded s + + | _, .lit (.natVal _) => do + let prims := (← read).prims + let expanded := toCtorIfLit prims s + if expanded == s then pure false + else isDefEq t expanded + + -- String literal vs constructor expansion + | .lit (.strVal str), _ => do + let prims := (← read).prims + let expanded := strLitToConstructor prims str + isDefEq expanded s + + | _, .lit (.strVal str) => do + let prims := (← read).prims + let expanded := strLitToConstructor prims str + isDefEq t expanded + + -- Structure eta + | _, .app _ _ => tryEtaStruct t s + | .app _ _, _ => tryEtaStruct s t + + | _, _ => pure false + + /-- Lazy delta reduction loop. Unfolds definitions one step at a time, + guided by ReducibilityHints, until a conclusive comparison or both + sides are stuck. -/ + partial def lazyDeltaReduction (t s : Expr m) + : TypecheckM m (Expr m × Expr m × Option Bool) := do + let mut tn := t + let mut sn := s + let kenv := (← read).kenv + let mut steps := 0 + repeat + if steps > 10000 then return (tn, sn, none) + steps := steps + 1 + + -- Syntactic check + if tn == sn then return (tn, sn, some true) + + -- Try nat reduction + if let some r := ← tryReduceNat tn then + tn ← whnfCore r; continue + if let some r := ← tryReduceNat sn then + sn ← whnfCore r; continue + + -- Lazy delta step + let tDelta := isDelta tn kenv + let sDelta := isDelta sn kenv + match tDelta, sDelta with + | none, none => return (tn, sn, none) -- both stuck + | some dt, none => + match unfoldDelta dt tn with + | some r => tn ← whnfCore r; continue + | none => return (tn, sn, none) + | none, some ds => + match unfoldDelta ds sn with + | some r => sn ← whnfCore r; continue + | none => return (tn, sn, none) + | some dt, some ds => + let ht := dt.hints + let hs := ds.hints + -- Same head optimization: try comparing args first + if sameHeadConst tn sn && ht.isRegular && hs.isRegular then + if ← isDefEqApp tn sn then return (tn, sn, some true) + if ht.lt' hs then + match unfoldDelta ds sn with + | some r => sn ← whnfCore r; continue + | none => + match unfoldDelta dt tn with + | some r => tn ← whnfCore r; continue + | none => return (tn, sn, none) + else if hs.lt' ht then + match unfoldDelta dt tn with + | some r => tn ← whnfCore r; continue + | none => + match unfoldDelta ds sn with + | some r => sn ← whnfCore r; continue + | none => return (tn, sn, none) + else + -- Same height: unfold both + match unfoldDelta dt tn, unfoldDelta ds sn with + | some rt, some rs => + tn ← whnfCore rt (cheapProj := true) + sn ← whnfCore rs (cheapProj := true) + continue + | some rt, none => tn ← whnfCore rt (cheapProj := true); continue + | none, some rs => sn ← whnfCore rs (cheapProj := true); continue + | none, none => return (tn, sn, none) + return (tn, sn, none) + + /-- Compare arguments of two applications with the same head constant. -/ + partial def isDefEqApp (t s : Expr m) : TypecheckM m Bool := do + let tArgs := t.getAppArgs + let sArgs := s.getAppArgs + if tArgs.size != sArgs.size then return false + -- Also compare universe params + let tFn := t.getAppFn + let sFn := s.getAppFn + match tFn, sFn with + | .const _ us _, .const _ us' _ => + if !equalUnivArrays us us' then return false + | _, _ => pure () + for h : i in [:tArgs.size] do + if !(← isDefEq tArgs[i] sArgs[i]!) then return false + return true + + /-- Try eta expansion for structure-like types. -/ + partial def tryEtaStruct (t s : Expr m) : TypecheckM m Bool := do + -- s should be a constructor application + let sFn := s.getAppFn + match sFn with + | .const ctorAddr _ _ => + match (← read).kenv.find? ctorAddr with + | some (.ctorInfo cv) => + let indAddr := cv.induct + if !(← read).kenv.isStructureLike indAddr then return false + let sArgs := s.getAppArgs + -- Check that each field arg is a projection of t + let numParams := cv.numParams + for h : i in [:cv.numFields] do + let argIdx := numParams + i + if argIdx < sArgs.size then + let arg := sArgs[argIdx]! + match arg with + | .proj a idx struct _ => + if a != indAddr || idx != i then return false + if !(← isDefEq t struct) then return false + | _ => return false + else return false + return true + | _ => return false + | _ => return false + + /-- Cache a def-eq result (both successes and failures). -/ + partial def cacheResult (t s : Expr m) (result : Bool) : TypecheckM m Unit := do + let key := eqCacheKey t s + if result then + modify fun stt => { stt with eqvCache := stt.eqvCache.insert key result } + else + modify fun stt => { stt with failureCache := stt.failureCache.insert key } + end -- mutual /-! ## Top-level entry points -/ @@ -564,21 +795,16 @@ end -- mutual /-- Typecheck a single constant by address. -/ def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) (quotInit : Bool := true) : Except String Unit := - runST fun σ => do - let fuelRef ← ST.mkRef defaultFuel - let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level m) × Value m)) - let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) - let ctx : TypecheckCtx m σ := { - lvl := 0, env := default, types := [], kenv := kenv, - prims := prims, safety := .safe, quotInit := quotInit, - mutTypes := default, recAddr? := none, - fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef - } - let stt : TypecheckState m := { typedConsts := default } - TypecheckM.run ctx stt (checkConst addr) - -/-- Typecheck all constants in a kernel environment. - Uses fresh state per declaration — dependencies are assumed well-typed. -/ + let ctx : TypecheckCtx m := { + types := #[], kenv := kenv, + prims := prims, safety := .safe, quotInit := quotInit, + mutTypes := default, recAddr? := none + } + let stt : TypecheckState m := { typedConsts := default } + let (result, _) := TypecheckM.run ctx stt (checkConst addr) + result + +/-- Typecheck all constants in a kernel environment. -/ def typecheckAll (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) : Except String Unit := do for (addr, ci) in kenv do @@ -592,8 +818,7 @@ def typecheckAll (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) | none => "" throw s!"{header}: {e}\n type: {typ}{val}" -/-- Typecheck all constants with IO progress reporting. - Uses fresh state per declaration — dependencies are assumed well-typed. -/ +/-- Typecheck all constants with IO progress reporting. -/ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) : IO (Except String Unit) := do let mut items : Array (Address × ConstantInfo m) := #[] diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 9fb0d2cd..45385b5a 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -1,144 +1,121 @@ /- TypecheckM: Monad stack, context, state, and utilities for the kernel typechecker. + + Environment-based kernel: no ST, no thunks, no Value domain. + Types and values are Expr m throughout. -/ import Ix.Kernel.Datatypes import Ix.Kernel.Level namespace Ix.Kernel +/-! ## Level substitution on Expr -/ + +/-- Substitute universe level params in an expression using `instBulkReduce`. -/ +def Expr.instantiateLevelParams (e : Expr m) (levels : Array (Level m)) : Expr m := + if levels.isEmpty then e + else e.instantiateLevelParamsBy (Level.instBulkReduce levels) + /-! ## Typechecker Context -/ -structure TypecheckCtx (m : MetaMode) (σ : Type) where - lvl : Nat - env : ValEnv m - types : List (SusValue m) +structure TypecheckCtx (m : MetaMode) where + /-- Type of each bound variable, indexed by de Bruijn index. + types[0] is the type of bvar 0 (most recently bound). -/ + types : Array (Expr m) kenv : Env m prims : Primitives safety : DefinitionSafety quotInit : Bool - /-- Maps a variable index (mutual reference) to (address, type-value function). -/ - mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare + /-- Maps a variable index (mutual reference) to (address, type function). -/ + mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → Expr m)) compare /-- Tracks the address of the constant currently being checked, for recursion detection. -/ recAddr? : Option Address - /-- Depth fuel: bounds the call-stack depth to prevent native stack overflow. - Decremented via the reader on each entry to eval/equal/infer. - Thunks inherit the depth from their capture point. -/ - depth : Nat := 10000 /-- Enable dbg_trace on major entry points for debugging. -/ trace : Bool := false - /-- Global fuel counter: bounds total recursive work across all thunks via ST.Ref. -/ - fuelRef : ST.Ref σ Nat - /-- Mutable eval cache: persists across thunk evaluations via ST.Ref. -/ - evalCacheRef : ST.Ref σ (Std.HashMap Address (Array (Level m) × Value m)) - /-- Mutable equality cache: persists across thunk evaluations via ST.Ref. -/ - equalCacheRef : ST.Ref σ (Std.HashMap (USize × USize) Bool) /-! ## Typechecker State -/ /-- Default fuel for bounding total recursive work per constant. -/ -def defaultFuel : Nat := 200000 +def defaultFuel : Nat := 1_000_000 structure TypecheckState (m : MetaMode) where - typedConsts : Std.TreeMap Address (TypedConst m) Address.compare - /-- Cache for constant type SusValues. When `infer (.const addr _)` computes a - suspended type, it is cached here so repeated references to the same constant - share the same SusValue pointer, enabling fast-path pointer equality in `equal`. - Stores universe parameters alongside the value for correctness with polymorphic constants. -/ - constTypeCache : Std.HashMap Address (List (Level m) × SusValue m) := {} + typedConsts : Std.TreeMap Address (TypedConst m) Address.compare + whnfCache : Std.HashMap (Expr m) (Expr m) := {} + /-- Cache for structural-only WHNF (whnfCore with cheapRec=false, cheapProj=false). + Separate from whnfCache to avoid stale entries from cheap reductions. -/ + whnfCoreCache : Std.HashMap (Expr m) (Expr m) := {} + /-- Infer cache: maps term → (binding context, inferred type). + Keyed on Expr only; context verified on retrieval via ptr equality + BEq fallback. -/ + inferCache : Std.HashMap (Expr m) (Array (Expr m) × Expr m) := {} + eqvCache : Std.HashMap (Expr m × Expr m) Bool := {} + failureCache : Std.HashSet (Expr m × Expr m) := {} + constTypeCache : Std.HashMap Address (Array (Level m) × Expr m) := {} + fuel : Nat := defaultFuel + /-- Tracks nesting depth of whnf calls from within recursor reduction (tryReduceApp → whnf). + When this exceeds a threshold, whnfCore is used instead of whnf to prevent stack overflow. -/ + whnfDepth : Nat := 0 deriving Inhabited /-! ## TypecheckM monad -/ -abbrev TypecheckM (m : MetaMode) (σ : Type) := - ReaderT (TypecheckCtx m σ) (ExceptT String (StateT (TypecheckState m) (ST σ))) - -def TypecheckM.run (ctx : TypecheckCtx m σ) (stt : TypecheckState m) - (x : TypecheckM m σ α) : ST σ (Except String α) := do - let (result, _) ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt - pure result +abbrev TypecheckM (m : MetaMode) := + ReaderT (TypecheckCtx m) (ExceptT String (StateM (TypecheckState m))) -def TypecheckM.runState (ctx : TypecheckCtx m σ) (stt : TypecheckState m) (x : TypecheckM m σ α) - : ST σ (Except String (α × TypecheckState m)) := do - let (result, stt') ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt - pure (match result with | .ok a => .ok (a, stt') | .error e => .error e) - -/-! ## pureRunST -/ - -/-- Unsafe bridge: run ST σ from pure code (for Thunk bodies). - Safe because the only side effects are append-only cache mutations. -/ -@[inline] unsafe def pureRunSTImpl {σ α : Type} [Inhabited α] (x : ST σ α) : α := - (x (unsafeCast ())).val - -@[implemented_by pureRunSTImpl] -opaque pureRunST {σ α : Type} [Inhabited α] : ST σ α → α +def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) + (x : TypecheckM m α) : Except String α × TypecheckState m := + let (result, stt') := StateT.run (ExceptT.run (ReaderT.run x ctx)) stt + (result, stt') /-! ## Context modifiers -/ -def withEnv (env : ValEnv m) : TypecheckM m σ α → TypecheckM m σ α := - withReader fun ctx => { ctx with env := env } - -def withResetCtx : TypecheckM m σ α → TypecheckM m σ α := +def withResetCtx : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with - lvl := 0, env := default, types := default, mutTypes := default, recAddr? := none } + types := #[], mutTypes := default, recAddr? := none } -def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare) : - TypecheckM m σ α → TypecheckM m σ α := +def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → Expr m)) compare) : + TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with mutTypes := mutTypes } -def withExtendedCtx (val typ : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := - withReader fun ctx => { ctx with - lvl := ctx.lvl + 1, - types := typ :: ctx.types, - env := ctx.env.extendWith val } - -def withExtendedEnv (thunk : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := - withReader fun ctx => { ctx with env := ctx.env.extendWith thunk } +/-- Extend the context with a new bound variable of the given type. -/ +def withExtendedCtx (varType : Expr m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with types := ctx.types.push varType } -def withNewExtendedEnv (env : ValEnv m) (thunk : SusValue m) : - TypecheckM m σ α → TypecheckM m σ α := - withReader fun ctx => { ctx with env := env.extendWith thunk } - -def withRecAddr (addr : Address) : TypecheckM m σ α → TypecheckM m σ α := +def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with recAddr? := some addr } -/-- Check both fuel counters, decrement them, and run the action. - - State fuel bounds total work (prevents exponential blowup / hanging). - - Reader depth bounds call-stack depth (prevents native stack overflow). -/ -def withFuelCheck (action : TypecheckM m σ α) : TypecheckM m σ α := do - let ctx ← read - if ctx.depth == 0 then - throw "deep recursion depth limit reached" - let fuel ← ctx.fuelRef.get - if fuel == 0 then throw "deep recursion fuel limit reached" - let _ ← ctx.fuelRef.set (fuel - 1) - withReader (fun ctx => { ctx with depth := ctx.depth - 1 }) action +/-- The current binding depth (number of bound variables in scope). -/ +def lvl : TypecheckM m Nat := do pure (← read).types.size + +/-- Check fuel and decrement it. -/ +def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do + let stt ← get + if stt.fuel == 0 then throw "deep recursion fuel limit reached" + modify fun s => { s with fuel := s.fuel - 1 } + action /-! ## Name lookup -/ /-- Look up the MetaField name for a constant address from the kernel environment. -/ -def lookupName (addr : Address) : TypecheckM m σ (MetaField m Ix.Name) := do +def lookupName (addr : Address) : TypecheckM m (MetaField m Ix.Name) := do match (← read).kenv.find? addr with | some ci => pure ci.cv.name | none => pure default /-! ## Const dereferencing -/ -def derefConst (addr : Address) : TypecheckM m σ (ConstantInfo m) := do - let ctx ← read - match ctx.kenv.find? addr with +def derefConst (addr : Address) : TypecheckM m (ConstantInfo m) := do + match (← read).kenv.find? addr with | some ci => pure ci | none => throw s!"unknown constant {addr}" -def derefTypedConst (addr : Address) : TypecheckM m σ (TypedConst m) := do +def derefTypedConst (addr : Address) : TypecheckM m (TypedConst m) := do match (← get).typedConsts.get? addr with | some tc => pure tc | none => throw s!"typed constant not found: {addr}" /-! ## Provisional TypedConst -/ -/-- Extract the major premise's inductive address from a recursor type. - Skips numParams + numMotives + numMinors + numIndices foralls, - then the next forall's domain's app head is the inductive const. -/ +/-- Extract the major premise's inductive address from a recursor type. -/ def getMajorInduct (type : Expr m) (numParams numMotives numMinors numIndices : Nat) : Option Address := go (numParams + numMotives + numMinors + numIndices) type where @@ -150,10 +127,7 @@ where | .forallE _ body _ _ => go n body | _ => none -/-- Build a provisional TypedConst entry from raw ConstantInfo. - Used when `infer` encounters a `.const` reference before the constant - has been fully typechecked. The entry uses default TypeInfo and raw - expressions directly from the kernel environment. -/ +/-- Build a provisional TypedConst entry from raw ConstantInfo. -/ def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := let rawType : TypedExpr m := ⟨default, ci.type⟩ match ci with @@ -164,7 +138,7 @@ def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := | .opaqueInfo v => .opaque rawType ⟨default, v.value⟩ | .quotInfo v => .quotient rawType v.kind | .inductInfo v => - let isStruct := v.ctors.size == 1 -- approximate; refined by checkIndBlock + let isStruct := v.ctors.size == 1 .inductive rawType isStruct | .ctorInfo v => .constructor rawType v.cidx v.numFields | .recInfo v => @@ -173,14 +147,23 @@ def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := let typedRules := v.rules.map fun r => (r.nfields, (⟨default, r.rhs⟩ : TypedExpr m)) .recursor rawType v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules -/-- Ensure a constant has a TypedConst entry. If not already present, build a - provisional one from raw ConstantInfo. This avoids the deep recursion of - `checkConst` when called from `infer`. -/ -def ensureTypedConst (addr : Address) : TypecheckM m σ Unit := do +/-- Ensure a constant has a TypedConst entry. -/ +def ensureTypedConst (addr : Address) : TypecheckM m Unit := do if (← get).typedConsts.get? addr |>.isSome then return () let ci ← derefConst addr let tc := provisionalTypedConst ci modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr tc } +/-! ## Def-eq cache helpers -/ + +instance : Hashable (Expr m × Expr m) where + hash p := mixHash (Hashable.hash p.1) (Hashable.hash p.2) + +/-- Symmetric cache key for def-eq pairs. Orders by structural hash to make key(a,b) == key(b,a). -/ +def eqCacheKey (a b : Expr m) : Expr m × Expr m := + let ha := Hashable.hash a + let hb := Hashable.hash b + if ha ≤ hb then (a, b) else (b, a) + end Ix.Kernel diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index fba45b00..6a8ff1d1 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -240,8 +240,133 @@ def tag : Expr m → String | .lit (.strVal s) => s!"strLit({s})" | .proj _ idx _ _ => s!"proj({idx})" +/-! ### Substitution helpers -/ + +/-- Lift free bvar indices by `n`. Under `depth` binders, bvars < depth are + bound and stay; bvars >= depth are free and get shifted by n. -/ +partial def liftBVars (e : Expr m) (n : Nat) (depth : Nat := 0) : Expr m := + if n == 0 then e + else go e depth +where + go (e : Expr m) (d : Nat) : Expr m := + match e with + | .bvar idx name => if idx >= d then .bvar (idx + n) name else e + | .app fn arg => .app (go fn d) (go arg d) + | .lam ty body name bi => .lam (go ty d) (go body (d + 1)) name bi + | .forallE ty body name bi => .forallE (go ty d) (go body (d + 1)) name bi + | .letE ty val body name => .letE (go ty d) (go val d) (go body (d + 1)) name + | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct d) typeName + | .sort .. | .const .. | .lit .. => e + +/-- Bulk substitution: replace bvar i with subst[i] for i < subst.size. + Free bvars (i >= subst.size) become bvar (i - subst.size). + Under binders, substitution values are lifted appropriately. -/ +partial def instantiate (e : Expr m) (subst : Array (Expr m)) : Expr m := + if subst.isEmpty then e + else go e 0 +where + go (e : Expr m) (shift : Nat) : Expr m := + match e with + | .bvar idx name => + if idx < shift then e -- bound by inner binder + else + let realIdx := idx - shift + if h : realIdx < subst.size then + (subst[realIdx]).liftBVars shift + else + .bvar (idx - subst.size) name + | .app fn arg => .app (go fn shift) (go arg shift) + | .lam ty body name bi => .lam (go ty shift) (go body (shift + 1)) name bi + | .forallE ty body name bi => .forallE (go ty shift) (go body (shift + 1)) name bi + | .letE ty val body name => .letE (go ty shift) (go val shift) (go body (shift + 1)) name + | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct shift) typeName + | .sort .. | .const .. | .lit .. => e + +/-- Single substitution: replace bvar 0 with val. -/ +def instantiate1 (body val : Expr m) : Expr m := body.instantiate #[val] + +/-- Substitute universe level params in an expression's Level nodes using a given + level substitution function. -/ +partial def instantiateLevelParamsBy (e : Expr m) (substFn : Level m → Level m) : Expr m := + go e +where + go (e : Expr m) : Expr m := + match e with + | .sort lvl => .sort (substFn lvl) + | .const addr ls name => .const addr (ls.map substFn) name + | .app fn arg => .app (go fn) (go arg) + | .lam ty body name bi => .lam (go ty) (go body) name bi + | .forallE ty body name bi => .forallE (go ty) (go body) name bi + | .letE ty val body name => .letE (go ty) (go val) (go body) name + | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct) typeName + | .bvar .. | .lit .. => e + +/-- Check if expression has any bvars with index >= depth. -/ +partial def hasLooseBVarsAbove (e : Expr m) (depth : Nat) : Bool := + match e with + | .bvar idx _ => idx >= depth + | .app fn arg => hasLooseBVarsAbove fn depth || hasLooseBVarsAbove arg depth + | .lam ty body _ _ => hasLooseBVarsAbove ty depth || hasLooseBVarsAbove body (depth + 1) + | .forallE ty body _ _ => hasLooseBVarsAbove ty depth || hasLooseBVarsAbove body (depth + 1) + | .letE ty val body _ => + hasLooseBVarsAbove ty depth || hasLooseBVarsAbove val depth || hasLooseBVarsAbove body (depth + 1) + | .proj _ _ struct _ => hasLooseBVarsAbove struct depth + | .sort .. | .const .. | .lit .. => false + +/-- Does the expression have any loose (free) bvars? -/ +def hasLooseBVars (e : Expr m) : Bool := e.hasLooseBVarsAbove 0 + +/-- Accessor for binding name. -/ +def bindingName! : Expr m → MetaField m Ix.Name + | forallE _ _ n _ => n | lam _ _ n _ => n | _ => panic! "bindingName!" + +/-- Accessor for binding binder info. -/ +def bindingInfo! : Expr m → MetaField m Lean.BinderInfo + | forallE _ _ _ bi => bi | lam _ _ _ bi => bi | _ => panic! "bindingInfo!" + +/-- Accessor for letE name. -/ +def letName! : Expr m → MetaField m Ix.Name + | letE _ _ _ n => n | _ => panic! "letName!" + +/-- Accessor for letE type. -/ +def letType! : Expr m → Expr m + | letE ty _ _ _ => ty | _ => panic! "letType!" + +/-- Accessor for letE value. -/ +def letValue! : Expr m → Expr m + | letE _ v _ _ => v | _ => panic! "letValue!" + +/-- Accessor for letE body. -/ +def letBody! : Expr m → Expr m + | letE _ _ b _ => b | _ => panic! "letBody!" + end Expr +/-! ## Hashable instances -/ + +partial def Level.hash : Level m → UInt64 + | .zero => 7 + | .succ l => mixHash 13 (Level.hash l) + | .max l₁ l₂ => mixHash 17 (mixHash (Level.hash l₁) (Level.hash l₂)) + | .imax l₁ l₂ => mixHash 23 (mixHash (Level.hash l₁) (Level.hash l₂)) + | .param idx _ => mixHash 29 (Hashable.hash idx) + +instance : Hashable (Level m) where hash := Level.hash + +partial def Expr.hash : Expr m → UInt64 + | .bvar idx _ => mixHash 31 (Hashable.hash idx) + | .sort lvl => mixHash 37 (Level.hash lvl) + | .const addr lvls _ => mixHash 41 (mixHash (Hashable.hash addr) (lvls.foldl (fun h l => mixHash h (Level.hash l)) 0)) + | .app fn arg => mixHash 43 (mixHash (Expr.hash fn) (Expr.hash arg)) + | .lam ty body _ _ => mixHash 47 (mixHash (Expr.hash ty) (Expr.hash body)) + | .forallE ty body _ _ => mixHash 53 (mixHash (Expr.hash ty) (Expr.hash body)) + | .letE ty val body _ => mixHash 59 (mixHash (Expr.hash ty) (mixHash (Expr.hash val) (Expr.hash body))) + | .lit (.natVal n) => mixHash 61 (Hashable.hash n) + | .lit (.strVal s) => mixHash 67 (Hashable.hash s) + | .proj addr idx struct _ => mixHash 71 (mixHash (Hashable.hash addr) (mixHash (Hashable.hash idx) (Expr.hash struct))) + +instance : Hashable (Expr m) where hash := Expr.hash + /-! ## Enums -/ inductive DefinitionSafety where diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean new file mode 100644 index 00000000..591b66d7 --- /dev/null +++ b/Ix/Kernel/Whnf.lean @@ -0,0 +1,538 @@ +/- + Kernel Whnf: Environment-based weak head normal form reduction. + + Works directly on `Expr m` with deferred substitution via closures. +-/ +import Ix.Kernel.TypecheckM + +namespace Ix.Kernel + +open Level (instBulkReduce reduceIMax) + +/-! ## Helpers -/ + +/-- Check if an address is a primitive operation that takes arguments. -/ +private def isPrimOp (prims : Primitives) (addr : Address) : Bool := + addr == prims.natAdd || addr == prims.natSub || addr == prims.natMul || + addr == prims.natPow || addr == prims.natGcd || addr == prims.natMod || + addr == prims.natDiv || addr == prims.natBeq || addr == prims.natBle || + addr == prims.natLand || addr == prims.natLor || addr == prims.natXor || + addr == prims.natShiftLeft || addr == prims.natShiftRight || + addr == prims.natSucc + +/-- Look up element in a list by index. -/ +def listGet? (l : List α) (n : Nat) : Option α := + match l, n with + | [], _ => none + | a :: _, 0 => some a + | _ :: l, n+1 => listGet? l n + +/-! ## Nat primitive reduction on Expr -/ + +/-- Try to reduce a Nat primitive applied to literal arguments. Returns the reduced Expr. -/ +def tryReduceNat (e : Expr m) : TypecheckM m (Option (Expr m)) := do + let fn := e.getAppFn + match fn with + | .const addr _ _ => + let prims := (← read).prims + if !isPrimOp prims addr then return none + let args := e.getAppArgs + -- Nat.succ: 1 arg + if addr == prims.natSucc then + if args.size >= 1 then + match args[0]! with + | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) + | _ => return none + else return none + -- Binary nat operations: 2 args + else if args.size >= 2 then + match args[0]!, args[1]! with + | .lit (.natVal x), .lit (.natVal y) => + if addr == prims.natAdd then return some (.lit (.natVal (x + y))) + else if addr == prims.natSub then return some (.lit (.natVal (x - y))) + else if addr == prims.natMul then return some (.lit (.natVal (x * y))) + else if addr == prims.natPow then + if y > 16777216 then return none + return some (.lit (.natVal (Nat.pow x y))) + else if addr == prims.natMod then return some (.lit (.natVal (x % y))) + else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) + else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) + else if addr == prims.natBeq then + let boolAddr := if x == y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natBle then + let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) + else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) + else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) + else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) + else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) + else return none + | _, _ => return none + else return none + | _ => return none + +/-- Convert a nat literal to Nat.succ/Nat.zero constructor expressions. -/ +def toCtorIfLit (prims : Primitives) : Expr m → Expr m + | .lit (.natVal 0) => Expr.mkConst prims.natZero #[] + | .lit (.natVal (n+1)) => + Expr.mkApp (Expr.mkConst prims.natSucc #[]) (.lit (.natVal n)) + | e => e + +/-- Expand a string literal to its constructor form: String.mk (list-of-chars). -/ +def strLitToConstructor (prims : Primitives) (s : String) : Expr m := + let mkCharOfNat (c : Char) : Expr m := + Expr.mkApp (Expr.mkConst prims.charMk #[]) (.lit (.natVal c.toNat)) + let charType : Expr m := Expr.mkConst prims.char #[] + let nilVal : Expr m := + Expr.mkApp (Expr.mkConst prims.listNil #[.zero]) charType + let listVal := s.toList.foldr (fun c acc => + let head := mkCharOfNat c + Expr.mkApp (Expr.mkApp (Expr.mkApp (Expr.mkConst prims.listCons #[.zero]) charType) head) acc + ) nilVal + Expr.mkApp (Expr.mkConst prims.stringMk #[]) listVal + +/-! ## WHNF core (structural reduction) -/ + +/-- Reduce a projection if the struct is a constructor application. -/ +partial def reduceProj (typeAddr : Address) (idx : Nat) (struct : Expr m) : TypecheckM m (Option (Expr m)) := do + -- Expand string literals to constructor form before projecting + let prims := (← read).prims + let struct' := match struct with + | .lit (.strVal s) => strLitToConstructor prims s + | e => e + let fn := struct'.getAppFn + match fn with + | .const ctorAddr _ _ => do + match (← read).kenv.find? ctorAddr with + | some (.ctorInfo v) => + let args := struct'.getAppArgs + let realIdx := v.numParams + idx + if h : realIdx < args.size then + return some args[realIdx] + else + return none + | _ => return none + | _ => return none + +mutual + /-- Structural WHNF: beta, let-zeta, iota-proj. No delta unfolding. + Uses an iterative loop to avoid deep stack usage: + - App spines are collected iteratively (not recursively) + - Beta/let/iota/proj results loop back instead of tail-calling + When cheapProj=true, projections are returned as-is (no struct reduction). + When cheapRec=true, recursor applications are returned as-is (no iota reduction). -/ + partial def whnfCore (e : Expr m) (cheapRec := false) (cheapProj := false) + : TypecheckM m (Expr m) := do + -- Cache lookup (only for full structural reduction, not cheap) + let useCache := !cheapRec && !cheapProj + if useCache then + if let some r := (← get).whnfCoreCache.get? e then return r + let r ← whnfCoreImpl e cheapRec cheapProj + if useCache then + modify fun s => { s with whnfCoreCache := s.whnfCoreCache.insert e r } + pure r + + partial def whnfCoreImpl (e : Expr m) (cheapRec : Bool) (cheapProj : Bool) + : TypecheckM m (Expr m) := do + let mut t := e + repeat + -- Fuel check + let stt ← get + if stt.fuel == 0 then throw "deep recursion fuel limit reached" + modify fun s => { s with fuel := s.fuel - 1 } + match t with + | .app .. => do + -- Collect app args iteratively (O(1) stack for app spine) + let args := t.getAppArgs + let fn := t.getAppFn + let fn' ← whnfCore fn cheapRec cheapProj -- recurse only on non-app head + -- Beta-reduce: consume as many args as possible + let mut result := fn' + let mut i : Nat := 0 + while i < args.size do + match result with + | .lam _ body _ _ => + result := body.instantiate1 args[i]! + i := i + 1 + | _ => break + if i > 0 then + -- Beta reductions happened. Apply remaining args and loop. + for h : j in [i:args.size] do + result := Expr.mkApp result args[j]! + t := result; continue -- loop instead of recursive tail call + else + -- No beta reductions. Try recursor/proj reduction. + let e' := if fn == fn' then t else fn'.mkAppN args + if cheapRec then return e' -- skip recursor reduction + let r ← tryReduceApp e' + if r == e' then return r -- stuck, return + t := r; continue -- iota/quot reduced, loop to re-process + | .letE _ val body _ => + t := body.instantiate1 val; continue -- loop instead of recursion + | .proj typeAddr idx struct _ => do + if cheapProj then return t -- skip projection reduction + let struct' ← whnfCore struct cheapRec cheapProj + match ← reduceProj typeAddr idx struct' with + | some result => t := result; continue -- loop instead of recursion + | none => + return if struct == struct' then t else .proj typeAddr idx struct' default + | _ => return t + return t -- unreachable, but needed for type checking + + /-- Try to reduce an application whose head is in WHNF. + Handles recursor iota-reduction and quotient reduction. -/ + partial def tryReduceApp (e : Expr m) : TypecheckM m (Expr m) := do + let fn := e.getAppFn + match fn with + | .const addr _ _ => do + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.recursor _ params motives minors indices isK indAddr rules) => + let args := e.getAppArgs + let majorIdx := params + motives + minors + indices + if h : majorIdx < args.size then + let major := args[majorIdx] + let major' ← whnf major + if isK then + tryKReduction e addr args major' params motives indAddr + else + tryIotaReduction e addr args major' params indices indAddr rules motives minors + else pure e + | some (.quotient _ kind) => + match kind with + | .lift => tryQuotReduction e 6 3 + | .ind => tryQuotReduction e 5 3 + | _ => pure e + | _ => pure e + | _ => pure e + + /-- K-reduction: for Prop inductives with single zero-field constructor. -/ + partial def tryKReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) + (major : Expr m) (params motives : Nat) (indAddr : Address) + : TypecheckM m (Expr m) := do + let ctx ← read + let prims := ctx.prims + let kenv := ctx.kenv + -- Check if major is a constructor + let majorCtor := toCtorIfLit prims major + let isCtor := match majorCtor.getAppFn with + | .const ctorAddr _ _ => + match kenv.find? ctorAddr with + | some (.ctorInfo _) => true + | _ => false + | _ => false + -- Also check if the inductive is in Prop + let isPropInd := match kenv.find? indAddr with + | some (.inductInfo v) => + let rec getSort : Expr m → Bool + | .forallE _ body _ _ => getSort body + | .sort (.zero) => true + | _ => false + getSort v.type + | _ => false + if isCtor || isPropInd then + -- K-reduction: return the (only) minor premise + let minorIdx := params + motives + if h : minorIdx < args.size then + return args[minorIdx] + pure e + else pure e + + /-- Iota-reduction: reduce a recursor applied to a constructor. + Follows the lean4 algorithm: + 1. Apply params + motives + minors from recursor args to rule RHS + 2. Apply constructor fields (skip constructor params) to rule RHS + 3. Apply extra args after major premise to rule RHS + Beta reduction happens in the subsequent whnfCore call. -/ + partial def tryIotaReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) + (major : Expr m) (params indices : Nat) (indAddr : Address) + (rules : Array (Nat × TypedExpr m)) + (motives minors : Nat) : TypecheckM m (Expr m) := do + let prims := (← read).prims + -- Skip large nat literals to avoid O(n) overhead + let skipLargeNat := match major with + | .lit (.natVal n) => indAddr == prims.nat && n > 256 + | _ => false + if skipLargeNat then return e + let majorCtor := toCtorIfLit prims major + let majorFn := majorCtor.getAppFn + match majorFn with + | .const ctorAddr _ _ => do + let kenv := (← read).kenv + let typedConsts := (← get).typedConsts + let ctorInfo? := match kenv.find? ctorAddr with + | some (.ctorInfo cv) => some (cv.cidx, cv.numFields) + | _ => + match typedConsts.get? ctorAddr with + | some (.constructor _ ctorIdx numFields) => some (ctorIdx, numFields) + | _ => none + match ctorInfo? with + | some (ctorIdx, _) => + match rules[ctorIdx]? with + | some (nfields, rhs) => + let majorArgs := majorCtor.getAppArgs + if nfields > majorArgs.size then return e + -- Instantiate universe level params in the rule RHS + let recFn := e.getAppFn + let recLevels := recFn.constLevels! + let mut result := rhs.body.instantiateLevelParams recLevels + -- Phase 1: Apply params + motives + minors from recursor args + let pmmEnd := params + motives + minors + result := result.mkAppRange 0 pmmEnd args + -- Phase 2: Apply constructor fields (skip constructor's own params) + let ctorParamCount := majorArgs.size - nfields + result := result.mkAppRange ctorParamCount majorArgs.size majorArgs + -- Phase 3: Apply remaining arguments after major premise + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + pure result -- return raw result; whnfCore's loop will re-process + | none => pure e + | none => + -- Not a constructor, try structure eta + tryStructEta e args indices indAddr rules major motives minors + | _ => + tryStructEta e args indices indAddr rules major motives minors + + /-- Structure eta: expand struct-like major via projections. -/ + partial def tryStructEta (e : Expr m) (args : Array (Expr m)) + (indices : Nat) (indAddr : Address) + (rules : Array (Nat × TypedExpr m)) (major : Expr m) + (motives minors : Nat) : TypecheckM m (Expr m) := do + let kenv := (← read).kenv + if !kenv.isStructureLike indAddr then return e + match rules[0]? with + | some (nfields, rhs) => + let recFn := e.getAppFn + let recLevels := recFn.constLevels! + let params := args.size - motives - minors - indices - 1 + let mut result := rhs.body.instantiateLevelParams recLevels + -- Phase 1: params + motives + minors + let pmmEnd := params + motives + minors + result := result.mkAppRange 0 pmmEnd args + -- Phase 2: projections as fields + let mut projArgs : Array (Expr m) := #[] + for i in [:nfields] do + projArgs := projArgs.push (Expr.mkProj indAddr i major) + result := projArgs.foldl (fun acc a => Expr.mkApp acc a) result + -- Phase 3: extra args after major + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + pure result -- return raw result; whnfCore's loop will re-process + | none => pure e + + /-- Quotient reduction: Quot.lift / Quot.ind. + For Quot.lift: `@Quot.lift α r β f h q` — reduceSize=6, fPos=3 (f is at index 3) + For Quot.ind: `@Quot.ind α r β f q` — reduceSize=5, fPos=3 (f is at index 3) + When major (q) reduces to `@Quot.mk α r a`, result is `f a`. -/ + partial def tryQuotReduction (e : Expr m) (reduceSize fPos : Nat) : TypecheckM m (Expr m) := do + let args := e.getAppArgs + if args.size < reduceSize then return e + let majorIdx := reduceSize - 1 + if h : majorIdx < args.size then + let major := args[majorIdx] + let major' ← whnf major + let majorFn := major'.getAppFn + match majorFn with + | .const majorAddr _ _ => + ensureTypedConst majorAddr + match (← get).typedConsts.get? majorAddr with + | some (.quotient _ .ctor) => + let majorArgs := major'.getAppArgs + -- Quot.mk has 3 args: [α, r, a]. The data 'a' is the last one. + if majorArgs.size < 3 then throw "Quot.mk should have at least 3 args" + let dataArg := majorArgs[majorArgs.size - 1]! + if h2 : fPos < args.size then + let f := args[fPos] + let result := Expr.mkApp f dataArg + -- Apply any extra args after the major premise + let result := if majorIdx + 1 < args.size then + result.mkAppRange (majorIdx + 1) args.size args + else result + pure result -- return raw result; whnfCore's loop will re-process + else return e + | _ => return e + | _ => return e + else return e + + /-- Full WHNF with delta unfolding loop. + whnfCore handles structural reduction (beta, let, iota, cheap proj). + This loop adds: nat primitives, stuck projection resolution, delta unfolding. + Projection chains are flattened to avoid deep recursion: + proj₁(proj₂(proj₃(struct))) → strip all projs, whnf(struct) ONCE, + then resolve projections iteratively from inside out. + Tracks nesting depth: when whnf calls nest too deep (from isDefEq ↔ whnf cycles), + degrades to whnfCore to prevent native stack overflow. -/ + partial def whnf (e : Expr m) : TypecheckM m (Expr m) := withFuelCheck do + -- Depth guard: when whnf nesting is too deep, degrade to structural-only + let depth := (← get).whnfDepth + if depth > 64 then return ← whnfCore e + modify fun s => { s with whnfDepth := s.whnfDepth + 1 } + let r ← whnfImpl e + modify fun s => { s with whnfDepth := s.whnfDepth - 1 } + pure r + + partial def whnfImpl (e : Expr m) : TypecheckM m (Expr m) := do + -- Check cache + if let some r := (← get).whnfCache.get? e then return r + let mut t ← whnfCore e + let mut steps := 0 + repeat + if steps > 10000 then break -- safety bound + -- Try nat primitive reduction + if let some r := ← tryReduceNat t then + t ← whnfCore r; steps := steps + 1; continue + -- Handle stuck projections (including inside app chains). + -- Flatten nested projection chains to avoid deep whnf→whnf recursion. + match t.getAppFn with + | .proj _ _ _ _ => + -- Collect the projection chain from outside in + let mut projStack : Array (Address × Nat × Array (Expr m)) := #[] + let mut inner := t + repeat + match inner.getAppFn with + | .proj typeAddr idx struct _ => + projStack := projStack.push (typeAddr, idx, inner.getAppArgs) + inner := struct + | _ => break + -- Reduce the innermost struct with depth-guarded whnf + let innerReduced ← whnf inner + -- Resolve projections from inside out (last pushed = innermost) + let mut current := innerReduced + let mut allResolved := true + let mut i := projStack.size + while i > 0 do + i := i - 1 + let (typeAddr, idx, args) := projStack[i]! + match ← reduceProj typeAddr idx current with + | some result => + let applied := if args.isEmpty then result else result.mkAppN args + current ← whnfCore applied + | none => + -- This projection couldn't be resolved. Reconstruct remaining chain. + let stuck := if args.isEmpty then + Expr.mkProj typeAddr idx current + else + (Expr.mkProj typeAddr idx current).mkAppN args + current ← whnfCore stuck + -- Reconstruct outer projections + while i > 0 do + i := i - 1 + let (ta, ix, as) := projStack[i]! + current := if as.isEmpty then + Expr.mkProj ta ix current + else + (Expr.mkProj ta ix current).mkAppN as + allResolved := false + break + if allResolved || current != t then + t := current; steps := steps + 1; continue + | _ => pure () + -- Try delta unfolding + if let some r := ← unfoldDefinition t then + t ← whnfCore r; steps := steps + 1; continue + break + modify fun s => { s with whnfCache := s.whnfCache.insert e t } + pure t + + /-- Unfold a single delta step (definition body). -/ + partial def unfoldDefinition (e : Expr m) : TypecheckM m (Option (Expr m)) := do + let head := e.getAppFn + match head with + | .const addr levels _ => do + let ci ← derefConst addr + match ci with + | .defnInfo v => + if v.safety == .partial then return none + let body := v.value.instantiateLevelParams levels + let args := e.getAppArgs + return some (body.mkAppN args) + | .thmInfo v => + let body := v.value.instantiateLevelParams levels + let args := e.getAppArgs + return some (body.mkAppN args) + | _ => return none + | _ => return none +end + +/-! ## Literal folding for pretty printing -/ + +/-- Try to extract a Char from a Char.ofNat application in an Expr. -/ +private partial def tryFoldChar (prims : Primitives) (e : Expr m) : Option Char := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.charMk then + let args := e.getAppArgs + if args.size == 1 then + match args[0]! with + | .lit (.natVal n) => some (Char.ofNat n) + | _ => none + else none + else none + | _ => none + +/-- Try to extract a List Char from a List.cons/List.nil chain in an Expr. -/ +private partial def tryFoldCharList (prims : Primitives) (e : Expr m) : Option (List Char) := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.listNil then some [] + else if addr == prims.listCons then + let args := e.getAppArgs + if args.size == 3 then + match tryFoldChar prims args[1]!, tryFoldCharList prims args[2]! with + | some c, some cs => some (c :: cs) + | _, _ => none + else none + else none + | _ => none + +/-- Walk an Expr and fold Nat.zero/Nat.succ chains to nat literals, + and String.mk (char list) to string literals. -/ +partial def foldLiterals (prims : Primitives) : Expr m → Expr m + | .const addr lvls name => + if addr == prims.natZero then .lit (.natVal 0) + else .const addr lvls name + | .app fn arg => + let fn' := foldLiterals prims fn + let arg' := foldLiterals prims arg + let e := Expr.app fn' arg' + match e.getAppFn with + | .const addr _ _ => + if addr == prims.natSucc && e.getAppNumArgs == 1 then + match e.appArg! with + | .lit (.natVal n) => .lit (.natVal (n + 1)) + | _ => e + else if addr == prims.stringMk && e.getAppNumArgs == 1 then + match tryFoldCharList prims e.appArg! with + | some cs => .lit (.strVal (String.ofList cs)) + | none => e + else e + | _ => e + | .lam ty body n bi => + .lam (foldLiterals prims ty) (foldLiterals prims body) n bi + | .forallE ty body n bi => + .forallE (foldLiterals prims ty) (foldLiterals prims body) n bi + | .letE ty val body n => + .letE (foldLiterals prims ty) (foldLiterals prims val) (foldLiterals prims body) n + | .proj ta idx s tn => + .proj ta idx (foldLiterals prims s) tn + | e => e + +/-! ## isDelta helper -/ + +/-- Check if an expression's head is a delta-reducible constant. + Returns the DefinitionVal if so. -/ +def isDelta (e : Expr m) (kenv : Env m) : Option (ConstantInfo m) := + match e.getAppFn with + | .const addr _ _ => + match kenv.find? addr with + | some ci@(.defnInfo v) => + if v.safety == .partial then none else some ci + | some ci@(.thmInfo _) => some ci + | _ => none + | _ => none + +end Ix.Kernel diff --git a/Tests/Ix/Check.lean b/Tests/Ix/Check.lean index 404b478d..99a9bcc1 100644 --- a/Tests/Ix/Check.lean +++ b/Tests/Ix/Check.lean @@ -1,6 +1,6 @@ /- Kernel type-checker integration tests. - Tests both the Rust kernel (via FFI) and the Lean NbE kernel. + Tests both the Rust kernel (via FFI) and the Lean kernel. -/ import Ix.Kernel @@ -54,39 +54,39 @@ def testCheckConst (name : String) : TestSeq := return (false, some s!"{name} failed: {repr err}") ) .done -/-! ## Lean NbE kernel tests -/ +/-! ## Lean kernel tests -/ def testKernelCheckEnv : TestSeq := - .individualIO "Lean NbE kernel check_env" (do + .individualIO "Lean kernel check_env" (do let leanEnv ← get_env! - IO.println s!"[Kernel-NbE] Compiling to Ixon..." + IO.println s!"[Kernel] Compiling to Ixon..." let compileStart ← IO.monoMsNow let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv let compileElapsed := (← IO.monoMsNow) - compileStart let numConsts := ixonEnv.consts.size - IO.println s!"[Kernel-NbE] Compiled {numConsts} constants in {compileElapsed.formatMs}" + IO.println s!"[Kernel] Compiled {numConsts} constants in {compileElapsed.formatMs}" - IO.println s!"[Kernel-NbE] Converting..." + IO.println s!"[Kernel] Converting..." let convertStart ← IO.monoMsNow match Ix.Kernel.Convert.convertEnv .meta ixonEnv with | .error e => - IO.println s!"[Kernel-NbE] convertEnv error: {e}" + IO.println s!"[Kernel] convertEnv error: {e}" return (false, some e) | .ok (kenv, prims, quotInit) => let convertElapsed := (← IO.monoMsNow) - convertStart - IO.println s!"[Kernel-NbE] Converted {kenv.size} constants in {convertElapsed.formatMs}" + IO.println s!"[Kernel] Converted {kenv.size} constants in {convertElapsed.formatMs}" - IO.println s!"[Kernel-NbE] Typechecking {kenv.size} constants..." + IO.println s!"[Kernel] Typechecking {kenv.size} constants..." let checkStart ← IO.monoMsNow match ← Ix.Kernel.typecheckAllIO kenv prims quotInit with | .error e => let elapsed := (← IO.monoMsNow) - checkStart - IO.println s!"[Kernel-NbE] typecheckAll error in {elapsed.formatMs}: {e}" - return (false, some s!"Kernel NbE check failed: {e}") + IO.println s!"[Kernel] typecheckAll error in {elapsed.formatMs}: {e}" + return (false, some s!"Kernel check failed: {e}") | .ok () => let elapsed := (← IO.monoMsNow) - checkStart - IO.println s!"[Kernel-NbE] All constants passed in {elapsed.formatMs}" + IO.println s!"[Kernel] All constants passed in {elapsed.formatMs}" return (true, none) ) .done diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index b14dbff4..4922cb17 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -216,33 +216,79 @@ def testConvertEnv : TestSeq := return (false, some s!"{missing.size} constants missing from Kernel.Env") ) .done -/-- Const pipeline: compile, convert, typecheck specific constants. -/ -def testConstPipeline : TestSeq := - .individualIO "kernel const pipeline" (do +/-- Typecheck specific constants through the Lean kernel. -/ +def testConsts : TestSeq := + .individualIO "kernel const checks" (do let leanEnv ← get_env! let start ← IO.monoMsNow let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv let compileMs := (← IO.monoMsNow) - start - IO.println s!"[kernel] rsCompileEnv: {ixonEnv.consts.size} consts in {compileMs.formatMs}" + IO.println s!"[kernel-const] rsCompileEnv: {ixonEnv.consts.size} consts in {compileMs.formatMs}" let convertStart ← IO.monoMsNow match Ix.Kernel.Convert.convertEnv .meta ixonEnv with | .error e => - IO.println s!"[kernel] convertEnv error: {e}" + IO.println s!"[kernel-const] convertEnv error: {e}" return (false, some e) | .ok (kenv, prims, quotInit) => let convertMs := (← IO.monoMsNow) - convertStart - IO.println s!"[kernel] convertEnv: {kenv.size} consts in {convertMs.formatMs}" + IO.println s!"[kernel-const] convertEnv: {kenv.size} consts in {convertMs.formatMs}" - -- Check specific constants let constNames := #[ + -- Basic inductives "Nat", "Nat.zero", "Nat.succ", "Nat.rec", "Bool", "Bool.true", "Bool.false", "Bool.rec", "Eq", "Eq.refl", "List", "List.nil", "List.cons", - "Nat.below" + "Nat.below", + -- Quotient types + "Quot", "Quot.mk", "Quot.lift", "Quot.ind", + -- K-reduction exercisers + "Eq.rec", "Eq.subst", "Eq.symm", "Eq.trans", + -- Proof irrelevance + "And.intro", "Or.inl", "Or.inr", + -- K-like reduction with congr + "congr", "congrArg", "congrFun", + -- Structure projections + eta + "Prod.fst", "Prod.snd", "Prod.mk", "Sigma.mk", "Subtype.mk", + -- Nat primitives + "Nat.add", "Nat.sub", "Nat.mul", "Nat.div", "Nat.mod", + "Nat.gcd", "Nat.beq", "Nat.ble", + "Nat.land", "Nat.lor", "Nat.xor", + "Nat.shiftLeft", "Nat.shiftRight", "Nat.pow", + -- Recursors + "List.rec", + -- Delta unfolding + "id", "Function.comp", + -- Various inductives + "Empty", "PUnit", "Fin", "Sigma", "Prod", + -- Proofs / proof irrelevance + "True", "False", "And", "Or", + -- Mutual/nested inductives + "List.map", "List.foldl", "List.append", + -- Universe polymorphism + "ULift", "PLift", + -- More complex + "Option", "Option.some", "Option.none", + "String", "String.mk", "Char", + -- Partial definitions + "WellFounded.fix", + -- Well-founded recursion scaffolding + "Nat.brecOn", + -- PProd (used by Nat.below) + "PProd", "PProd.mk", "PProd.fst", "PProd.snd", + "PUnit.unit", + -- noConfusion + "Lean.Meta.Grind.Origin.noConfusionType", + "Lean.Meta.Grind.Origin.noConfusion", + "Lean.Meta.Grind.Origin.stx.noConfusion", + -- Complex proofs (fuel-sensitive) + "Nat.Linear.Poly.of_denote_eq_cancel", + "String.length_empty", + "_private.Init.Grind.Ring.Basic.«0».Lean.Grind.IsCharP.mk'_aux._proof_1_5", + -- BVDecide regression test (fuel-sensitive) + "Std.Tactic.BVDecide.BVExpr.bitblast.blastUdiv.instLawfulVecOperatorShiftConcatInputBlastShiftConcat" ] - let checkStart ← IO.monoMsNow let mut passed := 0 let mut failures : Array String := #[] for name in constNames do @@ -250,15 +296,22 @@ def testConstPipeline : TestSeq := let some cNamed := ixonEnv.named.get? ixName | do failures := failures.push s!"{name}: not found"; continue let addr := cNamed.addr + IO.println s!" checking {name} ..." + (← IO.getStdout).flush + let start ← IO.monoMsNow match Ix.Kernel.typecheckConst kenv prims addr quotInit with - | .ok () => passed := passed + 1 - | .error e => failures := failures.push s!"{name}: {e}" - let checkMs := (← IO.monoMsNow) - checkStart - IO.println s!"[kernel] {passed}/{constNames.size} passed in {checkMs.formatMs}" + | .ok () => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✓ {name} ({ms.formatMs})" + passed := passed + 1 + | .error e => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✗ {name} ({ms.formatMs}): {e}" + failures := failures.push s!"{name}: {e}" + IO.println s!"[kernel-const] {passed}/{constNames.size} passed" if failures.isEmpty then return (true, none) else - for f in failures do IO.println s!" [fail] {f}" return (false, some s!"{failures.size} failure(s)") ) .done @@ -447,65 +500,6 @@ def testReducibilityHintsLt : TestSeq := /-! ## Expanded integration tests -/ -/-- Expanded constant pipeline: more constants including quotients, recursors, projections. -/ -def testMoreConstants : TestSeq := - .individualIO "expanded kernel const pipeline" (do - let leanEnv ← get_env! - let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv - match Ix.Kernel.Convert.convertEnv .meta ixonEnv with - | .error e => return (false, some e) - | .ok (kenv, prims, quotInit) => - let constNames := #[ - -- Quotient types - "Quot", "Quot.mk", "Quot.lift", "Quot.ind", - -- K-reduction exercisers - "Eq.rec", "Eq.subst", "Eq.symm", "Eq.trans", - -- Proof irrelevance - "And.intro", "Or.inl", "Or.inr", - -- K-like reduction with congr - "congr", "congrArg", "congrFun", - -- Structure projections + eta - "Prod.fst", "Prod.snd", "Prod.mk", "Sigma.mk", "Subtype.mk", - -- Nat primitives - "Nat.add", "Nat.sub", "Nat.mul", "Nat.div", "Nat.mod", - "Nat.gcd", "Nat.beq", "Nat.ble", - "Nat.land", "Nat.lor", "Nat.xor", - "Nat.shiftLeft", "Nat.shiftRight", "Nat.pow", - -- Recursors - "Bool.rec", "List.rec", - -- Delta unfolding - "id", "Function.comp", - -- Various inductives - "Empty", "PUnit", "Fin", "Sigma", "Prod", - -- Proofs / proof irrelevance - "True", "False", "And", "Or", - -- Mutual/nested inductives - "List.map", "List.foldl", "List.append", - -- Universe polymorphism - "ULift", "PLift", - -- More complex - "Option", "Option.some", "Option.none", - "String", "String.mk", "Char", - -- Partial definitions - "WellFounded.fix" - ] - let mut passed := 0 - let mut failures : Array String := #[] - for name in constNames do - let ixName := parseIxName name - let some cNamed := ixonEnv.named.get? ixName - | do failures := failures.push s!"{name}: not found"; continue - let addr := cNamed.addr - match Ix.Kernel.typecheckConst kenv prims addr quotInit with - | .ok () => passed := passed + 1 - | .error e => failures := failures.push s!"{name}: {e}" - IO.println s!"[kernel-expanded] {passed}/{constNames.size} passed" - if failures.isEmpty then - return (true, none) - else - for f in failures do IO.println s!" [fail] {f}" - return (false, some s!"{failures.size} failure(s)") - ) .done /-! ## Anon mode conversion test -/ @@ -1088,64 +1082,6 @@ def testHelperFunctions : TestSeq := (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) ++ .done -/-! ## Focused NbE constant tests -/ - -/-- Test individual constants through the NbE kernel to isolate failures. -/ -def testNbeConsts : TestSeq := - .individualIO "nbe focused const checks" (do - let leanEnv ← get_env! - let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv - match Ix.Kernel.Convert.convertEnv .meta ixonEnv with - | .error e => return (false, some s!"convertEnv: {e}") - | .ok (kenv, prims, quotInit) => - let constNames := #[ - -- Nat basics - "Nat", "Nat.zero", "Nat.succ", "Nat.rec", - -- Below / brecOn (well-founded recursion scaffolding) - "Nat.below", "Nat.brecOn", - -- PProd (used by Nat.below) - "PProd", "PProd.mk", "PProd.fst", "PProd.snd", - "PUnit", "PUnit.unit", - -- noConfusion (stuck neutral in fresh-state mode) - "Lean.Meta.Grind.Origin.noConfusionType", - "Lean.Meta.Grind.Origin.noConfusion", - "Lean.Meta.Grind.Origin.stx.noConfusion", - -- The previously-hanging constant - "Nat.Linear.Poly.of_denote_eq_cancel", - -- String theorem (fuel-sensitive) - "String.length_empty", - "_private.Init.Grind.Ring.Basic.«0».Lean.Grind.IsCharP.mk'_aux._proof_1_5", - ] - let mut passed := 0 - let mut failures : Array String := #[] - for name in constNames do - let ixName := parseIxName name - let some cNamed := ixonEnv.named.get? ixName - | do failures := failures.push s!"{name}: not found"; continue - let addr := cNamed.addr - IO.println s!" checking {name} ..." - (← IO.getStdout).flush - let start ← IO.monoMsNow - match Ix.Kernel.typecheckConst kenv prims addr quotInit with - | .ok () => - let ms := (← IO.monoMsNow) - start - IO.println s!" ✓ {name} ({ms.formatMs})" - passed := passed + 1 - | .error e => - let ms := (← IO.monoMsNow) - start - IO.println s!" ✗ {name} ({ms.formatMs}): {e}" - failures := failures.push s!"{name}: {e}" - IO.println s!"[nbe-focus] {passed}/{constNames.size} passed" - if failures.isEmpty then - return (true, none) - else - return (false, some s!"{failures.size} failure(s)") - ) .done - -def nbeFocusSuite : List TestSeq := [ - testNbeConsts, -] - /-! ## Test suites -/ def unitSuite : List TestSeq := [ @@ -1165,8 +1101,7 @@ def convertSuite : List TestSeq := [ ] def constSuite : List TestSeq := [ - testConstPipeline, - testMoreConstants, + testConsts, ] def negativeSuite : List TestSeq := [ diff --git a/Tests/Ix/PP.lean b/Tests/Ix/PP.lean index ab52ea3e..2f66249c 100644 --- a/Tests/Ix/PP.lean +++ b/Tests/Ix/PP.lean @@ -232,55 +232,7 @@ def testPpComplex : TestSeq := test "nested let" (outerLet.pp == "let x : Nat := 0; let y : Nat := x; y") ++ .done -/-! ## Quote round-trip: names survive eval → quote → pp -/ - -/-- Build a Value with named binders and verify names survive through quote → pp. - Uses a minimal TypecheckM context. -/ -def testQuoteRoundtrip : TestSeq := - .individualIO "quote round-trip preserves names" (do - let xName : MetaField .meta Ix.Name := mkName "x" - let yName : MetaField .meta Ix.Name := mkName "y" - let nat : Expr .meta := .const testAddr #[] (mkName "Nat") - -- Build Value.pi: ∀ (x : Nat), Nat - let domVal : SusValue .meta := ⟨.none, Thunk.mk fun _ => Value.neu (.const testAddr #[] (mkName "Nat"))⟩ - let imgTE : TypedExpr .meta := ⟨.none, nat⟩ - let piVal : Value .meta := .pi domVal imgTE (.mk [] []) xName .default - -- Build Value.lam: fun (y : Nat) => y - let bodyTE : TypedExpr .meta := ⟨.none, .bvar 0 yName⟩ - let lamVal : Value .meta := .lam domVal bodyTE (.mk [] []) yName .default - -- Quote and pp in a minimal TypecheckM context (wrapped in runST for ST.Ref allocation) - let result := runST fun σ => do - let fuelRef ← ST.mkRef Ix.Kernel.defaultFuel - let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level .meta) × Value .meta)) - let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) - let ctx : TypecheckCtx .meta σ := { - lvl := 0, env := .mk [] [], types := [], - kenv := default, prims := buildPrimitives, - safety := .safe, quotInit := true, mutTypes := default, recAddr? := none, - fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef - } - let stt : TypecheckState .meta := { typedConsts := default } - let piResult ← TypecheckM.run ctx stt (ppValue 0 piVal) - let lamResult ← TypecheckM.run ctx stt (ppValue 0 lamVal) - pure (piResult, lamResult) - -- Test pi - match result.1 with - | .ok s => - if s != "∀ (x : Nat), Nat" then - return (false, some s!"pi round-trip: expected '∀ (x : Nat), Nat', got '{s}'") - else pure () - | .error e => return (false, some s!"pi round-trip error: {e}") - -- Test lam - match result.2 with - | .ok s => - if s != "λ (y : Nat) => y" then - return (false, some s!"lam round-trip: expected 'λ (y : Nat) => y', got '{s}'") - else pure () - | .error e => return (false, some s!"lam round-trip error: {e}") - return (true, none) - ) .done - -/-! ## Literal folding: Nat/String constructor chains → literals in ppValue -/ +/-! ## Literal folding: Nat/String constructor chains → literals in Expr -/ def testFoldLiterals : TestSeq := let prims := buildPrimitives @@ -334,7 +286,6 @@ def suite : List TestSeq := [ testPpAnon, testPpMetaDefaultNames, testPpComplex, - testQuoteRoundtrip, testFoldLiterals, ] diff --git a/Tests/Main.lean b/Tests/Main.lean index e7ca61c2..b146142e 100644 --- a/Tests/Main.lean +++ b/Tests/Main.lean @@ -61,7 +61,6 @@ def ignoredSuites : Std.HashMap String (List LSpec.TestSeq) := .ofList [ ("kernel-const", Tests.KernelTests.constSuite), ("kernel-verify-prims", [Tests.KernelTests.testVerifyPrimAddrs]), ("kernel-dump-prims", [Tests.KernelTests.testDumpPrimAddrs]), - ("nbe-focus", Tests.KernelTests.nbeFocusSuite), ("kernel-roundtrip", Tests.KernelTests.roundtripSuite), ("ixon-full-roundtrip", Tests.Compile.ixonRoundtripSuiteIO), ] From 406a7a33f5de34c7b1634a31b4794742ffd5b09c Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Mon, 2 Mar 2026 13:07:48 -0500 Subject: [PATCH 07/14] Add EquivManager, inferOnly mode, and isDefEq optimizations Replace HashMap-based eqvCache with union-find EquivManager (ported from lean4lean) for congruence-aware structural equality caching. Add inferOnly mode that skips argument/let type-checking during inference, used for theorem value checking to handle sub-term type mismatches. Additional isDefEq improvements: - isDefEqUnitLike for non-recursive single-ctor zero-field types - isDefEqOffset for Nat.succ chain short-circuiting - tryUnfoldProjApp in lazy delta for projection-headed stuck terms - cheapProj=true in stage 1 defers full projection reduction to stage 2 - Failure cache on same-head optimization in lazyDeltaReduction - Fix ReducibilityHints.lt' to handle all cases correctly --- Ix/Kernel.lean | 1 + Ix/Kernel/EquivManager.lean | 92 ++++++++++++++++++ Ix/Kernel/Infer.lean | 181 ++++++++++++++++++++++++++++-------- Ix/Kernel/TypecheckM.lean | 11 ++- Ix/Kernel/Types.lean | 7 +- Ix/Kernel/Whnf.lean | 5 +- Tests/Ix/KernelTests.lean | 4 +- 7 files changed, 254 insertions(+), 47 deletions(-) create mode 100644 Ix/Kernel/EquivManager.lean diff --git a/Ix/Kernel.lean b/Ix/Kernel.lean index ba19b0b4..2ce31362 100644 --- a/Ix/Kernel.lean +++ b/Ix/Kernel.lean @@ -3,6 +3,7 @@ import Ix.Environment import Ix.Kernel.Types import Ix.Kernel.Datatypes import Ix.Kernel.Level +import Ix.Kernel.EquivManager import Ix.Kernel.TypecheckM import Ix.Kernel.Whnf import Ix.Kernel.DefEq diff --git a/Ix/Kernel/EquivManager.lean b/Ix/Kernel/EquivManager.lean new file mode 100644 index 00000000..9521922c --- /dev/null +++ b/Ix/Kernel/EquivManager.lean @@ -0,0 +1,92 @@ +/- + EquivManager: Union-find based equivalence tracking for definitional equality. + + Ported from lean4lean's EquivManager. Provides structural expression walking + with union-find to recognize congruence: if a ~ b and c ~ d, then f a c ~ f b d + is detected without re-entering isDefEq. +-/ +import Batteries.Data.UnionFind.Basic +import Ix.Kernel.Datatypes + +namespace Ix.Kernel + +abbrev NodeRef := Nat + +structure EquivManager (m : MetaMode) where + uf : Batteries.UnionFind := {} + toNodeMap : Std.HashMap (Expr m) NodeRef := {} + +instance : Inhabited (EquivManager m) := ⟨{}⟩ + +namespace EquivManager + +/-- Map an expression to a union-find node, creating one if it doesn't exist. -/ +def toNode (e : Expr m) : StateM (EquivManager m) NodeRef := fun mgr => + match mgr.toNodeMap.get? e with + | some n => (n, mgr) + | none => + let n := mgr.uf.size + (n, { uf := mgr.uf.push, toNodeMap := mgr.toNodeMap.insert e n }) + +/-- Find the root of a node with path compression. -/ +def find (n : NodeRef) : StateM (EquivManager m) NodeRef := fun mgr => + let (uf', root) := mgr.uf.findD n + (root, { mgr with uf := uf' }) + +/-- Merge two nodes into the same equivalence class. -/ +def merge (n1 n2 : NodeRef) : StateM (EquivManager m) Unit := fun mgr => + if n1 < mgr.uf.size && n2 < mgr.uf.size then + ((), { mgr with uf := mgr.uf.union! n1 n2 }) + else + ((), mgr) + +/-- Check structural equivalence with union-find memoization. + Recursively walks expression structure, checking if corresponding + sub-expressions are in the same union-find equivalence class. + Merges nodes on success for future O(α(n)) lookups. + + When `useHash = true`, expressions with different hashes are immediately + rejected without structural walking (fast path for obviously different terms). -/ +partial def isEquiv (useHash : Bool) (e1 e2 : Expr m) : StateM (EquivManager m) Bool := do + -- 1. Pointer/structural equality (O(1) via Blake3 content-addressing) + if e1 == e2 then return true + -- 2. Hash mismatch → definitely not structurally equal + if useHash && Hashable.hash e1 != Hashable.hash e2 then return false + -- 3. BVar fast path (compare indices directly, don't add to union-find) + match e1, e2 with + | .bvar i _, .bvar j _ => return i == j + | _, _ => pure () + -- 4. Union-find root comparison + let r1 ← find (← toNode e1) + let r2 ← find (← toNode e2) + if r1 == r2 then return true + -- 5. Structural decomposition + let result ← match e1, e2 with + | .const a1 l1 _, .const a2 l2 _ => pure (a1 == a2 && l1 == l2) + | .sort l1, .sort l2 => pure (l1 == l2) + | .lit l1, .lit l2 => pure (l1 == l2) + | .app f1 a1, .app f2 a2 => + if ← isEquiv useHash f1 f2 then isEquiv useHash a1 a2 else pure false + | .lam d1 b1 _ _, .lam d2 b2 _ _ => + if ← isEquiv useHash d1 d2 then isEquiv useHash b1 b2 else pure false + | .forallE d1 b1 _ _, .forallE d2 b2 _ _ => + if ← isEquiv useHash d1 d2 then isEquiv useHash b1 b2 else pure false + | .proj ta1 i1 s1 _, .proj ta2 i2 s2 _ => + if ta1 == ta2 && i1 == i2 then isEquiv useHash s1 s2 else pure false + | .letE t1 v1 b1 _, .letE t2 v2 b2 _ => + if ← isEquiv useHash t1 t2 then + if ← isEquiv useHash v1 v2 then isEquiv useHash b1 b2 else pure false + else pure false + | _, _ => pure false + -- 6. Merge on success + if result then merge r1 r2 + return result + +/-- Directly merge two expressions into the same equivalence class. -/ +def addEquiv (e1 e2 : Expr m) : StateM (EquivManager m) Unit := do + let r1 ← find (← toNode e1) + let r2 ← find (← toNode e2) + merge r1 r2 + +end EquivManager +end Ix.Kernel diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index abf8a9f2..5218d476 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -170,20 +170,27 @@ mutual let (fnTe, fncType) ← infer fn let mut currentType := fncType let mut resultBody := fnTe.body + let inferOnly := (← read).inferOnly for h : i in [:args.size] do let arg := args[i] let currentType' ← whnf currentType match currentType' with | .forallE dom body _ _ => do - let argTe ← check arg dom - resultBody := Expr.mkApp resultBody argTe.body + if inferOnly then + resultBody := Expr.mkApp resultBody arg + else + let argTe ← check arg dom + resultBody := Expr.mkApp resultBody argTe.body currentType := body.instantiate1 arg | _ => throw s!"Expected a pi type, got {currentType'.pp}\n function: {fn.pp}\n arg #{i}: {arg.pp}" let te : TypedExpr m := ⟨← infoFromType currentType, resultBody⟩ pure (te, currentType) | .lam ty body lamName lamBi => do - let (domTe, _) ← isSort ty + let domTe ← if (← read).inferOnly then + pure ⟨.none, ty⟩ + else + let (te, _) ← isSort ty; pure te let (bodTe, imgType) ← withExtendedCtx ty (infer body) let piType := Expr.forallE ty imgType lamName default let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ @@ -196,12 +203,18 @@ mutual let te : TypedExpr m := ⟨← infoFromType typ, .forallE domTe.body imgTe.body piName default⟩ pure (te, typ) | .letE ty val body letName => do - let (tyTe, _) ← isSort ty - let valTe ← check val ty - let (bodTe, bodType) ← withExtendedCtx ty (infer body) - let resultType := bodType.instantiate1 val - let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ - pure (te, resultType) + if (← read).inferOnly then + let (bodTe, bodType) ← withExtendedCtx ty (infer body) + let resultType := bodType.instantiate1 val + let te : TypedExpr m := ⟨bodTe.info, .letE ty val bodTe.body letName⟩ + pure (te, resultType) + else + let (tyTe, _) ← isSort ty + let valTe ← check val ty + let (bodTe, bodType) ← withExtendedCtx ty (infer body) + let resultType := bodType.instantiate1 val + let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ + pure (te, resultType) | .lit (.natVal _) => do let prims := (← read).prims let typ := Expr.mkConst prims.nat #[] @@ -300,7 +313,7 @@ mutual whnfCache := {}, whnfCoreCache := {}, inferCache := {}, - eqvCache := {}, + eqvManager := {}, failureCache := {}, fuel := defaultFuel } @@ -319,10 +332,13 @@ mutual let value ← withRecAddr addr (check ci.value?.get! type.body) pure (TypedConst.opaque type value) | .thmInfo _ => - let (type, lvl) ← isSort ci.type + let (type, lvl) ← withInferOnly (isSort ci.type) if !Level.isZero lvl then throw s!"theorem type must be a proposition (Sort 0)" - let value ← withRecAddr addr (check ci.value?.get! type.body) + let (_, valType) ← withRecAddr addr (withInferOnly (infer ci.value?.get!)) + if !(← withInferOnly (isDefEq valType type.body)) then + throw s!"theorem value type doesn't match declared type" + let value : TypedExpr m := ⟨.proof, ci.value?.get!⟩ pure (TypedConst.theorem type value) | .defnInfo v => let (type, _) ← isSort ci.type @@ -488,13 +504,19 @@ mutual - some false: definitely not equal - none: unknown, need deeper checks -/ partial def quickIsDefEq (t s : Expr m) (useHash : Bool := true) : TypecheckM m (Option Bool) := do - if t == s then return some true + -- Run EquivManager structural walk with union-find + let stt ← get + let (result, mgr') := EquivManager.isEquiv useHash t s |>.run stt.eqvManager + modify fun stt => { stt with eqvManager := mgr' } + if result then return some true + -- Failure cache (EquivManager only tracks successes) let key := eqCacheKey t s - if let some r := (← get).eqvCache.get? key then return some r if (← get).failureCache.contains key then return some false + -- Shape-specific checks with richer equality (Level.equalLevel, etc.) match t, s with | .sort u, .sort u' => pure (some (Level.equalLevel u u')) - | .const a us _, .const b us' _ => pure (some (a == b && equalUnivArrays us us')) + | .const a us _, .const b us' _ => + if a == b && equalUnivArrays us us' then pure (some true) else pure none | .lit l, .lit l' => pure (some (l == l')) | .bvar i _, .bvar j _ => pure (some (i == j)) | .lam ty body _ _, .lam ty' body' _ _ => @@ -520,12 +542,12 @@ mutual | some result => return result | none => pure () - -- 1. Stage 1: structural reduction - let tn ← whnfCore t - let sn ← whnfCore s + -- 1. Stage 1: structural reduction (cheapProj=true: defer full projection resolution) + let tn ← whnfCore t (cheapProj := true) + let sn ← whnfCore s (cheapProj := true) - -- 2. Quick check after whnfCore - match ← quickIsDefEq tn sn with + -- 2. Quick check after whnfCore (useHash=false for thorough union-find walking) + match ← quickIsDefEq tn sn (useHash := false) with | some true => cacheResult t s true; return true | some false => pure () -- don't cache — deeper checks may still succeed | none => pure () @@ -543,12 +565,25 @@ mutual cacheResult t s true return true - -- 5. Stage 2: full whnf (resolves projections + remaining delta) - let tnn ← whnf tn' - let snn ← whnf sn' - if tnn == snn then - cacheResult t s true - return true + -- 4b. Cheap structural checks after lazy delta (before full whnfCore) + match tn', sn' with + | .const a us _, .const b us' _ => + if a == b && equalUnivArrays us us' then + cacheResult t s true; return true + | .proj _ ti te _, .proj _ si se _ => + if ti == si then + if ← isDefEq te se then + cacheResult t s true; return true + | _, _ => pure () + + -- 5. Stage 2: full structural reduction (no cheapProj — resolve all projections) + let tnn ← whnfCore tn' + let snn ← whnfCore sn' + -- Only recurse into isDefEqCore if something actually changed + if !(tnn == tn' && snn == sn') then + let result ← isDefEqCore tnn snn + cacheResult t s result + return result -- 6. Structural comparison on fully-reduced terms let result ← isDefEqCore tnn snn @@ -668,7 +703,50 @@ mutual | _, .app _ _ => tryEtaStruct t s | .app _ _, _ => tryEtaStruct s t - | _, _ => pure false + -- Unit-like fallback: non-recursive, single ctor with 0 fields, 0 indices + | _, _ => isDefEqUnitLike t s + + /-- For unit-like types (non-recursive, single ctor with 0 fields, 0 indices), + two terms are defeq if their types are defeq. -/ + partial def isDefEqUnitLike (t s : Expr m) : TypecheckM m Bool := do + let kenv := (← read).kenv + let (_, tType) ← infer t + let tType' ← whnf tType + let fn := tType'.getAppFn + match fn with + | .const addr _ _ => + match kenv.find? addr with + | some (.inductInfo v) => + if v.isRec || v.numIndices != 0 || v.ctors.size != 1 then return false + match kenv.find? v.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields != 0 then return false + let (_, sType) ← infer s + isDefEq tType sType + | _ => return false + | _ => return false + | _ => return false + + /-- If e is an application whose head is a projection, try whnfCore to reduce it. -/ + partial def tryUnfoldProjApp (e : Expr m) : TypecheckM m (Option (Expr m)) := do + match e.getAppFn with + | .proj .. => + let e' ← whnfCore e + if e' == e then return none else return some e' + | _ => return none + + /-- Check if two Nat.succ chains or zero values match structurally. -/ + partial def isDefEqOffset (t s : Expr m) : TypecheckM m (Option Bool) := do + let prims := (← read).prims + let isZero (e : Expr m) := e.isConstOf prims.natZero || match e with | .lit (.natVal 0) => true | _ => false + let succOf? (e : Expr m) : Option (Expr m) := match e with + | .lit (.natVal (n+1)) => some (.lit (.natVal n)) + | .app fn arg => if fn.isConstOf prims.natSucc then some arg else none + | _ => none + if isZero t && isZero s then return some true + match succOf? t, succOf? s with + | some t', some s' => some <$> isDefEq t' s' + | _, _ => return none /-- Lazy delta reduction loop. Unfolds definitions one step at a time, guided by ReducibilityHints, until a conclusive comparison or both @@ -686,11 +764,22 @@ mutual -- Syntactic check if tn == sn then return (tn, sn, some true) + -- Quick structural check (EquivManager + lambda/forall matching) + -- Only trust "definitely equal"; delta reduction may still make unequal terms equal + match ← quickIsDefEq tn sn (useHash := false) with + | some true => return (tn, sn, some true) + | _ => pure () + + -- isDefEqOffset: short-circuit Nat.succ chain comparison + match ← isDefEqOffset tn sn with + | some result => return (tn, sn, some result) + | none => pure () + -- Try nat reduction if let some r := ← tryReduceNat tn then - tn ← whnfCore r; continue + tn ← whnfCore r (cheapProj := true); continue if let some r := ← tryReduceNat sn then - sn ← whnfCore r; continue + sn ← whnfCore r (cheapProj := true); continue -- Lazy delta step let tDelta := isDelta tn kenv @@ -698,32 +787,42 @@ mutual match tDelta, sDelta with | none, none => return (tn, sn, none) -- both stuck | some dt, none => + -- Try reducing projection-headed app on the stuck side first + if let some sn' ← tryUnfoldProjApp sn then + sn := sn'; continue match unfoldDelta dt tn with - | some r => tn ← whnfCore r; continue + | some r => tn ← whnfCore r (cheapProj := true); continue | none => return (tn, sn, none) | none, some ds => + -- Try reducing projection-headed app on the stuck side first + if let some tn' ← tryUnfoldProjApp tn then + tn := tn'; continue match unfoldDelta ds sn with - | some r => sn ← whnfCore r; continue + | some r => sn ← whnfCore r (cheapProj := true); continue | none => return (tn, sn, none) | some dt, some ds => let ht := dt.hints let hs := ds.hints - -- Same head optimization: try comparing args first - if sameHeadConst tn sn && ht.isRegular && hs.isRegular then - if ← isDefEqApp tn sn then return (tn, sn, some true) + -- Same head optimization: try comparing args first (with failure cache) + if tn.isApp && sn.isApp && sameHeadConst tn sn && ht.isRegular then + let key := eqCacheKey tn sn + if !(← get).failureCache.contains key then + if equalUnivArrays tn.getAppFn.constLevels! sn.getAppFn.constLevels! then + if ← isDefEqApp tn sn then return (tn, sn, some true) + modify fun stt => { stt with failureCache := stt.failureCache.insert key } if ht.lt' hs then match unfoldDelta ds sn with - | some r => sn ← whnfCore r; continue + | some r => sn ← whnfCore r (cheapProj := true); continue | none => match unfoldDelta dt tn with - | some r => tn ← whnfCore r; continue + | some r => tn ← whnfCore r (cheapProj := true); continue | none => return (tn, sn, none) else if hs.lt' ht then match unfoldDelta dt tn with - | some r => tn ← whnfCore r; continue + | some r => tn ← whnfCore r (cheapProj := true); continue | none => match unfoldDelta ds sn with - | some r => sn ← whnfCore r; continue + | some r => sn ← whnfCore r (cheapProj := true); continue | none => return (tn, sn, none) else -- Same height: unfold both @@ -782,10 +881,12 @@ mutual /-- Cache a def-eq result (both successes and failures). -/ partial def cacheResult (t s : Expr m) (result : Bool) : TypecheckM m Unit := do - let key := eqCacheKey t s if result then - modify fun stt => { stt with eqvCache := stt.eqvCache.insert key result } + modify fun stt => + let (_, mgr') := EquivManager.addEquiv t s |>.run stt.eqvManager + { stt with eqvManager := mgr' } else + let key := eqCacheKey t s modify fun stt => { stt with failureCache := stt.failureCache.insert key } end -- mutual diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 45385b5a..317fac09 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -6,6 +6,7 @@ -/ import Ix.Kernel.Datatypes import Ix.Kernel.Level +import Ix.Kernel.EquivManager namespace Ix.Kernel @@ -30,6 +31,8 @@ structure TypecheckCtx (m : MetaMode) where mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → Expr m)) compare /-- Tracks the address of the constant currently being checked, for recursion detection. -/ recAddr? : Option Address + /-- When true, skip argument type-checking during inference (lean4lean inferOnly). -/ + inferOnly : Bool := false /-- Enable dbg_trace on major entry points for debugging. -/ trace : Bool := false @@ -47,13 +50,16 @@ structure TypecheckState (m : MetaMode) where /-- Infer cache: maps term → (binding context, inferred type). Keyed on Expr only; context verified on retrieval via ptr equality + BEq fallback. -/ inferCache : Std.HashMap (Expr m) (Array (Expr m) × Expr m) := {} - eqvCache : Std.HashMap (Expr m × Expr m) Bool := {} + eqvManager : EquivManager m := {} failureCache : Std.HashSet (Expr m × Expr m) := {} constTypeCache : Std.HashMap Address (Array (Level m) × Expr m) := {} fuel : Nat := defaultFuel /-- Tracks nesting depth of whnf calls from within recursor reduction (tryReduceApp → whnf). When this exceeds a threshold, whnfCore is used instead of whnf to prevent stack overflow. -/ whnfDepth : Nat := 0 + /-- Global recursion depth across isDefEq/infer/whnf for stack overflow prevention. -/ + recDepth : Nat := 0 + maxRecDepth : Nat := 0 deriving Inhabited /-! ## TypecheckM monad -/ @@ -83,6 +89,9 @@ def withExtendedCtx (varType : Expr m) : TypecheckM m α → TypecheckM m α := def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with recAddr? := some addr } +def withInferOnly : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with inferOnly := true } + /-- The current binding depth (number of bound variables in scope). -/ def lvl : TypecheckM m Nat := do pure (← read).types.size diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index 6a8ff1d1..15d077c1 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -380,10 +380,11 @@ inductive ReducibilityHints where namespace ReducibilityHints def lt' : ReducibilityHints → ReducibilityHints → Bool + | _, .opaque => false + | .abbrev, _ => false + | .opaque, _ => true + | _, .abbrev => true | .regular d₁, .regular d₂ => d₁ < d₂ - | .regular _, .opaque => true - | .abbrev, .opaque => true - | _, _ => false def isRegular : ReducibilityHints → Bool | .regular _ => true diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean index 591b66d7..21bc566d 100644 --- a/Ix/Kernel/Whnf.lean +++ b/Ix/Kernel/Whnf.lean @@ -172,8 +172,9 @@ mutual | .letE _ val body _ => t := body.instantiate1 val; continue -- loop instead of recursion | .proj typeAddr idx struct _ => do - if cheapProj then return t -- skip projection reduction - let struct' ← whnfCore struct cheapRec cheapProj + -- cheapProj=true: try structural-only reduction (whnfCore, no delta) + -- cheapProj=false: full reduction (whnf, with delta) + let struct' ← if cheapProj then whnfCore struct cheapRec cheapProj else whnf struct match ← reduceProj typeAddr idx struct' with | some result => t := result; continue -- loop instead of recursion | none => diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index 4922cb17..d3c9f7ab 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -287,7 +287,9 @@ def testConsts : TestSeq := "String.length_empty", "_private.Init.Grind.Ring.Basic.«0».Lean.Grind.IsCharP.mk'_aux._proof_1_5", -- BVDecide regression test (fuel-sensitive) - "Std.Tactic.BVDecide.BVExpr.bitblast.blastUdiv.instLawfulVecOperatorShiftConcatInputBlastShiftConcat" + "Std.Tactic.BVDecide.BVExpr.bitblast.blastUdiv.instLawfulVecOperatorShiftConcatInputBlastShiftConcat", + -- Theorem with sub-term type mismatch (requires inferOnly) + "Std.Do.Spec.tryCatch_ExceptT" ] let mut passed := 0 let mut failures : Array String := #[] From 573abad6e7b8841155795b4a91c0cbc54fcbd44a Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Mon, 2 Mar 2026 14:57:35 -0500 Subject: [PATCH 08/14] Make positivity checking monadic with whnf and nested inductive support Move checkStrictPositivity/checkCtorPositivity into the mutual block as monadic checkPositivity/checkCtorFields/checkNestedCtorFields, enabling whnf calls during positivity analysis. This matches lean4lean's checkPositivity and correctly handles nested inductives (e.g. an inductive appearing as a param of a previously-defined inductive). Split KernelTests.lean into Helpers, Unit, and Soundness submodules. Add targeted soundness tests for nested positivity: positive nesting via Wrap, double nesting, multi-field, multi-param, contravariant rejection, index-position rejection, non-inductive head, and unsafe outer. Add Lean.Elab.Term.Do.Code.action as an integration test case requiring whnf-based nested positivity. --- Ix/Kernel/Infer.lean | 117 +++-- Tests/Ix/Kernel/Helpers.lean | 110 +++++ Tests/Ix/Kernel/Soundness.lean | 410 +++++++++++++++++ Tests/Ix/Kernel/Unit.lean | 298 +++++++++++++ Tests/Ix/KernelTests.lean | 780 ++------------------------------- 5 files changed, 934 insertions(+), 781 deletions(-) create mode 100644 Tests/Ix/Kernel/Helpers.lean create mode 100644 Tests/Ix/Kernel/Soundness.lean create mode 100644 Tests/Ix/Kernel/Unit.lean diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 5218d476..b3867342 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -20,36 +20,8 @@ partial def exprMentionsConst (e : Expr m) (addr : Address) : Bool := | .proj _ _ s _ => exprMentionsConst s addr | _ => false -/-- Check strict positivity of a field type w.r.t. a set of inductive addresses. -/ -partial def checkStrictPositivity (ty : Expr m) (indAddrs : Array Address) : Bool := - if !indAddrs.any (exprMentionsConst ty ·) then true - else match ty with - | .forallE domain body _ _ => - if indAddrs.any (exprMentionsConst domain ·) then false - else checkStrictPositivity body indAddrs - | e => - let fn := e.getAppFn - match fn with - | .const addr _ _ => indAddrs.any (· == addr) - | _ => false - -/-- Walk a Pi chain, skip numParams binders, then check positivity of each field. -/ -partial def checkCtorPositivity (ctorType : Expr m) (numParams : Nat) (indAddrs : Array Address) - : Option String := - go ctorType numParams -where - go (ty : Expr m) (remainingParams : Nat) : Option String := - match ty with - | .forallE _domain body _name _bi => - if remainingParams > 0 then - go body (remainingParams - 1) - else - let domain := ty.bindingDomain! - if !checkStrictPositivity domain indAddrs then - some "inductive occurs in negative position (strict positivity violation)" - else - go body 0 - | _ => none +-- checkStrictPositivity and checkCtorPositivity are now monadic (inside the mutual block) +-- to allow calling whnf, matching lean4lean's checkPositivity. /-- Walk a Pi chain past numParams + numFields binders to get the return type. -/ def getCtorReturnType (ctorType : Expr m) (numParams numFields : Nat) : Expr m := @@ -389,6 +361,89 @@ mutual withExtendedCtx dom (getReturnSort body n) | _, _ => throw "inductive type has fewer binders than expected" + /-- Check that the fields of a nested inductive's constructor use the current + inductives only in positive positions. Walks past numParams binders of the + outer ctor type, substituting actual param args, then checks each field. -/ + partial def checkNestedCtorFields (ctorType : Expr m) (numParams : Nat) + (paramArgs : Array (Expr m)) (indAddrs : Array Address) : TypecheckM m Bool := do + -- Walk past param binders to get the field portion of the ctor type + let mut ty := ctorType + for _ in [:numParams] do + match ty with + | .forallE _ body _ _ => ty := body + | _ => return true + -- Substitute all param bvars: bvar 0 = last param, bvar (n-1) = first param + ty := ty.instantiate paramArgs.reverse + -- Check each field for positivity + loop ty + where + loop (ty : Expr m) : TypecheckM m Bool := do + let ty ← whnf ty + match ty with + | .forallE dom body _ _ => + if !(← checkPositivity dom indAddrs) then return false + loop body + | _ => return true + + /-- Check strict positivity of a field type w.r.t. a set of inductive addresses. + Handles direct recursion, negative-position rejection, and nested inductives + (where the inductive appears as a param of a previously-defined inductive). -/ + partial def checkPositivity (ty : Expr m) (indAddrs : Array Address) : TypecheckM m Bool := do + let ty ← whnf ty + if !indAddrs.any (exprMentionsConst ty ·) then return true + match ty with + | .forallE dom body _ _ => + if indAddrs.any (exprMentionsConst dom ·) then + return false + checkPositivity body indAddrs + | e => + let fn := e.getAppFn + match fn with + | .const addr _ _ => + if indAddrs.any (· == addr) then return true + -- Nested inductive: head is a previously-defined inductive + match (← read).kenv.find? addr with + | some (.inductInfo fv) => + if fv.isUnsafe then return false + let args := e.getAppArgs + -- Index args must not mention current inductives + for i in [fv.numParams:args.size] do + if indAddrs.any (exprMentionsConst args[i]! ·) then return false + -- Check all constructors of the outer inductive use params positively. + -- Augment indAddrs with the outer inductive's own addresses so that + -- its self-recursive fields (e.g., List α in List.cons) are accepted + -- immediately rather than causing infinite recursion. + let paramArgs := args[:fv.numParams].toArray + let augmented := indAddrs ++ fv.all + for ctorAddr in fv.ctors do + match (← read).kenv.find? ctorAddr with + | some (.ctorInfo cv) => + if !(← checkNestedCtorFields cv.type fv.numParams paramArgs augmented) then + return false + | _ => return false + return true + | _ => return false + | _ => return false + + /-- Walk a Pi chain, skip numParams binders, then check positivity of each field. + Monadic to call whnf, matching lean4lean. -/ + partial def checkCtorFields (ctorType : Expr m) (numParams : Nat) (indAddrs : Array Address) + : TypecheckM m (Option String) := + go ctorType numParams + where + go (ty : Expr m) (remainingParams : Nat) : TypecheckM m (Option String) := do + let ty ← whnf ty + match ty with + | .forallE _dom body _name _bi => + if remainingParams > 0 then + go body (remainingParams - 1) + else + let domain := ty.bindingDomain! + if !(← checkPositivity domain indAddrs) then + return some "inductive occurs in negative position (strict positivity violation)" + go body 0 + | _ => return none + /-- Typecheck a mutual inductive block. -/ partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do let ci ← derefConst addr @@ -417,7 +472,7 @@ mutual if cv.numParams != iv.numParams then throw s!"Constructor {ctorAddr} has {cv.numParams} params but inductive has {iv.numParams}" if !iv.isUnsafe then - match checkCtorPositivity cv.type cv.numParams indAddrs with + match ← checkCtorFields cv.type cv.numParams indAddrs with | some msg => throw s!"Constructor {ctorAddr}: {msg}" | none => pure () if !iv.isUnsafe then diff --git a/Tests/Ix/Kernel/Helpers.lean b/Tests/Ix/Kernel/Helpers.lean new file mode 100644 index 00000000..6510abe8 --- /dev/null +++ b/Tests/Ix/Kernel/Helpers.lean @@ -0,0 +1,110 @@ +/- + Shared test utilities for kernel tests. + - Address helpers (mkAddr) + - Name parsing (parseIxName, leanNameToIx) + - Env-building helpers (addInductive, addCtor, addAxiom) + - Expect helpers (expectError, expectOk) +-/ +import Ix.Kernel + +open Ix.Kernel + +namespace Tests.Ix.Kernel.Helpers + +/-- Helper: make unique addresses from a seed byte. -/ +def mkAddr (seed : UInt8) : Address := + Address.blake3 (ByteArray.mk #[seed, 0xAA, 0xBB]) + +/-- Parse a dotted name string like "Nat.add" into an Ix.Name. + Handles `«...»` quoted name components (e.g. `Foo.«0».Bar`). -/ +partial def parseIxName (s : String) : Ix.Name := + let parts := splitParts s.toList [] + parts.foldl (fun acc part => + match part with + | .inl str => Ix.Name.mkStr acc str + | .inr nat => Ix.Name.mkNat acc nat + ) Ix.Name.mkAnon +where + /-- Split a dotted name into parts: .inl for string components, .inr for numeric (guillemet). -/ + splitParts : List Char → List (String ⊕ Nat) → List (String ⊕ Nat) + | [], acc => acc + | '.' :: rest, acc => splitParts rest acc + | '«' :: rest, acc => + let (inside, rest') := collectUntilClose rest "" + let part := match inside.toNat? with + | some n => .inr n + | none => .inl inside + splitParts rest' (acc ++ [part]) + | cs, acc => + let (word, rest) := collectUntilDot cs "" + splitParts rest (if word.isEmpty then acc else acc ++ [.inl word]) + collectUntilClose : List Char → String → String × List Char + | [], s => (s, []) + | '»' :: rest, s => (s, rest) + | c :: rest, s => collectUntilClose rest (s.push c) + collectUntilDot : List Char → String → String × List Char + | [], s => (s, []) + | '.' :: rest, s => (s, '.' :: rest) + | '«' :: rest, s => (s, '«' :: rest) + | c :: rest, s => collectUntilDot rest (s.push c) + +/-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ +partial def leanNameToIx : Lean.Name → Ix.Name + | .anonymous => Ix.Name.mkAnon + | .str pre s => Ix.Name.mkStr (leanNameToIx pre) s + | .num pre n => Ix.Name.mkNat (leanNameToIx pre) n + +/-- Build an inductive and insert it into the env. -/ +def addInductive (env : Env .anon) (addr : Address) + (type : Expr .anon) (ctors : Array Address) + (numParams numIndices : Nat := 0) (isRec := false) + (isUnsafe := false) (numNested := 0) : Env .anon := + env.insert addr (.inductInfo { + toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, + numParams, numIndices, all := #[addr], ctors, numNested, + isRec, isUnsafe, isReflexive := false + }) + +/-- Build a constructor and insert it into the env. -/ +def addCtor (env : Env .anon) (addr : Address) (induct : Address) + (type : Expr .anon) (cidx numParams numFields : Nat) + (isUnsafe := false) : Env .anon := + env.insert addr (.ctorInfo { + toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, + induct, cidx, numParams, numFields, isUnsafe + }) + +/-- Build an axiom and insert it into the env. -/ +def addAxiom (env : Env .anon) (addr : Address) + (type : Expr .anon) (isUnsafe := false) : Env .anon := + env.insert addr (.axiomInfo { + toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, + isUnsafe + }) + +/-- Build a recursor and insert it into the env. -/ +def addRec (env : Env .anon) (addr : Address) + (numLevels : Nat) (type : Expr .anon) (all : Array Address) + (numParams numIndices numMotives numMinors : Nat) + (rules : Array (RecursorRule .anon)) + (k := false) (isUnsafe := false) : Env .anon := + env.insert addr (.recInfo { + toConstantVal := { numLevels, type, name := (), levelParams := () }, + all, numParams, numIndices, numMotives, numMinors, rules, k, isUnsafe + }) + +/-- Assert typecheckConst fails. Returns (passed_delta, failure_msg?). -/ +def expectError (env : Env .anon) (prims : Primitives) (addr : Address) + (label : String) : Bool × Option String := + match typecheckConst env prims addr with + | .error _ => (true, none) + | .ok () => (false, some s!"{label}: expected error") + +/-- Assert typecheckConst succeeds. Returns (passed_delta, failure_msg?). -/ +def expectOk (env : Env .anon) (prims : Primitives) (addr : Address) + (label : String) : Bool × Option String := + match typecheckConst env prims addr with + | .ok () => (true, none) + | .error e => (false, some s!"{label}: unexpected error: {e}") + +end Tests.Ix.Kernel.Helpers diff --git a/Tests/Ix/Kernel/Soundness.lean b/Tests/Ix/Kernel/Soundness.lean new file mode 100644 index 00000000..406bc840 --- /dev/null +++ b/Tests/Ix/Kernel/Soundness.lean @@ -0,0 +1,410 @@ +/- + Soundness negative tests: verify that the typechecker rejects unsound + inductive declarations (positivity, universe constraints, K-flag, recursor rules). + + Each test is an individual named function using shared helpers. +-/ +import Ix.Kernel +import Tests.Ix.Kernel.Helpers +import LSpec + +open LSpec +open Ix.Kernel +open Tests.Ix.Kernel.Helpers + +namespace Tests.Ix.Kernel.Soundness + +/-! ## Shared Wrap inductive (reused across several positive-nesting tests) -/ + +/-- Insert Wrap : Sort 1 → Sort 1 and Wrap.mk into the env. -/ +private def addWrap (env : Env .anon) : Env .anon := + let wrapAddr := mkAddr 110 + let wrapMkAddr := mkAddr 111 + -- Wrap : Sort 1 → Sort 1 + let wrapType : Expr .anon := .forallE (.sort (.succ .zero)) (.sort (.succ .zero)) () () + let env := addInductive env wrapAddr wrapType #[wrapMkAddr] (numParams := 1) + -- Wrap.mk : ∀ (α : Sort 1), α → Wrap α + let wrapMkType : Expr .anon := + .forallE (.sort (.succ .zero)) + (.forallE (.bvar 0 ()) (.app (.const wrapAddr #[] ()) (.bvar 1 ())) () ()) + () () + addCtor env wrapMkAddr wrapAddr wrapMkType 0 1 1 + +private def wrapAddr := mkAddr 110 + +/-! ## Positivity tests -/ + +/-- Test 1: Positivity violation — Bad | mk : (Bad → Bad) → Bad -/ +def positivityViolation : TestSeq := + test "rejects (Bad → Bad) → Bad" ( + let badAddr := mkAddr 10 + let badMkAddr := mkAddr 11 + let env := addInductive default badAddr (.sort (.succ .zero)) #[badMkAddr] (isRec := true) + -- mk : (Bad → Bad) → Bad — Bad in negative position + let mkType : Expr .anon := + .forallE + (.forallE (.const badAddr #[] ()) (.const badAddr #[] ()) () ()) + (.const badAddr #[] ()) + () () + let env := addCtor env badMkAddr badAddr mkType 0 0 1 + (expectError env buildPrimitives badAddr "positivity").1 + ) + +/-- Test 11: Nested positive via Wrap (should PASS) — Tree | node : Wrap Tree → Tree -/ +def nestedWrapPositive : TestSeq := + test "accepts Wrap Tree → Tree" ( + let treeAddr := mkAddr 112 + let treeMkAddr := mkAddr 113 + let env := addWrap default + let env := addInductive env treeAddr (.sort (.succ .zero)) #[treeMkAddr] + (numNested := 1) (isRec := true) + -- Tree.node : Wrap Tree → Tree + let treeMkType : Expr .anon := + .forallE (.app (.const wrapAddr #[] ()) (.const treeAddr #[] ())) + (.const treeAddr #[] ()) () () + let env := addCtor env treeMkAddr treeAddr treeMkType 0 0 1 + (expectOk env buildPrimitives treeAddr "nested-wrap").1 + ) + +/-- Test 12: Double nesting (should PASS) — Forest | grove : Wrap (Wrap Forest) → Forest -/ +def doubleNestedPositive : TestSeq := + test "accepts Wrap (Wrap Forest) → Forest" ( + let forestAddr := mkAddr 114 + let forestMkAddr := mkAddr 115 + let env := addWrap default + let env := addInductive env forestAddr (.sort (.succ .zero)) #[forestMkAddr] + (numNested := 1) (isRec := true) + let forestMkType : Expr .anon := + .forallE + (.app (.const wrapAddr #[] ()) (.app (.const wrapAddr #[] ()) (.const forestAddr #[] ()))) + (.const forestAddr #[] ()) () () + let env := addCtor env forestMkAddr forestAddr forestMkType 0 0 1 + (expectOk env buildPrimitives forestAddr "double-nested").1 + ) + +/-- Test 13: Multi-field nested (should PASS) — Rose | node : Rose → Wrap Rose → Rose -/ +def multiFieldNestedPositive : TestSeq := + test "accepts Rose → Wrap Rose → Rose" ( + let roseAddr := mkAddr 116 + let roseMkAddr := mkAddr 117 + let env := addWrap default + let env := addInductive env roseAddr (.sort (.succ .zero)) #[roseMkAddr] + (numNested := 1) (isRec := true) + let roseMkType : Expr .anon := + .forallE (.const roseAddr #[] ()) + (.forallE (.app (.const wrapAddr #[] ()) (.const roseAddr #[] ())) + (.const roseAddr #[] ()) () ()) + () () + let env := addCtor env roseMkAddr roseAddr roseMkType 0 0 2 + (expectOk env buildPrimitives roseAddr "multi-field-nested").1 + ) + +/-- Test 14: Nested with multiple params — only one tainted (should PASS) + Pair α β | mk : α → β → Pair α β; U | star; MyInd | mk : Pair MyInd U → MyInd -/ +def multiParamNestedPositive : TestSeq := + test "accepts Pair MyInd U → MyInd" ( + let pairAddr := mkAddr 120 + let pairMkAddr := mkAddr 121 + let uAddr := mkAddr 122 + let uMkAddr := mkAddr 123 + let myAddr := mkAddr 124 + let myMkAddr := mkAddr 125 + -- Pair : Sort 1 → Sort 1 → Sort 1 + let pairType : Expr .anon := + .forallE (.sort (.succ .zero)) (.forallE (.sort (.succ .zero)) (.sort (.succ .zero)) () ()) () () + let env := addInductive default pairAddr pairType #[pairMkAddr] (numParams := 2) + -- Pair.mk : ∀ (α β : Sort 1), α → β → Pair α β + let pairMkType : Expr .anon := + .forallE (.sort (.succ .zero)) + (.forallE (.sort (.succ .zero)) + (.forallE (.bvar 1 ()) + (.forallE (.bvar 1 ()) + (.app (.app (.const pairAddr #[] ()) (.bvar 3 ())) (.bvar 2 ())) + () ()) + () ()) + () ()) + () () + let env := addCtor env pairMkAddr pairAddr pairMkType 0 2 2 + -- U : Sort 1 + let env := addInductive env uAddr (.sort (.succ .zero)) #[uMkAddr] + let env := addCtor env uMkAddr uAddr (.const uAddr #[] ()) 0 0 0 + -- MyInd : Sort 1 + let env := addInductive env myAddr (.sort (.succ .zero)) #[myMkAddr] + (numNested := 1) (isRec := true) + -- MyInd.mk : Pair MyInd U → MyInd + let myMkType : Expr .anon := + .forallE (.app (.app (.const pairAddr #[] ()) (.const myAddr #[] ())) (.const uAddr #[] ())) + (.const myAddr #[] ()) () () + let env := addCtor env myMkAddr myAddr myMkType 0 0 1 + (expectOk env buildPrimitives myAddr "multi-param-nested").1 + ) + +/-- Test 15: Negative via nested contravariant param (should FAIL) + Contra α | mk : (α → Prop) → Contra α; Bad | mk : Contra Bad → Bad -/ +def nestedContravariantFails : TestSeq := + test "rejects Contra Bad → Bad (α negative in Contra)" ( + let contraAddr := mkAddr 130 + let contraMkAddr := mkAddr 131 + let badAddr := mkAddr 132 + let badMkAddr := mkAddr 133 + -- Contra : Sort 1 → Sort 1 + let contraType : Expr .anon := .forallE (.sort (.succ .zero)) (.sort (.succ .zero)) () () + let env := addInductive default contraAddr contraType #[contraMkAddr] (numParams := 1) + -- Contra.mk : ∀ (α : Sort 1), (α → Prop) → Contra α + let contraMkType : Expr .anon := + .forallE (.sort (.succ .zero)) + (.forallE + (.forallE (.bvar 0 ()) (.sort .zero) () ()) + (.app (.const contraAddr #[] ()) (.bvar 1 ())) + () ()) + () () + let env := addCtor env contraMkAddr contraAddr contraMkType 0 1 1 + -- Bad : Sort 1 + let env := addInductive env badAddr (.sort (.succ .zero)) #[badMkAddr] (isRec := true) + let badMkType : Expr .anon := + .forallE (.app (.const contraAddr #[] ()) (.const badAddr #[] ())) + (.const badAddr #[] ()) () () + let env := addCtor env badMkAddr badAddr badMkType 0 0 1 + (expectError env buildPrimitives badAddr "nested-contravariant").1 + ) + +/-- Test 16: Inductive in index position (should FAIL) + PIdx : Prop → Prop (numParams=0, numIndices=1); PBad | mk : PIdx PBad → PBad -/ +def inductiveInIndexFails : TestSeq := + test "rejects PBad in index of PIdx" ( + let pidxAddr := mkAddr 140 + let pidxMkAddr := mkAddr 141 + let pbadAddr := mkAddr 142 + let pbadMkAddr := mkAddr 143 + -- PIdx : Prop → Prop + let pidxType : Expr .anon := .forallE (.sort .zero) (.sort .zero) () () + let env := addInductive default pidxAddr pidxType #[pidxMkAddr] (numIndices := 1) + -- PIdx.mk : ∀ (p : Prop), PIdx p + let pidxMkType : Expr .anon := + .forallE (.sort .zero) (.app (.const pidxAddr #[] ()) (.bvar 0 ())) () () + let env := addCtor env pidxMkAddr pidxAddr pidxMkType 0 0 1 + -- PBad : Prop + let env := addInductive env pbadAddr (.sort .zero) #[pbadMkAddr] (isRec := true) + let pbadMkType : Expr .anon := + .forallE (.app (.const pidxAddr #[] ()) (.const pbadAddr #[] ())) + (.const pbadAddr #[] ()) () () + let env := addCtor env pbadMkAddr pbadAddr pbadMkType 0 0 1 + (expectError env buildPrimitives pbadAddr "inductive-in-index").1 + ) + +/-- Test 17: Non-inductive head — axiom wrapping inductive (should FAIL) + axiom F : Sort 1 → Sort 1; Bad | mk : F Bad → Bad -/ +def nonInductiveHeadFails : TestSeq := + test "rejects F Bad → Bad (F is axiom)" ( + let fAddr := mkAddr 150 + let badAddr := mkAddr 152 + let badMkAddr := mkAddr 153 + -- F : Sort 1 → Sort 1 (axiom) + let fType : Expr .anon := .forallE (.sort (.succ .zero)) (.sort (.succ .zero)) () () + let env := addAxiom default fAddr fType + -- Bad : Sort 1 + let env := addInductive env badAddr (.sort (.succ .zero)) #[badMkAddr] (isRec := true) + let badMkType : Expr .anon := + .forallE (.app (.const fAddr #[] ()) (.const badAddr #[] ())) + (.const badAddr #[] ()) () () + let env := addCtor env badMkAddr badAddr badMkType 0 0 1 + (expectError env buildPrimitives badAddr "non-inductive-head").1 + ) + +/-- Test 18: Unsafe outer inductive — not trusted for nesting (should FAIL) + unsafe UWrap α | mk : (α → α) → UWrap α; Bad | mk : UWrap Bad → Bad -/ +def unsafeOuterFails : TestSeq := + test "rejects UWrap Bad → Bad (UWrap is unsafe)" ( + let uwAddr := mkAddr 160 + let uwMkAddr := mkAddr 161 + let badAddr := mkAddr 162 + let badMkAddr := mkAddr 163 + -- UWrap : Sort 1 → Sort 1 (unsafe) + let uwType : Expr .anon := .forallE (.sort (.succ .zero)) (.sort (.succ .zero)) () () + let env := addInductive default uwAddr uwType #[uwMkAddr] (numParams := 1) (isUnsafe := true) + -- UWrap.mk : ∀ (α : Sort 1), (α → α) → UWrap α (unsafe) + let uwMkType : Expr .anon := + .forallE (.sort (.succ .zero)) + (.forallE (.forallE (.bvar 0 ()) (.bvar 1 ()) () ()) + (.app (.const uwAddr #[] ()) (.bvar 1 ())) + () ()) + () () + let env := addCtor env uwMkAddr uwAddr uwMkType 0 1 1 (isUnsafe := true) + -- Bad : Sort 1 + let env := addInductive env badAddr (.sort (.succ .zero)) #[badMkAddr] (isRec := true) + let badMkType : Expr .anon := + .forallE (.app (.const uwAddr #[] ()) (.const badAddr #[] ())) + (.const badAddr #[] ()) () () + let env := addCtor env badMkAddr badAddr badMkType 0 0 1 + (expectError env buildPrimitives badAddr "unsafe-outer").1 + ) + +/-! ## Universe constraints -/ + +/-- Test 2: Universe constraint violation — Sort 2 field in Sort 1 inductive -/ +def universeViolation : TestSeq := + test "rejects Sort 2 field in Sort 1 inductive" ( + let ubAddr := mkAddr 20 + let ubMkAddr := mkAddr 21 + let env := addInductive default ubAddr (.sort (.succ .zero)) #[ubMkAddr] + -- mk : Sort 2 → Uni1Bad — Sort 2 : Sort 3, but inductive is Sort 1 + let mkType : Expr .anon := + .forallE (.sort (.succ (.succ .zero))) (.const ubAddr #[] ()) () () + let env := addCtor env ubMkAddr ubAddr mkType 0 0 1 + (expectError env buildPrimitives ubAddr "universe-constraint").1 + ) + +/-! ## K-flag tests -/ + +/-- Test 3: K=true on non-Prop inductive (Sort 1, 2 ctors) -/ +def kFlagNotProp : TestSeq := + test "rejects K=true on Sort 1 inductive" ( + let indAddr := mkAddr 30 + let mk1Addr := mkAddr 31 + let mk2Addr := mkAddr 32 + let recAddr := mkAddr 33 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mk1Addr, mk2Addr] + let env := addCtor env mk1Addr indAddr (.const indAddr #[] ()) 0 0 0 + let env := addCtor env mk2Addr indAddr (.const indAddr #[] ()) 1 0 0 + let env := addRec env recAddr 1 (.sort (.param 0 ())) #[indAddr] 0 0 1 2 + #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ] (k := true) + (expectError env buildPrimitives recAddr "k-flag-not-prop").1 + ) + +/-- Test 8: K=true on Prop inductive with 2 ctors -/ +def kFlagTwoCtors : TestSeq := + test "rejects K=true with 2 ctors in Prop" ( + let indAddr := mkAddr 80 + let mk1Addr := mkAddr 81 + let mk2Addr := mkAddr 82 + let recAddr := mkAddr 83 + let env := addInductive default indAddr (.sort .zero) #[mk1Addr, mk2Addr] + let env := addCtor env mk1Addr indAddr (.const indAddr #[] ()) 0 0 0 + let env := addCtor env mk2Addr indAddr (.const indAddr #[] ()) 1 0 0 + let env := addRec env recAddr 0 (.sort .zero) #[indAddr] 0 0 1 2 + #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ] (k := true) + (expectError env buildPrimitives recAddr "k-flag-two-ctors").1 + ) + +/-! ## Recursor tests -/ + +/-- Test 4: Recursor wrong rule count — 1 rule for 2-ctor inductive -/ +def recWrongRuleCount : TestSeq := + test "rejects 1 rule for 2-ctor inductive" ( + let indAddr := mkAddr 40 + let mk1Addr := mkAddr 41 + let mk2Addr := mkAddr 42 + let recAddr := mkAddr 43 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mk1Addr, mk2Addr] + let env := addCtor env mk1Addr indAddr (.const indAddr #[] ()) 0 0 0 + let env := addCtor env mk2Addr indAddr (.const indAddr #[] ()) 1 0 0 + let env := addRec env recAddr 1 (.sort (.param 0 ())) #[indAddr] 0 0 1 2 + #[{ ctor := mk1Addr, nfields := 0, rhs := .sort .zero }] -- only 1! + (expectError env buildPrimitives recAddr "rec-wrong-rule-count").1 + ) + +/-- Test 5: Recursor wrong nfields — ctor has 0 fields but rule claims 5 -/ +def recWrongNfields : TestSeq := + test "rejects nfields=5 for 0-field ctor" ( + let indAddr := mkAddr 50 + let mkAddr' := mkAddr 51 + let recAddr := mkAddr 52 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mkAddr'] + let env := addCtor env mkAddr' indAddr (.const indAddr #[] ()) 0 0 0 + let env := addRec env recAddr 1 (.sort (.param 0 ())) #[indAddr] 0 0 1 1 + #[{ ctor := mkAddr', nfields := 5, rhs := .sort .zero }] -- wrong nfields + (expectError env buildPrimitives recAddr "rec-wrong-nfields").1 + ) + +/-- Test 6: Recursor wrong num_params — rec claims 5 params, inductive has 0 -/ +def recWrongNumParams : TestSeq := + test "rejects numParams=5 for 0-param inductive" ( + let indAddr := mkAddr 60 + let mkAddr' := mkAddr 61 + let recAddr := mkAddr 62 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mkAddr'] + let env := addCtor env mkAddr' indAddr (.const indAddr #[] ()) 0 0 0 + let env := addRec env recAddr 1 (.sort (.param 0 ())) #[indAddr] + (numParams := 5) 0 1 1 -- wrong: inductive has 0 + #[{ ctor := mkAddr', nfields := 0, rhs := .sort .zero }] + (expectError env buildPrimitives recAddr "rec-wrong-num-params").1 + ) + +/-- Test 9: Recursor wrong ctor order — rules in wrong order -/ +def recWrongCtorOrder : TestSeq := + test "rejects wrong ctor order in rules" ( + let indAddr := mkAddr 90 + let mk1Addr := mkAddr 91 + let mk2Addr := mkAddr 92 + let recAddr := mkAddr 93 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mk1Addr, mk2Addr] + let env := addCtor env mk1Addr indAddr (.const indAddr #[] ()) 0 0 0 + let env := addCtor env mk2Addr indAddr (.const indAddr #[] ()) 1 0 0 + let env := addRec env recAddr 1 (.sort (.param 0 ())) #[indAddr] 0 0 1 2 + #[ + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero }, -- wrong order! + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero } + ] + (expectError env buildPrimitives recAddr "rec-wrong-ctor-order").1 + ) + +/-! ## Constructor validation -/ + +/-- Test 7: Constructor param count mismatch — ctor claims 3 params, ind has 0 -/ +def ctorParamMismatch : TestSeq := + test "rejects ctor with numParams=3 for 0-param inductive" ( + let indAddr := mkAddr 70 + let mkAddr' := mkAddr 71 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mkAddr'] + let env := addCtor env mkAddr' indAddr (.const indAddr #[] ()) 0 3 0 -- wrong: 3 params + (expectError env buildPrimitives indAddr "ctor-param-mismatch").1 + ) + +/-! ## Sanity -/ + +/-- Test 10: Valid single-ctor inductive passes -/ +def validSingleCtor : TestSeq := + test "accepts valid single-ctor inductive" ( + let indAddr := mkAddr 100 + let mkAddr' := mkAddr 101 + let env := addInductive default indAddr (.sort (.succ .zero)) #[mkAddr'] + let env := addCtor env mkAddr' indAddr (.const indAddr #[] ()) 0 0 0 + (expectOk env buildPrimitives indAddr "valid-inductive").1 + ) + +/-! ## Suite -/ + +def suite : List TestSeq := [ + group "Positivity" + (positivityViolation ++ + nestedWrapPositive ++ + doubleNestedPositive ++ + multiFieldNestedPositive ++ + multiParamNestedPositive ++ + nestedContravariantFails ++ + inductiveInIndexFails ++ + nonInductiveHeadFails ++ + unsafeOuterFails), + group "Universe constraints" + universeViolation, + group "K-flag" + (kFlagNotProp ++ + kFlagTwoCtors), + group "Recursors" + (recWrongRuleCount ++ + recWrongNfields ++ + recWrongNumParams ++ + recWrongCtorOrder), + group "Constructor validation" + ctorParamMismatch, + group "Sanity" + validSingleCtor, +] + +end Tests.Ix.Kernel.Soundness diff --git a/Tests/Ix/Kernel/Unit.lean b/Tests/Ix/Kernel/Unit.lean new file mode 100644 index 00000000..3fc42f29 --- /dev/null +++ b/Tests/Ix/Kernel/Unit.lean @@ -0,0 +1,298 @@ +/- + Unit tests for kernel types: Expr equality, Expr operations, Level operations, + reducibility hints, and inductive helper functions. +-/ +import Ix.Kernel +import Tests.Ix.Kernel.Helpers +import LSpec + +open LSpec +open Ix.Kernel +open Tests.Ix.Kernel.Helpers + +namespace Tests.Ix.Kernel.Unit + +/-! ## Expression equality -/ + +def testExprHashEq : TestSeq := + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv0' : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + test "mkBVar 0 == mkBVar 0" (bv0 == bv0') $ + test "mkBVar 0 != mkBVar 1" (bv0 != bv1) $ + -- Sort equality + let s0 : Expr .anon := Expr.mkSort Level.zero + let s0' : Expr .anon := Expr.mkSort Level.zero + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "mkSort 0 == mkSort 0" (s0 == s0') $ + test "mkSort 0 != mkSort 1" (s0 != s1) $ + -- App equality + let app1 := Expr.mkApp bv0 bv1 + let app1' := Expr.mkApp bv0 bv1 + let app2 := Expr.mkApp bv1 bv0 + test "mkApp bv0 bv1 == mkApp bv0 bv1" (app1 == app1') $ + test "mkApp bv0 bv1 != mkApp bv1 bv0" (app1 != app2) $ + -- Lambda equality + let lam1 := Expr.mkLam s0 bv0 + let lam1' := Expr.mkLam s0 bv0 + let lam2 := Expr.mkLam s1 bv0 + test "mkLam s0 bv0 == mkLam s0 bv0" (lam1 == lam1') $ + test "mkLam s0 bv0 != mkLam s1 bv0" (lam1 != lam2) $ + -- Forall equality + let pi1 := Expr.mkForallE s0 s1 + let pi1' := Expr.mkForallE s0 s1 + test "mkForallE s0 s1 == mkForallE s0 s1" (pi1 == pi1') $ + -- Const equality + let addr1 := Address.blake3 (ByteArray.mk #[1]) + let addr2 := Address.blake3 (ByteArray.mk #[2]) + let c1 : Expr .anon := Expr.mkConst addr1 #[] + let c1' : Expr .anon := Expr.mkConst addr1 #[] + let c2 : Expr .anon := Expr.mkConst addr2 #[] + test "mkConst addr1 == mkConst addr1" (c1 == c1') $ + test "mkConst addr1 != mkConst addr2" (c1 != c2) $ + -- Const with levels + let c1l : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l' : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l2 : Expr .anon := Expr.mkConst addr1 #[Level.succ Level.zero] + test "mkConst addr1 [0] == mkConst addr1 [0]" (c1l == c1l') $ + test "mkConst addr1 [0] != mkConst addr1 [1]" (c1l != c1l2) $ + -- Literal equality + let nat0 : Expr .anon := Expr.mkLit (.natVal 0) + let nat0' : Expr .anon := Expr.mkLit (.natVal 0) + let nat1 : Expr .anon := Expr.mkLit (.natVal 1) + let str1 : Expr .anon := Expr.mkLit (.strVal "hello") + let str1' : Expr .anon := Expr.mkLit (.strVal "hello") + let str2 : Expr .anon := Expr.mkLit (.strVal "world") + test "lit nat 0 == lit nat 0" (nat0 == nat0') $ + test "lit nat 0 != lit nat 1" (nat0 != nat1) $ + test "lit str hello == lit str hello" (str1 == str1') $ + test "lit str hello != lit str world" (str1 != str2) + +/-! ## Expression operations -/ + +def testExprOps : TestSeq := + -- getAppFn / getAppArgs + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + let bv2 : Expr .anon := Expr.mkBVar 2 + let app := Expr.mkApp (Expr.mkApp bv0 bv1) bv2 + test "getAppFn (app (app bv0 bv1) bv2) == bv0" (app.getAppFn == bv0) $ + test "getAppNumArgs == 2" (app.getAppNumArgs == 2) $ + test "getAppArgs[0] == bv1" (app.getAppArgs[0]! == bv1) $ + test "getAppArgs[1] == bv2" (app.getAppArgs[1]! == bv2) $ + -- mkAppN round-trips + let rebuilt := Expr.mkAppN bv0 #[bv1, bv2] + test "mkAppN round-trips" (rebuilt == app) $ + -- Predicates + test "isApp" app.isApp $ + test "isSort" (Expr.mkSort (Level.zero : Level .anon)).isSort $ + test "isLambda" (Expr.mkLam bv0 bv1).isLambda $ + test "isForall" (Expr.mkForallE bv0 bv1).isForall $ + test "isLit" (Expr.mkLit (.natVal 42) : Expr .anon).isLit $ + test "isBVar" bv0.isBVar $ + test "isConst" (Expr.mkConst (m := .anon) default #[]).isConst $ + -- Accessors + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "sortLevel!" (s1.sortLevel! == Level.succ Level.zero) $ + test "bvarIdx!" (bv1.bvarIdx! == 1) + +/-! ## Level operations -/ + +def testLevelOps : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- reduce + test "reduce zero" (Level.reduce l0 == l0) $ + test "reduce (succ zero)" (Level.reduce l1 == l1) $ + -- equalLevel + test "zero equiv zero" (Level.equalLevel l0 l0) $ + test "succ zero equiv succ zero" (Level.equalLevel l1 l1) $ + test "max a b equiv max b a" + (Level.equalLevel (Level.max p0 p1) (Level.max p1 p0)) $ + test "zero not equiv succ zero" (!Level.equalLevel l0 l1) $ + -- leq + test "zero <= zero" (Level.leq l0 l0 0) $ + test "succ zero <= zero + 1" (Level.leq l1 l0 1) $ + test "not (succ zero <= zero)" (!Level.leq l1 l0 0) $ + test "param 0 <= param 0" (Level.leq p0 p0 0) $ + test "succ (param 0) <= param 0 + 1" + (Level.leq (Level.succ p0) p0 1) $ + test "not (succ (param 0) <= param 0)" + (!Level.leq (Level.succ p0) p0 0) + +def testLevelReduceIMax : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- imax u 0 = 0 + test "imax u 0 = 0" (Level.reduceIMax p0 l0 == l0) $ + -- imax u (succ v) = max u (succ v) + test "imax u (succ v) = max u (succ v)" + (Level.equalLevel (Level.reduceIMax p0 l1) (Level.reduceMax p0 l1)) $ + -- imax u u = u (same param) + test "imax u u = u" (Level.reduceIMax p0 p0 == p0) $ + -- imax u v stays imax (different params) + test "imax u v stays imax" + (Level.reduceIMax p0 p1 == Level.imax p0 p1) $ + -- nested: imax u (imax v 0) — reduce inner first, then outer + let inner := Level.reduceIMax p1 l0 -- = 0 + test "imax u (imax v 0) = imax u 0 = 0" + (Level.reduceIMax p0 inner == l0) + +def testLevelReduceMax : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max 0 u = u + test "max 0 u = u" (Level.reduceMax l0 p0 == p0) $ + -- max u 0 = u + test "max u 0 = u" (Level.reduceMax p0 l0 == p0) $ + -- max (succ u) (succ v) = succ (max u v) + test "max (succ u) (succ v) = succ (max u v)" + (Level.reduceMax (Level.succ p0) (Level.succ p1) + == Level.succ (Level.reduceMax p0 p1)) $ + -- max p0 p0 = p0 + test "max p0 p0 = p0" (Level.reduceMax p0 p0 == p0) + +def testLevelLeqComplex : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max u v <= max v u (symmetry) + test "max u v <= max v u" + (Level.leq (Level.max p0 p1) (Level.max p1 p0) 0) $ + -- u <= max u v + test "u <= max u v" + (Level.leq p0 (Level.max p0 p1) 0) $ + -- imax u (succ v) <= max u (succ v) — after reduce they're equal + let lhs := Level.reduce (Level.imax p0 (.succ p1)) + let rhs := Level.reduce (Level.max p0 (.succ p1)) + test "imax u (succ v) <= max u (succ v)" + (Level.leq lhs rhs 0) $ + -- imax u 0 <= 0 + test "imax u 0 <= 0" + (Level.leq (Level.reduce (.imax p0 l0)) l0 0) $ + -- not (succ (max u v) <= max u v) + test "not (succ (max u v) <= max u v)" + (!Level.leq (Level.succ (Level.max p0 p1)) (Level.max p0 p1) 0) $ + -- imax u u <= u + test "imax u u <= u" + (Level.leq (Level.reduce (Level.imax p0 p0)) p0 0) $ + -- imax 1 (imax 1 u) = u (nested imax decomposition) + let l1 : Level .anon := Level.succ Level.zero + let nested := Level.reduce (Level.imax l1 (Level.imax l1 p0)) + test "imax 1 (imax 1 u) <= u" + (Level.leq nested p0 0) $ + test "u <= imax 1 (imax 1 u)" + (Level.leq p0 nested 0) + +def testLevelInstBulkReduce : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- Basic: param 0 with [zero] = zero + test "param 0 with [zero] = zero" + (Level.instBulkReduce #[l0] p0 == l0) $ + -- Multi: param 1 with [zero, succ zero] = succ zero + test "param 1 with [zero, succ zero] = succ zero" + (Level.instBulkReduce #[l0, l1] p1 == l1) $ + -- Out-of-bounds: param 2 with 2-element array shifts + let p2 : Level .anon := Level.param 2 default + test "param 2 with 2-elem array shifts to param 0" + (Level.instBulkReduce #[l0, l1] p2 == Level.param 0 default) $ + -- Compound: imax (param 0) (param 1) with [zero, succ zero] + let compound := Level.imax p0 p1 + let result := Level.instBulkReduce #[l0, l1] compound + -- imax 0 (succ 0) = max 0 (succ 0) = succ 0 + test "imax (param 0) (param 1) subst [zero, succ zero]" + (Level.equalLevel result l1) + +/-! ## Reducibility hints -/ + +def testReducibilityHintsLt : TestSeq := + -- ordering: opaque < regular(n) < abbrev (abbrev unfolds first) + test "regular 1 < regular 2" (ReducibilityHints.lt' (.regular 1) (.regular 2)) $ + test "not (regular 2 < regular 1)" (!ReducibilityHints.lt' (.regular 2) (.regular 1)) $ + test "opaque < regular" (ReducibilityHints.lt' .opaque (.regular 5)) $ + test "opaque < abbrev" (ReducibilityHints.lt' .opaque .abbrev) $ + test "regular < abbrev" (ReducibilityHints.lt' (.regular 5) .abbrev) $ + test "not (regular < opaque)" (!ReducibilityHints.lt' (.regular 5) .opaque) $ + test "not (abbrev < regular)" (!ReducibilityHints.lt' .abbrev (.regular 5)) $ + test "not (abbrev < opaque)" (!ReducibilityHints.lt' .abbrev .opaque) $ + test "not (opaque < opaque)" (!ReducibilityHints.lt' .opaque .opaque) $ + test "not (regular 5 < regular 5)" (!ReducibilityHints.lt' (.regular 5) (.regular 5)) + +/-! ## Inductive helper functions -/ + +def testHelperFunctions : TestSeq := + -- exprMentionsConst + let addr1 := mkAddr 200 + let addr2 := mkAddr 201 + let c1 : Expr .anon := .const addr1 #[] () + let c2 : Expr .anon := .const addr2 #[] () + test "exprMentionsConst: direct match" + (exprMentionsConst c1 addr1) $ + test "exprMentionsConst: no match" + (!exprMentionsConst c2 addr1) $ + test "exprMentionsConst: in app fn" + (exprMentionsConst (.app c1 c2) addr1) $ + test "exprMentionsConst: in app arg" + (exprMentionsConst (.app c2 c1) addr1) $ + test "exprMentionsConst: in forallE domain" + (exprMentionsConst (.forallE c1 c2 () () : Expr .anon) addr1) $ + test "exprMentionsConst: in forallE body" + (exprMentionsConst (.forallE c2 c1 () () : Expr .anon) addr1) $ + test "exprMentionsConst: in lam" + (exprMentionsConst (.lam c1 c2 () () : Expr .anon) addr1) $ + test "exprMentionsConst: absent in sort" + (!exprMentionsConst (.sort .zero : Expr .anon) addr1) $ + test "exprMentionsConst: absent in bvar" + (!exprMentionsConst (.bvar 0 () : Expr .anon) addr1) $ + -- getIndResultLevel + test "getIndResultLevel: sort zero" + (getIndResultLevel (.sort .zero : Expr .anon) == some .zero) $ + test "getIndResultLevel: sort (succ zero)" + (getIndResultLevel (.sort (.succ .zero) : Expr .anon) == some (.succ .zero)) $ + test "getIndResultLevel: forallE _ (sort zero)" + (getIndResultLevel (.forallE (.sort .zero) (.sort (.succ .zero)) () () : Expr .anon) == some (.succ .zero)) $ + test "getIndResultLevel: bvar (no sort)" + (getIndResultLevel (.bvar 0 () : Expr .anon) == none) $ + -- levelIsNonZero + test "levelIsNonZero: zero is false" + (!levelIsNonZero (.zero : Level .anon)) $ + test "levelIsNonZero: succ zero is true" + (levelIsNonZero (.succ .zero : Level .anon)) $ + test "levelIsNonZero: param is false" + (!levelIsNonZero (.param 0 () : Level .anon)) $ + test "levelIsNonZero: max(succ 0, param) is true" + (levelIsNonZero (.max (.succ .zero) (.param 0 ()) : Level .anon)) $ + test "levelIsNonZero: imax(param, succ 0) is true" + (levelIsNonZero (.imax (.param 0 ()) (.succ .zero) : Level .anon)) $ + test "levelIsNonZero: imax(succ, param) depends on second" + (!levelIsNonZero (.imax (.succ .zero) (.param 0 ()) : Level .anon)) $ + -- getCtorReturnType + test "getCtorReturnType: no binders returns expr" + (getCtorReturnType c1 0 0 == c1) $ + test "getCtorReturnType: skips foralls" + (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) + +/-! ## Suite -/ + +def suite : List TestSeq := [ + group "Expr equality" testExprHashEq, + group "Expr operations" testExprOps, + group "Level operations" $ + testLevelOps ++ + group "imax reduction" testLevelReduceIMax ++ + group "max reduction" testLevelReduceMax ++ + group "complex leq" testLevelLeqComplex ++ + group "bulk instantiation" testLevelInstBulkReduce, + group "Reducibility hints" testReducibilityHintsLt, + group "Inductive helpers" testHelperFunctions, +] + +end Tests.Ix.Kernel.Unit diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index d3c9f7ab..360e6a14 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -1,175 +1,27 @@ /- Kernel test suite. - - Unit tests for Kernel types, expression operations, and level operations - - Convert tests (Ixon.Env → Kernel.Env) - - Targeted constant-checking tests (individual constants through the full pipeline) + - Integration tests (convertEnv, const checks, roundtrip) + - Negative tests (malformed declarations) + - Re-exports unit and soundness suites from submodules -/ import Ix.Kernel import Ix.Kernel.DecompileM import Ix.CompileM import Ix.Common import Ix.Meta +import Tests.Ix.Kernel.Helpers +import Tests.Ix.Kernel.Unit +import Tests.Ix.Kernel.Soundness import LSpec open LSpec open Ix.Kernel +open Tests.Ix.Kernel.Helpers namespace Tests.KernelTests -/-! ## Unit tests: Expression equality -/ - -def testExprHashEq : TestSeq := - let bv0 : Expr .anon := Expr.mkBVar 0 - let bv0' : Expr .anon := Expr.mkBVar 0 - let bv1 : Expr .anon := Expr.mkBVar 1 - test "mkBVar 0 == mkBVar 0" (bv0 == bv0') ++ - test "mkBVar 0 != mkBVar 1" (bv0 != bv1) ++ - -- Sort equality - let s0 : Expr .anon := Expr.mkSort Level.zero - let s0' : Expr .anon := Expr.mkSort Level.zero - let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) - test "mkSort 0 == mkSort 0" (s0 == s0') ++ - test "mkSort 0 != mkSort 1" (s0 != s1) ++ - -- App equality - let app1 := Expr.mkApp bv0 bv1 - let app1' := Expr.mkApp bv0 bv1 - let app2 := Expr.mkApp bv1 bv0 - test "mkApp bv0 bv1 == mkApp bv0 bv1" (app1 == app1') ++ - test "mkApp bv0 bv1 != mkApp bv1 bv0" (app1 != app2) ++ - -- Lambda equality - let lam1 := Expr.mkLam s0 bv0 - let lam1' := Expr.mkLam s0 bv0 - let lam2 := Expr.mkLam s1 bv0 - test "mkLam s0 bv0 == mkLam s0 bv0" (lam1 == lam1') ++ - test "mkLam s0 bv0 != mkLam s1 bv0" (lam1 != lam2) ++ - -- Forall equality - let pi1 := Expr.mkForallE s0 s1 - let pi1' := Expr.mkForallE s0 s1 - test "mkForallE s0 s1 == mkForallE s0 s1" (pi1 == pi1') ++ - -- Const equality - let addr1 := Address.blake3 (ByteArray.mk #[1]) - let addr2 := Address.blake3 (ByteArray.mk #[2]) - let c1 : Expr .anon := Expr.mkConst addr1 #[] - let c1' : Expr .anon := Expr.mkConst addr1 #[] - let c2 : Expr .anon := Expr.mkConst addr2 #[] - test "mkConst addr1 == mkConst addr1" (c1 == c1') ++ - test "mkConst addr1 != mkConst addr2" (c1 != c2) ++ - -- Const with levels - let c1l : Expr .anon := Expr.mkConst addr1 #[Level.zero] - let c1l' : Expr .anon := Expr.mkConst addr1 #[Level.zero] - let c1l2 : Expr .anon := Expr.mkConst addr1 #[Level.succ Level.zero] - test "mkConst addr1 [0] == mkConst addr1 [0]" (c1l == c1l') ++ - test "mkConst addr1 [0] != mkConst addr1 [1]" (c1l != c1l2) ++ - -- Literal equality - let nat0 : Expr .anon := Expr.mkLit (.natVal 0) - let nat0' : Expr .anon := Expr.mkLit (.natVal 0) - let nat1 : Expr .anon := Expr.mkLit (.natVal 1) - let str1 : Expr .anon := Expr.mkLit (.strVal "hello") - let str1' : Expr .anon := Expr.mkLit (.strVal "hello") - let str2 : Expr .anon := Expr.mkLit (.strVal "world") - test "lit nat 0 == lit nat 0" (nat0 == nat0') ++ - test "lit nat 0 != lit nat 1" (nat0 != nat1) ++ - test "lit str hello == lit str hello" (str1 == str1') ++ - test "lit str hello != lit str world" (str1 != str2) ++ - .done - -/-! ## Unit tests: Expression operations -/ - -def testExprOps : TestSeq := - -- getAppFn / getAppArgs - let bv0 : Expr .anon := Expr.mkBVar 0 - let bv1 : Expr .anon := Expr.mkBVar 1 - let bv2 : Expr .anon := Expr.mkBVar 2 - let app := Expr.mkApp (Expr.mkApp bv0 bv1) bv2 - test "getAppFn (app (app bv0 bv1) bv2) == bv0" (app.getAppFn == bv0) ++ - test "getAppNumArgs == 2" (app.getAppNumArgs == 2) ++ - test "getAppArgs[0] == bv1" (app.getAppArgs[0]! == bv1) ++ - test "getAppArgs[1] == bv2" (app.getAppArgs[1]! == bv2) ++ - -- mkAppN round-trips - let rebuilt := Expr.mkAppN bv0 #[bv1, bv2] - test "mkAppN round-trips" (rebuilt == app) ++ - -- Predicates - test "isApp" app.isApp ++ - test "isSort" (Expr.mkSort (Level.zero : Level .anon)).isSort ++ - test "isLambda" (Expr.mkLam bv0 bv1).isLambda ++ - test "isForall" (Expr.mkForallE bv0 bv1).isForall ++ - test "isLit" (Expr.mkLit (.natVal 42) : Expr .anon).isLit ++ - test "isBVar" bv0.isBVar ++ - test "isConst" (Expr.mkConst (m := .anon) default #[]).isConst ++ - -- Accessors - let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) - test "sortLevel!" (s1.sortLevel! == Level.succ Level.zero) ++ - test "bvarIdx!" (bv1.bvarIdx! == 1) ++ - .done - -/-! ## Unit tests: Level operations -/ - -def testLevelOps : TestSeq := - let l0 : Level .anon := Level.zero - let l1 : Level .anon := Level.succ Level.zero - let p0 : Level .anon := Level.param 0 default - let p1 : Level .anon := Level.param 1 default - -- reduce - test "reduce zero" (Level.reduce l0 == l0) ++ - test "reduce (succ zero)" (Level.reduce l1 == l1) ++ - -- equalLevel - test "zero equiv zero" (Level.equalLevel l0 l0) ++ - test "succ zero equiv succ zero" (Level.equalLevel l1 l1) ++ - test "max a b equiv max b a" - (Level.equalLevel (Level.max p0 p1) (Level.max p1 p0)) ++ - test "zero not equiv succ zero" (!Level.equalLevel l0 l1) ++ - -- leq - test "zero <= zero" (Level.leq l0 l0 0) ++ - test "succ zero <= zero + 1" (Level.leq l1 l0 1) ++ - test "not (succ zero <= zero)" (!Level.leq l1 l0 0) ++ - test "param 0 <= param 0" (Level.leq p0 p0 0) ++ - test "succ (param 0) <= param 0 + 1" - (Level.leq (Level.succ p0) p0 1) ++ - test "not (succ (param 0) <= param 0)" - (!Level.leq (Level.succ p0) p0 0) ++ - .done - /-! ## Integration tests: Const pipeline -/ -/-- Parse a dotted name string like "Nat.add" into an Ix.Name. - Handles `«...»` quoted name components (e.g. `Foo.«0».Bar`). -/ -private partial def parseIxName (s : String) : Ix.Name := - let parts := splitParts s.toList [] - parts.foldl (fun acc part => - match part with - | .inl str => Ix.Name.mkStr acc str - | .inr nat => Ix.Name.mkNat acc nat - ) Ix.Name.mkAnon -where - /-- Split a dotted name into parts: .inl for string components, .inr for numeric (guillemet). -/ - splitParts : List Char → List (String ⊕ Nat) → List (String ⊕ Nat) - | [], acc => acc - | '.' :: rest, acc => splitParts rest acc - | '«' :: rest, acc => - let (inside, rest') := collectUntilClose rest "" - let part := match inside.toNat? with - | some n => .inr n - | none => .inl inside - splitParts rest' (acc ++ [part]) - | cs, acc => - let (word, rest) := collectUntilDot cs "" - splitParts rest (if word.isEmpty then acc else acc ++ [.inl word]) - collectUntilClose : List Char → String → String × List Char - | [], s => (s, []) - | '»' :: rest, s => (s, rest) - | c :: rest, s => collectUntilClose rest (s.push c) - collectUntilDot : List Char → String → String × List Char - | [], s => (s, []) - | '.' :: rest, s => (s, '.' :: rest) - | '«' :: rest, s => (s, '«' :: rest) - | c :: rest, s => collectUntilDot rest (s.push c) - -/-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ -private partial def leanNameToIx : Lean.Name → Ix.Name - | .anonymous => Ix.Name.mkAnon - | .str pre s => Ix.Name.mkStr (leanNameToIx pre) s - | .num pre n => Ix.Name.mkNat (leanNameToIx pre) n - def testConvertEnv : TestSeq := .individualIO "rsCompileEnv + convertEnv" (do let leanEnv ← get_env! @@ -289,7 +141,9 @@ def testConsts : TestSeq := -- BVDecide regression test (fuel-sensitive) "Std.Tactic.BVDecide.BVExpr.bitblast.blastUdiv.instLawfulVecOperatorShiftConcatInputBlastShiftConcat", -- Theorem with sub-term type mismatch (requires inferOnly) - "Std.Do.Spec.tryCatch_ExceptT" + "Std.Do.Spec.tryCatch_ExceptT", + -- Nested inductive positivity check (requires whnf) + "Lean.Elab.Term.Do.Code.action" ] let mut passed := 0 let mut failures : Array String := #[] @@ -396,113 +250,6 @@ def testDumpPrimAddrs : TestSeq := return (true, none) ) .done -/-! ## Unit tests: Level reduce/imax edge cases -/ - -def testLevelReduceIMax : TestSeq := - let l0 : Level .anon := Level.zero - let l1 : Level .anon := Level.succ Level.zero - let p0 : Level .anon := Level.param 0 default - let p1 : Level .anon := Level.param 1 default - -- imax u 0 = 0 - test "imax u 0 = 0" (Level.reduceIMax p0 l0 == l0) ++ - -- imax u (succ v) = max u (succ v) - test "imax u (succ v) = max u (succ v)" - (Level.equalLevel (Level.reduceIMax p0 l1) (Level.reduceMax p0 l1)) ++ - -- imax u u = u (same param) - test "imax u u = u" (Level.reduceIMax p0 p0 == p0) ++ - -- imax u v stays imax (different params) - test "imax u v stays imax" - (Level.reduceIMax p0 p1 == Level.imax p0 p1) ++ - -- nested: imax u (imax v 0) — reduce inner first, then outer - let inner := Level.reduceIMax p1 l0 -- = 0 - test "imax u (imax v 0) = imax u 0 = 0" - (Level.reduceIMax p0 inner == l0) ++ - .done - -def testLevelReduceMax : TestSeq := - let l0 : Level .anon := Level.zero - let p0 : Level .anon := Level.param 0 default - let p1 : Level .anon := Level.param 1 default - -- max 0 u = u - test "max 0 u = u" (Level.reduceMax l0 p0 == p0) ++ - -- max u 0 = u - test "max u 0 = u" (Level.reduceMax p0 l0 == p0) ++ - -- max (succ u) (succ v) = succ (max u v) - test "max (succ u) (succ v) = succ (max u v)" - (Level.reduceMax (Level.succ p0) (Level.succ p1) - == Level.succ (Level.reduceMax p0 p1)) ++ - -- max p0 p0 = p0 - test "max p0 p0 = p0" (Level.reduceMax p0 p0 == p0) ++ - .done - -def testLevelLeqComplex : TestSeq := - let l0 : Level .anon := Level.zero - let p0 : Level .anon := Level.param 0 default - let p1 : Level .anon := Level.param 1 default - -- max u v <= max v u (symmetry) - test "max u v <= max v u" - (Level.leq (Level.max p0 p1) (Level.max p1 p0) 0) ++ - -- u <= max u v - test "u <= max u v" - (Level.leq p0 (Level.max p0 p1) 0) ++ - -- imax u (succ v) <= max u (succ v) — after reduce they're equal - let lhs := Level.reduce (Level.imax p0 (.succ p1)) - let rhs := Level.reduce (Level.max p0 (.succ p1)) - test "imax u (succ v) <= max u (succ v)" - (Level.leq lhs rhs 0) ++ - -- imax u 0 <= 0 - test "imax u 0 <= 0" - (Level.leq (Level.reduce (.imax p0 l0)) l0 0) ++ - -- not (succ (max u v) <= max u v) - test "not (succ (max u v) <= max u v)" - (!Level.leq (Level.succ (Level.max p0 p1)) (Level.max p0 p1) 0) ++ - -- imax u u <= u - test "imax u u <= u" - (Level.leq (Level.reduce (Level.imax p0 p0)) p0 0) ++ - -- imax 1 (imax 1 u) = u (nested imax decomposition) - let l1 : Level .anon := Level.succ Level.zero - let nested := Level.reduce (Level.imax l1 (Level.imax l1 p0)) - test "imax 1 (imax 1 u) <= u" - (Level.leq nested p0 0) ++ - test "u <= imax 1 (imax 1 u)" - (Level.leq p0 nested 0) ++ - .done - -def testLevelInstBulkReduce : TestSeq := - let l0 : Level .anon := Level.zero - let l1 : Level .anon := Level.succ Level.zero - let p0 : Level .anon := Level.param 0 default - let p1 : Level .anon := Level.param 1 default - -- Basic: param 0 with [zero] = zero - test "param 0 with [zero] = zero" - (Level.instBulkReduce #[l0] p0 == l0) ++ - -- Multi: param 1 with [zero, succ zero] = succ zero - test "param 1 with [zero, succ zero] = succ zero" - (Level.instBulkReduce #[l0, l1] p1 == l1) ++ - -- Out-of-bounds: param 2 with 2-element array shifts - let p2 : Level .anon := Level.param 2 default - test "param 2 with 2-elem array shifts to param 0" - (Level.instBulkReduce #[l0, l1] p2 == Level.param 0 default) ++ - -- Compound: imax (param 0) (param 1) with [zero, succ zero] - let compound := Level.imax p0 p1 - let result := Level.instBulkReduce #[l0, l1] compound - -- imax 0 (succ 0) = max 0 (succ 0) = succ 0 - test "imax (param 0) (param 1) subst [zero, succ zero]" - (Level.equalLevel result l1) ++ - .done - -def testReducibilityHintsLt : TestSeq := - test "regular 1 < regular 2" (ReducibilityHints.lt' (.regular 1) (.regular 2)) ++ - test "not (regular 2 < regular 1)" (!ReducibilityHints.lt' (.regular 2) (.regular 1)) ++ - test "regular _ < opaque" (ReducibilityHints.lt' (.regular 5) .opaque) ++ - test "abbrev < opaque" (ReducibilityHints.lt' .abbrev .opaque) ++ - test "not (opaque < opaque)" (!ReducibilityHints.lt' .opaque .opaque) ++ - test "not (regular 5 < regular 5)" (!ReducibilityHints.lt' (.regular 5) (.regular 5)) ++ - .done - -/-! ## Expanded integration tests -/ - - /-! ## Anon mode conversion test -/ /-- Test that convertEnv in .anon mode produces the same number of constants. -/ @@ -604,7 +351,7 @@ def negativeTests : TestSeq := let cv : ConstantVal .anon := { numLevels := 0, type := .sort (.succ (.succ .zero)), name := (), levelParams := () } let lam : Expr .anon := .lam (.sort .zero) (.bvar 0 ()) () () - let ci : ConstantInfo .anon := .defnInfo + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .app lam (.sort (.succ .zero)), hints := .opaque, safety := .safe, all := #[] } let env := (default : Env .anon).insert testAddr ci match typecheckConst env prims testAddr with @@ -629,492 +376,6 @@ def negativeTests : TestSeq := return (false, some s!"{failures.size} failure(s)") ) .done -/-! ## Soundness negative tests (inductive validation) -/ - -/-- Helper: make unique addresses from a seed byte. -/ -private def mkAddr (seed : UInt8) : Address := - Address.blake3 (ByteArray.mk #[seed, 0xAA, 0xBB]) - -/-- Soundness negative test suite: verify that the typechecker rejects unsound - inductive declarations (positivity, universe constraints, K-flag, recursor rules). -/ -def soundnessNegativeTests : TestSeq := - .individualIO "kernel soundness negative tests" (do - let prims := buildPrimitives - let mut passed := 0 - let mut failures : Array String := #[] - - -- ======================================================================== - -- Test 1: Positivity violation — Bad | mk : (Bad → Bad) → Bad - -- The inductive appears in negative position (Pi domain). - -- ======================================================================== - do - let badAddr := mkAddr 10 - let badMkAddr := mkAddr 11 - let badType : Expr .anon := .sort (.succ .zero) -- Sort 1 - let badCv : ConstantVal .anon := - { numLevels := 0, type := badType, name := (), levelParams := () } - let badInd : ConstantInfo .anon := .inductInfo { - toConstantVal := badCv, numParams := 0, numIndices := 0, - all := #[badAddr], ctors := #[badMkAddr], numNested := 0, - isRec := true, isUnsafe := false, isReflexive := false - } - -- mk : (Bad → Bad) → Bad - -- The domain (Bad → Bad) has Bad in negative position - let mkType : Expr .anon := - .forallE - (.forallE (.const badAddr #[] ()) (.const badAddr #[] ()) () ()) - (.const badAddr #[] ()) - () () - let mkCv : ConstantVal .anon := - { numLevels := 0, type := mkType, name := (), levelParams := () } - let mkCtor : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := badAddr, cidx := 0, - numParams := 0, numFields := 1, isUnsafe := false - } - let env := ((default : Env .anon).insert badAddr badInd).insert badMkAddr mkCtor - match typecheckConst env prims badAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "positivity-violation: expected error (Bad → Bad in domain)" - - -- ======================================================================== - -- Test 2: Universe constraint violation — Uni1Bad : Sort 1 | mk : Sort 2 → Uni1Bad - -- Field lives in Sort 3 (Sort 2 : Sort 3) but inductive is in Sort 1. - -- (Note: Prop inductives have special exception allowing any field universe, - -- so we test with a Sort 1 inductive instead.) - -- ======================================================================== - do - let ubAddr := mkAddr 20 - let ubMkAddr := mkAddr 21 - let ubType : Expr .anon := .sort (.succ .zero) -- Sort 1 - let ubCv : ConstantVal .anon := - { numLevels := 0, type := ubType, name := (), levelParams := () } - let ubInd : ConstantInfo .anon := .inductInfo { - toConstantVal := ubCv, numParams := 0, numIndices := 0, - all := #[ubAddr], ctors := #[ubMkAddr], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - -- mk : Sort 2 → Uni1Bad - -- Sort 2 : Sort 3, so field sort = 3. Inductive sort = 1. 3 ≤ 1 fails. - let mkType : Expr .anon := - .forallE (.sort (.succ (.succ .zero))) (.const ubAddr #[] ()) () () - let mkCv : ConstantVal .anon := - { numLevels := 0, type := mkType, name := (), levelParams := () } - let mkCtor : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := ubAddr, cidx := 0, - numParams := 0, numFields := 1, isUnsafe := false - } - let env := ((default : Env .anon).insert ubAddr ubInd).insert ubMkAddr mkCtor - match typecheckConst env prims ubAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "universe-constraint: expected error (Sort 2 field in Sort 1 inductive)" - - -- ======================================================================== - -- Test 3: K-flag invalid — K=true on non-Prop inductive (Sort 1, 2 ctors) - -- ======================================================================== - do - let indAddr := mkAddr 30 - let mk1Addr := mkAddr 31 - let mk2Addr := mkAddr 32 - let recAddr := mkAddr 33 - let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 (not Prop) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mk1Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk1CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk1Cv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let mk2Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk2CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk2Cv, induct := indAddr, cidx := 1, - numParams := 0, numFields := 0, isUnsafe := false - } - -- Recursor with k=true on a non-Prop inductive - let recCv : ConstantVal .anon := - { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, - rules := #[ - { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, - { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } - ], - k := true, -- INVALID: not Prop - isUnsafe := false - } - let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "k-flag-not-prop: expected error" - - -- ======================================================================== - -- Test 4: Recursor wrong rule count — 1 rule for 2-ctor inductive - -- ======================================================================== - do - let indAddr := mkAddr 40 - let mk1Addr := mkAddr 41 - let mk2Addr := mkAddr 42 - let recAddr := mkAddr 43 - let indType : Expr .anon := .sort (.succ .zero) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mk1Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk1CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk1Cv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let mk2Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk2CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk2Cv, induct := indAddr, cidx := 1, - numParams := 0, numFields := 0, isUnsafe := false - } - -- Recursor with only 1 rule (should be 2) - let recCv : ConstantVal .anon := - { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, - rules := #[{ ctor := mk1Addr, nfields := 0, rhs := .sort .zero }], -- only 1! - k := false, isUnsafe := false - } - let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "rec-wrong-rule-count: expected error" - - -- ======================================================================== - -- Test 5: Recursor wrong nfields — ctor has 0 fields but rule claims 5 - -- ======================================================================== - do - let indAddr := mkAddr 50 - let mkAddr' := mkAddr 51 - let recAddr := mkAddr 52 - let indType : Expr .anon := .sort (.succ .zero) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mkAddr'], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mkCv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mkCI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let recCv : ConstantVal .anon := - { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 0, numIndices := 0, numMotives := 1, numMinors := 1, - rules := #[{ ctor := mkAddr', nfields := 5, rhs := .sort .zero }], -- wrong nfields - k := false, isUnsafe := false - } - let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "rec-wrong-nfields: expected error" - - -- ======================================================================== - -- Test 6: Recursor wrong num_params — rec claims 5 params, inductive has 0 - -- ======================================================================== - do - let indAddr := mkAddr 60 - let mkAddr' := mkAddr 61 - let recAddr := mkAddr 62 - let indType : Expr .anon := .sort (.succ .zero) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mkAddr'], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mkCv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mkCI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let recCv : ConstantVal .anon := - { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 5, -- wrong: inductive has 0 - numIndices := 0, numMotives := 1, numMinors := 1, - rules := #[{ ctor := mkAddr', nfields := 0, rhs := .sort .zero }], - k := false, isUnsafe := false - } - let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "rec-wrong-num-params: expected error" - - -- ======================================================================== - -- Test 7: Constructor param count mismatch — ctor claims 3 params, ind has 0 - -- ======================================================================== - do - let indAddr := mkAddr 70 - let mkAddr' := mkAddr 71 - let indType : Expr .anon := .sort (.succ .zero) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mkAddr'], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mkCv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mkCI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := indAddr, cidx := 0, - numParams := 3, -- wrong: inductive has 0 - numFields := 0, isUnsafe := false - } - let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI - match typecheckConst env prims indAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "ctor-param-mismatch: expected error" - - -- ======================================================================== - -- Test 8: K-flag invalid — K=true on Prop inductive with 2 ctors - -- ======================================================================== - do - let indAddr := mkAddr 80 - let mk1Addr := mkAddr 81 - let mk2Addr := mkAddr 82 - let recAddr := mkAddr 83 - let indType : Expr .anon := .sort .zero -- Prop - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mk1Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk1CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk1Cv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let mk2Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk2CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk2Cv, induct := indAddr, cidx := 1, - numParams := 0, numFields := 0, isUnsafe := false - } - let recCv : ConstantVal .anon := - { numLevels := 0, type := .sort .zero, name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, - rules := #[ - { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, - { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } - ], - k := true, -- INVALID: 2 ctors - isUnsafe := false - } - let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "k-flag-two-ctors: expected error" - - -- ======================================================================== - -- Test 9: Recursor wrong ctor order — rules in wrong order - -- ======================================================================== - do - let indAddr := mkAddr 90 - let mk1Addr := mkAddr 91 - let mk2Addr := mkAddr 92 - let recAddr := mkAddr 93 - let indType : Expr .anon := .sort (.succ .zero) - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mk1Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk1CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk1Cv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let mk2Cv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mk2CI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mk2Cv, induct := indAddr, cidx := 1, - numParams := 0, numFields := 0, isUnsafe := false - } - let recCv : ConstantVal .anon := - { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } - let recCI : ConstantInfo .anon := .recInfo { - toConstantVal := recCv, all := #[indAddr], - numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, - rules := #[ - { ctor := mk2Addr, nfields := 0, rhs := .sort .zero }, -- wrong order! - { ctor := mk1Addr, nfields := 0, rhs := .sort .zero } - ], - k := false, isUnsafe := false - } - let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI - match typecheckConst env prims recAddr with - | .error _ => passed := passed + 1 - | .ok () => failures := failures.push "rec-wrong-ctor-order: expected error" - - -- ======================================================================== - -- Test 10: Valid single-ctor inductive passes (sanity check) - -- ======================================================================== - do - let indAddr := mkAddr 100 - let mkAddr' := mkAddr 101 - let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 - let indCv : ConstantVal .anon := - { numLevels := 0, type := indType, name := (), levelParams := () } - let indCI : ConstantInfo .anon := .inductInfo { - toConstantVal := indCv, numParams := 0, numIndices := 0, - all := #[indAddr], ctors := #[mkAddr'], numNested := 0, - isRec := false, isUnsafe := false, isReflexive := false - } - let mkCv : ConstantVal .anon := - { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } - let mkCI : ConstantInfo .anon := .ctorInfo { - toConstantVal := mkCv, induct := indAddr, cidx := 0, - numParams := 0, numFields := 0, isUnsafe := false - } - let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI - match typecheckConst env prims indAddr with - | .ok () => passed := passed + 1 - | .error e => failures := failures.push s!"valid-inductive: unexpected error: {e}" - - let totalTests := 10 - IO.println s!"[kernel-soundness] {passed}/{totalTests} passed" - if failures.isEmpty then - return (true, none) - else - for f in failures do IO.println s!" [fail] {f}" - return (false, some s!"{failures.size} failure(s)") - ) .done - -/-! ## Unit tests: helper functions -/ - -def testHelperFunctions : TestSeq := - -- exprMentionsConst - let addr1 := mkAddr 200 - let addr2 := mkAddr 201 - let c1 : Expr .anon := .const addr1 #[] () - let c2 : Expr .anon := .const addr2 #[] () - test "exprMentionsConst: direct match" - (exprMentionsConst c1 addr1) ++ - test "exprMentionsConst: no match" - (!exprMentionsConst c2 addr1) ++ - test "exprMentionsConst: in app fn" - (exprMentionsConst (.app c1 c2) addr1) ++ - test "exprMentionsConst: in app arg" - (exprMentionsConst (.app c2 c1) addr1) ++ - test "exprMentionsConst: in forallE domain" - (exprMentionsConst (.forallE c1 c2 () () : Expr .anon) addr1) ++ - test "exprMentionsConst: in forallE body" - (exprMentionsConst (.forallE c2 c1 () () : Expr .anon) addr1) ++ - test "exprMentionsConst: in lam" - (exprMentionsConst (.lam c1 c2 () () : Expr .anon) addr1) ++ - test "exprMentionsConst: absent in sort" - (!exprMentionsConst (.sort .zero : Expr .anon) addr1) ++ - test "exprMentionsConst: absent in bvar" - (!exprMentionsConst (.bvar 0 () : Expr .anon) addr1) ++ - -- checkStrictPositivity - let indAddrs := #[addr1] - test "checkStrictPositivity: no mention is positive" - (checkStrictPositivity c2 indAddrs) ++ - test "checkStrictPositivity: head occurrence is positive" - (checkStrictPositivity c1 indAddrs) ++ - test "checkStrictPositivity: in Pi domain is negative" - (!checkStrictPositivity (.forallE c1 c2 () () : Expr .anon) indAddrs) ++ - test "checkStrictPositivity: in Pi codomain positive" - (checkStrictPositivity (.forallE c2 c1 () () : Expr .anon) indAddrs) ++ - -- getIndResultLevel - test "getIndResultLevel: sort zero" - (getIndResultLevel (.sort .zero : Expr .anon) == some .zero) ++ - test "getIndResultLevel: sort (succ zero)" - (getIndResultLevel (.sort (.succ .zero) : Expr .anon) == some (.succ .zero)) ++ - test "getIndResultLevel: forallE _ (sort zero)" - (getIndResultLevel (.forallE (.sort .zero) (.sort (.succ .zero)) () () : Expr .anon) == some (.succ .zero)) ++ - test "getIndResultLevel: bvar (no sort)" - (getIndResultLevel (.bvar 0 () : Expr .anon) == none) ++ - -- levelIsNonZero - test "levelIsNonZero: zero is false" - (!levelIsNonZero (.zero : Level .anon)) ++ - test "levelIsNonZero: succ zero is true" - (levelIsNonZero (.succ .zero : Level .anon)) ++ - test "levelIsNonZero: param is false" - (!levelIsNonZero (.param 0 () : Level .anon)) ++ - test "levelIsNonZero: max(succ 0, param) is true" - (levelIsNonZero (.max (.succ .zero) (.param 0 ()) : Level .anon)) ++ - test "levelIsNonZero: imax(param, succ 0) is true" - (levelIsNonZero (.imax (.param 0 ()) (.succ .zero) : Level .anon)) ++ - test "levelIsNonZero: imax(succ, param) depends on second" - (!levelIsNonZero (.imax (.succ .zero) (.param 0 ()) : Level .anon)) ++ - -- checkCtorPositivity - test "checkCtorPositivity: no inductive mention is ok" - (checkCtorPositivity c2 0 indAddrs == none) ++ - test "checkCtorPositivity: negative occurrence" - (checkCtorPositivity (.forallE (.forallE c1 c2 () ()) (.const addr1 #[] ()) () () : Expr .anon) 0 indAddrs != none) ++ - -- getCtorReturnType - test "getCtorReturnType: no binders returns expr" - (getCtorReturnType c1 0 0 == c1) ++ - test "getCtorReturnType: skips foralls" - (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) ++ - .done - -/-! ## Test suites -/ - -def unitSuite : List TestSeq := [ - testExprHashEq, - testExprOps, - testLevelOps, - testLevelReduceIMax, - testLevelReduceMax, - testLevelLeqComplex, - testLevelInstBulkReduce, - testReducibilityHintsLt, - testHelperFunctions, -] - -def convertSuite : List TestSeq := [ - testConvertEnv, -] - -def constSuite : List TestSeq := [ - testConsts, -] - -def negativeSuite : List TestSeq := [ - negativeTests, - soundnessNegativeTests, -] - -def anonConvertSuite : List TestSeq := [ - testAnonConvert, -] - /-! ## Roundtrip test: Lean → Ixon → Kernel → Lean -/ /-- Roundtrip test: compile Lean env to Ixon, convert to Kernel, decompile back to Lean, @@ -1177,6 +438,25 @@ def testRoundtrip : TestSeq := return (false, some s!"{mismatches}/{checked} constants have structural mismatches") ) .done +/-! ## Test suites -/ + +def unitSuite : List TestSeq := Tests.Ix.Kernel.Unit.suite + +def convertSuite : List TestSeq := [ + testConvertEnv, +] + +def constSuite : List TestSeq := [ + testConsts, +] + +def negativeSuite : List TestSeq := + [negativeTests] ++ Tests.Ix.Kernel.Soundness.suite + +def anonConvertSuite : List TestSeq := [ + testAnonConvert, +] + def roundtripSuite : List TestSeq := [ testRoundtrip, ] From 0055ecc9b21e5b47be890e955b6504d3b52d3c0f Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Mon, 2 Mar 2026 21:39:36 -0500 Subject: [PATCH 09/14] Iterativize binder chains and fix recursor validation for Ixon blocks - Rewrite lam/forallE/letE inference to iterate binder chains instead of recursing, preventing stack overflow on deeply nested terms - Add inductBlock/inductNames to RecursorVal to track the major inductive separately from rec.all, which can be empty in recursor-only Ixon blocks - Build InductiveBlockIndex to extract the correct major inductive from Ixon recursor types at conversion time - Fix validateRecursorRules to look up ctors from the major inductive directly instead of iterating rec.all - Fix isDefEq call in lazyDeltaReduction (was calling isDefEqCore) - Add regression tests for UInt64 isDefEq, recursor-only blocks, and deeply nested let chains --- Ix/Kernel/Convert.lean | 125 ++++++++++++++++++++++++++++---- Ix/Kernel/DecompileM.lean | 5 +- Ix/Kernel/Infer.lean | 147 ++++++++++++++++++++++++-------------- Ix/Kernel/Types.lean | 2 + Tests/Ix/KernelTests.lean | 8 ++- 5 files changed, 218 insertions(+), 69 deletions(-) diff --git a/Ix/Kernel/Convert.lean b/Ix/Kernel/Convert.lean index 369ffca2..6d0ebb5e 100644 --- a/Ix/Kernel/Convert.lean +++ b/Ix/Kernel/Convert.lean @@ -415,7 +415,10 @@ def convertRecursor (m : MetaMode) (r : Ixon.Recursor) (levelParams : MetaField m (Array Ix.Name) := default) (cMeta : ConstantMeta := .empty) (allNames : MetaField m (Array Ix.Name) := default) - (ruleCtorNames : Array (MetaField m Ix.Name) := #[]) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + (ruleCtorNames : Array (MetaField m Ix.Name) := #[]) + (inductBlock : Array Address := #[]) + (inductNames : MetaField m (Array (Array Ix.Name)) := default) + : ConvertM m (Ix.Kernel.ConstantInfo m) := do let typ ← convertExpr m r.typ (metaTypeRoot? cMeta) let cv := mkConstantVal m r.lvls typ name levelParams let ruleRoots := (metaRuleRoots cMeta) @@ -426,7 +429,7 @@ def convertRecursor (m : MetaMode) (r : Ixon.Recursor) let ruleRoot := if h : i < ruleRoots.size then some ruleRoots[i] else none rules := rules.push (← convertRule m r.rules[i]! ctorAddr ctorName ruleRoot) let v : Ix.Kernel.RecursorVal m := - { toConstantVal := cv, all, allNames, + { toConstantVal := cv, all, allNames, inductBlock, inductNames, numParams := r.params.toNat, numIndices := r.indices.toNat, numMotives := r.motives.toNat, numMinors := r.minors.toNat, rules, k := r.k, isUnsafe := r.isUnsafe } @@ -514,6 +517,77 @@ def buildRecurAddrs (bIdx : BlockIndex) (numMembers : Nat) : Except ConvertError | none => throw (.missingMemberAddr i numMembers) return addrs +/-! ## Ixon-level major inductive extraction -/ + +/-- Expand Ixon.Expr.share nodes. -/ +private partial def ixonExpandShare (sharing : Array Ixon.Expr) : Ixon.Expr → Ixon.Expr + | .share idx => + if h : idx.toNat < sharing.size then ixonExpandShare sharing sharing[idx.toNat] + else .share idx + | e => e + +/-- Extract the major inductive's ref index from an Ixon recursor type. + Walks `n` forall (`.all`) binders, then extracts the head `.ref` of the domain. + Returns `none` if the structure doesn't match. -/ +private partial def ixonGetMajorRef (sharing : Array Ixon.Expr) (typ : Ixon.Expr) (n : Nat) : Option UInt64 := + let e := ixonExpandShare sharing typ + match n, e with + | 0, .all dom _ => + let dom' := ixonExpandShare sharing dom + getHead dom' + | n+1, .all _ body => ixonGetMajorRef sharing body n + | _, _ => none +where + getHead : Ixon.Expr → Option UInt64 + | .ref refIdx _ => some refIdx + | .app fn _ => getHead (ixonExpandShare sharing fn) + | _ => none + +/-- Pre-built index mapping each iPrj address to its block's (allInductAddrs, ctorAddrsInOrder). + Built once per convertEnv call, then used for O(1) lookups. -/ +structure InductiveBlockIndex where + /-- iPrj address → (allInductAddrs, ctorAddrsInOrder) for its block -/ + entries : Std.HashMap Address (Array Address × Array Address) := {} + +def InductiveBlockIndex.get (idx : InductiveBlockIndex) (indAddr : Address) + : Array Address × Array Address := + idx.entries.getD indAddr (#[indAddr], #[]) + +/-- Build the InductiveBlockIndex by scanning the Ixon env once. -/ +def buildInductiveBlockIndex (ixonEnv : Ixon.Env) : InductiveBlockIndex := Id.run do + -- Pass 1: group iPrj and cPrj by block address + let mut inductByBlock : Std.HashMap Address (Array (UInt64 × Address)) := {} + let mut ctorByBlock : Std.HashMap Address (Array (UInt64 × UInt64 × Address)) := {} + for (addr, c) in ixonEnv.consts do + match c.info with + | .iPrj prj => + inductByBlock := inductByBlock.insert prj.block + ((inductByBlock.getD prj.block #[]).push (prj.idx, addr)) + | .cPrj prj => + ctorByBlock := ctorByBlock.insert prj.block + ((ctorByBlock.getD prj.block #[]).push (prj.idx, prj.cidx, addr)) + | _ => pure () + -- Pass 2: for each block, sort and build the (inductAddrs, ctorAddrs) pair, + -- then map each iPrj address to that pair + let mut entries : Std.HashMap Address (Array Address × Array Address) := {} + for (blockAddr, rawInduct) in inductByBlock do + let sortedInduct := rawInduct.insertionSort (fun a b => a.1 < b.1) + let inductAddrs := sortedInduct.map (·.2) + let rawCtor := ctorByBlock.getD blockAddr #[] + let sortedCtor := rawCtor.insertionSort (fun a b => a.1 < b.1 || (a.1 == b.1 && a.2.1 < b.2.1)) + let ctorAddrs := sortedCtor.map (·.2.2) + let pair := (inductAddrs, ctorAddrs) + for (_, addr) in sortedInduct do + entries := entries.insert addr pair + { entries } + +/-- Pre-built reverse index mapping constant address → Array of Ix.Names. -/ +def buildAddrToNames (ixonEnv : Ixon.Env) : Std.HashMap Address (Array Ix.Name) := Id.run do + let mut acc : Std.HashMap Address (Array Ix.Name) := {} + for (ixName, entry) in ixonEnv.named do + acc := acc.insert entry.addr ((acc.getD entry.addr #[]).push ixName) + acc + /-! ## Projection conversion -/ /-- Convert a single projection constant as a ConvertM action. @@ -521,6 +595,9 @@ def buildRecurAddrs (bIdx : BlockIndex) (numMembers : Nat) : Except ConvertError def convertProjAction (m : MetaMode) (addr : Address) (c : Constant) (blockConst : Constant) (bIdx : BlockIndex) + (ixonEnv : Ixon.Env) + (indBlockIdx : InductiveBlockIndex) + (addrToNames : Std.HashMap Address (Array Ix.Name)) (name : MetaField m Ix.Name := default) (levelParams : MetaField m (Array Ix.Name) := default) (cMeta : ConstantMeta := .empty) @@ -555,11 +632,24 @@ def convertProjAction (m : MetaMode) if h : prj.idx.toNat < members.size then match members[prj.idx.toNat] with | .recr r => - let ruleCtorAs := bIdx.allCtorAddrsInOrder + -- Extract the major inductive from the Ixon type expression (metadata-free). + let skip := r.params.toNat + r.motives.toNat + r.minors.toNat + r.indices.toNat + let (inductBlock, ruleCtorAs) := + match ixonGetMajorRef blockConst.sharing r.typ skip with + | some refIdx => + if h2 : refIdx.toNat < blockConst.refs.size then + indBlockIdx.get blockConst.refs[refIdx.toNat] + else (bIdx.allInductAddrs, bIdx.allCtorAddrsInOrder) + | none => (bIdx.allInductAddrs, bIdx.allCtorAddrsInOrder) + let inductNs : MetaField m (Array (Array Ix.Name)) := match m with + | .anon => () + | .meta => inductBlock.map fun a => addrToNames.getD a #[] + let ruleCtorNs : Array (MetaField m Ix.Name) := match m with + | .anon => ruleCtorAs.map fun _ => () + | .meta => ruleCtorAs.map fun a => + (addrToNames.getD a #[])[0]?.getD default let allNs := resolveMetaNames m names (match cMeta with | .recr _ _ _ a _ _ _ _ => a | _ => #[]) - let metaRules := match cMeta with | .recr _ _ rules _ _ _ _ _ => rules | _ => #[] - let ruleCtorNs := metaRules.map fun x => resolveMetaName m names x - .ok (convertRecursor m r bIdx.allInductAddrs ruleCtorAs name levelParams cMeta allNs ruleCtorNs) + .ok (convertRecursor m r bIdx.allInductAddrs ruleCtorAs name levelParams cMeta allNs ruleCtorNs inductBlock inductNs) | _ => .error s!"rPrj at {addr} does not point to a recursor" else .error s!"rPrj index out of bounds at {addr}" | .dPrj prj => @@ -647,14 +737,19 @@ def convertStandalone (m : MetaMode) (hashToAddr : Std.HashMap Address Address) let ruleCtorAddrs := metaRules.map fun x => hashToAddr.getD x x let allNames := resolveMetaNames m ixonEnv.names metaAll let ruleCtorNames := metaRules.map fun x => resolveMetaName m ixonEnv.names x - let ci ← (ConvertM.run cEnv (convertRecursor m r all ruleCtorAddrs entry.name lps cMeta allNames ruleCtorNames)).mapError toString + let inductNs : MetaField m (Array (Array Ix.Name)) := match m with + | .anon => () + | .meta => metaAll.map fun x => #[ixonEnv.names.getD x default] + let ci ← (ConvertM.run cEnv (convertRecursor m r all ruleCtorAddrs entry.name lps cMeta allNames ruleCtorNames (inductBlock := all) (inductNames := inductNs))).mapError toString return some ci | .muts _ => return none | _ => return none -- projections handled separately /-- Convert a complete block group (all projections share cache + recurAddrs). -/ def convertWorkBlock (m : MetaMode) - (ixonEnv : Ixon.Env) (blockAddr : Address) + (ixonEnv : Ixon.Env) (indBlockIdx : InductiveBlockIndex) + (addrToNames : Std.HashMap Address (Array Ix.Name)) + (blockAddr : Address) (entries : Array (ConvertEntry m)) (results : Array (Address × Ix.Kernel.ConstantInfo m)) (errors : Array (Address × String)) : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do @@ -690,7 +785,7 @@ def convertWorkBlock (m : MetaMode) let lvlNames := resolveLevelParams m ixonEnv.names (metaLvlAddrs cMeta) let lps := mkLevelParams m ixonEnv.names (metaLvlAddrs cMeta) let cEnv := { baseEnv with arena := (metaArena cMeta), levelParamNames := lvlNames } - match convertProjAction m entry.addr entry.const blockConst bIdx entry.name lps cMeta ixonEnv.names with + match convertProjAction m entry.addr entry.const blockConst bIdx ixonEnv indBlockIdx addrToNames entry.name lps cMeta ixonEnv.names with | .ok action => match ConvertM.runWith cEnv state action with | .ok (ci, state') => @@ -706,7 +801,9 @@ def convertWorkBlock (m : MetaMode) /-- Convert a chunk of work items. -/ def convertChunk (m : MetaMode) (hashToAddr : Std.HashMap Address Address) - (ixonEnv : Ixon.Env) (chunk : Array (WorkItem m)) + (ixonEnv : Ixon.Env) (indBlockIdx : InductiveBlockIndex) + (addrToNames : Std.HashMap Address (Array Ix.Name)) + (chunk : Array (WorkItem m)) : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do let mut results : Array (Address × Ix.Kernel.ConstantInfo m) := #[] let mut errors : Array (Address × String) := #[] @@ -718,7 +815,7 @@ def convertChunk (m : MetaMode) (hashToAddr : Std.HashMap Address Address) | .ok none => pure () | .error e => errors := errors.push (entry.addr, e) | .block blockAddr entries => - (results, errors) := convertWorkBlock m ixonEnv blockAddr entries results errors + (results, errors) := convertWorkBlock m ixonEnv indBlockIdx addrToNames blockAddr entries results errors (results, errors) /-! ## Top-level conversion -/ @@ -803,7 +900,9 @@ def convertEnv (m : MetaMode) (ixonEnv : Ixon.Env) (numWorkers : Nat := 32) workItems := workItems.push (.standalone entry) for ((blockAddr, _), blockEntries) in blockGroups do workItems := workItems.push (.block blockAddr blockEntries) - -- Phase 5: Chunk work items and parallelize + -- Phase 5: Build indexes and chunk work items for parallel conversion + let indBlockIdx := buildInductiveBlockIndex ixonEnv + let addrToNames := buildAddrToNames ixonEnv let total := workItems.size let chunkSize := (total + numWorkers - 1) / numWorkers let mut tasks : Array (Task (Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String))) := #[] @@ -812,7 +911,7 @@ def convertEnv (m : MetaMode) (ixonEnv : Ixon.Env) (numWorkers : Nat := 32) let endIdx := min (offset + chunkSize) total let chunk := workItems[offset:endIdx] let task := Task.spawn (prio := .dedicated) fun () => - convertChunk m hashToAddr ixonEnv chunk.toArray + convertChunk m hashToAddr ixonEnv indBlockIdx addrToNames chunk.toArray tasks := tasks.push task offset := endIdx -- Phase 6: Collect results diff --git a/Ix/Kernel/DecompileM.lean b/Ix/Kernel/DecompileM.lean index d52bda4a..e0dabddf 100644 --- a/Ix/Kernel/DecompileM.lean +++ b/Ix/Kernel/DecompileM.lean @@ -149,9 +149,12 @@ def decompileConstantInfo (ci : ConstantInfo .meta) : Lean.ConstantInfo := isUnsafe := v.isUnsafe } | .recInfo v => + -- Use inductNames (the associated inductives) for Lean's `all` field. + -- inductNames is Array (Array Ix.Name) — flatten to a single list. + let allLean := (v.inductNames.foldl (fun acc group => acc ++ group) #[]).toList.map ixNameToLean .recInfo { name, levelParams := lps, type := decompTy - all := v.allNames.toList.map ixNameToLean + all := allLean numParams := v.numParams, numIndices := v.numIndices numMotives := v.numMotives, numMinors := v.numMinors k := v.k, isUnsafe := v.isUnsafe diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index b3867342..c35d513c 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -158,35 +158,82 @@ mutual throw s!"Expected a pi type, got {currentType'.pp}\n function: {fn.pp}\n arg #{i}: {arg.pp}" let te : TypedExpr m := ⟨← infoFromType currentType, resultBody⟩ pure (te, currentType) - | .lam ty body lamName lamBi => do - let domTe ← if (← read).inferOnly then - pure ⟨.none, ty⟩ - else - let (te, _) ← isSort ty; pure te - let (bodTe, imgType) ← withExtendedCtx ty (infer body) - let piType := Expr.forallE ty imgType lamName default - let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ - pure (te, piType) - | .forallE ty body piName _ => do - let (domTe, domLvl) ← isSort ty - let (imgTe, imgLvl) ← withExtendedCtx ty (isSort body) - let sortLvl := Level.reduceIMax domLvl imgLvl - let typ := Expr.mkSort sortLvl - let te : TypedExpr m := ⟨← infoFromType typ, .forallE domTe.body imgTe.body piName default⟩ - pure (te, typ) - | .letE ty val body letName => do - if (← read).inferOnly then - let (bodTe, bodType) ← withExtendedCtx ty (infer body) - let resultType := bodType.instantiate1 val - let te : TypedExpr m := ⟨bodTe.info, .letE ty val bodTe.body letName⟩ - pure (te, resultType) - else - let (tyTe, _) ← isSort ty - let valTe ← check val ty - let (bodTe, bodType) ← withExtendedCtx ty (infer body) - let resultType := bodType.instantiate1 val - let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ - pure (te, resultType) + | .lam .. => do + -- Iterate lambda chain to avoid O(n) stack depth + let inferOnly := (← read).inferOnly + let mut cur := term + let mut extTypes := (← read).types + let mut binderMeta : Array (Expr m × Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .lam ty body lamName lamBi => + let domBody ← if inferOnly then pure ty + else do let (te, _) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty); pure te.body + binderMeta := binderMeta.push (domBody, ty, lamName, lamBi) + extTypes := extTypes.push ty + cur := body + | _ => break + let (bodTe, imgType) ← withReader (fun ctx => { ctx with types := extTypes }) (infer cur) + let mut resultType := imgType + let mut resultBody := bodTe.body + let mut resultInfo := bodTe.info + for i in [:binderMeta.size] do + let j := binderMeta.size - 1 - i + let (domBody, origTy, lamName, lamBi) := binderMeta[j]! + resultType := .forallE origTy resultType lamName default + resultBody := .lam domBody resultBody lamName lamBi + resultInfo := lamInfo resultInfo + pure (⟨resultInfo, resultBody⟩, resultType) + | .forallE .. => do + -- Iterate forallE chain to avoid O(n) stack depth + let mut cur := term + let mut extTypes := (← read).types + let mut binderMeta : Array (Expr m × Level m × MetaField m Ix.Name) := #[] + repeat + match cur with + | .forallE ty body piName _ => + let (domTe, domLvl) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty) + binderMeta := binderMeta.push (domTe.body, domLvl, piName) + extTypes := extTypes.push ty + cur := body + | _ => break + let (imgTe, imgLvl) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort cur) + let mut resultLvl := imgLvl + let mut resultBody := imgTe.body + for i in [:binderMeta.size] do + let j := binderMeta.size - 1 - i + let (domBody, domLvl, piName) := binderMeta[j]! + resultLvl := Level.reduceIMax domLvl resultLvl + resultBody := .forallE domBody resultBody piName default + let typ := Expr.mkSort resultLvl + pure (⟨← infoFromType typ, resultBody⟩, typ) + | .letE .. => do + -- Iterate let chain to avoid O(n) stack depth + let inferOnly := (← read).inferOnly + let mut cur := term + let mut extTypes := (← read).types + let mut binderInfo : Array (Expr m × Expr m × Expr m × MetaField m Ix.Name) := #[] + repeat + match cur with + | .letE ty val body letName => + if inferOnly then + binderInfo := binderInfo.push (ty, val, val, letName) + else + let (tyTe, _) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty) + let valTe ← withReader (fun ctx => { ctx with types := extTypes }) (check val ty) + binderInfo := binderInfo.push (tyTe.body, valTe.body, val, letName) + extTypes := extTypes.push ty + cur := body + | _ => break + let (bodTe, bodType) ← withReader (fun ctx => { ctx with types := extTypes }) (infer cur) + let mut resultType := bodType + let mut resultBody := bodTe.body + for i in [:binderInfo.size] do + let j := binderInfo.size - 1 - i + let (tyBody, valBody, origVal, letName) := binderInfo[j]! + resultType := resultType.instantiate1 origVal + resultBody := .letE tyBody valBody resultBody letName + pure (⟨bodTe.info, resultBody⟩, resultType) | .lit (.natVal _) => do let prims := (← read).prims let typ := Expr.mkConst prims.nat #[] @@ -509,10 +556,10 @@ mutual /-- Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m Unit := do - if rec.all.size != 1 then - throw "recursor claims K but inductive is mutual" match (← read).kenv.find? indAddr with | some (.inductInfo iv) => + if iv.all.size != 1 then + throw "recursor claims K but inductive is mutual" match getIndResultLevel iv.type with | some lvl => if levelIsNonZero lvl then @@ -527,31 +574,23 @@ mutual | _ => throw "recursor claims K but constructor not found" | _ => throw s!"recursor claims K but {indAddr} is not an inductive" - /-- Validate recursor rules: check rule count, ctor membership, field counts. -/ + /-- Validate recursor rules: check rule count, ctor membership, field counts. + Uses `indAddr` (from getMajorInduct) to look up the inductive directly, + since rec.all may be empty for recursor-only Ixon blocks. + Does NOT check numParams/numIndices — auxiliary recursors (rec_1, etc.) + can have different param counts than the major inductive. -/ partial def validateRecursorRules (rec : RecursorVal m) (indAddr : Address) : TypecheckM m Unit := do - let mut allCtors : Array Address := #[] - for iAddr in rec.all do - match (← read).kenv.find? iAddr with - | some (.inductInfo iv) => - allCtors := allCtors ++ iv.ctors - | _ => throw s!"recursor references {iAddr} which is not an inductive" - if rec.rules.size != allCtors.size then - throw s!"recursor has {rec.rules.size} rules but inductive(s) have {allCtors.size} constructors" - for h : i in [:rec.rules.size] do - let rule := rec.rules[i] - if rule.ctor != allCtors[i]! then - throw s!"recursor rule {i} has constructor {rule.ctor} but expected {allCtors[i]!}" - match (← read).kenv.find? rule.ctor with - | some (.ctorInfo cv) => - if rule.nfields != cv.numFields then - throw s!"recursor rule for {rule.ctor} has nfields={rule.nfields} but constructor has {cv.numFields} fields" - | _ => throw s!"recursor rule constructor {rule.ctor} not found" match (← read).kenv.find? indAddr with | some (.inductInfo iv) => - if rec.numParams != iv.numParams then - throw s!"recursor numParams={rec.numParams} but inductive has {iv.numParams}" - if rec.numIndices != iv.numIndices then - throw s!"recursor numIndices={rec.numIndices} but inductive has {iv.numIndices}" + if rec.rules.size != iv.ctors.size then + throw s!"recursor has {rec.rules.size} rules but inductive has {iv.ctors.size} constructors" + for h : i in [:rec.rules.size] do + let rule := rec.rules[i] + match (← read).kenv.find? iv.ctors[i]! with + | some (.ctorInfo cv) => + if rule.nfields != cv.numFields then + throw s!"recursor rule for {iv.ctors[i]!} has nfields={rule.nfields} but constructor has {cv.numFields} fields" + | _ => throw s!"constructor {iv.ctors[i]!} not found" | _ => pure () /-- Quick structural equality check without WHNF. Returns: @@ -636,7 +675,7 @@ mutual let snn ← whnfCore sn' -- Only recurse into isDefEqCore if something actually changed if !(tnn == tn' && snn == sn') then - let result ← isDefEqCore tnn snn + let result ← isDefEq tnn snn cacheResult t s result return result diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index 15d077c1..4c2adabb 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -466,6 +466,8 @@ structure RecursorRule (m : MetaMode) where structure RecursorVal (m : MetaMode) extends ConstantVal m where all : Array Address allNames : MetaField m (Array Ix.Name) := default + inductBlock : Array Address := #[] + inductNames : MetaField m (Array (Array Ix.Name)) := default numParams : Nat numIndices : Nat numMotives : Nat diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index 360e6a14..d037f3a7 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -143,7 +143,13 @@ def testConsts : TestSeq := -- Theorem with sub-term type mismatch (requires inferOnly) "Std.Do.Spec.tryCatch_ExceptT", -- Nested inductive positivity check (requires whnf) - "Lean.Elab.Term.Do.Code.action" + "Lean.Elab.Term.Do.Code.action", + -- UInt64/BitVec isDefEq regression + "UInt64.decLt", + -- Recursor-only Ixon block regression (rec.all was empty) + "Lean.Elab.Tactic.RCases.RCasesPatt.rec_1", + -- Deeply nested let chain (stack overflow regression) + "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go._sunfold" ] let mut passed := 0 let mut failures : Array String := #[] From cb5fe1a55b53fbdbcaad2e1773d48622fac17e27 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Mon, 2 Mar 2026 22:42:55 -0500 Subject: [PATCH 10/14] Replace HashMap with TreeMap and iterativize Expr traversals Switch all kernel caches from Std.HashMap to Std.TreeMap, replacing hash-based lookups with structural comparison (Expr.compare, Level.compare). Expr.compare is fully iterative using an explicit worklist stack, and Expr.beq/liftLooseBVars/instantiate/substLevels/hasLooseBVarsAbove now loop over binder chains to avoid stack overflow on deeply nested terms. Add pointer equality fast paths (ptrEq) for Level and Expr, and a pointer-address comparator (ptrCompare) for the def-eq failure cache. --- Ix/Kernel/EquivManager.lean | 20 +-- Ix/Kernel/Infer.lean | 4 +- Ix/Kernel/TypecheckM.lean | 21 +-- Ix/Kernel/Types.lean | 333 +++++++++++++++++++++++++++++++----- 4 files changed, 308 insertions(+), 70 deletions(-) diff --git a/Ix/Kernel/EquivManager.lean b/Ix/Kernel/EquivManager.lean index 9521922c..cfabc626 100644 --- a/Ix/Kernel/EquivManager.lean +++ b/Ix/Kernel/EquivManager.lean @@ -14,7 +14,7 @@ abbrev NodeRef := Nat structure EquivManager (m : MetaMode) where uf : Batteries.UnionFind := {} - toNodeMap : Std.HashMap (Expr m) NodeRef := {} + toNodeMap : Std.TreeMap (Expr m) NodeRef Expr.compare := {} instance : Inhabited (EquivManager m) := ⟨{}⟩ @@ -47,12 +47,10 @@ def merge (n1 n2 : NodeRef) : StateM (EquivManager m) Unit := fun mgr => When `useHash = true`, expressions with different hashes are immediately rejected without structural walking (fast path for obviously different terms). -/ -partial def isEquiv (useHash : Bool) (e1 e2 : Expr m) : StateM (EquivManager m) Bool := do +partial def isEquiv (_useHash : Bool) (e1 e2 : Expr m) : StateM (EquivManager m) Bool := do -- 1. Pointer/structural equality (O(1) via Blake3 content-addressing) if e1 == e2 then return true - -- 2. Hash mismatch → definitely not structurally equal - if useHash && Hashable.hash e1 != Hashable.hash e2 then return false - -- 3. BVar fast path (compare indices directly, don't add to union-find) + -- 2. BVar fast path (compare indices directly, don't add to union-find) match e1, e2 with | .bvar i _, .bvar j _ => return i == j | _, _ => pure () @@ -66,16 +64,16 @@ partial def isEquiv (useHash : Bool) (e1 e2 : Expr m) : StateM (EquivManager m) | .sort l1, .sort l2 => pure (l1 == l2) | .lit l1, .lit l2 => pure (l1 == l2) | .app f1 a1, .app f2 a2 => - if ← isEquiv useHash f1 f2 then isEquiv useHash a1 a2 else pure false + if ← isEquiv _useHash f1 f2 then isEquiv _useHash a1 a2 else pure false | .lam d1 b1 _ _, .lam d2 b2 _ _ => - if ← isEquiv useHash d1 d2 then isEquiv useHash b1 b2 else pure false + if ← isEquiv _useHash d1 d2 then isEquiv _useHash b1 b2 else pure false | .forallE d1 b1 _ _, .forallE d2 b2 _ _ => - if ← isEquiv useHash d1 d2 then isEquiv useHash b1 b2 else pure false + if ← isEquiv _useHash d1 d2 then isEquiv _useHash b1 b2 else pure false | .proj ta1 i1 s1 _, .proj ta2 i2 s2 _ => - if ta1 == ta2 && i1 == i2 then isEquiv useHash s1 s2 else pure false + if ta1 == ta2 && i1 == i2 then isEquiv _useHash s1 s2 else pure false | .letE t1 v1 b1 _, .letE t2 v2 b2 _ => - if ← isEquiv useHash t1 t2 then - if ← isEquiv useHash v1 v2 then isEquiv useHash b1 b2 else pure false + if ← isEquiv _useHash t1 t2 then + if ← isEquiv _useHash v1 v2 then isEquiv _useHash b1 b2 else pure false else pure false | _, _ => pure false -- 6. Merge on success diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index c35d513c..546a3a9b 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -903,7 +903,7 @@ mutual if !(← get).failureCache.contains key then if equalUnivArrays tn.getAppFn.constLevels! sn.getAppFn.constLevels! then if ← isDefEqApp tn sn then return (tn, sn, some true) - modify fun stt => { stt with failureCache := stt.failureCache.insert key } + modify fun stt => { stt with failureCache := stt.failureCache.insert key () } if ht.lt' hs then match unfoldDelta ds sn with | some r => sn ← whnfCore r (cheapProj := true); continue @@ -981,7 +981,7 @@ mutual { stt with eqvManager := mgr' } else let key := eqCacheKey t s - modify fun stt => { stt with failureCache := stt.failureCache.insert key } + modify fun stt => { stt with failureCache := stt.failureCache.insert key () } end -- mutual diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 317fac09..c9428245 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -39,20 +39,20 @@ structure TypecheckCtx (m : MetaMode) where /-! ## Typechecker State -/ /-- Default fuel for bounding total recursive work per constant. -/ -def defaultFuel : Nat := 1_000_000 +def defaultFuel : Nat := 10_000_000 structure TypecheckState (m : MetaMode) where typedConsts : Std.TreeMap Address (TypedConst m) Address.compare - whnfCache : Std.HashMap (Expr m) (Expr m) := {} + whnfCache : Std.TreeMap (Expr m) (Expr m) Expr.compare := {} /-- Cache for structural-only WHNF (whnfCore with cheapRec=false, cheapProj=false). Separate from whnfCache to avoid stale entries from cheap reductions. -/ - whnfCoreCache : Std.HashMap (Expr m) (Expr m) := {} + whnfCoreCache : Std.TreeMap (Expr m) (Expr m) Expr.compare := {} /-- Infer cache: maps term → (binding context, inferred type). Keyed on Expr only; context verified on retrieval via ptr equality + BEq fallback. -/ - inferCache : Std.HashMap (Expr m) (Array (Expr m) × Expr m) := {} + inferCache : Std.TreeMap (Expr m) (Array (Expr m) × Expr m) Expr.compare := {} eqvManager : EquivManager m := {} - failureCache : Std.HashSet (Expr m × Expr m) := {} - constTypeCache : Std.HashMap Address (Array (Level m) × Expr m) := {} + failureCache : Std.TreeMap (Expr m × Expr m) Unit Expr.pairCompare := {} + constTypeCache : Std.TreeMap Address (Array (Level m) × Expr m) Address.compare := {} fuel : Nat := defaultFuel /-- Tracks nesting depth of whnf calls from within recursor reduction (tryReduceApp → whnf). When this exceeds a threshold, whnfCore is used instead of whnf to prevent stack overflow. -/ @@ -166,13 +166,8 @@ def ensureTypedConst (addr : Address) : TypecheckM m Unit := do /-! ## Def-eq cache helpers -/ -instance : Hashable (Expr m × Expr m) where - hash p := mixHash (Hashable.hash p.1) (Hashable.hash p.2) - -/-- Symmetric cache key for def-eq pairs. Orders by structural hash to make key(a,b) == key(b,a). -/ +/-- Symmetric cache key for def-eq pairs. Orders by pointer address to make key(a,b) == key(b,a). -/ def eqCacheKey (a b : Expr m) : Expr m × Expr m := - let ha := Hashable.hash a - let hb := Hashable.hash b - if ha ≤ hb then (a, b) else (b, a) + if Expr.ptrCompare a b != .gt then (a, b) else (b, a) end Ix.Kernel diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index 4c2adabb..ed4d07f6 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -71,7 +71,36 @@ inductive Expr (m : MetaMode) where | lit (l : Lean.Literal) | proj (typeAddr : Address) (idx : Nat) (struct : Expr m) (typeName : MetaField m Ix.Name) - deriving Inhabited, BEq + deriving Inhabited + +/-- Structural equality for Expr, iterating over binder body spines to avoid + stack overflow on deeply nested let/lam/forallE chains. -/ +partial def Expr.beq : Expr m → Expr m → Bool := go where + go (a b : Expr m) : Bool := Id.run do + let mut ca := a; let mut cb := b + repeat + match ca, cb with + | .lam ty1 body1 n1 bi1, .lam ty2 body2 n2 bi2 => + if !(go ty1 ty2 && n1 == n2 && bi1 == bi2) then return false + ca := body1; cb := body2 + | .forallE ty1 body1 n1 bi1, .forallE ty2 body2 n2 bi2 => + if !(go ty1 ty2 && n1 == n2 && bi1 == bi2) then return false + ca := body1; cb := body2 + | .letE ty1 val1 body1 n1, .letE ty2 val2 body2 n2 => + if !(go ty1 ty2 && go val1 val2 && n1 == n2) then return false + ca := body1; cb := body2 + | _, _ => break + match ca, cb with + | .bvar i1 n1, .bvar i2 n2 => return i1 == i2 && n1 == n2 + | .sort l1, .sort l2 => return l1 == l2 + | .const a1 ls1 n1, .const a2 ls2 n2 => return a1 == a2 && ls1 == ls2 && n1 == n2 + | .app fn1 arg1, .app fn2 arg2 => return go fn1 fn2 && go arg1 arg2 + | .lit l1, .lit l2 => return l1 == l2 + | .proj a1 i1 s1 n1, .proj a2 i2 s2 n2 => + return a1 == a2 && i1 == i2 && go s1 s2 && n1 == n2 + | _, _ => return false + +instance : BEq (Expr m) where beq := Expr.beq /-! ## Pretty printing helpers -/ @@ -252,9 +281,42 @@ where match e with | .bvar idx name => if idx >= d then .bvar (idx + n) name else e | .app fn arg => .app (go fn d) (go arg d) - | .lam ty body name bi => .lam (go ty d) (go body (d + 1)) name bi - | .forallE ty body name bi => .forallE (go ty d) (go body (d + 1)) name bi - | .letE ty val body name => .letE (go ty d) (go val d) (go body (d + 1)) name + | .lam .. => Id.run do + let mut cur := e; let mut curD := d + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .lam ty body name bi => acc := acc.push (go ty curD, name, bi); curD := curD + 1; cur := body + | _ => break + let mut result := go cur curD + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .lam ty result name bi + return result + | .forallE .. => Id.run do + let mut cur := e; let mut curD := d + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .forallE ty body name bi => acc := acc.push (go ty curD, name, bi); curD := curD + 1; cur := body + | _ => break + let mut result := go cur curD + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .forallE ty result name bi + return result + | .letE .. => Id.run do + let mut cur := e; let mut curD := d + let mut acc : Array (Expr m × Expr m × MetaField m Ix.Name) := #[] + repeat + match cur with + | .letE ty val body name => acc := acc.push (go ty curD, go val curD, name); curD := curD + 1; cur := body + | _ => break + let mut result := go cur curD + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, val, name) := acc[j]! + result := .letE ty val result name + return result | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct d) typeName | .sort .. | .const .. | .lit .. => e @@ -276,9 +338,42 @@ where else .bvar (idx - subst.size) name | .app fn arg => .app (go fn shift) (go arg shift) - | .lam ty body name bi => .lam (go ty shift) (go body (shift + 1)) name bi - | .forallE ty body name bi => .forallE (go ty shift) (go body (shift + 1)) name bi - | .letE ty val body name => .letE (go ty shift) (go val shift) (go body (shift + 1)) name + | .lam .. => Id.run do + let mut cur := e; let mut curShift := shift + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .lam ty body name bi => acc := acc.push (go ty curShift, name, bi); curShift := curShift + 1; cur := body + | _ => break + let mut result := go cur curShift + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .lam ty result name bi + return result + | .forallE .. => Id.run do + let mut cur := e; let mut curShift := shift + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .forallE ty body name bi => acc := acc.push (go ty curShift, name, bi); curShift := curShift + 1; cur := body + | _ => break + let mut result := go cur curShift + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .forallE ty result name bi + return result + | .letE .. => Id.run do + let mut cur := e; let mut curShift := shift + let mut acc : Array (Expr m × Expr m × MetaField m Ix.Name) := #[] + repeat + match cur with + | .letE ty val body name => acc := acc.push (go ty curShift, go val curShift, name); curShift := curShift + 1; cur := body + | _ => break + let mut result := go cur curShift + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, val, name) := acc[j]! + result := .letE ty val result name + return result | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct shift) typeName | .sort .. | .const .. | .lit .. => e @@ -295,23 +390,66 @@ where | .sort lvl => .sort (substFn lvl) | .const addr ls name => .const addr (ls.map substFn) name | .app fn arg => .app (go fn) (go arg) - | .lam ty body name bi => .lam (go ty) (go body) name bi - | .forallE ty body name bi => .forallE (go ty) (go body) name bi - | .letE ty val body name => .letE (go ty) (go val) (go body) name + | .lam .. => Id.run do + let mut cur := e + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .lam ty body name bi => acc := acc.push (go ty, name, bi); cur := body + | _ => break + let mut result := go cur + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .lam ty result name bi + return result + | .forallE .. => Id.run do + let mut cur := e + let mut acc : Array (Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] + repeat + match cur with + | .forallE ty body name bi => acc := acc.push (go ty, name, bi); cur := body + | _ => break + let mut result := go cur + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, name, bi) := acc[j]! + result := .forallE ty result name bi + return result + | .letE .. => Id.run do + let mut cur := e + let mut acc : Array (Expr m × Expr m × MetaField m Ix.Name) := #[] + repeat + match cur with + | .letE ty val body name => acc := acc.push (go ty, go val, name); cur := body + | _ => break + let mut result := go cur + for i in [:acc.size] do + let j := acc.size - 1 - i; let (ty, val, name) := acc[j]! + result := .letE ty val result name + return result | .proj typeAddr idx struct typeName => .proj typeAddr idx (go struct) typeName | .bvar .. | .lit .. => e /-- Check if expression has any bvars with index >= depth. -/ -partial def hasLooseBVarsAbove (e : Expr m) (depth : Nat) : Bool := - match e with - | .bvar idx _ => idx >= depth - | .app fn arg => hasLooseBVarsAbove fn depth || hasLooseBVarsAbove arg depth - | .lam ty body _ _ => hasLooseBVarsAbove ty depth || hasLooseBVarsAbove body (depth + 1) - | .forallE ty body _ _ => hasLooseBVarsAbove ty depth || hasLooseBVarsAbove body (depth + 1) - | .letE ty val body _ => - hasLooseBVarsAbove ty depth || hasLooseBVarsAbove val depth || hasLooseBVarsAbove body (depth + 1) - | .proj _ _ struct _ => hasLooseBVarsAbove struct depth - | .sort .. | .const .. | .lit .. => false +partial def hasLooseBVarsAbove (e : Expr m) (depth : Nat) : Bool := Id.run do + let mut cur := e; let mut curDepth := depth + repeat + match cur with + | .lam ty body _ _ => + if hasLooseBVarsAbove ty curDepth then return true + curDepth := curDepth + 1; cur := body + | .forallE ty body _ _ => + if hasLooseBVarsAbove ty curDepth then return true + curDepth := curDepth + 1; cur := body + | .letE ty val body _ => + if hasLooseBVarsAbove ty curDepth then return true + if hasLooseBVarsAbove val curDepth then return true + curDepth := curDepth + 1; cur := body + | _ => break + match cur with + | .bvar idx _ => return idx >= curDepth + | .app fn arg => return hasLooseBVarsAbove fn curDepth || hasLooseBVarsAbove arg curDepth + | .proj _ _ struct _ => return hasLooseBVarsAbove struct curDepth + | _ => return false /-- Does the expression have any loose (free) bvars? -/ def hasLooseBVars (e : Expr m) : Bool := e.hasLooseBVarsAbove 0 @@ -342,30 +480,137 @@ def letBody! : Expr m → Expr m end Expr -/-! ## Hashable instances -/ - -partial def Level.hash : Level m → UInt64 - | .zero => 7 - | .succ l => mixHash 13 (Level.hash l) - | .max l₁ l₂ => mixHash 17 (mixHash (Level.hash l₁) (Level.hash l₂)) - | .imax l₁ l₂ => mixHash 23 (mixHash (Level.hash l₁) (Level.hash l₂)) - | .param idx _ => mixHash 29 (Hashable.hash idx) - -instance : Hashable (Level m) where hash := Level.hash - -partial def Expr.hash : Expr m → UInt64 - | .bvar idx _ => mixHash 31 (Hashable.hash idx) - | .sort lvl => mixHash 37 (Level.hash lvl) - | .const addr lvls _ => mixHash 41 (mixHash (Hashable.hash addr) (lvls.foldl (fun h l => mixHash h (Level.hash l)) 0)) - | .app fn arg => mixHash 43 (mixHash (Expr.hash fn) (Expr.hash arg)) - | .lam ty body _ _ => mixHash 47 (mixHash (Expr.hash ty) (Expr.hash body)) - | .forallE ty body _ _ => mixHash 53 (mixHash (Expr.hash ty) (Expr.hash body)) - | .letE ty val body _ => mixHash 59 (mixHash (Expr.hash ty) (mixHash (Expr.hash val) (Expr.hash body))) - | .lit (.natVal n) => mixHash 61 (Hashable.hash n) - | .lit (.strVal s) => mixHash 67 (Hashable.hash s) - | .proj addr idx struct _ => mixHash 71 (mixHash (Hashable.hash addr) (mixHash (Hashable.hash idx) (Expr.hash struct))) - -instance : Hashable (Expr m) where hash := Expr.hash +/-! ## Structural ordering -/ + +/-- Numeric tag for Level constructors, used for ordering. -/ +private def Level.tag : Level m → UInt8 + | .zero => 0 + | .succ _ => 1 + | .max _ _ => 2 + | .imax _ _ => 3 + | .param _ _ => 4 + +/-- Pointer equality check for Levels (O(1) fast path). -/ +private unsafe def Level.ptrEqUnsafe (a : @& Level m) (b : @& Level m) : Bool := + ptrAddrUnsafe a == ptrAddrUnsafe b + +@[implemented_by Level.ptrEqUnsafe] +opaque Level.ptrEq : @& Level m → @& Level m → Bool + +/-- Structural ordering on universe levels. Pointer-equal levels short-circuit to .eq. -/ +partial def Level.compare (a b : Level m) : Ordering := + if Level.ptrEq a b then .eq + else match a, b with + | .zero, .zero => .eq + | .succ l₁, .succ l₂ => Level.compare l₁ l₂ + | .max a₁ a₂, .max b₁ b₂ => + match Level.compare a₁ b₁ with | .eq => Level.compare a₂ b₂ | o => o + | .imax a₁ a₂, .imax b₁ b₂ => + match Level.compare a₁ b₁ with | .eq => Level.compare a₂ b₂ | o => o + | .param i₁ _, .param i₂ _ => Ord.compare i₁ i₂ + | _, _ => Ord.compare a.tag b.tag + +private def Level.compareArray (a b : Array (Level m)) : Ordering := Id.run do + match Ord.compare a.size b.size with + | .eq => + for i in [:a.size] do + match Level.compare a[i]! b[i]! with + | .eq => continue + | o => return o + return .eq + | o => return o + +/-- Numeric tag for Expr constructors, used for ordering. -/ +private def Expr.tag' : Expr m → UInt8 + | .bvar .. => 0 + | .sort .. => 1 + | .const .. => 2 + | .app .. => 3 + | .lam .. => 4 + | .forallE .. => 5 + | .letE .. => 6 + | .lit .. => 7 + | .proj .. => 8 + +/-- Pointer equality check for Exprs (O(1) fast path). -/ +private unsafe def Expr.ptrEqUnsafe (a : @& Expr m) (b : @& Expr m) : Bool := + ptrAddrUnsafe a == ptrAddrUnsafe b + +@[implemented_by Expr.ptrEqUnsafe] +opaque Expr.ptrEq : @& Expr m → @& Expr m → Bool + +/-- Fully iterative structural ordering on expressions using an explicit worklist. + Pointer-equal exprs short-circuit to .eq. Never recurses — uses a stack of + pending comparison pairs to avoid call-stack overflow on huge expressions. -/ +partial def Expr.compare (a b : Expr m) : Ordering := Id.run do + let mut stack : Array (Expr m × Expr m) := #[(a, b)] + while h : stack.size > 0 do + let (e1, e2) := stack[stack.size - 1] + stack := stack.pop + if Expr.ptrEq e1 e2 then continue + -- Flatten binder chains + let mut ca := e1; let mut cb := e2 + repeat + match ca, cb with + | .lam ty1 body1 _ _, .lam ty2 body2 _ _ => + stack := stack.push (ty1, ty2); ca := body1; cb := body2 + | .forallE ty1 body1 _ _, .forallE ty2 body2 _ _ => + stack := stack.push (ty1, ty2); ca := body1; cb := body2 + | .letE ty1 val1 body1 _, .letE ty2 val2 body2 _ => + stack := stack.push (ty1, ty2); stack := stack.push (val1, val2) + ca := body1; cb := body2 + | _, _ => break + -- Flatten app spines, then push heads back for further processing + match ca, cb with + | .app .., .app .. => + let mut f1 := ca; let mut f2 := cb + repeat match f1, f2 with + | .app fn1 arg1, .app fn2 arg2 => + stack := stack.push (arg1, arg2); f1 := fn1; f2 := fn2 + | _, _ => break + -- Push heads back onto stack so binder/leaf handling runs on them + stack := stack.push (f1, f2) + continue + | _, _ => pure () + -- Compare leaf nodes (non-binder, non-app) + match ca, cb with + | .bvar i1 _, .bvar i2 _ => + match Ord.compare i1 i2 with | .eq => pure () | o => return o + | .sort l1, .sort l2 => + match Level.compare l1 l2 with | .eq => pure () | o => return o + | .const a1 ls1 _, .const a2 ls2 _ => + match Ord.compare a1 a2 with | .eq => pure () | o => return o + match Level.compareArray ls1 ls2 with | .eq => pure () | o => return o + | .lit l1, .lit l2 => + let o := match l1, l2 with + | .natVal n1, .natVal n2 => Ord.compare n1 n2 + | .natVal _, .strVal _ => .lt + | .strVal _, .natVal _ => .gt + | .strVal s1, .strVal s2 => Ord.compare s1 s2 + match o with | .eq => pure () | o => return o + | .proj a1 i1 s1 _, .proj a2 i2 s2 _ => + match Ord.compare a1 a2 with | .eq => pure () | o => return o + match Ord.compare i1 i2 with | .eq => pure () | o => return o + stack := stack.push (s1, s2) + | _, _ => + match Ord.compare ca.tag' cb.tag' with | .eq => pure () | o => return o + return .eq + +/-- Pointer-based comparison for expressions. + Structurally-equal expressions at different addresses are considered distinct. + This is fine for def-eq failure caches (we just get occasional misses). + Lean 4 uses refcounting (no moving GC), so addresses are stable. -/ +private unsafe def Expr.ptrCompareUnsafe (a : @& Expr m) (b : @& Expr m) : Ordering := + Ord.compare (ptrAddrUnsafe a) (ptrAddrUnsafe b) + +@[implemented_by Expr.ptrCompareUnsafe] +opaque Expr.ptrCompare : @& Expr m → @& Expr m → Ordering + +/-- Compare pairs of expressions by pointer address (first component, then second). -/ +def Expr.pairCompare (a b : Expr m × Expr m) : Ordering := + match Expr.ptrCompare a.1 b.1 with + | .eq => Expr.ptrCompare a.2 b.2 + | ord => ord /-! ## Enums -/ From c3f16c653d48100929c80523ba807a971d4539e2 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Thu, 5 Mar 2026 13:48:11 -0500 Subject: [PATCH 11/14] Unify recursion depth tracking, move caches before fuel guards, and iterativize isDefEq MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace per-function whnfDepth with unified withRecDepthCheck (limit 2000) across isDefEq, infer, and whnf — simpler and more predictable stack overflow prevention - Move cache lookups (inferCache, whnfCache, whnfCoreCache) before withFuelCheck/withRecDepthCheck so cache hits incur zero fuel or stack cost - Iterativize isDefEq main loop: steps 1-5 now loop via continue instead of recursing back into isDefEq when whnfCore(cheapProj=false) changes terms - Iterativize quickIsDefEq lam/forallE binder chains to avoid deep recursion on nested binders - Add pointer equality fast path to Expr.beq; move Expr.ptrEq decl earlier - Skip context check for closed expressions (const/sort/lit) in inferCache - Add Expr.nodeCount, trace parameter to typecheckConst - Add Std.Time.* dependency chain test constants for _sunfold regression --- Ix/Kernel/Infer.lean | 161 ++++++++++++++++++++++++-------------- Ix/Kernel/TypecheckM.lean | 17 +++- Ix/Kernel/Types.lean | 15 ++-- Ix/Kernel/Whnf.lean | 18 ++--- Tests/Ix/KernelTests.lean | 35 +++++++++ 5 files changed, 167 insertions(+), 79 deletions(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 546a3a9b..eedf0702 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -87,7 +87,7 @@ def infoFromType (typ : Expr m) : TypecheckM m (TypeInfo m) := do mutual /-- Check that a term has a given type. -/ partial def check (term : Expr m) (expectedType : Expr m) : TypecheckM m (TypedExpr m) := do - if (← read).trace then dbg_trace s!"check: {term.tag}" + -- if (← read).trace then dbg_trace s!"check: {term.tag}" let (te, inferredType) ← infer term if !(← isDefEq inferredType expectedType) then let ppInferred := inferredType.pp @@ -96,15 +96,21 @@ mutual pure te /-- Infer the type of an expression, returning the typed expression and its type. -/ - partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × Expr m) := withFuelCheck do - -- Check infer cache: keyed on Expr, context verified on retrieval + partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × Expr m) := do + -- Check infer cache FIRST — no fuel or stack cost for cache hits let types := (← read).types if let some (cachedCtx, cachedType) := (← get).inferCache.get? term then -- Ptr equality first, structural BEq fallback - if unsafe ptrAddrUnsafe cachedCtx == ptrAddrUnsafe types || cachedCtx == types then + -- For consts/sorts/lits, context doesn't matter (always closed) + let contextOk := match term with + | .const .. | .sort .. | .lit .. => true + | _ => unsafe ptrAddrUnsafe cachedCtx == ptrAddrUnsafe types || cachedCtx == types + if contextOk then let te : TypedExpr m := ⟨← infoFromType cachedType, term⟩ return (te, cachedType) - if (← read).trace then dbg_trace s!"infer: {term.tag}" + withRecDepthCheck do + withFuelCheck do + -- if (← read).trace then dbg_trace s!"infer: {term.tag}" let result ← do match term with | .bvar idx bvarName => do let ctx ← read @@ -334,7 +340,9 @@ mutual inferCache := {}, eqvManager := {}, failureCache := {}, - fuel := defaultFuel + fuel := defaultFuel, + recDepth := 0, + maxRecDepth := 0 } -- Skip if already in typedConsts if (← get).typedConsts.get? addr |>.isSome then @@ -613,14 +621,26 @@ mutual if a == b && equalUnivArrays us us' then pure (some true) else pure none | .lit l, .lit l' => pure (some (l == l')) | .bvar i _, .bvar j _ => pure (some (i == j)) - | .lam ty body _ _, .lam ty' body' _ _ => - match ← quickIsDefEq ty ty' with - | some true => quickIsDefEq body body' - | other => pure other - | .forallE ty body _ _, .forallE ty' body' _ _ => - match ← quickIsDefEq ty ty' with - | some true => quickIsDefEq body body' - | other => pure other + | .lam .., .lam .. => do + let mut a := t; let mut b := s + repeat + match a, b with + | .lam ty body _ _, .lam ty' body' _ _ => + match ← quickIsDefEq ty ty' with + | some true => a := body; b := body' + | other => return other + | _, _ => break + quickIsDefEq a b + | .forallE .., .forallE .. => do + let mut a := t; let mut b := s + repeat + match a, b with + | .forallE ty body _ _, .forallE ty' body' _ _ => + match ← quickIsDefEq ty ty' with + | some true => a := body; b := body' + | other => return other + | _, _ => break + quickIsDefEq a b | _, _ => pure none /-- Check if two expressions are definitionally equal. @@ -630,60 +650,66 @@ mutual 3. Lazy delta reduction — unfold definitions one step at a time 4. whnfCore(cheapProj=false) — full projection resolution (only if needed) 5. Structural comparison -/ - partial def isDefEq (t s : Expr m) : TypecheckM m Bool := withFuelCheck do - -- 0. Quick structural check (avoids WHNF for trivially equal/unequal terms) + partial def isDefEq (t s : Expr m) : TypecheckM m Bool := do + -- 0. Quick structural check FIRST — no fuel/stack cost for trivial cases match ← quickIsDefEq t s with | some result => return result | none => pure () + withRecDepthCheck do + withFuelCheck do - -- 1. Stage 1: structural reduction (cheapProj=true: defer full projection resolution) - let tn ← whnfCore t (cheapProj := true) - let sn ← whnfCore s (cheapProj := true) + -- Loop: steps 1-5 may restart when whnfCore(cheapProj=false) changes terms + let mut ct := t + let mut cs := s + repeat + -- 1. Stage 1: structural reduction (cheapProj=true: defer full projection resolution) + let tn ← whnfCore ct (cheapProj := true) + let sn ← whnfCore cs (cheapProj := true) - -- 2. Quick check after whnfCore (useHash=false for thorough union-find walking) - match ← quickIsDefEq tn sn (useHash := false) with - | some true => cacheResult t s true; return true - | some false => pure () -- don't cache — deeper checks may still succeed - | none => pure () + -- 2. Quick check after whnfCore (useHash=false for thorough union-find walking) + match ← quickIsDefEq tn sn (useHash := false) with + | some true => cacheResult t s true; return true + | some false => pure () -- don't cache — deeper checks may still succeed + | none => pure () - -- 3. Proof irrelevance - match ← isDefEqProofIrrel tn sn with - | some result => - cacheResult t s result - return result - | none => pure () + -- 3. Proof irrelevance + match ← isDefEqProofIrrel tn sn with + | some result => + cacheResult t s result + return result + | none => pure () - -- 4. Lazy delta reduction (incremental unfolding) - let (tn', sn', deltaResult) ← lazyDeltaReduction tn sn - if deltaResult == some true then - cacheResult t s true - return true + -- 4. Lazy delta reduction (incremental unfolding) + let (tn', sn', deltaResult) ← lazyDeltaReduction tn sn + if deltaResult == some true then + cacheResult t s true + return true - -- 4b. Cheap structural checks after lazy delta (before full whnfCore) - match tn', sn' with - | .const a us _, .const b us' _ => - if a == b && equalUnivArrays us us' then - cacheResult t s true; return true - | .proj _ ti te _, .proj _ si se _ => - if ti == si then - if ← isDefEq te se then + -- 4b. Cheap structural checks after lazy delta (before full whnfCore) + match tn', sn' with + | .const a us _, .const b us' _ => + if a == b && equalUnivArrays us us' then cacheResult t s true; return true - | _, _ => pure () - - -- 5. Stage 2: full structural reduction (no cheapProj — resolve all projections) - let tnn ← whnfCore tn' - let snn ← whnfCore sn' - -- Only recurse into isDefEqCore if something actually changed - if !(tnn == tn' && snn == sn') then - let result ← isDefEq tnn snn + | .proj _ ti te _, .proj _ si se _ => + if ti == si then + if ← isDefEq te se then + cacheResult t s true; return true + | _, _ => pure () + + -- 5. Stage 2: full structural reduction (no cheapProj — resolve all projections) + let tnn ← whnfCore tn' + let snn ← whnfCore sn' + -- If terms changed, loop back to step 1 instead of recursing into isDefEq + if !(tnn == tn' && snn == sn') then + ct := tnn; cs := snn; continue + + -- 6. Structural comparison on fully-reduced terms + let result ← isDefEqCore tnn snn cacheResult t s result return result - -- 6. Structural comparison on fully-reduced terms - let result ← isDefEqCore tnn snn - - cacheResult t s result - return result + -- unreachable, but needed for type checking + return false /-- Check if both terms are proofs of the same Prop type (proof irrelevance). Returns `none` if inference fails (e.g., free bound variables) or the type isn't Prop. -/ @@ -985,15 +1011,34 @@ mutual end -- mutual +/-! ## Expr size -/ + +/-- Count the number of nodes in an expression (iterative). -/ +partial def Expr.nodeCount (e : Expr m) : Nat := Id.run do + let mut stack : Array (Expr m) := #[e] + let mut count : Nat := 0 + while h : stack.size > 0 do + let cur := stack[stack.size - 1] + stack := stack.pop + count := count + 1 + match cur with + | .app fn arg => stack := stack.push fn |>.push arg + | .lam ty body _ _ => stack := stack.push ty |>.push body + | .forallE ty body _ _ => stack := stack.push ty |>.push body + | .letE ty val body _ => stack := stack.push ty |>.push val |>.push body + | .proj _ _ s _ => stack := stack.push s + | _ => pure () + return count + /-! ## Top-level entry points -/ /-- Typecheck a single constant by address. -/ def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) - (quotInit : Bool := true) : Except String Unit := + (quotInit : Bool := true) (trace : Bool := false) : Except String Unit := let ctx : TypecheckCtx m := { types := #[], kenv := kenv, prims := prims, safety := .safe, quotInit := quotInit, - mutTypes := default, recAddr? := none + mutTypes := default, recAddr? := none, trace := trace } let stt : TypecheckState m := { typedConsts := default } let (result, _) := TypecheckM.run ctx stt (checkConst addr) diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index c9428245..7f4d078b 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -54,9 +54,6 @@ structure TypecheckState (m : MetaMode) where failureCache : Std.TreeMap (Expr m × Expr m) Unit Expr.pairCompare := {} constTypeCache : Std.TreeMap Address (Array (Level m) × Expr m) Address.compare := {} fuel : Nat := defaultFuel - /-- Tracks nesting depth of whnf calls from within recursor reduction (tryReduceApp → whnf). - When this exceeds a threshold, whnfCore is used instead of whnf to prevent stack overflow. -/ - whnfDepth : Nat := 0 /-- Global recursion depth across isDefEq/infer/whnf for stack overflow prevention. -/ recDepth : Nat := 0 maxRecDepth : Nat := 0 @@ -102,6 +99,20 @@ def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do modify fun s => { s with fuel := s.fuel - 1 } action +/-- Maximum recursion depth for the mutual isDefEq/whnf/infer cycle. + Prevents native stack overflow. Hard error when exceeded. -/ +def maxRecursionDepth : Nat := 2000 + +/-- Check and increment recursion depth. Throws on exceeding limit. -/ +def withRecDepthCheck (action : TypecheckM m α) : TypecheckM m α := do + let d := (← get).recDepth + if d >= maxRecursionDepth then + throw s!"maximum recursion depth ({maxRecursionDepth}) exceeded" + modify fun s => { s with recDepth := d + 1, maxRecDepth := max s.maxRecDepth (d + 1) } + let r ← action + modify fun s => { s with recDepth := d } + pure r + /-! ## Name lookup -/ /-- Look up the MetaField name for a constant address from the kernel environment. -/ diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index ed4d07f6..3d176d8d 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -73,10 +73,18 @@ inductive Expr (m : MetaMode) where (typeName : MetaField m Ix.Name) deriving Inhabited +/-- Pointer equality check for Exprs (O(1) fast path). -/ +private unsafe def Expr.ptrEqUnsafe (a : @& Expr m) (b : @& Expr m) : Bool := + ptrAddrUnsafe a == ptrAddrUnsafe b + +@[implemented_by Expr.ptrEqUnsafe] +opaque Expr.ptrEq : @& Expr m → @& Expr m → Bool + /-- Structural equality for Expr, iterating over binder body spines to avoid stack overflow on deeply nested let/lam/forallE chains. -/ partial def Expr.beq : Expr m → Expr m → Bool := go where go (a b : Expr m) : Bool := Id.run do + if Expr.ptrEq a b then return true let mut ca := a; let mut cb := b repeat match ca, cb with @@ -532,13 +540,6 @@ private def Expr.tag' : Expr m → UInt8 | .lit .. => 7 | .proj .. => 8 -/-- Pointer equality check for Exprs (O(1) fast path). -/ -private unsafe def Expr.ptrEqUnsafe (a : @& Expr m) (b : @& Expr m) : Bool := - ptrAddrUnsafe a == ptrAddrUnsafe b - -@[implemented_by Expr.ptrEqUnsafe] -opaque Expr.ptrEq : @& Expr m → @& Expr m → Bool - /-- Fully iterative structural ordering on expressions using an explicit worklist. Pointer-equal exprs short-circuit to .eq. Never recurses — uses a stack of pending comparison pairs to avoid call-stack overflow on huge expressions. -/ diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean index 21bc566d..e086230e 100644 --- a/Ix/Kernel/Whnf.lean +++ b/Ix/Kernel/Whnf.lean @@ -125,7 +125,7 @@ mutual When cheapRec=true, recursor applications are returned as-is (no iota reduction). -/ partial def whnfCore (e : Expr m) (cheapRec := false) (cheapProj := false) : TypecheckM m (Expr m) := do - -- Cache lookup (only for full structural reduction, not cheap) + -- Cache check FIRST — no stack cost for cache hits let useCache := !cheapRec && !cheapProj if useCache then if let some r := (← get).whnfCoreCache.get? e then return r @@ -367,18 +367,14 @@ mutual then resolve projections iteratively from inside out. Tracks nesting depth: when whnf calls nest too deep (from isDefEq ↔ whnf cycles), degrades to whnfCore to prevent native stack overflow. -/ - partial def whnf (e : Expr m) : TypecheckM m (Expr m) := withFuelCheck do - -- Depth guard: when whnf nesting is too deep, degrade to structural-only - let depth := (← get).whnfDepth - if depth > 64 then return ← whnfCore e - modify fun s => { s with whnfDepth := s.whnfDepth + 1 } - let r ← whnfImpl e - modify fun s => { s with whnfDepth := s.whnfDepth - 1 } - pure r + partial def whnf (e : Expr m) : TypecheckM m (Expr m) := do + -- Cache check FIRST — no fuel or stack cost for cache hits + if let some r := (← get).whnfCache.get? e then return r + withRecDepthCheck do + withFuelCheck do + whnfImpl e partial def whnfImpl (e : Expr m) : TypecheckM m (Expr m) := do - -- Check cache - if let some r := (← get).whnfCache.get? e then return r let mut t ← whnfCore e let mut steps := 0 repeat diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index d037f3a7..ecb2aee1 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -148,6 +148,33 @@ def testConsts : TestSeq := "UInt64.decLt", -- Recursor-only Ixon block regression (rec.all was empty) "Lean.Elab.Tactic.RCases.RCasesPatt.rec_1", + -- Dependencies of _sunfold (check these first to rule out lazy blowup) + "Std.Time.FormatPart", + "Std.Time.FormatConfig", + "Std.Time.FormatString", + "Std.Time.FormatType", + "Std.Time.FormatType.match_1", + "Std.Time.TypeFormat", + "Std.Time.Modifier", + "List.below", + "List.brecOn", + "Std.Internal.Parsec.String.Parser", + "Std.Internal.Parsec.instMonad", + "Std.Internal.Parsec.instAlternative", + "Std.Internal.Parsec.String.skipString", + "Std.Internal.Parsec.eof", + "Std.Internal.Parsec.fail", + "Bind.bind", + "Monad.toBind", + "SeqRight.seqRight", + "Applicative.toSeqRight", + "Applicative.toPure", + "Alternative.toApplicative", + "Pure.pure", + "_private.Std.Time.Format.Basic.«0».Std.Time.parseWith", + "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go.match_3", + "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go.match_1", + "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go", -- Deeply nested let chain (stack overflow regression) "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go._sunfold" ] @@ -160,6 +187,14 @@ def testConsts : TestSeq := let addr := cNamed.addr IO.println s!" checking {name} ..." (← IO.getStdout).flush + -- if name.containsSubstr "builderParser" then + -- if let some ci := kenv.find? addr then + -- let safety := match ci with | .defnInfo v => s!"{repr v.safety}" | _ => "n/a" + -- IO.println s!" [{name}] kind={ci.kindName} safety={safety}" + -- IO.println s!" type: {ci.type.pp}" + -- if let some val := ci.value? then + -- IO.println s!" value ({val.nodeCount} nodes): {val.pp}" + -- (← IO.getStdout).flush let start ← IO.monoMsNow match Ix.Kernel.typecheckConst kenv prims addr quotInit with | .ok () => From 3f08273702064171731bc8a0ee30a8ea3fc4fb0f Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Thu, 5 Mar 2026 18:30:08 -0500 Subject: [PATCH 12/14] Add let-bound bvar zeta-reduction, fix proof irrelevance and K-reduction - Track letValues/numLetBindings in TypecheckCtx so whnfCore can zeta-reduce let-bound bvars by looking up stored values - Thread let context through iterativized binder chains in infer (lam, forallE, letE) and isDefEqCore (lam/pi flattening, eta) - Add isProp that checks type_of(type_of(t)) == Sort 0 and rewrite isDefEqProofIrrel to use it with withInferOnly - Fix K-reduction to apply extra args after major premise - Add cheapBetaReduce for let body result types - Whnf nat primitive args when they aren't already literals - Skip whnf/whnfCore caches when let bindings are in scope - Increase maxRecursionDepth to 10000 --- Ix/Kernel/Infer.lean | 78 ++++++++++++++++++++++++++------------- Ix/Kernel/TypecheckM.lean | 24 ++++++++++-- Ix/Kernel/Types.lean | 22 +++++++++++ Ix/Kernel/Whnf.lean | 55 +++++++++++++++++++++++---- Tests/Ix/KernelTests.lean | 10 ++++- 5 files changed, 151 insertions(+), 38 deletions(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index eedf0702..ffc62d97 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -169,17 +169,19 @@ mutual let inferOnly := (← read).inferOnly let mut cur := term let mut extTypes := (← read).types + let mut extLetValues := (← read).letValues let mut binderMeta : Array (Expr m × Expr m × MetaField m Ix.Name × MetaField m Lean.BinderInfo) := #[] repeat match cur with | .lam ty body lamName lamBi => let domBody ← if inferOnly then pure ty - else do let (te, _) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty); pure te.body + else do let (te, _) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isSort ty); pure te.body binderMeta := binderMeta.push (domBody, ty, lamName, lamBi) extTypes := extTypes.push ty + extLetValues := extLetValues.push none cur := body | _ => break - let (bodTe, imgType) ← withReader (fun ctx => { ctx with types := extTypes }) (infer cur) + let (bodTe, imgType) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (infer cur) let mut resultType := imgType let mut resultBody := bodTe.body let mut resultInfo := bodTe.info @@ -194,16 +196,18 @@ mutual -- Iterate forallE chain to avoid O(n) stack depth let mut cur := term let mut extTypes := (← read).types + let mut extLetValues := (← read).letValues let mut binderMeta : Array (Expr m × Level m × MetaField m Ix.Name) := #[] repeat match cur with | .forallE ty body piName _ => - let (domTe, domLvl) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty) + let (domTe, domLvl) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isSort ty) binderMeta := binderMeta.push (domTe.body, domLvl, piName) extTypes := extTypes.push ty + extLetValues := extLetValues.push none cur := body | _ => break - let (imgTe, imgLvl) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort cur) + let (imgTe, imgLvl) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isSort cur) let mut resultLvl := imgLvl let mut resultBody := imgTe.body for i in [:binderMeta.size] do @@ -218,6 +222,8 @@ mutual let inferOnly := (← read).inferOnly let mut cur := term let mut extTypes := (← read).types + let mut extLetValues := (← read).letValues + let mut extNumLets := (← read).numLetBindings let mut binderInfo : Array (Expr m × Expr m × Expr m × MetaField m Ix.Name) := #[] repeat match cur with @@ -225,14 +231,16 @@ mutual if inferOnly then binderInfo := binderInfo.push (ty, val, val, letName) else - let (tyTe, _) ← withReader (fun ctx => { ctx with types := extTypes }) (isSort ty) - let valTe ← withReader (fun ctx => { ctx with types := extTypes }) (check val ty) + let (tyTe, _) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues, numLetBindings := extNumLets }) (isSort ty) + let valTe ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues, numLetBindings := extNumLets }) (check val ty) binderInfo := binderInfo.push (tyTe.body, valTe.body, val, letName) extTypes := extTypes.push ty + extLetValues := extLetValues.push (some val) + extNumLets := extNumLets + 1 cur := body | _ => break - let (bodTe, bodType) ← withReader (fun ctx => { ctx with types := extTypes }) (infer cur) - let mut resultType := bodType + let (bodTe, bodType) ← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues, numLetBindings := extNumLets }) (infer cur) + let mut resultType := bodType.cheapBetaReduce let mut resultBody := bodTe.body for i in [:binderInfo.size] do let j := binderInfo.size - 1 - i @@ -657,6 +665,8 @@ mutual | none => pure () withRecDepthCheck do withFuelCheck do + let depth := (← get).recDepth + -- Temporarily removed for call-site tracing -- Loop: steps 1-5 may restart when whnfCore(cheapProj=false) changes terms let mut ct := t @@ -711,19 +721,27 @@ mutual -- unreachable, but needed for type checking return false + /-- Check if e lives in Prop: type_of(e) reduces to Sort 0. + Matches lean4lean's `isProp`. -/ + partial def isProp (e : Expr m) : TypecheckM m Bool := do + let (_, ty) ← withInferOnly (infer e) + let ty' ← whnf ty + return ty' == .sort .zero + /-- Check if both terms are proofs of the same Prop type (proof irrelevance). Returns `none` if inference fails (e.g., free bound variables) or the type isn't Prop. -/ partial def isDefEqProofIrrel (t s : Expr m) : TypecheckM m (Option Bool) := do - let tType ← try let (_, ty) ← infer t; pure (some ty) catch _ => pure none + let tType ← try let (_, ty) ← withInferOnly (infer t); pure (some ty) catch _ => pure none let some tType := tType | return none - let tType' ← whnf tType - match tType' with - | .sort .zero => - let sType ← try let (_, ty) ← infer s; pure (some ty) catch _ => pure none - let some sType := sType | return none - let result ← isDefEq tType sType - return some result - | _ => return none + let isPropType ← try isProp tType catch e => do + if (← get).recDepth > 100 then + dbg_trace s!"isProp FAILED at depth {(← get).recDepth}: {e}" + pure false + if !isPropType then return none + let sType ← try let (_, ty) ← withInferOnly (infer s); pure (some ty) catch _ => pure none + let some sType := sType | return none + let result ← isDefEq tType sType + return some result /-- Core structural comparison after whnf. -/ partial def isDefEqCore (t s : Expr m) : TypecheckM m Bool := do @@ -739,28 +757,38 @@ mutual pure (a == b && equalUnivArrays us us') -- Lambda: flatten binder chain to avoid O(num_binders) stack depth + -- Extend context at each binder so proof irrelevance / infer work on bodies | .lam .., .lam .. => do let mut a := t let mut b := s + let mut extTypes := (← read).types + let mut extLetValues := (← read).letValues repeat match a, b with | .lam ty body _ _, .lam ty' body' _ _ => - if !(← isDefEq ty ty') then return false + if !(← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isDefEq ty ty')) then return false + extTypes := extTypes.push ty + extLetValues := extLetValues.push none a := body; b := body' | _, _ => break - isDefEq a b + withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isDefEq a b) -- Pi/ForallE: flatten binder chain to avoid O(num_binders) stack depth + -- Extend context at each binder so proof irrelevance / infer work on bodies | .forallE .., .forallE .. => do let mut a := t let mut b := s + let mut extTypes := (← read).types + let mut extLetValues := (← read).letValues repeat match a, b with | .forallE ty body _ _, .forallE ty' body' _ _ => - if !(← isDefEq ty ty') then return false + if !(← withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isDefEq ty ty')) then return false + extTypes := extTypes.push ty + extLetValues := extLetValues.push none a := body; b := body' | _, _ => break - isDefEq a b + withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isDefEq a b) -- Application: flatten app spine to avoid O(num_args) stack depth | .app .., .app .. => do @@ -787,13 +815,13 @@ mutual -- eta: (\x => body) =?= s iff body =?= s x where x = bvar 0 let sLifted := s.liftBVars 1 let sApp := Expr.mkApp sLifted (Expr.mkBVar 0) - isDefEq body sApp + withExtendedCtx ty (isDefEq body sApp) | _, .lam ty body _ _ => do -- eta: t =?= (\x => body) iff t x =?= body let tLifted := t.liftBVars 1 let tApp := Expr.mkApp tLifted (Expr.mkBVar 0) - isDefEq tApp body + withExtendedCtx ty (isDefEq tApp body) -- Nat literal vs constructor expansion | .lit (.natVal _), _ => do @@ -830,7 +858,7 @@ mutual two terms are defeq if their types are defeq. -/ partial def isDefEqUnitLike (t s : Expr m) : TypecheckM m Bool := do let kenv := (← read).kenv - let (_, tType) ← infer t + let (_, tType) ← withInferOnly (infer t) let tType' ← whnf tType let fn := tType'.getAppFn match fn with @@ -841,7 +869,7 @@ mutual match kenv.find? v.ctors[0]! with | some (.ctorInfo cv) => if cv.numFields != 0 then return false - let (_, sType) ← infer s + let (_, sType) ← withInferOnly (infer s) isDefEq tType sType | _ => return false | _ => return false diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 7f4d078b..94172711 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -23,6 +23,12 @@ structure TypecheckCtx (m : MetaMode) where /-- Type of each bound variable, indexed by de Bruijn index. types[0] is the type of bvar 0 (most recently bound). -/ types : Array (Expr m) + /-- Let-bound values parallel to `types`. `letValues[i] = some val` means the + binding at position `i` was introduced by a `letE` with value `val`. + `none` means it was introduced by a lambda/forall binder. -/ + letValues : Array (Option (Expr m)) := #[] + /-- Number of let bindings currently in scope (for cache gating). -/ + numLetBindings : Nat := 0 kenv : Env m prims : Primitives safety : DefinitionSafety @@ -73,15 +79,25 @@ def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) def withResetCtx : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with - types := #[], mutTypes := default, recAddr? := none } + types := #[], letValues := #[], numLetBindings := 0, + mutTypes := default, recAddr? := none } def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → Expr m)) compare) : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with mutTypes := mutTypes } -/-- Extend the context with a new bound variable of the given type. -/ +/-- Extend the context with a new bound variable of the given type (lambda/forall). -/ def withExtendedCtx (varType : Expr m) : TypecheckM m α → TypecheckM m α := - withReader fun ctx => { ctx with types := ctx.types.push varType } + withReader fun ctx => { ctx with + types := ctx.types.push varType, + letValues := ctx.letValues.push none } + +/-- Extend the context with a let-bound variable (stores both type and value for zeta-reduction). -/ +def withExtendedLetCtx (varType : Expr m) (val : Expr m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with + types := ctx.types.push varType, + letValues := ctx.letValues.push (some val), + numLetBindings := ctx.numLetBindings + 1 } def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with recAddr? := some addr } @@ -101,7 +117,7 @@ def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do /-- Maximum recursion depth for the mutual isDefEq/whnf/infer cycle. Prevents native stack overflow. Hard error when exceeded. -/ -def maxRecursionDepth : Nat := 2000 +def maxRecursionDepth : Nat := 10000 /-- Check and increment recursion depth. Throws on exceeding limit. -/ def withRecDepthCheck (action : TypecheckM m α) : TypecheckM m α := do diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index 3d176d8d..8b2a90d5 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -388,6 +388,28 @@ where /-- Single substitution: replace bvar 0 with val. -/ def instantiate1 (body val : Expr m) : Expr m := body.instantiate #[val] +/-- Cheap beta reduction: if `e` is `(fun x₁ ... xₙ => body) a₁ ... aₘ`, and `body` is + either a bvar or has no loose bvars, substitute without a full traversal. + Matches lean4lean's `Expr.cheapBetaReduce`. -/ +def cheapBetaReduce (e : Expr m) : Expr m := Id.run do + let fn := e.getAppFn + match fn with + | .lam .. => pure () + | _ => return e + let args := e.getAppArgs + -- Walk lambda binders, counting how many args we can consume + let mut cur := fn + let mut i : Nat := 0 + repeat + if i >= args.size then break + match cur with + | .lam _ body _ _ => cur := body; i := i + 1 + | _ => break + -- cur is the lambda body after consuming i args; substitute + if i == 0 then return e + let body := cur.instantiate (args[:i].toArray.reverse) + return body.mkAppRange i args.size args + /-- Substitute universe level params in an expression's Level nodes using a given level substitution function. -/ partial def instantiateLevelParamsBy (e : Expr m) (substFn : Level m → Level m) : Expr m := diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean index e086230e..f10b1c70 100644 --- a/Ix/Kernel/Whnf.lean +++ b/Ix/Kernel/Whnf.lean @@ -126,7 +126,8 @@ mutual partial def whnfCore (e : Expr m) (cheapRec := false) (cheapProj := false) : TypecheckM m (Expr m) := do -- Cache check FIRST — no stack cost for cache hits - let useCache := !cheapRec && !cheapProj + -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) + let useCache := !cheapRec && !cheapProj && (← read).numLetBindings == 0 if useCache then if let some r := (← get).whnfCoreCache.get? e then return r let r ← whnfCoreImpl e cheapRec cheapProj @@ -169,6 +170,17 @@ mutual let r ← tryReduceApp e' if r == e' then return r -- stuck, return t := r; continue -- iota/quot reduced, loop to re-process + | .bvar idx _ => do + -- Zeta-reduce let-bound bvars: look up the stored value and substitute + let ctx ← read + let depth := ctx.types.size + if idx < depth then + let arrayIdx := depth - 1 - idx + if h : arrayIdx < ctx.letValues.size then + if let some val := ctx.letValues[arrayIdx] then + -- Shift free bvars in val past the intermediate binders + t := val.liftBVars (idx + 1); continue + return t | .letE _ val body _ => t := body.instantiate1 val; continue -- loop instead of recursion | .proj typeAddr idx struct _ => do @@ -197,7 +209,7 @@ mutual let major := args[majorIdx] let major' ← whnf major if isK then - tryKReduction e addr args major' params motives indAddr + tryKReduction e addr args major' params motives minors indices indAddr else tryIotaReduction e addr args major' params indices indAddr rules motives minors else pure e @@ -209,9 +221,10 @@ mutual | _ => pure e | _ => pure e - /-- K-reduction: for Prop inductives with single zero-field constructor. -/ + /-- K-reduction: for Prop inductives with single zero-field constructor. + Returns the (only) minor premise, plus any extra args after the major. -/ partial def tryKReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) - (major : Expr m) (params motives : Nat) (indAddr : Address) + (major : Expr m) (params motives minors indices : Nat) (indAddr : Address) : TypecheckM m (Expr m) := do let ctx ← read let prims := ctx.prims @@ -237,7 +250,12 @@ mutual -- K-reduction: return the (only) minor premise let minorIdx := params + motives if h : minorIdx < args.size then - return args[minorIdx] + let mut result := args[minorIdx] + -- Apply extra args after major premise (matching lean4 kernel behavior) + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + return result pure e else pure e @@ -369,10 +387,16 @@ mutual degrades to whnfCore to prevent native stack overflow. -/ partial def whnf (e : Expr m) : TypecheckM m (Expr m) := do -- Cache check FIRST — no fuel or stack cost for cache hits - if let some r := (← get).whnfCache.get? e then return r + -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) + let useWhnfCache := (← read).numLetBindings == 0 + if useWhnfCache then + if let some r := (← get).whnfCache.get? e then return r withRecDepthCheck do withFuelCheck do - whnfImpl e + let r ← whnfImpl e + if useWhnfCache then + modify fun s => { s with whnfCache := s.whnfCache.insert e r } + pure r partial def whnfImpl (e : Expr m) : TypecheckM m (Expr m) := do let mut t ← whnfCore e @@ -382,6 +406,22 @@ mutual -- Try nat primitive reduction if let some r := ← tryReduceNat t then t ← whnfCore r; steps := steps + 1; continue + -- If head is a nat primitive but args aren't literals, whnf args and retry + match t.getAppFn with + | .const addr _ _ => + if isPrimOp (← read).prims addr then + let args := t.getAppArgs + let mut changed := false + let mut newArgs : Array (Expr m) := #[] + for arg in args do + let arg' ← whnf arg + newArgs := newArgs.push arg' + if arg' != arg then changed := true + if changed then + let t' := t.getAppFn.mkAppN newArgs + if let some r := ← tryReduceNat t' then + t ← whnfCore r; steps := steps + 1; continue + | _ => pure () -- Handle stuck projections (including inside app chains). -- Flatten nested projection chains to avoid deep whnf→whnf recursion. match t.getAppFn with @@ -432,7 +472,6 @@ mutual if let some r := ← unfoldDefinition t then t ← whnfCore r; steps := steps + 1; continue break - modify fun s => { s with whnfCache := s.whnfCache.insert e t } pure t /-- Unfold a single delta step (definition body). -/ diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index ecb2aee1..7dab1364 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -176,7 +176,15 @@ def testConsts : TestSeq := "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go.match_1", "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go", -- Deeply nested let chain (stack overflow regression) - "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go._sunfold" + "_private.Std.Time.Format.Basic.«0».Std.Time.GenericFormat.builderParser.go._sunfold", + -- Let-bound bvar zeta-reduction regression (requires whnf to resolve let-bound bvars) + "Std.Sat.AIG.mkGate", + -- Proof irrelevance regression (requires isProp to check type_of(type_of(t)) == Sort 0) + "Fin.dfoldrM.loop._sunfold", + -- rfl theorem: both sides must be defeq via delta unfolding + "Std.Tactic.BVDecide.BVExpr.eval.eq_10", + -- K-reduction: extra args after major premise must be applied + "UInt8.toUInt64_toUSize" ] let mut passed := 0 let mut failures : Array String := #[] From 92069bcc15e3f0d6346c01053061d2cdcb2e1003 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Thu, 5 Mar 2026 23:12:18 -0500 Subject: [PATCH 13/14] Fix nat reduction to whnf args, use content-based def-eq cache keys - Move tryReduceNat inside mutual block so it can whnf arguments before reducing (matching lean4lean's reduceNat), replacing the separate whnf-args-and-retry loop in whnf - Use isDefEqCore (not isDefEq) for nat literal expansion to avoid cycle where Nat.succ(lit n) gets reduced back to lit(n+1) - Return nat reduction results directly in lazyDeltaReduction instead of looping back through whnfCore - Switch def-eq cache keys from pointer-based to content-based comparison so cache hits work across pointer-distinct copies - Consolidate imax reduction: reuse reduceIMax in reduce, instReduce, and instBulkReduce; add imax(0,b)=b and imax(1,b)=b rules - Simplify K-reduction to only fire when major premise is a constructor - Remove unused depth variable and debug traces --- Ix/Kernel/Infer.lean | 27 +++----- Ix/Kernel/Level.lean | 32 +++------- Ix/Kernel/TypecheckM.lean | 4 +- Ix/Kernel/Types.lean | 8 ++- Ix/Kernel/Whnf.lean | 127 +++++++++++++++++++++----------------- 5 files changed, 99 insertions(+), 99 deletions(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index ffc62d97..5239ab78 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -108,9 +108,7 @@ mutual if contextOk then let te : TypedExpr m := ⟨← infoFromType cachedType, term⟩ return (te, cachedType) - withRecDepthCheck do withFuelCheck do - -- if (← read).trace then dbg_trace s!"infer: {term.tag}" let result ← do match term with | .bvar idx bvarName => do let ctx ← read @@ -665,8 +663,6 @@ mutual | none => pure () withRecDepthCheck do withFuelCheck do - let depth := (← get).recDepth - -- Temporarily removed for call-site tracing -- Loop: steps 1-5 may restart when whnfCore(cheapProj=false) changes terms let mut ct := t @@ -712,7 +708,6 @@ mutual -- If terms changed, loop back to step 1 instead of recursing into isDefEq if !(tnn == tn' && snn == sn') then ct := tnn; cs := snn; continue - -- 6. Structural comparison on fully-reduced terms let result ← isDefEqCore tnn snn cacheResult t s result @@ -733,10 +728,7 @@ mutual partial def isDefEqProofIrrel (t s : Expr m) : TypecheckM m (Option Bool) := do let tType ← try let (_, ty) ← withInferOnly (infer t); pure (some ty) catch _ => pure none let some tType := tType | return none - let isPropType ← try isProp tType catch e => do - if (← get).recDepth > 100 then - dbg_trace s!"isProp FAILED at depth {(← get).recDepth}: {e}" - pure false + let isPropType ← try isProp tType catch _ => pure false if !isPropType then return none let sType ← try let (_, ty) ← withInferOnly (infer s); pure (some ty) catch _ => pure none let some sType := sType | return none @@ -823,18 +815,19 @@ mutual let tApp := Expr.mkApp tLifted (Expr.mkBVar 0) withExtendedCtx ty (isDefEq tApp body) - -- Nat literal vs constructor expansion + -- Nat literal vs non-literal: expand to constructor form but stay in isDefEqCore + -- (calling full isDefEq would reduce Nat.succ(lit n) back to lit(n+1), causing a cycle) | .lit (.natVal _), _ => do let prims := (← read).prims let expanded := toCtorIfLit prims t if expanded == t then pure false - else isDefEq expanded s + else isDefEqCore expanded s | _, .lit (.natVal _) => do let prims := (← read).prims let expanded := toCtorIfLit prims s if expanded == s then pure false - else isDefEq t expanded + else isDefEqCore t expanded -- String literal vs constructor expansion | .lit (.strVal str), _ => do @@ -923,11 +916,11 @@ mutual | some result => return (tn, sn, some result) | none => pure () - -- Try nat reduction - if let some r := ← tryReduceNat tn then - tn ← whnfCore r (cheapProj := true); continue - if let some r := ← tryReduceNat sn then - sn ← whnfCore r (cheapProj := true); continue + -- Try nat reduction (whnf's args like lean4lean's reduceNat) + if let some tn' ← tryReduceNat tn then + return (tn', sn, some (← isDefEq tn' sn)) + if let some sn' ← tryReduceNat sn then + return (tn, sn', some (← isDefEq tn sn')) -- Lazy delta step let tDelta := isDelta tn kenv diff --git a/Ix/Kernel/Level.lean b/Ix/Kernel/Level.lean index f22bcb53..43b34b9d 100644 --- a/Ix/Kernel/Level.lean +++ b/Ix/Kernel/Level.lean @@ -27,21 +27,20 @@ def reduceIMax (a b : Level m) : Level m := match b with | .zero => .zero | .succ _ => reduceMax a b - | .param idx _ => match a with - | .param idx' _ => if idx == idx' then a else .imax a b + | _ => + match a with + | .zero => b + | .succ .zero => b -- imax(1, b) = b + | .param idx' _ => match b with + | .param idx _ => if idx == idx' then a else .imax a b + | _ => .imax a b | _ => .imax a b - | _ => .imax a b /-- Reduce a level to normal form. -/ def reduce : Level m → Level m | .succ u => .succ (reduce u) | .max a b => reduceMax (reduce a) (reduce b) - | .imax a b => - let b' := reduce b - match b' with - | .zero => .zero - | .succ _ => reduceMax (reduce a) b' - | _ => .imax (reduce a) b' + | .imax a b => reduceIMax (reduce a) (reduce b) | u => u /-! ## Instantiation -/ @@ -52,13 +51,7 @@ def instReduce (u : Level m) (idx : Nat) (subst : Level m) : Level m := match u with | .succ u => .succ (instReduce u idx subst) | .max a b => reduceMax (instReduce a idx subst) (instReduce b idx subst) - | .imax a b => - let a' := instReduce a idx subst - let b' := instReduce b idx subst - match b' with - | .zero => .zero - | .succ _ => reduceMax a' b' - | _ => .imax a' b' + | .imax a b => reduceIMax (instReduce a idx subst) (instReduce b idx subst) | .param idx' _ => if idx' == idx then subst else u | .zero => u @@ -68,12 +61,7 @@ def instBulkReduce (substs : Array (Level m)) : Level m → Level m | z@(.zero ..) => z | .succ u => .succ (instBulkReduce substs u) | .max a b => reduceMax (instBulkReduce substs a) (instBulkReduce substs b) - | .imax a b => - let b' := instBulkReduce substs b - match b' with - | .zero => .zero - | .succ _ => reduceMax (instBulkReduce substs a) b' - | _ => .imax (instBulkReduce substs a) b' + | .imax a b => reduceIMax (instBulkReduce substs a) (instBulkReduce substs b) | .param idx name => if h : idx < substs.size then substs[idx] else .param (idx - substs.size) name diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 94172711..5ead128e 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -193,8 +193,8 @@ def ensureTypedConst (addr : Address) : TypecheckM m Unit := do /-! ## Def-eq cache helpers -/ -/-- Symmetric cache key for def-eq pairs. Orders by pointer address to make key(a,b) == key(b,a). -/ +/-- Symmetric cache key for def-eq pairs. Orders by content to make key(a,b) == key(b,a). -/ def eqCacheKey (a b : Expr m) : Expr m × Expr m := - if Expr.ptrCompare a b != .gt then (a, b) else (b, a) + if Expr.compare a b != .gt then (a, b) else (b, a) end Ix.Kernel diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index 8b2a90d5..a9e95818 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -629,10 +629,12 @@ private unsafe def Expr.ptrCompareUnsafe (a : @& Expr m) (b : @& Expr m) : Order @[implemented_by Expr.ptrCompareUnsafe] opaque Expr.ptrCompare : @& Expr m → @& Expr m → Ordering -/-- Compare pairs of expressions by pointer address (first component, then second). -/ +/-- Compare pairs of expressions by content (first component, then second). + Uses structural `Expr.compare` so the failure cache works across pointer-distinct + copies of the same expression. -/ def Expr.pairCompare (a b : Expr m × Expr m) : Ordering := - match Expr.ptrCompare a.1 b.1 with - | .eq => Expr.ptrCompare a.2 b.2 + match Expr.compare a.1 b.1 with + | .eq => Expr.compare a.2 b.2 | ord => ord /-! ## Enums -/ diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean index f10b1c70..cbb17621 100644 --- a/Ix/Kernel/Whnf.lean +++ b/Ix/Kernel/Whnf.lean @@ -29,8 +29,9 @@ def listGet? (l : List α) (n : Nat) : Option α := /-! ## Nat primitive reduction on Expr -/ -/-- Try to reduce a Nat primitive applied to literal arguments. Returns the reduced Expr. -/ -def tryReduceNat (e : Expr m) : TypecheckM m (Option (Expr m)) := do +/-- Try to reduce a Nat primitive applied to literal arguments (no whnf on args). + Used in lazyDeltaReduction where args are already partially reduced. -/ +def tryReduceNatLit (e : Expr m) : TypecheckM m (Option (Expr m)) := do let fn := e.getAppFn match fn with | .const addr _ _ => @@ -222,42 +223,34 @@ mutual | _ => pure e /-- K-reduction: for Prop inductives with single zero-field constructor. - Returns the (only) minor premise, plus any extra args after the major. -/ + Returns the (only) minor premise, plus any extra args after the major. + Only fires when the major premise has already been reduced to a constructor. + (lean4lean's toCtorWhenK also handles non-constructor majors by checking + indices via isDefEq, but that requires infer/isDefEq which are in a + separate mutual block. The whnf of the major should handle most cases.) -/ partial def tryKReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) - (major : Expr m) (params motives minors indices : Nat) (indAddr : Address) + (major : Expr m) (params motives minors indices : Nat) (_indAddr : Address) : TypecheckM m (Expr m) := do + -- Check if major is a constructor (including nat literal → ctor conversion) let ctx ← read - let prims := ctx.prims - let kenv := ctx.kenv - -- Check if major is a constructor - let majorCtor := toCtorIfLit prims major + let majorCtor := toCtorIfLit ctx.prims major let isCtor := match majorCtor.getAppFn with | .const ctorAddr _ _ => - match kenv.find? ctorAddr with + match ctx.kenv.find? ctorAddr with | some (.ctorInfo _) => true | _ => false | _ => false - -- Also check if the inductive is in Prop - let isPropInd := match kenv.find? indAddr with - | some (.inductInfo v) => - let rec getSort : Expr m → Bool - | .forallE _ body _ _ => getSort body - | .sort (.zero) => true - | _ => false - getSort v.type - | _ => false - if isCtor || isPropInd then - -- K-reduction: return the (only) minor premise - let minorIdx := params + motives - if h : minorIdx < args.size then - let mut result := args[minorIdx] - -- Apply extra args after major premise (matching lean4 kernel behavior) - let majorIdx := params + motives + minors + indices - if majorIdx + 1 < args.size then - result := result.mkAppRange (majorIdx + 1) args.size args - return result - pure e - else pure e + if !isCtor then return e + -- K-reduction: return the (only) minor premise + let minorIdx := params + motives + if h : minorIdx < args.size then + let mut result := args[minorIdx] + -- Apply extra args after major premise (matching lean4 kernel behavior) + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + return result + pure e /-- Iota-reduction: reduce a recursor applied to a constructor. Follows the lean4 algorithm: @@ -377,14 +370,54 @@ mutual | _ => return e else return e - /-- Full WHNF with delta unfolding loop. - whnfCore handles structural reduction (beta, let, iota, cheap proj). - This loop adds: nat primitives, stuck projection resolution, delta unfolding. - Projection chains are flattened to avoid deep recursion: - proj₁(proj₂(proj₃(struct))) → strip all projs, whnf(struct) ONCE, - then resolve projections iteratively from inside out. - Tracks nesting depth: when whnf calls nest too deep (from isDefEq ↔ whnf cycles), - degrades to whnfCore to prevent native stack overflow. -/ + /-- Try to reduce a Nat primitive, whnf'ing args if needed (like lean4lean's reduceNat). + Inside the mutual block so it can call `whnf` on arguments. -/ + partial def tryReduceNat (e : Expr m) : TypecheckM m (Option (Expr m)) := do + let fn := e.getAppFn + match fn with + | .const addr _ _ => + let prims := (← read).prims + if !isPrimOp prims addr then return none + let args := e.getAppArgs + -- Nat.succ: 1 arg + if addr == prims.natSucc then + if args.size >= 1 then + let a ← whnf args[0]! + match a with + | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) + | _ => return none + else return none + -- Binary nat operations: 2 args, whnf both (matches lean4lean reduceBinNatOp) + else if args.size >= 2 then + let a ← whnf args[0]! + let b ← whnf args[1]! + match a, b with + | .lit (.natVal x), .lit (.natVal y) => + if addr == prims.natAdd then return some (.lit (.natVal (x + y))) + else if addr == prims.natSub then return some (.lit (.natVal (x - y))) + else if addr == prims.natMul then return some (.lit (.natVal (x * y))) + else if addr == prims.natPow then + if y > 16777216 then return none + return some (.lit (.natVal (Nat.pow x y))) + else if addr == prims.natMod then return some (.lit (.natVal (x % y))) + else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) + else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) + else if addr == prims.natBeq then + let boolAddr := if x == y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natBle then + let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) + else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) + else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) + else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) + else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) + else return none + | _, _ => return none + else return none + | _ => return none + partial def whnf (e : Expr m) : TypecheckM m (Expr m) := do -- Cache check FIRST — no fuel or stack cost for cache hits -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) @@ -403,25 +436,9 @@ mutual let mut steps := 0 repeat if steps > 10000 then break -- safety bound - -- Try nat primitive reduction + -- Try nat primitive reduction (whnf's args like lean4lean's reduceNat) if let some r := ← tryReduceNat t then t ← whnfCore r; steps := steps + 1; continue - -- If head is a nat primitive but args aren't literals, whnf args and retry - match t.getAppFn with - | .const addr _ _ => - if isPrimOp (← read).prims addr then - let args := t.getAppArgs - let mut changed := false - let mut newArgs : Array (Expr m) := #[] - for arg in args do - let arg' ← whnf arg - newArgs := newArgs.push arg' - if arg' != arg then changed := true - if changed then - let t' := t.getAppFn.mkAppN newArgs - if let some r := ← tryReduceNat t' then - t ← whnfCore r; steps := steps + 1; continue - | _ => pure () -- Handle stuck projections (including inside app chains). -- Flatten nested projection chains to avoid deep whnf→whnf recursion. match t.getAppFn with From 7abd736a0b311f8faf8fa37740c407cdbb10c7a0 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 6 Mar 2026 15:39:51 -0500 Subject: [PATCH 14/14] Add primitive validation, recursor rule type checking, and merge whnf into mutual block Move whnf/whnfCore/unfoldDefinition from Whnf.lean into the Infer.lean mutual block so they can call infer/isDefEq (needed for toCtorWhenK, isProp in struct-eta, and checkRecursorRuleType). Add new Primitive.lean module that validates Bool/Nat inductives and primitive definitions (add/sub/mul/pow/beq/ble/shiftLeft/shiftRight/land/lor/xor/pred/charMk/ stringOfList) against their expected types and reduction rules. Validate Eq and Quot type signatures at quotient init time. Key changes: - checkRecursorRuleType: builds expected type from recursor + ctor types, handles nested inductives (cnp > np) with level/bvar substitution - checkElimLevel: validates large elimination for Prop inductives - toCtorWhenK: infers major's type and constructs nullary ctor (was stub) - tryEtaStruct: now symmetric (tries both directions) with type check - isDefEq: add Bool.true proof-by-reflection, fix bvar quick check to return none (not false) on mismatch, add eta-struct fallback for apps - Safety/universe validation in infer .const, withSafety in checkConst - Constructor param domain matching and return type validation - Hardcode ~30 new primitive addresses in buildPrimitives - Add unit tests for toCtorIfLit/strLitToConstructor/isPrimOp/foldLiterals - Add soundness tests for mutual recursors, parametric/nested recursors - Previously failing RCasesPatt.rec_1 now passes --- Ix/Kernel.lean | 1 + Ix/Kernel/Convert.lean | 3 +- Ix/Kernel/Infer.lean | 944 +++++++++++++++++++++++++++++++-- Ix/Kernel/Primitive.lean | 402 ++++++++++++++ Ix/Kernel/TypecheckM.lean | 3 + Ix/Kernel/Types.lean | 75 ++- Ix/Kernel/Whnf.lean | 398 +------------- Tests/Ix/Kernel/Helpers.lean | 15 +- Tests/Ix/Kernel/Soundness.lean | 403 ++++++++++++++ Tests/Ix/Kernel/Unit.lean | 85 +++ Tests/Ix/KernelTests.lean | 86 ++- 11 files changed, 1930 insertions(+), 485 deletions(-) create mode 100644 Ix/Kernel/Primitive.lean diff --git a/Ix/Kernel.lean b/Ix/Kernel.lean index 2ce31362..c76129c8 100644 --- a/Ix/Kernel.lean +++ b/Ix/Kernel.lean @@ -8,6 +8,7 @@ import Ix.Kernel.TypecheckM import Ix.Kernel.Whnf import Ix.Kernel.DefEq import Ix.Kernel.Infer +import Ix.Kernel.Primitive import Ix.Kernel.Convert namespace Ix.Kernel diff --git a/Ix/Kernel/Convert.lean b/Ix/Kernel/Convert.lean index 6d0ebb5e..46b80b01 100644 --- a/Ix/Kernel/Convert.lean +++ b/Ix/Kernel/Convert.lean @@ -825,7 +825,8 @@ def convertChunk (m : MetaMode) (hashToAddr : Std.HashMap Address Address) constants not in named. Groups projections by block and parallelizes. -/ def convertEnv (m : MetaMode) (ixonEnv : Ixon.Env) (numWorkers : Nat := 32) : Except String (Ix.Kernel.Env m × Primitives × Bool) := - -- Build primitives with quot addresses + -- Build primitives with quot addresses and name-based lookup for extra addresses + -- Build primitives: hardcoded addresses + Quot from .quot tags let prims : Primitives := Id.run do let mut p := buildPrimitives for (addr, c) in ixonEnv.consts do diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 5239ab78..bce1b3d5 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -4,9 +4,68 @@ Environment-based kernel: types are Exprs, uses whnf/isDefEq. -/ import Ix.Kernel.DefEq +import Ix.Kernel.Primitive namespace Ix.Kernel +/-! ## Recursor rule type helpers -/ + +/-- Shift bvar indices and level params in an expression from a constructor context + to a recursor rule context. + - `fieldDepth`: number of field binders above this expr in the ctor type + - `bvarShift`: amount to shift param bvar refs (= numMotives + numMinors) + - `levelShift`: amount to shift Level.param indices (= recLevelCount - ctorLevelCount) + Bvar i at depth d is a param ref when i >= d + fieldDepth. -/ +partial def shiftCtorToRule (e : Expr m) (fieldDepth : Nat) (bvarShift : Nat) (levelSubst : Array (Level m)) : Expr m := + if bvarShift == 0 && levelSubst.size == 0 then e else go e 0 +where + substLevel : Level m → Level m + | .param i n => if h : i < levelSubst.size then levelSubst[i] else .param i n + | .succ l => .succ (substLevel l) + | .max a b => .max (substLevel a) (substLevel b) + | .imax a b => .imax (substLevel a) (substLevel b) + | l => l + go (e : Expr m) (depth : Nat) : Expr m := + match e with + | .bvar i n => + if i >= depth + fieldDepth then .bvar (i + bvarShift) n + else e + | .app fn arg => .app (go fn depth) (go arg depth) + | .lam ty body n bi => .lam (go ty depth) (go body (depth + 1)) n bi + | .forallE ty body n bi => .forallE (go ty depth) (go body (depth + 1)) n bi + | .letE ty val body n => .letE (go ty depth) (go val depth) (go body (depth + 1)) n + | .proj ta idx s n => .proj ta idx (go s depth) n + | .sort l => .sort (substLevel l) + | .const addr lvls name => .const addr (lvls.map substLevel) name + | _ => e + +/-- Substitute extra nested param bvars in a constructor body expression. + After peeling `cnp` params from the ctor type, extra param bvars occupy + indices `fieldDepth..fieldDepth+numExtra-1` at depth 0 (they are the innermost + free param bvars, below the shared params). Replace them with `vals` and + shift shared param bvars down by `numExtra` to close the gap. + - `fieldDepth`: number of field binders enclosing this expr (0 for return type) + - `numExtra`: number of extra nested params (cnp - np) + - `vals`: replacement values (already shifted for the rule context) -/ +partial def substNestedParams (e : Expr m) (fieldDepth : Nat) (numExtra : Nat) (vals : Array (Expr m)) : Expr m := + if numExtra == 0 then e else go e 0 +where + go (e : Expr m) (depth : Nat) : Expr m := + match e with + | .bvar i n => + let freeIdx := i - (depth + fieldDepth) -- which param bvar (0 = innermost extra) + if i < depth + fieldDepth then e -- bound by field/local binder + else if freeIdx < numExtra then + -- Extra nested param: substitute with vals[freeIdx] shifted up by depth + shiftCtorToRule vals[freeIdx]! 0 depth #[] + else .bvar (i - numExtra) n -- Shared param: shift down + | .app fn arg => .app (go fn depth) (go arg depth) + | .lam ty body n bi => .lam (go ty depth) (go body (depth + 1)) n bi + | .forallE ty body n bi => .forallE (go ty depth) (go body (depth + 1)) n bi + | .letE ty val body n => .letE (go ty depth) (go val depth) (go body (depth + 1)) n + | .proj ta idx s n => .proj ta idx (go s depth) n + | _ => e + /-! ## Inductive validation helpers -/ /-- Check if an expression mentions a constant at the given address. -/ @@ -42,6 +101,23 @@ where | .sort lvl => some lvl | _ => none +/-- Extract the motive's return sort from a recursor type. + Walks past numParams Pi binders, then walks the motive's domain to the final Sort. -/ +def getMotiveSort (recType : Expr m) (numParams : Nat) : Option (Level m) := + go recType numParams +where + go (ty : Expr m) : Nat → Option (Level m) + | 0 => match ty with + | .forallE motiveDom _ _ _ => walkToSort motiveDom + | _ => none + | n+1 => match ty with + | .forallE _ body _ _ => go body n + | _ => none + walkToSort : Expr m → Option (Level m) + | .forallE _ body _ _ => walkToSort body + | .sort lvl => some lvl + | _ => none + /-- Check if a level is definitively non-zero (always >= 1). -/ partial def levelIsNonZero : Level m → Bool | .succ _ => true @@ -60,31 +136,457 @@ def piInfo (dom img : TypeInfo m) : TypeInfo m := match dom, img with | .sort lvl, .sort lvl' => .sort (Level.reduceIMax lvl lvl') | _, _ => .none -/-- Infer TypeInfo from a type expression (after whnf). -/ -def infoFromType (typ : Expr m) : TypecheckM m (TypeInfo m) := do - let typ' ← whnf typ - match typ' with - | .sort (.zero) => pure .proof - | .sort lvl => pure (.sort lvl) - | .app .. => - let head := typ'.getAppFn - match head with - | .const addr _ _ => - match (← read).kenv.find? addr with - | some (.inductInfo v) => - if v.ctors.size == 1 then - match (← read).kenv.find? v.ctors[0]! with - | some (.ctorInfo cv) => - if cv.numFields == 0 then pure .unit else pure .none - | _ => pure .none - else pure .none +mutual + /-- Infer TypeInfo from a type expression (after whnf). -/ + partial def infoFromType (typ : Expr m) : TypecheckM m (TypeInfo m) := do + let typ' ← whnf typ + match typ' with + | .sort (.zero) => pure .proof + | .sort lvl => pure (.sort lvl) + | .app .. => + let head := typ'.getAppFn + match head with + | .const addr _ _ => + match (← read).kenv.find? addr with + | some (.inductInfo v) => + if v.ctors.size == 1 then + match (← read).kenv.find? v.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields == 0 then pure .unit else pure .none + | _ => pure .none + else pure .none + | _ => pure .none | _ => pure .none | _ => pure .none - | _ => pure .none -/-! ## Inference / Checking -/ + -- WHNF (moved from Whnf.lean to share mutual block with infer/isDefEq) + + /-- Structural WHNF: beta, let-zeta, iota-proj. No delta unfolding. + Uses an iterative loop to avoid deep stack usage: + - App spines are collected iteratively (not recursively) + - Beta/let/iota/proj results loop back instead of tail-calling + When cheapProj=true, projections are returned as-is (no struct reduction). + When cheapRec=true, recursor applications are returned as-is (no iota reduction). -/ + partial def whnfCore (e : Expr m) (cheapRec := false) (cheapProj := false) + : TypecheckM m (Expr m) := do + -- Cache check FIRST — no stack cost for cache hits + -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) + let useCache := !cheapRec && !cheapProj && (← read).numLetBindings == 0 + if useCache then + if let some r := (← get).whnfCoreCache.get? e then return r + let r ← whnfCoreImpl e cheapRec cheapProj + if useCache then + modify fun s => { s with whnfCoreCache := s.whnfCoreCache.insert e r } + pure r + + partial def whnfCoreImpl (e : Expr m) (cheapRec : Bool) (cheapProj : Bool) + : TypecheckM m (Expr m) := do + let mut t := e + repeat + -- Fuel check + let stt ← get + if stt.fuel == 0 then throw "deep recursion fuel limit reached" + modify fun s => { s with fuel := s.fuel - 1 } + match t with + | .app .. => do + -- Collect app args iteratively (O(1) stack for app spine) + let args := t.getAppArgs + let fn := t.getAppFn + let fn' ← whnfCore fn cheapRec cheapProj -- recurse only on non-app head + -- Beta-reduce: consume as many args as possible + let mut result := fn' + let mut i : Nat := 0 + while i < args.size do + match result with + | .lam _ body _ _ => + result := body.instantiate1 args[i]! + i := i + 1 + | _ => break + if i > 0 then + -- Beta reductions happened. Apply remaining args and loop. + for h : j in [i:args.size] do + result := Expr.mkApp result args[j]! + t := result; continue -- loop instead of recursive tail call + else + -- No beta reductions. Try recursor/proj reduction. + let e' := if fn == fn' then t else fn'.mkAppN args + if cheapRec then return e' -- skip recursor reduction + let r ← tryReduceApp e' + if r == e' then return r -- stuck, return + t := r; continue -- iota/quot reduced, loop to re-process + | .bvar idx _ => do + -- Zeta-reduce let-bound bvars: look up the stored value and substitute + let ctx ← read + let depth := ctx.types.size + if idx < depth then + let arrayIdx := depth - 1 - idx + if h : arrayIdx < ctx.letValues.size then + if let some val := ctx.letValues[arrayIdx] then + -- Shift free bvars in val past the intermediate binders + t := val.liftBVars (idx + 1); continue + return t + | .letE _ val body _ => + t := body.instantiate1 val; continue -- loop instead of recursion + | .proj typeAddr idx struct _ => do + -- cheapProj=true: try structural-only reduction (whnfCore, no delta) + -- cheapProj=false: full reduction (whnf, with delta) + let struct' ← if cheapProj then whnfCore struct cheapRec cheapProj else whnf struct + match ← reduceProj typeAddr idx struct' with + | some result => t := result; continue -- loop instead of recursion + | none => + return if struct == struct' then t else .proj typeAddr idx struct' default + | _ => return t + return t -- unreachable, but needed for type checking + + /-- Try to reduce an application whose head is in WHNF. + Handles recursor iota-reduction and quotient reduction. -/ + partial def tryReduceApp (e : Expr m) : TypecheckM m (Expr m) := do + let fn := e.getAppFn + match fn with + | .const addr _ _ => do + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.recursor _ params motives minors indices isK indAddr rules) => + let args := e.getAppArgs + let majorIdx := params + motives + minors + indices + if h : majorIdx < args.size then + let major := args[majorIdx] + let major' ← whnf major + if isK then + tryKReduction e addr args major' params motives minors indices indAddr + else + tryIotaReduction e addr args major' params indices indAddr rules motives minors + else pure e + | some (.quotient _ kind) => + match kind with + | .lift => tryQuotReduction e 6 3 + | .ind => tryQuotReduction e 5 3 + | _ => pure e + | _ => pure e + | _ => pure e + + /-- K-reduction: for Prop inductives with single zero-field constructor. + Returns the (only) minor premise, plus any extra args after the major. + When the major is not a constructor, tries toCtorWhenK: infers the major's type, + checks it matches the inductive, and constructs the nullary constructor. -/ + partial def tryKReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) + (major : Expr m) (params motives minors indices : Nat) (indAddr : Address) + : TypecheckM m (Expr m) := do + -- Check if major is a constructor (including nat literal → ctor conversion) + let ctx ← read + let majorCtor := toCtorIfLit ctx.prims major + let isCtor := match majorCtor.getAppFn with + | .const ctorAddr _ _ => + match ctx.kenv.find? ctorAddr with + | some (.ctorInfo _) => true + | _ => false + | _ => false + if !isCtor then + -- toCtorWhenK: verify the major's type matches the K-inductive. + -- K-types have zero fields, so the ctor itself isn't needed — we just return the minor. + match ← toCtorWhenK major indAddr with + | some _ => pure () -- type matches, fall through to K-reduction + | none => return e + -- K-reduction: return the (only) minor premise + let minorIdx := params + motives + if h : minorIdx < args.size then + let mut result := args[minorIdx] + -- Apply extra args after major premise (matching lean4 kernel behavior) + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + return result + pure e + + /-- For K-like inductives, try to construct the nullary constructor from the major's type. + Infers the major's type, checks it matches the inductive, and returns the constructor. + Matches lean4lean's `toCtorWhenK` / lean4 C++ `to_cnstr_when_K`. -/ + partial def toCtorWhenK (major : Expr m) (indAddr : Address) : TypecheckM m (Option (Expr m)) := do + let kenv := (← read).kenv + match kenv.find? indAddr with + | some (.inductInfo iv) => + if iv.ctors.isEmpty then return none + let ctorAddr := iv.ctors[0]! + -- Infer major's type and check it matches the inductive + let (_, majorType) ← try withInferOnly (infer major) catch _ => return none + let majorType' ← whnf majorType + let majorHead := majorType'.getAppFn + match majorHead with + | .const headAddr _ _ => + if headAddr != indAddr then return none + -- Construct the nullary constructor applied to params from the type + let typeArgs := majorType'.getAppArgs + let ctorUnivs := majorHead.constLevels! + let mut ctor : Expr m := Expr.mkConst ctorAddr ctorUnivs + -- Apply params (first numParams args of the type) + for i in [:iv.numParams] do + if i < typeArgs.size then + ctor := Expr.mkApp ctor typeArgs[i]! + -- Verify ctor type matches major type (prevents K-reduction when indices differ) + let (_, ctorType) ← try withInferOnly (infer ctor) catch _ => return none + if !(← isDefEq majorType' ctorType) then return none + return some ctor + | _ => return none + | _ => return none + + /-- Iota-reduction: reduce a recursor applied to a constructor. + Follows the lean4 algorithm: + 1. Apply params + motives + minors from recursor args to rule RHS + 2. Apply constructor fields (skip constructor params) to rule RHS + 3. Apply extra args after major premise to rule RHS + Beta reduction happens in the subsequent whnfCore call. -/ + partial def tryIotaReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) + (major : Expr m) (params indices : Nat) (indAddr : Address) + (rules : Array (Nat × TypedExpr m)) + (motives minors : Nat) : TypecheckM m (Expr m) := do + let prims := (← read).prims + let majorCtor := toCtorIfLit prims major + let majorFn := majorCtor.getAppFn + match majorFn with + | .const ctorAddr _ _ => do + let kenv := (← read).kenv + let typedConsts := (← get).typedConsts + let ctorInfo? := match kenv.find? ctorAddr with + | some (.ctorInfo cv) => some (cv.cidx, cv.numFields) + | _ => + match typedConsts.get? ctorAddr with + | some (.constructor _ ctorIdx numFields) => some (ctorIdx, numFields) + | _ => none + match ctorInfo? with + | some (ctorIdx, _) => + match rules[ctorIdx]? with + | some (nfields, rhs) => + let majorArgs := majorCtor.getAppArgs + if nfields > majorArgs.size then return e + -- Instantiate universe level params in the rule RHS + let recFn := e.getAppFn + let recLevels := recFn.constLevels! + let mut result := rhs.body.instantiateLevelParams recLevels + -- Phase 1: Apply params + motives + minors from recursor args + let pmmEnd := params + motives + minors + result := result.mkAppRange 0 pmmEnd args + -- Phase 2: Apply constructor fields (skip constructor's own params) + let ctorParamCount := majorArgs.size - nfields + result := result.mkAppRange ctorParamCount majorArgs.size majorArgs + -- Phase 3: Apply remaining arguments after major premise + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + pure result -- return raw result; whnfCore's loop will re-process + | none => pure e + | none => + -- Not a constructor, try structure eta + tryStructEta e args params indices indAddr rules major motives minors + | _ => + tryStructEta e args params indices indAddr rules major motives minors + + /-- Structure eta: expand struct-like major via projections. + Skips Prop structures (proof irrelevance handles those; projections may not reduce). -/ + partial def tryStructEta (e : Expr m) (args : Array (Expr m)) + (params : Nat) (indices : Nat) (indAddr : Address) + (rules : Array (Nat × TypedExpr m)) (major : Expr m) + (motives minors : Nat) : TypecheckM m (Expr m) := do + let kenv := (← read).kenv + if !kenv.isStructureLike indAddr then return e + -- Skip Prop structures: proof irrelevance handles them, projections may not reduce. + let (_, majorType) ← try withInferOnly (infer major) catch _ => return e + if ← (try isProp majorType catch _ => pure false) then return e + match rules[0]? with + | some (nfields, rhs) => + let recFn := e.getAppFn + let recLevels := recFn.constLevels! + let mut result := rhs.body.instantiateLevelParams recLevels + -- Phase 1: params + motives + minors + let pmmEnd := params + motives + minors + result := result.mkAppRange 0 pmmEnd args + -- Phase 2: projections as fields + let mut projArgs : Array (Expr m) := #[] + for i in [:nfields] do + projArgs := projArgs.push (Expr.mkProj indAddr i major) + result := projArgs.foldl (fun acc a => Expr.mkApp acc a) result + -- Phase 3: extra args after major + let majorIdx := params + motives + minors + indices + if majorIdx + 1 < args.size then + result := result.mkAppRange (majorIdx + 1) args.size args + pure result -- return raw result; whnfCore's loop will re-process + | none => pure e + + /-- Quotient reduction: Quot.lift / Quot.ind. + For Quot.lift: `@Quot.lift α r β f h q` — reduceSize=6, fPos=3 (f is at index 3) + For Quot.ind: `@Quot.ind α r β f q` — reduceSize=5, fPos=3 (f is at index 3) + When major (q) reduces to `@Quot.mk α r a`, result is `f a`. -/ + partial def tryQuotReduction (e : Expr m) (reduceSize fPos : Nat) : TypecheckM m (Expr m) := do + let args := e.getAppArgs + if args.size < reduceSize then return e + let majorIdx := reduceSize - 1 + if h : majorIdx < args.size then + let major := args[majorIdx] + let major' ← whnf major + let majorFn := major'.getAppFn + match majorFn with + | .const majorAddr _ _ => + ensureTypedConst majorAddr + match (← get).typedConsts.get? majorAddr with + | some (.quotient _ .ctor) => + let majorArgs := major'.getAppArgs + -- Quot.mk has 3 args: [α, r, a]. The data 'a' is the last one. + if majorArgs.size < 3 then throw "Quot.mk should have at least 3 args" + let dataArg := majorArgs[majorArgs.size - 1]! + if h2 : fPos < args.size then + let f := args[fPos] + let result := Expr.mkApp f dataArg + -- Apply any extra args after the major premise + let result := if majorIdx + 1 < args.size then + result.mkAppRange (majorIdx + 1) args.size args + else result + pure result -- return raw result; whnfCore's loop will re-process + else return e + | _ => return e + | _ => return e + else return e + + /-- Try to reduce a Nat primitive, whnf'ing args if needed (like lean4lean's reduceNat). + Inside the mutual block so it can call `whnf` on arguments. -/ + partial def tryReduceNat (e : Expr m) : TypecheckM m (Option (Expr m)) := do + let fn := e.getAppFn + match fn with + | .const addr _ _ => + let prims := (← read).prims + if !isPrimOp prims addr then return none + let args := e.getAppArgs + -- Nat.succ: 1 arg + if addr == prims.natSucc then + if args.size >= 1 then + let a ← whnf args[0]! + match a with + | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) + | _ => return none + else return none + -- Binary nat operations: 2 args, whnf both (matches lean4lean reduceBinNatOp) + else if args.size >= 2 then + let a ← whnf args[0]! + let b ← whnf args[1]! + match a, b with + | .lit (.natVal x), .lit (.natVal y) => + if addr == prims.natAdd then return some (.lit (.natVal (x + y))) + else if addr == prims.natSub then return some (.lit (.natVal (x - y))) + else if addr == prims.natMul then return some (.lit (.natVal (x * y))) + else if addr == prims.natPow then + if y > 16777216 then return none + return some (.lit (.natVal (Nat.pow x y))) + else if addr == prims.natMod then return some (.lit (.natVal (x % y))) + else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) + else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) + else if addr == prims.natBeq then + let boolAddr := if x == y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natBle then + let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse + return some (Expr.mkConst boolAddr #[]) + else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) + else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) + else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) + else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) + else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) + else return none + | _, _ => return none + else return none + | _ => return none + + partial def whnf (e : Expr m) : TypecheckM m (Expr m) := do + -- Cache check FIRST — no fuel or stack cost for cache hits + -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) + let useWhnfCache := (← read).numLetBindings == 0 + if useWhnfCache then + if let some r := (← get).whnfCache.get? e then return r + withRecDepthCheck do + withFuelCheck do + let r ← whnfImpl e + if useWhnfCache then + modify fun s => { s with whnfCache := s.whnfCache.insert e r } + pure r + + partial def whnfImpl (e : Expr m) : TypecheckM m (Expr m) := do + let mut t ← whnfCore e + let mut steps := 0 + repeat + if steps > 10000 then throw "whnf delta step limit (10000) exceeded" + -- Try nat primitive reduction (whnf's args like lean4lean's reduceNat) + if let some r := ← tryReduceNat t then + t ← whnfCore r; steps := steps + 1; continue + -- Handle stuck projections (including inside app chains). + -- Flatten nested projection chains to avoid deep whnf→whnf recursion. + match t.getAppFn with + | .proj _ _ _ _ => + -- Collect the projection chain from outside in + let mut projStack : Array (Address × Nat × Array (Expr m)) := #[] + let mut inner := t + repeat + match inner.getAppFn with + | .proj typeAddr idx struct _ => + projStack := projStack.push (typeAddr, idx, inner.getAppArgs) + inner := struct + | _ => break + -- Reduce the innermost struct with depth-guarded whnf + let innerReduced ← whnf inner + -- Resolve projections from inside out (last pushed = innermost) + let mut current := innerReduced + let mut allResolved := true + let mut i := projStack.size + while i > 0 do + i := i - 1 + let (typeAddr, idx, args) := projStack[i]! + match ← reduceProj typeAddr idx current with + | some result => + let applied := if args.isEmpty then result else result.mkAppN args + current ← whnfCore applied + | none => + -- This projection couldn't be resolved. Reconstruct remaining chain. + let stuck := if args.isEmpty then + Expr.mkProj typeAddr idx current + else + (Expr.mkProj typeAddr idx current).mkAppN args + current ← whnfCore stuck + -- Reconstruct outer projections + while i > 0 do + i := i - 1 + let (ta, ix, as) := projStack[i]! + current := if as.isEmpty then + Expr.mkProj ta ix current + else + (Expr.mkProj ta ix current).mkAppN as + allResolved := false + break + if allResolved || current != t then + t := current; steps := steps + 1; continue + | _ => pure () + -- Try delta unfolding + if let some r := ← unfoldDefinition t then + t ← whnfCore r; steps := steps + 1; continue + break + pure t + + /-- Unfold a single delta step (definition body). -/ + partial def unfoldDefinition (e : Expr m) : TypecheckM m (Option (Expr m)) := do + let head := e.getAppFn + match head with + | .const addr levels _ => do + let ci ← derefConst addr + match ci with + | .defnInfo v => + if v.safety == .partial then return none + let body := v.value.instantiateLevelParams levels + let args := e.getAppArgs + return some (body.mkAppN args) + | .thmInfo v => + let body := v.value.instantiateLevelParams levels + let args := e.getAppArgs + return some (body.mkAppN args) + | _ => return none + | _ => return none + + -- Type Inference and Checking -mutual /-- Check that a term has a given type. -/ partial def check (term : Expr m) (expectedType : Expr m) : TypecheckM m (TypedExpr m) := do -- if (← read).trace then dbg_trace s!"check: {term.tag}" @@ -258,6 +760,19 @@ mutual pure (te, typ) | .const addr constUnivs _ => do ensureTypedConst addr + -- Safety check: safe declarations cannot reference unsafe/partial constants + let inferOnly := (← read).inferOnly + if !inferOnly then + let ci ← derefConst addr + let curSafety := (← read).safety + if ci.isUnsafe && curSafety != .unsafe then + throw s!"invalid declaration, it uses unsafe declaration {addr}" + if let .defnInfo v := ci then + if v.safety == .partial && curSafety == .safe then + throw s!"invalid declaration, safe declaration must not contain partial declaration {addr}" + -- Universe level param count validation + if constUnivs.size != ci.numLevels then + throw s!"incorrect number of universe levels for {addr}: expected {ci.numLevels}, got {constUnivs.size}" match (← get).constTypeCache.get? addr with | some (cachedUnivs, cachedTyp) => if cachedUnivs == constUnivs then @@ -338,6 +853,10 @@ mutual /-- Typecheck a constant. -/ partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do + -- Determine safety early for withSafety wrapper + let ci? := (← read).kenv.find? addr + let declSafety := match ci? with | some ci => ci.safety | none => .safe + withSafety declSafety do -- Reset fuel and per-constant caches modify fun stt => { stt with constTypeCache := {}, @@ -355,7 +874,6 @@ mutual return () let ci ← derefConst addr let univs := ci.cv.mkUnivParams - -- Universe level instantiation for the constant's own level params let newConst ← match ci with | .axiomInfo _ => let (type, _) ← isSort ci.type @@ -383,9 +901,12 @@ mutual (Std.TreeMap.empty).insert 0 (addr, fun _ => typExpr) withMutTypes mutTypes (withRecAddr addr (check v.value type.body)) else withRecAddr addr (check v.value type.body) + validatePrimitive addr pure (TypedConst.definition type value part) | .quotInfo v => let (type, _) ← isSort ci.type + if (← read).quotInit then + validateQuotient pure (TypedConst.quotient type v.kind) | .inductInfo _ => checkIndBlock addr @@ -401,6 +922,15 @@ mutual if v.k then validateKFlag v indAddr validateRecursorRules v indAddr + checkElimLevel ci.type v indAddr + -- Check each rule RHS has the expected type + match (← read).kenv.find? indAddr with + | some (.inductInfo iv) => + for h : i in [:v.rules.size] do + let rule := v.rules[i] + if i < iv.ctors.size then + checkRecursorRuleType ci.type v iv.ctors[i]! rule.nfields rule.rhs + | _ => pure () let typedRules ← v.rules.mapM fun rule => do let (rhs, _) ← infer rule.rhs pure (rule.nfields, rhs) @@ -518,6 +1048,7 @@ mutual let .inductInfo iv := indInfo | throw "unreachable" if (← get).typedConsts.get? addr |>.isSome then return () let (type, _) ← isSort iv.type + validatePrimitive addr let isStruct := !iv.isRec && iv.numIndices == 0 && iv.ctors.size == 1 && match (← read).kenv.find? iv.ctors[0]! with | some (.ctorInfo cv) => cv.numFields > 0 @@ -532,6 +1063,19 @@ mutual modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cv.cidx cv.numFields) } if cv.numParams != iv.numParams then throw s!"Constructor {ctorAddr} has {cv.numParams} params but inductive has {iv.numParams}" + -- Validate constructor parameter domains match inductive parameter domains + if !iv.isUnsafe then do + let mut indTy := iv.type + let mut ctorTy := cv.type + for i in [:iv.numParams] do + match indTy, ctorTy with + | .forallE indDom indBody _ _, .forallE ctorDom ctorBody _ _ => + if !(← isDefEq indDom ctorDom) then + throw s!"Constructor {ctorAddr} parameter {i} domain doesn't match inductive parameter domain" + indTy := indBody + ctorTy := ctorBody + | _, _ => + throw s!"Constructor {ctorAddr} has fewer Pi binders than expected parameters" if !iv.isUnsafe then match ← checkCtorFields cv.type cv.numParams indAddrs with | some msg => throw s!"Constructor {ctorAddr}: {msg}" @@ -541,7 +1085,26 @@ mutual checkFieldUniverses cv.type cv.numParams ctorAddr indLvl if !iv.isUnsafe then let retType := getCtorReturnType cv.type cv.numParams cv.numFields + -- Validate return type head is one of the inductives being defined + let retHead := retType.getAppFn + match retHead with + | .const retAddr _ _ => + if !indAddrs.any (· == retAddr) then + throw s!"Constructor {ctorAddr} return type head is not the inductive being defined" + | _ => + throw s!"Constructor {ctorAddr} return type is not an inductive application" let args := retType.getAppArgs + -- Validate param args are correct bvars (bvar (numFields + numParams - 1 - i) for param i) + for i in [:iv.numParams] do + if i < args.size then + let expectedBvar := cv.numFields + iv.numParams - 1 - i + match args[i]! with + | .bvar idx _ => + if idx != expectedBvar then + throw s!"Constructor {ctorAddr} return type has wrong parameter at position {i}" + | _ => + throw s!"Constructor {ctorAddr} return type parameter {i} is not a bound variable" + -- Validate index args don't mention the inductives for i in [iv.numParams:args.size] do for indAddr in indAddrs do if exprMentionsConst args[i]! indAddr then @@ -553,7 +1116,8 @@ mutual (ctorAddr : Address) (indLvl : Level m) : TypecheckM m Unit := go ctorType numParams where - go (ty : Expr m) (remainingParams : Nat) : TypecheckM m Unit := + go (ty : Expr m) (remainingParams : Nat) : TypecheckM m Unit := do + let ty ← whnf ty match ty with | .forallE dom body _piName _ => if remainingParams > 0 then do @@ -568,6 +1132,70 @@ mutual withExtendedCtx dom (go body 0) | _ => pure () + /-- Check if a single-ctor Prop inductive allows large elimination. + All non-Prop fields must appear directly as index arguments in the return type. + Matches lean4lean's `isLargeEliminator` / lean4 C++ `elim_only_at_universe_zero`. -/ + partial def checkLargeElimSingleCtor (ctorType : Expr m) (numParams numFields : Nat) + : TypecheckM m Bool := + go ctorType numParams numFields #[] + where + go (ty : Expr m) (remainingParams : Nat) (remainingFields : Nat) + (nonPropBvars : Array Nat) : TypecheckM m Bool := do + let ty ← whnf ty + match ty with + | .forallE dom body _ _ => + if remainingParams > 0 then + withExtendedCtx dom (go body (remainingParams - 1) remainingFields nonPropBvars) + else if remainingFields > 0 then + let (_, fieldSortLvl) ← isSort dom + let nonPropBvars := if !Level.isZero fieldSortLvl then + -- After all remaining fields, this field is bvar (remainingFields - 1) + nonPropBvars.push (remainingFields - 1) + else nonPropBvars + withExtendedCtx dom (go body 0 (remainingFields - 1) nonPropBvars) + else pure true + | _ => + if nonPropBvars.isEmpty then return true + let args := ty.getAppArgs + for bvarIdx in nonPropBvars do + let mut found := false + for i in [numParams:args.size] do + match args[i]! with + | .bvar idx _ => if idx == bvarIdx then found := true + | _ => pure () + if !found then return false + return true + + /-- Validate that the recursor's elimination level is appropriate for the inductive. + If the inductive doesn't allow large elimination, the motive must return Prop. -/ + partial def checkElimLevel (recType : Expr m) (rec : RecursorVal m) (indAddr : Address) + : TypecheckM m Unit := do + let kenv := (← read).kenv + match kenv.find? indAddr with + | some (.inductInfo iv) => + let some indLvl := getIndResultLevel iv.type | return () + -- Non-zero result level → large elimination always allowed + if levelIsNonZero indLvl then return () + -- Extract motive sort from recursor type + let some motiveSort := getMotiveSort recType rec.numParams | return () + -- If motive is already Prop, nothing to check + if Level.isZero motiveSort then return () + -- Motive wants non-Prop elimination. Check if it's allowed. + -- Mutual inductives in Prop → no large elimination + if iv.all.size != 1 then + throw s!"recursor claims large elimination but mutual Prop inductive only allows Prop elimination" + if iv.ctors.isEmpty then return () -- empty Prop type can eliminate into any Sort + if iv.ctors.size != 1 then + throw s!"recursor claims large elimination but Prop inductive with multiple constructors only allows Prop elimination" + let ctorAddr := iv.ctors[0]! + match kenv.find? ctorAddr with + | some (.ctorInfo cv) => + let allowed ← checkLargeElimSingleCtor cv.type iv.numParams cv.numFields + if !allowed then + throw s!"recursor claims large elimination but inductive has non-Prop fields not appearing in indices" + | _ => return () + | _ => return () + /-- Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m Unit := do match (← read).kenv.find? indAddr with @@ -607,6 +1235,192 @@ mutual | _ => throw s!"constructor {iv.ctors[i]!} not found" | _ => pure () + /-- Check that a recursor rule RHS has the expected type. + Builds the expected type from the recursor type and constructor type, + then verifies the inferred RHS type matches via isDefEq. + The expected type for rule j (constructor ctor_j with nf fields) is: + Π (rec_params) (motives) (minors) (ctor_fields) . motive indices (ctor_j params fields) + where the first (np+nm+nk) Pi binders come from the recursor type and + the field binders come from the constructor type (with param bvars shifted + to skip motive/minor binders). -/ + partial def checkRecursorRuleType (recType : Expr m) (rec : RecursorVal m) + (ctorAddr : Address) (nf : Nat) (ruleRhs : Expr m) : TypecheckM m Unit := do + let np := rec.numParams + let nm := rec.numMotives + let nk := rec.numMinors + let shift := nm + nk + -- Look up constructor info + let ctorCi ← derefConst ctorAddr + let ctorType := ctorCi.type + -- 1. Extract recursor binder domains (params + motives + minors) + let mut recTy := recType + let mut recDoms : Array (Expr m) := #[] + for _ in [:np + nm + nk] do + match recTy with + | .forallE dom body _ _ => + recDoms := recDoms.push dom + recTy := body + | _ => throw "recursor type has too few Pi binders for params+motives+minors" + -- Determine motive position from recursor return type. + -- After stripping indices+major, the return expr head is bvar(ni+nk+nm-d) + -- where d is the motive index for the major inductive. + let ni := rec.numIndices + let motivePos : Nat := Id.run do + let mut ty := recTy + for _ in [:ni + 1] do + match ty with + | .forallE _ body _ _ => ty := body + | _ => return 0 + match ty.getAppFn with + | .bvar idx _ => return (ni + nk + nm - idx) + | _ => return 0 + -- 2. Extract field domains from ctor type and handle nested params. + -- The constructor may have more params than the recursor (nested inductive pattern): + -- rec.numParams = shared params; cv.numParams may include extra "nested" params. + let cnp := match ctorCi with | .ctorInfo cv => cv.numParams | _ => np + -- Extract the major premise domain (needed for nested param values and level extraction). + -- recTy (after stripping np+nm+nk) = Π (indices) (major : IndType args), ret + let majorPremiseDom : Option (Expr m) := Id.run do + let mut ty := recTy + for _ in [:ni] do + match ty with + | .forallE _ body _ _ => ty := body + | _ => return none + match ty with + | .forallE dom _ _ _ => return some dom + | _ => return none + -- Compute constructor level substitution. + -- For nested inductives (cnp > np): extract actual levels from the major premise domain head + -- (e.g., List.{0} RCasesPatt → levels = [Level.zero]). + -- For standard case: map ctor level param i → rec level param (levelOffset + i). + let recLevelCount := rec.numLevels + let ctorLevelCount := ctorCi.cv.numLevels + let levelSubst : Array (Level m) := + if cnp > np then + match majorPremiseDom with + | some dom => match dom.getAppFn with + | .const _ lvls _ => lvls + | _ => #[] + | none => #[] + else + let levelOffset := recLevelCount - ctorLevelCount + Array.ofFn (n := ctorLevelCount) fun i => + .param (levelOffset + i.val) (default : MetaField m Ix.Name) + let ctorLevels := levelSubst + -- Extract nested param values from the major premise domain args. + let nestedParams : Array (Expr m) := + if cnp > np then + match majorPremiseDom with + | some dom => + let args := dom.getAppArgs + -- args[np..cnp-1] are nested param values (under np+nm+nk+ni binders) + -- Shift up by nf to account for field binders in rule context + Array.ofFn (n := cnp - np) fun i => + if np + i.val < args.size then + shiftCtorToRule args[np + i.val]! 0 nf #[] + else default + | none => #[] + else #[] + -- Peel ALL constructor params (cnp, not just np) + let mut cty := ctorType + for _ in [:cnp] do + match cty with + | .forallE _ body _ _ => cty := body + | _ => throw "constructor type has too few Pi binders for params" + -- cty has nf field Pi binders and cnp free param bvars + let mut fieldDoms : Array (Expr m) := #[] + let mut ctorRetType := cty + for _ in [:nf] do + match ctorRetType with + | .forallE dom body _ _ => + fieldDoms := fieldDoms.push dom + ctorRetType := body + | _ => throw "constructor type has too few Pi binders for fields" + -- ctorRetType has cnp free param bvars and nf free field bvars. + -- Extra nested param bvars (0..cnp-np-1 at depth 0, i.e. indices nf..nf+cnp-np-1 in body) + -- need to be substituted with nestedParams before shifting. + -- Substitute extra param bvars: in the body, extra params are bvar indices + -- 0..cnp-np-1 (after fields). We instantiate them and shift shared params down. + let ctorRet := if cnp > np then + substNestedParams ctorRetType nf (cnp - np) nestedParams + else ctorRetType + let fieldDomsAdj := if cnp > np then + Array.ofFn (n := fieldDoms.size) fun i => + substNestedParams fieldDoms[i]! i.val (cnp - np) nestedParams + else fieldDoms + -- Now ctorRet has np free param bvars and nf free field bvars + -- Shift param bvars (>= nf) up by nm+nk for the rule context + let ctorRetShifted := shiftCtorToRule ctorRet nf shift levelSubst + -- 3. Build expected return type: motive indices (ctor params fields) + -- Under all np+nm+nk+nf binders: + -- motive_d = bvar (nf + nk + nm - 1 - d) [d = position of major inductive in rec.all] + -- param i = bvar (nf + nk + nm + np - 1 - i) + -- field k = bvar (nf - 1 - k) + let motiveIdx := nf + nk + nm - 1 - motivePos + let mut ret := Expr.mkBVar motiveIdx + -- Apply indices from shifted ctor return type (skip all cnp param args) + let ctorRetArgs := ctorRetShifted.getAppArgs + for i in [cnp:ctorRetArgs.size] do + ret := Expr.mkApp ret ctorRetArgs[i]! + -- Build ctor application: ctor levels params fields nested-params + let mut ctorApp : Expr m := Expr.mkConst ctorAddr ctorLevels + for i in [:np] do + ctorApp := Expr.mkApp ctorApp (Expr.mkBVar (nf + shift + np - 1 - i)) + for v in nestedParams do + ctorApp := Expr.mkApp ctorApp v + for k in [:nf] do + ctorApp := Expr.mkApp ctorApp (Expr.mkBVar (nf - 1 - k)) + ret := Expr.mkApp ret ctorApp + -- 4. Wrap return type with field Pi binders (innermost first, shifted) + let mut fullType := ret + for i in [:nf] do + let j := nf - 1 - i + let dom := shiftCtorToRule fieldDomsAdj[j]! j shift levelSubst + fullType := .forallE dom fullType default default + -- 5. Wrap with recursor binder Pi's (minors, motives, params - outermost first → innermost first) + for i in [:np + nm + nk] do + let j := np + nm + nk - 1 - i + fullType := .forallE recDoms[j]! fullType default default + -- 6. Check inferred RHS type matches expected type + let (_, rhsType) ← withInferOnly (infer ruleRhs) + if !(← withInferOnly (isDefEq rhsType fullType)) then + -- Walk both types in parallel, peeling Pi binders, to find where they diverge + let mut rTy := rhsType + let mut eTy := fullType + let mut binderIdx := 0 + let mut divergeMsg := "types differ at top level" + let mut found := false + for _ in [:np + nm + nk + nf + 10] do -- enough iterations + if found then break + match rTy, eTy with + | .forallE rd rb _ _, .forallE ed eb _ _ => + if !(← withInferOnly (isDefEq rd ed)) then + divergeMsg := s!"binder {binderIdx} domain differs" + found := true + else + rTy := rb; eTy := eb; binderIdx := binderIdx + 1 + | _, _ => + if !(← withInferOnly (isDefEq rTy eTy)) then + let rHead := rTy.getAppFn + let eHead := eTy.getAppFn + let rArgs := rTy.getAppArgs + let eArgs := eTy.getAppArgs + let headEq ← withInferOnly (isDefEq rHead eHead) + let rTag := if rHead.isBVar then s!"bvar{rHead.bvarIdx!}" else if rHead.isConst then "const" else "other" + let eTag := if eHead.isBVar then s!"bvar{eHead.bvarIdx!}" else if eHead.isConst then "const" else "other" + let mut argDiag := s!"rHead={rTag} eHead={eTag} headEq={headEq} rArgs={rArgs.size} eArgs={eArgs.size}" + if headEq then + for j in [:min rArgs.size eArgs.size] do + if !(← withInferOnly (isDefEq rArgs[j]! eArgs[j]!)) then + argDiag := argDiag ++ s!" arg{j}differs" + break + divergeMsg := s!"return type differs after {binderIdx} binders; {argDiag}" + found := true + else + divergeMsg := s!"types are actually equal after {binderIdx} binders??" + found := true + throw s!"recursor rule RHS type mismatch for constructor {ctorCi.cv.name} ({ctorAddr}): {divergeMsg} (np={np} cnp={cnp})" + /-- Quick structural equality check without WHNF. Returns: - some true: definitely equal - some false: definitely not equal @@ -626,7 +1440,7 @@ mutual | .const a us _, .const b us' _ => if a == b && equalUnivArrays us us' then pure (some true) else pure none | .lit l, .lit l' => pure (some (l == l')) - | .bvar i _, .bvar j _ => pure (some (i == j)) + | .bvar i _, .bvar j _ => if i == j then pure (some true) else pure none | .lam .., .lam .. => do let mut a := t; let mut b := s repeat @@ -664,6 +1478,16 @@ mutual withRecDepthCheck do withFuelCheck do + -- Bool.true proof-by-reflection (matches lean4 C++ is_def_eq_core) + -- If one side is Bool.true, fully reduce the other and check + let prims := (← read).prims + if s.isConstOf prims.boolTrue then + let t' ← whnf t + if t'.isConstOf prims.boolTrue then cacheResult t s true; return true + if t.isConstOf prims.boolTrue then + let s' ← whnf s + if s'.isConstOf prims.boolTrue then cacheResult t s true; return true + -- Loop: steps 1-5 may restart when whnfCore(cheapProj=false) changes terms let mut ct := t let mut cs := s @@ -782,17 +1606,20 @@ mutual | _, _ => break withReader (fun ctx => { ctx with types := extTypes, letValues := extLetValues }) (isDefEq a b) - -- Application: flatten app spine to avoid O(num_args) stack depth + -- Application: flatten app spine, with eta-struct fallback (matches lean4lean) | .app .., .app .. => do let tFn := t.getAppFn let sFn := s.getAppFn let tArgs := t.getAppArgs let sArgs := s.getAppArgs - if tArgs.size != sArgs.size then return false - if !(← isDefEq tFn sFn) then return false - for h : i in [:tArgs.size] do - if !(← isDefEq tArgs[i] sArgs[i]!) then return false - return true + if tArgs.size == sArgs.size then + if (← isDefEq tFn sFn) then + let mut ok := true + for h : i in [:tArgs.size] do + if !(← isDefEq tArgs[i] sArgs[i]!) then ok := false; break + if ok then return true + -- Fallback: try eta-struct when isDefEqApp fails + tryEtaStruct t s -- Projection | .proj a i struct _, .proj b j struct' _ => @@ -840,9 +1667,10 @@ mutual let expanded := strLitToConstructor prims str isDefEq t expanded - -- Structure eta - | _, .app _ _ => tryEtaStruct t s - | .app _ _, _ => tryEtaStruct s t + -- Structure eta (one side is app, other is not), with unit-like fallback + | _, .app _ _ | .app _ _, _ => do + if ← tryEtaStruct t s then return true + isDefEqUnitLike t s -- Unit-like fallback: non-recursive, single ctor with 0 fields, 0 indices | _, _ => isDefEqUnitLike t s @@ -899,7 +1727,7 @@ mutual let kenv := (← read).kenv let mut steps := 0 repeat - if steps > 10000 then return (tn, sn, none) + if steps > 10000 then throw "lazyDeltaReduction step limit (10000) exceeded" steps := steps + 1 -- Syntactic check @@ -993,32 +1821,28 @@ mutual if !(← isDefEq tArgs[i] sArgs[i]!) then return false return true - /-- Try eta expansion for structure-like types. -/ + /-- Try eta expansion for structure-like types. + Matches lean4lean's `tryEtaStruct`: constructs projections and compares via `isDefEq`. -/ partial def tryEtaStruct (t s : Expr m) : TypecheckM m Bool := do - -- s should be a constructor application - let sFn := s.getAppFn - match sFn with - | .const ctorAddr _ _ => + if ← tryEtaStructCore t s then return true + tryEtaStructCore s t + where + tryEtaStructCore (t s : Expr m) : TypecheckM m Bool := do + let .const ctorAddr _ _ := s.getAppFn | return false match (← read).kenv.find? ctorAddr with | some (.ctorInfo cv) => - let indAddr := cv.induct - if !(← read).kenv.isStructureLike indAddr then return false let sArgs := s.getAppArgs - -- Check that each field arg is a projection of t - let numParams := cv.numParams + unless sArgs.size == cv.numParams + cv.numFields do return false + unless (← read).kenv.isStructureLike cv.induct do return false + let (_, tType) ← withInferOnly (infer t) + let (_, sType) ← withInferOnly (infer s) + unless ← isDefEq tType sType do return false for h : i in [:cv.numFields] do - let argIdx := numParams + i - if argIdx < sArgs.size then - let arg := sArgs[argIdx]! - match arg with - | .proj a idx struct _ => - if a != indAddr || idx != i then return false - if !(← isDefEq t struct) then return false - | _ => return false - else return false + let argIdx := cv.numParams + i + let proj := Expr.mkProj cv.induct i t + unless ← isDefEq proj sArgs[argIdx]! do return false return true | _ => return false - | _ => return false /-- Cache a def-eq result (both successes and failures). -/ partial def cacheResult (t s : Expr m) (result : Bool) : TypecheckM m Unit := do @@ -1030,6 +1854,20 @@ mutual let key := eqCacheKey t s modify fun stt => { stt with failureCache := stt.failureCache.insert key () } + /-- Validate a primitive definition/inductive/quotient using the KernelOps callback. -/ + partial def validatePrimitive (addr : Address) : TypecheckM m Unit := do + let ops : KernelOps m := { isDefEq, whnf, infer, isProp, isSort } + let prims := (← read).prims + let kenv := (← read).kenv + let _ ← checkPrimitive ops prims kenv addr + + /-- Validate quotient constant type signatures. -/ + partial def validateQuotient : TypecheckM m Unit := do + let ops : KernelOps m := { isDefEq, whnf, infer, isProp, isSort } + let prims := (← read).prims + checkEqType ops prims + checkQuotTypes ops prims + end -- mutual /-! ## Expr size -/ diff --git a/Ix/Kernel/Primitive.lean b/Ix/Kernel/Primitive.lean new file mode 100644 index 00000000..4df64fef --- /dev/null +++ b/Ix/Kernel/Primitive.lean @@ -0,0 +1,402 @@ +/- + Kernel Primitive: Validation of primitive definitions, inductives, and quotient types. + + Translates lean4lean's Primitive.lean and Quot.lean checks to work with + Ix's address-based, de Bruijn-indexed expressions. Called from the mutual + block in Infer.lean via the KernelOps callback struct. + + All comparisons use isDefEq (not structural equality) so that .meta mode + name/binder-info differences don't cause spurious failures. +-/ +import Ix.Kernel.TypecheckM + +namespace Ix.Kernel + +/-! ## KernelOps — callback struct to access mutual-block functions -/ + +structure KernelOps (m : MetaMode) where + isDefEq : Expr m → Expr m → TypecheckM m Bool + whnf : Expr m → TypecheckM m (Expr m) + infer : Expr m → TypecheckM m (TypedExpr m × Expr m) + isProp : Expr m → TypecheckM m Bool + isSort : Expr m → TypecheckM m (TypedExpr m × Level m) + +/-! ## Expression builders -/ + +private def natConst (p : Primitives) : Expr m := Expr.mkConst p.nat #[] +private def boolConst (p : Primitives) : Expr m := Expr.mkConst p.bool #[] +private def trueConst (p : Primitives) : Expr m := Expr.mkConst p.boolTrue #[] +private def falseConst (p : Primitives) : Expr m := Expr.mkConst p.boolFalse #[] +private def zeroConst (p : Primitives) : Expr m := Expr.mkConst p.natZero #[] +private def charConst (p : Primitives) : Expr m := Expr.mkConst p.char #[] +private def stringConst (p : Primitives) : Expr m := Expr.mkConst p.string #[] +private def listCharConst (p : Primitives) : Expr m := + Expr.mkApp (Expr.mkConst p.list #[Level.succ .zero]) (charConst p) + +private def succApp (p : Primitives) (e : Expr m) : Expr m := + Expr.mkApp (Expr.mkConst p.natSucc #[]) e +private def predApp (p : Primitives) (e : Expr m) : Expr m := + Expr.mkApp (Expr.mkConst p.natPred #[]) e +private def addApp (p : Primitives) (a b : Expr m) : Expr m := + Expr.mkApp (Expr.mkApp (Expr.mkConst p.natAdd #[]) a) b +private def subApp (p : Primitives) (a b : Expr m) : Expr m := + Expr.mkApp (Expr.mkApp (Expr.mkConst p.natSub #[]) a) b +private def mulApp (p : Primitives) (a b : Expr m) : Expr m := + Expr.mkApp (Expr.mkApp (Expr.mkConst p.natMul #[]) a) b +private def modApp (p : Primitives) (a b : Expr m) : Expr m := + Expr.mkApp (Expr.mkApp (Expr.mkConst p.natMod #[]) a) b +private def divApp (p : Primitives) (a b : Expr m) : Expr m := + Expr.mkApp (Expr.mkApp (Expr.mkConst p.natDiv #[]) a) b + +/-- Arrow type: `a → b` (non-dependent forall). -/ +private def mkArrow (a b : Expr m) : Expr m := Expr.mkForallE a (b.liftBVars 1) + +/-- `Nat → Nat → Nat` -/ +private def natBinType (p : Primitives) : Expr m := + mkArrow (natConst p) (mkArrow (natConst p) (natConst p)) + +/-- `Nat → Nat` -/ +private def natUnaryType (p : Primitives) : Expr m := + mkArrow (natConst p) (natConst p) + +/-- `Nat → Nat → Bool` -/ +private def natBinBoolType (p : Primitives) : Expr m := + mkArrow (natConst p) (mkArrow (natConst p) (boolConst p)) + +/-- Wrap both sides in `∀ (_ : Nat), _` so bvar 0 is well-typed as Nat. -/ +private def defeq1 (ops : KernelOps m) (p : Primitives) (a b : Expr m) : TypecheckM m Bool := + ops.isDefEq (mkArrow (natConst p) a) (mkArrow (natConst p) b) + +/-- Wrap both sides in `∀ (_ : Nat), ∀ (_ : Nat), _` for two free variables. -/ +private def defeq2 (ops : KernelOps m) (p : Primitives) (a b : Expr m) : TypecheckM m Bool := + defeq1 ops p (mkArrow (natConst p) a) (mkArrow (natConst p) b) + +/-- Check if an address is non-default (i.e., was actually resolved). -/ +private def resolved (addr : Address) : Bool := addr != default + +/-! ## Primitive inductive validation -/ + +/-- Check that Bool or Nat inductives have the expected form. + Uses isDefEq for type comparison so it works in both .meta and .anon modes. + Matches constructors by address from Primitives, not by position. -/ +def checkPrimitiveInductive (ops : KernelOps m) (p : Primitives) (kenv : Env m) + (addr : Address) : TypecheckM m Bool := do + let ci ← derefConst addr + let .inductInfo iv := ci | return false + if iv.isUnsafe then return false + if iv.numLevels != 0 then return false + if iv.numParams != 0 then return false + unless ← ops.isDefEq iv.type (Expr.mkSort (Level.succ .zero)) do return false + -- Check Bool + if addr == p.bool then + if iv.ctors.size != 2 then + throw "Bool must have exactly 2 constructors" + for ctorAddr in iv.ctors do + let ctor ← derefConst ctorAddr + unless ← ops.isDefEq ctor.type (boolConst p) do + throw s!"Bool constructor has unexpected type" + return true + -- Check Nat + if addr == p.nat then + if iv.ctors.size != 2 then + throw "Nat must have exactly 2 constructors" + for ctorAddr in iv.ctors do + let ctor ← derefConst ctorAddr + if ctorAddr == p.natZero then + unless ← ops.isDefEq ctor.type (natConst p) do + throw "Nat.zero has unexpected type" + else if ctorAddr == p.natSucc then + unless ← ops.isDefEq ctor.type (natUnaryType p) do + throw "Nat.succ has unexpected type" + else + throw s!"unexpected Nat constructor" + return true + return false + +/-! ## Simple primitive definition checks -/ + +/-- Check a primitive definition's type and reduction rules. + Returns true if the address matches a known primitive and passes validation. -/ +def checkPrimitiveDef (ops : KernelOps m) (p : Primitives) (kenv : Env m) (addr : Address) + : TypecheckM m Bool := do + let ci ← derefConst addr + let .defnInfo v := ci | return false + -- Skip if addr doesn't match any known primitive (avoid false positives). + -- stringOfList is excluded when it equals stringMk (constructor, validated via inductive path). + let isPrimAddr := addr == p.natAdd || addr == p.natSub || addr == p.natMul || + addr == p.natPow || addr == p.natBeq || addr == p.natBle || + addr == p.natShiftLeft || addr == p.natShiftRight || + addr == p.natLand || addr == p.natLor || addr == p.natXor || + addr == p.natPred || addr == p.natBitwise || + addr == p.charMk || + (addr == p.stringOfList && p.stringOfList != p.stringMk) + if !isPrimAddr then return false + let fail {α : Type} (msg : String := "invalid form for primitive def") : TypecheckM m α := + throw msg + let nat : Expr m := natConst p + let tru : Expr m := trueConst p + let fal : Expr m := falseConst p + let zero : Expr m := zeroConst p + let succ : Expr m → Expr m := succApp p + let pred : Expr m → Expr m := predApp p + let add : Expr m → Expr m → Expr m := addApp p + let _sub : Expr m → Expr m → Expr m := subApp p + let mul : Expr m → Expr m → Expr m := mulApp p + let _mod' : Expr m → Expr m → Expr m := modApp p + let div' : Expr m → Expr m → Expr m := divApp p + let one : Expr m := succ zero + let two : Expr m := succ one + -- x = bvar 0, y = bvar 1 (inside wrapping binders) + let x : Expr m := .mkBVar 0 + let y : Expr m := .mkBVar 1 + + -- Nat.add + if addr == p.natAdd then + if !kenv.contains p.nat || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let addV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (addV x zero) x do fail + unless ← defeq2 ops p (addV y (succ x)) (succ (addV y x)) do fail + return true + + -- Nat.pred + if addr == p.natPred then + if !kenv.contains p.nat || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natUnaryType p) do fail + let predV := fun a => Expr.mkApp v.value a + unless ← ops.isDefEq (predV zero) zero do fail + unless ← defeq1 ops p (predV (succ x)) x do fail + return true + + -- Nat.sub + if addr == p.natSub then + if !kenv.contains p.natPred || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let subV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (subV x zero) x do fail + unless ← defeq2 ops p (subV y (succ x)) (pred (subV y x)) do fail + return true + + -- Nat.mul + if addr == p.natMul then + if !kenv.contains p.natAdd || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let mulV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (mulV x zero) zero do fail + unless ← defeq2 ops p (mulV y (succ x)) (add (mulV y x) y) do fail + return true + + -- Nat.pow + if addr == p.natPow then + if !kenv.contains p.natMul || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let powV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (powV x zero) one do fail + unless ← defeq2 ops p (powV y (succ x)) (mul (powV y x) y) do fail + return true + + -- Nat.beq + if addr == p.natBeq then + if !kenv.contains p.nat || !kenv.contains p.bool || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinBoolType p) do fail + let beqV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← ops.isDefEq (beqV zero zero) tru do fail + unless ← defeq1 ops p (beqV zero (succ x)) fal do fail + unless ← defeq1 ops p (beqV (succ x) zero) fal do fail + unless ← defeq2 ops p (beqV (succ y) (succ x)) (beqV y x) do fail + return true + + -- Nat.ble + if addr == p.natBle then + if !kenv.contains p.nat || !kenv.contains p.bool || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinBoolType p) do fail + let bleV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← ops.isDefEq (bleV zero zero) tru do fail + unless ← defeq1 ops p (bleV zero (succ x)) tru do fail + unless ← defeq1 ops p (bleV (succ x) zero) fal do fail + unless ← defeq2 ops p (bleV (succ y) (succ x)) (bleV y x) do fail + return true + + -- Nat.shiftLeft + if addr == p.natShiftLeft then + if !kenv.contains p.natMul || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let shlV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (shlV x zero) x do fail + unless ← defeq2 ops p (shlV x (succ y)) (shlV (mul two x) y) do fail + return true + + -- Nat.shiftRight + if addr == p.natShiftRight then + if !kenv.contains p.natDiv || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let shrV := fun a b => Expr.mkApp (Expr.mkApp v.value a) b + unless ← defeq1 ops p (shrV x zero) x do fail + unless ← defeq2 ops p (shrV x (succ y)) (div' (shrV x y) two) do fail + return true + + -- Nat.land + if addr == p.natLand then + if !kenv.contains p.natBitwise || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let (.app fn f) := v.value | fail "Nat.land value must be Nat.bitwise applied to a function" + unless fn.isConstOf p.natBitwise do fail "Nat.land value head must be Nat.bitwise" + let andF := fun a b => Expr.mkApp (Expr.mkApp f a) b + unless ← defeq1 ops p (andF fal x) fal do fail + unless ← defeq1 ops p (andF tru x) x do fail + return true + + -- Nat.lor + if addr == p.natLor then + if !kenv.contains p.natBitwise || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let (.app fn f) := v.value | fail "Nat.lor value must be Nat.bitwise applied to a function" + unless fn.isConstOf p.natBitwise do fail "Nat.lor value head must be Nat.bitwise" + let orF := fun a b => Expr.mkApp (Expr.mkApp f a) b + unless ← defeq1 ops p (orF fal x) x do fail + unless ← defeq1 ops p (orF tru x) tru do fail + return true + + -- Nat.xor + if addr == p.natXor then + if !kenv.contains p.natBitwise || v.numLevels != 0 then fail + unless ← ops.isDefEq v.type (natBinType p) do fail + let (.app fn f) := v.value | fail "Nat.xor value must be Nat.bitwise applied to a function" + unless fn.isConstOf p.natBitwise do fail "Nat.xor value head must be Nat.bitwise" + let xorF := fun a b => Expr.mkApp (Expr.mkApp f a) b + unless ← ops.isDefEq (xorF fal fal) fal do fail + unless ← ops.isDefEq (xorF tru fal) tru do fail + unless ← ops.isDefEq (xorF fal tru) tru do fail + unless ← ops.isDefEq (xorF tru tru) fal do fail + return true + + -- Char.ofNat (charMk field) + if addr == p.charMk then + if !kenv.contains p.nat || v.numLevels != 0 then fail + let expectedType := mkArrow nat (charConst p) + unless ← ops.isDefEq v.type expectedType do fail + return true + + -- String.ofList + if addr == p.stringOfList then + if v.numLevels != 0 then fail + let listChar := listCharConst p + let expectedType := mkArrow listChar (stringConst p) + unless ← ops.isDefEq v.type expectedType do fail + -- Check List.nil Char : List Char + let nilChar := Expr.mkApp (Expr.mkConst p.listNil #[Level.succ .zero]) (charConst p) + let (_, nilType) ← ops.infer nilChar + unless ← ops.isDefEq nilType listChar do fail + -- Check List.cons Char : Char → List Char → List Char + let consChar := Expr.mkApp (Expr.mkConst p.listCons #[Level.succ .zero]) (charConst p) + let (_, consType) ← ops.infer consChar + let expectedConsType := mkArrow (charConst p) (mkArrow listChar listChar) + unless ← ops.isDefEq consType expectedConsType do fail + return true + + return false + +/-! ## Quotient validation -/ + +/-- Check that the Eq inductive has the correct form using isDefEq. + Eq must be an inductive with 1 univ param, 1 constructor. + Eq type: ∀ {α : Sort u}, α → α → Prop + Eq.refl type: ∀ {α : Sort u} (a : α), @Eq α a a -/ +def checkEqType (ops : KernelOps m) (p : Primitives) : TypecheckM m Unit := do + if !(← read).kenv.contains p.eq then + throw "Eq type not found in environment" + let ci ← derefConst p.eq + let .inductInfo iv := ci | throw "Eq is not an inductive" + if iv.numLevels != 1 then + throw "Eq must have exactly 1 universe parameter" + if iv.ctors.size != 1 then + throw "Eq must have exactly 1 constructor" + -- Check Eq type: ∀ {α : Sort u}, α → α → Prop + let u : Level m := .param 0 default + let sortU : Expr m := Expr.mkSort u + let expectedEqType : Expr m := + Expr.mkForallE sortU -- {α : Sort u} + (Expr.mkForallE (.mkBVar 0) -- (a : α) + (Expr.mkForallE (.mkBVar 1) -- (b : α) + Expr.prop)) -- Prop + unless ← ops.isDefEq ci.type expectedEqType do + throw "Eq has unexpected type" + + -- Check Eq.refl + if !(← read).kenv.contains p.eqRefl then + throw "Eq.refl not found in environment" + let refl ← derefConst p.eqRefl + if refl.numLevels != 1 then + throw "Eq.refl must have exactly 1 universe parameter" + let eqConst : Expr m := Expr.mkConst p.eq #[u] + let expectedReflType : Expr m := + Expr.mkForallE sortU -- {α : Sort u} + (Expr.mkForallE (.mkBVar 0) -- (a : α) + (Expr.mkApp (Expr.mkApp (Expr.mkApp eqConst (.mkBVar 1)) (.mkBVar 0)) (.mkBVar 0))) + unless ← ops.isDefEq refl.type expectedReflType do + throw "Eq.refl has unexpected type" + +/-- Check quotient type signatures against expected forms. -/ +def checkQuotTypes (ops : KernelOps m) (p : Primitives) + : TypecheckM m Unit := do + let u : Level m := .param 0 default + let sortU : Expr m := Expr.mkSort u + + -- Build `α → α → Prop` where α = bvar depth at the current level. + -- Under one binder, α = bvar (depth+1). Direct forallE, no mkArrow lift. + let relType (depth : Nat) : Expr m := + Expr.mkForallE (.mkBVar depth) -- ∀ (_ : α) + (Expr.mkForallE (.mkBVar (depth + 1)) -- ∀ (_ : α) + Expr.prop) + + -- Quot.{u} : ∀ {α : Sort u} (r : α → α → Prop), Sort u + if resolved p.quotType then + let ci ← derefConst p.quotType + let expectedType : Expr m := + Expr.mkForallE sortU -- {α : Sort u} + (Expr.mkForallE (relType 0) -- (r : α → α → Prop) + (Expr.mkSort u)) + unless ← ops.isDefEq ci.type expectedType do + throw "Quot type signature mismatch" + + -- Quot.mk.{u} : ∀ {α : Sort u} (r : α → α → Prop) (a : α), @Quot α r + -- Under {α=2, r=1, a=0}: Quot α r = Quot (bvar 2) (bvar 1) + if resolved p.quotCtor then + let ci ← derefConst p.quotCtor + let quotApp : Expr m := Expr.mkApp (Expr.mkApp (Expr.mkConst p.quotType #[u]) (.mkBVar 2)) (.mkBVar 1) + let expectedType : Expr m := + Expr.mkForallE sortU -- {α : Sort u} + (Expr.mkForallE (relType 0) -- (r : α → α → Prop) + (Expr.mkForallE (.mkBVar 1) -- (a : α) — α=bvar 1 under {α=1, r=0} + quotApp)) + unless ← ops.isDefEq ci.type expectedType do + throw "Quot.mk type signature mismatch" + + -- Quot.lift and Quot.ind have complex types with deeply nested dependent binders. + -- Verify structural properties: correct number of universe params. + -- The type-checking of quotient reduction rules (in Whnf.lean) provides + -- the semantic guarantee that these constants have correct behavior. + -- TODO: Full de Bruijn type signature validation for Quot.lift and Quot.ind. + if resolved p.quotLift then + let ci ← derefConst p.quotLift + if ci.numLevels != 2 then + throw "Quot.lift must have exactly 2 universe parameters" + + if resolved p.quotInd then + let ci ← derefConst p.quotInd + if ci.numLevels != 1 then + throw "Quot.ind must have exactly 1 universe parameter" + +/-! ## Top-level dispatch -/ + +/-- Check if `addr` is a known primitive and validate it. + Returns true if the address matches a known primitive and passes validation. -/ +def checkPrimitive (ops : KernelOps m) (p : Primitives) (kenv : Env m) (addr : Address) + : TypecheckM m Bool := do + -- Try primitive inductives first + if addr == p.bool || addr == p.nat then + return ← checkPrimitiveInductive ops p kenv addr + -- Try primitive definitions + checkPrimitiveDef ops p kenv addr + +end Ix.Kernel diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 5ead128e..3d182c8a 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -105,6 +105,9 @@ def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := def withInferOnly : TypecheckM m α → TypecheckM m α := withReader fun ctx => { ctx with inferOnly := true } +def withSafety (s : DefinitionSafety) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with safety := s } + /-- The current binding depth (number of bound variables in scope). -/ def lvl : TypecheckM m Nat := do pure (← read).types.size diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean index a9e95818..a24c808d 100644 --- a/Ix/Kernel/Types.lean +++ b/Ix/Kernel/Types.lean @@ -797,7 +797,7 @@ def hints : ConstantInfo m → ReducibilityHints def safety : ConstantInfo m → DefinitionSafety | defnInfo v => v.safety - | _ => .safe + | ci => if ci.isUnsafe then .unsafe else .safe def all? : ConstantInfo m → Option (Array Address) | defnInfo v => some v.all @@ -917,6 +917,10 @@ structure Primitives where natXor : Address := default natShiftLeft : Address := default natShiftRight : Address := default + natPred : Address := default + natBitwise : Address := default + natModCoreGo : Address := default + natDivGo : Address := default bool : Address := default boolTrue : Address := default boolFalse : Address := default @@ -924,19 +928,50 @@ structure Primitives where stringMk : Address := default char : Address := default charMk : Address := default + stringOfList : Address := default list : Address := default listNil : Address := default listCons : Address := default + eq : Address := default + eqRefl : Address := default quotType : Address := default quotCtor : Address := default quotLift : Address := default quotInd : Address := default + /-- Extra addresses for complex primitive validation (mod/div/gcd/bitwise). + These are only needed for checking primitive definitions, not for WHNF/etc. -/ + natLE : Address := default + natDecLe : Address := default + natDecEq : Address := default + natBleRefl : Address := default + natNotBleRefl : Address := default + natBeqRefl : Address := default + natNotBeqRefl : Address := default + ite : Address := default + dite : Address := default + «not» : Address := default + accRec : Address := default + accIntro : Address := default + natLtSuccSelf : Address := default + natDivRecFuelLemma : Address := default deriving Repr, Inhabited def buildPrimitives : Primitives := - { nat := addr! "fc0e1e912f2d7f12049a5b315d76eec29562e34dc39ebca25287ae58807db137" + { -- Core types and constructors + nat := addr! "fc0e1e912f2d7f12049a5b315d76eec29562e34dc39ebca25287ae58807db137" natZero := addr! "fac82f0d2555d6a63e1b8a1fe8d86bd293197f39c396fdc23c1275c60f182b37" natSucc := addr! "7190ce56f6a2a847b944a355e3ec595a4036fb07e3c3db9d9064fc041be72b64" + bool := addr! "6405a455ba70c2b2179c7966c6f610bf3417bd0f3dd2ba7a522533c2cd9e1d0b" + boolTrue := addr! "420dead2168abd16a7050edfd8e17d45155237d3118782d0e68b6de87742cb8d" + boolFalse := addr! "c127f89f92e0481f7a3e0631c5615fe7f6cbbf439d5fd7eba400fb0603aedf2f" + string := addr! "591cf1c489d505d4082f2767500f123e29db5227eb1bae4721eeedd672f36190" + stringMk := addr! "f055b87da4265d980cdede04ce5c7d986866e55816dc94d32a5d90e805101230" + char := addr! "563b426b73cdf1538b767308d12d10d746e1f0b3b55047085bf690319a86f893" + charMk := addr! "7156fef44bc309789375d784e5c36e387f7119363dd9cd349226c52df43d2075" + list := addr! "abed9ff1aba4634abc0bd3af76ca544285a32dcfe43dc27b129aea8867457620" + listNil := addr! "0ebe345dc46917c824b6c3f6c42b101f2ac8c0e2c99f033a0ee3c60acb9cd84d" + listCons := addr! "f79842f10206598929e6ba60ce3ebaa00d11f201c99e80285f46cc0e90932832" + -- Nat arithmetic primitives natAdd := addr! "dcc96f3f914e363d1e906a8be4c8f49b994137bfdb077d07b6c8a4cf88a4f7bf" natSub := addr! "6903e9bbd169b6c5515b27b3fc0c289ba2ff8e7e0c7f984747d572de4e6a7853" natMul := addr! "8e641c3df8fe3878e5a219c888552802743b9251c3c37c32795f5b9b9e0818a5" @@ -951,17 +986,31 @@ def buildPrimitives : Primitives := natXor := addr! "a711ef2cb4fa8221bebaa17ef8f4a965cf30678a89bc45ff18a13c902e683cc5" natShiftLeft := addr! "16e4558f51891516843a5b30ddd9d9b405ec096d3e1c728d09ff152b345dd607" natShiftRight := addr! "b9515e6c2c6b18635b1c65ebca18b5616483ebd53936f78e4ae123f6a27a089e" - bool := addr! "6405a455ba70c2b2179c7966c6f610bf3417bd0f3dd2ba7a522533c2cd9e1d0b" - boolTrue := addr! "420dead2168abd16a7050edfd8e17d45155237d3118782d0e68b6de87742cb8d" - boolFalse := addr! "c127f89f92e0481f7a3e0631c5615fe7f6cbbf439d5fd7eba400fb0603aedf2f" - string := addr! "591cf1c489d505d4082f2767500f123e29db5227eb1bae4721eeedd672f36190" - stringMk := addr! "f055b87da4265d980cdede04ce5c7d986866e55816dc94d32a5d90e805101230" - char := addr! "563b426b73cdf1538b767308d12d10d746e1f0b3b55047085bf690319a86f893" - charMk := addr! "7156fef44bc309789375d784e5c36e387f7119363dd9cd349226c52df43d2075" - list := addr! "abed9ff1aba4634abc0bd3af76ca544285a32dcfe43dc27b129aea8867457620" - listNil := addr! "0ebe345dc46917c824b6c3f6c42b101f2ac8c0e2c99f033a0ee3c60acb9cd84d" - listCons := addr! "f79842f10206598929e6ba60ce3ebaa00d11f201c99e80285f46cc0e90932832" - -- Quot primitives need to be computed; use default until wired up + natPred := addr! "27ccc47de9587564d0c87f4b84d231c523f835af76bae5c7176f694ae78e7d65" + natBitwise := addr! "f3c9111f01de3d46cb3e3f6ad2e35991c0283257e6c75ae56d2a7441e8c63e8b" + natModCoreGo := addr! "7304267986fb0f6d398b45284aa6d64a953a72faa347128bf17c52d1eaf55c8e" + natDivGo := addr! "b3266f662eb973cafd1c5a61e0036d4f9a8f5db6dab7d9f1fe4421c4fb4e1251" + -- String/Char definitions + stringOfList := addr! "f055b87da4265d980cdede04ce5c7d986866e55816dc94d32a5d90e805101230" + -- Eq + eq := addr! "c1b8d6903a3966bfedeccb63b6702fe226f893740d5c7ecf40045e7ac7635db3" + eqRefl := addr! "154ff4baae9cd74c5ffd813f61d3afee0168827ce12fd49aad8141ebe011ae35" + -- Quot primitives are resolved from .quot tags at conversion time + -- Extra: mod/div/gcd validation helpers (for future complex primitive validation) + natLE := addr! "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262" + natDecLe := addr! "fa523228c653841d5ad7f149c1587d0743f259209306458195510ed5bf1bfb14" + natDecEq := addr! "84817cd97c5054a512c3f0a6273c7cd81808eb2dec2916c1df737e864df6b23a" + natBleRefl := addr! "204286820d20add0c3f1bda45865297b01662876fc06c0d5c44347d5850321fe" + natNotBleRefl := addr! "2b2da52eecb98350a7a7c5654c0f6f07125808c5188d74f8a6196a9e1ca66c0c" + natBeqRefl := addr! "db18a07fc2d71d4f0303a17521576dc3020ab0780f435f6760cc9294804004f9" + natNotBeqRefl := addr! "d5ae71af8c02a6839275a2e212b7ee8e31a9ae07870ab721c4acf89644ef8128" + ite := addr! "4ddf0c98eee233ec746f52468f10ee754c2e05f05bdf455b1c77555a15107b8b" + dite := addr! "a942a2b85dd20f591163fad2e84e573476736d852ad95bcfba50a22736cd3c79" + «not» := addr! "236b6e6720110bc351a8ad6cbd22437c3e0ef014981a37d45ba36805c81364f3" + accRec := addr! "23104251c3618f32eb77bec895e99f54edd97feed7ac27f3248da378d05e3289" + accIntro := addr! "7ff829fa1057b6589e25bac87f500ad979f9b93f77d47ca9bde6b539a8842d87" + natLtSuccSelf := addr! "2d2e51025b6e0306fdc45b79492becea407881d5137573d23ff144fc38a29519" + natDivRecFuelLemma := addr! "026b6f9a63f5fe7ac20b41b81e4180d95768ca78d7d1962aa8280be6b27362b7" } end Ix.Kernel diff --git a/Ix/Kernel/Whnf.lean b/Ix/Kernel/Whnf.lean index cbb17621..466ae21c 100644 --- a/Ix/Kernel/Whnf.lean +++ b/Ix/Kernel/Whnf.lean @@ -12,7 +12,7 @@ open Level (instBulkReduce reduceIMax) /-! ## Helpers -/ /-- Check if an address is a primitive operation that takes arguments. -/ -private def isPrimOp (prims : Primitives) (addr : Address) : Bool := +def isPrimOp (prims : Primitives) (addr : Address) : Bool := addr == prims.natAdd || addr == prims.natSub || addr == prims.natMul || addr == prims.natPow || addr == prims.natGcd || addr == prims.natMod || addr == prims.natDiv || addr == prims.natBeq || addr == prims.natBle || @@ -117,399 +117,9 @@ partial def reduceProj (typeAddr : Address) (idx : Nat) (struct : Expr m) : Type | _ => return none | _ => return none -mutual - /-- Structural WHNF: beta, let-zeta, iota-proj. No delta unfolding. - Uses an iterative loop to avoid deep stack usage: - - App spines are collected iteratively (not recursively) - - Beta/let/iota/proj results loop back instead of tail-calling - When cheapProj=true, projections are returned as-is (no struct reduction). - When cheapRec=true, recursor applications are returned as-is (no iota reduction). -/ - partial def whnfCore (e : Expr m) (cheapRec := false) (cheapProj := false) - : TypecheckM m (Expr m) := do - -- Cache check FIRST — no stack cost for cache hits - -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) - let useCache := !cheapRec && !cheapProj && (← read).numLetBindings == 0 - if useCache then - if let some r := (← get).whnfCoreCache.get? e then return r - let r ← whnfCoreImpl e cheapRec cheapProj - if useCache then - modify fun s => { s with whnfCoreCache := s.whnfCoreCache.insert e r } - pure r - - partial def whnfCoreImpl (e : Expr m) (cheapRec : Bool) (cheapProj : Bool) - : TypecheckM m (Expr m) := do - let mut t := e - repeat - -- Fuel check - let stt ← get - if stt.fuel == 0 then throw "deep recursion fuel limit reached" - modify fun s => { s with fuel := s.fuel - 1 } - match t with - | .app .. => do - -- Collect app args iteratively (O(1) stack for app spine) - let args := t.getAppArgs - let fn := t.getAppFn - let fn' ← whnfCore fn cheapRec cheapProj -- recurse only on non-app head - -- Beta-reduce: consume as many args as possible - let mut result := fn' - let mut i : Nat := 0 - while i < args.size do - match result with - | .lam _ body _ _ => - result := body.instantiate1 args[i]! - i := i + 1 - | _ => break - if i > 0 then - -- Beta reductions happened. Apply remaining args and loop. - for h : j in [i:args.size] do - result := Expr.mkApp result args[j]! - t := result; continue -- loop instead of recursive tail call - else - -- No beta reductions. Try recursor/proj reduction. - let e' := if fn == fn' then t else fn'.mkAppN args - if cheapRec then return e' -- skip recursor reduction - let r ← tryReduceApp e' - if r == e' then return r -- stuck, return - t := r; continue -- iota/quot reduced, loop to re-process - | .bvar idx _ => do - -- Zeta-reduce let-bound bvars: look up the stored value and substitute - let ctx ← read - let depth := ctx.types.size - if idx < depth then - let arrayIdx := depth - 1 - idx - if h : arrayIdx < ctx.letValues.size then - if let some val := ctx.letValues[arrayIdx] then - -- Shift free bvars in val past the intermediate binders - t := val.liftBVars (idx + 1); continue - return t - | .letE _ val body _ => - t := body.instantiate1 val; continue -- loop instead of recursion - | .proj typeAddr idx struct _ => do - -- cheapProj=true: try structural-only reduction (whnfCore, no delta) - -- cheapProj=false: full reduction (whnf, with delta) - let struct' ← if cheapProj then whnfCore struct cheapRec cheapProj else whnf struct - match ← reduceProj typeAddr idx struct' with - | some result => t := result; continue -- loop instead of recursion - | none => - return if struct == struct' then t else .proj typeAddr idx struct' default - | _ => return t - return t -- unreachable, but needed for type checking - - /-- Try to reduce an application whose head is in WHNF. - Handles recursor iota-reduction and quotient reduction. -/ - partial def tryReduceApp (e : Expr m) : TypecheckM m (Expr m) := do - let fn := e.getAppFn - match fn with - | .const addr _ _ => do - ensureTypedConst addr - match (← get).typedConsts.get? addr with - | some (.recursor _ params motives minors indices isK indAddr rules) => - let args := e.getAppArgs - let majorIdx := params + motives + minors + indices - if h : majorIdx < args.size then - let major := args[majorIdx] - let major' ← whnf major - if isK then - tryKReduction e addr args major' params motives minors indices indAddr - else - tryIotaReduction e addr args major' params indices indAddr rules motives minors - else pure e - | some (.quotient _ kind) => - match kind with - | .lift => tryQuotReduction e 6 3 - | .ind => tryQuotReduction e 5 3 - | _ => pure e - | _ => pure e - | _ => pure e - - /-- K-reduction: for Prop inductives with single zero-field constructor. - Returns the (only) minor premise, plus any extra args after the major. - Only fires when the major premise has already been reduced to a constructor. - (lean4lean's toCtorWhenK also handles non-constructor majors by checking - indices via isDefEq, but that requires infer/isDefEq which are in a - separate mutual block. The whnf of the major should handle most cases.) -/ - partial def tryKReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) - (major : Expr m) (params motives minors indices : Nat) (_indAddr : Address) - : TypecheckM m (Expr m) := do - -- Check if major is a constructor (including nat literal → ctor conversion) - let ctx ← read - let majorCtor := toCtorIfLit ctx.prims major - let isCtor := match majorCtor.getAppFn with - | .const ctorAddr _ _ => - match ctx.kenv.find? ctorAddr with - | some (.ctorInfo _) => true - | _ => false - | _ => false - if !isCtor then return e - -- K-reduction: return the (only) minor premise - let minorIdx := params + motives - if h : minorIdx < args.size then - let mut result := args[minorIdx] - -- Apply extra args after major premise (matching lean4 kernel behavior) - let majorIdx := params + motives + minors + indices - if majorIdx + 1 < args.size then - result := result.mkAppRange (majorIdx + 1) args.size args - return result - pure e - - /-- Iota-reduction: reduce a recursor applied to a constructor. - Follows the lean4 algorithm: - 1. Apply params + motives + minors from recursor args to rule RHS - 2. Apply constructor fields (skip constructor params) to rule RHS - 3. Apply extra args after major premise to rule RHS - Beta reduction happens in the subsequent whnfCore call. -/ - partial def tryIotaReduction (e : Expr m) (_addr : Address) (args : Array (Expr m)) - (major : Expr m) (params indices : Nat) (indAddr : Address) - (rules : Array (Nat × TypedExpr m)) - (motives minors : Nat) : TypecheckM m (Expr m) := do - let prims := (← read).prims - -- Skip large nat literals to avoid O(n) overhead - let skipLargeNat := match major with - | .lit (.natVal n) => indAddr == prims.nat && n > 256 - | _ => false - if skipLargeNat then return e - let majorCtor := toCtorIfLit prims major - let majorFn := majorCtor.getAppFn - match majorFn with - | .const ctorAddr _ _ => do - let kenv := (← read).kenv - let typedConsts := (← get).typedConsts - let ctorInfo? := match kenv.find? ctorAddr with - | some (.ctorInfo cv) => some (cv.cidx, cv.numFields) - | _ => - match typedConsts.get? ctorAddr with - | some (.constructor _ ctorIdx numFields) => some (ctorIdx, numFields) - | _ => none - match ctorInfo? with - | some (ctorIdx, _) => - match rules[ctorIdx]? with - | some (nfields, rhs) => - let majorArgs := majorCtor.getAppArgs - if nfields > majorArgs.size then return e - -- Instantiate universe level params in the rule RHS - let recFn := e.getAppFn - let recLevels := recFn.constLevels! - let mut result := rhs.body.instantiateLevelParams recLevels - -- Phase 1: Apply params + motives + minors from recursor args - let pmmEnd := params + motives + minors - result := result.mkAppRange 0 pmmEnd args - -- Phase 2: Apply constructor fields (skip constructor's own params) - let ctorParamCount := majorArgs.size - nfields - result := result.mkAppRange ctorParamCount majorArgs.size majorArgs - -- Phase 3: Apply remaining arguments after major premise - let majorIdx := params + motives + minors + indices - if majorIdx + 1 < args.size then - result := result.mkAppRange (majorIdx + 1) args.size args - pure result -- return raw result; whnfCore's loop will re-process - | none => pure e - | none => - -- Not a constructor, try structure eta - tryStructEta e args indices indAddr rules major motives minors - | _ => - tryStructEta e args indices indAddr rules major motives minors - - /-- Structure eta: expand struct-like major via projections. -/ - partial def tryStructEta (e : Expr m) (args : Array (Expr m)) - (indices : Nat) (indAddr : Address) - (rules : Array (Nat × TypedExpr m)) (major : Expr m) - (motives minors : Nat) : TypecheckM m (Expr m) := do - let kenv := (← read).kenv - if !kenv.isStructureLike indAddr then return e - match rules[0]? with - | some (nfields, rhs) => - let recFn := e.getAppFn - let recLevels := recFn.constLevels! - let params := args.size - motives - minors - indices - 1 - let mut result := rhs.body.instantiateLevelParams recLevels - -- Phase 1: params + motives + minors - let pmmEnd := params + motives + minors - result := result.mkAppRange 0 pmmEnd args - -- Phase 2: projections as fields - let mut projArgs : Array (Expr m) := #[] - for i in [:nfields] do - projArgs := projArgs.push (Expr.mkProj indAddr i major) - result := projArgs.foldl (fun acc a => Expr.mkApp acc a) result - -- Phase 3: extra args after major - let majorIdx := params + motives + minors + indices - if majorIdx + 1 < args.size then - result := result.mkAppRange (majorIdx + 1) args.size args - pure result -- return raw result; whnfCore's loop will re-process - | none => pure e - - /-- Quotient reduction: Quot.lift / Quot.ind. - For Quot.lift: `@Quot.lift α r β f h q` — reduceSize=6, fPos=3 (f is at index 3) - For Quot.ind: `@Quot.ind α r β f q` — reduceSize=5, fPos=3 (f is at index 3) - When major (q) reduces to `@Quot.mk α r a`, result is `f a`. -/ - partial def tryQuotReduction (e : Expr m) (reduceSize fPos : Nat) : TypecheckM m (Expr m) := do - let args := e.getAppArgs - if args.size < reduceSize then return e - let majorIdx := reduceSize - 1 - if h : majorIdx < args.size then - let major := args[majorIdx] - let major' ← whnf major - let majorFn := major'.getAppFn - match majorFn with - | .const majorAddr _ _ => - ensureTypedConst majorAddr - match (← get).typedConsts.get? majorAddr with - | some (.quotient _ .ctor) => - let majorArgs := major'.getAppArgs - -- Quot.mk has 3 args: [α, r, a]. The data 'a' is the last one. - if majorArgs.size < 3 then throw "Quot.mk should have at least 3 args" - let dataArg := majorArgs[majorArgs.size - 1]! - if h2 : fPos < args.size then - let f := args[fPos] - let result := Expr.mkApp f dataArg - -- Apply any extra args after the major premise - let result := if majorIdx + 1 < args.size then - result.mkAppRange (majorIdx + 1) args.size args - else result - pure result -- return raw result; whnfCore's loop will re-process - else return e - | _ => return e - | _ => return e - else return e - - /-- Try to reduce a Nat primitive, whnf'ing args if needed (like lean4lean's reduceNat). - Inside the mutual block so it can call `whnf` on arguments. -/ - partial def tryReduceNat (e : Expr m) : TypecheckM m (Option (Expr m)) := do - let fn := e.getAppFn - match fn with - | .const addr _ _ => - let prims := (← read).prims - if !isPrimOp prims addr then return none - let args := e.getAppArgs - -- Nat.succ: 1 arg - if addr == prims.natSucc then - if args.size >= 1 then - let a ← whnf args[0]! - match a with - | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) - | _ => return none - else return none - -- Binary nat operations: 2 args, whnf both (matches lean4lean reduceBinNatOp) - else if args.size >= 2 then - let a ← whnf args[0]! - let b ← whnf args[1]! - match a, b with - | .lit (.natVal x), .lit (.natVal y) => - if addr == prims.natAdd then return some (.lit (.natVal (x + y))) - else if addr == prims.natSub then return some (.lit (.natVal (x - y))) - else if addr == prims.natMul then return some (.lit (.natVal (x * y))) - else if addr == prims.natPow then - if y > 16777216 then return none - return some (.lit (.natVal (Nat.pow x y))) - else if addr == prims.natMod then return some (.lit (.natVal (x % y))) - else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) - else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) - else if addr == prims.natBeq then - let boolAddr := if x == y then prims.boolTrue else prims.boolFalse - return some (Expr.mkConst boolAddr #[]) - else if addr == prims.natBle then - let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse - return some (Expr.mkConst boolAddr #[]) - else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) - else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) - else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) - else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) - else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) - else return none - | _, _ => return none - else return none - | _ => return none - - partial def whnf (e : Expr m) : TypecheckM m (Expr m) := do - -- Cache check FIRST — no fuel or stack cost for cache hits - -- Skip cache when let bindings are in scope (bvar zeta makes results context-dependent) - let useWhnfCache := (← read).numLetBindings == 0 - if useWhnfCache then - if let some r := (← get).whnfCache.get? e then return r - withRecDepthCheck do - withFuelCheck do - let r ← whnfImpl e - if useWhnfCache then - modify fun s => { s with whnfCache := s.whnfCache.insert e r } - pure r - - partial def whnfImpl (e : Expr m) : TypecheckM m (Expr m) := do - let mut t ← whnfCore e - let mut steps := 0 - repeat - if steps > 10000 then break -- safety bound - -- Try nat primitive reduction (whnf's args like lean4lean's reduceNat) - if let some r := ← tryReduceNat t then - t ← whnfCore r; steps := steps + 1; continue - -- Handle stuck projections (including inside app chains). - -- Flatten nested projection chains to avoid deep whnf→whnf recursion. - match t.getAppFn with - | .proj _ _ _ _ => - -- Collect the projection chain from outside in - let mut projStack : Array (Address × Nat × Array (Expr m)) := #[] - let mut inner := t - repeat - match inner.getAppFn with - | .proj typeAddr idx struct _ => - projStack := projStack.push (typeAddr, idx, inner.getAppArgs) - inner := struct - | _ => break - -- Reduce the innermost struct with depth-guarded whnf - let innerReduced ← whnf inner - -- Resolve projections from inside out (last pushed = innermost) - let mut current := innerReduced - let mut allResolved := true - let mut i := projStack.size - while i > 0 do - i := i - 1 - let (typeAddr, idx, args) := projStack[i]! - match ← reduceProj typeAddr idx current with - | some result => - let applied := if args.isEmpty then result else result.mkAppN args - current ← whnfCore applied - | none => - -- This projection couldn't be resolved. Reconstruct remaining chain. - let stuck := if args.isEmpty then - Expr.mkProj typeAddr idx current - else - (Expr.mkProj typeAddr idx current).mkAppN args - current ← whnfCore stuck - -- Reconstruct outer projections - while i > 0 do - i := i - 1 - let (ta, ix, as) := projStack[i]! - current := if as.isEmpty then - Expr.mkProj ta ix current - else - (Expr.mkProj ta ix current).mkAppN as - allResolved := false - break - if allResolved || current != t then - t := current; steps := steps + 1; continue - | _ => pure () - -- Try delta unfolding - if let some r := ← unfoldDefinition t then - t ← whnfCore r; steps := steps + 1; continue - break - pure t - - /-- Unfold a single delta step (definition body). -/ - partial def unfoldDefinition (e : Expr m) : TypecheckM m (Option (Expr m)) := do - let head := e.getAppFn - match head with - | .const addr levels _ => do - let ci ← derefConst addr - match ci with - | .defnInfo v => - if v.safety == .partial then return none - let body := v.value.instantiateLevelParams levels - let args := e.getAppArgs - return some (body.mkAppN args) - | .thmInfo v => - let body := v.value.instantiateLevelParams levels - let args := e.getAppArgs - return some (body.mkAppN args) - | _ => return none - | _ => return none -end +-- NOTE: The whnf mutual block has been moved to Infer.lean to enable +-- whnf functions to call infer/isDefEq (needed for toCtorWhenK, isProp checks). +-- Non-mutual helpers (reduceProj, toCtorIfLit, etc.) remain here. /-! ## Literal folding for pretty printing -/ diff --git a/Tests/Ix/Kernel/Helpers.lean b/Tests/Ix/Kernel/Helpers.lean index 6510abe8..77bc840a 100644 --- a/Tests/Ix/Kernel/Helpers.lean +++ b/Tests/Ix/Kernel/Helpers.lean @@ -58,27 +58,28 @@ partial def leanNameToIx : Lean.Name → Ix.Name def addInductive (env : Env .anon) (addr : Address) (type : Expr .anon) (ctors : Array Address) (numParams numIndices : Nat := 0) (isRec := false) - (isUnsafe := false) (numNested := 0) : Env .anon := + (isUnsafe := false) (numNested := 0) + (numLevels : Nat := 0) (all : Array Address := #[addr]) : Env .anon := env.insert addr (.inductInfo { - toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, - numParams, numIndices, all := #[addr], ctors, numNested, + toConstantVal := { numLevels, type, name := (), levelParams := () }, + numParams, numIndices, all, ctors, numNested, isRec, isUnsafe, isReflexive := false }) /-- Build a constructor and insert it into the env. -/ def addCtor (env : Env .anon) (addr : Address) (induct : Address) (type : Expr .anon) (cidx numParams numFields : Nat) - (isUnsafe := false) : Env .anon := + (isUnsafe := false) (numLevels : Nat := 0) : Env .anon := env.insert addr (.ctorInfo { - toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, + toConstantVal := { numLevels, type, name := (), levelParams := () }, induct, cidx, numParams, numFields, isUnsafe }) /-- Build an axiom and insert it into the env. -/ def addAxiom (env : Env .anon) (addr : Address) - (type : Expr .anon) (isUnsafe := false) : Env .anon := + (type : Expr .anon) (isUnsafe := false) (numLevels : Nat := 0) : Env .anon := env.insert addr (.axiomInfo { - toConstantVal := { numLevels := 0, type, name := (), levelParams := () }, + toConstantVal := { numLevels, type, name := (), levelParams := () }, isUnsafe }) diff --git a/Tests/Ix/Kernel/Soundness.lean b/Tests/Ix/Kernel/Soundness.lean index 406bc840..818438a3 100644 --- a/Tests/Ix/Kernel/Soundness.lean +++ b/Tests/Ix/Kernel/Soundness.lean @@ -378,6 +378,398 @@ def validSingleCtor : TestSeq := (expectOk env buildPrimitives indAddr "valid-inductive").1 ) +/-! ## Mutual recursor motive tests -/ + +/-- Shared mutual inductive: A and B, each with a 0-field constructor. + mutual + inductive A : Type where | mk : A + inductive B : Type where | mk : B + end -/ +private def mutualAddrs := do + let aAddr := mkAddr 120 + let bAddr := mkAddr 121 + let aMkAddr := mkAddr 122 + let bMkAddr := mkAddr 123 + (aAddr, bAddr, aMkAddr, bMkAddr) + +private def buildMutualEnv : Env .anon := + let (aAddr, bAddr, aMkAddr, bMkAddr) := mutualAddrs + -- A : Sort 1 + let env : Env .anon := default + let env := env.insert aAddr (.inductInfo { + toConstantVal := { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () }, + numParams := 0, numIndices := 0, all := #[aAddr, bAddr], ctors := #[aMkAddr], + numNested := 0, isRec := false, isUnsafe := false, isReflexive := false + }) + -- A.mk : A + let env := addCtor env aMkAddr aAddr (.const aAddr #[] ()) 0 0 0 + -- B : Sort 1 + let env := env.insert bAddr (.inductInfo { + toConstantVal := { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () }, + numParams := 0, numIndices := 0, all := #[aAddr, bAddr], ctors := #[bMkAddr], + numNested := 0, isRec := false, isUnsafe := false, isReflexive := false + }) + -- B.mk : B + addCtor env bMkAddr bAddr (.const bAddr #[] ()) 0 0 0 + +/-- Build recursor type: + Π (mA : A → Sort u) (mB : B → Sort u) (cA : mA A.mk) (cB : mB B.mk) + (major : majorInd), motive major + where `motive` is bvar idx for the appropriate motive. -/ +private def mkMutualRecType (majorAddr : Address) (motiveRetBvar : Nat) : Expr .anon := + let (aAddr, bAddr, aMkAddr, bMkAddr) := mutualAddrs + -- mA : A → Sort u + .forallE (.forallE (.const aAddr #[] ()) (.sort (.param 0 ())) () ()) + -- mB : B → Sort u + (.forallE (.forallE (.const bAddr #[] ()) (.sort (.param 0 ())) () ()) + -- cA : mA A.mk (under [mA, mB]: mA = bvar 1) + (.forallE (.app (.bvar 1 ()) (.const aMkAddr #[] ())) + -- cB : mB B.mk (under [mA, mB, cA]: mB = bvar 1) + (.forallE (.app (.bvar 1 ()) (.const bMkAddr #[] ())) + -- major : majorInd + (.forallE (.const majorAddr #[] ()) + -- return: motive major (under [mA,mB,cA,cB,major]) + (.app (.bvar motiveRetBvar ()) (.bvar 0 ())) + () ()) + () ()) + () ()) + () ()) + () () + +/-- Test: A.rec with correct motive (motive_0 = outermost, bvar 4) passes -/ +def mutualRecMotiveFirst : TestSeq := + test "accepts A.rec with motive_0 (outermost)" ( + let (aAddr, bAddr, aMkAddr, bMkAddr) := mutualAddrs + let recAddr := mkAddr 130 + let env := buildMutualEnv + -- A.rec type: return type uses mA = bvar 4 + let recType := mkMutualRecType aAddr 4 + -- RHS for A.mk rule: λ mA mB cA cB, cA + -- Under [mA, mB, cA, cB]: cA = bvar 1 + let rhs : Expr .anon := + .lam (.forallE (.const aAddr #[] ()) (.sort (.param 0 ())) () ()) -- mA + (.lam (.forallE (.const bAddr #[] ()) (.sort (.param 0 ())) () ()) -- mB + (.lam (.app (.bvar 1 ()) (.const aMkAddr #[] ())) -- cA + (.lam (.app (.bvar 1 ()) (.const bMkAddr #[] ())) -- cB + (.bvar 1 ()) -- body: cA + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 1 recType #[aAddr, bAddr] 0 0 2 2 + #[{ ctor := aMkAddr, nfields := 0, rhs }] + (expectOk env buildPrimitives recAddr "mutual-rec-motive-first").1 + ) + +/-- Test: B.rec with correct motive (motive_1 = second, bvar 3) passes -/ +def mutualRecMotiveSecond : TestSeq := + test "accepts B.rec with motive_1 (second motive)" ( + let (aAddr, bAddr, aMkAddr, bMkAddr) := mutualAddrs + let recAddr := mkAddr 131 + let env := buildMutualEnv + -- B.rec type: return type uses mB = bvar 3 + let recType := mkMutualRecType bAddr 3 + -- RHS for B.mk rule: λ mA mB cA cB, cB + -- Under [mA, mB, cA, cB]: cB = bvar 0 + let rhs : Expr .anon := + .lam (.forallE (.const aAddr #[] ()) (.sort (.param 0 ())) () ()) -- mA + (.lam (.forallE (.const bAddr #[] ()) (.sort (.param 0 ())) () ()) -- mB + (.lam (.app (.bvar 1 ()) (.const aMkAddr #[] ())) -- cA + (.lam (.app (.bvar 1 ()) (.const bMkAddr #[] ())) -- cB + (.bvar 0 ()) -- body: cB + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 1 recType #[aAddr, bAddr] 0 0 2 2 + #[{ ctor := bMkAddr, nfields := 0, rhs }] + (expectOk env buildPrimitives recAddr "mutual-rec-motive-second").1 + ) + +/-- Test: B.rec with wrong motive (uses mA instead of mB in return) fails -/ +def mutualRecWrongMotive : TestSeq := + test "rejects B.rec with wrong motive in return type" ( + let (aAddr, bAddr, aMkAddr, bMkAddr) := mutualAddrs + let recAddr := mkAddr 132 + let env := buildMutualEnv + -- B.rec type but with return using mA (bvar 4) instead of mB (bvar 3) + let recType := mkMutualRecType bAddr 4 -- wrong: should be 3 + -- RHS for B.mk: λ mA mB cA cB, cB (type is mB B.mk, but recType says mA) + let rhs : Expr .anon := + .lam (.forallE (.const aAddr #[] ()) (.sort (.param 0 ())) () ()) + (.lam (.forallE (.const bAddr #[] ()) (.sort (.param 0 ())) () ()) + (.lam (.app (.bvar 1 ()) (.const aMkAddr #[] ())) + (.lam (.app (.bvar 1 ()) (.const bMkAddr #[] ())) + (.bvar 0 ()) + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 1 recType #[aAddr, bAddr] 0 0 2 2 + #[{ ctor := bMkAddr, nfields := 0, rhs }] + (expectError env buildPrimitives recAddr "mutual-rec-wrong-motive").1 + ) + +/-! ## Mutual recursor with fields (nested-inductive pattern) -/ + +/-- Mutual block with 1-field constructors and a standalone type T: + axiom T : Sort 1 + mutual + inductive C : Sort 1 where | mk : T → C + inductive D : Sort 1 where | mk : T → D + end + Tests field binder shifting and motive selection together. -/ +private def fieldAddrs := do + let tAddr := mkAddr 140 + let cAddr := mkAddr 141 + let dAddr := mkAddr 142 + let cMkAddr := mkAddr 143 + let dMkAddr := mkAddr 144 + (tAddr, cAddr, dAddr, cMkAddr, dMkAddr) + +private def buildFieldMutualEnv : Env .anon := + let (tAddr, cAddr, dAddr, cMkAddr, dMkAddr) := fieldAddrs + -- T : Sort 1 (axiom) + let env : Env .anon := default + let env := addAxiom env tAddr (.sort (.succ .zero)) + -- C : Sort 1 + let env := env.insert cAddr (.inductInfo { + toConstantVal := { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () }, + numParams := 0, numIndices := 0, all := #[cAddr, dAddr], ctors := #[cMkAddr], + numNested := 0, isRec := false, isUnsafe := false, isReflexive := false + }) + -- C.mk : T → C + let env := addCtor env cMkAddr cAddr + (.forallE (.const tAddr #[] ()) (.const cAddr #[] ()) () ()) 0 0 1 + -- D : Sort 1 + let env := env.insert dAddr (.inductInfo { + toConstantVal := { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () }, + numParams := 0, numIndices := 0, all := #[cAddr, dAddr], ctors := #[dMkAddr], + numNested := 0, isRec := false, isUnsafe := false, isReflexive := false + }) + -- D.mk : T → D + addCtor env dMkAddr dAddr + (.forallE (.const tAddr #[] ()) (.const dAddr #[] ()) () ()) 0 0 1 + +/-- Build C.rec or D.rec type with 1-field constructors. + Π (mC : C → Sort u) (mD : D → Sort u) + (cC : Π (t : T), mC (C.mk t)) + (cD : Π (t : T), mD (D.mk t)) + (major : majorInd), motive major + motiveRetBvar: bvar index of motive in the return type (4=mC, 3=mD) -/ +private def mkFieldRecType (majorAddr : Address) (motiveRetBvar : Nat) : Expr .anon := + let (tAddr, cAddr, dAddr, cMkAddr, dMkAddr) := fieldAddrs + -- mC : C → Sort u + .forallE (.forallE (.const cAddr #[] ()) (.sort (.param 0 ())) () ()) + -- mD : D → Sort u + (.forallE (.forallE (.const dAddr #[] ()) (.sort (.param 0 ())) () ()) + -- cC : Π (t : T), mC (C.mk t) [under mC,mD: mC=bvar 1; inner body under mC,mD,t: mC=bvar 2] + (.forallE (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const cMkAddr #[] ()) (.bvar 0 ()))) () ()) + -- cD : Π (t : T), mD (D.mk t) [under mC,mD,cC; inner body under mC,mD,cC,t: mD=bvar 2] + (.forallE (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const dMkAddr #[] ()) (.bvar 0 ()))) () ()) + -- major : majorInd + (.forallE (.const majorAddr #[] ()) + -- return: motive major [under mC,mD,cC,cD,major] + (.app (.bvar motiveRetBvar ()) (.bvar 0 ())) + () ()) + () ()) + () ()) + () ()) + () () + +/-- Test: C.rec with 1-field ctor, motive_0 (bvar 4) passes -/ +def mutualFieldRecFirst : TestSeq := + test "accepts C.rec with fields and motive_0" ( + let (tAddr, cAddr, dAddr, cMkAddr, dMkAddr) := fieldAddrs + let recAddr := mkAddr 150 + let env := buildFieldMutualEnv + let recType := mkFieldRecType cAddr 4 + -- RHS: λ mC mD cC cD (t : T), cC t + -- Under [mC,mD,cC,cD,t]: cC=bvar 2, t=bvar 0 + let rhs : Expr .anon := + .lam (.forallE (.const cAddr #[] ()) (.sort (.param 0 ())) () ()) -- mC + (.lam (.forallE (.const dAddr #[] ()) (.sort (.param 0 ())) () ()) -- mD + (.lam (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const cMkAddr #[] ()) (.bvar 0 ()))) () ()) -- cC + (.lam (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const dMkAddr #[] ()) (.bvar 0 ()))) () ()) -- cD + (.lam (.const tAddr #[] ()) -- t + (.app (.bvar 2 ()) (.bvar 0 ())) -- cC t + () ()) + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 1 recType #[cAddr, dAddr] 0 0 2 2 + #[{ ctor := cMkAddr, nfields := 1, rhs }] + (expectOk env buildPrimitives recAddr "mutual-field-rec-first").1 + ) + +/-- Test: D.rec with 1-field ctor, motive_1 (bvar 3) passes -/ +def mutualFieldRecSecond : TestSeq := + test "accepts D.rec with fields and motive_1" ( + let (tAddr, cAddr, dAddr, cMkAddr, dMkAddr) := fieldAddrs + let recAddr := mkAddr 151 + let env := buildFieldMutualEnv + let recType := mkFieldRecType dAddr 3 + -- RHS: λ mC mD cC cD (t : T), cD t + -- Under [mC,mD,cC,cD,t]: cD=bvar 1, t=bvar 0 + let rhs : Expr .anon := + .lam (.forallE (.const cAddr #[] ()) (.sort (.param 0 ())) () ()) -- mC + (.lam (.forallE (.const dAddr #[] ()) (.sort (.param 0 ())) () ()) -- mD + (.lam (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const cMkAddr #[] ()) (.bvar 0 ()))) () ()) -- cC + (.lam (.forallE (.const tAddr #[] ()) (.app (.bvar 2 ()) (.app (.const dMkAddr #[] ()) (.bvar 0 ()))) () ()) -- cD + (.lam (.const tAddr #[] ()) -- t + (.app (.bvar 1 ()) (.bvar 0 ())) -- cD t + () ()) + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 1 recType #[cAddr, dAddr] 0 0 2 2 + #[{ ctor := dMkAddr, nfields := 1, rhs }] + (expectOk env buildPrimitives recAddr "mutual-field-rec-second").1 + ) + +/-! ## Parametric and nested recursor tests -/ + +/-- Shared universe-polymorphic wrapper W.{u} : Sort (succ u) → Sort (succ u) -/ +private def polyWAddr := mkAddr 170 +private def polyWmAddr := mkAddr 171 + +/-- Build env with W.{u} and W.mk.{u}. -/ +private def addPolyW (env : Env .anon) : Env .anon := + -- W : Sort (succ u) → Sort (succ u) [1 level param] + let wType : Expr .anon := + .forallE (.sort (.succ (.param 0 ()))) (.sort (.succ (.param 0 ()))) () () + let env := addInductive env polyWAddr wType #[polyWmAddr] (numParams := 1) (numLevels := 1) + -- W.mk : ∀ (α : Sort (succ u)), α → W.{u} α [1 level, 1 param, 1 field] + let wmType : Expr .anon := + .forallE (.sort (.succ (.param 0 ()))) + (.forallE (.bvar 0 ()) (.app (.const polyWAddr #[.param 0 ()] ()) (.bvar 1 ())) () ()) + () () + addCtor env polyWmAddr polyWAddr wmType 0 1 1 (numLevels := 1) + +/-- Test: Parametric recursor W.rec.{v,u} with correct level offset. + W.rec : ∀ {α : Sort (succ u)} (motive : W.{u} α → Sort v) + (h : ∀ (a : α), motive (W.mk.{u} α a)) (w : W.{u} α), motive w + RHS for W.mk: λ α motive h a, h a -/ +def parametricRecursor : TestSeq := + test "accepts parametric W.rec with level offset" ( + let recAddr := mkAddr 172 + let env := addPolyW default + -- W.rec type: 2 levels (param 0 = v, param 1 = u), 1 param, 1 motive, 1 minor + let recType : Expr .anon := + -- ∀ (α : Sort (succ u)) + .forallE (.sort (.succ (.param 1 ()))) + -- (motive : W.{u} α → Sort v) + (.forallE (.forallE (.app (.const polyWAddr #[.param 1 ()] ()) (.bvar 0 ())) (.sort (.param 0 ())) () ()) + -- (h : ∀ (a : α), motive (W.mk.{u} α a)) + (.forallE (.forallE (.bvar 1 ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.param 1 ()] ()) (.bvar 2 ())) (.bvar 0 ()))) () ()) + -- (w : W.{u} α) + (.forallE (.app (.const polyWAddr #[.param 1 ()] ()) (.bvar 2 ())) + -- motive w + (.app (.bvar 2 ()) (.bvar 0 ())) + () ()) + () ()) + () ()) + () () + -- RHS: λ α motive h a, h a + let rhs : Expr .anon := + .lam (.sort (.succ (.param 1 ()))) + (.lam (.forallE (.app (.const polyWAddr #[.param 1 ()] ()) (.bvar 0 ())) (.sort (.param 0 ())) () ()) + (.lam (.forallE (.bvar 1 ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.param 1 ()] ()) (.bvar 2 ())) (.bvar 0 ()))) () ()) + (.lam (.bvar 2 ()) + (.app (.bvar 1 ()) (.bvar 0 ())) + () ()) + () ()) + () ()) + () () + let env := addRec env recAddr 2 recType #[polyWAddr] 1 0 1 1 + #[{ ctor := polyWmAddr, nfields := 1, rhs }] + (expectOk env buildPrimitives recAddr "parametric-rec").1 + ) + +/-- Test: Nested auxiliary recursor I.rec_1 for W.{0} I. + I : Sort 1, I.mk : W.{0} I → I + I.rec_1 : ∀ (motive : W.{0} I → Sort v) (h : ∀ (a : I), motive (W.mk.{0} I a)) + (w : W.{0} I), motive w + RHS: λ motive h a, h a + Key: constructor W.mk uses Level.zero (not Level.param 0 which is the elim level). -/ +def nestedAuxRecursor : TestSeq := + test "accepts nested auxiliary recursor I.rec_1 with concrete levels" ( + let iAddr := mkAddr 173 + let imAddr := mkAddr 174 + let rec1Addr := mkAddr 175 + let env := addPolyW default + -- I : Sort 1 [0 levels] + let env := addInductive env iAddr (.sort (.succ .zero)) #[imAddr] (numNested := 1) + -- I.mk : W.{0} I → I [0 levels, 0 params, 1 field] + let imType : Expr .anon := + .forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) + (.const iAddr #[] ()) + () () + let env := addCtor env imAddr iAddr imType 0 0 1 + -- I.rec_1 type: 1 level (param 0 = elim level v), 0 params, 1 motive, 1 minor + let rec1Type : Expr .anon := + -- ∀ (motive : W.{0} I → Sort v) + .forallE (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) (.sort (.param 0 ())) () ()) + -- (h : ∀ (a : I), motive (W.mk.{0} I a)) + (.forallE (.forallE (.const iAddr #[] ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.zero] ()) (.const iAddr #[] ())) (.bvar 0 ()))) () ()) + -- (w : W.{0} I) + (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) + -- motive w + (.app (.bvar 2 ()) (.bvar 0 ())) + () ()) + () ()) + () () + -- RHS: λ motive h a, h a (W.mk uses Level.zero, NOT param 0) + let rhs : Expr .anon := + .lam (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) (.sort (.param 0 ())) () ()) + (.lam (.forallE (.const iAddr #[] ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.zero] ()) (.const iAddr #[] ())) (.bvar 0 ()))) () ()) + (.lam (.const iAddr #[] ()) + (.app (.bvar 1 ()) (.bvar 0 ())) + () ()) + () ()) + () () + let env := addRec env rec1Addr 1 rec1Type #[polyWAddr] 0 0 1 1 + #[{ ctor := polyWmAddr, nfields := 1, rhs }] + (expectOk env buildPrimitives rec1Addr "nested-aux-rec").1 + ) + +/-- Test: Nested auxiliary recursor with wrong RHS (body returns a constant, not h a). + Should be rejected because the inferred RHS type won't match the expected type. -/ +def nestedAuxRecWrongRhs : TestSeq := + test "rejects nested auxiliary recursor with wrong RHS" ( + let iAddr := mkAddr 176 + let imAddr := mkAddr 177 + let rec1Addr := mkAddr 178 + let env := addPolyW default + let env := addInductive env iAddr (.sort (.succ .zero)) #[imAddr] (numNested := 1) + let imType : Expr .anon := + .forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) + (.const iAddr #[] ()) () () + let env := addCtor env imAddr iAddr imType 0 0 1 + let rec1Type : Expr .anon := + .forallE (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) (.sort (.param 0 ())) () ()) + (.forallE (.forallE (.const iAddr #[] ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.zero] ()) (.const iAddr #[] ())) (.bvar 0 ()))) () ()) + (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) + (.app (.bvar 2 ()) (.bvar 0 ())) + () ()) + () ()) + () () + -- Wrong RHS: λ motive h a, motive (instead of h a) + let rhs : Expr .anon := + .lam (.forallE (.app (.const polyWAddr #[.zero] ()) (.const iAddr #[] ())) (.sort (.param 0 ())) () ()) + (.lam (.forallE (.const iAddr #[] ()) (.app (.bvar 1 ()) (.app (.app (.const polyWmAddr #[.zero] ()) (.const iAddr #[] ())) (.bvar 0 ()))) () ()) + (.lam (.const iAddr #[] ()) + (.bvar 2 ()) -- wrong: returns motive instead of h a + () ()) + () ()) + () () + let env := addRec env rec1Addr 1 rec1Type #[polyWAddr] 0 0 1 1 + #[{ ctor := polyWmAddr, nfields := 1, rhs }] + (expectError env buildPrimitives rec1Addr "nested-aux-rec-wrong-rhs").1 + ) + /-! ## Suite -/ def suite : List TestSeq := [ @@ -401,6 +793,17 @@ def suite : List TestSeq := [ recWrongNfields ++ recWrongNumParams ++ recWrongCtorOrder), + group "Mutual recursor motives" + (mutualRecMotiveFirst ++ + mutualRecMotiveSecond ++ + mutualRecWrongMotive), + group "Mutual recursor with fields" + (mutualFieldRecFirst ++ + mutualFieldRecSecond), + group "Parametric and nested recursors" + (parametricRecursor ++ + nestedAuxRecursor ++ + nestedAuxRecWrongRhs), group "Constructor validation" ctorParamMismatch, group "Sanity" diff --git a/Tests/Ix/Kernel/Unit.lean b/Tests/Ix/Kernel/Unit.lean index 3fc42f29..3575a6ca 100644 --- a/Tests/Ix/Kernel/Unit.lean +++ b/Tests/Ix/Kernel/Unit.lean @@ -280,6 +280,86 @@ def testHelperFunctions : TestSeq := test "getCtorReturnType: skips foralls" (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) +/-! ## Primitive helpers -/ + +def testToCtorIfLit : TestSeq := + let prims := buildPrimitives + -- natVal 0 => Nat.zero + test "toCtorIfLit 0 = Nat.zero" + (toCtorIfLit prims (.lit (.natVal 0) : Expr .anon) == Expr.mkConst prims.natZero #[]) $ + -- natVal 1 => Nat.succ (natVal 0) + test "toCtorIfLit 1 = Nat.succ 0" + (toCtorIfLit prims (.lit (.natVal 1) : Expr .anon) == + Expr.mkApp (Expr.mkConst prims.natSucc #[]) (.lit (.natVal 0))) $ + -- natVal 5 => Nat.succ (natVal 4) + test "toCtorIfLit 5 = Nat.succ 4" + (toCtorIfLit prims (.lit (.natVal 5) : Expr .anon) == + Expr.mkApp (Expr.mkConst prims.natSucc #[]) (.lit (.natVal 4))) $ + -- non-nat unchanged + test "toCtorIfLit sort = sort" + (toCtorIfLit prims (.sort .zero : Expr .anon) == (.sort .zero : Expr .anon)) $ + test "toCtorIfLit strVal = strVal" + (toCtorIfLit prims (.lit (.strVal "hi") : Expr .anon) == (.lit (.strVal "hi") : Expr .anon)) + +def testStrLitToConstructor : TestSeq := + let prims := buildPrimitives + -- empty string => String.mk (List.nil Char) + let empty := strLitToConstructor (m := .anon) prims "" + test "strLitToConstructor empty head is stringMk" + (empty.getAppFn.isConstOf prims.stringMk) $ + test "strLitToConstructor empty has 1 arg" + (empty.getAppNumArgs == 1) $ + -- the arg of empty string should be List.nil applied to Char + test "strLitToConstructor empty arg head is listNil" + (empty.appArg!.getAppFn.isConstOf prims.listNil) $ + -- single char string + let single := strLitToConstructor (m := .anon) prims "a" + test "strLitToConstructor \"a\" head is stringMk" + (single.getAppFn.isConstOf prims.stringMk) $ + -- roundtrip: foldLiterals should recover the string literal + test "foldLiterals roundtrips empty" + (foldLiterals prims empty == .lit (.strVal "")) $ + test "foldLiterals roundtrips \"a\"" + (foldLiterals prims single == .lit (.strVal "a")) + +def testIsPrimOp : TestSeq := + let prims := buildPrimitives + test "isPrimOp natAdd" (isPrimOp prims prims.natAdd) $ + test "isPrimOp natSucc" (isPrimOp prims prims.natSucc) $ + test "isPrimOp natSub" (isPrimOp prims prims.natSub) $ + test "isPrimOp natMul" (isPrimOp prims prims.natMul) $ + test "isPrimOp natGcd" (isPrimOp prims prims.natGcd) $ + test "isPrimOp natMod" (isPrimOp prims prims.natMod) $ + test "isPrimOp natDiv" (isPrimOp prims prims.natDiv) $ + test "isPrimOp natBeq" (isPrimOp prims prims.natBeq) $ + test "isPrimOp natBle" (isPrimOp prims prims.natBle) $ + test "isPrimOp natLand" (isPrimOp prims prims.natLand) $ + test "isPrimOp natLor" (isPrimOp prims prims.natLor) $ + test "isPrimOp natXor" (isPrimOp prims prims.natXor) $ + test "isPrimOp natShiftLeft" (isPrimOp prims prims.natShiftLeft) $ + test "isPrimOp natShiftRight" (isPrimOp prims prims.natShiftRight) $ + test "isPrimOp natPow" (isPrimOp prims prims.natPow) $ + test "not isPrimOp nat" (!isPrimOp prims prims.nat) $ + test "not isPrimOp bool" (!isPrimOp prims prims.bool) $ + test "not isPrimOp default" (!isPrimOp prims default) + +def testFoldLiterals : TestSeq := + let prims := buildPrimitives + -- Nat.zero => lit 0 + test "foldLiterals Nat.zero = lit 0" + (foldLiterals prims (Expr.mkConst prims.natZero #[] : Expr .anon) == .lit (.natVal 0)) $ + -- Nat.succ (lit 0) => lit 1 + let succZero : Expr .anon := Expr.mkApp (Expr.mkConst prims.natSucc #[]) (.lit (.natVal 0)) + test "foldLiterals Nat.succ(lit 0) = lit 1" + (foldLiterals prims succZero == .lit (.natVal 1)) $ + -- Nat.succ (lit 4) => lit 5 + let succ4 : Expr .anon := Expr.mkApp (Expr.mkConst prims.natSucc #[]) (.lit (.natVal 4)) + test "foldLiterals Nat.succ(lit 4) = lit 5" + (foldLiterals prims succ4 == .lit (.natVal 5)) $ + -- non-nat expressions are unchanged + test "foldLiterals bvar = bvar" + (foldLiterals prims (.bvar 0 () : Expr .anon) == (.bvar 0 () : Expr .anon)) + /-! ## Suite -/ def suite : List TestSeq := [ @@ -293,6 +373,11 @@ def suite : List TestSeq := [ group "bulk instantiation" testLevelInstBulkReduce, group "Reducibility hints" testReducibilityHintsLt, group "Inductive helpers" testHelperFunctions, + group "Primitive helpers" $ + group "toCtorIfLit" testToCtorIfLit ++ + group "strLitToConstructor" testStrLitToConstructor ++ + group "isPrimOp" testIsPrimOp ++ + group "foldLiterals" testFoldLiterals, ] end Tests.Ix.Kernel.Unit diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index 7dab1364..c8995332 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -108,6 +108,9 @@ def testConsts : TestSeq := "Nat.gcd", "Nat.beq", "Nat.ble", "Nat.land", "Nat.lor", "Nat.xor", "Nat.shiftLeft", "Nat.shiftRight", "Nat.pow", + "Nat.pred", "Nat.bitwise", + -- String/Char primitives + "Char.ofNat", "String.ofList", -- Recursors "List.rec", -- Delta unfolding @@ -146,8 +149,6 @@ def testConsts : TestSeq := "Lean.Elab.Term.Do.Code.action", -- UInt64/BitVec isDefEq regression "UInt64.decLt", - -- Recursor-only Ixon block regression (rec.all was empty) - "Lean.Elab.Tactic.RCases.RCasesPatt.rec_1", -- Dependencies of _sunfold (check these first to rule out lazy blowup) "Std.Time.FormatPart", "Std.Time.FormatConfig", @@ -184,7 +185,13 @@ def testConsts : TestSeq := -- rfl theorem: both sides must be defeq via delta unfolding "Std.Tactic.BVDecide.BVExpr.eval.eq_10", -- K-reduction: extra args after major premise must be applied - "UInt8.toUInt64_toUSize" + "UInt8.toUInt64_toUSize", + -- DHashMap: rfl theorem requiring projection reduction + eta-struct + "Std.DHashMap.Internal.Raw₀.contains_eq_containsₘ", + -- K-reduction: toCtorWhenK must check isDefEq before reducing + "instDecidableEqVector.decEq", + -- Recursor-only Ixon block regression (rec.all was empty) + "Lean.Elab.Tactic.RCases.RCasesPatt.rec_1", ] let mut passed := 0 let mut failures : Array String := #[] @@ -237,9 +244,21 @@ def testVerifyPrimAddrs : TestSeq := let hardcoded := Ix.Kernel.buildPrimitives let mut failures : Array String := #[] let checks : Array (String × String × Address) := #[ + -- Core types and constructors ("nat", "Nat", hardcoded.nat), ("natZero", "Nat.zero", hardcoded.natZero), ("natSucc", "Nat.succ", hardcoded.natSucc), + ("bool", "Bool", hardcoded.bool), + ("boolTrue", "Bool.true", hardcoded.boolTrue), + ("boolFalse", "Bool.false", hardcoded.boolFalse), + ("string", "String", hardcoded.string), + ("stringMk", "String.mk", hardcoded.stringMk), + ("char", "Char", hardcoded.char), + ("charMk", "Char.ofNat", hardcoded.charMk), + ("list", "List", hardcoded.list), + ("listNil", "List.nil", hardcoded.listNil), + ("listCons", "List.cons", hardcoded.listCons), + -- Nat arithmetic primitives ("natAdd", "Nat.add", hardcoded.natAdd), ("natSub", "Nat.sub", hardcoded.natSub), ("natMul", "Nat.mul", hardcoded.natMul), @@ -254,16 +273,30 @@ def testVerifyPrimAddrs : TestSeq := ("natXor", "Nat.xor", hardcoded.natXor), ("natShiftLeft", "Nat.shiftLeft", hardcoded.natShiftLeft), ("natShiftRight", "Nat.shiftRight", hardcoded.natShiftRight), - ("bool", "Bool", hardcoded.bool), - ("boolTrue", "Bool.true", hardcoded.boolTrue), - ("boolFalse", "Bool.false", hardcoded.boolFalse), - ("string", "String", hardcoded.string), - ("stringMk", "String.mk", hardcoded.stringMk), - ("char", "Char", hardcoded.char), - ("charMk", "Char.ofNat", hardcoded.charMk), - ("list", "List", hardcoded.list), - ("listNil", "List.nil", hardcoded.listNil), - ("listCons", "List.cons", hardcoded.listCons) + ("natPred", "Nat.pred", hardcoded.natPred), + ("natBitwise", "Nat.bitwise", hardcoded.natBitwise), + ("natModCoreGo", "Nat.modCore.go", hardcoded.natModCoreGo), + ("natDivGo", "Nat.div.go", hardcoded.natDivGo), + -- String/Char definitions + ("stringOfList", "String.ofList", hardcoded.stringOfList), + -- Eq + ("eq", "Eq", hardcoded.eq), + ("eqRefl", "Eq.refl", hardcoded.eqRefl), + -- Extra: mod/div/gcd validation helpers + ("natLE", "Nat.instLE.le", hardcoded.natLE), + ("natDecLe", "Nat.decLe", hardcoded.natDecLe), + ("natDecEq", "Nat.decEq", hardcoded.natDecEq), + ("natBleRefl", "Nat.le_of_ble_eq_true", hardcoded.natBleRefl), + ("natNotBleRefl", "Nat.not_le_of_not_ble_eq_true", hardcoded.natNotBleRefl), + ("natBeqRefl", "Nat.eq_of_beq_eq_true", hardcoded.natBeqRefl), + ("natNotBeqRefl", "Nat.ne_of_beq_eq_false", hardcoded.natNotBeqRefl), + ("ite", "ite", hardcoded.ite), + ("dite", "dite", hardcoded.dite), + ("not", "Not", hardcoded.«not»), + ("accRec", "Acc.rec", hardcoded.accRec), + ("accIntro", "Acc.intro", hardcoded.accIntro), + ("natLtSuccSelf", "Nat.lt_succ_self", hardcoded.natLtSuccSelf), + ("natDivRecFuelLemma", "Nat.div_rec_fuel_lemma", hardcoded.natDivRecFuelLemma) ] for (field, name, expected) in checks do let actual := lookupPrim ixonEnv name @@ -283,16 +316,35 @@ def testDumpPrimAddrs : TestSeq := let leanEnv ← get_env! let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv let names := #[ + -- Core types and constructors ("nat", "Nat"), ("natZero", "Nat.zero"), ("natSucc", "Nat.succ"), + ("bool", "Bool"), ("boolTrue", "Bool.true"), ("boolFalse", "Bool.false"), + ("string", "String"), ("stringMk", "String.mk"), + ("char", "Char"), ("charMk", "Char.ofNat"), + ("list", "List"), ("listNil", "List.nil"), ("listCons", "List.cons"), + -- Nat arithmetic primitives ("natAdd", "Nat.add"), ("natSub", "Nat.sub"), ("natMul", "Nat.mul"), ("natPow", "Nat.pow"), ("natGcd", "Nat.gcd"), ("natMod", "Nat.mod"), ("natDiv", "Nat.div"), ("natBeq", "Nat.beq"), ("natBle", "Nat.ble"), ("natLand", "Nat.land"), ("natLor", "Nat.lor"), ("natXor", "Nat.xor"), ("natShiftLeft", "Nat.shiftLeft"), ("natShiftRight", "Nat.shiftRight"), - ("bool", "Bool"), ("boolTrue", "Bool.true"), ("boolFalse", "Bool.false"), - ("string", "String"), ("stringMk", "String.mk"), - ("char", "Char"), ("charMk", "Char.ofNat"), - ("list", "List"), ("listNil", "List.nil"), ("listCons", "List.cons") + ("natPred", "Nat.pred"), ("natBitwise", "Nat.bitwise"), + ("natModCoreGo", "Nat.modCore.go"), ("natDivGo", "Nat.div.go"), + -- String/Char definitions + ("stringOfList", "String.ofList"), + -- Eq + ("eq", "Eq"), ("eqRefl", "Eq.refl"), + -- Extra: mod/div/gcd validation helpers + ("natLE", "Nat.instLE.le"), ("natDecLe", "Nat.decLe"), + ("natDecEq", "Nat.decEq"), + ("natBleRefl", "Nat.le_of_ble_eq_true"), + ("natNotBleRefl", "Nat.not_le_of_not_ble_eq_true"), + ("natBeqRefl", "Nat.eq_of_beq_eq_true"), + ("natNotBeqRefl", "Nat.ne_of_beq_eq_false"), + ("ite", "ite"), ("dite", "dite"), ("«not»", "Not"), + ("accRec", "Acc.rec"), ("accIntro", "Acc.intro"), + ("natLtSuccSelf", "Nat.lt_succ_self"), + ("natDivRecFuelLemma", "Nat.div_rec_fuel_lemma") ] for (field, name) in names do IO.println s!"{field} := \"{lookupPrim ixonEnv name}\""