Skip to content

PoC: Python UDFs (for IVF Flat) via numba-cuda-mlir on top of JIT/LTO#2133

Draft
dantegd wants to merge 3 commits into
rapidsai:mainfrom
dantegd:fea-pyudf
Draft

PoC: Python UDFs (for IVF Flat) via numba-cuda-mlir on top of JIT/LTO#2133
dantegd wants to merge 3 commits into
rapidsai:mainfrom
dantegd:fea-pyudf

Conversation

@dantegd
Copy link
Copy Markdown
Member

@dantegd dantegd commented May 27, 2026

This is a personal PoC exploring the idea of using the recently released numba-cuda-mlir as a Python frontend on top of the existing cuVS JIT/LTO infrastructure to add python UDF capabilities to cuVS.

The goal is to validate the end-to-end shape rather than propose a finalized public API: define an IVF Flat metric in Python, lower it to LTO-IR, package it as a cuVS device UDF artifact, pass it through Python/C/C++, and link it into the IVF Flat JIT/LTO search path.

This also includes an "expert" CUDA/C++ source-string path via ivf_flat.cuda_source_metric(...), which helps compare the new Python/LTO-IR flow against the existing JIT/LTO UDF mechanism.

Yep, that’s a cleaner framing. Replace the separate Example API and Validation sections with this combined Example section:

Example

This PoC includes an end-to-end demo at:

python examples/experimental/ivf_flat_udf_e2e_demo.py

Python metric UDF with a CuPy capture

@ivf_flat.metric(
    order="min",
    initial=0.0,
    coarse_metric="sqeuclidean",
    captures={"weights": weights},
    symbol_name="cuvs_demo_weighted_l2_update_f32",
)
def weighted_l2_update(x, y, acc, ctx):
    d = x - y
    return acc + ctx.weights[ctx.dim] * d * d

udf_distances, udf_neighbors = ivf_flat.search(
    ivf_flat.SearchParams(n_probes=1, metric=weighted_l2_update),
    index,
    queries,
    3,
)

Output:

========================================================================================
Example 1: Python @ivf_flat.metric weighted L2 with a CuPy capture
========================================================================================
weights:
[0.25 1.5  3.   0.75]
queries:
[[ 0.2   0.1  -0.1   0.5 ]
 [ 2.1  -0.7   0.4  -0.25]]
dataset:
[[ 0.    0.    0.    0.  ]
 [ 1.    0.5  -0.5   2.  ]
 [ 2.   -1.    0.25 -0.5 ]
 [-1.5   2.    1.    0.75]
 [ 3.    1.5  -2.   -1.  ]
 [-2.   -1.    2.5   1.25]]
UDF neighbors:
[[0 1 2]
 [2 0 1]]
reference neighbors:
[[0 1 2]
 [2 0 1]]
UDF distances:
[[0.2425 2.5675 3.7425]
 [0.2519 2.3644 8.6894]]
reference distances:
[[0.2425 2.5675 3.7425]
 [0.2519 2.3644 8.6894]]
max abs distance error: 0.00000024
neighbors match: True
RESULT: PASS

Expert CUDA/C++ source-string metric

source = r'''
namespace cuvs::neighbors::ivf_flat::detail {
template <typename T, typename AccT, int Veclen>
__device__ __forceinline__ void compute_dist_udf_impl(
    AccT& acc, AccT x, AccT y)
{
  auto d = x - y;
  acc += d * d;
}
}
'''
source_metric = ivf_flat.cuda_source_metric(
    source,
    symbol_name="cuvs_demo_cuda_source_l2_metric",
)

source_distances, source_neighbors = ivf_flat.search(
    ivf_flat.SearchParams(n_probes=1, metric=source_metric),
    index,
    queries,
    3,
)

Output:

========================================================================================
Example 2: Expert CUDA/C++ source-string L2 metric
========================================================================================
queries:
[[ 0.2  0.1 -0.1]
 [ 2.1 -0.7  0.4]]
dataset:
[[ 0.    0.    0.  ]
 [ 1.    0.5  -0.5 ]
 [ 2.   -1.    0.25]
 [-1.5   2.    1.  ]
 [ 3.    1.5  -2.  ]
 [-2.   -1.    2.5 ]]
CUDA source neighbors:
[[0 1 2]
 [2 1 0]]
built-in L2 neighbors:
[[0 1 2]
 [2 1 0]]
CUDA source distances:
[[0.06   0.96   4.5725]
 [0.1225 3.46   5.06  ]]
built-in L2 distances:
[[0.06   0.96   4.5725]
 [0.1225 3.46   5.06  ]]
max abs distance error: 0.00000000
neighbors match: True
RESULT: PASS

