diff --git a/AGENTS.md b/AGENTS.md index 959d508..223cfcb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -34,7 +34,7 @@ pathrex/ │ ├── mod.rs # FormatError enum, re-exports │ ├── csv.rs # Csv — CSV → Edge iterator (CsvConfig, ColumnSpec) │ ├── mm.rs # MatrixMarket directory loader (vertices.txt, edges.txt, *.txt) -│ └── nt.rs # NTriples — N-Triples → Edge iterator (full predicate IRI labels) +│ └── rdf.rs # Rdf — unified RDF parser (N-Triples, Turtle) → Edge iterator ├── tests/ │ ├── inmemory_tests.rs # Integration tests for InMemoryBuilder / InMemoryGraph │ ├── mm_tests.rs # Integration tests for MatrixMarket format @@ -139,7 +139,7 @@ feed itself into a specific [`GraphBuilder`]: - [`apply_to(self, builder: B) -> Result`](src/graph/mod.rs:169) — consumes the source and returns the populated builder. -[`Csv`](src/formats/csv.rs), [`MatrixMarket`](src/formats/mm.rs), and [`NTriples`](src/formats/nt.rs) +[`Csv`](src/formats/csv.rs), [`MatrixMarket`](src/formats/mm.rs), and [`Rdf`](src/formats/rdf.rs) implement `GraphSource` (see [`src/graph/inmemory.rs`](src/graph/inmemory.rs)), so they can be passed to [`GraphBuilder::load`] and [`Graph::try_from`]. @@ -207,7 +207,6 @@ which is used by the MatrixMarket loader. Three built-in parsers are available, each yielding `Iterator>` and pluggable into `GraphBuilder::load()` via `GraphSource` (see [`src/graph/inmemory.rs`](src/graph/inmemory.rs)). -CSV and MatrixMarket edge loaders are available: #### `Csv` @@ -251,26 +250,40 @@ Helper functions: `MatrixMarket` implements `GraphSource` in [`src/graph/inmemory.rs`](src/graph/inmemory.rs) (see the `impl` at line 215): `vertices.txt` maps are converted from 1-based file indices to 0-based matrix ids before [`set_node_map`](src/graph/inmemory.rs:67); `edges.txt` indices are unchanged for `n.txt` lookup. -#### `NTriples` +#### `Rdf` — Unified RDF Parser -[`NTriples`](src/formats/nt.rs:64) parses [W3C N-Triples](https://www.w3.org/TR/n-triples/) -RDF files using `oxttl` and `oxrdf`. Each triple `(subject, predicate, object)` becomes an -[`Edge`](src/graph/mod.rs:158) where: +[`Rdf`](src/formats/rdf.rs) is a unified parser for RDF formats using `oxttl` and `oxrdf`. +It supports both **N-Triples** (`.nt`) and **Turtle** (`.ttl`) formats via the [`RdfFormat`](src/formats/rdf.rs) enum. + +Each triple `(subject, predicate, object)` becomes an [`Edge`](src/graph/mod.rs:158) where: - `source` — subject IRI or blank-node ID (`_:label`). - `target` — object IRI or blank-node ID; triples whose object is an RDF literal yield `Err(FormatError::LiteralAsNode)` (callers may filter these out). -- `label` — predicate IRI, transformed by [`LabelExtraction`](src/formats/nt.rs:38): +- `label` — full predicate IRI string (including fragment `#…` when present). + +Constructor: -| Variant | Behaviour | +- [`Rdf::from_path(path)`](src/formats/rdf.rs) — auto-detects format from file extension (`.nt` → N-Triples, `.ttl` → Turtle). Parses in parallel using memory-mapping and rayon. + +Format detection via [`RdfFormat::from_path(path)`](src/formats/rdf.rs): + +| Extension | Format | |---|---| -| `LocalName` (default) | Fragment (`#name`) or last path segment of the predicate IRI | -| `FullIri` | Full predicate IRI string | +| `.nt`, `.ntriples` | `RdfFormat::NTriples` | +| `.ttl`, `.turtle` | `RdfFormat::Turtle` | -Constructors: +Example usage: -- [`NTriples::new(reader)`](src/formats/nt.rs:70) — uses `LabelExtraction::LocalName`. -- [`NTriples::with_label_extraction(reader, strategy)`](src/formats/nt.rs:74) — explicit strategy. +```rust +use pathrex::formats::Rdf; +use pathrex::graph::{Graph, InMemory}; + +// Auto-detect from extension +let graph = Graph::::try_from( + Rdf::from_path("data.ttl")? +)?; +``` ### SPARQL parsing (`src/sparql/mod.rs`) @@ -421,7 +434,10 @@ LAGraph. Safe Rust wrappers live in [`graph::mod`](src/graph/mod.rs): - [`GraphblasVector`](src/graph/mod.rs:128) — RAII wrapper around `GrB_Vector` (derives `Debug`). - [`GraphblasMatrix`](src/graph/mod.rs) — RAII wrapper around `GrB_Matrix` (`dup` + `free` on drop). -- [`ensure_grb_init()`](src/graph/mod.rs:39) — one-time `LAGraph_Init` via `std::sync::Once`. +- [`ensure_grb_init()`](src/graph/wrappers.rs:11) — internal one-time `LAGraph_Init` via + `std::sync::Once`. Called automatically by RAII-wrapped constructors + (`LagraphGraph::from_coo`, `LagraphGraph::from_matrix`, `ThreadScope::enter`) and by + `load_mm_file`. Crate-private; no other code should call it. ### Macros & helpers (`src/utils.rs`) @@ -472,7 +488,7 @@ Tests in `src/graph/mod.rs` use `CountingBuilder` / `CountOutput` / `VecSource` [`src/utils.rs`](src/utils.rs) — these do **not** call into GraphBLAS and run without native libraries. -Tests in `src/formats/csv.rs` and `src/formats/nt.rs` are pure Rust and need no native dependencies. +Tests in `src/formats/csv.rs` and `src/formats/rdf.rs` are pure Rust and need no native dependencies. Tests in `src/sparql/mod.rs` are pure Rust and need no native dependencies. diff --git a/Cargo.toml b/Cargo.toml index 9825c87..c454400 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,10 @@ edition = "2024" csv = "1.4.0" egg = "0.10.0" libc = "0.2" +memmap2 = "0.9" oxrdf = "0.3.3" oxttl = "0.2.3" +rayon = "1" rustfst = "1.2" spargebra = "0.4.6" thiserror = "1.0" diff --git a/build.rs b/build.rs index 9b12997..243bbcb 100644 --- a/build.rs +++ b/build.rs @@ -82,6 +82,8 @@ fn regenerate_bindings() { .allowlist_function("LAGraph_CheckGraph") .allowlist_function("LAGraph_Init") .allowlist_function("LAGraph_Finalize") + .allowlist_function("LAGraph_SetNumThreads") + .allowlist_function("LAGraph_GetNumThreads") .allowlist_function("LAGraph_New") .allowlist_function("LAGraph_Delete") .allowlist_function("LAGraph_Cached_AT") diff --git a/src/formats/mm.rs b/src/formats/mm.rs index 9b16576..5bf44b9 100644 --- a/src/formats/mm.rs +++ b/src/formats/mm.rs @@ -24,56 +24,12 @@ //! ``` use std::collections::HashMap; -use std::ffi::CString; use std::fs::File; use std::io::{BufRead, BufReader}; -use std::os::fd::IntoRawFd; use std::path::{Path, PathBuf}; use crate::formats::FormatError; -use crate::graph::{GraphError, ensure_grb_init}; -use crate::la_ok; -use crate::lagraph_sys::{FILE, GrB_Matrix, LAGraph_MMRead}; - -/// Read a single MatrixMarket file and return the raw [`GrB_Matrix`]. -pub fn load_mm_file(path: impl AsRef) -> Result { - let path = path.as_ref(); - - ensure_grb_init().map_err(|e| match e { - GraphError::LAGraph(info, msg) => FormatError::MatrixMarket { - code: info, - message: msg, - }, - _ => FormatError::MatrixMarket { - code: crate::lagraph_sys::GrB_Info::GrB_PANIC, - message: "Failed to initialize GraphBLAS".to_string(), - }, - })?; - - let file = File::open(path)?; - let fd = file.into_raw_fd(); - - let c_mode = CString::new("r").unwrap(); - let f = unsafe { libc::fdopen(fd, c_mode.as_ptr()) }; - if f.is_null() { - unsafe { libc::close(fd) }; - return Err(std::io::Error::last_os_error().into()); - } - - let mut matrix: GrB_Matrix = std::ptr::null_mut(); - - let err = la_ok!(LAGraph_MMRead(&mut matrix, f as *mut FILE)); - unsafe { libc::fclose(f) }; - - match err { - Ok(_) => Ok(matrix), - Err(GraphError::LAGraph(info, msg)) => Err(FormatError::MatrixMarket { - code: info, - message: msg, - }), - _ => unreachable!("should be either mm read error or ok"), - } -} +pub use crate::graph::load_mm_file; // Trims first "<" and last ">". fn normalize_map_name(name: &str) -> String { @@ -92,12 +48,12 @@ pub(crate) fn apply_base_iri(name: String, base: Option<&str>) -> String { } } +type IndexMap = (HashMap, HashMap); + /// Parse a ` ` mapping file. /// /// Throws error on non-positive or duplicate indicies -pub(crate) fn parse_index_map( - path: &Path, -) -> Result<(HashMap, HashMap), FormatError> { +pub(crate) fn parse_index_map(path: &Path) -> Result { let file_name = path .file_name() .map(|n| n.to_string_lossy().into_owned()) @@ -189,7 +145,7 @@ impl MatrixMarket { self } - pub(crate) fn mm_path(&self, idx: usize) -> PathBuf { + pub fn mm_path(&self, idx: usize) -> PathBuf { self.dir.join(format!("{}.txt", idx)) } } @@ -278,10 +234,12 @@ mod tests { #[test] fn test_load_nonexistent_mm_file_returns_io_error() { + use crate::formats::FormatError; + use crate::graph::GraphError; let result = load_mm_file("/nonexistent/path/to/file.txt"); assert!( - matches!(result, Err(FormatError::Io(_))), - "expected Io error for missing file, got: {:?}", + matches!(result, Err(GraphError::Format(FormatError::Io(_)))), + "expected Format(Io) error for missing file, got: {:?}", result ); } diff --git a/src/formats/mod.rs b/src/formats/mod.rs index 480d720..b78a5cc 100644 --- a/src/formats/mod.rs +++ b/src/formats/mod.rs @@ -4,28 +4,29 @@ //! //! ```no_run //! use pathrex::graph::{Graph, InMemory, GraphDecomposition}; -//! use pathrex::formats::{Csv, NTriples}; +//! use pathrex::formats::{Csv, Rdf}; //! use std::fs::File; //! -//! // Build from CSV in one line +//! // Build from CSV //! let g = Graph::::try_from( //! Csv::from_reader(File::open("edges.csv").unwrap()).unwrap() //! ).unwrap(); //! -//! // Build from N-Triples in one line +//! // Build from Turtle (auto-detect from extension) //! let g2 = Graph::::try_from( -//! NTriples::new(File::open("data.nt").unwrap()) +//! Rdf::from_path("data.ttl").unwrap() //! ).unwrap(); //! ``` pub mod csv; pub mod mm; -pub mod nt; +pub mod rdf; pub use csv::Csv; pub use mm::MatrixMarket; -pub use nt::NTriples; +pub use rdf::{Rdf, RdfFormat}; +use oxttl::TurtleSyntaxError; use thiserror::Error; use crate::lagraph_sys::GrB_Info; @@ -57,9 +58,9 @@ pub enum FormatError { reason: String, }, - /// An error produced by the N-Triples parser. - #[error("N-Triples parse error: {0}")] - NTriples(String), + /// An error produced by an RDF parser (N-Triples, Turtle, etc.) + #[error("RDF parse error: {0}")] + Rdf(#[from] TurtleSyntaxError), /// An RDF literal appeared as a subject or object where a node IRI or /// blank node was expected. diff --git a/src/formats/nt.rs b/src/formats/nt.rs deleted file mode 100644 index aa08880..0000000 --- a/src/formats/nt.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! N-Triples edge iterator for the formats layer. -//! -//! ```no_run -//! use pathrex::formats::NTriples; -//! use pathrex::formats::FormatError; -//! -//! # let reader = std::io::empty(); -//! let iter = NTriples::new(reader) -//! .filter_map(|r| match r { -//! Err(FormatError::LiteralAsNode) => None, // skip -//! other => Some(other), -//! }); -//! ``` -//! -//! To load into a graph: -//! -//! ```no_run -//! use pathrex::graph::{Graph, InMemory, GraphDecomposition}; -//! use pathrex::formats::NTriples; -//! use std::fs::File; -//! -//! let graph = Graph::::try_from( -//! NTriples::new(File::open("data.nt").unwrap()) -//! ).unwrap(); -//! ``` - -use std::io::Read; - -use oxrdf::{NamedOrBlankNode, Term}; -use oxttl::NTriplesParser; -use oxttl::ntriples::ReaderNTriplesParser; - -use crate::formats::FormatError; -use crate::graph::Edge; - -/// An iterator that reads N-Triples and yields `Result`. -/// -/// # Example -/// -/// ```no_run -/// use pathrex::formats::nt::NTriples; -/// use std::fs::File; -/// -/// let file = File::open("data.nt").unwrap(); -/// let iter = NTriples::new(file); -/// for result in iter { -/// let edge = result.unwrap(); -/// println!("{} --{}--> {}", edge.source, edge.label, edge.target); -/// } -/// ``` -pub struct NTriples { - inner: ReaderNTriplesParser, -} - -impl NTriples { - pub fn new(reader: R) -> Self { - Self { - inner: NTriplesParser::new().for_reader(reader), - } - } - - fn subject_to_node_id(subject: NamedOrBlankNode) -> String { - match subject { - NamedOrBlankNode::NamedNode(n) => n.into_string(), - NamedOrBlankNode::BlankNode(b) => format!("_:{}", b.as_str()), - } - } - - fn object_to_node_id(object: Term) -> Result { - match object { - Term::NamedNode(n) => Ok(n.into_string()), - Term::BlankNode(b) => Ok(format!("_:{}", b.as_str())), - Term::Literal(_) => Err(FormatError::LiteralAsNode), - } - } -} - -impl Iterator for NTriples { - type Item = Result; - - fn next(&mut self) -> Option { - let triple = match self.inner.next()? { - Ok(t) => t, - Err(e) => return Some(Err(FormatError::NTriples(e.to_string()))), - }; - - let source = Self::subject_to_node_id(triple.subject.into()); - let label = triple.predicate.as_str().to_owned(); - let target = match Self::object_to_node_id(triple.object) { - Ok(t) => t, - Err(e) => return Some(Err(e)), - }; - - Some(Ok(Edge { - source, - target, - label, - })) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn parse(nt: &str) -> Vec> { - NTriples::new(nt.as_bytes()).collect() - } - - #[test] - fn test_basic_ntriples() { - let nt = " .\n\ - .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 2); - - let e0 = edges[0].as_ref().unwrap(); - assert_eq!(e0.source, "http://example.org/Alice"); - assert_eq!(e0.target, "http://example.org/Bob"); - assert_eq!(e0.label, "http://example.org/knows"); - - let e1 = edges[1].as_ref().unwrap(); - assert_eq!(e1.source, "http://example.org/Bob"); - assert_eq!(e1.target, "http://example.org/Charlie"); - assert_eq!(e1.label, "http://example.org/likes"); - } - - #[test] - fn test_blank_node_subject_and_object() { - let nt = "_:b1 _:b2 .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 1); - - let e = edges[0].as_ref().unwrap(); - assert_eq!(e.source, "_:b1"); - assert_eq!(e.target, "_:b2"); - } - - #[test] - fn test_literal_object_yields_error() { - let nt = " \"Alice\" .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 1); - assert!( - matches!(edges[0], Err(FormatError::LiteralAsNode)), - "literal object should yield LiteralAsNode error" - ); - } - - #[test] - fn test_caller_can_skip_literal_triples() { - let nt = " .\n\ - \"Alice\" .\n\ - .\n"; - let edges: Vec<_> = NTriples::new(nt.as_bytes()) - .filter_map(|r| match r { - Err(FormatError::LiteralAsNode) => None, - other => Some(other), - }) - .collect(); - - assert_eq!(edges.len(), 2, "literal triple should be skipped"); - assert!(edges.iter().all(|r| r.is_ok())); - } - - #[test] - fn test_predicate_with_fragment_is_full_iri_string() { - let nt = - " .\n"; - let edges = parse(nt); - assert_eq!( - edges[0].as_ref().unwrap().label, - "http://example.org/ns#knows" - ); - } - - #[test] - fn test_non_ascii_in_iris() { - let nt = " .\n\ - .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 2); - - let e0 = edges[0].as_ref().unwrap(); - assert_eq!(e0.source, "http://example.org/人甲"); - assert_eq!(e0.target, "http://example.org/人乙"); - assert_eq!(e0.label, "http://example.org/关系/认识"); - - let e1 = edges[1].as_ref().unwrap(); - assert_eq!(e1.source, "http://example.org/Алиса"); - assert_eq!(e1.target, "http://example.org/Боб"); - assert_eq!(e1.label, "http://example.org/знает"); - } - - #[test] - fn test_ntriples_graph_source() { - use crate::graph::{GraphBuilder, GraphDecomposition, InMemoryBuilder}; - - let nt = " .\n\ - .\n"; - let iter = NTriples::new(nt.as_bytes()); - - let graph = InMemoryBuilder::default() - .load(iter) - .expect("load should succeed") - .build() - .expect("build should succeed"); - assert_eq!(graph.num_nodes(), 3); - } -} diff --git a/src/formats/rdf.rs b/src/formats/rdf.rs new file mode 100644 index 0000000..4fc0c2a --- /dev/null +++ b/src/formats/rdf.rs @@ -0,0 +1,301 @@ +//! RDF parser supporting N-Tripples and Turtle formats. +//! +//! # Example +//! ```no_run +//! use pathrex::formats::{Rdf, RdfFormat}; +//! use pathrex::graph::{Graph, InMemory}; +//! +//! // Auto-detect from path +//! let graph = Graph::::try_from( +//! Rdf::from_path("data.ttl").unwrap() +//! ).unwrap(); +//! ``` + +use std::ops::Deref; +use std::path::Path; + +use oxrdf::{NamedOrBlankNode, Term, Triple}; +use oxttl::{NTriplesParser, TurtleParser}; +use rayon::prelude::*; + +use crate::formats::FormatError; +use crate::graph::Edge; + +enum RdfData { + Mapped(memmap2::Mmap), + Owned(Vec), +} + +impl Deref for RdfData { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + match self { + RdfData::Mapped(m) => m, + RdfData::Owned(v) => v, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RdfFormat { + /// N-Triples format (.nt) + NTriples, + /// Turtle format (.ttl) + Turtle, +} + +impl RdfFormat { + /// Detect format from file extension. + pub fn from_path>(path: P) -> Option { + match path.as_ref().extension()?.to_str()? { + "nt" | "ntriples" => Some(Self::NTriples), + "ttl" | "turtle" => Some(Self::Turtle), + _ => None, + } + } +} + +/// RDF parser supporting N-Triples and Turtle formats. +/// +/// # Example +/// ```no_run +/// use pathrex::formats::{Rdf, RdfFormat}; +/// use pathrex::graph::{Graph, InMemory}; +/// +/// let graph = Graph::::try_from( +/// Rdf::from_path("data.ttl").unwrap() +/// ).unwrap(); +/// ``` +pub struct Rdf { + data: RdfData, + format: RdfFormat, +} + +impl Rdf { + /// Create a parser from an in-memory byte slice and a known format. + pub fn new(data: impl Into>, format: RdfFormat) -> Self { + Self { + data: RdfData::Owned(data.into()), + format, + } + } + + /// Load a file from `path`, detecting its format from the extension. + pub fn from_path(path: impl AsRef) -> Result { + let path = path.as_ref(); + let format = RdfFormat::from_path(path).ok_or_else(|| { + FormatError::Io(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Unknown RDF extension: {:?}", path.extension()), + )) + })?; + + let file = std::fs::File::open(path)?; + + let data = match unsafe { memmap2::Mmap::map(&file) } { + Ok(mmap) => RdfData::Mapped(mmap), + Err(_) => { + use std::io::Read; + let mut buf = Vec::new(); + let mut file = file; + file.read_to_end(&mut buf)?; + RdfData::Owned(buf) + } + }; + + Ok(Self { data, format }) + } + + /// Parse the stored bytes in parallel, returning an iterator of edges and errors. + pub fn parse(self) -> impl Iterator> { + let target_parallelism = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); + + let bytes: &[u8] = &self.data; + + let edges: Vec> = match self.format { + RdfFormat::NTriples => NTriplesParser::new() + .lenient() + .split_slice_for_parallel_parsing(bytes, target_parallelism) + .into_par_iter() + .flat_map_iter(|parser| parser.map(|item| triple_to_edge(item?))) + .collect(), + RdfFormat::Turtle => TurtleParser::new() + .lenient() + .split_slice_for_parallel_parsing(bytes, target_parallelism) + .into_par_iter() + .flat_map_iter(|parser| parser.map(|item| triple_to_edge(item?))) + .collect(), + }; + + edges.into_iter() + } +} + +fn subject_to_node_id(subject: NamedOrBlankNode) -> String { + match subject { + NamedOrBlankNode::NamedNode(n) => n.into_string(), + NamedOrBlankNode::BlankNode(b) => format!("_:{}", b.as_str()), + } +} + +fn object_to_node_id(object: Term) -> Result { + match object { + Term::NamedNode(n) => Ok(n.into_string()), + Term::BlankNode(b) => Ok(format!("_:{}", b.as_str())), + Term::Literal(_) => Err(FormatError::LiteralAsNode), + } +} + +/// Convert a parsed [`Triple`] into an [`Edge`]. +pub(crate) fn triple_to_edge(triple: Triple) -> Result { + let source = subject_to_node_id(triple.subject); + let label = triple.predicate.as_str().to_owned(); + let target = object_to_node_id(triple.object)?; + Ok(Edge { + source, + target, + label, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse_turtle(ttl: &[u8]) -> Vec> { + Rdf::new(ttl, RdfFormat::Turtle).parse().collect() + } + + fn parse_ntriples(nt: &[u8]) -> Vec> { + Rdf::new(nt, RdfFormat::NTriples).parse().collect() + } + + fn ok_edges(results: Vec>) -> Vec { + results.into_iter().filter_map(|r| r.ok()).collect() + } + + #[test] + fn test_turtle_basic() { + let ttl = br#" + @prefix ex: . + ex:Alice ex:knows ex:Bob . + ex:Bob ex:knows ex:Charlie . + "#; + let mut edges = ok_edges(parse_turtle(ttl)); + edges.sort_by(|a, b| a.source.cmp(&b.source).then(a.target.cmp(&b.target))); + assert_eq!(edges.len(), 2); + // Find Alice->Bob edge + let alice_bob = edges + .iter() + .find(|e| e.source == "http://example.org/Alice") + .unwrap(); + assert_eq!(alice_bob.label, "http://example.org/knows"); + assert_eq!(alice_bob.target, "http://example.org/Bob"); + } + + #[test] + fn test_turtle_predicate_object_lists() { + let ttl = br#" + @prefix ex: . + ex:Alice ex:knows ex:Bob, ex:Charlie ; + ex:likes ex:Dave . + "#; + let edges = ok_edges(parse_turtle(ttl)); + assert_eq!(edges.len(), 3); + } + + #[test] + fn test_ntriples_basic() { + let nt = b" .\n\ + .\n"; + let edges = ok_edges(parse_ntriples(nt)); + assert_eq!(edges.len(), 2); + let alice = edges + .iter() + .find(|e| e.source == "http://example.org/Alice") + .unwrap(); + assert_eq!(alice.label, "http://example.org/knows"); + } + + #[test] + fn test_literal_yields_error() { + // parallel_rdf_edges now returns Err(FormatError::LiteralAsNode) for literal-object triples + let ttl = br#" + @prefix ex: . + ex:Alice ex:name "Alice" . + ex:Alice ex:knows ex:Bob . + "#; + let results = parse_turtle(ttl); + let errors: Vec<_> = results.iter().filter(|r| r.is_err()).collect(); + let edges: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect(); + // The literal-object triple produces an error + assert_eq!(errors.len(), 1); + assert!(matches!(errors[0], Err(FormatError::LiteralAsNode))); + // The valid triple still produces an edge + assert_eq!(edges.len(), 1); + assert_eq!(edges[0].source, "http://example.org/Alice"); + assert_eq!(edges[0].label, "http://example.org/knows"); + } + + #[test] + fn test_format_detection() { + assert_eq!(RdfFormat::from_path("data.ttl"), Some(RdfFormat::Turtle)); + assert_eq!(RdfFormat::from_path("data.turtle"), Some(RdfFormat::Turtle)); + assert_eq!(RdfFormat::from_path("data.nt"), Some(RdfFormat::NTriples)); + assert_eq!( + RdfFormat::from_path("data.ntriples"), + Some(RdfFormat::NTriples) + ); + assert_eq!(RdfFormat::from_path("data.csv"), None); + } + + #[test] + fn test_blank_nodes() { + let nt = b"_:b1 _:b2 .\n"; + let edges = ok_edges(parse_ntriples(nt)); + assert_eq!(edges.len(), 1); + assert_eq!(edges[0].source, "_:b1"); + assert_eq!(edges[0].target, "_:b2"); + } + + #[test] + fn test_non_ascii_in_iris() { + let nt = " .\n\ + .\n"; + let edges = ok_edges(parse_ntriples(nt.as_bytes())); + assert_eq!(edges.len(), 2); + + let person1 = edges + .iter() + .find(|e| e.source == "http://example.org/人甲") + .unwrap(); + assert_eq!(person1.target, "http://example.org/人乙"); + assert_eq!(person1.label, "http://example.org/关系/认识"); + + let alice = edges + .iter() + .find(|e| e.source == "http://example.org/Алиса") + .unwrap(); + assert_eq!(alice.target, "http://example.org/Боб"); + assert_eq!(alice.label, "http://example.org/знает"); + } + + #[test] + fn test_from_path_mmap() { + use std::io::Write; + use tempfile::NamedTempFile; + + let nt = b" .\n"; + let mut f = NamedTempFile::with_suffix(".nt").unwrap(); + f.write_all(nt).unwrap(); + f.flush().unwrap(); + + let edges = ok_edges(Rdf::from_path(f.path()).unwrap().parse().collect()); + assert_eq!(edges.len(), 1); + assert_eq!(edges[0].source, "http://example.org/A"); + assert_eq!(edges[0].target, "http://example.org/B"); + } +} diff --git a/src/graph/inmemory.rs b/src/graph/inmemory.rs index 7832198..5c764e4 100644 --- a/src/graph/inmemory.rs +++ b/src/graph/inmemory.rs @@ -1,15 +1,18 @@ use std::sync::Arc; use std::{collections::HashMap, io::Read}; -use crate::formats::mm::{apply_base_iri, load_mm_file, parse_index_map}; -use crate::formats::{Csv, MatrixMarket, NTriples}; +use rayon::prelude::*; + +use crate::formats::mm::{apply_base_iri, parse_index_map}; +use crate::formats::{Csv, MatrixMarket, Rdf}; use crate::{ graph::GraphSource, - lagraph_sys::{GrB_Index, GrB_Matrix, GrB_Matrix_free, LAGraph_Kind}, + lagraph_sys::{GrB_Index, LAGraph_Kind}, }; use super::{ - Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, LagraphGraph, ensure_grb_init, + compute_outer_inner, load_mm_file, Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, + LagraphGraph, ThreadScope, }; /// Marker type for the in-memory GraphBLAS-backed backend. @@ -101,26 +104,12 @@ impl InMemoryBuilder { Ok(self) } - /// Accept a pre-built [`GrB_Matrix`] for `label`, wrapping it in an - /// [`LAGraph_Graph`] immediately. - pub fn push_grb_matrix( + /// Bulk-install pre-wrapped `(label, LagraphGraph)` pairs into `prebuilt`. + pub(crate) fn extend_prebuilt>( &mut self, - label: impl Into, - mut matrix: GrB_Matrix, - ) -> Result<(), GraphError> { - ensure_grb_init()?; - let lg: LagraphGraph = - match LagraphGraph::new(matrix, LAGraph_Kind::LAGraph_ADJACENCY_DIRECTED) { - Ok(g) => g, - Err(e) => { - if !matrix.is_null() { - unsafe { GrB_Matrix_free(&mut matrix) }; - } - return Err(e); - } - }; - self.prebuilt.insert(label.into(), lg); - Ok(()) + iter: I, + ) { + self.prebuilt.extend(iter); } } @@ -129,8 +118,6 @@ impl GraphBuilder for InMemoryBuilder { type Error = GraphError; fn build(self) -> Result { - ensure_grb_init()?; - let n: GrB_Index = self .id_to_node .keys() @@ -146,20 +133,34 @@ impl GraphBuilder for InMemoryBuilder { graphs.insert(label, Arc::new(lg)); } - for (label, pairs) in &self.label_buffers { - let rows: Vec = pairs.iter().map(|(r, _)| *r as GrB_Index).collect(); - let cols: Vec = pairs.iter().map(|(_, c)| *c as GrB_Index).collect(); - let vals: Vec = vec![true; pairs.len()]; - - let lg = LagraphGraph::from_coo( - &rows, - &cols, - &vals, - n, - LAGraph_Kind::LAGraph_ADJACENCY_DIRECTED, - )?; - - graphs.insert(label.clone(), Arc::new(lg)); + let label_buffers: Vec<(String, Vec<(usize, usize)>)> = + self.label_buffers.into_iter().collect(); + + let (outer, inner) = compute_outer_inner(label_buffers.len()); + let _scope = ThreadScope::enter(outer, inner)?; + + let built: Vec<(String, LagraphGraph)> = label_buffers + .into_par_iter() + .map( + |(label, pairs)| -> Result<(String, LagraphGraph), GraphError> { + let rows: Vec = pairs.iter().map(|(r, _)| *r as GrB_Index).collect(); + let cols: Vec = pairs.iter().map(|(_, c)| *c as GrB_Index).collect(); + let vals: Vec = vec![true; pairs.len()]; + + let lg = LagraphGraph::from_coo( + &rows, + &cols, + &vals, + n, + LAGraph_Kind::LAGraph_ADJACENCY_DIRECTED, + )?; + Ok((label, lg)) + }, + ) + .collect::, GraphError>>()?; + + for (label, lg) in built { + graphs.insert(label, Arc::new(lg)); } Ok(InMemoryGraph { @@ -225,27 +226,42 @@ impl GraphSource for MatrixMarket { .collect(); let (edge_by_idx, _) = parse_index_map(&self.dir.join("edges.txt"))?; - let edge_by_idx: HashMap = edge_by_idx + let edge_by_idx: Vec<(usize, String)> = edge_by_idx .into_iter() .map(|(i, label)| (i, apply_base_iri(label, base))) .collect(); builder.set_node_map(vert_by_idx, vert_by_name); - for (idx, label) in edge_by_idx { - let path = self.mm_path(idx); - let matrix = load_mm_file(&path)?; - builder.push_grb_matrix(label, matrix)?; - } + let (outer, inner) = compute_outer_inner(edge_by_idx.len()); + let _scope = ThreadScope::enter(outer, inner)?; + + let mm_dir = self.dir.clone(); + let loaded: Vec<(String, LagraphGraph)> = edge_by_idx + .into_par_iter() + .map( + |(idx, label)| -> Result<(String, LagraphGraph), GraphError> { + let path = mm_dir.join(format!("{}.txt", idx)); + let matrix = load_mm_file(&path)?; + let lg = LagraphGraph::from_matrix( + matrix, + LAGraph_Kind::LAGraph_ADJACENCY_DIRECTED, + )?; + Ok((label, lg)) + }, + ) + .collect::, GraphError>>()?; + + builder.extend_prebuilt(loaded); Ok(builder) } } -impl GraphSource for NTriples { +impl GraphSource for Rdf { fn apply_to(self, mut builder: InMemoryBuilder) -> Result { - for item in self { - builder.push_edge(item?)?; + for edge in self.parse().flatten() { + builder.push_edge(edge)?; } Ok(builder) } @@ -340,15 +356,36 @@ mod tests { } #[test] - fn test_with_stream_from_ntriples() { - use crate::formats::nt::NTriples; + fn test_rdf_skip_bad_syntax_lines() { + use crate::formats::rdf::{Rdf, RdfFormat}; + + let nt = b" .\n\ + THIS IS NOT VALID RDF SYNTAX .\n\ + .\n"; + + let graph = InMemoryBuilder::new() + .load(Rdf::new(nt.as_ref(), RdfFormat::NTriples)) + .expect("load should succeed despite bad line") + .build() + .expect("build should succeed"); + + assert_eq!(graph.num_nodes(), 3, "A, B, C must all be present"); + assert!( + graph.get_graph("http://example.org/knows").is_ok(), + "label matrix must exist" + ); + } + + #[test] + fn test_with_stream_from_rdf() { + use crate::formats::rdf::{Rdf, RdfFormat}; - let nt = " .\n\ + let nt = b" .\n\ .\n\ .\n"; let graph = InMemoryBuilder::new() - .load(NTriples::new(nt.as_bytes())) + .load(Rdf::new(nt.as_ref(), RdfFormat::NTriples)) .expect("load should succeed") .build() .expect("build should succeed"); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 5a48dfa..5b489d0 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,13 +1,16 @@ //! Core graph abstractions for pathrex. pub mod inmemory; +pub mod wrappers; pub use inmemory::{InMemory, InMemoryBuilder, InMemoryGraph}; +pub(crate) use wrappers::{compute_outer_inner, ensure_grb_init, ThreadScope}; +pub use wrappers::{load_mm_file, GraphblasMatrix, GraphblasVector, LagraphGraph}; use std::marker::PhantomData; -use std::sync::{Arc, Once}; +use std::sync::Arc; -use crate::{grb_ok, la_ok, lagraph_sys::*}; +use crate::lagraph_sys::GrB_Info; use thiserror::Error; @@ -21,184 +24,19 @@ pub enum GraphError { #[error("GraphBLAS error: info code {0}; msg: {1}")] LAGraph(GrB_Info, String), + /// GraphBLAS/LAGraph initialisation failed. + #[error("LAGraph initialization failed")] + InitFailed, + /// [`GraphDecomposition::get_graph`] was called with an unknown label. #[error("Label not found: '{0}'")] LabelNotFound(String), - /// [`ensure_grb_init`] was called but `LAGraph_Init` returned a failure code. - #[error("LAGraph initialization failed")] - InitFailed, - /// A format-layer error propagated through [`GraphBuilder::load`]. #[error("Format error: {0}")] Format(#[from] crate::formats::FormatError), } -static GRB_INIT: Once = Once::new(); - -pub fn ensure_grb_init() -> Result<(), GraphError> { - let mut result = Ok(()); - GRB_INIT.call_once(|| { - result = la_ok!(LAGraph_Init()); - }); - result -} - -#[derive(Debug)] -pub struct LagraphGraph { - pub(crate) inner: LAGraph_Graph, -} - -impl LagraphGraph { - pub fn new(mut matrix: GrB_Matrix, kind: LAGraph_Kind) -> Result { - let mut g: LAGraph_Graph = std::ptr::null_mut(); - la_ok!(LAGraph_New(&mut g, &mut matrix, kind,))?; - - return Ok(Self { inner: g }); - } - - /// Build a new `LagraphGraph` from coordinate (COO) format. - /// - /// Creates a boolean adjacency matrix from parallel arrays of row indices, - /// column indices, and boolean values, then wraps it in an `LAGraph_Graph`. - /// - /// # Parameters - /// - `rows`: Row indices - /// - `cols`: Column indices - /// - `vals`: Boolean values for each edge - /// - `n`: Number of nodes - /// - `kind`: Graph kind (e.g., `LAGraph_ADJACENCY_DIRECTED`) - /// - /// # Safety - /// Caller must ensure LAGraph/GraphBLAS has been initialised via - /// [`ensure_grb_init`]. - /// - /// # Example - /// ```ignore - /// let rows = vec![0, 1, 2]; - /// let cols = vec![1, 2, 0]; - /// let vals = vec![true, true, true]; - /// let graph = unsafe { - /// LagraphGraph::from_coo(&rows, &cols, &vals, 3, LAGraph_ADJACENCY_DIRECTED) - /// }?; - /// ``` - pub fn from_coo( - rows: &[GrB_Index], - cols: &[GrB_Index], - vals: &[bool], - n: GrB_Index, - kind: LAGraph_Kind, - ) -> Result { - let nvals = rows.len() as GrB_Index; - - let mut matrix: GrB_Matrix = std::ptr::null_mut(); - grb_ok!(GrB_Matrix_new(&mut matrix, GrB_BOOL, n, n))?; - - if let Err(e) = grb_ok!(GrB_Matrix_build_BOOL( - matrix, - rows.as_ptr(), - cols.as_ptr(), - vals.as_ptr(), - nvals, - GrB_LOR, - )) { - let _ = grb_ok!(GrB_Matrix_free(&mut matrix)); - return Err(e); - } - - Self::new(matrix, kind) - } - - pub fn check_graph(&self) -> Result<(), GraphError> { - la_ok!(LAGraph_CheckGraph(self.inner)) - } -} - -impl Drop for LagraphGraph { - fn drop(&mut self) { - if !self.inner.is_null() { - let _ = la_ok!(LAGraph_Delete(&mut self.inner)); - } - } -} - -unsafe impl Send for LagraphGraph {} -unsafe impl Sync for LagraphGraph {} - -#[derive(Debug)] -pub struct GraphblasVector { - pub inner: GrB_Vector, -} - -impl GraphblasVector { - /// Allocate a new N-element boolean `GrB_Vector`. - /// - /// # Safety - /// Caller must ensure LAGraph/GraphBLAS has been initialised via - /// [`ensure_grb_init`]. - pub unsafe fn new_bool(n: GrB_Index) -> Result { - let mut v: GrB_Vector = std::ptr::null_mut(); - grb_ok!(GrB_Vector_new(&mut v, GrB_BOOL, n))?; - Ok(Self { inner: v }) - } - - /// Returns the number of stored values in this vector. - pub fn nvals(&self) -> Result { - let mut nvals: GrB_Index = 0; - grb_ok!(GrB_Vector_nvals(&mut nvals, self.inner))?; - Ok(nvals) - } - - /// Extracts all stored indices from boolean vector. - pub fn indices(&self) -> Result, GraphError> { - let nvals = self.nvals()?; - if nvals == 0 { - return Ok(Vec::new()); - } - - let mut indices = vec![0u64; nvals as usize]; - let mut values = vec![false; nvals as usize]; - let mut actual_nvals = nvals; - - grb_ok!(GrB_Vector_extractTuples_BOOL( - indices.as_mut_ptr(), - values.as_mut_ptr(), - &mut actual_nvals, - self.inner, - ))?; - - indices.truncate(actual_nvals as usize); - Ok(indices) - } -} - -impl Drop for GraphblasVector { - fn drop(&mut self) { - if !self.inner.is_null() { - let _ = grb_ok!(GrB_Vector_free(&mut self.inner)); - } - } -} - -unsafe impl Send for GraphblasVector {} -unsafe impl Sync for GraphblasVector {} - -#[derive(Debug)] -pub struct GraphblasMatrix { - pub inner: GrB_Matrix, -} - -impl Drop for GraphblasMatrix { - fn drop(&mut self) { - if !self.inner.is_null() { - let _ = grb_ok!(GrB_Matrix_free(&mut self.inner)); - } - } -} - -unsafe impl Send for GraphblasMatrix {} -unsafe impl Sync for GraphblasMatrix {} - /// A directed, labelled edge as produced by format parsers. #[derive(Debug, Clone)] pub struct Edge { @@ -322,6 +160,32 @@ mod tests { assert_eq!(output.num_nodes(), 3); } + #[test] + fn test_compute_outer_inner_product_bounded_by_cores() { + let cores = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); + for num_tasks in [0usize, 1, 2, 4, 8, 16, 64, 1024] { + let (outer, inner) = compute_outer_inner(num_tasks); + assert!(outer >= 1, "outer must be >= 1 for num_tasks={num_tasks}"); + assert!(inner >= 1, "inner must be >= 1 for num_tasks={num_tasks}"); + let product = (outer as usize) * (inner as usize); + assert!( + product <= cores.max(1), + "outer*inner ({outer}*{inner}={product}) must not exceed cores ({cores}) for num_tasks={num_tasks}" + ); + } + } + + #[test] + fn test_compute_outer_inner_caps_outer_at_tasks() { + // With a very small number of tasks, outer should never exceed that. + let (outer, _inner) = compute_outer_inner(1); + assert_eq!(outer, 1); + let (outer, _inner) = compute_outer_inner(2); + assert!(outer <= 2); + } + #[test] fn test_graph_try_from() { struct TestBackend; diff --git a/src/graph/wrappers.rs b/src/graph/wrappers.rs new file mode 100644 index 0000000..f8128e6 --- /dev/null +++ b/src/graph/wrappers.rs @@ -0,0 +1,275 @@ +//! RAII wrappers and init helpers for GraphBLAS and LAGraph C handles. +//! +//! GraphBLAS initialisation is performed lazily inside the RAII-wrapped constructors +//! here (`LagraphGraph::from_coo`, `LagraphGraph::from_matrix`) and inside +//! `ThreadScope::enter`. Consumers of these wrappers — including format loaders, +//! builders, and RPQ evaluators — do not need (and should not) call init themselves. + +use std::ffi::CString; +use std::fs::File; +use std::os::fd::IntoRawFd; +use std::path::Path; +use std::sync::Once; + +use crate::{grb_ok, la_ok, lagraph_sys::*}; + +use super::GraphError; + +static GRB_INIT: Once = Once::new(); + +pub(crate) fn ensure_grb_init() -> Result<(), GraphError> { + let mut result = Ok(()); + GRB_INIT.call_once(|| { + result = unsafe { la_ok!(LAGraph_Init()) }; + }); + result +} + +/// Compute a balanced `(outer, inner)` split for LAGraph's two-level threading. +/// +/// `outer` is the number of user-level concurrent tasks (rayon workers); +/// `inner` is the number of GraphBLAS/OpenMP threads per task. The product is +/// kept close to `available_parallelism()` so the OS scheduler does not +/// thrash. +pub(crate) fn compute_outer_inner(num_tasks: usize) -> (i32, i32) { + let cores = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); + let tasks = num_tasks.max(1); + let outer = tasks.min(cores).max(1); + let inner = (cores / outer).max(1); + (outer as i32, inner as i32) +} + +/// RAII guard that temporarily sets LAGraph's `(outer, inner)` thread counts. +/// +/// On entry calls `LAGraph_SetNumThreads(outer, inner)`. On drop restores +/// `(1, available_parallelism())` so subsequent callers +/// keep full GraphBLAS parallelism. +pub(crate) struct ThreadScope { + _private: (), +} + +impl ThreadScope { + pub(crate) fn enter(outer: i32, inner: i32) -> Result { + ensure_grb_init()?; + unsafe { la_ok!(LAGraph_SetNumThreads(outer, inner))? }; + Ok(Self { _private: () }) + } +} + +impl Drop for ThreadScope { + fn drop(&mut self) { + let cores = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1) as i32; + if let Err(e) = unsafe { la_ok!(LAGraph_SetNumThreads(1, cores)) } { + eprintln!("ThreadScope: failed to restore thread counts: {e}"); + } + } +} + +#[derive(Debug)] +pub struct LagraphGraph { + pub(crate) inner: LAGraph_Graph, +} + +impl LagraphGraph { + /// Build a `LagraphGraph` from an RAII-wrapped [`GraphblasMatrix`]. + /// + /// On success, ownership of the underlying `GrB_Matrix` is transferred + /// into the `LAGraph_Graph` and the [`GraphblasMatrix`] guard is forgotten + /// + /// On failure, the [`GraphblasMatrix`] is dropped normally, freeing the + /// matrix. + pub fn from_matrix(matrix: GraphblasMatrix, kind: LAGraph_Kind) -> Result { + ensure_grb_init()?; + let mut raw = matrix.inner; + let mut g: LAGraph_Graph = std::ptr::null_mut(); + match unsafe { la_ok!(LAGraph_New(&mut g, &mut raw, kind)) } { + Ok(()) => { + std::mem::forget(matrix); + Ok(Self { inner: g }) + } + Err(e) => Err(e), + } + } + + /// Build a new `LagraphGraph` from coordinate (COO) format. + /// + /// Creates a boolean adjacency matrix from parallel arrays of row indices, + /// column indices, and boolean values, then wraps it in an `LAGraph_Graph`. + /// + /// # Parameters + /// - `rows`: Row indices + /// - `cols`: Column indices + /// - `vals`: Boolean values for each edge + /// - `n`: Number of nodes + /// - `kind`: Graph kind (e.g., `LAGraph_ADJACENCY_DIRECTED`) + /// + /// # Example + /// ```ignore + /// let rows = vec![0, 1, 2]; + /// let cols = vec![1, 2, 0]; + /// let vals = vec![true, true, true]; + /// let graph = LagraphGraph::from_coo(&rows, &cols, &vals, 3, LAGraph_ADJACENCY_DIRECTED)?; + /// ``` + pub fn from_coo( + rows: &[GrB_Index], + cols: &[GrB_Index], + vals: &[bool], + n: GrB_Index, + kind: LAGraph_Kind, + ) -> Result { + ensure_grb_init()?; + let nvals = rows.len() as GrB_Index; + + let mut matrix: GrB_Matrix = std::ptr::null_mut(); + unsafe { grb_ok!(GrB_Matrix_new(&mut matrix, GrB_BOOL, n, n))? }; + + let owned = GraphblasMatrix::from_raw(matrix); + + grb_ok!(unsafe { + GrB_Matrix_build_BOOL( + owned.inner, + rows.as_ptr(), + cols.as_ptr(), + vals.as_ptr(), + nvals, + GrB_LOR, + ) + })?; + + Self::from_matrix(owned, kind) + } + + pub fn check_graph(&self) -> Result<(), GraphError> { + unsafe { la_ok!(LAGraph_CheckGraph(self.inner)) } + } + + /// Number of stored (non-zero) values in the underlying adjacency matrix. + pub fn nvals(&self) -> Result { + if self.inner.is_null() { + return Ok(0); + } + let matrix: GrB_Matrix = unsafe { (*self.inner).A }; + let mut nvals: GrB_Index = 0; + unsafe { grb_ok!(GrB_Matrix_nvals(&mut nvals, matrix))? }; + Ok(nvals) + } +} + +impl Drop for LagraphGraph { + fn drop(&mut self) { + if !self.inner.is_null() { + let _ = unsafe { la_ok!(LAGraph_Delete(&mut self.inner)) }; + } + } +} + +unsafe impl Send for LagraphGraph {} +unsafe impl Sync for LagraphGraph {} + +#[derive(Debug)] +pub struct GraphblasVector { + pub inner: GrB_Vector, +} + +impl GraphblasVector { + /// Returns the number of stored values in this vector. + pub fn nvals(&self) -> Result { + let mut nvals: GrB_Index = 0; + unsafe { grb_ok!(GrB_Vector_nvals(&mut nvals, self.inner))? }; + Ok(nvals) + } + + /// Extracts all stored indices from boolean vector. + pub fn indices(&self) -> Result, GraphError> { + let nvals = self.nvals()?; + if nvals == 0 { + return Ok(Vec::new()); + } + + let mut indices = vec![0u64; nvals as usize]; + let mut values = vec![false; nvals as usize]; + let mut actual_nvals = nvals; + + unsafe { + grb_ok!(GrB_Vector_extractTuples_BOOL( + indices.as_mut_ptr(), + values.as_mut_ptr(), + &mut actual_nvals, + self.inner, + ))? + }; + + indices.truncate(actual_nvals as usize); + Ok(indices) + } +} + +impl Drop for GraphblasVector { + fn drop(&mut self) { + if !self.inner.is_null() { + let _ = unsafe { grb_ok!(GrB_Vector_free(&mut self.inner)) }; + } + } +} + +unsafe impl Send for GraphblasVector {} +unsafe impl Sync for GraphblasVector {} + +#[derive(Debug)] +pub struct GraphblasMatrix { + pub inner: GrB_Matrix, +} + +impl GraphblasMatrix { + /// Wrap a raw [`GrB_Matrix`] pointer in an RAII guard. + /// + /// The caller must ensure the pointer is either null or a valid, + /// live `GrB_Matrix` that is not shared with any other owner. + /// [`Drop`] will call `GrB_Matrix_free` when the guard is dropped. + pub fn from_raw(raw: GrB_Matrix) -> Self { + Self { inner: raw } + } +} + +impl Drop for GraphblasMatrix { + fn drop(&mut self) { + if !self.inner.is_null() { + let _ = unsafe { grb_ok!(GrB_Matrix_free(&mut self.inner)) }; + } + } +} + +unsafe impl Send for GraphblasMatrix {} +unsafe impl Sync for GraphblasMatrix {} + +/// Read a single MatrixMarket file and return a RAII-wrapped [`GraphblasMatrix`]. +/// +/// Initialises GraphBLAS on first call. The file must be in MatrixMarket +/// coordinate format as produced by LAGraph. +pub fn load_mm_file(path: impl AsRef) -> Result { + ensure_grb_init()?; + + let file = File::open(path.as_ref()) + .map_err(|e| GraphError::Format(crate::formats::FormatError::Io(e)))?; + let fd = file.into_raw_fd(); + + let c_mode = CString::new("r").unwrap(); + let f = unsafe { libc::fdopen(fd, c_mode.as_ptr()) }; + if f.is_null() { + unsafe { libc::close(fd) }; + return Err(GraphError::Format(crate::formats::FormatError::Io( + std::io::Error::last_os_error(), + ))); + } + + let mut matrix = std::ptr::null_mut(); + let err = unsafe { la_ok!(LAGraph_MMRead(&mut matrix, f as *mut FILE)) }; + unsafe { libc::fclose(f) }; + err?; + + Ok(GraphblasMatrix::from_raw(matrix)) +} diff --git a/src/lagraph_sys_generated.rs b/src/lagraph_sys_generated.rs index 5e0de30..5e02d97 100644 --- a/src/lagraph_sys_generated.rs +++ b/src/lagraph_sys_generated.rs @@ -262,6 +262,20 @@ unsafe extern "C" { msg: *mut ::std::os::raw::c_char, ) -> ::std::os::raw::c_int; } +unsafe extern "C" { + pub fn LAGraph_GetNumThreads( + nthreads_outer: *mut ::std::os::raw::c_int, + nthreads_inner: *mut ::std::os::raw::c_int, + msg: *mut ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn LAGraph_SetNumThreads( + nthreads_outer: ::std::os::raw::c_int, + nthreads_inner: ::std::os::raw::c_int, + msg: *mut ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} unsafe extern "C" { pub fn LAGraph_MMRead( A: *mut GrB_Matrix, diff --git a/src/rpq/mod.rs b/src/rpq/mod.rs index a916890..db2d220 100644 --- a/src/rpq/mod.rs +++ b/src/rpq/mod.rs @@ -51,10 +51,8 @@ impl RpqQuery { } fn strip_endpoint(ep: &mut Endpoint, base: &str) { - if let Endpoint::Named(s) = ep { - if s.starts_with(base) { - *s = s[base.len()..].to_owned(); - } + if let Endpoint::Named(s) = ep && s.starts_with(base) { + *s = s[base.len()..].to_owned(); } } diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index e52797c..a616b64 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -1,11 +1,11 @@ //! NFA-based RPQ evaluation using `LAGraph_RegularPathQuery`. -use crate::graph::{GraphDecomposition, GraphblasVector, LagraphGraph, ensure_grb_init}; +use crate::graph::{GraphDecomposition, GraphblasVector, LagraphGraph}; use crate::la_ok; -use crate::lagraph_sys::*; use crate::lagraph_sys::LAGraph_Kind; +use crate::lagraph_sys::*; use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; -use rustfst::algorithms::closure::{ClosureType, closure}; +use rustfst::algorithms::closure::{closure, ClosureType}; use rustfst::algorithms::concat::concat; use rustfst::algorithms::rm_epsilon::rm_epsilon; use rustfst::algorithms::union::union; @@ -83,7 +83,6 @@ impl Nfa { /// Convert NFA transitions to LAGraph matrices for RPQ evaluation. pub fn build_lagraph_matrices(&self) -> Result, RpqError> { - ensure_grb_init()?; let n = self.num_states as GrB_Index; let mut result = Vec::with_capacity(self.transitions.len()); @@ -243,18 +242,20 @@ impl RpqEvaluator for NfaRpqEvaluator { let mut reachable: GrB_Vector = std::ptr::null_mut(); - la_ok!(LAGraph_RegularPathQuery( - &mut reachable, - nfa_graph_ptrs.as_mut_ptr(), - nfa_matrices.len(), - nfa.start_states.as_ptr(), - nfa.start_states.len(), - nfa.final_states.as_ptr(), - nfa.final_states.len(), - data_graph_ptrs.as_mut_ptr(), - source_vertices.as_ptr(), - source_vertices.len(), - ))?; + unsafe { + la_ok!(LAGraph_RegularPathQuery( + &mut reachable, + nfa_graph_ptrs.as_mut_ptr(), + nfa_matrices.len(), + nfa.start_states.as_ptr(), + nfa.start_states.len(), + nfa.final_states.as_ptr(), + nfa.final_states.len(), + data_graph_ptrs.as_mut_ptr(), + source_vertices.as_ptr(), + source_vertices.len(), + ))? + }; let result_vec = GraphblasVector { inner: reachable }; diff --git a/src/rpq/rpqmatrix.rs b/src/rpq/rpqmatrix.rs index 8587674..72f9110 100644 --- a/src/rpq/rpqmatrix.rs +++ b/src/rpq/rpqmatrix.rs @@ -4,7 +4,7 @@ use std::ptr::null_mut; use egg::{Id, RecExpr, define_language}; -use crate::graph::{GraphDecomposition, GraphblasMatrix, ensure_grb_init}; +use crate::graph::{GraphDecomposition, GraphblasMatrix}; use crate::lagraph_sys::*; use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; use crate::{grb_ok, la_ok}; @@ -119,7 +119,10 @@ pub fn materialize( .ok_or_else(|| RpqError::VertexNotFound(name.clone()))? as GrB_Index; let mut mat: GrB_Matrix = null_mut(); - grb_ok!(LAGraph_RPQMatrix_label(&mut mat, vertex_id, n, n,))?; + unsafe { + crate::graph::ensure_grb_init()?; + grb_ok!(LAGraph_RPQMatrix_label(&mut mat, vertex_id, n, n,))? + }; if mat.is_null() { return Err(RpqError::Graph(crate::graph::GraphError::GraphBlas( GrB_Info::GrB_INVALID_VALUE, @@ -182,15 +185,13 @@ impl RpqEvaluator for RpqMatrixEvaluator { query: &RpqQuery, graph: &G, ) -> Result { - ensure_grb_init()?; - let expr = query_to_expr(query)?; let (mut plans, owned_matrices) = materialize(&expr, graph)?; let root_ptr = unsafe { plans.as_mut_ptr().add(plans.len() - 1) }; let mut nnz: GrB_Index = 0; - la_ok!(LAGraph_RPQMatrix(&mut nnz, root_ptr))?; + unsafe { la_ok!(LAGraph_RPQMatrix(&mut nnz, root_ptr))? }; let matrix = unsafe { let mat = (*root_ptr).res_mat; @@ -198,7 +199,7 @@ impl RpqEvaluator for RpqMatrixEvaluator { GraphblasMatrix { inner: mat } }; - grb_ok!(LAGraph_DestroyRpqMatrixPlan(root_ptr))?; + unsafe { grb_ok!(LAGraph_DestroyRpqMatrixPlan(root_ptr))? }; // Free diagonal matrices created for named vertices. for mut mat in owned_matrices { diff --git a/src/utils.rs b/src/utils.rs index 626bf8f..a6d5d51 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -135,7 +135,7 @@ impl From for GrB_Info { #[macro_export] macro_rules! grb_ok { ($grb_func:expr) => { - unsafe { + { let info: $crate::lagraph_sys::GrB_Info = $grb_func.into(); if info == $crate::lagraph_sys::GrB_Info::GrB_SUCCESS { Ok(()) @@ -163,7 +163,7 @@ macro_rules! grb_ok { /// ``` #[macro_export] macro_rules! la_ok { - ( $($func:ident)::+ ( $($arg:expr),* $(,)? ) ) => { unsafe { + ( $($func:ident)::+ ( $($arg:expr),* $(,)? ) ) => { { let mut msg = [0i8; $crate::lagraph_sys::LAGRAPH_MSG_LEN as usize]; let info: $crate::lagraph_sys::GrB_Info = $($func)::+($($arg,)* msg.as_mut_ptr()).into(); @@ -193,7 +193,7 @@ pub fn build_graph(edges: &[(&str, &str, &str)]) -> ::Graph }) .collect::>>(); builder - .with_stream(edges.into_iter()) + .with_stream(edges) .expect("Should insert edges stream") .build() .expect("build must succeed") diff --git a/tests/mm_tests.rs b/tests/mm_tests.rs index 60c7d47..0998e1b 100644 --- a/tests/mm_tests.rs +++ b/tests/mm_tests.rs @@ -122,10 +122,10 @@ fn test_mm_graph_matrix_dimensions() { for label in expected_labels { let matrix = graph .get_graph(label) - .expect(&format!("Should have matrix for label {}", label)); + .unwrap_or_else(|_| panic!("Should have matrix for label {}", label)); matrix .check_graph() - .expect(&format!("Matrix for {} should be valid", label)); + .unwrap_or_else(|_| panic!("Matrix for {} should be valid", label)); } } diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index 7483fc5..fbdf42a 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -21,17 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line - .splitn(2, ',') - .nth(1) + let query_str = line.split_once(',').map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - let query = - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")); - query + + parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } fn load_queries(case_dir: &Path) -> Vec { @@ -90,8 +87,7 @@ fn run_la_n_egg_case(case_name: &str) { let actual_nnz = result.reachable.nvals().expect("failed to get nvals"); assert_eq!( - actual_nnz, - *expected_nnz, + actual_nnz, *expected_nnz, "case '{case_name}' query #{i} nnz mismatch\n query: {query:?}\n expected: {expected_nnz}\n actual: {actual_nnz}", ); } @@ -143,7 +139,10 @@ fn test_single_label_named_source() { .evaluate(&rq(named_ep("A"), label("knows"), var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let b_id = graph.get_node_id("B").expect("B should exist"); assert!( indices.contains(&(b_id as GrB_Index)), @@ -181,7 +180,10 @@ fn test_sequence_path_named_source() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let c_id = graph.get_node_id("C").expect("C should exist"); assert!( indices.contains(&(c_id as GrB_Index)), @@ -202,7 +204,10 @@ fn test_alternative_path() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let b_id = graph.get_node_id("B").expect("B should exist"); let c_id = graph.get_node_id("C").expect("C should exist"); assert!( @@ -228,7 +233,10 @@ fn test_zero_or_more_path() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let a_id = graph.get_node_id("A").expect("A should exist"); let b_id = graph.get_node_id("B").expect("B should exist"); let c_id = graph.get_node_id("C").expect("C should exist"); @@ -260,7 +268,10 @@ fn test_one_or_more_path() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let a_id = graph.get_node_id("A").expect("A should exist"); let b_id = graph.get_node_id("B").expect("B should exist"); let c_id = graph.get_node_id("C").expect("C should exist"); @@ -292,7 +303,10 @@ fn test_zero_or_one_path() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let a_id = graph.get_node_id("A").expect("A should exist"); let b_id = graph.get_node_id("B").expect("B should exist"); let c_id = graph.get_node_id("C").expect("C should exist"); @@ -402,7 +416,10 @@ fn test_complex_path() { .evaluate(&rq(named_ep("A"), path, var("y")), &graph) .expect("evaluate should succeed"); - let indices = result.reachable.indices().expect("failed to extract indices"); + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); let d_id = graph.get_node_id("D").expect("D should exist"); assert!( indices.contains(&(d_id as GrB_Index)), diff --git a/tests/rpqmatrix_tests.rs b/tests/rpqmatrix_tests.rs index 353c727..b23e0e9 100644 --- a/tests/rpqmatrix_tests.rs +++ b/tests/rpqmatrix_tests.rs @@ -21,17 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line - .splitn(2, ',') - .nth(1) + let query_str = line.split_once(',').map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - let query = - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")); - query + + parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } fn load_queries(case_dir: &Path) -> Vec {