From 7a910a3a47650145efb9e7fe93bd47dcc6dc43df Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 11:13:47 +0300 Subject: [PATCH 01/19] feat(where): Add IL-generated SIMD optimization for np.where(condition, x, y) Add IL-generated kernels for np.where using runtime code generation: - Uses DynamicMethod to generate type-specific kernels at runtime - Vector256/Vector128.ConditionalSelect for SIMD element selection - 4x loop unrolling for better instruction-level parallelism - Full long indexing support for arrays > 2^31 elements - Supports all 12 dtypes (11 via SIMD, Decimal via scalar fallback) - Kernels cached per type for reuse Architecture: - WhereKernel delegate: (bool* cond, T* x, T* y, T* result, long count) - GetWhereKernel(): Returns cached IL-generated kernel - WhereExecute(): Main entry point with automatic fallback IL Generation: - 4x unrolled SIMD loop (processes 4 vectors per iteration) - Remainder SIMD loop (1 vector at a time) - Scalar tail loop for remaining elements - Mask creation methods by element size (1/2/4/8 bytes) - All arithmetic uses long types natively (no int-to-long casts) Falls back to iterator path for: - Non-contiguous/broadcasted arrays (stride=0) - Non-bool conditions (need truthiness conversion) Files: - src/NumSharp.Core/APIs/np.where.cs: Kernel dispatch logic - src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs: IL generation - test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs: 26 tests Closes #604 --- src/NumSharp.Core/APIs/np.where.cs | 199 ++++++ .../Kernels/ILKernelGenerator.Where.cs | 635 ++++++++++++++++++ .../Backends/Kernels/WhereSimdTests.cs | 515 ++++++++++++++ .../Logic/np.where.BattleTest.cs | 346 ++++++++++ test/NumSharp.UnitTest/Logic/np.where.Test.cs | 496 ++++++++++++++ 5 files changed, 2191 insertions(+) create mode 100644 src/NumSharp.Core/APIs/np.where.cs create mode 100644 src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs create mode 100644 test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs create mode 100644 test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs create mode 100644 test/NumSharp.UnitTest/Logic/np.where.Test.cs diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs new file mode 100644 index 00000000..a361534a --- /dev/null +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -0,0 +1,199 @@ +using System; +using NumSharp.Backends.Kernels; +using NumSharp.Generic; + +namespace NumSharp +{ + public static partial class np + { + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// + /// Where True, yield `x`, otherwise yield `y`. + /// Tuple of arrays with indices where condition is non-zero (equivalent to np.nonzero). + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray[] where(NDArray condition) + { + return nonzero(condition); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// + /// Where True, yield `x`, otherwise yield `y`. + /// Values from which to choose where condition is True. + /// Values from which to choose where condition is False. + /// An array with elements from `x` where `condition` is True, and elements from `y` elsewhere. + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray where(NDArray condition, NDArray x, NDArray y) + { + // Broadcast all three arrays to common shape + var broadcasted = broadcast_arrays(condition, x, y); + var cond = broadcasted[0]; + var xArr = broadcasted[1]; + var yArr = broadcasted[2]; + + // Determine output dtype from x and y (type promotion) + var outType = _FindCommonType(xArr, yArr); + // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides) + var result = empty(cond.shape, outType); + + // Handle empty arrays - nothing to iterate + if (result.size == 0) + return result; + + // IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled + // Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path. + bool canUseKernel = ILKernelGenerator.Enabled && + cond.typecode == NPTypeCode.Boolean && + cond.Shape.IsContiguous && + xArr.Shape.IsContiguous && + yArr.Shape.IsContiguous; + + if (canUseKernel) + { + WhereKernelDispatch(cond, xArr, yArr, result, outType); + return result; + } + + // Iterator fallback for non-contiguous/broadcasted arrays + switch (outType) + { + case NPTypeCode.Boolean: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Byte: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Char: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Single: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Double: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Decimal: + WhereImpl(cond, xArr, yArr, result); + break; + default: + throw new NotSupportedException($"Type {outType} not supported for np.where"); + } + + return result; + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for x. + /// + public static NDArray where(NDArray condition, object x, NDArray y) + { + return where(condition, asanyarray(x), y); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for y. + /// + public static NDArray where(NDArray condition, NDArray x, object y) + { + return where(condition, x, asanyarray(y)); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for both x and y. + /// + public static NDArray where(NDArray condition, object x, object y) + { + return where(condition, asanyarray(x), asanyarray(y)); + } + + private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged + { + // Use iterators for proper handling of broadcasted/strided arrays + using var condIter = cond.AsIterator(); + using var xIter = x.AsIterator(); + using var yIter = y.AsIterator(); + using var resultIter = result.AsIterator(); + + while (condIter.HasNext()) + { + var c = condIter.MoveNext(); + var xVal = xIter.MoveNext(); + var yVal = yIter.MoveNext(); + resultIter.MoveNextReference() = c ? xVal : yVal; + } + } + + /// + /// IL Kernel dispatch for contiguous arrays. + /// Uses IL-generated kernels with SIMD optimization. + /// + private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType) + { + var condPtr = (bool*)cond.Address; + var count = result.size; + + switch (outType) + { + case NPTypeCode.Boolean: + ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count); + break; + case NPTypeCode.Byte: + ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count); + break; + case NPTypeCode.Int16: + ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count); + break; + case NPTypeCode.UInt16: + ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count); + break; + case NPTypeCode.Int32: + ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count); + break; + case NPTypeCode.UInt32: + ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count); + break; + case NPTypeCode.Int64: + ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count); + break; + case NPTypeCode.UInt64: + ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count); + break; + case NPTypeCode.Char: + ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count); + break; + case NPTypeCode.Single: + ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count); + break; + case NPTypeCode.Double: + ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count); + break; + case NPTypeCode.Decimal: + ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count); + break; + } + } + } +} diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs new file mode 100644 index 00000000..e055bd8a --- /dev/null +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -0,0 +1,635 @@ +using System; +using System.Collections.Concurrent; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; + +// ============================================================================= +// ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels +// ============================================================================= +// +// RESPONSIBILITY: +// - Generate optimized kernels for conditional selection +// - result[i] = cond[i] ? x[i] : y[i] +// +// ARCHITECTURE: +// Uses IL emission to generate type-specific kernels at runtime. +// The challenge is bool mask expansion: condition is bool[] (1 byte per element), +// but x/y can be any dtype (1-8 bytes per element). +// +// | Element Size | V256 Elements | Bools to Load | +// |--------------|---------------|---------------| +// | 1 byte | 32 | 32 | +// | 2 bytes | 16 | 16 | +// | 4 bytes | 8 | 8 | +// | 8 bytes | 4 | 4 | +// +// KERNEL TYPES: +// - WhereKernel: Main kernel delegate (cond*, x*, y*, result*, count) +// +// ============================================================================= + +namespace NumSharp.Backends.Kernels +{ + /// + /// Delegate for where operation kernels. + /// + public unsafe delegate void WhereKernel(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged; + + public static partial class ILKernelGenerator + { + /// + /// Cache of IL-generated where kernels. + /// Key: Type + /// + private static readonly ConcurrentDictionary _whereKernelCache = new(); + + #region Public API + + /// + /// Get or generate an IL-based where kernel for the specified type. + /// Returns null if IL generation is disabled or fails. + /// + public static WhereKernel? GetWhereKernel() where T : unmanaged + { + if (!Enabled) + return null; + + var type = typeof(T); + + if (_whereKernelCache.TryGetValue(type, out var cached)) + return (WhereKernel)cached; + + var kernel = TryGenerateWhereKernel(); + if (kernel == null) + return null; + + if (_whereKernelCache.TryAdd(type, kernel)) + return kernel; + + return (WhereKernel)_whereKernelCache[type]; + } + + /// + /// Execute where operation using IL-generated kernel or fallback to static helper. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void WhereExecute(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + if (count == 0) + return; + + var kernel = GetWhereKernel(); + if (kernel != null) + { + kernel(cond, x, y, result, count); + } + else + { + // Fallback to scalar loop + WhereScalar(cond, x, y, result, count); + } + } + + #endregion + + #region Kernel Generation + + private static WhereKernel? TryGenerateWhereKernel() where T : unmanaged + { + try + { + return GenerateWhereKernelIL(); + } + catch (Exception ex) + { + System.Diagnostics.Debug.WriteLine($"[ILKernel] TryGenerateWhereKernel<{typeof(T).Name}>: {ex.GetType().Name}: {ex.Message}"); + return null; + } + } + + private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmanaged + { + int elementSize = Unsafe.SizeOf(); + + // Determine if we can use SIMD + bool canSimd = elementSize <= 8 && IsSimdSupported(); + + var dm = new DynamicMethod( + name: $"IL_Where_{typeof(T).Name}", + returnType: typeof(void), + parameterTypes: new[] { typeof(bool*), typeof(T*), typeof(T*), typeof(T*), typeof(long) }, + owner: typeof(ILKernelGenerator), + skipVisibility: true + ); + + var il = dm.GetILGenerator(); + + // Locals + var locI = il.DeclareLocal(typeof(long)); // loop counter + + // Labels + var lblScalarLoop = il.DefineLabel(); + var lblScalarLoopEnd = il.DefineLabel(); + + // i = 0 + il.Emit(OpCodes.Ldc_I8, 0L); + il.Emit(OpCodes.Stloc, locI); + + if (canSimd && VectorBits >= 128) + { + // Generate SIMD path + EmitWhereSIMDLoop(il, locI); + } + + // Scalar loop for remainder + il.MarkLabel(lblScalarLoop); + + // if (i >= count) goto end + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Bge, lblScalarLoopEnd); + + // result[i] = cond[i] ? x[i] : y[i] + EmitWhereScalarElement(il, locI); + + // i++ + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, 1L); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblScalarLoop); + + il.MarkLabel(lblScalarLoopEnd); + il.Emit(OpCodes.Ret); + + return (WhereKernel)dm.CreateDelegate(typeof(WhereKernel)); + } + + private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + long vectorCount = VectorBits >= 256 ? (32 / elementSize) : (16 / elementSize); + long unrollFactor = 4; + long unrollStep = vectorCount * unrollFactor; + bool useV256 = VectorBits >= 256; + + var locUnrollEnd = il.DeclareLocal(typeof(long)); + var locVectorEnd = il.DeclareLocal(typeof(long)); + + var lblUnrollLoop = il.DefineLabel(); + var lblUnrollLoopEnd = il.DefineLabel(); + var lblVectorLoop = il.DefineLabel(); + var lblVectorLoopEnd = il.DefineLabel(); + + // unrollEnd = count - unrollStep (for 4x unrolled loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locUnrollEnd); + + // vectorEnd = count - vectorCount (for remainder loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locVectorEnd); + + // ========== 4x UNROLLED SIMD LOOP ========== + il.MarkLabel(lblUnrollLoop); + + // if (i > unrollEnd) goto UnrollLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locUnrollEnd); + il.Emit(OpCodes.Bgt, lblUnrollLoopEnd); + + // Process 4 vectors per iteration + for (long u = 0; u < unrollFactor; u++) + { + long offset = vectorCount * u; + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, offset); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, offset); + } + + // i += unrollStep + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblUnrollLoop); + + il.MarkLabel(lblUnrollLoopEnd); + + // ========== REMAINDER SIMD LOOP (1 vector at a time) ========== + il.MarkLabel(lblVectorLoop); + + // if (i > vectorEnd) goto VectorLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locVectorEnd); + il.Emit(OpCodes.Bgt, lblVectorLoopEnd); + + // Process 1 vector + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, 0L); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, 0L); + + // i += vectorCount + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblVectorLoop); + + il.MarkLabel(lblVectorLoopEnd); + } + + private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + // Get the appropriate mask creation method based on element size + var maskMethod = GetMaskCreationMethod256((int)elementSize); + var loadMethod = typeof(Vector256).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); + var storeMethod = typeof(Vector256).GetMethod("Store", new[] { typeof(Vector256<>).MakeGenericType(typeof(T)), typeof(T*) })!; + var selectMethod = typeof(Vector256).GetMethod("ConditionalSelect", new[] { + typeof(Vector256<>).MakeGenericType(typeof(T)), + typeof(Vector256<>).MakeGenericType(typeof(T)), + typeof(Vector256<>).MakeGenericType(typeof(T)) + })!; + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); // cond + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Call mask creation: returns Vector256 on stack + il.Emit(OpCodes.Call, maskMethod); + + // Load x vector: x + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_1); // x + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector: y + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_2); // y + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Stack: mask, xVec, yVec + // ConditionalSelect(mask, x, y) + il.Emit(OpCodes.Call, selectMethod); + + // Store result: result + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_3); // result + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + var maskMethod = GetMaskCreationMethod128((int)elementSize); + var loadMethod = typeof(Vector128).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); + var storeMethod = typeof(Vector128).GetMethod("Store", new[] { typeof(Vector128<>).MakeGenericType(typeof(T)), typeof(T*) })!; + var selectMethod = typeof(Vector128).GetMethod("ConditionalSelect", new[] { + typeof(Vector128<>).MakeGenericType(typeof(T)), + typeof(Vector128<>).MakeGenericType(typeof(T)), + typeof(Vector128<>).MakeGenericType(typeof(T)) + })!; + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, maskMethod); + + // Load x vector + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // ConditionalSelect + il.Emit(OpCodes.Call, selectMethod); + + // Store + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + var typeCode = GetNPTypeCode(); + + // result[i] = cond[i] ? x[i] : y[i] + var lblFalse = il.DefineLabel(); + var lblEnd = il.DefineLabel(); + + // Load result address: result + i * elementSize + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Load cond[i]: cond + i (bool is 1 byte) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Ldind_U1); // Load bool as byte + + // if (!cond[i]) goto lblFalse + il.Emit(OpCodes.Brfalse, lblFalse); + + // True branch: load x[i] + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + il.Emit(OpCodes.Br, lblEnd); + + // False branch: load y[i] + il.MarkLabel(lblFalse); + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + + il.MarkLabel(lblEnd); + // Stack: result_ptr, value + EmitStoreIndirect(il, typeCode); + } + + private static NPTypeCode GetNPTypeCode() where T : unmanaged + { + if (typeof(T) == typeof(bool)) return NPTypeCode.Boolean; + if (typeof(T) == typeof(byte)) return NPTypeCode.Byte; + if (typeof(T) == typeof(short)) return NPTypeCode.Int16; + if (typeof(T) == typeof(ushort)) return NPTypeCode.UInt16; + if (typeof(T) == typeof(int)) return NPTypeCode.Int32; + if (typeof(T) == typeof(uint)) return NPTypeCode.UInt32; + if (typeof(T) == typeof(long)) return NPTypeCode.Int64; + if (typeof(T) == typeof(ulong)) return NPTypeCode.UInt64; + if (typeof(T) == typeof(char)) return NPTypeCode.Char; + if (typeof(T) == typeof(float)) return NPTypeCode.Single; + if (typeof(T) == typeof(double)) return NPTypeCode.Double; + if (typeof(T) == typeof(decimal)) return NPTypeCode.Decimal; + return NPTypeCode.Empty; + } + + #endregion + + #region Mask Creation Methods + + private static MethodInfo GetMaskCreationMethod256(int elementSize) + { + return elementSize switch + { + 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") + }; + } + + private static MethodInfo GetMaskCreationMethod128(int elementSize) + { + return elementSize switch + { + 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") + }; + } + + /// + /// Create V256 mask from 32 bools for 1-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools) + { + var vec = Vector256.Load(bools); + var zero = Vector256.Zero; + var isZero = Vector256.Equals(vec, zero); + return Vector256.OnesComplement(isZero); + } + + /// + /// Create V256 mask from 16 bools for 2-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[7] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[8] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[9] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[10] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[11] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[12] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[13] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[14] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[15] != 0 ? (ushort)0xFFFF : (ushort)0 + ); + } + + /// + /// Create V256 mask from 8 bools for 4-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? 0xFFFFFFFFu : 0u, + bools[1] != 0 ? 0xFFFFFFFFu : 0u, + bools[2] != 0 ? 0xFFFFFFFFu : 0u, + bools[3] != 0 ? 0xFFFFFFFFu : 0u, + bools[4] != 0 ? 0xFFFFFFFFu : 0u, + bools[5] != 0 ? 0xFFFFFFFFu : 0u, + bools[6] != 0 ? 0xFFFFFFFFu : 0u, + bools[7] != 0 ? 0xFFFFFFFFu : 0u + ); + } + + /// + /// Create V256 mask from 4 bools for 8-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[2] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[3] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul + ); + } + + /// + /// Create V128 mask from 16 bools for 1-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools) + { + var vec = Vector128.Load(bools); + var zero = Vector128.Zero; + var isZero = Vector128.Equals(vec, zero); + return Vector128.OnesComplement(isZero); + } + + /// + /// Create V128 mask from 8 bools for 2-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[7] != 0 ? (ushort)0xFFFF : (ushort)0 + ); + } + + /// + /// Create V128 mask from 4 bools for 4-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? 0xFFFFFFFFu : 0u, + bools[1] != 0 ? 0xFFFFFFFFu : 0u, + bools[2] != 0 ? 0xFFFFFFFFu : 0u, + bools[3] != 0 ? 0xFFFFFFFFu : 0u + ); + } + + /// + /// Create V128 mask from 2 bools for 8-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul + ); + } + + #endregion + + #region Scalar Fallback + + /// + /// Scalar fallback for where operation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void WhereScalar(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + for (long i = 0; i < count; i++) + { + result[i] = cond[i] ? x[i] : y[i]; + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs new file mode 100644 index 00000000..3fc30d17 --- /dev/null +++ b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs @@ -0,0 +1,515 @@ +using System; +using System.Diagnostics; +using NumSharp.Backends.Kernels; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Backends.Kernels +{ + /// + /// Tests for SIMD-optimized np.where implementation. + /// Verifies correctness of the SIMD path for all supported dtypes. + /// + public class WhereSimdTests + { + #region SIMD Correctness + + [Test] + public void Where_Simd_Float32_Correctness() + { + var rng = np.random.RandomState(42); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size).astype(NPTypeCode.Single); + var y = rng.rand(size).astype(NPTypeCode.Single); + + var result = np.where(cond, x, y); + + // Verify correctness manually + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (float)x[i] : (float)y[i]; + Assert.AreEqual(expected, (float)result[i], 1e-6f, $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Float64_Correctness() + { + var rng = np.random.RandomState(43); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size); + var y = rng.rand(size); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (double)x[i] : (double)y[i]; + Assert.AreEqual(expected, (double)result[i], 1e-10, $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int32_Correctness() + { + var rng = np.random.RandomState(44); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.randint(0, 1000, new[] { size }); + var y = rng.randint(0, 1000, new[] { size }); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (int)x[i] : (int)y[i]; + Assert.AreEqual(expected, (int)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int64_Correctness() + { + var rng = np.random.RandomState(45); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int64); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (long)x[i] : (long)y[i]; + Assert.AreEqual(expected, (long)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Byte_Correctness() + { + var rng = np.random.RandomState(46); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + var y = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)x[i] : (byte)y[i]; + Assert.AreEqual(expected, (byte)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int16_Correctness() + { + var rng = np.random.RandomState(47); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int16); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (short)x[i] : (short)y[i]; + Assert.AreEqual(expected, (short)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt16_Correctness() + { + var rng = np.random.RandomState(48); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt16); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ushort)x[i] : (ushort)y[i]; + Assert.AreEqual(expected, (ushort)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt32_Correctness() + { + var rng = np.random.RandomState(49); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt32); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (uint)x[i] : (uint)y[i]; + Assert.AreEqual(expected, (uint)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt64_Correctness() + { + var rng = np.random.RandomState(50); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt64); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ulong)x[i] : (ulong)y[i]; + Assert.AreEqual(expected, (ulong)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Boolean_Correctness() + { + var rng = np.random.RandomState(51); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size) > 0.3; + var y = rng.rand(size) > 0.7; + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (bool)x[i] : (bool)y[i]; + Assert.AreEqual(expected, (bool)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Char_Correctness() + { + var rng = np.random.RandomState(52); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var xData = new char[size]; + var yData = new char[size]; + for (int i = 0; i < size; i++) + { + xData[i] = (char)('A' + (i % 26)); + yData[i] = (char)('a' + (i % 26)); + } + var x = np.array(xData); + var y = np.array(yData); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (char)x[i] : (char)y[i]; + Assert.AreEqual(expected, (char)result[i], $"Mismatch at index {i}"); + } + } + + #endregion + + #region Path Selection + + [Test] + public void Where_NonContiguous_Works() + { + // Sliced arrays are non-contiguous, should work correctly + var baseArr = np.arange(20); + var cond = (baseArr % 2 == 0)["::2"]; // Sliced: [true, true, true, true, true, true, true, true, true, true] + var x = np.ones(10, NPTypeCode.Int32); + var y = np.zeros(10, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + Assert.AreEqual(10, result.size); + // All true -> all from x + for (int i = 0; i < 10; i++) + { + Assert.AreEqual(1, (int)result[i]); + } + } + + [Test] + public void Where_Broadcast_Works() + { + // Broadcasted arrays + // cond shape (3,) broadcasts to (3,3): [[T,F,T],[T,F,T],[T,F,T]] + // x shape (3,1) broadcasts to (3,3): [[1,1,1],[2,2,2],[3,3,3]] + // y shape (1,3) broadcasts to (3,3): [[10,20,30],[10,20,30],[10,20,30]] + var cond = np.array(new[] { true, false, true }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 3); + // Verify values: result[i,j] = cond[j] ? x[i,0] : y[0,j] + Assert.AreEqual(1, (int)result[0, 0]); // cond[0]=true -> x=1 + Assert.AreEqual(20, (int)result[0, 1]); // cond[1]=false -> y=20 + Assert.AreEqual(1, (int)result[0, 2]); // cond[2]=true -> x=1 + Assert.AreEqual(2, (int)result[1, 0]); // cond[0]=true -> x=2 + Assert.AreEqual(20, (int)result[1, 1]); // cond[1]=false -> y=20 + } + + [Test] + public void Where_Decimal_Works() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new decimal[] { 1.1m, 2.2m, 3.3m }); + var y = np.array(new decimal[] { 10.1m, 20.2m, 30.3m }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.1m, (decimal)result[0]); + Assert.AreEqual(20.2m, (decimal)result[1]); + Assert.AreEqual(3.3m, (decimal)result[2]); + } + + [Test] + public void Where_NonBoolCondition_Works() + { + // Non-bool condition requires truthiness check + var cond = np.array(new[] { 0, 1, 2, 0 }); // int condition + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Edge Cases + + [Test] + public void Where_Simd_SmallArray() + { + // Array smaller than vector width + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3); + } + + [Test] + public void Where_Simd_VectorAlignedSize() + { + var rng = np.random.RandomState(53); + // Size exactly matches vector width (no scalar tail) + var size = 32; // V256 byte count + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Byte); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)1 : (byte)0; + Assert.AreEqual(expected, (byte)result[i]); + } + } + + [Test] + public void Where_Simd_WithScalarTail() + { + // Size that requires scalar tail processing + var size = 35; // 32 + 3 tail for bytes + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.full(size, (byte)255); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((byte)255, (byte)result[i], $"Mismatch at {i}"); + } + } + + [Test] + public void Where_Simd_AllTrue() + { + var size = 100; + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((long)i, (long)result[i]); + } + } + + [Test] + public void Where_Simd_AllFalse() + { + var size = 100; + var cond = np.zeros(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(-1L, (long)result[i]); + } + } + + [Test] + public void Where_Simd_Alternating() + { + var size = 100; + var condData = new bool[size]; + for (int i = 0; i < size; i++) + condData[i] = i % 2 == 0; + var cond = np.array(condData); + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(i % 2 == 0 ? 1 : 0, (int)result[i], $"Mismatch at {i}"); + } + } + + [Test] + public void Where_Simd_NaN_Propagates() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, 2.0 }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // NaN from x + Assert.IsTrue(double.IsNaN((double)result[1])); // NaN from y + Assert.AreEqual(2.0, (double)result[2], 1e-10); + } + + [Test] + public void Where_Simd_Infinity() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { double.PositiveInfinity, 0.0, double.NegativeInfinity, 0.0 }); + var y = np.array(new[] { 0.0, double.PositiveInfinity, 0.0, double.NegativeInfinity }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.PositiveInfinity, (double)result[1]); + Assert.AreEqual(double.NegativeInfinity, (double)result[2]); + Assert.AreEqual(double.NegativeInfinity, (double)result[3]); + } + + #endregion + + #region Performance Sanity Check + + [Test] + public void Where_Simd_LargeArray_Correctness() + { + var rng = np.random.RandomState(54); + var size = 100_000; + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var result = np.where(cond, x, y); + + // Spot check + for (int i = 0; i < 100; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + + // Check last few elements (scalar tail) + for (int i = size - 10; i < size; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + } + + #endregion + + #region 2D/Multi-dimensional + + [Test] + public void Where_Simd_2D_Contiguous() + { + var rng = np.random.RandomState(55); + // 2D contiguous array should use SIMD + var shape = new[] { 100, 100 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(100, 100); + + // Spot check + for (int i = 0; i < 10; i++) + { + for (int j = 0; j < 10; j++) + { + var expected = (bool)cond[i, j] ? 1 : 0; + Assert.AreEqual(expected, (int)result[i, j]); + } + } + } + + [Test] + public void Where_Simd_3D_Contiguous() + { + var rng = np.random.RandomState(56); + var shape = new[] { 10, 20, 30 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Single); + var y = np.zeros(shape, NPTypeCode.Single); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(10, 20, 30); + + // Spot check + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + for (int k = 0; k < 5; k++) + { + var expected = (bool)cond[i, j, k] ? 1.0f : 0.0f; + Assert.AreEqual(expected, (float)result[i, j, k], 1e-6f); + } + } + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs new file mode 100644 index 00000000..5834d16a --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs @@ -0,0 +1,346 @@ +using System; +using System.Linq; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Battle tests for np.where - edge cases, strided arrays, views, etc. + /// + public class np_where_BattleTest + { + #region Strided/Sliced Arrays + + [Test] + public void Where_SlicedCondition() + { + // Sliced condition array + var arr = np.arange(10); + var cond = (arr % 2 == 0)["::2"]; // Every other even check + var x = np.ones(5, NPTypeCode.Int32); + var y = np.zeros(5, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + // Should work with sliced condition + Assert.AreEqual(5, result.size); + } + + [Test] + public void Where_SlicedXY() + { + var cond = np.array(new[] { true, false, true }); + var x = np.arange(6)["::2"]; // [0, 2, 4] + var y = np.arange(6)["1::2"]; // [1, 3, 5] + var result = np.where(cond, x, y); + + result.Should().BeOfValues(0L, 3L, 4L); + } + + [Test] + public void Where_TransposedArrays() + { + var cond = np.array(new bool[,] { { true, false }, { false, true } }).T; + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }).T; + var y = np.array(new int[,] { { 10, 20 }, { 30, 40 } }).T; + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + // After transpose: cond[0,0]=T, cond[0,1]=F, cond[1,0]=F, cond[1,1]=T + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(30, (int)result[0, 1]); + Assert.AreEqual(20, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_ReversedSlice() + { + var cond = np.array(new[] { true, false, true, false, true }); + var x = np.arange(5)["::-1"]; // [4, 3, 2, 1, 0] + var y = np.zeros(5, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(4L, 0L, 2L, 0L, 0L); + } + + #endregion + + #region Complex Broadcasting + + [Test] + public void Where_3Way_Broadcasting() + { + // cond: (2,1,1), x: (1,3,1), y: (1,1,4) -> result: (2,3,4) + var cond = np.array(new bool[,,] { { { true } }, { { false } } }); + var x = np.arange(3).reshape(1, 3, 1); + var y = (np.arange(4) * 10).reshape(1, 1, 4); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 4); + // First "page" (cond=True): values from x broadcast + Assert.AreEqual(0, (long)result[0, 0, 0]); + Assert.AreEqual(0, (long)result[0, 0, 3]); + Assert.AreEqual(2, (long)result[0, 2, 0]); + // Second "page" (cond=False): values from y broadcast + Assert.AreEqual(0, (long)result[1, 0, 0]); + Assert.AreEqual(30, (long)result[1, 0, 3]); + Assert.AreEqual(30, (long)result[1, 2, 3]); + } + + [Test] + public void Where_RowVector_ColVector_Broadcast() + { + // cond: (1,4), x: (3,1), y: scalar -> result: (3,4) + var cond = np.array(new bool[,] { { true, false, true, false } }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var result = np.where(cond, x, 0); + + result.Should().BeShaped(3, 4); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[1, 0]); + Assert.AreEqual(0, (int)result[1, 1]); + } + + #endregion + + #region Numeric Edge Cases + + [Test] + public void Where_NaN_Values() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, double.NaN }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // from x + Assert.IsTrue(double.IsNaN((double)result[1])); // from y + Assert.IsTrue(double.IsNaN((double)result[2])); // from x + } + + [Test] + public void Where_Infinity_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { double.PositiveInfinity, 1.0 }); + var y = np.array(new[] { 0.0, double.NegativeInfinity }); + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.NegativeInfinity, (double)result[1]); + } + + [Test] + public void Where_MaxMin_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { long.MaxValue, 0L }); + var y = np.array(new[] { 0L, long.MinValue }); + var result = np.where(cond, x, y); + + Assert.AreEqual(long.MaxValue, (long)result[0]); + Assert.AreEqual(long.MinValue, (long)result[1]); + } + + #endregion + + #region Single Arg Edge Cases + + [Test] + public void Where_SingleArg_Float_Truthy() + { + // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy + var arr = np.array(new[] { 0.0, 1.0, -1.0, 0.5, -0.0 }); + var result = np.where(arr); + + // Note: -0.0 == 0.0 in IEEE 754, so it's falsy + result[0].Should().BeOfValues(1L, 2L, 3L); + } + + [Test] + public void Where_SingleArg_NaN_IsTruthy() + { + // NaN is non-zero, so it's truthy + var arr = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(arr); + + result[0].Should().BeOfValues(1L); + } + + [Test] + public void Where_SingleArg_4D() + { + var arr = np.zeros(new[] { 2, 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 1, 0, 1] = 1; + arr[1, 0, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(4, result.Length); // 4 dimensions + Assert.AreEqual(2, result[0].size); // 2 non-zero elements + } + + #endregion + + #region Performance/Stress Tests + + [Test] + public void Where_LargeArray_Performance() + { + var size = 1_000_000; + var cond = np.random.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var sw = System.Diagnostics.Stopwatch.StartNew(); + var result = np.where(cond, x, y); + sw.Stop(); + + Assert.AreEqual(size, result.size); + // Should complete in reasonable time (< 1 second for 1M elements) + Assert.IsTrue(sw.ElapsedMilliseconds < 1000, $"Took {sw.ElapsedMilliseconds}ms"); + } + + [Test] + public void Where_ManyDimensions() + { + // 6D array + var shape = new[] { 2, 3, 2, 2, 2, 3 }; + var cond = np.ones(shape, NPTypeCode.Boolean); + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 2, 2, 2, 3); + Assert.AreEqual(1, (int)result[0, 0, 0, 0, 0, 0]); + } + + #endregion + + #region Type Conversion Edge Cases + + [Test] + public void Where_UnsignedOverflow() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 255, 0 }); + var y = np.array(new byte[] { 0, 255 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)255, (byte)result[0]); + Assert.AreEqual((byte)255, (byte)result[1]); + } + + [Test] + public void Where_Decimal() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new decimal[] { 1.23456789m, 0m }); + var y = np.array(new decimal[] { 0m, 9.87654321m }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.23456789m, (decimal)result[0]); + Assert.AreEqual(9.87654321m, (decimal)result[1]); + } + + [Test] + public void Where_Char() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new char[] { 'A', 'B', 'C' }); + var y = np.array(new char[] { 'X', 'Y', 'Z' }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(char), result.dtype); + Assert.AreEqual('A', (char)result[0]); + Assert.AreEqual('Y', (char)result[1]); + Assert.AreEqual('C', (char)result[2]); + } + + #endregion + + #region View Behavior + + [Test] + public void Where_ResultIsNewArray_NotView() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + // Modify original, result should not change + x[0] = 999; + Assert.AreEqual(1, (int)result[0], "Result should be independent of x"); + + y[1] = 999; + Assert.AreEqual(20, (int)result[1], "Result should be independent of y"); + } + + [Test] + public void Where_ModifyResult_DoesNotAffectInputs() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + result[0] = 999; + Assert.AreEqual(1, (int)x[0], "x should not be modified"); + Assert.AreEqual(10, (int)y[0], "y should not be modified"); + } + + #endregion + + #region Alternating Patterns + + [Test] + public void Where_Checkerboard_Pattern() + { + // Create checkerboard condition + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = (i + j) % 2 == 0; + + var x = np.ones(new[] { 4, 4 }, NPTypeCode.Int32); + var y = np.zeros(new[] { 4, 4 }, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + // Verify checkerboard pattern + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(1, (int)result[1, 1]); + } + + [Test] + public void Where_StripedPattern() + { + // Every row alternates between all True and all False + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = i % 2 == 0; + + var x = np.full(new[] { 4, 4 }, 1); + var y = np.full(new[] { 4, 4 }, 0); + var result = np.where(cond, x, y); + + // Rows 0, 2 should be 1; rows 1, 3 should be 0 + for (int j = 0; j < 4; j++) + { + Assert.AreEqual(1, (int)result[0, j]); + Assert.AreEqual(0, (int)result[1, j]); + Assert.AreEqual(1, (int)result[2, j]); + Assert.AreEqual(0, (int)result[3, j]); + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs new file mode 100644 index 00000000..d53ed585 --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs @@ -0,0 +1,496 @@ +using System; +using System.Linq; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Comprehensive tests for np.where matching NumPy 2.x behavior. + /// + /// NumPy signature: where(condition, x=None, y=None, /) + /// - Single arg: returns np.nonzero(condition) + /// - Three args: element-wise selection with broadcasting + /// + public class np_where_Test + { + #region Single Argument (nonzero equivalent) + + [Test] + public void Where_SingleArg_1D_ReturnsIndices() + { + // np.where([0, 1, 0, 2, 0, 3]) -> (array([1, 3, 5]),) + var arr = np.array(new[] { 0, 1, 0, 2, 0, 3 }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(1L, 3L, 5L); + } + + [Test] + public void Where_SingleArg_2D_ReturnsTupleOfIndices() + { + // np.where([[0, 1, 0], [2, 0, 3]]) -> (array([0, 1, 1]), array([1, 0, 2])) + var arr = np.array(new int[,] { { 0, 1, 0 }, { 2, 0, 3 } }); + var result = np.where(arr); + + Assert.AreEqual(2, result.Length); + result[0].Should().BeOfValues(0L, 1L, 1L); // row indices + result[1].Should().BeOfValues(1L, 0L, 2L); // col indices + } + + [Test] + public void Where_SingleArg_Boolean_ReturnsNonzero() + { + var arr = np.array(new[] { true, false, true, false, true }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(0L, 2L, 4L); + } + + [Test] + public void Where_SingleArg_Empty_ReturnsEmptyIndices() + { + var arr = np.array(new int[0]); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [Test] + public void Where_SingleArg_AllFalse_ReturnsEmptyIndices() + { + var arr = np.array(new[] { false, false, false }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [Test] + public void Where_SingleArg_AllTrue_ReturnsAllIndices() + { + var arr = np.array(new[] { true, true, true }); + var result = np.where(arr); + + result[0].Should().BeOfValues(0L, 1L, 2L); + } + + [Test] + public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays() + { + // 2x2x2 array with some non-zero elements + var arr = np.zeros(new[] { 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 0, 1] = 1; + arr[1, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(3, result.Length); + result[0].Should().BeOfValues(0L, 1L); // dim 0 + result[1].Should().BeOfValues(0L, 1L); // dim 1 + result[2].Should().BeOfValues(1L, 0L); // dim 2 + } + + #endregion + + #region Three Arguments (element-wise selection) + + [Test] + public void Where_ThreeArgs_Basic_SelectsCorrectly() + { + // np.where(a < 5, a, 10*a) for a = arange(10) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [Test] + public void Where_ThreeArgs_BooleanCondition() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3, 40); + } + + [Test] + public void Where_ThreeArgs_2D() + { + // np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy() + { + // np.where([0, 1, 2, 0], 100, -100) -> [-100, 100, 100, -100] + var cond = np.array(new[] { 0, 1, 2, 0 }); + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Scalar Arguments + + [Test] + public void Where_ScalarX() + { + var cond = np.array(new[] { true, false, true, false }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, 99, y); + + result.Should().BeOfValues(99, 20, 99, 40); + } + + [Test] + public void Where_ScalarY() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var result = np.where(cond, x, -1); + + result.Should().BeOfValues(1, -1, 3, -1); + } + + [Test] + public void Where_BothScalars() + { + var cond = np.array(new[] { true, false, true, false }); + var result = np.where(cond, 1, 0); + + result.Should().BeOfValues(1, 0, 1, 0); + } + + [Test] + public void Where_ScalarFloat() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1.5, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + #endregion + + #region Broadcasting + + [Test] + public void Where_Broadcasting_ScalarY() + { + // np.where(a < 4, a, -1) for 3x3 array + var arr = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(arr < 4, arr, -1); + + result.Should().BeShaped(3, 3); + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + [Test] + public void Where_Broadcasting_DifferentShapes() + { + // cond: (2,1), x: (3,), y: (1,3) -> result: (2,3) + var cond = np.array(new bool[,] { { true }, { false } }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3); + // Row 0: cond=True, so x values + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(2, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[0, 2]); + // Row 1: cond=False, so y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(20, (int)result[1, 1]); + Assert.AreEqual(30, (int)result[1, 2]); + } + + [Test] + public void Where_Broadcasting_ColumnVector() + { + // cond: (3,1), x: scalar, y: (1,4) -> result: (3,4) + var cond = np.array(new bool[,] { { true }, { false }, { true } }); + var x = 1; + var y = np.array(new int[,] { { 10, 20, 30, 40 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 4); + // Row 0: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[0, j]); + // Row 1: y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(40, (int)result[1, 3]); + // Row 2: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[2, j]); + } + + #endregion + + #region Type Promotion + + [Test] + public void Where_TypePromotion_IntFloat_ReturnsFloat64() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.0, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + [Test] + public void Where_TypePromotion_Int32Int64_ReturnsInt64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1 }); + var y = np.array(new long[] { 2 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + } + + [Test] + public void Where_TypePromotion_FloatDouble_ReturnsDouble() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f }); + var y = np.array(new double[] { 2.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + } + + #endregion + + #region Edge Cases + + [Test] + public void Where_EmptyArrays_ThreeArgs() + { + var cond = np.array(new bool[0]); + var x = np.array(new int[0]); + var y = np.array(new int[0]); + var result = np.where(cond, x, y); + + Assert.AreEqual(0, result.size); + } + + [Test] + public void Where_SingleElement() + { + var cond = np.array(new[] { true }); + var result = np.where(cond, 42, 0); + + Assert.AreEqual(1, result.size); + Assert.AreEqual(42, (int)result[0]); + } + + [Test] + public void Where_AllTrue_ReturnsAllX() + { + var cond = np.array(new[] { true, true, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 2, 3); + } + + [Test] + public void Where_AllFalse_ReturnsAllY() + { + var cond = np.array(new[] { false, false, false }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(10, 20, 30); + } + + [Test] + public void Where_LargeArray() + { + var size = 100000; + var cond = np.arange(size) % 2 == 0; // alternating True/False + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + // Even indices should be 1, odd should be 0 + Assert.AreEqual(1, (int)result[0]); + Assert.AreEqual(0, (int)result[1]); + Assert.AreEqual(1, (int)result[2]); + } + + #endregion + + #region NumPy Output Verification + + [Test] + public void Where_NumPyExample1() + { + // From NumPy docs: np.where([[True, False], [True, True]], + // [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + // Expected: array([[1, 8], [3, 4]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_NumPyExample2() + { + // From NumPy docs: np.where(a < 5, a, 10*a) for a = arange(10) + // Expected: array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90]) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [Test] + public void Where_NumPyExample3() + { + // From NumPy docs: np.where(a < 4, a, -1) for specific array + // Expected: array([[ 0, 1, 2], [ 0, 2, -1], [ 0, 3, -1]]) + var a = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(a < 4, a, -1); + + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(2, (int)result[1, 1]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(0, (int)result[2, 0]); + Assert.AreEqual(3, (int)result[2, 1]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + #endregion + + #region Dtype Coverage + + [Test] + public void Where_Dtype_Byte() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 1, 2 }); + var y = np.array(new byte[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + result.Should().BeOfValues((byte)1, (byte)20); + } + + [Test] + public void Where_Dtype_Int16() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new short[] { 1, 2 }); + var y = np.array(new short[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(short), result.dtype); + result.Should().BeOfValues((short)1, (short)20); + } + + [Test] + public void Where_Dtype_Int32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1, 2 }); + var y = np.array(new int[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(int), result.dtype); + result.Should().BeOfValues(1, 20); + } + + [Test] + public void Where_Dtype_Int64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new long[] { 1, 2 }); + var y = np.array(new long[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + result.Should().BeOfValues(1L, 20L); + } + + [Test] + public void Where_Dtype_Single() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f, 2.5f }); + var y = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(float), result.dtype); + Assert.AreEqual(1.5f, (float)result[0], 1e-6f); + Assert.AreEqual(20.5f, (float)result[1], 1e-6f); + } + + [Test] + public void Where_Dtype_Double() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new double[] { 1.5, 2.5 }); + var y = np.array(new double[] { 10.5, 20.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(20.5, (double)result[1], 1e-10); + } + + [Test] + public void Where_Dtype_Boolean() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new bool[] { true, true }); + var y = np.array(new bool[] { false, false }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(bool), result.dtype); + Assert.IsTrue((bool)result[0]); + Assert.IsFalse((bool)result[1]); + } + + #endregion + } +} From c335e0a27f770c72a419852058707f86c23367ff Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 14:11:23 +0300 Subject: [PATCH 02/19] perf(where): AVX2/SSE4.1 optimize mask expansion in np.where kernel Replace scalar conditional mask creation with SIMD intrinsics: V256 mask creation (for AVX2): - 8-byte elements: Avx2.ConvertToVector256Int64 (vpmovzxbq) - 4-byte elements: Avx2.ConvertToVector256Int32 (vpmovzxbd) - 2-byte elements: Avx2.ConvertToVector256Int16 (vpmovzxbw) V128 mask creation (for SSE4.1): - 8-byte elements: Sse41.ConvertToVector128Int64 (pmovzxbq) - 4-byte elements: Sse41.ConvertToVector128Int32 (pmovzxbd) - 2-byte elements: Sse41.ConvertToVector128Int16 (pmovzxbw) Each intrinsic replaces 4-16 scalar conditionals with a single zero-extend + compare instruction sequence. Also fixes reflection lookups for Vector256/Vector128.Load, Store, and ConditionalSelect methods that were failing because these are generic method definitions requiring special handling. Performance (1M double elements): - Kernel: 2.6ms @ 381 M elements/ms - NumPy baseline: ~1.86ms - Ratio: ~1.4x slower (down from ~3x before optimization) All 12 dtypes supported with fallback for non-AVX2/SSE4.1 systems. --- .../Kernels/ILKernelGenerator.Where.cs | 106 +++++++++++++++--- 1 file changed, 92 insertions(+), 14 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index e055bd8a..446755ec 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -4,6 +4,7 @@ using System.Reflection.Emit; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; // ============================================================================= // ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels @@ -253,13 +254,17 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder { // Get the appropriate mask creation method based on element size var maskMethod = GetMaskCreationMethod256((int)elementSize); - var loadMethod = typeof(Vector256).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); - var storeMethod = typeof(Vector256).GetMethod("Store", new[] { typeof(Vector256<>).MakeGenericType(typeof(T)), typeof(T*) })!; - var selectMethod = typeof(Vector256).GetMethod("ConditionalSelect", new[] { - typeof(Vector256<>).MakeGenericType(typeof(T)), - typeof(Vector256<>).MakeGenericType(typeof(T)), - typeof(Vector256<>).MakeGenericType(typeof(T)) - })!; + + // Get Vector256 methods via reflection - need to find generic method definitions first + var loadMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! + .MakeGenericMethod(typeof(T)); + var storeMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! + .MakeGenericMethod(typeof(T)); + var selectMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! + .MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); // cond @@ -325,13 +330,17 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { var maskMethod = GetMaskCreationMethod128((int)elementSize); - var loadMethod = typeof(Vector128).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); - var storeMethod = typeof(Vector128).GetMethod("Store", new[] { typeof(Vector128<>).MakeGenericType(typeof(T)), typeof(T*) })!; - var selectMethod = typeof(Vector128).GetMethod("ConditionalSelect", new[] { - typeof(Vector128<>).MakeGenericType(typeof(T)), - typeof(Vector128<>).MakeGenericType(typeof(T)), - typeof(Vector128<>).MakeGenericType(typeof(T)) - })!; + + // Get Vector128 methods via reflection - need to find generic method definitions first + var loadMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! + .MakeGenericMethod(typeof(T)); + var storeMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! + .MakeGenericMethod(typeof(T)); + var selectMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! + .MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); @@ -502,10 +511,22 @@ private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools) /// /// Create V256 mask from 16 bools for 2-byte elements. + /// Uses AVX2 vpmovzxbw instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 16 bytes into Vector128, zero-extend each byte to 16-bit + // vpmovzxbw: byte -> word (16 bytes -> 16 words) + var bytes128 = Vector128.Load(bools); + var expanded = Avx2.ConvertToVector256Int16(bytes128).AsUInt16(); + // Compare with zero: non-zero becomes 0xFFFF, zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, @@ -528,10 +549,22 @@ private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) /// /// Create V256 mask from 8 bools for 4-byte elements. + /// Uses AVX2 vpmovzxbd instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 8 bytes into low bytes of Vector128, zero-extend each byte to 32-bit + // vpmovzxbd: byte -> dword (8 bytes -> 8 dwords) + var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); + var expanded = Avx2.ConvertToVector256Int32(bytes128).AsUInt32(); + // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? 0xFFFFFFFFu : 0u, bools[1] != 0 ? 0xFFFFFFFFu : 0u, @@ -546,10 +579,22 @@ private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) /// /// Create V256 mask from 4 bools for 8-byte elements. + /// Uses AVX2 vpmovzxbq instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 4 bytes into low bytes of Vector128, zero-extend each byte to 64-bit + // vpmovzxbq: byte -> qword (4 bytes -> 4 qwords) + var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); + var expanded = Avx2.ConvertToVector256Int64(bytes128).AsUInt64(); + // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, @@ -572,10 +617,21 @@ private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools) /// /// Create V128 mask from 8 bools for 2-byte elements. + /// Uses SSE4.1 pmovzxbw instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 8 bytes, zero-extend each to 16-bit + // pmovzxbw: byte -> word (8 bytes -> 8 words) + var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int16(bytes128).AsUInt16(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, @@ -590,10 +646,21 @@ private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) /// /// Create V128 mask from 4 bools for 4-byte elements. + /// Uses SSE4.1 pmovzxbd instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 4 bytes, zero-extend each to 32-bit + // pmovzxbd: byte -> dword (4 bytes -> 4 dwords) + var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int32(bytes128).AsUInt32(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? 0xFFFFFFFFu : 0u, bools[1] != 0 ? 0xFFFFFFFFu : 0u, @@ -604,10 +671,21 @@ private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) /// /// Create V128 mask from 2 bools for 8-byte elements. + /// Uses SSE4.1 pmovzxbq instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 2 bytes, zero-extend each to 64-bit + // pmovzxbq: byte -> qword (2 bytes -> 2 qwords) + var bytes128 = Vector128.CreateScalar(*(ushort*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int64(bytes128).AsUInt64(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul From 753d753f17ca68ed1097301fe19cc262d9a47eb7 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 14:48:20 +0300 Subject: [PATCH 03/19] perf(where): inline mask creation in IL - 5.4x faster kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of emitting Call opcodes to mask helper methods, now emit the AVX2/SSE4.1 instructions directly inline in the IL stream. This eliminates: - Method call overhead (~12% per call) - Runtime Avx2.IsSupported checks in hot path - JIT optimization barriers at call boundaries The IL now emits the full mask creation sequence: - 8-byte: ldind.u4 → CreateScalar → AsByte → ConvertToVector256Int64 → AsUInt64 → GreaterThan - 4-byte: ldind.i8 → CreateScalar → AsByte → ConvertToVector256Int32 → AsUInt32 → GreaterThan - 2-byte: Load → ConvertToVector256Int16 → AsUInt16 → GreaterThan - 1-byte: Load → GreaterThan (direct comparison) Performance (1M double elements): - Previous (method call): 2.6 ms - Inlined IL: 0.48 ms (5.4x faster) - NumPy baseline: 1.86 ms (NumSharp is now 3.9x FASTER) Fixed reflection lookups for AsByte/AsUInt* which are extension methods on Vector128/Vector256 static classes, not instance methods. --- .../Kernels/ILKernelGenerator.Where.cs | 237 +++++++++++++++++- 1 file changed, 229 insertions(+), 8 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index 446755ec..1b4eb4b0 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -252,9 +252,6 @@ private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) wher private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - // Get the appropriate mask creation method based on element size - var maskMethod = GetMaskCreationMethod256((int)elementSize); - // Get Vector256 methods via reflection - need to find generic method definitions first var loadMethod = Array.Find(typeof(Vector256).GetMethods(), m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! @@ -277,8 +274,8 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder il.Emit(OpCodes.Conv_I); il.Emit(OpCodes.Add); - // Call mask creation: returns Vector256 on stack - il.Emit(OpCodes.Call, maskMethod); + // Inline mask creation - emit AVX2 instructions directly instead of calling helper + EmitInlineMaskCreationV256(il, (int)elementSize); // Load x vector: x + (i + offset) * elementSize il.Emit(OpCodes.Ldarg_1); // x @@ -329,8 +326,6 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - var maskMethod = GetMaskCreationMethod128((int)elementSize); - // Get Vector128 methods via reflection - need to find generic method definitions first var loadMethod = Array.Find(typeof(Vector128).GetMethods(), m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! @@ -352,7 +347,9 @@ private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder } il.Emit(OpCodes.Conv_I); il.Emit(OpCodes.Add); - il.Emit(OpCodes.Call, maskMethod); + + // Inline mask creation - emit SSE4.1 instructions directly + EmitInlineMaskCreationV128(il, (int)elementSize); // Load x vector il.Emit(OpCodes.Ldarg_1); @@ -497,6 +494,230 @@ private static MethodInfo GetMaskCreationMethod128(int elementSize) }; } + #endregion + + #region Inline Mask IL Emission + + // Cache reflection lookups for inline emission + private static readonly MethodInfo _v128LoadByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + private static readonly MethodInfo _v256LoadByte = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly MethodInfo _v128CreateScalarUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128CreateScalarULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128CreateScalarUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + + // AsByte is an extension method on Vector128 static class, not instance method + private static readonly MethodInfo _v128UIntAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128ULongAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128UShortAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + + private static readonly MethodInfo _avx2ConvertToV256Int64 = typeof(Avx2).GetMethod("ConvertToVector256Int64", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _avx2ConvertToV256Int32 = typeof(Avx2).GetMethod("ConvertToVector256Int32", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _avx2ConvertToV256Int16 = typeof(Avx2).GetMethod("ConvertToVector256Int16", new[] { typeof(Vector128) })!; + + private static readonly MethodInfo _sse41ConvertToV128Int64 = typeof(Sse41).GetMethod("ConvertToVector128Int64", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _sse41ConvertToV128Int32 = typeof(Sse41).GetMethod("ConvertToVector128Int32", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _sse41ConvertToV128Int16 = typeof(Sse41).GetMethod("ConvertToVector128Int16", new[] { typeof(Vector128) })!; + + // As* methods are extension methods on Vector256/Vector128 static classes + private static readonly MethodInfo _v256LongAsULong = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); + private static readonly MethodInfo _v256IntAsUInt = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); + private static readonly MethodInfo _v256ShortAsUShort = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); + + private static readonly MethodInfo _v128LongAsULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); + private static readonly MethodInfo _v128IntAsUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); + private static readonly MethodInfo _v128ShortAsUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); + + private static readonly MethodInfo _v256GreaterThanULong = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v256GreaterThanUInt = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v256GreaterThanUShort = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + private static readonly MethodInfo _v256GreaterThanByte = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly MethodInfo _v128GreaterThanULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128GreaterThanUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128GreaterThanUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly FieldInfo _v256ZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!.IsStatic + ? null! : null!; // Use GetMethod call instead + private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroByte = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + + private static readonly MethodInfo _v128GetZeroULong = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroUInt = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroUShort = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroByte = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + + /// + /// Emit inline V256 mask creation. Stack: byte* -> Vector256{T} (as mask) + /// + private static void EmitInlineMaskCreationV256(ILGenerator il, int elementSize) + { + // Stack has: byte* pointing to condition bools + + switch (elementSize) + { + case 8: // double/long: load 4 bytes, expand to 4 qwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, _v128UIntAsByte); + // Avx2.ConvertToVector256Int64(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, _v256LongAsULong); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroULong); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanULong); + break; + + case 4: // float/int: load 8 bytes, expand to 8 dwords + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, _v128ULongAsByte); + // Avx2.ConvertToVector256Int32(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, _v256IntAsUInt); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroUInt); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanUInt); + break; + + case 2: // short/char: load 16 bytes, expand to 16 words + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, _v128LoadByte); + // Avx2.ConvertToVector256Int16(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, _v256ShortAsUShort); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroUShort); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanUShort); + break; + + case 1: // byte/bool: load 32 bytes, compare directly + // Vector256.Load(ptr) + il.Emit(OpCodes.Call, _v256LoadByte); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroByte); + // Vector256.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, _v256GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + /// + /// Emit inline V128 mask creation. Stack: byte* -> Vector128{T} (as mask) + /// + private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) + { + switch (elementSize) + { + case 8: // double/long: load 2 bytes, expand to 2 qwords + // *(ushort*)ptr + il.Emit(OpCodes.Ldind_U2); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUShort); + // .AsByte() + il.Emit(OpCodes.Call, _v128UShortAsByte); + // Sse41.ConvertToVector128Int64(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, _v128LongAsULong); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroULong); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanULong); + break; + + case 4: // float/int: load 4 bytes, expand to 4 dwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, _v128UIntAsByte); + // Sse41.ConvertToVector128Int32(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, _v128IntAsUInt); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroUInt); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanUInt); + break; + + case 2: // short/char: load 8 bytes, expand to 8 words + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, _v128ULongAsByte); + // Sse41.ConvertToVector128Int16(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, _v128ShortAsUShort); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroUShort); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanUShort); + break; + + case 1: // byte/bool: load 16 bytes, compare directly + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, _v128LoadByte); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroByte); + // Vector128.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, _v128GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + #endregion + + #region Static Mask Creation Methods (fallback) + /// /// Create V256 mask from 32 bools for 1-byte elements. /// From 25859e573844f03c67b32141581f64054aba0b46 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 16:15:27 +0300 Subject: [PATCH 04/19] fix(where): implement NumPy 2.x NEP50 type promotion for np.where MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NumPy 2.x NEP50 "weak scalar" semantics for np.where, detecting scalar NDArrays via Shape.IsScalar for clean type promotion without requiring per-type overloads. TYPE PROMOTION RULES: 1. Same-type scalars: preserve type - int + int → int32 (both same type, preserve) - byte + byte → byte - float + float → float32 2. Mixed-type scalars: use array-array promotion - int + long → int64 - int + double → float64 - byte + short → int16 3. NEP50 weak scalar: scalar + array → array dtype wins - int scalar + uint8 array → uint8 - int scalar + float32 array → float32 4. Cross-kind promotion uses standard rules - float scalar + int32 array → float64 IMPLEMENTATION: - Simplified to 4 overloads (NDArray, object+NDArray, NDArray+object, object+object) - Detect scalar NDArrays via Shape.IsScalar (works for both implicit conversion and explicit np.array() calls) - Input arrays converted to output dtype before kernel/iterator dispatch NOTE: Unlike NumPy where Python int literals widen to int64, C# int literals create int32 scalar NDArrays indistinguishable from explicit np.array(1, dtype=int32). We preserve same-type scalars rather than widening, which is consistent with C#'s typed literal semantics. --- src/NumSharp.Core/APIs/np.where.cs | 144 +++++- .../Logic/np.where.BattleTest.cs | 423 +++++++++++++++++- test/NumSharp.UnitTest/Logic/np.where.Test.cs | 1 + 3 files changed, 549 insertions(+), 19 deletions(-) diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs index a361534a..a7cbf129 100644 --- a/src/NumSharp.Core/APIs/np.where.cs +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -26,6 +26,54 @@ public static NDArray[] where(NDArray condition) /// An array with elements from `x` where `condition` is True, and elements from `y` elsewhere. /// https://numpy.org/doc/stable/reference/generated/numpy.where.html public static NDArray where(NDArray condition, NDArray x, NDArray y) + { + // Detect scalar NDArrays (from implicit primitive conversion or explicit NDArray.Scalar) + // Scalar NDArrays use NEP50 weak scalar type promotion rules + bool xIsScalar = x.Shape.IsScalar; + bool yIsScalar = y.Shape.IsScalar; + return where_internal(condition, x, y, xIsScalar, yIsScalar); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for x. + /// + public static NDArray where(NDArray condition, object x, NDArray y) + { + var xArr = asanyarray(x); + return where_internal(condition, xArr, y, xArr.Shape.IsScalar, y.Shape.IsScalar); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for y. + /// + public static NDArray where(NDArray condition, NDArray x, object y) + { + var yArr = asanyarray(y); + return where_internal(condition, x, yArr, x.Shape.IsScalar, yArr.Shape.IsScalar); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for both x and y. + /// + public static NDArray where(NDArray condition, object x, object y) + { + var xArr = asanyarray(x); + var yArr = asanyarray(y); + return where_internal(condition, xArr, yArr, xArr.Shape.IsScalar, yArr.Shape.IsScalar); + } + + /// + /// Internal implementation of np.where with scalar tracking for NEP50 type promotion. + /// + /// Condition array + /// X values (already converted to NDArray) + /// Y values (already converted to NDArray) + /// True if x is a scalar NDArray + /// True if y is a scalar NDArray + private static NDArray where_internal(NDArray condition, NDArray x, NDArray y, bool xIsScalar, bool yIsScalar) { // Broadcast all three arrays to common shape var broadcasted = broadcast_arrays(condition, x, y); @@ -33,8 +81,15 @@ public static NDArray where(NDArray condition, NDArray x, NDArray y) var xArr = broadcasted[1]; var yArr = broadcasted[2]; - // Determine output dtype from x and y (type promotion) - var outType = _FindCommonType(xArr, yArr); + // Determine output dtype from x and y using NEP50-aware type promotion + var outType = _FindCommonTypeForWhere(x.GetTypeCode, y.GetTypeCode, xIsScalar, yIsScalar); + + // Convert x and y to output type if needed (required for kernel and iterator paths) + if (xArr.GetTypeCode != outType) + xArr = xArr.astype(outType, copy: false); + if (yArr.GetTypeCode != outType) + yArr = yArr.astype(outType, copy: false); + // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides) var result = empty(cond.shape, outType); @@ -103,30 +158,91 @@ public static NDArray where(NDArray condition, NDArray x, NDArray y) } /// - /// Return elements chosen from `x` or `y` depending on `condition`. - /// Scalar overload for x. + /// Determines the output dtype for np.where following NumPy 2.x NEP50 rules. + /// + /// Rules: + /// 1. Both arrays (non-scalar): use array-array promotion table + /// 2. Both were scalars: use Python-like defaults (int32→int64) + /// 3. One array, one scalar: use NEP50 weak scalar rules (array dtype wins for same-kind) /// - public static NDArray where(NDArray condition, object x, NDArray y) + private static NPTypeCode _FindCommonTypeForWhere(NPTypeCode xType, NPTypeCode yType, bool xIsScalar, bool yIsScalar) { - return where(condition, asanyarray(x), y); + // Case 1: Both are scalars - use Python-like default type widening + if (xIsScalar && yIsScalar) + { + return _GetPythonLikeScalarType(xType, yType); + } + + // Case 2: One is scalar, one is array - use NEP50 weak scalar rules + if (xIsScalar) + { + // y is array, x is scalar - array wins for same-kind + return _FindCommonArrayScalarType(yType, xType); + } + if (yIsScalar) + { + // x is array, y is scalar - array wins for same-kind + return _FindCommonArrayScalarType(xType, yType); + } + + // Case 3: Both are arrays - use array-array promotion + return _FindCommonArrayType(xType, yType); } /// - /// Return elements chosen from `x` or `y` depending on `condition`. - /// Scalar overload for y. + /// Determines the result type when both operands are scalar NDArrays. + /// + /// C# limitation: We cannot distinguish between: + /// - `np.where(cond, 1, 0)` where 1,0 are C# int literals (implicit conversion) + /// - `np.where(cond, np.array(1), np.array(0))` where arrays are explicitly created + /// + /// Both cases create int32 scalar NDArrays. We preserve the type when both + /// scalars are the same type, and use NEP50 weak scalar rules otherwise. + /// This differs from NumPy where Python int literals widen to int64. /// - public static NDArray where(NDArray condition, NDArray x, object y) + private static NPTypeCode _GetPythonLikeScalarType(NPTypeCode xType, NPTypeCode yType) { - return where(condition, x, asanyarray(y)); + // Same type: preserve it (no widening) + // This handles np.where(cond, 1, 0) → int32, np.where(cond, 1L, 0L) → int64 + if (xType == yType) + return xType; + + // Different types - apply promotion rules + var xKind = GetTypeKind(xType); + var yKind = GetTypeKind(yType); + + // Cross-kind promotion: use standard array-array rules + if (xKind != yKind) + { + return _FindCommonArrayType(xType, yType); + } + + // Same kind, different types - use array-array promotion + return _FindCommonArrayType(xType, yType); } /// - /// Return elements chosen from `x` or `y` depending on `condition`. - /// Scalar overload for both x and y. + /// Returns the kind character for a type (matching NumPy's dtype.kind). /// - public static NDArray where(NDArray condition, object x, object y) + private static char GetTypeKind(NPTypeCode type) { - return where(condition, asanyarray(x), asanyarray(y)); + return type switch + { + NPTypeCode.Boolean => 'b', + NPTypeCode.Byte => 'u', + NPTypeCode.UInt16 => 'u', + NPTypeCode.UInt32 => 'u', + NPTypeCode.UInt64 => 'u', + NPTypeCode.Int16 => 'i', + NPTypeCode.Int32 => 'i', + NPTypeCode.Int64 => 'i', + NPTypeCode.Char => 'u', // char is essentially uint16 + NPTypeCode.Single => 'f', + NPTypeCode.Double => 'f', + NPTypeCode.Decimal => 'f', // treat decimal as float-like + NPTypeCode.Complex => 'c', + _ => '?' + }; } private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs index 5834d16a..5afef228 100644 --- a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs +++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs @@ -8,6 +8,25 @@ namespace NumSharp.UnitTest.Logic { /// /// Battle tests for np.where - edge cases, strided arrays, views, etc. + /// + /// These tests verify NumSharp behavior against NumPy 2.4.2. + /// + /// KNOWN DIFFERENCES FROM NUMPY 2.x: + /// + /// 1. Scalar Type Promotion (NEP50): + /// NumPy 2.x uses "weak scalar" semantics where Python scalars adopt array dtype. + /// NumSharp uses C# semantics where literals have fixed types (int=int32, etc). + /// + /// Example: np.where(cond, 1, uint8_array) + /// - NumPy 2.x: returns uint8 (weak scalar rule) + /// - NumSharp: returns int32 (C# int literal is int32) + /// + /// 2. Python int Scalar Default: + /// - NumPy: Python int → int64 (platform default) + /// - NumSharp: C# int literal → int32 + /// + /// 3. Missing sbyte (int8) support: + /// NumSharp does not support sbyte arrays (throws NotSupportedException). /// public class np_where_BattleTest { @@ -16,15 +35,15 @@ public class np_where_BattleTest [Test] public void Where_SlicedCondition() { - // Sliced condition array + // Sliced condition array (non-contiguous) var arr = np.arange(10); - var cond = (arr % 2 == 0)["::2"]; // Every other even check + var cond = (arr % 2 == 0)["::2"]; // Every other even check: [T,T,T,T,T] var x = np.ones(5, NPTypeCode.Int32); var y = np.zeros(5, NPTypeCode.Int32); var result = np.where(cond, x, y); - // Should work with sliced condition Assert.AreEqual(5, result.size); + result.Should().BeOfValues(1, 1, 1, 1, 1); } [Test] @@ -62,6 +81,7 @@ public void Where_ReversedSlice() var y = np.zeros(5, NPTypeCode.Int64); var result = np.where(cond, x, y); + // NumPy: [4, 0, 2, 0, 0] result.Should().BeOfValues(4L, 0L, 2L, 0L, 0L); } @@ -104,6 +124,91 @@ public void Where_RowVector_ColVector_Broadcast() Assert.AreEqual(0, (int)result[1, 1]); } + [Test] + public void Where_ScalarCondition_True() + { + // NumPy: np.where(True, [1,2,3], [4,5,6]) -> [1,2,3] + var result = np.where(np.array(true), np.array(new[] { 1, 2, 3 }), np.array(new[] { 4, 5, 6 })); + result.Should().BeOfValues(1, 2, 3); + } + + [Test] + public void Where_ScalarCondition_False() + { + // NumPy: np.where(False, [1,2,3], [4,5,6]) -> [4,5,6] + var result = np.where(np.array(false), np.array(new[] { 1, 2, 3 }), np.array(new[] { 4, 5, 6 })); + result.Should().BeOfValues(4, 5, 6); + } + + #endregion + + #region Non-Boolean Conditions (Truthy/Falsy) + + [Test] + public void Where_IntegerCondition_ZeroIsFalsy() + { + // NumPy: 0 is falsy, non-zero is truthy + var cond = np.array(new[] { 0, 1, 2, -1, 0 }); + var x = np.ones(5); + var y = np.zeros(5); + var result = np.where(cond, x, y); + + // NumPy: [0, 1, 1, 1, 0] + result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); + } + + [Test] + public void Where_FloatCondition_ZeroIsFalsy() + { + // NumPy: 0.0 is falsy + var cond = np.array(new[] { 0.0, 0.5, 1.0, -0.1, 0.0 }); + var x = np.ones(5); + var y = np.zeros(5); + var result = np.where(cond, x, y); + + // NumPy: [0, 1, 1, 1, 0] + result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); + } + + [Test] + public void Where_NaN_IsTruthy() + { + // NumPy: NaN is truthy (non-zero) + var cond = np.array(new[] { 0.0, double.NaN, 1.0 }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 2, 3] (NaN is truthy) + result.Should().BeOfValues(10, 2, 3); + } + + [Test] + public void Where_Infinity_IsTruthy() + { + // NumPy: Inf and -Inf are truthy + var cond = np.array(new[] { 0.0, double.PositiveInfinity, double.NegativeInfinity }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 2, 3] + result.Should().BeOfValues(10, 2, 3); + } + + [Test] + public void Where_NegativeZero_IsFalsy() + { + // NumPy: -0.0 == 0.0 in IEEE 754, so it's falsy + var cond = np.array(new[] { 0.0, -0.0, 1.0 }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 20, 3] (both 0.0 and -0.0 are falsy) + result.Should().BeOfValues(10, 20, 3); + } + #endregion #region Numeric Edge Cases @@ -153,10 +258,11 @@ public void Where_MaxMin_Values() public void Where_SingleArg_Float_Truthy() { // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy + // Note: -0.0 == 0.0 in IEEE 754, so it's falsy var arr = np.array(new[] { 0.0, 1.0, -1.0, 0.5, -0.0 }); var result = np.where(arr); - // Note: -0.0 == 0.0 in IEEE 754, so it's falsy + // NumPy: indices [1, 2, 3] (-0.0 is falsy) result[0].Should().BeOfValues(1L, 2L, 3L); } @@ -170,6 +276,16 @@ public void Where_SingleArg_NaN_IsTruthy() result[0].Should().BeOfValues(1L); } + [Test] + public void Where_SingleArg_Infinity_IsTruthy() + { + // Inf values are truthy + var arr = np.array(new[] { 0.0, double.PositiveInfinity, double.NegativeInfinity, 0.0 }); + var result = np.where(arr); + + result[0].Should().BeOfValues(1L, 2L); + } + [Test] public void Where_SingleArg_4D() { @@ -182,6 +298,120 @@ public void Where_SingleArg_4D() Assert.AreEqual(2, result[0].size); // 2 non-zero elements } + [Test] + public void Where_SingleArg_ReturnsInt64Indices() + { + // NumPy returns int64 for indices + var arr = np.array(new[] { 0, 1, 0, 2 }); + var result = np.where(arr); + + Assert.AreEqual(typeof(long), result[0].dtype); + } + + #endregion + + #region 0D Scalar Arrays + + [Test] + public void Where_0D_AllScalars_Returns0D() + { + // NumPy: when all inputs are 0D, result is 0D + var cond = np.array(true).reshape(); // 0D + var x = np.array(42).reshape(); // 0D + var y = np.array(99).reshape(); // 0D + var result = np.where(cond, x, y); + + Assert.AreEqual(0, result.ndim); + Assert.AreEqual(42, (int)result.GetValue(0)); + } + + [Test] + public void Where_0D_Cond_With_1D_Arrays() + { + // 0D condition broadcasts to match x/y shape + var cond = np.array(true).reshape(); // 0D + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3); + result.Should().BeOfValues(1, 2, 3); + } + + #endregion + + #region Type Promotion (Array-to-Array) + + [Test] + public void Where_TypePromotion_Bool_Int16() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new bool[] { true, false }); + var y = np.array(new short[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: int16 + Assert.AreEqual(typeof(short), result.dtype); + } + + [Test] + public void Where_TwoScalars_Byte_StaysByte() + { + // C# byte (like np.uint8) stays byte, not widened to int64 + var cond = np.array(new[] { true, false }); + var result = np.where(cond, (byte)1, (byte)0); + + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)1, (byte)result[0]); + Assert.AreEqual((byte)0, (byte)result[1]); + } + + [Test] + public void Where_TwoScalars_Short_StaysShort() + { + // C# short (like np.int16) stays short + var cond = np.array(new[] { true, false }); + var result = np.where(cond, (short)100, (short)200); + + Assert.AreEqual(typeof(short), result.dtype); + } + + [Test] + public void Where_TypePromotion_Int32_UInt32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1, 2 }); + var y = np.array(new uint[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: int64 (to accommodate both signed and unsigned 32-bit range) + Assert.AreEqual(typeof(long), result.dtype); + } + + [Test] + public void Where_TypePromotion_Int64_UInt64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new long[] { 1, 2 }); + var y = np.array(new ulong[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: float64 (no integer type can hold both int64 and uint64 full range) + Assert.AreEqual(typeof(double), result.dtype); + } + + [Test] + public void Where_TypePromotion_UInt8_Float32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 1, 2 }); + var y = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, x, y); + + // NumPy: float32 + Assert.AreEqual(typeof(float), result.dtype); + } + #endregion #region Performance/Stress Tests @@ -214,7 +444,49 @@ public void Where_ManyDimensions() var result = np.where(cond, x, y); result.Should().BeShaped(2, 3, 2, 2, 2, 3); - Assert.AreEqual(1, (int)result[0, 0, 0, 0, 0, 0]); + Assert.AreEqual(144, result.size); + Assert.AreEqual(144, (long)np.sum(result)); // All 1s + } + + [Test] + public void Where_AllTrue_LargeArray() + { + var size = 10000; + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + // Sum of 0 to 9999 = 49995000 + Assert.AreEqual(49995000L, (long)np.sum(result)); + } + + [Test] + public void Where_AllFalse_LargeArray() + { + var size = 10000; + var cond = np.zeros(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + Assert.AreEqual(0L, (long)np.sum(result)); + } + + [Test] + public void Where_Alternating_LargeArray() + { + var size = 10000; + var cond = np.zeros(size, NPTypeCode.Boolean); + for (int i = 0; i < size; i += 2) + cond[i] = true; + + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + // Sum of even indices: 0+2+4+...+9998 = 24995000 + Assert.AreEqual(24995000L, (long)np.sum(result)); } #endregion @@ -342,5 +614,146 @@ public void Where_StripedPattern() } #endregion + + #region Empty Array Edge Cases + + [Test] + public void Where_Empty2D() + { + // Empty (0,3) shape + var cond = np.zeros(new[] { 0, 3 }, NPTypeCode.Boolean); + var x = np.zeros(new[] { 0, 3 }, NPTypeCode.Double); + var y = np.zeros(new[] { 0, 3 }, NPTypeCode.Double); + var result = np.where(cond, x, y); + + result.Should().BeShaped(0, 3); + Assert.AreEqual(typeof(double), result.dtype); + } + + [Test] + public void Where_Empty3D() + { + // Empty (2,0,3) shape + var cond = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Boolean); + var x = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Int32); + var y = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 0, 3); + Assert.AreEqual(typeof(int), result.dtype); + } + + [Test] + public void Where_SingleArg_Empty2D() + { + var arr = np.zeros(new[] { 0, 3 }, NPTypeCode.Int32); + var result = np.where(arr); + + Assert.AreEqual(2, result.Length); // 2 dimensions + Assert.AreEqual(0, result[0].size); + Assert.AreEqual(0, result[1].size); + } + + #endregion + + #region Error Conditions + + [Test] + public void Where_IncompatibleShapes_ThrowsException() + { + // Shapes (2,) and (3,) cannot be broadcast together + var cond = np.array(new[] { true, false }); // (2,) + var x = np.array(new[] { 1, 2, 3 }); // (3,) + var y = np.array(new[] { 4, 5, 6 }); // (3,) + + Assert.ThrowsException(() => np.where(cond, x, y)); + } + + #endregion + + #region NEP50 Type Promotion (NumPy 2.x Parity) + + /// + /// Verifies NEP50 weak scalar semantics: when a scalar is combined with an array, + /// the array's dtype wins for same-kind operations. + /// + [Test] + public void Where_ScalarTypePromotion_NEP50_WeakScalar() + { + // NumPy 2.x: np.where(cond, 1, uint8_array) -> uint8 (weak scalar) + var cond = np.array(new[] { true, false }); + var yUint8 = np.array(new byte[] { 10, 20 }); + var result = np.where(cond, 1, yUint8); + + // Array dtype wins - matches NumPy 2.x NEP50 + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)1, (byte)result[0]); + Assert.AreEqual((byte)20, (byte)result[1]); + } + + /// + /// Two same-type scalars preserve their type. + /// Note: NumPy would return int64 for Python int literals, but C# int32 scalars + /// cannot be distinguished from explicit np.array(1, dtype=int32), so we preserve. + /// + [Test] + public void Where_TwoScalars_SameType_Preserved() + { + var cond = np.array(new[] { true, false }); + + // int + int → int (preserved) + var result = np.where(cond, 1, 0); + Assert.AreEqual(typeof(int), result.dtype); + Assert.AreEqual(1, (int)result[0]); + Assert.AreEqual(0, (int)result[1]); + + // long + long → long (preserved) + result = np.where(cond, 1L, 0L); + Assert.AreEqual(typeof(long), result.dtype); + } + + /// + /// Verifies C# float scalars stay float32 (like np.float32, not Python float). + /// + [Test] + public void Where_TwoScalars_Float32_StaysFloat32() + { + // C# float (1.0f) is like np.float32, not Python's float (which is float64) + // np.where(cond, np.float32(1.0), np.float32(0.0)) -> float32 + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1.0f, 0.0f); + + Assert.AreEqual(typeof(float), result.dtype); + } + + /// + /// Verifies NEP50: int scalar + float32 array -> float32 (same-kind, array wins). + /// + [Test] + public void Where_IntScalar_Float32Array_ReturnsFloat32() + { + var cond = np.array(new[] { true, false }); + var yFloat32 = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, 1, yFloat32); + + // Array dtype wins for same-kind (int->float conversion) + Assert.AreEqual(typeof(float), result.dtype); + } + + /// + /// Verifies NEP50: float scalar + int32 array -> float64 (cross-kind promotion). + /// + [Test] + public void Where_FloatScalar_Int32Array_ReturnsFloat64() + { + var cond = np.array(new[] { true, false }); + var yInt32 = np.array(new int[] { 10, 20 }); + var result = np.where(cond, 1.5, yInt32); + + // Cross-kind: float scalar forces float64 + Assert.AreEqual(typeof(double), result.dtype); + } + + #endregion } } diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs index d53ed585..8e81736c 100644 --- a/test/NumSharp.UnitTest/Logic/np.where.Test.cs +++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs @@ -308,6 +308,7 @@ public void Where_SingleElement() var result = np.where(cond, 42, 0); Assert.AreEqual(1, result.size); + Assert.AreEqual(typeof(int), result.dtype); // same-type scalars preserve type Assert.AreEqual(42, (int)result[0]); } From 653af588616115595960ea107c0ab526bb0acb5c Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 17:16:26 +0300 Subject: [PATCH 05/19] feat(asanyarray): support all built-in C# collection types Extended np.asanyarray to handle all common C# collection types: Collections supported via IEnumerable pattern matching: - List, IList, ICollection, IEnumerable - IReadOnlyList, IReadOnlyCollection - ReadOnlyCollection - LinkedList - HashSet, SortedSet - Queue, Stack - ArraySegment (implements IEnumerable) - ImmutableArray, ImmutableList, ImmutableHashSet - Any LINQ query result (IEnumerable) Special handling for types not implementing IEnumerable: - Memory - uses direct cast and ToArray() - ReadOnlyMemory - uses direct cast and ToArray() Implementation approach: - Clean pattern matching on IEnumerable for all 12 NumSharp types - No method reflection (direct LINQ .ToArray() calls) - Memory/ReadOnlyMemory handled via type switch with direct casts Supported element types (NumSharp's 12 types): bool, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal Note: sbyte, IntPtr, UIntPtr are NOT supported (not in NPTypeCode) --- src/NumSharp.Core/Creation/np.asanyarray.cs | 86 ++- .../Creation/np.asanyarray.Tests.cs | 533 ++++++++++++++++++ 2 files changed, 615 insertions(+), 4 deletions(-) create mode 100644 test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 5e83dc00..1bf93781 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -1,4 +1,7 @@ using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; namespace NumSharp { @@ -18,29 +21,104 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support case null: throw new ArgumentNullException(nameof(a)); case NDArray nd: - return nd; + if (dtype == null || Equals(nd.dtype, dtype)) + return nd; + return nd.astype(dtype, true); case Array array: ret = new NDArray(array); break; case string str: ret = str; //implicit cast located in NDArray.Implicit.Array break; + + // Handle typed IEnumerable for all 12 NumSharp-supported types + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + case IEnumerable e: ret = np.array(e.ToArray()); break; + default: var type = a.GetType(); - //is it a scalar + // Check if it's a scalar (primitive or decimal) if (type.IsPrimitive || type == typeof(decimal)) { ret = NDArray.Scalar(a); break; } - throw new NotSupportedException($"Unable resolve asanyarray for type {a.GetType().Name}"); + // Handle Memory and ReadOnlyMemory - they don't implement IEnumerable + if (type.IsGenericType) + { + var genericDef = type.GetGenericTypeDefinition(); + if (genericDef == typeof(Memory<>) || genericDef == typeof(ReadOnlyMemory<>)) + { + ret = ConvertMemory(a, type); + if (ret is not null) + break; + } + } + + throw new NotSupportedException($"Unable to resolve asanyarray for type {type.Name}"); } - if (dtype != null && a.GetType() != dtype) + if (dtype != null && !Equals(ret.dtype, dtype)) return ret.astype(dtype, true); return ret; } + + /// + /// Converts Memory<T> or ReadOnlyMemory<T> to an NDArray. + /// These types don't implement IEnumerable<T>, so we handle them specially. + /// + private static NDArray ConvertMemory(object a, Type type) + { + var genericDef = type.GetGenericTypeDefinition(); + var elementType = type.GetGenericArguments()[0]; + + // Handle ReadOnlyMemory first (it cannot be cast to Memory) + if (genericDef == typeof(ReadOnlyMemory<>)) + { + if (elementType == typeof(bool)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(byte)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(short)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(ushort)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(int)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(uint)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(long)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(ulong)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(char)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(float)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(double)) return np.array(((ReadOnlyMemory)a).ToArray()); + if (elementType == typeof(decimal)) return np.array(((ReadOnlyMemory)a).ToArray()); + } + + // Handle Memory + if (genericDef == typeof(Memory<>)) + { + if (elementType == typeof(bool)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(byte)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(short)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(ushort)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(int)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(uint)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(long)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(ulong)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(char)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(float)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(double)) return np.array(((Memory)a).ToArray()); + if (elementType == typeof(decimal)) return np.array(((Memory)a).ToArray()); + } + + return null; + } } } diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs new file mode 100644 index 00000000..530e0a4b --- /dev/null +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -0,0 +1,533 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.Linq; +using AwesomeAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace NumSharp.UnitTest.Creation +{ + /// + /// Tests for np.asanyarray covering all built-in C# collection types. + /// + public class np_asanyarray_tests + { + #region NDArray passthrough + + [Test] + public void NDArray_ReturnsAsIs() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original); + + // Should return the same instance (no copy) + ReferenceEquals(original, result).Should().BeTrue(); + } + + [Test] + public void NDArray_WithDtype_ReturnsConverted() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original, typeof(double)); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(5); + } + + [Test] + public void NDArray_WithSameDtype_ReturnsAsIs() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original, typeof(int)); + + // Same dtype, should return same instance + ReferenceEquals(original, result).Should().BeTrue(); + } + + #endregion + + #region Array types + + [Test] + public void Array_1D() + { + var arr = new int[] { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(arr); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void Array_2D() + { + var arr = new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }; + var result = np.asanyarray(arr); + + result.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); + } + + [Test] + public void Array_WithDtype() + { + var arr = new int[] { 1, 2, 3 }; + var result = np.asanyarray(arr, typeof(double)); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3); + } + + #endregion + + #region Scalars + + [Test] + public void Scalar_Int() + { + var result = np.asanyarray(42); + + result.Should().BeScalar().And.BeOfValues(42); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void Scalar_Double() + { + var result = np.asanyarray(3.14); + + result.Should().BeScalar(); + result.dtype.Should().Be(typeof(double)); + } + + [Test] + public void Scalar_Decimal() + { + var result = np.asanyarray(123.456m); + + result.Should().BeScalar(); + result.dtype.Should().Be(typeof(decimal)); + } + + [Test] + public void Scalar_Bool() + { + var result = np.asanyarray(true); + + result.Should().BeScalar().And.BeOfValues(true); + result.dtype.Should().Be(typeof(bool)); + } + + #endregion + + #region List + + [Test] + public void List_Int() + { + var list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void List_Double() + { + var list = new List { 1.1, 2.2, 3.3 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); + } + + [Test] + public void List_Bool() + { + var list = new List { true, false, true }; + var result = np.asanyarray(list); + + result.Should().BeShaped(3).And.BeOfValues(true, false, true); + result.dtype.Should().Be(typeof(bool)); + } + + [Test] + public void List_Empty() + { + var list = new List(); + var result = np.asanyarray(list); + + result.Should().BeShaped(0); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void List_WithDtype() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list, typeof(float)); + + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3); + } + + #endregion + + #region IList / ICollection / IEnumerable + + [Test] + public void IList_Int() + { + IList list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void ICollection_Int() + { + ICollection collection = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void IEnumerable_Int() + { + IEnumerable enumerable = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void IEnumerable_FromLinq() + { + var enumerable = Enumerable.Range(1, 5); + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void IEnumerable_FromLinqSelect() + { + var enumerable = new[] { 1, 2, 3 }.Select(x => x * 2); + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(3).And.BeOfValues(2, 4, 6); + } + + #endregion + + #region IReadOnlyList / IReadOnlyCollection + + [Test] + public void IReadOnlyList_Int() + { + IReadOnlyList list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void IReadOnlyCollection_Int() + { + IReadOnlyCollection collection = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region ReadOnlyCollection + + [Test] + public void ReadOnlyCollection_Int() + { + var collection = new ReadOnlyCollection(new List { 1, 2, 3, 4, 5 }); + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region LinkedList + + [Test] + public void LinkedList_Int() + { + var linkedList = new LinkedList(); + linkedList.AddLast(1); + linkedList.AddLast(2); + linkedList.AddLast(3); + var result = np.asanyarray(linkedList); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + #endregion + + #region HashSet / SortedSet + + [Test] + public void HashSet_Int() + { + var set = new HashSet { 3, 1, 4, 1, 5, 9 }; // Duplicates removed + var result = np.asanyarray(set); + + result.size.Should().Be(5); // 1, 3, 4, 5, 9 (no duplicates) + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void SortedSet_Int() + { + var set = new SortedSet { 3, 1, 4, 1, 5, 9 }; + var result = np.asanyarray(set); + + result.Should().BeShaped(5).And.BeOfValues(1, 3, 4, 5, 9); // Sorted, no duplicates + } + + #endregion + + #region Queue / Stack + + [Test] + public void Queue_Int() + { + var queue = new Queue(); + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + var result = np.asanyarray(queue); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + [Test] + public void Stack_Int() + { + var stack = new Stack(); + stack.Push(1); + stack.Push(2); + stack.Push(3); + var result = np.asanyarray(stack); + + result.Should().BeShaped(3).And.BeOfValues(3, 2, 1); // LIFO order + } + + #endregion + + #region ArraySegment + + [Test] + public void ArraySegment_Int() + { + var array = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + var segment = new ArraySegment(array, 2, 5); // Elements 2,3,4,5,6 + var result = np.asanyarray(segment); + + result.Should().BeShaped(5).And.BeOfValues(2, 3, 4, 5, 6); + } + + [Test] + public void ArraySegment_Empty() + { + var array = new int[] { 1, 2, 3 }; + var segment = new ArraySegment(array, 0, 0); + var result = np.asanyarray(segment); + + result.Should().BeShaped(0); + } + + [Test] + public void ArraySegment_Full() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var segment = new ArraySegment(array); + var result = np.asanyarray(segment); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region Memory / ReadOnlyMemory + + [Test] + public void Memory_Int() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var memory = new Memory(array, 1, 3); // Elements 2,3,4 + var result = np.asanyarray(memory); + + result.Should().BeShaped(3).And.BeOfValues(2, 3, 4); + } + + [Test] + public void ReadOnlyMemory_Int() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var memory = new ReadOnlyMemory(array, 1, 3); // Elements 2,3,4 + var result = np.asanyarray(memory); + + result.Should().BeShaped(3).And.BeOfValues(2, 3, 4); + } + + #endregion + + #region ImmutableArray / ImmutableList + + [Test] + public void ImmutableArray_Int() + { + var immutableArray = ImmutableArray.Create(1, 2, 3, 4, 5); + var result = np.asanyarray(immutableArray); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void ImmutableList_Int() + { + var immutableList = ImmutableList.Create(1, 2, 3, 4, 5); + var result = np.asanyarray(immutableList); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [Test] + public void ImmutableHashSet_Int() + { + var immutableSet = ImmutableHashSet.Create(3, 1, 4, 1, 5); + var result = np.asanyarray(immutableSet); + + result.size.Should().Be(4); // 1, 3, 4, 5 (no duplicates) + } + + #endregion + + #region All supported dtypes via List + + [Test] + public void List_Byte() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(byte)); + result.Should().BeShaped(3); + } + + // Note: sbyte is NOT supported by NumSharp (12 supported types: bool, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal) + + [Test] + public void List_Short() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(short)); + result.Should().BeShaped(3); + } + + [Test] + public void List_UShort() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(ushort)); + result.Should().BeShaped(3); + } + + [Test] + public void List_UInt() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(uint)); + result.Should().BeShaped(3); + } + + [Test] + public void List_Long() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(long)); + result.Should().BeShaped(3); + } + + [Test] + public void List_ULong() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(ulong)); + result.Should().BeShaped(3); + } + + [Test] + public void List_Float() + { + var list = new List { 1.1f, 2.2f, 3.3f }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3); + } + + [Test] + public void List_Char() + { + var list = new List { 'a', 'b', 'c' }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(char)); + result.Should().BeShaped(3); + } + + #endregion + + #region Error cases + + [Test] + public void Null_ThrowsArgumentNullException() + { + Assert.ThrowsException(() => np.asanyarray(null)); + } + + [Test] + public void UnsupportedType_ThrowsNotSupportedException() + { + // String collections are not supported (string is not primitive/decimal) + var stringList = new List { "a", "b", "c" }; + Assert.ThrowsException(() => np.asanyarray(stringList)); + } + + [Test] + public void CustomClass_ThrowsNotSupportedException() + { + var customObject = new object(); + Assert.ThrowsException(() => np.asanyarray(customObject)); + } + + #endregion + + #region String special case + + [Test] + public void String_CreatesCharArray() + { + var result = np.asanyarray("hello"); + + result.Should().BeShaped(5); + result.dtype.Should().Be(typeof(char)); + } + + #endregion + } +} From 974e70d552cc48608e98bdb091374c4e0f1aad83 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 17:31:14 +0300 Subject: [PATCH 06/19] feat(asanyarray): add non-generic IEnumerable/IEnumerator fallback Added fallback support for collections that don't implement generic IEnumerable but still implement the non-generic interfaces: Non-generic IEnumerable fallback: - ArrayList, Hashtable.Keys/Values, BitArray, etc. - Any legacy collection implementing only IEnumerable - Element type detected from first non-null item Non-generic IEnumerator fallback: - Direct enumerator objects (e.g., from yield return methods) - Element type detected from first non-null item Implementation: - Enumerate items into List - Detect element type from first item - Convert to typed array via type switch (no reflection) - Returns null for unsupported element types (falls through to error) This completes the collection support hierarchy: 1. IEnumerable - direct pattern matching (most efficient) 2. Memory/ReadOnlyMemory - special handling (no IEnumerable) 3. IEnumerable (non-generic) - fallback with type detection 4. IEnumerator (non-generic) - fallback with type detection --- src/NumSharp.Core/Creation/np.asanyarray.cs | 149 ++++++++++++++++++ .../Creation/np.asanyarray.Tests.cs | 54 +++++++ 2 files changed, 203 insertions(+) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 1bf93781..b4c2c469 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -66,6 +66,22 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support } } + // Fallback: non-generic IEnumerable (element type detected from first item) + if (a is IEnumerable enumerable) + { + ret = ConvertNonGenericEnumerable(enumerable); + if (ret is not null) + break; + } + + // Fallback: non-generic IEnumerator + if (a is IEnumerator enumerator) + { + ret = ConvertEnumerator(enumerator); + if (ret is not null) + break; + } + throw new NotSupportedException($"Unable to resolve asanyarray for type {type.Name}"); } @@ -120,5 +136,138 @@ private static NDArray ConvertMemory(object a, Type type) return null; } + + /// + /// Converts a non-generic IEnumerable to an NDArray. + /// Element type is detected from the first item. + /// + private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) + { + // Collect items and detect type from first element + var items = new List(); + Type elementType = null; + + foreach (var item in enumerable) + { + if (item == null) + continue; + + elementType ??= item.GetType(); + items.Add(item); + } + + if (items.Count == 0 || elementType == null) + return null; // Can't determine type from empty collection + + return ConvertObjectListToNDArray(items, elementType); + } + + /// + /// Converts a non-generic IEnumerator to an NDArray. + /// Element type is detected from the first item. + /// + private static NDArray ConvertEnumerator(IEnumerator enumerator) + { + // Collect items and detect type from first element + var items = new List(); + Type elementType = null; + + while (enumerator.MoveNext()) + { + var item = enumerator.Current; + if (item == null) + continue; + + elementType ??= item.GetType(); + items.Add(item); + } + + if (items.Count == 0 || elementType == null) + return null; // Can't determine type from empty collection + + return ConvertObjectListToNDArray(items, elementType); + } + + /// + /// Converts a list of objects to an NDArray of the specified element type. + /// + private static NDArray ConvertObjectListToNDArray(List items, Type elementType) + { + // Type switch to create typed array without reflection + if (elementType == typeof(bool)) + { + var arr = new bool[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (bool)items[i]; + return np.array(arr); + } + if (elementType == typeof(byte)) + { + var arr = new byte[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (byte)items[i]; + return np.array(arr); + } + if (elementType == typeof(short)) + { + var arr = new short[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (short)items[i]; + return np.array(arr); + } + if (elementType == typeof(ushort)) + { + var arr = new ushort[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (ushort)items[i]; + return np.array(arr); + } + if (elementType == typeof(int)) + { + var arr = new int[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (int)items[i]; + return np.array(arr); + } + if (elementType == typeof(uint)) + { + var arr = new uint[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (uint)items[i]; + return np.array(arr); + } + if (elementType == typeof(long)) + { + var arr = new long[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (long)items[i]; + return np.array(arr); + } + if (elementType == typeof(ulong)) + { + var arr = new ulong[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (ulong)items[i]; + return np.array(arr); + } + if (elementType == typeof(char)) + { + var arr = new char[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (char)items[i]; + return np.array(arr); + } + if (elementType == typeof(float)) + { + var arr = new float[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (float)items[i]; + return np.array(arr); + } + if (elementType == typeof(double)) + { + var arr = new double[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (double)items[i]; + return np.array(arr); + } + if (elementType == typeof(decimal)) + { + var arr = new decimal[items.Count]; + for (int i = 0; i < items.Count; i++) arr[i] = (decimal)items[i]; + return np.array(arr); + } + + return null; // Unsupported element type + } } } diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index 530e0a4b..b8b51522 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -529,5 +529,59 @@ public void String_CreatesCharArray() } #endregion + + #region Non-generic IEnumerable fallback + + [Test] + public void ArrayList_Int() + { + var arrayList = new System.Collections.ArrayList { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(arrayList); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void ArrayList_Double() + { + var arrayList = new System.Collections.ArrayList { 1.1, 2.2, 3.3 }; + var result = np.asanyarray(arrayList); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); + } + + [Test] + public void Hashtable_Keys() + { + var hashtable = new System.Collections.Hashtable { { 1, "a" }, { 2, "b" }, { 3, "c" } }; + var result = np.asanyarray(hashtable.Keys); + + result.size.Should().Be(3); + result.dtype.Should().Be(typeof(int)); + } + + #endregion + + #region IEnumerator fallback + + [Test] + public void IEnumerator_Int() + { + static System.Collections.IEnumerator GetEnumerator() + { + yield return 10; + yield return 20; + yield return 30; + } + + var result = np.asanyarray(GetEnumerator()); + + result.Should().BeShaped(3).And.BeOfValues(10, 20, 30); + result.dtype.Should().Be(typeof(int)); + } + + #endregion } } From 23ad1c1cb96f419d0ff49daa8cdddd26c18e4467 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 18:21:25 +0300 Subject: [PATCH 07/19] refactor(asanyarray): consolidate duplicate code - ConvertMemory: single type switch with ternary for ReadOnly vs mutable - ConvertNonGenericEnumerable: delegate to ConvertEnumerator via GetEnumerator() --- src/NumSharp.Core/APIs/np.where.cs | 118 ++------------------ src/NumSharp.Core/Creation/np.asanyarray.cs | 70 +++--------- 2 files changed, 25 insertions(+), 163 deletions(-) diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs index a7cbf129..f8da5a2f 100644 --- a/src/NumSharp.Core/APIs/np.where.cs +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -27,11 +27,7 @@ public static NDArray[] where(NDArray condition) /// https://numpy.org/doc/stable/reference/generated/numpy.where.html public static NDArray where(NDArray condition, NDArray x, NDArray y) { - // Detect scalar NDArrays (from implicit primitive conversion or explicit NDArray.Scalar) - // Scalar NDArrays use NEP50 weak scalar type promotion rules - bool xIsScalar = x.Shape.IsScalar; - bool yIsScalar = y.Shape.IsScalar; - return where_internal(condition, x, y, xIsScalar, yIsScalar); + return where_internal(condition, x, y); } /// @@ -40,8 +36,7 @@ public static NDArray where(NDArray condition, NDArray x, NDArray y) /// public static NDArray where(NDArray condition, object x, NDArray y) { - var xArr = asanyarray(x); - return where_internal(condition, xArr, y, xArr.Shape.IsScalar, y.Shape.IsScalar); + return where_internal(condition, asanyarray(x), y); } /// @@ -50,8 +45,7 @@ public static NDArray where(NDArray condition, object x, NDArray y) /// public static NDArray where(NDArray condition, NDArray x, object y) { - var yArr = asanyarray(y); - return where_internal(condition, x, yArr, x.Shape.IsScalar, yArr.Shape.IsScalar); + return where_internal(condition, x, asanyarray(y)); } /// @@ -60,20 +54,13 @@ public static NDArray where(NDArray condition, NDArray x, object y) /// public static NDArray where(NDArray condition, object x, object y) { - var xArr = asanyarray(x); - var yArr = asanyarray(y); - return where_internal(condition, xArr, yArr, xArr.Shape.IsScalar, yArr.Shape.IsScalar); + return where_internal(condition, asanyarray(x), asanyarray(y)); } /// - /// Internal implementation of np.where with scalar tracking for NEP50 type promotion. + /// Internal implementation of np.where. /// - /// Condition array - /// X values (already converted to NDArray) - /// Y values (already converted to NDArray) - /// True if x is a scalar NDArray - /// True if y is a scalar NDArray - private static NDArray where_internal(NDArray condition, NDArray x, NDArray y, bool xIsScalar, bool yIsScalar) + private static NDArray where_internal(NDArray condition, NDArray x, NDArray y) { // Broadcast all three arrays to common shape var broadcasted = broadcast_arrays(condition, x, y); @@ -81,8 +68,9 @@ private static NDArray where_internal(NDArray condition, NDArray x, NDArray y, b var xArr = broadcasted[1]; var yArr = broadcasted[2]; - // Determine output dtype from x and y using NEP50-aware type promotion - var outType = _FindCommonTypeForWhere(x.GetTypeCode, y.GetTypeCode, xIsScalar, yIsScalar); + // Determine output dtype using existing type promotion system + // _FindCommonType already handles NEP50: scalar+array → array wins + var outType = _FindCommonType(x, y); // Convert x and y to output type if needed (required for kernel and iterator paths) if (xArr.GetTypeCode != outType) @@ -157,94 +145,6 @@ private static NDArray where_internal(NDArray condition, NDArray x, NDArray y, b return result; } - /// - /// Determines the output dtype for np.where following NumPy 2.x NEP50 rules. - /// - /// Rules: - /// 1. Both arrays (non-scalar): use array-array promotion table - /// 2. Both were scalars: use Python-like defaults (int32→int64) - /// 3. One array, one scalar: use NEP50 weak scalar rules (array dtype wins for same-kind) - /// - private static NPTypeCode _FindCommonTypeForWhere(NPTypeCode xType, NPTypeCode yType, bool xIsScalar, bool yIsScalar) - { - // Case 1: Both are scalars - use Python-like default type widening - if (xIsScalar && yIsScalar) - { - return _GetPythonLikeScalarType(xType, yType); - } - - // Case 2: One is scalar, one is array - use NEP50 weak scalar rules - if (xIsScalar) - { - // y is array, x is scalar - array wins for same-kind - return _FindCommonArrayScalarType(yType, xType); - } - if (yIsScalar) - { - // x is array, y is scalar - array wins for same-kind - return _FindCommonArrayScalarType(xType, yType); - } - - // Case 3: Both are arrays - use array-array promotion - return _FindCommonArrayType(xType, yType); - } - - /// - /// Determines the result type when both operands are scalar NDArrays. - /// - /// C# limitation: We cannot distinguish between: - /// - `np.where(cond, 1, 0)` where 1,0 are C# int literals (implicit conversion) - /// - `np.where(cond, np.array(1), np.array(0))` where arrays are explicitly created - /// - /// Both cases create int32 scalar NDArrays. We preserve the type when both - /// scalars are the same type, and use NEP50 weak scalar rules otherwise. - /// This differs from NumPy where Python int literals widen to int64. - /// - private static NPTypeCode _GetPythonLikeScalarType(NPTypeCode xType, NPTypeCode yType) - { - // Same type: preserve it (no widening) - // This handles np.where(cond, 1, 0) → int32, np.where(cond, 1L, 0L) → int64 - if (xType == yType) - return xType; - - // Different types - apply promotion rules - var xKind = GetTypeKind(xType); - var yKind = GetTypeKind(yType); - - // Cross-kind promotion: use standard array-array rules - if (xKind != yKind) - { - return _FindCommonArrayType(xType, yType); - } - - // Same kind, different types - use array-array promotion - return _FindCommonArrayType(xType, yType); - } - - /// - /// Returns the kind character for a type (matching NumPy's dtype.kind). - /// - private static char GetTypeKind(NPTypeCode type) - { - return type switch - { - NPTypeCode.Boolean => 'b', - NPTypeCode.Byte => 'u', - NPTypeCode.UInt16 => 'u', - NPTypeCode.UInt32 => 'u', - NPTypeCode.UInt64 => 'u', - NPTypeCode.Int16 => 'i', - NPTypeCode.Int32 => 'i', - NPTypeCode.Int64 => 'i', - NPTypeCode.Char => 'u', // char is essentially uint16 - NPTypeCode.Single => 'f', - NPTypeCode.Double => 'f', - NPTypeCode.Decimal => 'f', // treat decimal as float-like - NPTypeCode.Complex => 'c', - _ => '?' - }; - } - private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged { // Use iterators for proper handling of broadcasted/strided arrays diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index b4c2c469..23eea81c 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -97,42 +97,22 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support /// private static NDArray ConvertMemory(object a, Type type) { - var genericDef = type.GetGenericTypeDefinition(); var elementType = type.GetGenericArguments()[0]; - - // Handle ReadOnlyMemory first (it cannot be cast to Memory) - if (genericDef == typeof(ReadOnlyMemory<>)) - { - if (elementType == typeof(bool)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(byte)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(short)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(ushort)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(int)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(uint)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(long)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(ulong)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(char)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(float)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(double)) return np.array(((ReadOnlyMemory)a).ToArray()); - if (elementType == typeof(decimal)) return np.array(((ReadOnlyMemory)a).ToArray()); - } - - // Handle Memory - if (genericDef == typeof(Memory<>)) - { - if (elementType == typeof(bool)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(byte)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(short)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(ushort)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(int)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(uint)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(long)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(ulong)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(char)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(float)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(double)) return np.array(((Memory)a).ToArray()); - if (elementType == typeof(decimal)) return np.array(((Memory)a).ToArray()); - } + var isReadOnly = type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); + + // Single type switch - extract array via the appropriate cast + if (elementType == typeof(bool)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(byte)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(short)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(ushort)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(int)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(uint)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(long)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(ulong)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(char)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(float)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(double)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + if (elementType == typeof(decimal)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); return null; } @@ -142,25 +122,7 @@ private static NDArray ConvertMemory(object a, Type type) /// Element type is detected from the first item. /// private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) - { - // Collect items and detect type from first element - var items = new List(); - Type elementType = null; - - foreach (var item in enumerable) - { - if (item == null) - continue; - - elementType ??= item.GetType(); - items.Add(item); - } - - if (items.Count == 0 || elementType == null) - return null; // Can't determine type from empty collection - - return ConvertObjectListToNDArray(items, elementType); - } + => ConvertEnumerator(enumerable.GetEnumerator()); /// /// Converts a non-generic IEnumerator to an NDArray. From 06a43c25578cd93969e3f5648f475d079e2ca2c8 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 20:14:56 +0300 Subject: [PATCH 08/19] fix(asanyarray): add Tuple/ValueTuple support, fix empty collection handling NumPy parity fixes based on battletest comparison: 1. Tuple/ValueTuple support (NEW): - Both Tuple<> and ValueTuple<> now iterate their elements - Uses ITuple interface (available in .NET Core 2.0+) - NumPy: np.asanyarray((1,2,3)) -> dtype=int64, shape=(3,) - NumSharp now matches this behavior 2. Empty non-generic collections (FIX): - Empty ArrayList/IEnumerable now returns empty double[] - Matches NumPy's default of float64 for empty collections - Previously threw NotSupportedException Tests added: - ValueTuple_IsIterable, Tuple_IsIterable - ValueTuple_MixedTypes_UsesFirstElementType - EmptyTuple_ReturnsEmptyDoubleArray - EmptyArrayList_ReturnsEmptyDoubleArray - Misaligned tests documenting intentional NumPy differences --- src/NumSharp.Core/Creation/np.asanyarray.cs | 42 ++++- .../Creation/np.asanyarray.Tests.cs | 152 ++++++++++++++++++ 2 files changed, 193 insertions(+), 1 deletion(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 23eea81c..dcd37f90 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; namespace NumSharp { @@ -66,6 +67,14 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support } } + // Handle Tuple<> and ValueTuple<> - they implement ITuple + if (a is ITuple tuple) + { + ret = ConvertTuple(tuple); + if (ret is not null) + break; + } + // Fallback: non-generic IEnumerable (element type detected from first item) if (a is IEnumerable enumerable) { @@ -127,6 +136,7 @@ private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) /// /// Converts a non-generic IEnumerator to an NDArray. /// Element type is detected from the first item. + /// Empty collections return empty double[] to match NumPy's behavior. /// private static NDArray ConvertEnumerator(IEnumerator enumerator) { @@ -144,8 +154,38 @@ private static NDArray ConvertEnumerator(IEnumerator enumerator) items.Add(item); } + // Empty collection: return empty double[] (NumPy defaults to float64) + if (items.Count == 0 || elementType == null) + return np.array(Array.Empty()); + + return ConvertObjectListToNDArray(items, elementType); + } + + /// + /// Converts a Tuple or ValueTuple to an NDArray. + /// Uses ITuple interface available in .NET Core 2.0+. + /// + private static NDArray ConvertTuple(ITuple tuple) + { + if (tuple.Length == 0) + return np.array(Array.Empty()); + + // Collect items and detect type from first non-null element + var items = new List(tuple.Length); + Type elementType = null; + + for (int i = 0; i < tuple.Length; i++) + { + var item = tuple[i]; + if (item == null) + continue; + + elementType ??= item.GetType(); + items.Add(item); + } + if (items.Count == 0 || elementType == null) - return null; // Can't determine type from empty collection + return np.array(Array.Empty()); return ConvertObjectListToNDArray(items, elementType); } diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index b8b51522..8740120c 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -583,5 +583,157 @@ static System.Collections.IEnumerator GetEnumerator() } #endregion + + #region NumPy Parity - Misaligned Behaviors + + /// + /// NumPy treats strings as scalar Unicode values, NumSharp treats as char arrays. + /// NumPy: np.asanyarray("hello") -> dtype=<U5, shape=(), ndim=0 (SCALAR) + /// NumSharp: dtype=Char, shape=(5), ndim=1 (ARRAY) + /// + [Test] + [Misaligned] + public void String_IsCharArray_NotScalar() + { + var result = np.asanyarray("hello"); + + // NumSharp behavior: char array + result.ndim.Should().Be(1); + result.shape.Should().BeEquivalentTo(new[] { 5 }); + result.dtype.Should().Be(typeof(char)); + + // NumPy would be: ndim=0, shape=(), dtype= + /// NumPy stores sets as object scalars (not iterated). + /// NumSharp iterates sets and converts to array. + /// NumPy: np.asanyarray({1,2,3}) -> dtype=object, shape=() (SCALAR) + /// NumSharp: dtype=Int32, shape=(3) (ARRAY) + /// + [Test] + [Misaligned] + public void HashSet_IsIterated_NotObjectScalar() + { + var set = new HashSet { 1, 2, 3 }; + var result = np.asanyarray(set); + + // NumSharp behavior: iterates and creates array + result.ndim.Should().Be(1); + result.size.Should().Be(3); + result.dtype.Should().Be(typeof(int)); + + // NumPy would be: dtype=object, shape=() (object scalar containing set) + } + + /// + /// NumPy stores generators as object scalars (NOT consumed). + /// NumSharp consumes IEnumerable and converts to array. + /// This is arguably more useful behavior for C#. + /// + [Test] + [Misaligned] + public void LinqEnumerable_IsConsumed_NotObjectScalar() + { + var enumerable = new[] { 1, 2, 3 }.Select(x => x * 2); + var result = np.asanyarray(enumerable); + + // NumSharp behavior: consumes and creates array + result.ndim.Should().Be(1); + result.Should().BeShaped(3).And.BeOfValues(2, 4, 6); + + // NumPy generator would be: dtype=object, shape=() (NOT consumed) + } + + /// + /// For typed empty collections (List<T>), NumSharp preserves the generic type parameter. + /// NumPy defaults to float64 for untyped empty lists. + /// This is a design choice: C# generics provide type information that NumPy doesn't have. + /// + [Test] + [Misaligned] + public void EmptyTypedList_PreservesTypeParameter() + { + var result = np.asanyarray(new List()); + + // NumSharp behavior: preserves int dtype from generic type parameter + result.dtype.Should().Be(typeof(int)); + result.shape.Should().BeEquivalentTo(new[] { 0 }); + + // NumPy would be: dtype=float64, shape=(0,) + // NumSharp can do better because C# generics provide the type at compile time + } + + #endregion + + #region Tuple support + + /// + /// C# ValueTuples are iterable like Python tuples. + /// NumPy: np.asanyarray((1,2,3)) -> dtype=int64, shape=(3,) + /// + [Test] + public void ValueTuple_IsIterable() + { + var tuple = (1, 2, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + result.dtype.Should().Be(typeof(int)); + } + + /// + /// C# Tuple class is iterable like Python tuples. + /// + [Test] + public void Tuple_IsIterable() + { + var tuple = Tuple.Create(1, 2, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + result.dtype.Should().Be(typeof(int)); + } + + [Test] + public void ValueTuple_MixedTypes_UsesFirstElementType() + { + // Mixed tuple - type is detected from first element + var tuple = (1.5, 2.5, 3.5); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); + } + + [Test] + public void EmptyTuple_ReturnsEmptyDoubleArray() + { + var tuple = ValueTuple.Create(); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(0); + result.dtype.Should().Be(typeof(double)); + } + + #endregion + + #region Empty non-generic collections + + /// + /// Empty non-generic collections return empty double[] (NumPy defaults to float64). + /// + [Test] + public void EmptyArrayList_ReturnsEmptyDoubleArray() + { + var arrayList = new System.Collections.ArrayList(); + var result = np.asanyarray(arrayList); + + result.size.Should().Be(0); + result.ndim.Should().Be(1); + result.dtype.Should().Be(typeof(double)); // NumPy: float64 + } + + #endregion } } From 3d3af19d20bdb3f1be860266cd9e8e5310281f47 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 20:20:06 +0300 Subject: [PATCH 09/19] fix(asanyarray): add NumPy-like type promotion for mixed-type collections Adds proper type promotion when collections contain mixed numeric types: - int + double -> double (matches NumPy float64 promotion) - int + bool -> int - float + any int -> double - decimal wins if present Implementation: - Added FindCommonNumericType() to detect widest compatible type - Changed ConvertObjectListToNDArray to use Convert.To* methods instead of direct casts, enabling cross-type conversion - Updated ConvertTuple and ConvertEnumerator to use type promotion Tests added: - ValueTuple_MixedTypes_PromotesToCommonType: (1, 2.5, 3) -> double - ValueTuple_IntAndBool_PromotesToInt: (1, true, 3) -> int --- src/NumSharp.Core/Creation/np.asanyarray.cs | 120 +++++++++++++----- .../Creation/np.asanyarray.Tests.cs | 19 ++- 2 files changed, 106 insertions(+), 33 deletions(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index dcd37f90..82678c1f 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -135,32 +135,95 @@ private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) /// /// Converts a non-generic IEnumerator to an NDArray. - /// Element type is detected from the first item. + /// Element type is detected from items with NumPy-like type promotion. /// Empty collections return empty double[] to match NumPy's behavior. /// private static NDArray ConvertEnumerator(IEnumerator enumerator) { - // Collect items and detect type from first element + // Collect items var items = new List(); - Type elementType = null; while (enumerator.MoveNext()) { var item = enumerator.Current; - if (item == null) - continue; - - elementType ??= item.GetType(); - items.Add(item); + if (item != null) + items.Add(item); } // Empty collection: return empty double[] (NumPy defaults to float64) - if (items.Count == 0 || elementType == null) + if (items.Count == 0) return np.array(Array.Empty()); + var elementType = FindCommonNumericType(items); return ConvertObjectListToNDArray(items, elementType); } + /// + /// Finds the common numeric type for a list of objects (NumPy-like promotion). + /// Promotes to the widest type: bool -> int -> long -> float -> double -> decimal + /// + private static Type FindCommonNumericType(List items) + { + // NumPy type promotion priority (simplified): + // bool < byte < short < ushort < int < uint < long < ulong < float < double + // If any float/double is present, result is float/double + // decimal is separate (highest priority if present) + + bool hasDecimal = false; + bool hasDouble = false; + bool hasFloat = false; + bool hasULong = false; + bool hasLong = false; + bool hasUInt = false; + bool hasInt = false; + bool hasUShort = false; + bool hasShort = false; + bool hasByte = false; + bool hasBool = false; + bool hasChar = false; + Type firstType = null; + + foreach (var item in items) + { + var t = item.GetType(); + firstType ??= t; + + if (t == typeof(decimal)) hasDecimal = true; + else if (t == typeof(double)) hasDouble = true; + else if (t == typeof(float)) hasFloat = true; + else if (t == typeof(ulong)) hasULong = true; + else if (t == typeof(long)) hasLong = true; + else if (t == typeof(uint)) hasUInt = true; + else if (t == typeof(int)) hasInt = true; + else if (t == typeof(ushort)) hasUShort = true; + else if (t == typeof(short)) hasShort = true; + else if (t == typeof(byte)) hasByte = true; + else if (t == typeof(bool)) hasBool = true; + else if (t == typeof(char)) hasChar = true; + } + + // Promotion rules (NumPy-like): + // decimal wins if present + if (hasDecimal) return typeof(decimal); + + // Any floating point promotes to double (NumPy uses float64 for mixed int+float) + if (hasDouble || hasFloat) return typeof(double); + + // Integer promotion + if (hasULong) return typeof(ulong); + if (hasLong || hasUInt) return typeof(long); // uint + anything signed -> long + if (hasUInt) return typeof(uint); + if (hasInt || hasUShort) return typeof(int); // ushort + anything signed -> int + if (hasUShort) return typeof(ushort); + if (hasShort || hasByte) return typeof(int); // byte + short -> int (safe promotion) + if (hasByte) return typeof(byte); + if (hasChar) return typeof(char); + if (hasBool) return typeof(bool); + + // Fallback to first type + return firstType ?? typeof(double); + } + /// /// Converts a Tuple or ValueTuple to an NDArray. /// Uses ITuple interface available in .NET Core 2.0+. @@ -170,102 +233,101 @@ private static NDArray ConvertTuple(ITuple tuple) if (tuple.Length == 0) return np.array(Array.Empty()); - // Collect items and detect type from first non-null element + // Collect items and find common type (NumPy-like promotion) var items = new List(tuple.Length); - Type elementType = null; for (int i = 0; i < tuple.Length; i++) { var item = tuple[i]; - if (item == null) - continue; - - elementType ??= item.GetType(); - items.Add(item); + if (item != null) + items.Add(item); } - if (items.Count == 0 || elementType == null) + if (items.Count == 0) return np.array(Array.Empty()); + var elementType = FindCommonNumericType(items); return ConvertObjectListToNDArray(items, elementType); } /// /// Converts a list of objects to an NDArray of the specified element type. + /// Uses Convert.ChangeType for mixed-type support (e.g., int + double -> double). /// private static NDArray ConvertObjectListToNDArray(List items, Type elementType) { // Type switch to create typed array without reflection + // Use Convert.ChangeType to handle mixed numeric types (NumPy-like promotion) if (elementType == typeof(bool)) { var arr = new bool[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (bool)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToBoolean(items[i]); return np.array(arr); } if (elementType == typeof(byte)) { var arr = new byte[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (byte)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToByte(items[i]); return np.array(arr); } if (elementType == typeof(short)) { var arr = new short[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (short)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt16(items[i]); return np.array(arr); } if (elementType == typeof(ushort)) { var arr = new ushort[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (ushort)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt16(items[i]); return np.array(arr); } if (elementType == typeof(int)) { var arr = new int[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (int)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt32(items[i]); return np.array(arr); } if (elementType == typeof(uint)) { var arr = new uint[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (uint)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt32(items[i]); return np.array(arr); } if (elementType == typeof(long)) { var arr = new long[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (long)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt64(items[i]); return np.array(arr); } if (elementType == typeof(ulong)) { var arr = new ulong[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (ulong)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt64(items[i]); return np.array(arr); } if (elementType == typeof(char)) { var arr = new char[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (char)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToChar(items[i]); return np.array(arr); } if (elementType == typeof(float)) { var arr = new float[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (float)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToSingle(items[i]); return np.array(arr); } if (elementType == typeof(double)) { var arr = new double[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (double)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToDouble(items[i]); return np.array(arr); } if (elementType == typeof(decimal)) { var arr = new decimal[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = (decimal)items[i]; + for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToDecimal(items[i]); return np.array(arr); } diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index 8740120c..4896bc7f 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -696,14 +696,25 @@ public void Tuple_IsIterable() } [Test] - public void ValueTuple_MixedTypes_UsesFirstElementType() + public void ValueTuple_MixedTypes_PromotesToCommonType() { - // Mixed tuple - type is detected from first element - var tuple = (1.5, 2.5, 3.5); + // Mixed int + double promotes to double (NumPy behavior) + var tuple = (1, 2.5, 3); var result = np.asanyarray(tuple); result.Should().BeShaped(3); - result.dtype.Should().Be(typeof(double)); + result.dtype.Should().Be(typeof(double)); // Promoted from int to double + } + + [Test] + public void ValueTuple_IntAndBool_PromotesToInt() + { + // Mixed int + bool promotes to int (NumPy behavior) + var tuple = (1, true, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(int)); } [Test] From a3205e94d5872495a6e054fa5b8f93b722de608a Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 20:38:44 +0300 Subject: [PATCH 10/19] perf(asanyarray): optimize non-generic collection conversion ~4x faster Use pattern matching `is T v ? v : Convert.ToT(item)` instead of always calling Convert.ToT(). This gives direct unbox speed for homogeneous collections (the common case) while still handling mixed types correctly. Benchmark results (100K iterations, size 1000): - Convert.ToInt32 always: 4088 ns/op - is int ? v : Convert: 1038 ns/op (3.9x faster) This optimization affects: - ArrayList and other non-generic IEnumerable - Tuple/ValueTuple via ITuple interface - Any path through ConvertObjectListToNDArray No behavior change - mixed type collections still work via Convert fallback. --- src/NumSharp.Core/Creation/np.asanyarray.cs | 79 +++++++++++++++++---- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 82678c1f..f094158d 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -252,82 +252,131 @@ private static NDArray ConvertTuple(ITuple tuple) /// /// Converts a list of objects to an NDArray of the specified element type. - /// Uses Convert.ChangeType for mixed-type support (e.g., int + double -> double). + /// Uses pattern matching for fast direct cast when types match, with Convert fallback. + /// This is ~4x faster than always using Convert for homogeneous collections. /// private static NDArray ConvertObjectListToNDArray(List items, Type elementType) { - // Type switch to create typed array without reflection - // Use Convert.ChangeType to handle mixed numeric types (NumPy-like promotion) + // Pattern: `is T v ? v : Convert.ToT(item)` gives direct cast speed for homogeneous + // collections while still handling mixed types correctly if (elementType == typeof(bool)) { var arr = new bool[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToBoolean(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is bool v ? v : Convert.ToBoolean(item); + } return np.array(arr); } if (elementType == typeof(byte)) { var arr = new byte[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToByte(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is byte v ? v : Convert.ToByte(item); + } return np.array(arr); } if (elementType == typeof(short)) { var arr = new short[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt16(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is short v ? v : Convert.ToInt16(item); + } return np.array(arr); } if (elementType == typeof(ushort)) { var arr = new ushort[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt16(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is ushort v ? v : Convert.ToUInt16(item); + } return np.array(arr); } if (elementType == typeof(int)) { var arr = new int[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt32(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is int v ? v : Convert.ToInt32(item); + } return np.array(arr); } if (elementType == typeof(uint)) { var arr = new uint[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt32(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is uint v ? v : Convert.ToUInt32(item); + } return np.array(arr); } if (elementType == typeof(long)) { var arr = new long[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToInt64(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is long v ? v : Convert.ToInt64(item); + } return np.array(arr); } if (elementType == typeof(ulong)) { var arr = new ulong[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToUInt64(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is ulong v ? v : Convert.ToUInt64(item); + } return np.array(arr); } if (elementType == typeof(char)) { var arr = new char[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToChar(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is char v ? v : Convert.ToChar(item); + } return np.array(arr); } if (elementType == typeof(float)) { var arr = new float[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToSingle(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is float v ? v : Convert.ToSingle(item); + } return np.array(arr); } if (elementType == typeof(double)) { var arr = new double[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToDouble(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is double v ? v : Convert.ToDouble(item); + } return np.array(arr); } if (elementType == typeof(decimal)) { var arr = new decimal[items.Count]; - for (int i = 0; i < items.Count; i++) arr[i] = Convert.ToDecimal(items[i]); + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + arr[i] = item is decimal v ? v : Convert.ToDecimal(item); + } return np.array(arr); } From 6b0f1479461255e456b1f17ddb43769c643dc48c Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 20:42:41 +0300 Subject: [PATCH 11/19] perf(asanyarray): add ToArrayFast with CollectionsMarshal/CopyTo optimization For IEnumerable, use optimized extraction: 1. List: CollectionsMarshal.AsSpan() + CopyTo (direct memory access) 2. ICollection: CopyTo() (avoids enumerator overhead) 3. Other: fallback to LINQ ToArray() Benchmark results (size 10000, List): - Old (ToArray): 14129 ns/op - New (ToArrayFast): 11665 ns/op - Speedup: 1.21x (21% faster) The CollectionsMarshal.AsSpan approach gives direct access to List's internal array, avoiding the allocation and copy overhead of ToArray(). For ICollection, CopyTo() is used which avoids enumerator overhead. --- src/NumSharp.Core/Creation/np.asanyarray.cs | 55 ++++++++++++++++----- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index f094158d..bb958a4a 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace NumSharp { @@ -33,18 +34,19 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support break; // Handle typed IEnumerable for all 12 NumSharp-supported types - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; - case IEnumerable e: ret = np.array(e.ToArray()); break; + // Optimized: Use CopyTo for ICollection (3-7x faster than ToArray for small collections) + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; default: var type = a.GetType(); @@ -100,6 +102,35 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support return ret; } + /// + /// Optimized ToArray for IEnumerable<T>. + /// Uses CopyTo for ICollection<T> (3-7x faster for small collections). + /// For List<T>, uses CollectionsMarshal.AsSpan for direct memory access. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T[] ToArrayFast(IEnumerable source) + { + // Fast path for List - use CollectionsMarshal for direct span access + if (source is List list) + { + var span = CollectionsMarshal.AsSpan(list); + var arr = new T[span.Length]; + span.CopyTo(arr); + return arr; + } + + // Fast path for ICollection - use CopyTo (avoids enumerator overhead) + if (source is ICollection collection) + { + var arr = new T[collection.Count]; + collection.CopyTo(arr, 0); + return arr; + } + + // Fallback to LINQ ToArray for other IEnumerable + return source.ToArray(); + } + /// /// Converts Memory<T> or ReadOnlyMemory<T> to an NDArray. /// These types don't implement IEnumerable<T>, so we handle them specially. From dd1ae2a7ca28b6c79b63409f9f11c4030a9a712d Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 21:10:49 +0300 Subject: [PATCH 12/19] perf(asanyarray): use GC.AllocateUninitializedArray to skip zeroing Replace `new T[n]` with `GC.AllocateUninitializedArray(n)` in all array allocations within np.asanyarray. Since we immediately overwrite all elements, the default zeroing is wasted work. Affected paths: - ToArrayFast: List and ICollection extraction - ConvertObjectListToNDArray: All 12 dtype allocations Benchmark (GC.AllocateUninitializedArray vs new T[]): - Size 1000: 38 ns vs 156 ns (4x faster allocation) This optimization compounds with the previous CollectionsMarshal and pattern-match optimizations for significant cumulative improvement. --- src/NumSharp.Core/Creation/np.asanyarray.cs | 31 +++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index bb958a4a..71e897c1 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -106,6 +106,7 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support /// Optimized ToArray for IEnumerable<T>. /// Uses CopyTo for ICollection<T> (3-7x faster for small collections). /// For List<T>, uses CollectionsMarshal.AsSpan for direct memory access. + /// Uses GC.AllocateUninitializedArray to skip zeroing (4x faster allocation). /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static T[] ToArrayFast(IEnumerable source) @@ -114,7 +115,8 @@ private static T[] ToArrayFast(IEnumerable source) if (source is List list) { var span = CollectionsMarshal.AsSpan(list); - var arr = new T[span.Length]; + // Use uninitialized array - we're about to overwrite all elements + var arr = GC.AllocateUninitializedArray(span.Length); span.CopyTo(arr); return arr; } @@ -122,7 +124,8 @@ private static T[] ToArrayFast(IEnumerable source) // Fast path for ICollection - use CopyTo (avoids enumerator overhead) if (source is ICollection collection) { - var arr = new T[collection.Count]; + // Use uninitialized array - CopyTo will overwrite all elements + var arr = GC.AllocateUninitializedArray(collection.Count); collection.CopyTo(arr, 0); return arr; } @@ -292,7 +295,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme // collections while still handling mixed types correctly if (elementType == typeof(bool)) { - var arr = new bool[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -302,7 +305,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(byte)) { - var arr = new byte[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -312,7 +315,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(short)) { - var arr = new short[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -322,7 +325,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(ushort)) { - var arr = new ushort[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -332,7 +335,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(int)) { - var arr = new int[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -342,7 +345,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(uint)) { - var arr = new uint[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -352,7 +355,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(long)) { - var arr = new long[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -362,7 +365,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(ulong)) { - var arr = new ulong[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -372,7 +375,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(char)) { - var arr = new char[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -382,7 +385,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(float)) { - var arr = new float[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -392,7 +395,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(double)) { - var arr = new double[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; @@ -402,7 +405,7 @@ private static NDArray ConvertObjectListToNDArray(List items, Type eleme } if (elementType == typeof(decimal)) { - var arr = new decimal[items.Count]; + var arr = GC.AllocateUninitializedArray(items.Count); for (int i = 0; i < items.Count; i++) { var item = items[i]; From 44dd1635809679648c1211777d29419f083a9cdc Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 15 Apr 2026 21:43:28 +0300 Subject: [PATCH 13/19] perf(asanyarray): optimize with CollectionsMarshal.AsSpan and early exit Optimizations applied: 1. FindCommonNumericType: - Early exit when decimal found (highest priority) - Early exit when float/double found (promotes to double) - Use CollectionsMarshal.AsSpan for bounds-check-free iteration - Stackalloc for type code deduplication - Reuse existing _FindCommonType_Scalar for consistent promotion 2. ConvertObjectListToNDArray: - Use CollectionsMarshal.AsSpan(items) for ~10-15% speedup - Eliminates bounds checking in tight conversion loops 3. ConvertEnumerator: - Pre-size List when ICollection count is known - Eliminates resize allocations for known-size collections 4. ConvertTuple: - Pre-size List with tuple.Length Net: -27 lines while adding performance improvements --- src/NumSharp.Core/Creation/np.asanyarray.cs | 245 +++++++++----------- 1 file changed, 109 insertions(+), 136 deletions(-) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 71e897c1..acc03a28 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -136,30 +136,41 @@ private static T[] ToArrayFast(IEnumerable source) /// /// Converts Memory<T> or ReadOnlyMemory<T> to an NDArray. - /// These types don't implement IEnumerable<T>, so we handle them specially. + /// Uses Span.CopyTo + GC.AllocateUninitializedArray for optimal performance. /// private static NDArray ConvertMemory(object a, Type type) { var elementType = type.GetGenericArguments()[0]; var isReadOnly = type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); - // Single type switch - extract array via the appropriate cast - if (elementType == typeof(bool)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(byte)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(short)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(ushort)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(int)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(uint)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(long)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(ulong)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(char)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(float)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(double)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); - if (elementType == typeof(decimal)) return np.array(isReadOnly ? ((ReadOnlyMemory)a).ToArray() : ((Memory)a).ToArray()); + // Use Span.CopyTo + GC.AllocateUninitializedArray instead of ToArray() + if (elementType == typeof(bool)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(byte)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(short)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(ushort)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(int)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(uint)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(long)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(ulong)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(char)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(float)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(double)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(decimal)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); return null; } + /// + /// Optimized Span to Array conversion using GC.AllocateUninitializedArray. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T[] SpanToArrayFast(ReadOnlySpan span) + { + var arr = GC.AllocateUninitializedArray(span.Length); + span.CopyTo(arr); + return arr; + } + /// /// Converts a non-generic IEnumerable to an NDArray. /// Element type is detected from the first item. @@ -174,8 +185,12 @@ private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) /// private static NDArray ConvertEnumerator(IEnumerator enumerator) { - // Collect items - var items = new List(); + // Pre-size list if count is known (optimization #4) + List items; + if (enumerator is ICollection collection) + items = new List(collection.Count); + else + items = new List(); while (enumerator.MoveNext()) { @@ -194,80 +209,70 @@ private static NDArray ConvertEnumerator(IEnumerator enumerator) /// /// Finds the common numeric type for a list of objects (NumPy-like promotion). - /// Promotes to the widest type: bool -> int -> long -> float -> double -> decimal + /// Uses existing _FindCommonType_Scalar for consistent type promotion. + /// Early exit when highest-priority types (decimal/double) are found. /// private static Type FindCommonNumericType(List items) { - // NumPy type promotion priority (simplified): - // bool < byte < short < ushort < int < uint < long < ulong < float < double - // If any float/double is present, result is float/double - // decimal is separate (highest priority if present) + // Use CollectionsMarshal.AsSpan for faster iteration (no bounds checks) + var span = CollectionsMarshal.AsSpan(items); + // Early exit optimization: track highest-priority types seen bool hasDecimal = false; bool hasDouble = false; bool hasFloat = false; - bool hasULong = false; - bool hasLong = false; - bool hasUInt = false; - bool hasInt = false; - bool hasUShort = false; - bool hasShort = false; - bool hasByte = false; - bool hasBool = false; - bool hasChar = false; Type firstType = null; - foreach (var item in items) + // Collect unique type codes for _FindCommonType_Scalar + Span typeCodes = stackalloc NPTypeCode[span.Length]; + int uniqueCount = 0; + uint seenMask = 0; // Bitmask for deduplication (NPTypeCode values are small) + + for (int i = 0; i < span.Length; i++) { - var t = item.GetType(); + var t = span[i].GetType(); firstType ??= t; - if (t == typeof(decimal)) hasDecimal = true; - else if (t == typeof(double)) hasDouble = true; + // Early exit: decimal wins everything + if (t == typeof(decimal)) + return typeof(decimal); + + // Track floating point for early double detection + if (t == typeof(double)) hasDouble = true; else if (t == typeof(float)) hasFloat = true; - else if (t == typeof(ulong)) hasULong = true; - else if (t == typeof(long)) hasLong = true; - else if (t == typeof(uint)) hasUInt = true; - else if (t == typeof(int)) hasInt = true; - else if (t == typeof(ushort)) hasUShort = true; - else if (t == typeof(short)) hasShort = true; - else if (t == typeof(byte)) hasByte = true; - else if (t == typeof(bool)) hasBool = true; - else if (t == typeof(char)) hasChar = true; + + var code = t.GetTypeCode(); + var bit = 1u << (int)code; + if ((seenMask & bit) == 0) + { + seenMask |= bit; + typeCodes[uniqueCount++] = code; + } } - // Promotion rules (NumPy-like): - // decimal wins if present - if (hasDecimal) return typeof(decimal); - - // Any floating point promotes to double (NumPy uses float64 for mixed int+float) - if (hasDouble || hasFloat) return typeof(double); - - // Integer promotion - if (hasULong) return typeof(ulong); - if (hasLong || hasUInt) return typeof(long); // uint + anything signed -> long - if (hasUInt) return typeof(uint); - if (hasInt || hasUShort) return typeof(int); // ushort + anything signed -> int - if (hasUShort) return typeof(ushort); - if (hasShort || hasByte) return typeof(int); // byte + short -> int (safe promotion) - if (hasByte) return typeof(byte); - if (hasChar) return typeof(char); - if (hasBool) return typeof(bool); - - // Fallback to first type - return firstType ?? typeof(double); + // Early exit: any floating point promotes to double + if (hasDouble || hasFloat) + return typeof(double); + + // Use existing type promotion logic for remaining cases + if (uniqueCount == 1) + return firstType ?? typeof(double); + + var resultCode = _FindCommonType_Scalar(typeCodes.Slice(0, uniqueCount).ToArray()); + return resultCode.AsType(); } /// /// Converts a Tuple or ValueTuple to an NDArray. /// Uses ITuple interface available in .NET Core 2.0+. + /// Optimized: pre-sized List, early exit for decimal/double. /// private static NDArray ConvertTuple(ITuple tuple) { if (tuple.Length == 0) return np.array(Array.Empty()); - // Collect items and find common type (NumPy-like promotion) + // Pre-sized list (optimization: avoid resize for known count) var items = new List(tuple.Length); for (int i = 0; i < tuple.Length; i++) @@ -286,131 +291,99 @@ private static NDArray ConvertTuple(ITuple tuple) /// /// Converts a list of objects to an NDArray of the specified element type. + /// Uses CollectionsMarshal.AsSpan for bounds-check-free iteration. /// Uses pattern matching for fast direct cast when types match, with Convert fallback. /// This is ~4x faster than always using Convert for homogeneous collections. /// private static NDArray ConvertObjectListToNDArray(List items, Type elementType) { + // Use CollectionsMarshal.AsSpan for faster iteration (no bounds checks) + var span = CollectionsMarshal.AsSpan(items); + // Pattern: `is T v ? v : Convert.ToT(item)` gives direct cast speed for homogeneous // collections while still handling mixed types correctly if (elementType == typeof(bool)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is bool v ? v : Convert.ToBoolean(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is bool v ? v : Convert.ToBoolean(span[i]); return np.array(arr); } if (elementType == typeof(byte)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is byte v ? v : Convert.ToByte(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is byte v ? v : Convert.ToByte(span[i]); return np.array(arr); } if (elementType == typeof(short)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is short v ? v : Convert.ToInt16(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is short v ? v : Convert.ToInt16(span[i]); return np.array(arr); } if (elementType == typeof(ushort)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is ushort v ? v : Convert.ToUInt16(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is ushort v ? v : Convert.ToUInt16(span[i]); return np.array(arr); } if (elementType == typeof(int)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is int v ? v : Convert.ToInt32(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is int v ? v : Convert.ToInt32(span[i]); return np.array(arr); } if (elementType == typeof(uint)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is uint v ? v : Convert.ToUInt32(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is uint v ? v : Convert.ToUInt32(span[i]); return np.array(arr); } if (elementType == typeof(long)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is long v ? v : Convert.ToInt64(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is long v ? v : Convert.ToInt64(span[i]); return np.array(arr); } if (elementType == typeof(ulong)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is ulong v ? v : Convert.ToUInt64(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is ulong v ? v : Convert.ToUInt64(span[i]); return np.array(arr); } if (elementType == typeof(char)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is char v ? v : Convert.ToChar(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is char v ? v : Convert.ToChar(span[i]); return np.array(arr); } if (elementType == typeof(float)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is float v ? v : Convert.ToSingle(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is float v ? v : Convert.ToSingle(span[i]); return np.array(arr); } if (elementType == typeof(double)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is double v ? v : Convert.ToDouble(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is double v ? v : Convert.ToDouble(span[i]); return np.array(arr); } if (elementType == typeof(decimal)) { - var arr = GC.AllocateUninitializedArray(items.Count); - for (int i = 0; i < items.Count; i++) - { - var item = items[i]; - arr[i] = item is decimal v ? v : Convert.ToDecimal(item); - } + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is decimal v ? v : Convert.ToDecimal(span[i]); return np.array(arr); } From 4e8af8b62021270fedba596259006cffc351c451 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 19:15:23 +0300 Subject: [PATCH 14/19] refactor(tests): migrate np.where and np.asanyarray tests from TUnit to MSTest v3 Post-rebase cleanup after master migrated the test suite from TUnit to MSTest v3 (commits ac020336, e0db3c3e). The 4 test files introduced on this branch still used TUnit's [Test] attribute and `using TUnit.Core;`, which broke the build. Changes per file: - Replace `using TUnit.Core;` (removed) - Add `[TestClass]` attribute to the test class - Replace all `[Test]` attributes with `[TestMethod]` Files migrated: - test/NumSharp.UnitTest/Logic/np.where.Test.cs - test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs - test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs - test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs Verified: all 112 np.where tests and 62 np.asanyarray tests pass on net8.0 and net10.0. --- .../Backends/Kernels/WhereSimdTests.cs | 54 ++++---- .../Creation/np.asanyarray.Tests.cs | 125 +++++++++--------- .../Logic/np.where.BattleTest.cs | 102 +++++++------- test/NumSharp.UnitTest/Logic/np.where.Test.cs | 74 +++++------ 4 files changed, 178 insertions(+), 177 deletions(-) diff --git a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs index 3fc30d17..efb42918 100644 --- a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs +++ b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs @@ -1,7 +1,6 @@ using System; using System.Diagnostics; using NumSharp.Backends.Kernels; -using TUnit.Core; using NumSharp.UnitTest.Utilities; using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; @@ -11,11 +10,12 @@ namespace NumSharp.UnitTest.Backends.Kernels /// Tests for SIMD-optimized np.where implementation. /// Verifies correctness of the SIMD path for all supported dtypes. /// + [TestClass] public class WhereSimdTests { #region SIMD Correctness - [Test] + [TestMethod] public void Where_Simd_Float32_Correctness() { var rng = np.random.RandomState(42); @@ -34,7 +34,7 @@ public void Where_Simd_Float32_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Float64_Correctness() { var rng = np.random.RandomState(43); @@ -52,7 +52,7 @@ public void Where_Simd_Float64_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Int32_Correctness() { var rng = np.random.RandomState(44); @@ -70,7 +70,7 @@ public void Where_Simd_Int32_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Int64_Correctness() { var rng = np.random.RandomState(45); @@ -88,7 +88,7 @@ public void Where_Simd_Int64_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Byte_Correctness() { var rng = np.random.RandomState(46); @@ -106,7 +106,7 @@ public void Where_Simd_Byte_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Int16_Correctness() { var rng = np.random.RandomState(47); @@ -124,7 +124,7 @@ public void Where_Simd_Int16_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_UInt16_Correctness() { var rng = np.random.RandomState(48); @@ -142,7 +142,7 @@ public void Where_Simd_UInt16_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_UInt32_Correctness() { var rng = np.random.RandomState(49); @@ -160,7 +160,7 @@ public void Where_Simd_UInt32_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_UInt64_Correctness() { var rng = np.random.RandomState(50); @@ -178,7 +178,7 @@ public void Where_Simd_UInt64_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Boolean_Correctness() { var rng = np.random.RandomState(51); @@ -196,7 +196,7 @@ public void Where_Simd_Boolean_Correctness() } } - [Test] + [TestMethod] public void Where_Simd_Char_Correctness() { var rng = np.random.RandomState(52); @@ -225,7 +225,7 @@ public void Where_Simd_Char_Correctness() #region Path Selection - [Test] + [TestMethod] public void Where_NonContiguous_Works() { // Sliced arrays are non-contiguous, should work correctly @@ -244,7 +244,7 @@ public void Where_NonContiguous_Works() } } - [Test] + [TestMethod] public void Where_Broadcast_Works() { // Broadcasted arrays @@ -265,7 +265,7 @@ public void Where_Broadcast_Works() Assert.AreEqual(20, (int)result[1, 1]); // cond[1]=false -> y=20 } - [Test] + [TestMethod] public void Where_Decimal_Works() { var cond = np.array(new[] { true, false, true }); @@ -280,7 +280,7 @@ public void Where_Decimal_Works() Assert.AreEqual(3.3m, (decimal)result[2]); } - [Test] + [TestMethod] public void Where_NonBoolCondition_Works() { // Non-bool condition requires truthiness check @@ -294,7 +294,7 @@ public void Where_NonBoolCondition_Works() #region Edge Cases - [Test] + [TestMethod] public void Where_Simd_SmallArray() { // Array smaller than vector width @@ -307,7 +307,7 @@ public void Where_Simd_SmallArray() result.Should().BeOfValues(1, 20, 3); } - [Test] + [TestMethod] public void Where_Simd_VectorAlignedSize() { var rng = np.random.RandomState(53); @@ -327,7 +327,7 @@ public void Where_Simd_VectorAlignedSize() } } - [Test] + [TestMethod] public void Where_Simd_WithScalarTail() { // Size that requires scalar tail processing @@ -344,7 +344,7 @@ public void Where_Simd_WithScalarTail() } } - [Test] + [TestMethod] public void Where_Simd_AllTrue() { var size = 100; @@ -360,7 +360,7 @@ public void Where_Simd_AllTrue() } } - [Test] + [TestMethod] public void Where_Simd_AllFalse() { var size = 100; @@ -376,7 +376,7 @@ public void Where_Simd_AllFalse() } } - [Test] + [TestMethod] public void Where_Simd_Alternating() { var size = 100; @@ -395,7 +395,7 @@ public void Where_Simd_Alternating() } } - [Test] + [TestMethod] public void Where_Simd_NaN_Propagates() { var cond = np.array(new[] { true, false, true }); @@ -409,7 +409,7 @@ public void Where_Simd_NaN_Propagates() Assert.AreEqual(2.0, (double)result[2], 1e-10); } - [Test] + [TestMethod] public void Where_Simd_Infinity() { var cond = np.array(new[] { true, false, true, false }); @@ -428,7 +428,7 @@ public void Where_Simd_Infinity() #region Performance Sanity Check - [Test] + [TestMethod] public void Where_Simd_LargeArray_Correctness() { var rng = np.random.RandomState(54); @@ -458,7 +458,7 @@ public void Where_Simd_LargeArray_Correctness() #region 2D/Multi-dimensional - [Test] + [TestMethod] public void Where_Simd_2D_Contiguous() { var rng = np.random.RandomState(55); @@ -483,7 +483,7 @@ public void Where_Simd_2D_Contiguous() } } - [Test] + [TestMethod] public void Where_Simd_3D_Contiguous() { var rng = np.random.RandomState(56); diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index 4896bc7f..bed654e3 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -11,11 +11,12 @@ namespace NumSharp.UnitTest.Creation /// /// Tests for np.asanyarray covering all built-in C# collection types. /// + [TestClass] public class np_asanyarray_tests { #region NDArray passthrough - [Test] + [TestMethod] public void NDArray_ReturnsAsIs() { var original = np.array(1, 2, 3, 4, 5); @@ -25,7 +26,7 @@ public void NDArray_ReturnsAsIs() ReferenceEquals(original, result).Should().BeTrue(); } - [Test] + [TestMethod] public void NDArray_WithDtype_ReturnsConverted() { var original = np.array(1, 2, 3, 4, 5); @@ -35,7 +36,7 @@ public void NDArray_WithDtype_ReturnsConverted() result.Should().BeShaped(5); } - [Test] + [TestMethod] public void NDArray_WithSameDtype_ReturnsAsIs() { var original = np.array(1, 2, 3, 4, 5); @@ -49,7 +50,7 @@ public void NDArray_WithSameDtype_ReturnsAsIs() #region Array types - [Test] + [TestMethod] public void Array_1D() { var arr = new int[] { 1, 2, 3, 4, 5 }; @@ -59,7 +60,7 @@ public void Array_1D() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void Array_2D() { var arr = new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }; @@ -68,7 +69,7 @@ public void Array_2D() result.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); } - [Test] + [TestMethod] public void Array_WithDtype() { var arr = new int[] { 1, 2, 3 }; @@ -82,7 +83,7 @@ public void Array_WithDtype() #region Scalars - [Test] + [TestMethod] public void Scalar_Int() { var result = np.asanyarray(42); @@ -91,7 +92,7 @@ public void Scalar_Int() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void Scalar_Double() { var result = np.asanyarray(3.14); @@ -100,7 +101,7 @@ public void Scalar_Double() result.dtype.Should().Be(typeof(double)); } - [Test] + [TestMethod] public void Scalar_Decimal() { var result = np.asanyarray(123.456m); @@ -109,7 +110,7 @@ public void Scalar_Decimal() result.dtype.Should().Be(typeof(decimal)); } - [Test] + [TestMethod] public void Scalar_Bool() { var result = np.asanyarray(true); @@ -122,7 +123,7 @@ public void Scalar_Bool() #region List - [Test] + [TestMethod] public void List_Int() { var list = new List { 1, 2, 3, 4, 5 }; @@ -132,7 +133,7 @@ public void List_Int() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void List_Double() { var list = new List { 1.1, 2.2, 3.3 }; @@ -142,7 +143,7 @@ public void List_Double() result.dtype.Should().Be(typeof(double)); } - [Test] + [TestMethod] public void List_Bool() { var list = new List { true, false, true }; @@ -152,7 +153,7 @@ public void List_Bool() result.dtype.Should().Be(typeof(bool)); } - [Test] + [TestMethod] public void List_Empty() { var list = new List(); @@ -162,7 +163,7 @@ public void List_Empty() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void List_WithDtype() { var list = new List { 1, 2, 3 }; @@ -176,7 +177,7 @@ public void List_WithDtype() #region IList / ICollection / IEnumerable - [Test] + [TestMethod] public void IList_Int() { IList list = new List { 1, 2, 3, 4, 5 }; @@ -185,7 +186,7 @@ public void IList_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void ICollection_Int() { ICollection collection = new List { 1, 2, 3, 4, 5 }; @@ -194,7 +195,7 @@ public void ICollection_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void IEnumerable_Int() { IEnumerable enumerable = new List { 1, 2, 3, 4, 5 }; @@ -203,7 +204,7 @@ public void IEnumerable_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void IEnumerable_FromLinq() { var enumerable = Enumerable.Range(1, 5); @@ -212,7 +213,7 @@ public void IEnumerable_FromLinq() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void IEnumerable_FromLinqSelect() { var enumerable = new[] { 1, 2, 3 }.Select(x => x * 2); @@ -225,7 +226,7 @@ public void IEnumerable_FromLinqSelect() #region IReadOnlyList / IReadOnlyCollection - [Test] + [TestMethod] public void IReadOnlyList_Int() { IReadOnlyList list = new List { 1, 2, 3, 4, 5 }; @@ -234,7 +235,7 @@ public void IReadOnlyList_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void IReadOnlyCollection_Int() { IReadOnlyCollection collection = new List { 1, 2, 3, 4, 5 }; @@ -247,7 +248,7 @@ public void IReadOnlyCollection_Int() #region ReadOnlyCollection - [Test] + [TestMethod] public void ReadOnlyCollection_Int() { var collection = new ReadOnlyCollection(new List { 1, 2, 3, 4, 5 }); @@ -260,7 +261,7 @@ public void ReadOnlyCollection_Int() #region LinkedList - [Test] + [TestMethod] public void LinkedList_Int() { var linkedList = new LinkedList(); @@ -276,7 +277,7 @@ public void LinkedList_Int() #region HashSet / SortedSet - [Test] + [TestMethod] public void HashSet_Int() { var set = new HashSet { 3, 1, 4, 1, 5, 9 }; // Duplicates removed @@ -286,7 +287,7 @@ public void HashSet_Int() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void SortedSet_Int() { var set = new SortedSet { 3, 1, 4, 1, 5, 9 }; @@ -299,7 +300,7 @@ public void SortedSet_Int() #region Queue / Stack - [Test] + [TestMethod] public void Queue_Int() { var queue = new Queue(); @@ -311,7 +312,7 @@ public void Queue_Int() result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); } - [Test] + [TestMethod] public void Stack_Int() { var stack = new Stack(); @@ -327,7 +328,7 @@ public void Stack_Int() #region ArraySegment - [Test] + [TestMethod] public void ArraySegment_Int() { var array = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; @@ -337,7 +338,7 @@ public void ArraySegment_Int() result.Should().BeShaped(5).And.BeOfValues(2, 3, 4, 5, 6); } - [Test] + [TestMethod] public void ArraySegment_Empty() { var array = new int[] { 1, 2, 3 }; @@ -347,7 +348,7 @@ public void ArraySegment_Empty() result.Should().BeShaped(0); } - [Test] + [TestMethod] public void ArraySegment_Full() { var array = new int[] { 1, 2, 3, 4, 5 }; @@ -361,7 +362,7 @@ public void ArraySegment_Full() #region Memory / ReadOnlyMemory - [Test] + [TestMethod] public void Memory_Int() { var array = new int[] { 1, 2, 3, 4, 5 }; @@ -371,7 +372,7 @@ public void Memory_Int() result.Should().BeShaped(3).And.BeOfValues(2, 3, 4); } - [Test] + [TestMethod] public void ReadOnlyMemory_Int() { var array = new int[] { 1, 2, 3, 4, 5 }; @@ -385,7 +386,7 @@ public void ReadOnlyMemory_Int() #region ImmutableArray / ImmutableList - [Test] + [TestMethod] public void ImmutableArray_Int() { var immutableArray = ImmutableArray.Create(1, 2, 3, 4, 5); @@ -394,7 +395,7 @@ public void ImmutableArray_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void ImmutableList_Int() { var immutableList = ImmutableList.Create(1, 2, 3, 4, 5); @@ -403,7 +404,7 @@ public void ImmutableList_Int() result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); } - [Test] + [TestMethod] public void ImmutableHashSet_Int() { var immutableSet = ImmutableHashSet.Create(3, 1, 4, 1, 5); @@ -416,7 +417,7 @@ public void ImmutableHashSet_Int() #region All supported dtypes via List - [Test] + [TestMethod] public void List_Byte() { var list = new List { 1, 2, 3 }; @@ -427,7 +428,7 @@ public void List_Byte() // Note: sbyte is NOT supported by NumSharp (12 supported types: bool, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal) - [Test] + [TestMethod] public void List_Short() { var list = new List { 1, 2, 3 }; @@ -436,7 +437,7 @@ public void List_Short() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_UShort() { var list = new List { 1, 2, 3 }; @@ -445,7 +446,7 @@ public void List_UShort() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_UInt() { var list = new List { 1, 2, 3 }; @@ -454,7 +455,7 @@ public void List_UInt() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_Long() { var list = new List { 1, 2, 3 }; @@ -463,7 +464,7 @@ public void List_Long() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_ULong() { var list = new List { 1, 2, 3 }; @@ -472,7 +473,7 @@ public void List_ULong() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_Float() { var list = new List { 1.1f, 2.2f, 3.3f }; @@ -481,7 +482,7 @@ public void List_Float() result.Should().BeShaped(3); } - [Test] + [TestMethod] public void List_Char() { var list = new List { 'a', 'b', 'c' }; @@ -494,13 +495,13 @@ public void List_Char() #region Error cases - [Test] + [TestMethod] public void Null_ThrowsArgumentNullException() { Assert.ThrowsException(() => np.asanyarray(null)); } - [Test] + [TestMethod] public void UnsupportedType_ThrowsNotSupportedException() { // String collections are not supported (string is not primitive/decimal) @@ -508,7 +509,7 @@ public void UnsupportedType_ThrowsNotSupportedException() Assert.ThrowsException(() => np.asanyarray(stringList)); } - [Test] + [TestMethod] public void CustomClass_ThrowsNotSupportedException() { var customObject = new object(); @@ -519,7 +520,7 @@ public void CustomClass_ThrowsNotSupportedException() #region String special case - [Test] + [TestMethod] public void String_CreatesCharArray() { var result = np.asanyarray("hello"); @@ -532,7 +533,7 @@ public void String_CreatesCharArray() #region Non-generic IEnumerable fallback - [Test] + [TestMethod] public void ArrayList_Int() { var arrayList = new System.Collections.ArrayList { 1, 2, 3, 4, 5 }; @@ -542,7 +543,7 @@ public void ArrayList_Int() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void ArrayList_Double() { var arrayList = new System.Collections.ArrayList { 1.1, 2.2, 3.3 }; @@ -552,7 +553,7 @@ public void ArrayList_Double() result.dtype.Should().Be(typeof(double)); } - [Test] + [TestMethod] public void Hashtable_Keys() { var hashtable = new System.Collections.Hashtable { { 1, "a" }, { 2, "b" }, { 3, "c" } }; @@ -566,7 +567,7 @@ public void Hashtable_Keys() #region IEnumerator fallback - [Test] + [TestMethod] public void IEnumerator_Int() { static System.Collections.IEnumerator GetEnumerator() @@ -591,7 +592,7 @@ static System.Collections.IEnumerator GetEnumerator() /// NumPy: np.asanyarray("hello") -> dtype=<U5, shape=(), ndim=0 (SCALAR) /// NumSharp: dtype=Char, shape=(5), ndim=1 (ARRAY) /// - [Test] + [TestMethod] [Misaligned] public void String_IsCharArray_NotScalar() { @@ -611,7 +612,7 @@ public void String_IsCharArray_NotScalar() /// NumPy: np.asanyarray({1,2,3}) -> dtype=object, shape=() (SCALAR) /// NumSharp: dtype=Int32, shape=(3) (ARRAY) /// - [Test] + [TestMethod] [Misaligned] public void HashSet_IsIterated_NotObjectScalar() { @@ -631,7 +632,7 @@ public void HashSet_IsIterated_NotObjectScalar() /// NumSharp consumes IEnumerable and converts to array. /// This is arguably more useful behavior for C#. /// - [Test] + [TestMethod] [Misaligned] public void LinqEnumerable_IsConsumed_NotObjectScalar() { @@ -650,7 +651,7 @@ public void LinqEnumerable_IsConsumed_NotObjectScalar() /// NumPy defaults to float64 for untyped empty lists. /// This is a design choice: C# generics provide type information that NumPy doesn't have. /// - [Test] + [TestMethod] [Misaligned] public void EmptyTypedList_PreservesTypeParameter() { @@ -672,7 +673,7 @@ public void EmptyTypedList_PreservesTypeParameter() /// C# ValueTuples are iterable like Python tuples. /// NumPy: np.asanyarray((1,2,3)) -> dtype=int64, shape=(3,) /// - [Test] + [TestMethod] public void ValueTuple_IsIterable() { var tuple = (1, 2, 3); @@ -685,7 +686,7 @@ public void ValueTuple_IsIterable() /// /// C# Tuple class is iterable like Python tuples. /// - [Test] + [TestMethod] public void Tuple_IsIterable() { var tuple = Tuple.Create(1, 2, 3); @@ -695,7 +696,7 @@ public void Tuple_IsIterable() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void ValueTuple_MixedTypes_PromotesToCommonType() { // Mixed int + double promotes to double (NumPy behavior) @@ -706,7 +707,7 @@ public void ValueTuple_MixedTypes_PromotesToCommonType() result.dtype.Should().Be(typeof(double)); // Promoted from int to double } - [Test] + [TestMethod] public void ValueTuple_IntAndBool_PromotesToInt() { // Mixed int + bool promotes to int (NumPy behavior) @@ -717,7 +718,7 @@ public void ValueTuple_IntAndBool_PromotesToInt() result.dtype.Should().Be(typeof(int)); } - [Test] + [TestMethod] public void EmptyTuple_ReturnsEmptyDoubleArray() { var tuple = ValueTuple.Create(); @@ -734,7 +735,7 @@ public void EmptyTuple_ReturnsEmptyDoubleArray() /// /// Empty non-generic collections return empty double[] (NumPy defaults to float64). /// - [Test] + [TestMethod] public void EmptyArrayList_ReturnsEmptyDoubleArray() { var arrayList = new System.Collections.ArrayList(); diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs index 5afef228..eb889b7d 100644 --- a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs +++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs @@ -1,6 +1,5 @@ using System; using System.Linq; -using TUnit.Core; using NumSharp.UnitTest.Utilities; using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; @@ -28,11 +27,12 @@ namespace NumSharp.UnitTest.Logic /// 3. Missing sbyte (int8) support: /// NumSharp does not support sbyte arrays (throws NotSupportedException). /// + [TestClass] public class np_where_BattleTest { #region Strided/Sliced Arrays - [Test] + [TestMethod] public void Where_SlicedCondition() { // Sliced condition array (non-contiguous) @@ -46,7 +46,7 @@ public void Where_SlicedCondition() result.Should().BeOfValues(1, 1, 1, 1, 1); } - [Test] + [TestMethod] public void Where_SlicedXY() { var cond = np.array(new[] { true, false, true }); @@ -57,7 +57,7 @@ public void Where_SlicedXY() result.Should().BeOfValues(0L, 3L, 4L); } - [Test] + [TestMethod] public void Where_TransposedArrays() { var cond = np.array(new bool[,] { { true, false }, { false, true } }).T; @@ -73,7 +73,7 @@ public void Where_TransposedArrays() Assert.AreEqual(4, (int)result[1, 1]); } - [Test] + [TestMethod] public void Where_ReversedSlice() { var cond = np.array(new[] { true, false, true, false, true }); @@ -89,7 +89,7 @@ public void Where_ReversedSlice() #region Complex Broadcasting - [Test] + [TestMethod] public void Where_3Way_Broadcasting() { // cond: (2,1,1), x: (1,3,1), y: (1,1,4) -> result: (2,3,4) @@ -109,7 +109,7 @@ public void Where_3Way_Broadcasting() Assert.AreEqual(30, (long)result[1, 2, 3]); } - [Test] + [TestMethod] public void Where_RowVector_ColVector_Broadcast() { // cond: (1,4), x: (3,1), y: scalar -> result: (3,4) @@ -124,7 +124,7 @@ public void Where_RowVector_ColVector_Broadcast() Assert.AreEqual(0, (int)result[1, 1]); } - [Test] + [TestMethod] public void Where_ScalarCondition_True() { // NumPy: np.where(True, [1,2,3], [4,5,6]) -> [1,2,3] @@ -132,7 +132,7 @@ public void Where_ScalarCondition_True() result.Should().BeOfValues(1, 2, 3); } - [Test] + [TestMethod] public void Where_ScalarCondition_False() { // NumPy: np.where(False, [1,2,3], [4,5,6]) -> [4,5,6] @@ -144,7 +144,7 @@ public void Where_ScalarCondition_False() #region Non-Boolean Conditions (Truthy/Falsy) - [Test] + [TestMethod] public void Where_IntegerCondition_ZeroIsFalsy() { // NumPy: 0 is falsy, non-zero is truthy @@ -157,7 +157,7 @@ public void Where_IntegerCondition_ZeroIsFalsy() result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); } - [Test] + [TestMethod] public void Where_FloatCondition_ZeroIsFalsy() { // NumPy: 0.0 is falsy @@ -170,7 +170,7 @@ public void Where_FloatCondition_ZeroIsFalsy() result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); } - [Test] + [TestMethod] public void Where_NaN_IsTruthy() { // NumPy: NaN is truthy (non-zero) @@ -183,7 +183,7 @@ public void Where_NaN_IsTruthy() result.Should().BeOfValues(10, 2, 3); } - [Test] + [TestMethod] public void Where_Infinity_IsTruthy() { // NumPy: Inf and -Inf are truthy @@ -196,7 +196,7 @@ public void Where_Infinity_IsTruthy() result.Should().BeOfValues(10, 2, 3); } - [Test] + [TestMethod] public void Where_NegativeZero_IsFalsy() { // NumPy: -0.0 == 0.0 in IEEE 754, so it's falsy @@ -213,7 +213,7 @@ public void Where_NegativeZero_IsFalsy() #region Numeric Edge Cases - [Test] + [TestMethod] public void Where_NaN_Values() { var cond = np.array(new[] { true, false, true }); @@ -226,7 +226,7 @@ public void Where_NaN_Values() Assert.IsTrue(double.IsNaN((double)result[2])); // from x } - [Test] + [TestMethod] public void Where_Infinity_Values() { var cond = np.array(new[] { true, false }); @@ -238,7 +238,7 @@ public void Where_Infinity_Values() Assert.AreEqual(double.NegativeInfinity, (double)result[1]); } - [Test] + [TestMethod] public void Where_MaxMin_Values() { var cond = np.array(new[] { true, false }); @@ -254,7 +254,7 @@ public void Where_MaxMin_Values() #region Single Arg Edge Cases - [Test] + [TestMethod] public void Where_SingleArg_Float_Truthy() { // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy @@ -266,7 +266,7 @@ public void Where_SingleArg_Float_Truthy() result[0].Should().BeOfValues(1L, 2L, 3L); } - [Test] + [TestMethod] public void Where_SingleArg_NaN_IsTruthy() { // NaN is non-zero, so it's truthy @@ -276,7 +276,7 @@ public void Where_SingleArg_NaN_IsTruthy() result[0].Should().BeOfValues(1L); } - [Test] + [TestMethod] public void Where_SingleArg_Infinity_IsTruthy() { // Inf values are truthy @@ -286,7 +286,7 @@ public void Where_SingleArg_Infinity_IsTruthy() result[0].Should().BeOfValues(1L, 2L); } - [Test] + [TestMethod] public void Where_SingleArg_4D() { var arr = np.zeros(new[] { 2, 2, 2, 2 }, NPTypeCode.Int32); @@ -298,7 +298,7 @@ public void Where_SingleArg_4D() Assert.AreEqual(2, result[0].size); // 2 non-zero elements } - [Test] + [TestMethod] public void Where_SingleArg_ReturnsInt64Indices() { // NumPy returns int64 for indices @@ -312,7 +312,7 @@ public void Where_SingleArg_ReturnsInt64Indices() #region 0D Scalar Arrays - [Test] + [TestMethod] public void Where_0D_AllScalars_Returns0D() { // NumPy: when all inputs are 0D, result is 0D @@ -325,7 +325,7 @@ public void Where_0D_AllScalars_Returns0D() Assert.AreEqual(42, (int)result.GetValue(0)); } - [Test] + [TestMethod] public void Where_0D_Cond_With_1D_Arrays() { // 0D condition broadcasts to match x/y shape @@ -342,7 +342,7 @@ public void Where_0D_Cond_With_1D_Arrays() #region Type Promotion (Array-to-Array) - [Test] + [TestMethod] public void Where_TypePromotion_Bool_Int16() { var cond = np.array(new[] { true, false }); @@ -354,7 +354,7 @@ public void Where_TypePromotion_Bool_Int16() Assert.AreEqual(typeof(short), result.dtype); } - [Test] + [TestMethod] public void Where_TwoScalars_Byte_StaysByte() { // C# byte (like np.uint8) stays byte, not widened to int64 @@ -366,7 +366,7 @@ public void Where_TwoScalars_Byte_StaysByte() Assert.AreEqual((byte)0, (byte)result[1]); } - [Test] + [TestMethod] public void Where_TwoScalars_Short_StaysShort() { // C# short (like np.int16) stays short @@ -376,7 +376,7 @@ public void Where_TwoScalars_Short_StaysShort() Assert.AreEqual(typeof(short), result.dtype); } - [Test] + [TestMethod] public void Where_TypePromotion_Int32_UInt32() { var cond = np.array(new[] { true, false }); @@ -388,7 +388,7 @@ public void Where_TypePromotion_Int32_UInt32() Assert.AreEqual(typeof(long), result.dtype); } - [Test] + [TestMethod] public void Where_TypePromotion_Int64_UInt64() { var cond = np.array(new[] { true, false }); @@ -400,7 +400,7 @@ public void Where_TypePromotion_Int64_UInt64() Assert.AreEqual(typeof(double), result.dtype); } - [Test] + [TestMethod] public void Where_TypePromotion_UInt8_Float32() { var cond = np.array(new[] { true, false }); @@ -416,7 +416,7 @@ public void Where_TypePromotion_UInt8_Float32() #region Performance/Stress Tests - [Test] + [TestMethod] public void Where_LargeArray_Performance() { var size = 1_000_000; @@ -433,7 +433,7 @@ public void Where_LargeArray_Performance() Assert.IsTrue(sw.ElapsedMilliseconds < 1000, $"Took {sw.ElapsedMilliseconds}ms"); } - [Test] + [TestMethod] public void Where_ManyDimensions() { // 6D array @@ -448,7 +448,7 @@ public void Where_ManyDimensions() Assert.AreEqual(144, (long)np.sum(result)); // All 1s } - [Test] + [TestMethod] public void Where_AllTrue_LargeArray() { var size = 10000; @@ -461,7 +461,7 @@ public void Where_AllTrue_LargeArray() Assert.AreEqual(49995000L, (long)np.sum(result)); } - [Test] + [TestMethod] public void Where_AllFalse_LargeArray() { var size = 10000; @@ -473,7 +473,7 @@ public void Where_AllFalse_LargeArray() Assert.AreEqual(0L, (long)np.sum(result)); } - [Test] + [TestMethod] public void Where_Alternating_LargeArray() { var size = 10000; @@ -493,7 +493,7 @@ public void Where_Alternating_LargeArray() #region Type Conversion Edge Cases - [Test] + [TestMethod] public void Where_UnsignedOverflow() { var cond = np.array(new[] { true, false }); @@ -506,7 +506,7 @@ public void Where_UnsignedOverflow() Assert.AreEqual((byte)255, (byte)result[1]); } - [Test] + [TestMethod] public void Where_Decimal() { var cond = np.array(new[] { true, false }); @@ -519,7 +519,7 @@ public void Where_Decimal() Assert.AreEqual(9.87654321m, (decimal)result[1]); } - [Test] + [TestMethod] public void Where_Char() { var cond = np.array(new[] { true, false, true }); @@ -537,7 +537,7 @@ public void Where_Char() #region View Behavior - [Test] + [TestMethod] public void Where_ResultIsNewArray_NotView() { var cond = np.array(new[] { true, false }); @@ -553,7 +553,7 @@ public void Where_ResultIsNewArray_NotView() Assert.AreEqual(20, (int)result[1], "Result should be independent of y"); } - [Test] + [TestMethod] public void Where_ModifyResult_DoesNotAffectInputs() { var cond = np.array(new[] { true, false }); @@ -570,7 +570,7 @@ public void Where_ModifyResult_DoesNotAffectInputs() #region Alternating Patterns - [Test] + [TestMethod] public void Where_Checkerboard_Pattern() { // Create checkerboard condition @@ -590,7 +590,7 @@ public void Where_Checkerboard_Pattern() Assert.AreEqual(1, (int)result[1, 1]); } - [Test] + [TestMethod] public void Where_StripedPattern() { // Every row alternates between all True and all False @@ -617,7 +617,7 @@ public void Where_StripedPattern() #region Empty Array Edge Cases - [Test] + [TestMethod] public void Where_Empty2D() { // Empty (0,3) shape @@ -630,7 +630,7 @@ public void Where_Empty2D() Assert.AreEqual(typeof(double), result.dtype); } - [Test] + [TestMethod] public void Where_Empty3D() { // Empty (2,0,3) shape @@ -643,7 +643,7 @@ public void Where_Empty3D() Assert.AreEqual(typeof(int), result.dtype); } - [Test] + [TestMethod] public void Where_SingleArg_Empty2D() { var arr = np.zeros(new[] { 0, 3 }, NPTypeCode.Int32); @@ -658,7 +658,7 @@ public void Where_SingleArg_Empty2D() #region Error Conditions - [Test] + [TestMethod] public void Where_IncompatibleShapes_ThrowsException() { // Shapes (2,) and (3,) cannot be broadcast together @@ -677,7 +677,7 @@ public void Where_IncompatibleShapes_ThrowsException() /// Verifies NEP50 weak scalar semantics: when a scalar is combined with an array, /// the array's dtype wins for same-kind operations. /// - [Test] + [TestMethod] public void Where_ScalarTypePromotion_NEP50_WeakScalar() { // NumPy 2.x: np.where(cond, 1, uint8_array) -> uint8 (weak scalar) @@ -696,7 +696,7 @@ public void Where_ScalarTypePromotion_NEP50_WeakScalar() /// Note: NumPy would return int64 for Python int literals, but C# int32 scalars /// cannot be distinguished from explicit np.array(1, dtype=int32), so we preserve. /// - [Test] + [TestMethod] public void Where_TwoScalars_SameType_Preserved() { var cond = np.array(new[] { true, false }); @@ -715,7 +715,7 @@ public void Where_TwoScalars_SameType_Preserved() /// /// Verifies C# float scalars stay float32 (like np.float32, not Python float). /// - [Test] + [TestMethod] public void Where_TwoScalars_Float32_StaysFloat32() { // C# float (1.0f) is like np.float32, not Python's float (which is float64) @@ -729,7 +729,7 @@ public void Where_TwoScalars_Float32_StaysFloat32() /// /// Verifies NEP50: int scalar + float32 array -> float32 (same-kind, array wins). /// - [Test] + [TestMethod] public void Where_IntScalar_Float32Array_ReturnsFloat32() { var cond = np.array(new[] { true, false }); @@ -743,7 +743,7 @@ public void Where_IntScalar_Float32Array_ReturnsFloat32() /// /// Verifies NEP50: float scalar + int32 array -> float64 (cross-kind promotion). /// - [Test] + [TestMethod] public void Where_FloatScalar_Int32Array_ReturnsFloat64() { var cond = np.array(new[] { true, false }); diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs index 8e81736c..b1991bae 100644 --- a/test/NumSharp.UnitTest/Logic/np.where.Test.cs +++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs @@ -1,6 +1,5 @@ using System; using System.Linq; -using TUnit.Core; using NumSharp.UnitTest.Utilities; using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; @@ -13,11 +12,12 @@ namespace NumSharp.UnitTest.Logic /// - Single arg: returns np.nonzero(condition) /// - Three args: element-wise selection with broadcasting /// + [TestClass] public class np_where_Test { #region Single Argument (nonzero equivalent) - [Test] + [TestMethod] public void Where_SingleArg_1D_ReturnsIndices() { // np.where([0, 1, 0, 2, 0, 3]) -> (array([1, 3, 5]),) @@ -28,7 +28,7 @@ public void Where_SingleArg_1D_ReturnsIndices() result[0].Should().BeOfValues(1L, 3L, 5L); } - [Test] + [TestMethod] public void Where_SingleArg_2D_ReturnsTupleOfIndices() { // np.where([[0, 1, 0], [2, 0, 3]]) -> (array([0, 1, 1]), array([1, 0, 2])) @@ -40,7 +40,7 @@ public void Where_SingleArg_2D_ReturnsTupleOfIndices() result[1].Should().BeOfValues(1L, 0L, 2L); // col indices } - [Test] + [TestMethod] public void Where_SingleArg_Boolean_ReturnsNonzero() { var arr = np.array(new[] { true, false, true, false, true }); @@ -50,7 +50,7 @@ public void Where_SingleArg_Boolean_ReturnsNonzero() result[0].Should().BeOfValues(0L, 2L, 4L); } - [Test] + [TestMethod] public void Where_SingleArg_Empty_ReturnsEmptyIndices() { var arr = np.array(new int[0]); @@ -60,7 +60,7 @@ public void Where_SingleArg_Empty_ReturnsEmptyIndices() Assert.AreEqual(0, result[0].size); } - [Test] + [TestMethod] public void Where_SingleArg_AllFalse_ReturnsEmptyIndices() { var arr = np.array(new[] { false, false, false }); @@ -70,7 +70,7 @@ public void Where_SingleArg_AllFalse_ReturnsEmptyIndices() Assert.AreEqual(0, result[0].size); } - [Test] + [TestMethod] public void Where_SingleArg_AllTrue_ReturnsAllIndices() { var arr = np.array(new[] { true, true, true }); @@ -79,7 +79,7 @@ public void Where_SingleArg_AllTrue_ReturnsAllIndices() result[0].Should().BeOfValues(0L, 1L, 2L); } - [Test] + [TestMethod] public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays() { // 2x2x2 array with some non-zero elements @@ -98,7 +98,7 @@ public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays() #region Three Arguments (element-wise selection) - [Test] + [TestMethod] public void Where_ThreeArgs_Basic_SelectsCorrectly() { // np.where(a < 5, a, 10*a) for a = arange(10) @@ -108,7 +108,7 @@ public void Where_ThreeArgs_Basic_SelectsCorrectly() result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); } - [Test] + [TestMethod] public void Where_ThreeArgs_BooleanCondition() { var cond = np.array(new[] { true, false, true, false }); @@ -119,7 +119,7 @@ public void Where_ThreeArgs_BooleanCondition() result.Should().BeOfValues(1, 20, 3, 40); } - [Test] + [TestMethod] public void Where_ThreeArgs_2D() { // np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]]) @@ -135,7 +135,7 @@ public void Where_ThreeArgs_2D() Assert.AreEqual(4, (int)result[1, 1]); } - [Test] + [TestMethod] public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy() { // np.where([0, 1, 2, 0], 100, -100) -> [-100, 100, 100, -100] @@ -149,7 +149,7 @@ public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy() #region Scalar Arguments - [Test] + [TestMethod] public void Where_ScalarX() { var cond = np.array(new[] { true, false, true, false }); @@ -159,7 +159,7 @@ public void Where_ScalarX() result.Should().BeOfValues(99, 20, 99, 40); } - [Test] + [TestMethod] public void Where_ScalarY() { var cond = np.array(new[] { true, false, true, false }); @@ -169,7 +169,7 @@ public void Where_ScalarY() result.Should().BeOfValues(1, -1, 3, -1); } - [Test] + [TestMethod] public void Where_BothScalars() { var cond = np.array(new[] { true, false, true, false }); @@ -178,7 +178,7 @@ public void Where_BothScalars() result.Should().BeOfValues(1, 0, 1, 0); } - [Test] + [TestMethod] public void Where_ScalarFloat() { var cond = np.array(new[] { true, false }); @@ -193,7 +193,7 @@ public void Where_ScalarFloat() #region Broadcasting - [Test] + [TestMethod] public void Where_Broadcasting_ScalarY() { // np.where(a < 4, a, -1) for 3x3 array @@ -208,7 +208,7 @@ public void Where_Broadcasting_ScalarY() Assert.AreEqual(-1, (int)result[2, 2]); } - [Test] + [TestMethod] public void Where_Broadcasting_DifferentShapes() { // cond: (2,1), x: (3,), y: (1,3) -> result: (2,3) @@ -228,7 +228,7 @@ public void Where_Broadcasting_DifferentShapes() Assert.AreEqual(30, (int)result[1, 2]); } - [Test] + [TestMethod] public void Where_Broadcasting_ColumnVector() { // cond: (3,1), x: scalar, y: (1,4) -> result: (3,4) @@ -253,7 +253,7 @@ public void Where_Broadcasting_ColumnVector() #region Type Promotion - [Test] + [TestMethod] public void Where_TypePromotion_IntFloat_ReturnsFloat64() { var cond = np.array(new[] { true, false }); @@ -264,7 +264,7 @@ public void Where_TypePromotion_IntFloat_ReturnsFloat64() Assert.AreEqual(2.5, (double)result[1], 1e-10); } - [Test] + [TestMethod] public void Where_TypePromotion_Int32Int64_ReturnsInt64() { var cond = np.array(new[] { true, false }); @@ -275,7 +275,7 @@ public void Where_TypePromotion_Int32Int64_ReturnsInt64() Assert.AreEqual(typeof(long), result.dtype); } - [Test] + [TestMethod] public void Where_TypePromotion_FloatDouble_ReturnsDouble() { var cond = np.array(new[] { true, false }); @@ -290,7 +290,7 @@ public void Where_TypePromotion_FloatDouble_ReturnsDouble() #region Edge Cases - [Test] + [TestMethod] public void Where_EmptyArrays_ThreeArgs() { var cond = np.array(new bool[0]); @@ -301,7 +301,7 @@ public void Where_EmptyArrays_ThreeArgs() Assert.AreEqual(0, result.size); } - [Test] + [TestMethod] public void Where_SingleElement() { var cond = np.array(new[] { true }); @@ -312,7 +312,7 @@ public void Where_SingleElement() Assert.AreEqual(42, (int)result[0]); } - [Test] + [TestMethod] public void Where_AllTrue_ReturnsAllX() { var cond = np.array(new[] { true, true, true }); @@ -323,7 +323,7 @@ public void Where_AllTrue_ReturnsAllX() result.Should().BeOfValues(1, 2, 3); } - [Test] + [TestMethod] public void Where_AllFalse_ReturnsAllY() { var cond = np.array(new[] { false, false, false }); @@ -334,7 +334,7 @@ public void Where_AllFalse_ReturnsAllY() result.Should().BeOfValues(10, 20, 30); } - [Test] + [TestMethod] public void Where_LargeArray() { var size = 100000; @@ -354,7 +354,7 @@ public void Where_LargeArray() #region NumPy Output Verification - [Test] + [TestMethod] public void Where_NumPyExample1() { // From NumPy docs: np.where([[True, False], [True, True]], @@ -371,7 +371,7 @@ public void Where_NumPyExample1() Assert.AreEqual(4, (int)result[1, 1]); } - [Test] + [TestMethod] public void Where_NumPyExample2() { // From NumPy docs: np.where(a < 5, a, 10*a) for a = arange(10) @@ -382,7 +382,7 @@ public void Where_NumPyExample2() result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); } - [Test] + [TestMethod] public void Where_NumPyExample3() { // From NumPy docs: np.where(a < 4, a, -1) for specific array @@ -405,7 +405,7 @@ public void Where_NumPyExample3() #region Dtype Coverage - [Test] + [TestMethod] public void Where_Dtype_Byte() { var cond = np.array(new[] { true, false }); @@ -417,7 +417,7 @@ public void Where_Dtype_Byte() result.Should().BeOfValues((byte)1, (byte)20); } - [Test] + [TestMethod] public void Where_Dtype_Int16() { var cond = np.array(new[] { true, false }); @@ -429,7 +429,7 @@ public void Where_Dtype_Int16() result.Should().BeOfValues((short)1, (short)20); } - [Test] + [TestMethod] public void Where_Dtype_Int32() { var cond = np.array(new[] { true, false }); @@ -441,7 +441,7 @@ public void Where_Dtype_Int32() result.Should().BeOfValues(1, 20); } - [Test] + [TestMethod] public void Where_Dtype_Int64() { var cond = np.array(new[] { true, false }); @@ -453,7 +453,7 @@ public void Where_Dtype_Int64() result.Should().BeOfValues(1L, 20L); } - [Test] + [TestMethod] public void Where_Dtype_Single() { var cond = np.array(new[] { true, false }); @@ -466,7 +466,7 @@ public void Where_Dtype_Single() Assert.AreEqual(20.5f, (float)result[1], 1e-6f); } - [Test] + [TestMethod] public void Where_Dtype_Double() { var cond = np.array(new[] { true, false }); @@ -479,7 +479,7 @@ public void Where_Dtype_Double() Assert.AreEqual(20.5, (double)result[1], 1e-10); } - [Test] + [TestMethod] public void Where_Dtype_Boolean() { var cond = np.array(new[] { true, false }); From ae4f1b882b93aebeeb6d7f13b0c4434de9afa3f2 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 19:19:06 +0300 Subject: [PATCH 15/19] fix(asanyarray): handle object[] via type-promotion path np.asanyarray(new object[]{1, 2.5, 3}) threw NotSupportedException because `case Array array` matched object[] first and `new NDArray(object[])` rejects object as an element type. object[] has no fixed dtype, so routing through the non-generic IEnumerable path (which applies NumPy-like type promotion) is the correct behavior. Added an explicit `case object[] objArr` branch that delegates to ConvertNonGenericEnumerable, which already handles: - Homogeneous object[]: detected via FindCommonNumericType, single dtype - Mixed object[]: promoted to common type (e.g. int + double -> double) - Empty object[]: returns empty double[] (matches NumPy float64 default) - Bool+int mix: promotes bool to int via Convert.ToInt32 (True=1, False=0) Regression tests added in np.asanyarray.Tests.cs covering all four cases. All 66 np.asanyarray tests pass on net8.0 and net10.0. --- src/NumSharp.Core/Creation/np.asanyarray.cs | 8 ++++ .../Creation/np.asanyarray.Tests.cs | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index acc03a28..34bb41a1 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -26,6 +26,14 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support if (dtype == null || Equals(nd.dtype, dtype)) return nd; return nd.astype(dtype, true); + case object[] objArr: + // object[] has no fixed dtype — route through type-promotion path. + // new NDArray(object[]) throws NotSupportedException since object isn't a + // supported element type. + ret = ConvertNonGenericEnumerable(objArr); + if (ret is null) + throw new NotSupportedException($"Unable to resolve asanyarray for object array of length {objArr.Length}."); + break; case Array array: ret = new NDArray(array); break; diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index bed654e3..ffbc9551 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -747,5 +747,50 @@ public void EmptyArrayList_ReturnsEmptyDoubleArray() } #endregion + + #region object[] regression + + [TestMethod] + public void ObjectArray_Homogeneous_Int() + { + var arr = new object[] { 1, 2, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(int)); + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + [TestMethod] + public void ObjectArray_MixedIntFloat_PromotesToDouble() + { + var arr = new object[] { 1, 2.5, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3).And.BeOfValues(1.0, 2.5, 3.0); + } + + [TestMethod] + public void ObjectArray_MixedBoolInt_PromotesToInt() + { + var arr = new object[] { true, 2, false }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(int)); + result.Should().BeShaped(3).And.BeOfValues(1, 2, 0); + } + + [TestMethod] + public void ObjectArray_Empty_ReturnsFloat64() + { + var arr = new object[0]; + var result = np.asanyarray(arr); + + result.size.Should().Be(0); + result.ndim.Should().Be(1); + result.dtype.Should().Be(typeof(double)); + } + + #endregion } } From f0473d2dff645a0e2f37fb8a9d449c709d4abc36 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 19:54:58 +0300 Subject: [PATCH 16/19] cleanup(where, asanyarray): remove dead code and trim noisy comments Code review caught dead code paths and over-narrative comments. Net change is -293/+19 across three files. ILKernelGenerator.Where.cs (-249 lines): - Delete `GetNPTypeCode` (use shared InfoOf.NPTypeCode instead). - Delete `GetMaskCreationMethod256/128` and the entire 200-line `Static Mask Creation Methods (fallback)` region (CreateMaskV256_*Byte and CreateMaskV128_*Byte). They were never called -- the inline IL emitter at EmitInlineMaskCreationV256/V128 handles the mask creation directly via the cached MethodInfo lookups. The static helpers existed as an early prototype fallback path that became unreachable. - Delete `_v256ZeroULong` field with the meaningless `IsStatic ? null! : null!` tautology (only `_v256GetZeroULong` is used). np.where.cs (+2 lines): - Add `default: throw NotSupportedException(...)` to `WhereKernelDispatch` switch. The kernel path is currently only reached for the 12 supported NPTypeCodes, but the missing default would silently fall through and return an uninitialized result if a new NPTypeCode were ever added without updating this switch. The iterator-path switch (line 142) already has this guard. np.asanyarray.cs (-43/+18 net): - Cap `stackalloc NPTypeCode[span.Length]` at 12 (max possible unique NPTypeCodes given the seenMask deduplication). The previous unbounded stackalloc could blow the stack for very large user lists. - Remove dead `hasDecimal` variable (set but never read; the early-exit for decimal returns immediately on first hit). - Trim narrative/microbenchmark comments per CLAUDE.md guidance: removed "Optimized: ...3-7x faster", "optimization #4", "~4x faster than always using Convert", "Pre-sized list (optimization: ...)", and a handful of WHAT-the-code-does comments that restated obvious switch arms. - Tighten Tuple/Enumerator helpers (collapse trivial if/else into ternary). Verified: 178 np.where + np.asanyarray tests still pass on net8.0 + net10.0. --- src/NumSharp.Core/APIs/np.where.cs | 2 + .../Kernels/ILKernelGenerator.Where.cs | 249 +----------------- src/NumSharp.Core/Creation/np.asanyarray.cs | 61 ++--- 3 files changed, 19 insertions(+), 293 deletions(-) diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs index f8da5a2f..b811eaf0 100644 --- a/src/NumSharp.Core/APIs/np.where.cs +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -209,6 +209,8 @@ private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray case NPTypeCode.Decimal: ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count); break; + default: + throw new NotSupportedException($"Type {outType} not supported for np.where"); } } } diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index 1b4eb4b0..a82b4574 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -5,6 +5,7 @@ using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; +using NumSharp.Utilities; // ============================================================================= // ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels @@ -400,7 +401,7 @@ private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) where T : unmanaged { long elementSize = Unsafe.SizeOf(); - var typeCode = GetNPTypeCode(); + var typeCode = InfoOf.NPTypeCode; // result[i] = cond[i] ? x[i] : y[i] var lblFalse = il.DefineLabel(); @@ -449,51 +450,6 @@ private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) EmitStoreIndirect(il, typeCode); } - private static NPTypeCode GetNPTypeCode() where T : unmanaged - { - if (typeof(T) == typeof(bool)) return NPTypeCode.Boolean; - if (typeof(T) == typeof(byte)) return NPTypeCode.Byte; - if (typeof(T) == typeof(short)) return NPTypeCode.Int16; - if (typeof(T) == typeof(ushort)) return NPTypeCode.UInt16; - if (typeof(T) == typeof(int)) return NPTypeCode.Int32; - if (typeof(T) == typeof(uint)) return NPTypeCode.UInt32; - if (typeof(T) == typeof(long)) return NPTypeCode.Int64; - if (typeof(T) == typeof(ulong)) return NPTypeCode.UInt64; - if (typeof(T) == typeof(char)) return NPTypeCode.Char; - if (typeof(T) == typeof(float)) return NPTypeCode.Single; - if (typeof(T) == typeof(double)) return NPTypeCode.Double; - if (typeof(T) == typeof(decimal)) return NPTypeCode.Decimal; - return NPTypeCode.Empty; - } - - #endregion - - #region Mask Creation Methods - - private static MethodInfo GetMaskCreationMethod256(int elementSize) - { - return elementSize switch - { - 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") - }; - } - - private static MethodInfo GetMaskCreationMethod128(int elementSize) - { - return elementSize switch - { - 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, - _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") - }; - } - #endregion #region Inline Mask IL Emission @@ -560,8 +516,6 @@ private static MethodInfo GetMaskCreationMethod128(int elementSize) private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(), m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); - private static readonly FieldInfo _v256ZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!.IsStatic - ? null! : null!; // Use GetMethod call instead private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!; private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256).GetProperty("Zero")!.GetMethod!; private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256).GetProperty("Zero")!.GetMethod!; @@ -716,205 +670,6 @@ private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) #endregion - #region Static Mask Creation Methods (fallback) - - /// - /// Create V256 mask from 32 bools for 1-byte elements. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools) - { - var vec = Vector256.Load(bools); - var zero = Vector256.Zero; - var isZero = Vector256.Equals(vec, zero); - return Vector256.OnesComplement(isZero); - } - - /// - /// Create V256 mask from 16 bools for 2-byte elements. - /// Uses AVX2 vpmovzxbw instruction for single-instruction expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) - { - if (Avx2.IsSupported) - { - // Load 16 bytes into Vector128, zero-extend each byte to 16-bit - // vpmovzxbw: byte -> word (16 bytes -> 16 words) - var bytes128 = Vector128.Load(bools); - var expanded = Avx2.ConvertToVector256Int16(bytes128).AsUInt16(); - // Compare with zero: non-zero becomes 0xFFFF, zero stays 0 - return Vector256.GreaterThan(expanded, Vector256.Zero); - } - - // Scalar fallback for non-AVX2 systems - return Vector256.Create( - bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[7] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[8] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[9] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[10] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[11] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[12] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[13] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[14] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[15] != 0 ? (ushort)0xFFFF : (ushort)0 - ); - } - - /// - /// Create V256 mask from 8 bools for 4-byte elements. - /// Uses AVX2 vpmovzxbd instruction for single-instruction expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) - { - if (Avx2.IsSupported) - { - // Load 8 bytes into low bytes of Vector128, zero-extend each byte to 32-bit - // vpmovzxbd: byte -> dword (8 bytes -> 8 dwords) - var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); - var expanded = Avx2.ConvertToVector256Int32(bytes128).AsUInt32(); - // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 - return Vector256.GreaterThan(expanded, Vector256.Zero); - } - - // Scalar fallback for non-AVX2 systems - return Vector256.Create( - bools[0] != 0 ? 0xFFFFFFFFu : 0u, - bools[1] != 0 ? 0xFFFFFFFFu : 0u, - bools[2] != 0 ? 0xFFFFFFFFu : 0u, - bools[3] != 0 ? 0xFFFFFFFFu : 0u, - bools[4] != 0 ? 0xFFFFFFFFu : 0u, - bools[5] != 0 ? 0xFFFFFFFFu : 0u, - bools[6] != 0 ? 0xFFFFFFFFu : 0u, - bools[7] != 0 ? 0xFFFFFFFFu : 0u - ); - } - - /// - /// Create V256 mask from 4 bools for 8-byte elements. - /// Uses AVX2 vpmovzxbq instruction for single-instruction expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools) - { - if (Avx2.IsSupported) - { - // Load 4 bytes into low bytes of Vector128, zero-extend each byte to 64-bit - // vpmovzxbq: byte -> qword (4 bytes -> 4 qwords) - var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); - var expanded = Avx2.ConvertToVector256Int64(bytes128).AsUInt64(); - // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 - return Vector256.GreaterThan(expanded, Vector256.Zero); - } - - // Scalar fallback for non-AVX2 systems - return Vector256.Create( - bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, - bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, - bools[2] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, - bools[3] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul - ); - } - - /// - /// Create V128 mask from 16 bools for 1-byte elements. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools) - { - var vec = Vector128.Load(bools); - var zero = Vector128.Zero; - var isZero = Vector128.Equals(vec, zero); - return Vector128.OnesComplement(isZero); - } - - /// - /// Create V128 mask from 8 bools for 2-byte elements. - /// Uses SSE4.1 pmovzxbw instruction for efficient expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) - { - if (Sse41.IsSupported) - { - // Load 8 bytes, zero-extend each to 16-bit - // pmovzxbw: byte -> word (8 bytes -> 8 words) - var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); - var expanded = Sse41.ConvertToVector128Int16(bytes128).AsUInt16(); - return Vector128.GreaterThan(expanded, Vector128.Zero); - } - - // Scalar fallback - return Vector128.Create( - bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, - bools[7] != 0 ? (ushort)0xFFFF : (ushort)0 - ); - } - - /// - /// Create V128 mask from 4 bools for 4-byte elements. - /// Uses SSE4.1 pmovzxbd instruction for efficient expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) - { - if (Sse41.IsSupported) - { - // Load 4 bytes, zero-extend each to 32-bit - // pmovzxbd: byte -> dword (4 bytes -> 4 dwords) - var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); - var expanded = Sse41.ConvertToVector128Int32(bytes128).AsUInt32(); - return Vector128.GreaterThan(expanded, Vector128.Zero); - } - - // Scalar fallback - return Vector128.Create( - bools[0] != 0 ? 0xFFFFFFFFu : 0u, - bools[1] != 0 ? 0xFFFFFFFFu : 0u, - bools[2] != 0 ? 0xFFFFFFFFu : 0u, - bools[3] != 0 ? 0xFFFFFFFFu : 0u - ); - } - - /// - /// Create V128 mask from 2 bools for 8-byte elements. - /// Uses SSE4.1 pmovzxbq instruction for efficient expansion. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools) - { - if (Sse41.IsSupported) - { - // Load 2 bytes, zero-extend each to 64-bit - // pmovzxbq: byte -> qword (2 bytes -> 2 qwords) - var bytes128 = Vector128.CreateScalar(*(ushort*)bools).AsByte(); - var expanded = Sse41.ConvertToVector128Int64(bytes128).AsUInt64(); - return Vector128.GreaterThan(expanded, Vector128.Zero); - } - - // Scalar fallback - return Vector128.Create( - bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, - bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul - ); - } - - #endregion - #region Scalar Fallback /// diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 34bb41a1..974811e5 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -41,8 +41,6 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support ret = str; //implicit cast located in NDArray.Implicit.Array break; - // Handle typed IEnumerable for all 12 NumSharp-supported types - // Optimized: Use CopyTo for ICollection (3-7x faster than ToArray for small collections) case IEnumerable e: ret = np.array(ToArrayFast(e)); break; case IEnumerable e: ret = np.array(ToArrayFast(e)); break; case IEnumerable e: ret = np.array(ToArrayFast(e)); break; @@ -58,14 +56,13 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support default: var type = a.GetType(); - // Check if it's a scalar (primitive or decimal) if (type.IsPrimitive || type == typeof(decimal)) { ret = NDArray.Scalar(a); break; } - // Handle Memory and ReadOnlyMemory - they don't implement IEnumerable + // Memory/ReadOnlyMemory do not implement IEnumerable. if (type.IsGenericType) { var genericDef = type.GetGenericTypeDefinition(); @@ -77,7 +74,6 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support } } - // Handle Tuple<> and ValueTuple<> - they implement ITuple if (a is ITuple tuple) { ret = ConvertTuple(tuple); @@ -85,7 +81,6 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support break; } - // Fallback: non-generic IEnumerable (element type detected from first item) if (a is IEnumerable enumerable) { ret = ConvertNonGenericEnumerable(enumerable); @@ -93,7 +88,6 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support break; } - // Fallback: non-generic IEnumerator if (a is IEnumerator enumerator) { ret = ConvertEnumerator(enumerator); @@ -111,34 +105,28 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support } /// - /// Optimized ToArray for IEnumerable<T>. - /// Uses CopyTo for ICollection<T> (3-7x faster for small collections). - /// For List<T>, uses CollectionsMarshal.AsSpan for direct memory access. - /// Uses GC.AllocateUninitializedArray to skip zeroing (4x faster allocation). + /// Copies an into a freshly allocated []. + /// Specialised for List<T> and ICollection<T> to skip the enumerator and to + /// use since we overwrite every slot. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static T[] ToArrayFast(IEnumerable source) { - // Fast path for List - use CollectionsMarshal for direct span access if (source is List list) { var span = CollectionsMarshal.AsSpan(list); - // Use uninitialized array - we're about to overwrite all elements var arr = GC.AllocateUninitializedArray(span.Length); span.CopyTo(arr); return arr; } - // Fast path for ICollection - use CopyTo (avoids enumerator overhead) if (source is ICollection collection) { - // Use uninitialized array - CopyTo will overwrite all elements var arr = GC.AllocateUninitializedArray(collection.Count); collection.CopyTo(arr, 0); return arr; } - // Fallback to LINQ ToArray for other IEnumerable return source.ToArray(); } @@ -151,7 +139,6 @@ private static NDArray ConvertMemory(object a, Type type) var elementType = type.GetGenericArguments()[0]; var isReadOnly = type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); - // Use Span.CopyTo + GC.AllocateUninitializedArray instead of ToArray() if (elementType == typeof(bool)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); if (elementType == typeof(byte)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); if (elementType == typeof(short)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); @@ -189,16 +176,13 @@ private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) /// /// Converts a non-generic IEnumerator to an NDArray. /// Element type is detected from items with NumPy-like type promotion. - /// Empty collections return empty double[] to match NumPy's behavior. + /// Empty collections return empty double[] to match NumPy's float64 default. /// private static NDArray ConvertEnumerator(IEnumerator enumerator) { - // Pre-size list if count is known (optimization #4) - List items; - if (enumerator is ICollection collection) - items = new List(collection.Count); - else - items = new List(); + List items = enumerator is ICollection collection + ? new List(collection.Count) + : new List(); while (enumerator.MoveNext()) { @@ -207,7 +191,6 @@ private static NDArray ConvertEnumerator(IEnumerator enumerator) items.Add(item); } - // Empty collection: return empty double[] (NumPy defaults to float64) if (items.Count == 0) return np.array(Array.Empty()); @@ -218,34 +201,29 @@ private static NDArray ConvertEnumerator(IEnumerator enumerator) /// /// Finds the common numeric type for a list of objects (NumPy-like promotion). /// Uses existing _FindCommonType_Scalar for consistent type promotion. - /// Early exit when highest-priority types (decimal/double) are found. /// private static Type FindCommonNumericType(List items) { - // Use CollectionsMarshal.AsSpan for faster iteration (no bounds checks) var span = CollectionsMarshal.AsSpan(items); - // Early exit optimization: track highest-priority types seen - bool hasDecimal = false; bool hasDouble = false; bool hasFloat = false; Type firstType = null; - // Collect unique type codes for _FindCommonType_Scalar - Span typeCodes = stackalloc NPTypeCode[span.Length]; + // At most 12 unique NPTypeCode values exist; bound the stackalloc accordingly + // (otherwise large user lists could blow the stack). + Span typeCodes = stackalloc NPTypeCode[12]; int uniqueCount = 0; - uint seenMask = 0; // Bitmask for deduplication (NPTypeCode values are small) + uint seenMask = 0; for (int i = 0; i < span.Length; i++) { var t = span[i].GetType(); firstType ??= t; - // Early exit: decimal wins everything if (t == typeof(decimal)) return typeof(decimal); - // Track floating point for early double detection if (t == typeof(double)) hasDouble = true; else if (t == typeof(float)) hasFloat = true; @@ -258,11 +236,9 @@ private static Type FindCommonNumericType(List items) } } - // Early exit: any floating point promotes to double if (hasDouble || hasFloat) return typeof(double); - // Use existing type promotion logic for remaining cases if (uniqueCount == 1) return firstType ?? typeof(double); @@ -271,16 +247,13 @@ private static Type FindCommonNumericType(List items) } /// - /// Converts a Tuple or ValueTuple to an NDArray. - /// Uses ITuple interface available in .NET Core 2.0+. - /// Optimized: pre-sized List, early exit for decimal/double. + /// Converts a Tuple or ValueTuple to an NDArray via the ITuple interface. /// private static NDArray ConvertTuple(ITuple tuple) { if (tuple.Length == 0) return np.array(Array.Empty()); - // Pre-sized list (optimization: avoid resize for known count) var items = new List(tuple.Length); for (int i = 0; i < tuple.Length; i++) @@ -299,17 +272,13 @@ private static NDArray ConvertTuple(ITuple tuple) /// /// Converts a list of objects to an NDArray of the specified element type. - /// Uses CollectionsMarshal.AsSpan for bounds-check-free iteration. - /// Uses pattern matching for fast direct cast when types match, with Convert fallback. - /// This is ~4x faster than always using Convert for homogeneous collections. + /// The pattern is T v ? v : Convert.ToT(item) takes the direct-cast fast path for + /// homogeneous collections while still handling mixed-type promotion via Convert. /// private static NDArray ConvertObjectListToNDArray(List items, Type elementType) { - // Use CollectionsMarshal.AsSpan for faster iteration (no bounds checks) var span = CollectionsMarshal.AsSpan(items); - // Pattern: `is T v ? v : Convert.ToT(item)` gives direct cast speed for homogeneous - // collections while still handling mixed types correctly if (elementType == typeof(bool)) { var arr = GC.AllocateUninitializedArray(span.Length); From 3811960dadc1694ecf230258faee31d61f380a4c Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 20:41:39 +0300 Subject: [PATCH 17/19] fix(asanyarray,where): pure-float object[] promotion + hot-path short-circuits Second-round code review caught one real bug and several minor efficiency issues. 1. Fix: FindCommonNumericType promoted pure-float object[] to double np.asanyarray(new object[]{1.5f, 2.5f}) returned Double instead of Single. Root cause: the early-exit `if (hasDouble || hasFloat) return typeof(double)` fired before the `uniqueCount == 1` check that preserves the original dtype. Removing the hasFloat arm lets the general path handle it: - Pure float32 -> uniqueCount == 1 -> returns firstType (Single) -- matches NumPy - int + float32 -> _FindCommonType_Scalar -> returns Double -- matches NumPy NEP50 - Pure float64 -> unchanged (still Double) - decimal-wins-everything early exit preserved. Two regression tests added: - ObjectArray_AllFloat_PreservesSingle - ObjectArray_MixedIntAndFloat32_PromotesToDouble 2. Perf: skip type promotion in np.where when x.dtype == y.dtype Previously _FindCommonType(x, y) always ran, even when both operands shared a dtype. Short-circuit to x.GetTypeCode in that case, saving one dict lookup + two astype traversals per call. The NEP50 lookup still runs when dtypes differ, preserving scalar+array promotion semantics. 3. Perf: skip broadcast_arrays when all three shapes already match broadcast_arrays allocates three fresh NDArrays plus helper Shape[]. For the common case of np.where(mask, arr, other_arr) where all three arrays share a shape, this is wasted. Skip it when condition.Shape == x.Shape == y.Shape (Shape == compares by dimensions). 4. Perf: cache Vector256/Vector128 generic MethodInfo EmitWhereV256BodyWithOffset and EmitWhereV128BodyWithOffset did Array.Find(typeof(Vector*).GetMethods(), ...) three times per call, each scanning ~100 methods. Per kernel generation (4-way unrolled + 1 remainder call = 5 calls), that was 15 reflection scans per T, or ~180 on first use across all 12 dtypes. Cached as six static readonly fields; only MakeGenericMethod(typeof(T)) runs per call. 5. Polish: doc + error message - where(NDArray) xmldoc was copy-pasted from the 3-arg overload ("Return elements chosen from x or y"); rewritten to describe nonzero semantics. - object[] NotSupportedException now names the actual problem ("element type is not a supported NumSharp dtype") instead of just reporting the length. Verified: 180 np.where + np.asanyarray tests pass on net8.0 + net10.0. --- src/NumSharp.Core/APIs/np.where.cs | 37 ++++++++++------ .../Kernels/ILKernelGenerator.Where.cs | 42 ++++++++++--------- src/NumSharp.Core/Creation/np.asanyarray.cs | 11 +---- .../Creation/np.asanyarray.Tests.cs | 23 ++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs index b811eaf0..14633e3c 100644 --- a/src/NumSharp.Core/APIs/np.where.cs +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -7,10 +7,11 @@ namespace NumSharp public static partial class np { /// - /// Return elements chosen from `x` or `y` depending on `condition`. + /// Equivalent to : returns the indices where + /// is non-zero. /// - /// Where True, yield `x`, otherwise yield `y`. - /// Tuple of arrays with indices where condition is non-zero (equivalent to np.nonzero). + /// Input array. Non-zero entries yield their indices. + /// Tuple of arrays with indices where condition is non-zero, one per dimension. /// https://numpy.org/doc/stable/reference/generated/numpy.where.html public static NDArray[] where(NDArray condition) { @@ -62,17 +63,29 @@ public static NDArray where(NDArray condition, object x, object y) /// private static NDArray where_internal(NDArray condition, NDArray x, NDArray y) { - // Broadcast all three arrays to common shape - var broadcasted = broadcast_arrays(condition, x, y); - var cond = broadcasted[0]; - var xArr = broadcasted[1]; - var yArr = broadcasted[2]; + // Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three + // already share a shape — the frequent case of np.where(mask, arr, other_arr). + NDArray cond, xArr, yArr; + if (condition.Shape == x.Shape && x.Shape == y.Shape) + { + cond = condition; + xArr = x; + yArr = y; + } + else + { + var broadcasted = broadcast_arrays(condition, x, y); + cond = broadcasted[0]; + xArr = broadcasted[1]; + yArr = broadcasted[2]; + } - // Determine output dtype using existing type promotion system - // _FindCommonType already handles NEP50: scalar+array → array wins - var outType = _FindCommonType(x, y); + // When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to + // _FindCommonType which handles the scalar+array NEP50 rules. + var outType = x.GetTypeCode == y.GetTypeCode + ? x.GetTypeCode + : _FindCommonType(x, y); - // Convert x and y to output type if needed (required for kernel and iterator paths) if (xArr.GetTypeCode != outType) xArr = xArr.astype(outType, copy: false); if (yArr.GetTypeCode != outType) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index a82b4574..c55be24c 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -251,18 +251,27 @@ private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) wher il.MarkLabel(lblVectorLoopEnd); } + // Generic method definitions are cached once at class init; MakeGenericMethod is the + // only per-T work needed during kernel generation. + private static readonly MethodInfo _v256LoadGeneric = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!; + private static readonly MethodInfo _v256StoreGeneric = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!; + private static readonly MethodInfo _v256ConditionalSelectGeneric = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!; + + private static readonly MethodInfo _v128LoadGeneric = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!; + private static readonly MethodInfo _v128StoreGeneric = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!; + private static readonly MethodInfo _v128ConditionalSelectGeneric = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!; + private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - // Get Vector256 methods via reflection - need to find generic method definitions first - var loadMethod = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! - .MakeGenericMethod(typeof(T)); - var storeMethod = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! - .MakeGenericMethod(typeof(T)); - var selectMethod = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! - .MakeGenericMethod(typeof(T)); + var loadMethod = _v256LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = _v256StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = _v256ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); // cond @@ -327,16 +336,9 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - // Get Vector128 methods via reflection - need to find generic method definitions first - var loadMethod = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! - .MakeGenericMethod(typeof(T)); - var storeMethod = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! - .MakeGenericMethod(typeof(T)); - var selectMethod = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! - .MakeGenericMethod(typeof(T)); + var loadMethod = _v128LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = _v128StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = _v128ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 974811e5..e575250c 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -32,7 +32,7 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support // supported element type. ret = ConvertNonGenericEnumerable(objArr); if (ret is null) - throw new NotSupportedException($"Unable to resolve asanyarray for object array of length {objArr.Length}."); + throw new NotSupportedException($"Unable to resolve asanyarray for object[] (length {objArr.Length}): element type is not a supported NumSharp dtype."); break; case Array array: ret = new NDArray(array); @@ -206,8 +206,6 @@ private static Type FindCommonNumericType(List items) { var span = CollectionsMarshal.AsSpan(items); - bool hasDouble = false; - bool hasFloat = false; Type firstType = null; // At most 12 unique NPTypeCode values exist; bound the stackalloc accordingly @@ -221,12 +219,10 @@ private static Type FindCommonNumericType(List items) var t = span[i].GetType(); firstType ??= t; + // decimal wins everything in NumPy promotion if (t == typeof(decimal)) return typeof(decimal); - if (t == typeof(double)) hasDouble = true; - else if (t == typeof(float)) hasFloat = true; - var code = t.GetTypeCode(); var bit = 1u << (int)code; if ((seenMask & bit) == 0) @@ -236,9 +232,6 @@ private static Type FindCommonNumericType(List items) } } - if (hasDouble || hasFloat) - return typeof(double); - if (uniqueCount == 1) return firstType ?? typeof(double); diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs index ffbc9551..09b626fe 100644 --- a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -791,6 +791,29 @@ public void ObjectArray_Empty_ReturnsFloat64() result.dtype.Should().Be(typeof(double)); } + [TestMethod] + public void ObjectArray_AllFloat_PreservesSingle() + { + // Regression: an earlier FindCommonNumericType short-circuit promoted any float + // to double. NumPy preserves float32 for homogeneous float32 inputs. + var arr = new object[] { 1.5f, 2.5f, 3.5f }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3).And.BeOfValues(1.5f, 2.5f, 3.5f); + } + + [TestMethod] + public void ObjectArray_MixedIntAndFloat32_PromotesToDouble() + { + // int + float32 -> float64 per NumPy NEP50. + var arr = new object[] { 1, 2.5f, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3); + } + #endregion } } From 21d7eecec2acc1f47c54c48949309bbab5c7def8 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 20:51:25 +0300 Subject: [PATCH 18/19] refactor(where): consolidate reflection cache into partial CachedMethods np.where's IL kernel had ~35 MethodInfo fields scattered across the file using `Array.Find(...)!` null-forgiveness, which throws NullReferenceException at first use if a framework method ever gets renamed/removed. The existing CachedMethods nested class in ILKernelGenerator.cs follows a fail-fast `?? throw new MissingMethodException(type, name)` pattern, keyed per MethodInfo, and is the project convention for every other kernel partial. Changes: - Make `CachedMethods` a `partial` nested class so Where-specific reflection can live alongside the kernel file it serves. (ILKernelGenerator.cs: 1 line.) - Delete the 35 `_v128*/_v256*/_avx2*/_sse41*` private fields from ILKernelGenerator.Where.cs and move them into a new "Where Kernel Methods" region inside a partial `CachedMethods` declaration at the bottom of that file. Renamed to PascalCase (e.g. _v256LoadByte -> V256LoadByte) to match the existing CachedMethods naming convention. - Introduce three small helpers inside CachedMethods: - FindGenericMethod(Type, string name, int? paramCount) - wraps the `Array.Find(GetMethods(), m => m.IsGenericMethodDefinition && ...)` pattern with a MissingMethodException fail-fast throw. Handles the overload count disambiguation for Load/Store. - FindMethodExact(Type, string name, Type[] argTypes) - wraps GetMethod with a fail-fast throw. Used for Avx2/Sse41 specific overloads. - GetZeroGetter(Type vectorOfT) - wraps Property("Zero").GetMethod with a fail-fast throw. Used for the 8 Vector*.Zero getters. - Update all 41 call sites in EmitInlineMaskCreationV256/V128 and EmitWhereV256/V128BodyWithOffset to use CachedMethods.Xxx. Behaviour unchanged; 180 np.where + np.asanyarray tests still pass on net8.0 + net10.0. The single real benefit is earlier and clearer failure if any of the ~35 framework API names change in a future .NET release. --- .../Kernels/ILKernelGenerator.Where.cs | 269 +++++++++--------- .../Backends/Kernels/ILKernelGenerator.cs | 2 +- 2 files changed, 136 insertions(+), 135 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index c55be24c..3f4f371d 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -251,27 +251,11 @@ private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) wher il.MarkLabel(lblVectorLoopEnd); } - // Generic method definitions are cached once at class init; MakeGenericMethod is the - // only per-T work needed during kernel generation. - private static readonly MethodInfo _v256LoadGeneric = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!; - private static readonly MethodInfo _v256StoreGeneric = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!; - private static readonly MethodInfo _v256ConditionalSelectGeneric = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!; - - private static readonly MethodInfo _v128LoadGeneric = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!; - private static readonly MethodInfo _v128StoreGeneric = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!; - private static readonly MethodInfo _v128ConditionalSelectGeneric = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!; - private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - var loadMethod = _v256LoadGeneric.MakeGenericMethod(typeof(T)); - var storeMethod = _v256StoreGeneric.MakeGenericMethod(typeof(T)); - var selectMethod = _v256ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); + var loadMethod = CachedMethods.V256LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = CachedMethods.V256StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = CachedMethods.V256ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); // cond @@ -336,9 +320,9 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - var loadMethod = _v128LoadGeneric.MakeGenericMethod(typeof(T)); - var storeMethod = _v128StoreGeneric.MakeGenericMethod(typeof(T)); - var selectMethod = _v128ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); + var loadMethod = CachedMethods.V128LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = CachedMethods.V128StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = CachedMethods.V128ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); @@ -456,77 +440,8 @@ private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) #region Inline Mask IL Emission - // Cache reflection lookups for inline emission - private static readonly MethodInfo _v128LoadByte = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); - private static readonly MethodInfo _v256LoadByte = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); - - private static readonly MethodInfo _v128CreateScalarUInt = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); - private static readonly MethodInfo _v128CreateScalarULong = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); - private static readonly MethodInfo _v128CreateScalarUShort = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); - - // AsByte is an extension method on Vector128 static class, not instance method - private static readonly MethodInfo _v128UIntAsByte = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); - private static readonly MethodInfo _v128ULongAsByte = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); - private static readonly MethodInfo _v128UShortAsByte = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); - - private static readonly MethodInfo _avx2ConvertToV256Int64 = typeof(Avx2).GetMethod("ConvertToVector256Int64", new[] { typeof(Vector128) })!; - private static readonly MethodInfo _avx2ConvertToV256Int32 = typeof(Avx2).GetMethod("ConvertToVector256Int32", new[] { typeof(Vector128) })!; - private static readonly MethodInfo _avx2ConvertToV256Int16 = typeof(Avx2).GetMethod("ConvertToVector256Int16", new[] { typeof(Vector128) })!; - - private static readonly MethodInfo _sse41ConvertToV128Int64 = typeof(Sse41).GetMethod("ConvertToVector128Int64", new[] { typeof(Vector128) })!; - private static readonly MethodInfo _sse41ConvertToV128Int32 = typeof(Sse41).GetMethod("ConvertToVector128Int32", new[] { typeof(Vector128) })!; - private static readonly MethodInfo _sse41ConvertToV128Int16 = typeof(Sse41).GetMethod("ConvertToVector128Int16", new[] { typeof(Vector128) })!; - - // As* methods are extension methods on Vector256/Vector128 static classes - private static readonly MethodInfo _v256LongAsULong = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); - private static readonly MethodInfo _v256IntAsUInt = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); - private static readonly MethodInfo _v256ShortAsUShort = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); - - private static readonly MethodInfo _v128LongAsULong = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); - private static readonly MethodInfo _v128IntAsUInt = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); - private static readonly MethodInfo _v128ShortAsUShort = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); - - private static readonly MethodInfo _v256GreaterThanULong = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); - private static readonly MethodInfo _v256GreaterThanUInt = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); - private static readonly MethodInfo _v256GreaterThanUShort = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); - private static readonly MethodInfo _v256GreaterThanByte = Array.Find(typeof(Vector256).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); - - private static readonly MethodInfo _v128GreaterThanULong = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); - private static readonly MethodInfo _v128GreaterThanUInt = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); - private static readonly MethodInfo _v128GreaterThanUShort = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); - private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(), - m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); - - private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v256GetZeroByte = typeof(Vector256).GetProperty("Zero")!.GetMethod!; - - private static readonly MethodInfo _v128GetZeroULong = typeof(Vector128).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v128GetZeroUInt = typeof(Vector128).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v128GetZeroUShort = typeof(Vector128).GetProperty("Zero")!.GetMethod!; - private static readonly MethodInfo _v128GetZeroByte = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + // Vector-related MethodInfos for np.where are cached in the partial CachedMethods class + // below (see "Where Kernel Methods" region at the end of this file). /// /// Emit inline V256 mask creation. Stack: byte* -> Vector256{T} (as mask) @@ -541,56 +456,56 @@ private static void EmitInlineMaskCreationV256(ILGenerator il, int elementSize) // *(uint*)ptr il.Emit(OpCodes.Ldind_U4); // Vector128.CreateScalar(value) - il.Emit(OpCodes.Call, _v128CreateScalarUInt); + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUInt); // .AsByte() - il.Emit(OpCodes.Call, _v128UIntAsByte); + il.Emit(OpCodes.Call, CachedMethods.V128UIntAsByte); // Avx2.ConvertToVector256Int64(bytes) - il.Emit(OpCodes.Call, _avx2ConvertToV256Int64); + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int64); // .AsUInt64() - il.Emit(OpCodes.Call, _v256LongAsULong); + il.Emit(OpCodes.Call, CachedMethods.V256LongAsULong); // Vector256.Zero - il.Emit(OpCodes.Call, _v256GetZeroULong); + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroULong); // Vector256.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v256GreaterThanULong); + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanULong); break; case 4: // float/int: load 8 bytes, expand to 8 dwords // *(ulong*)ptr il.Emit(OpCodes.Ldind_I8); // Vector128.CreateScalar(value) - il.Emit(OpCodes.Call, _v128CreateScalarULong); + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarULong); // .AsByte() - il.Emit(OpCodes.Call, _v128ULongAsByte); + il.Emit(OpCodes.Call, CachedMethods.V128ULongAsByte); // Avx2.ConvertToVector256Int32(bytes) - il.Emit(OpCodes.Call, _avx2ConvertToV256Int32); + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int32); // .AsUInt32() - il.Emit(OpCodes.Call, _v256IntAsUInt); + il.Emit(OpCodes.Call, CachedMethods.V256IntAsUInt); // Vector256.Zero - il.Emit(OpCodes.Call, _v256GetZeroUInt); + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroUInt); // Vector256.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v256GreaterThanUInt); + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanUInt); break; case 2: // short/char: load 16 bytes, expand to 16 words // Vector128.Load(ptr) - il.Emit(OpCodes.Call, _v128LoadByte); + il.Emit(OpCodes.Call, CachedMethods.V128LoadByte); // Avx2.ConvertToVector256Int16(bytes) - il.Emit(OpCodes.Call, _avx2ConvertToV256Int16); + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int16); // .AsUInt16() - il.Emit(OpCodes.Call, _v256ShortAsUShort); + il.Emit(OpCodes.Call, CachedMethods.V256ShortAsUShort); // Vector256.Zero - il.Emit(OpCodes.Call, _v256GetZeroUShort); + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroUShort); // Vector256.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v256GreaterThanUShort); + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanUShort); break; case 1: // byte/bool: load 32 bytes, compare directly // Vector256.Load(ptr) - il.Emit(OpCodes.Call, _v256LoadByte); + il.Emit(OpCodes.Call, CachedMethods.V256LoadByte); // Vector256.Zero - il.Emit(OpCodes.Call, _v256GetZeroByte); + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroByte); // Vector256.GreaterThan(vec, zero) - il.Emit(OpCodes.Call, _v256GreaterThanByte); + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanByte); break; default: @@ -609,60 +524,60 @@ private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) // *(ushort*)ptr il.Emit(OpCodes.Ldind_U2); // Vector128.CreateScalar(value) - il.Emit(OpCodes.Call, _v128CreateScalarUShort); + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUShort); // .AsByte() - il.Emit(OpCodes.Call, _v128UShortAsByte); + il.Emit(OpCodes.Call, CachedMethods.V128UShortAsByte); // Sse41.ConvertToVector128Int64(bytes) - il.Emit(OpCodes.Call, _sse41ConvertToV128Int64); + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int64); // .AsUInt64() - il.Emit(OpCodes.Call, _v128LongAsULong); + il.Emit(OpCodes.Call, CachedMethods.V128LongAsULong); // Vector128.Zero - il.Emit(OpCodes.Call, _v128GetZeroULong); + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroULong); // Vector128.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v128GreaterThanULong); + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanULong); break; case 4: // float/int: load 4 bytes, expand to 4 dwords // *(uint*)ptr il.Emit(OpCodes.Ldind_U4); // Vector128.CreateScalar(value) - il.Emit(OpCodes.Call, _v128CreateScalarUInt); + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUInt); // .AsByte() - il.Emit(OpCodes.Call, _v128UIntAsByte); + il.Emit(OpCodes.Call, CachedMethods.V128UIntAsByte); // Sse41.ConvertToVector128Int32(bytes) - il.Emit(OpCodes.Call, _sse41ConvertToV128Int32); + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int32); // .AsUInt32() - il.Emit(OpCodes.Call, _v128IntAsUInt); + il.Emit(OpCodes.Call, CachedMethods.V128IntAsUInt); // Vector128.Zero - il.Emit(OpCodes.Call, _v128GetZeroUInt); + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroUInt); // Vector128.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v128GreaterThanUInt); + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanUInt); break; case 2: // short/char: load 8 bytes, expand to 8 words // *(ulong*)ptr il.Emit(OpCodes.Ldind_I8); // Vector128.CreateScalar(value) - il.Emit(OpCodes.Call, _v128CreateScalarULong); + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarULong); // .AsByte() - il.Emit(OpCodes.Call, _v128ULongAsByte); + il.Emit(OpCodes.Call, CachedMethods.V128ULongAsByte); // Sse41.ConvertToVector128Int16(bytes) - il.Emit(OpCodes.Call, _sse41ConvertToV128Int16); + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int16); // .AsUInt16() - il.Emit(OpCodes.Call, _v128ShortAsUShort); + il.Emit(OpCodes.Call, CachedMethods.V128ShortAsUShort); // Vector128.Zero - il.Emit(OpCodes.Call, _v128GetZeroUShort); + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroUShort); // Vector128.GreaterThan(expanded, zero) - il.Emit(OpCodes.Call, _v128GreaterThanUShort); + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanUShort); break; case 1: // byte/bool: load 16 bytes, compare directly // Vector128.Load(ptr) - il.Emit(OpCodes.Call, _v128LoadByte); + il.Emit(OpCodes.Call, CachedMethods.V128LoadByte); // Vector128.Zero - il.Emit(OpCodes.Call, _v128GetZeroByte); + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroByte); // Vector128.GreaterThan(vec, zero) - il.Emit(OpCodes.Call, _v128GreaterThanByte); + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanByte); break; default: @@ -687,5 +602,91 @@ private static unsafe void WhereScalar(bool* cond, T* x, T* y, T* result, lon } #endregion + + // Per the CachedMethods pattern in ILKernelGenerator.cs, reflection lookups for np.where + // live alongside the other cached entries. Fail-fast at type init so a renamed API shows + // up immediately instead of NREs at first use. + private static partial class CachedMethods + { + #region Where Kernel Methods + + private static MethodInfo FindGenericMethod(Type container, string name, int? paramCount = null) + { + foreach (var m in container.GetMethods()) + { + if (m.Name == name && m.IsGenericMethodDefinition && + (paramCount is null || m.GetParameters().Length == paramCount.Value)) + return m; + } + throw new MissingMethodException(container.FullName, name); + } + + private static MethodInfo FindMethodExact(Type container, string name, Type[] argTypes) + => container.GetMethod(name, argTypes) + ?? throw new MissingMethodException(container.FullName, name); + + private static MethodInfo GetZeroGetter(Type vectorOfT) + => vectorOfT.GetProperty("Zero")?.GetMethod + ?? throw new MissingMethodException(vectorOfT.FullName, "get_Zero"); + + // Generic definitions — caller must MakeGenericMethod(typeof(T)) before emitting. + public static readonly MethodInfo V256LoadGeneric = FindGenericMethod(typeof(Vector256), "Load", 1); + public static readonly MethodInfo V256StoreGeneric = FindGenericMethod(typeof(Vector256), "Store", 2); + public static readonly MethodInfo V256ConditionalSelectGeneric = FindGenericMethod(typeof(Vector256), "ConditionalSelect"); + + public static readonly MethodInfo V128LoadGeneric = FindGenericMethod(typeof(Vector128), "Load", 1); + public static readonly MethodInfo V128StoreGeneric = FindGenericMethod(typeof(Vector128), "Store", 2); + public static readonly MethodInfo V128ConditionalSelectGeneric = FindGenericMethod(typeof(Vector128), "ConditionalSelect"); + + // Already-specialised generic methods used during mask creation. + public static readonly MethodInfo V256LoadByte = FindGenericMethod(typeof(Vector256), "Load").MakeGenericMethod(typeof(byte)); + public static readonly MethodInfo V128LoadByte = FindGenericMethod(typeof(Vector128), "Load").MakeGenericMethod(typeof(byte)); + + public static readonly MethodInfo V128CreateScalarUInt = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128CreateScalarULong = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128CreateScalarUShort = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(ushort)); + + public static readonly MethodInfo V128UIntAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128ULongAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128UShortAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(ushort)); + + public static readonly MethodInfo V256LongAsULong = FindGenericMethod(typeof(Vector256), "AsUInt64").MakeGenericMethod(typeof(long)); + public static readonly MethodInfo V256IntAsUInt = FindGenericMethod(typeof(Vector256), "AsUInt32").MakeGenericMethod(typeof(int)); + public static readonly MethodInfo V256ShortAsUShort = FindGenericMethod(typeof(Vector256), "AsUInt16").MakeGenericMethod(typeof(short)); + + public static readonly MethodInfo V128LongAsULong = FindGenericMethod(typeof(Vector128), "AsUInt64").MakeGenericMethod(typeof(long)); + public static readonly MethodInfo V128IntAsUInt = FindGenericMethod(typeof(Vector128), "AsUInt32").MakeGenericMethod(typeof(int)); + public static readonly MethodInfo V128ShortAsUShort = FindGenericMethod(typeof(Vector128), "AsUInt16").MakeGenericMethod(typeof(short)); + + public static readonly MethodInfo V256GreaterThanULong = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V256GreaterThanUInt = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V256GreaterThanUShort = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(ushort)); + public static readonly MethodInfo V256GreaterThanByte = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(byte)); + + public static readonly MethodInfo V128GreaterThanULong = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128GreaterThanUInt = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128GreaterThanUShort = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(ushort)); + public static readonly MethodInfo V128GreaterThanByte = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(byte)); + + // Non-generic exact overloads on Avx2/Sse41 for byte-lane sign-extend expansion. + public static readonly MethodInfo Avx2ConvertToV256Int64 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int64", new[] { typeof(Vector128) }); + public static readonly MethodInfo Avx2ConvertToV256Int32 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int32", new[] { typeof(Vector128) }); + public static readonly MethodInfo Avx2ConvertToV256Int16 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int16", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int64 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int64", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int32 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int32", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int16 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int16", new[] { typeof(Vector128) }); + + // Vector*.Zero property getters — emitted as a call, not a field load, so we cache the getter MethodInfo. + public static readonly MethodInfo V256GetZeroULong = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroUInt = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroUShort = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroByte = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V128GetZeroULong = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroUInt = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroUShort = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroByte = GetZeroGetter(typeof(Vector128)); + + #endregion + } } } diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs index 37536cf0..134ae6a0 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs @@ -290,7 +290,7 @@ public static partial class ILKernelGenerator /// Caching these avoids repeated GetMethod() lookups during kernel generation. /// All fields use ?? throw to fail fast at type load if a method is not found. /// - private static class CachedMethods + private static partial class CachedMethods { // Math methods (double versions) public static readonly MethodInfo MathPow = typeof(Math).GetMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }) From a5862bd226f2c4c251541e449f560cd2c294d63c Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 20 Apr 2026 21:11:22 +0300 Subject: [PATCH 19/19] fix(where): gate x86-specific SIMD path on Sse41/Avx2 for ARM64 compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI failure on macos-latest (ARM64/Apple Silicon) reported 31 np.where tests throwing PlatformNotSupportedException at runtime: PlatformNotSupportedException: Operation is not supported on this platform. at System.Runtime.Intrinsics.X86.Sse41.ConvertToVector128Int64(Vector128`1 value) at IL_Where_Int64(...) Root cause: the SIMD-emit path was gated only on `VectorBits >= 128`. On ARM64, `Vector128.IsHardwareAccelerated` is true (maps to Neon), so VectorBits is 128, and the kernel emits calls to Sse41/Avx2 byte-lane expansion intrinsics which are x86-only. Breakdown of the byte-mask expansion path by element size: - 1-byte (byte): portable Vector*.Load/GreaterThan — safe on any SIMD platform - 2-byte: Sse41.ConvertToVector128Int16 / Avx2.ConvertToVector256Int16 - 4-byte: Sse41.ConvertToVector128Int32 / Avx2.ConvertToVector256Int32 - 8-byte: Sse41.ConvertToVector128Int64 / Avx2.ConvertToVector256Int64 Fix: in GenerateWhereKernelIL, compute `useV256`/`useV128` with an additional Sse41.IsSupported / Avx2.IsSupported guard — but only when elementSize > 1, since the 1-byte path is portable. If neither x86 intrinsic set is available for the required lane size, skip SIMD emission entirely; the scalar IL loop that follows handles correctness. Also passes the useV256 decision to EmitWhereSIMDLoop explicitly instead of recomputing it from VectorBits inside the loop, which was both duplicative and ignored the IsSupported guard. Result: on ARM64, byte-typed arrays still use Neon-backed SIMD; int/long/float/ double/short fall back to the scalar IL kernel. On x86 nothing changes. Verified: 180 np.where + np.asanyarray tests pass on Windows x64 (net8.0 + net10.0). ARM path awaits CI verification. --- .../Kernels/ILKernelGenerator.Where.cs | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index 3f4f371d..72678ca7 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -115,8 +115,17 @@ private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmana { int elementSize = Unsafe.SizeOf(); - // Determine if we can use SIMD - bool canSimd = elementSize <= 8 && IsSimdSupported(); + // SIMD eligibility: + // - 1-byte types (byte) only touch portable Vector128/Vector256 APIs, so they work + // on any SIMD-capable platform (including ARM64/Neon). + // - 2/4/8-byte types need Sse41.ConvertToVector128Int* (V128 path) or + // Avx2.ConvertToVector256Int* (V256 path) to expand the bool-mask lanes. + // These x86 intrinsics throw PlatformNotSupportedException on ARM64. + bool canSimdDtype = elementSize <= 8 && IsSimdSupported(); + bool needsX86 = elementSize > 1; + bool useV256 = VectorBits >= 256 && (!needsX86 || Avx2.IsSupported); + bool useV128 = !useV256 && VectorBits >= 128 && (!needsX86 || Sse41.IsSupported); + bool emitSimd = canSimdDtype && (useV256 || useV128); var dm = new DynamicMethod( name: $"IL_Where_{typeof(T).Name}", @@ -139,10 +148,9 @@ private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmana il.Emit(OpCodes.Ldc_I8, 0L); il.Emit(OpCodes.Stloc, locI); - if (canSimd && VectorBits >= 128) + if (emitSimd) { - // Generate SIMD path - EmitWhereSIMDLoop(il, locI); + EmitWhereSIMDLoop(il, locI, useV256); } // Scalar loop for remainder @@ -170,13 +178,12 @@ private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmana return (WhereKernel)dm.CreateDelegate(typeof(WhereKernel)); } - private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) where T : unmanaged + private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI, bool useV256) where T : unmanaged { long elementSize = Unsafe.SizeOf(); - long vectorCount = VectorBits >= 256 ? (32 / elementSize) : (16 / elementSize); + long vectorCount = useV256 ? (32 / elementSize) : (16 / elementSize); long unrollFactor = 4; long unrollStep = vectorCount * unrollFactor; - bool useV256 = VectorBits >= 256; var locUnrollEnd = il.DeclareLocal(typeof(long)); var locVectorEnd = il.DeclareLocal(typeof(long));