Design Notes / ABI Framing

The Python UDF API in this PoC should be read as a coordinate-wise accumulator ABI, not as a fully general metric API.

The current supported shape is:

def metric(x, y, acc, ctx):
    ...
    return new_acc

This maps cleanly onto the existing IVF Flat scan kernel and the existing C++ UDF model, where the kernel already accumulates a distance one coordinate update at a time. In the previous C++ source UDF path, the practical shape was based on x, y, and acc; this PoC adds ctx so Python UDFs can access limited contextual state such as ctx.dim and one captured CUDA array.

This is intentionally not “the one true UDF API forever.” It is useful for metrics and transforms that can be expressed as independent coordinate updates plus an accumulator, including L2-like metrics, weighted L2, simple per-dimension transforms, and other custom distances that fit the existing ANN fine-scoring loop.

More general metrics should get a separate ABI/version rather than overloading this coordinate-wise ABI until it becomes confusing. A future block/vector-level ABI might look more like:

def metric(query_vector, database_vector, ctx):
    return distance

or use a block-level/kernel-adapter shape. That would be needed for metrics that require custom reductions, multiple passes, normalization over the whole vector, shared memory, synchronization, control flow across dimensions, or richer per-query/per-vector state.

That future API would require real kernel support for how user code participates in loading, reducing, shared memory, synchronization, and output selection. This PR does not imply plug-and-play block-level primitives yet.

What Changed

  • Added a shared device UDF artifact model for ABI, payload kind, symbol name, cache key, target metadata, and captures.
  • Added C/C++ cuvsDeviceUDF descriptors and conversion helpers.
  • Added Python ivf_flat.metric(...) support that compiles Python metric update functions to LTO-IR using numba-cuda-mlir.
  • Added Python ivf_flat.cuda_source_metric(...) for expert CUDA/C++ source-string metrics.
  • Extended IVF Flat SearchParams in Python/C/C++ to accept an optional metric UDF artifact.
  • Extended the IVF Flat JIT/LTO interleaved scan path to link LTO-IR metric payloads and adapt them into the existing compute_dist call path.
  • Added support for passing the current dimension and one capture pointer into the custom metric path.
  • Added an end-to-end demo at examples/experimental/ivf_flat_udf_e2e_demo.py.
  • Added Python and C++ tests for artifact metadata, validation, lowering, backend compilation, descriptor conversion, and IVF Flat search correctness.

Current PoC Scope

  • IVF Flat metric UDFs only.
  • V1 ABI: rapids.cuvs.ivf_flat.metric.v1.
  • Python path currently supports order="min" and initial=0.0.
  • Python captures are currently limited to at most one contiguous 1-D float32 CUDA array.
  • Captures are accessed as ctx.<capture_name>[ctx.dim].
  • CUDA source metrics currently do not support captures.
  • Coarse routing still uses the index/coarse metric behavior; this PoC customizes the fine scoring path.
  • This is a coordinate-wise accumulator ABI, not a block/vector-level metric ABI.
  • Requires a CUDA-capable environment with numba-cuda-mlir for the Python-to-LTO-IR path.

Potential Future Directions

  • Broaden the Python UDF surface beyond IVF Flat metrics to other cuVS algorithms that already use or could use the JIT/LTO infrastructure.
  • Support more metric metadata, including additional order, initial, and coarse-routing combinations.
  • Support richer capture shapes, multiple captures, and more capture dtypes beyond the current single contiguous 1-D float32 capture.
  • Explore a more stable/public artifact ABI so Python, C, and C++ integrations can share UDF payloads cleanly.
  • Add artifact caching across process boundaries so repeated UDF searches avoid recompilation/relinking when the same source, target, and capture metadata are reused.
  • Improve diagnostics for Python compilation failures, unsupported syntax, ABI mismatches, and JIT/link errors.
  • Investigate performance characteristics of the Python/LTO-IR path versus hand-written CUDA source UDFs and built-in metrics.
  • Evaluate whether captures should support structured descriptors, constants, or small POD values in addition to CUDA array pointers.
  • Design a separate block/vector-level ABI for metrics that cannot be represented as coordinate-wise accumulator updates.

Test Coverage

This PR adds coverage for:

  • UDF artifact metadata and cache-key behavior.
  • Python UDF validation and AST lowering.
  • numba-cuda-mlir backend compilation to LTO-IR.
  • IVF Flat Python metric UDFs matching built-in/reference results.
  • CUDA source metric artifacts matching built-in L2.
  • C/C++ UDF descriptor conversion and validation.
  • C++ IVF Flat custom metric search paths, including LTO-IR artifacts.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

2 participants