From bcf9e2ecc4787c4e97115250338f493854090f19 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:33 +0100 Subject: [PATCH 1/7] Add GradientDispatch bridge for custom gradient adapter dispatch --- .../op/DispatchingGradientAdapter.java | 76 +++++++++++++++++++ .../org/tensorflow/op/GradientDispatch.java | 25 ++++++ 2 files changed, 101 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java new file mode 100644 index 00000000000..413dd7afb42 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -0,0 +1,76 @@ +package org.tensorflow.op; + +import java.lang.reflect.Constructor; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.tensorflow.AbstractGradientAdapter; +import org.tensorflow.Graph; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.internal.c_api.TFJ_Scope; + +final class DispatchingGradientAdapter extends AbstractGradientAdapter { + + private final ConcurrentMap raw = new ConcurrentHashMap<>(); + private final ConcurrentMap> typed = new ConcurrentHashMap<>(); + + static final class TypedEntry> { + final CustomGradient grad; + final Class inputClass; + final Constructor ctor; + + TypedEntry(CustomGradient grad, Class inputClass) { + this.grad = grad; + this.inputClass = inputClass; + try { + this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", e); + } + } + } + + void putRaw(String opType, RawCustomGradient g) { + raw.put(opType, g); + } + + > void putTyped(String opType, CustomGradient g, Class inputClass) { + typed.put(opType, new TypedEntry<>(g, inputClass)); + } + + @Override + protected List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { + + final String opType = operation.type(); + + RawCustomGradient rg = raw.get(opType); + if (rg != null) { + // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op + Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return rg.call(new Ops(nativeScope), operation, gradInputs); + } + + @SuppressWarnings("rawtypes") + TypedEntry te = typed.get(opType); + if (te != null) { + return applyTyped(graph, scope, operation, gradInputs, te); + } + + throw new IllegalStateException("No Java custom gradient registered for op type: " + opType); + } + + private > List> applyTyped( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs, TypedEntry te) { + try { + T inputs = te.ctor.newInstance(operation); + Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return te.grad.call(new Ops(nativeScope), inputs, gradInputs); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java new file mode 100644 index 00000000000..64504121677 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java @@ -0,0 +1,25 @@ +package org.tensorflow.op; + +import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; + +/** Public bridge to a single native gradient adapter. */ +public final class GradientDispatch { + + // package-private adapter that can access NativeScope/Ops constructors + static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter(); + + private GradientDispatch() {} + + public static TFJ_GradFuncAdapter adapter() { + return ADAPTER; + } + + public static void putRaw(String opType, RawCustomGradient gradient) { + ADAPTER.putRaw(opType, gradient); + } + + public static > void putTyped( + String opType, CustomGradient gradient, Class inputClass) { + ADAPTER.putTyped(opType, gradient, inputClass); + } +} From 7c7dc54b57555bb5cc49526ce0bf1b8732eb74d9 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:40 +0100 Subject: [PATCH 2/7] Fix custom gradient registration scalability and support NoGradient --- .../tensorflow/AbstractGradientAdapter.java | 12 ++++++- .../main/java/org/tensorflow/TensorFlow.java | 10 ++++-- .../org/tensorflow/op/CustomGradient.java | 32 ++++++++++++++++++- .../op/DispatchingGradientAdapter.java | 18 ++++++++--- .../org/tensorflow/op/RawCustomGradient.java | 31 +++++++++++++++++- 5 files changed, 93 insertions(+), 10 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 2119cddaa67..18f95d2197d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -92,8 +92,18 @@ private static TF_Output toNativeOutputs(List> outputs) { new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { - var output = outputs.get(i).asOutput(); var nativeOutput = nativeOutputs.getPointer(i); + + Operand operand = outputs.get(i); + if (operand == null) { + // "NoGradient" sentinel: null oper + index 0. + // Native side must tolerate TF_Output.oper == nullptr. + nativeOutput.oper((org.tensorflow.internal.c_api.TF_Operation) null); + nativeOutput.index(0); + continue; + } + + var output = operand.asOutput(); nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle()); nativeOutput.index(output.index()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 7eba6d7ce30..76c0f168eb6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -207,7 +207,10 @@ public static synchronized boolean registerCustomGradient( if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient); + + org.tensorflow.op.GradientDispatch.putRaw(opType, gradient); + TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } @@ -255,7 +258,10 @@ public static synchronized > boolean registerCustomGrad if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass); + + org.tensorflow.op.GradientDispatch.putTyped(opType, gradient, inputClass); + TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java index 02acce1cb37..4c3b80a6cad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java @@ -17,10 +17,15 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for ops of type {@link T}. Should be registered using {@link @@ -57,6 +62,31 @@ public interface CustomGradient { */ static > TFJ_GradFuncAdapter adapter( CustomGradient gradient, Class opClass) { - return new TypedGradientAdapter(gradient, opClass); + + final TypedGradientAdapter impl = new TypedGradientAdapter(gradient, opClass); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java index 413dd7afb42..380dd6b555d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -28,7 +28,8 @@ static final class TypedEntry> { this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); } catch (NoSuchMethodException e) { throw new IllegalArgumentException( - "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", e); + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", + e); } } } @@ -37,7 +38,8 @@ void putRaw(String opType, RawCustomGradient g) { raw.put(opType, g); } - > void putTyped(String opType, CustomGradient g, Class inputClass) { + > void putTyped( + String opType, CustomGradient g, Class inputClass) { typed.put(opType, new TypedEntry<>(g, inputClass)); } @@ -50,7 +52,8 @@ protected List> apply( RawCustomGradient rg = raw.get(opType); if (rg != null) { // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op - Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); return rg.call(new Ops(nativeScope), operation, gradInputs); } @@ -64,10 +67,15 @@ protected List> apply( } private > List> applyTyped( - Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs, TypedEntry te) { + Graph graph, + TFJ_Scope scope, + GraphOperation operation, + List> gradInputs, + TypedEntry te) { try { T inputs = te.ctor.newInstance(operation); - Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); return te.grad.call(new Ops(nativeScope), inputs, gradInputs); } catch (ReflectiveOperationException e) { throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java index c2d5496de2a..723d45d58ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java @@ -17,11 +17,16 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.GraphOperation; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for an op of unspecified type. Should be registered using {@link @@ -54,6 +59,30 @@ public interface RawCustomGradient { * TensorFlow#registerCustomGradient(String, RawCustomGradient)}. */ static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) { - return new RawGradientAdapter(gradient); + final RawGradientAdapter impl = new RawGradientAdapter(gradient); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } From e2fa04b5729159308162bbebd8caa24d314f977b Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:51 +0100 Subject: [PATCH 3/7] Handle NoGradient in native custom gradient bridge --- .../internal/c_api/presets/tensorflow.java | 13 ++++ .../internal/c_api/tfj_gradients_impl.cc | 78 ++++++++++++++++--- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index cd3be39fde2..01f8fe59b4b 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -213,6 +213,19 @@ public void map(InfoMap infoMap) { // Skip C++ classes infoMap.put(new Info("tsl::StatusGroup").skip()); + + // Force correct marshalling of TFJ_RegisterCustomGradient callback argument. + // Without an explicit cast, JavaCPP may pass a NULL function pointer for some FunctionPointer + // instances. + infoMap.put( + new Info("TFJ_RegisterCustomGradient") + .javaText( + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") BytePointer op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n" + + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") String op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n")); } @Override diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 6882cfa704c..0e30623dbc3 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -18,6 +18,12 @@ limitations under the License. #include #include +// IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, +// so we must not rely on transitive includes from other headers). +#include +#include +#include + #include "tfj_graph.h" #include "tsl/platform/errors.h" #include "tensorflow/c/c_api.h" @@ -32,7 +38,6 @@ namespace tensorflow { unordered_map g_grad_func_adapters; /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit - /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 template T* struct_cast(U* ptr) { @@ -53,16 +58,34 @@ namespace tensorflow { if (found_adapter == g_grad_func_adapters.end()) { return errors::NotFound("No gradient adapter found for operation ", op_type); } - int num_inputs = grad_inputs.size(); - TF_Output* inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + + const int num_inputs = static_cast(grad_inputs.size()); + + TF_Output* inputs = nullptr; + if (num_inputs > 0) { + inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); + if (inputs == nullptr) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + for (int i = 0; i < num_inputs; ++i) { - Output grad_input = grad_inputs[i]; + const Output& grad_input = grad_inputs[i]; inputs[i].oper = struct_cast(grad_input.node()); inputs[i].index = grad_input.index(); } - TF_Output* outputs = NULL; + + TF_Output* outputs = nullptr; LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - int num_outputs = found_adapter->second( + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == nullptr) { + if (inputs != nullptr) free(inputs); + return errors::Unknown("Null Java gradient adapter for op ", op_type); + } + LOG(INFO) << "Adapter ptr for " << op_type << " = " << reinterpret_cast(found_adapter->second); + const int num_outputs = adapter( static_cast(scope.graph()), struct_cast(const_cast(&scope)), struct_cast(op.node()), @@ -70,12 +93,39 @@ namespace tensorflow { num_inputs, &outputs ); + + // Always free inputs, even on error paths. + if (inputs != nullptr) free(inputs); + + // Adapter contract hardening: + // - On Java exception / failure, adapter should return negative or outputs==nullptr. + if (num_outputs < 0) { + if (outputs != nullptr) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for op ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == nullptr) { + return errors::Unknown("Java custom gradient adapter returned null outputs for op ", + op_type, " with num_outputs=", num_outputs); + } + + grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); + for (int i = 0; i < num_outputs; ++i) { - TF_Output output = outputs[i]; - grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + const TF_Output out = outputs[i]; + + // "NoGradient" sentinel from Java: TF_Output.oper == nullptr + if (out.oper == nullptr) { + // Represent "no gradient" as an empty Output. + // TF's gradient builder should tolerate missing gradients for non-differentiable inputs. + grad_outputs->push_back(Output()); + continue; + } + + grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); } - free(inputs); - free(outputs); // outputs are allocated from Java but must be freed here + + if (outputs != nullptr) free(outputs); return OkStatus(); } } @@ -91,6 +141,14 @@ bool TFJ_HasGradient(const char* op_type) { } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { + LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" + << reinterpret_cast(grad_func_adapter); + + if (grad_func_adapter == nullptr) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; + return false; + } + if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type << ", which has already a registered function"; From 1d43d4cd9542d632071b512cd5c690d4cd25f5f1 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 04:53:09 +0100 Subject: [PATCH 4/7] Add tests for NoGradient support in Java custom gradients --- .../org/tensorflow/CustomGradientsTest.java | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java new file mode 100644 index 00000000000..102623ffc5b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -0,0 +1,124 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.Ops; +import org.tensorflow.op.RawCustomGradient; +import org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +@DisabledOnOs(OS.WINDOWS) +public class CustomGradientsTest { + + @Test + public void noGradientNullIsSupported() { + // Register a custom gradient for an op that has NO native gradient in TF core. + CustomGradient grad = + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand gLoss = (Operand) gradInputs.get(0); // [B] + + @SuppressWarnings("unchecked") + Operand logits = op.features; + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, op.labels); + + Operand backprop = xent.backprop(); // [B,C] + Operand gLossE = tf.expandDims(gLoss, tf.constant(1)); // [B,1] + Operand dLogits = tf.math.mul(gLossE, backprop); // [B,C] + + // labels: NoGradient + return java.util.Arrays.asList(dLogits, null); + }; + + assertTrue( + TensorFlow.registerCustomGradient(SparseSoftmaxCrossEntropyWithLogits.Inputs.class, grad)); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + // Small fixed shapes to be able to create an explicit seed (avoid OnesLike in addGradients). + Operand logits = tf.constant(new float[][] {{1f, 2f, 3f}, {3f, 2f, 1f}}); + Operand labels = tf.constant(new int[] {2, 0}); + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, labels); + + Output loss = xent.loss(); // [2] + Operand seed = tf.constant(new float[] {1f, 1f}); // same shape as loss + + Output[] grads = + g.addGradients( + "seed", + new Output[] {loss}, + new Output[] {logits.asOutput(), labels.asOutput()}, + new Output[] {seed.asOutput()}); + + // logits grad exists, labels grad must be "NoGradient" (represented as a CLOSED Output) + assertNotNull(grads); + assertEquals(2, grads.length); + assertNotNull(grads[0], "Expected gradient for logits"); + assertNotNull(grads[1], "Expected an Output placeholder for labels gradient"); + assertTrue(grads[1].isClosed(), "Expected closed gradient (NoGradient) for labels"); + } + } + + @Test + public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { + // Register custom gradient for SigmoidGrad (if already registered, it will return false, + // but the test can still pass because the gradient exists in the current process). + TensorFlow.registerCustomGradient( + "SigmoidGrad", + (RawCustomGradient) + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand y = (Operand) op.input(0); // sigmoid(x) + @SuppressWarnings("unchecked") + Operand dy = (Operand) op.input(1); // upstream into SigmoidGrad + @SuppressWarnings("unchecked") + Operand upstream = (Operand) gradInputs.get(0); + + Operand one = tf.constant(1.0f); + Operand yTimesOneMinusY = tf.math.mul(y, tf.math.sub(one, y)); + + // dL/d(dy) = upstream * y*(1-y) + Operand dDy = tf.math.mul(upstream, yTimesOneMinusY); + + // dL/d(y) not needed for this test; return zeros to keep it non-null. + Operand dY = tf.zerosLike(y); + + return List.of(dY, dDy); + }); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand x = tf.placeholder(TFloat32.class); + Operand y = tf.math.sigmoid(x); + + // Provide an explicit seed dy to avoid Graph.addGradients defaulting to OnesLike(y) + Operand seed = tf.fill(tf.shape(y), tf.constant(1.0f)); + + Output[] grads = + g.addGradients( + "seed", + new Output[] {y.asOutput()}, + new Output[] {x.asOutput()}, + new Output[] {seed.asOutput()}); + + assertNotNull(grads); + assertEquals(1, grads.length); + assertNotNull(grads[0], "Expected a non-null gradient for sigmoid(x) wrt x."); + assertTrue(!grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); + } + } +} From 69a9a3672760188a25229eb57380d388000acafb Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 09:58:12 +0100 Subject: [PATCH 5/7] Fix custom gradients: support NoGradient and stabilize adapter --- .../tensorflow/AbstractGradientAdapter.java | 17 +- .../org/tensorflow/CustomGradientsTest.java | 5 +- .../internal/c_api/tfj_gradients_impl.cc | 243 +++++++++--------- 3 files changed, 130 insertions(+), 135 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 18f95d2197d..ef581db7ae0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -80,11 +80,11 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput } /** - * Put the Java outputs into the array of native outputs, resizing it to the necessary size. - * - * @param outputs the outputs to put - * @return pointer to the native array of outputs - */ + * Put the Java outputs into the array of native outputs, resizing it to the necessary size. + * + * @param outputs the outputs to put + * @return pointer to the native array of outputs + */ private static TF_Output toNativeOutputs(List> outputs) { // Use malloc to allocate native outputs, as they will be freed by the native layer and we do // not want JavaCPP to deallocate them @@ -92,13 +92,12 @@ private static TF_Output toNativeOutputs(List> outputs) { new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { + Operand operand = outputs.get(i); var nativeOutput = nativeOutputs.getPointer(i); - Operand operand = outputs.get(i); + // Convention: null Operand => NoGradient if (operand == null) { - // "NoGradient" sentinel: null oper + index 0. - // Native side must tolerate TF_Output.oper == nullptr. - nativeOutput.oper((org.tensorflow.internal.c_api.TF_Operation) null); + nativeOutput.oper((TF_Operation) null); nativeOutput.index(0); continue; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java index 102623ffc5b..baaa7bdb742 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; import java.util.List; import org.junit.jupiter.api.Test; @@ -96,7 +97,7 @@ public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { // dL/d(y) not needed for this test; return zeros to keep it non-null. Operand dY = tf.zerosLike(y); - return List.of(dY, dDy); + return java.util.Arrays.asList(dY, dDy); }); try (Graph g = new Graph()) { @@ -118,7 +119,7 @@ public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { assertNotNull(grads); assertEquals(1, grads.length); assertNotNull(grads[0], "Expected a non-null gradient for sigmoid(x) wrt x."); - assertTrue(!grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); + assertFalse(grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); } } } diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 0e30623dbc3..32e6043b1c3 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -21,8 +21,8 @@ limitations under the License. // IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, // so we must not rely on transitive includes from other headers). #include -#include #include +#include #include "tfj_graph.h" #include "tsl/platform/errors.h" @@ -31,141 +31,136 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { - namespace java { - using namespace tsl; - using namespace std; - - unordered_map g_grad_func_adapters; - - /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit - /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: - /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 - template T* struct_cast(U* ptr) { - return static_cast(static_cast(ptr)); - } - - /// This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the - /// graph using the given `scope`. - /// We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom - /// "C" version (TFJ_GradFuncAdapter). - Status CustomGradFunc(const Scope& scope, - const Operation& op, - const vector& grad_inputs, - vector* grad_outputs) - { - const string& op_type = op.node()->type_string(); - auto found_adapter = g_grad_func_adapters.find(op_type); - if (found_adapter == g_grad_func_adapters.end()) { - return errors::NotFound("No gradient adapter found for operation ", op_type); - } - - const int num_inputs = static_cast(grad_inputs.size()); - - TF_Output* inputs = nullptr; - if (num_inputs > 0) { - inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); - if (inputs == nullptr) { - return errors::ResourceExhausted( - "Out of memory allocating inputs for custom gradient of op ", op_type); - } - } - - for (int i = 0; i < num_inputs; ++i) { - const Output& grad_input = grad_inputs[i]; - inputs[i].oper = struct_cast(grad_input.node()); - inputs[i].index = grad_input.index(); - } - - TF_Output* outputs = nullptr; - LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - - TFJ_GradFuncAdapter adapter = found_adapter->second; - if (adapter == nullptr) { - if (inputs != nullptr) free(inputs); - return errors::Unknown("Null Java gradient adapter for op ", op_type); - } - LOG(INFO) << "Adapter ptr for " << op_type << " = " << reinterpret_cast(found_adapter->second); - const int num_outputs = adapter( - static_cast(scope.graph()), - struct_cast(const_cast(&scope)), - struct_cast(op.node()), - inputs, - num_inputs, - &outputs - ); - - // Always free inputs, even on error paths. - if (inputs != nullptr) free(inputs); - - // Adapter contract hardening: - // - On Java exception / failure, adapter should return negative or outputs==nullptr. - if (num_outputs < 0) { - if (outputs != nullptr) free(outputs); - return errors::Unknown("Java custom gradient adapter failed for op ", op_type, - " (num_outputs=", num_outputs, ")"); - } - if (num_outputs > 0 && outputs == nullptr) { - return errors::Unknown("Java custom gradient adapter returned null outputs for op ", - op_type, " with num_outputs=", num_outputs); - } - - grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); - - for (int i = 0; i < num_outputs; ++i) { - const TF_Output out = outputs[i]; - - // "NoGradient" sentinel from Java: TF_Output.oper == nullptr - if (out.oper == nullptr) { - // Represent "no gradient" as an empty Output. - // TF's gradient builder should tolerate missing gradients for non-differentiable inputs. - grad_outputs->push_back(Output()); - continue; - } - - grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); - } - - if (outputs != nullptr) free(outputs); - return OkStatus(); - } +namespace java { + +using namespace tsl; +using namespace std; + +unordered_map g_grad_func_adapters; + +// Cast helper (inspired by TF C-API) +template +T* struct_cast(U* ptr) { + return static_cast(static_cast(ptr)); +} + +// Bridge called by TF runtime when building gradients for op +Status CustomGradFunc(const Scope& scope, + const Operation& op, + const vector& grad_inputs, + vector* grad_outputs) { + const string& op_type = op.node()->type_string(); + auto found_adapter = g_grad_func_adapters.find(op_type); + if (found_adapter == g_grad_func_adapters.end()) { + return errors::NotFound("No gradient adapter found for operation ", op_type); + } + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == nullptr) { + return errors::Unknown("Null Java gradient adapter for op ", op_type); + } + + const int num_inputs = static_cast(grad_inputs.size()); + + TF_Output* inputs = nullptr; + if (num_inputs > 0) { + inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); + if (inputs == nullptr) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + + for (int i = 0; i < num_inputs; ++i) { + const Output& grad_input = grad_inputs[i]; + inputs[i].oper = struct_cast(grad_input.node()); + inputs[i].index = grad_input.index(); + } + + TF_Output* outputs = nullptr; + + LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; + const int num_outputs = adapter( + static_cast(scope.graph()), + struct_cast(const_cast(&scope)), + struct_cast(op.node()), + inputs, + num_inputs, + &outputs); + + if (inputs != nullptr) free(inputs); + + // Adapter contract: + // - num_outputs < 0 indicates failure + // - num_outputs == 0: OK, outputs may be nullptr + // - num_outputs > 0: outputs must be non-null + if (num_outputs < 0) { + if (outputs != nullptr) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for op ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == nullptr) { + return errors::Unknown("Java custom gradient adapter returned null outputs for op ", + op_type, " with num_outputs=", num_outputs); + } + + grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); + + for (int i = 0; i < num_outputs; ++i) { + const TF_Output out = outputs[i]; + + // Convention: out.oper == nullptr => NoGradient + if (out.oper == nullptr) { + grad_outputs->push_back(Output()); // TF interprets empty Output as "no grad" + continue; } + + grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); + } + + if (outputs != nullptr) free(outputs); // allocated from Java via malloc + return OkStatus(); } +} // namespace java +} // namespace tensorflow + using namespace tensorflow::ops; using namespace tensorflow::java; bool TFJ_HasGradient(const char* op_type) { - GradFunc dummy; - tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); - return status.ok(); + GradFunc dummy; + tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); + return status.ok(); } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" - << reinterpret_cast(grad_func_adapter); - - if (grad_func_adapter == nullptr) { - LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; - return false; - } - - if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash - LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type - << ", which has already a registered function"; - return false; - } - bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); - if (registered) { - g_grad_func_adapters.insert({op_type, grad_func_adapter}); - } - return registered; + LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" + << reinterpret_cast(grad_func_adapter); + + if (grad_func_adapter == nullptr) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; + return false; + } + + if (TFJ_HasGradient(op_type)) { + LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type + << ", which has already a registered function"; + return false; + } + + bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); + if (registered) { + g_grad_func_adapters.insert({op_type, grad_func_adapter}); + } + return registered; } -#else // #ifndef _WIN32 - -/* This extension is not available on Windows */ +#else // _WIN32 bool TFJ_HasGradient(const char* op_type) { return true; } -bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; } +bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { + return false; +} -#endif // #ifndef _WIN32 +#endif // _WIN32 From 8d80312b8c0b0efd53132fa5ac8f1d2cb5d5f50c Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 17:55:07 +0100 Subject: [PATCH 6/7] apply mvn spotless --- .../tensorflow/AbstractGradientAdapter.java | 10 +- .../org/tensorflow/CustomGradientsTest.java | 3 +- .../internal/c_api/tfj_gradients_impl.cc | 230 +++++++++--------- 3 files changed, 116 insertions(+), 127 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index ef581db7ae0..94f82786ed3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -80,11 +80,11 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput } /** - * Put the Java outputs into the array of native outputs, resizing it to the necessary size. - * - * @param outputs the outputs to put - * @return pointer to the native array of outputs - */ + * Put the Java outputs into the array of native outputs, resizing it to the necessary size. + * + * @param outputs the outputs to put + * @return pointer to the native array of outputs + */ private static TF_Output toNativeOutputs(List> outputs) { // Use malloc to allocate native outputs, as they will be freed by the native layer and we do // not want JavaCPP to deallocate them diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java index baaa7bdb742..6d7e9a098cd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -1,11 +1,10 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.assertFalse; -import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledOnOs; import org.junit.jupiter.api.condition.OS; diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 32e6043b1c3..9c5ecd75e07 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include -// IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, -// so we must not rely on transitive includes from other headers). #include #include #include @@ -31,136 +29,128 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { -namespace java { - -using namespace tsl; -using namespace std; - -unordered_map g_grad_func_adapters; - -// Cast helper (inspired by TF C-API) -template -T* struct_cast(U* ptr) { - return static_cast(static_cast(ptr)); -} - -// Bridge called by TF runtime when building gradients for op -Status CustomGradFunc(const Scope& scope, - const Operation& op, - const vector& grad_inputs, - vector* grad_outputs) { - const string& op_type = op.node()->type_string(); - auto found_adapter = g_grad_func_adapters.find(op_type); - if (found_adapter == g_grad_func_adapters.end()) { - return errors::NotFound("No gradient adapter found for operation ", op_type); - } - - TFJ_GradFuncAdapter adapter = found_adapter->second; - if (adapter == nullptr) { - return errors::Unknown("Null Java gradient adapter for op ", op_type); - } - - const int num_inputs = static_cast(grad_inputs.size()); - - TF_Output* inputs = nullptr; - if (num_inputs > 0) { - inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); - if (inputs == nullptr) { - return errors::ResourceExhausted( - "Out of memory allocating inputs for custom gradient of op ", op_type); - } - } - - for (int i = 0; i < num_inputs; ++i) { - const Output& grad_input = grad_inputs[i]; - inputs[i].oper = struct_cast(grad_input.node()); - inputs[i].index = grad_input.index(); - } - - TF_Output* outputs = nullptr; - - LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - const int num_outputs = adapter( - static_cast(scope.graph()), - struct_cast(const_cast(&scope)), - struct_cast(op.node()), - inputs, - num_inputs, - &outputs); - - if (inputs != nullptr) free(inputs); - - // Adapter contract: - // - num_outputs < 0 indicates failure - // - num_outputs == 0: OK, outputs may be nullptr - // - num_outputs > 0: outputs must be non-null - if (num_outputs < 0) { - if (outputs != nullptr) free(outputs); - return errors::Unknown("Java custom gradient adapter failed for op ", op_type, - " (num_outputs=", num_outputs, ")"); - } - if (num_outputs > 0 && outputs == nullptr) { - return errors::Unknown("Java custom gradient adapter returned null outputs for op ", - op_type, " with num_outputs=", num_outputs); - } - - grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); - - for (int i = 0; i < num_outputs; ++i) { - const TF_Output out = outputs[i]; - - // Convention: out.oper == nullptr => NoGradient - if (out.oper == nullptr) { - grad_outputs->push_back(Output()); // TF interprets empty Output as "no grad" - continue; + namespace java { + using namespace tsl; + using namespace std; + + unordered_map g_grad_func_adapters; + + /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit + /// + /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: + /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 + template T* struct_cast(U* ptr) { + return static_cast(static_cast(ptr)); + } + + /// This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the + /// graph using the given `scope`. + /// We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom + /// "C" version (TFJ_GradFuncAdapter). + Status CustomGradFunc(const Scope& scope, + const Operation& op, + const vector& grad_inputs, + vector* grad_outputs) + { + const string& op_type = op.node()->type_string(); + auto found_adapter = g_grad_func_adapters.find(op_type); + if (found_adapter == g_grad_func_adapters.end()) { + return errors::NotFound("No gradient adapter found for operation ", op_type); + } + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == NULL) { + return errors::Unknown("Null Java gradient adapter for operation ", op_type); + } + + int num_inputs = grad_inputs.size(); + TF_Output* inputs = NULL; + if (num_inputs > 0) { + inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + if (inputs == NULL) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + + for (int i = 0; i < num_inputs; ++i) { + Output grad_input = grad_inputs[i]; + inputs[i].oper = struct_cast(grad_input.node()); + inputs[i].index = grad_input.index(); + } + + TF_Output* outputs = NULL; + LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; + int num_outputs = adapter( + static_cast(scope.graph()), + struct_cast(const_cast(&scope)), + struct_cast(op.node()), + inputs, + num_inputs, + &outputs + ); + + if (inputs != NULL) free(inputs); + + if (num_outputs < 0) { + if (outputs != NULL) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for operation ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == NULL) { + return errors::Unknown("Java custom gradient adapter returned null outputs for operation ", + op_type, " with num_outputs=", num_outputs); + } + + for (int i = 0; i < num_outputs; ++i) { + TF_Output output = outputs[i]; + + // Convention: output.oper == NULL => NoGradient + if (output.oper == NULL) { + grad_outputs->push_back(Output()); + } else { + grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + } + } + + if (outputs != NULL) free(outputs); // outputs are allocated from Java but must be freed here + return OkStatus(); + } } - - grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); - } - - if (outputs != nullptr) free(outputs); // allocated from Java via malloc - return OkStatus(); } -} // namespace java -} // namespace tensorflow - using namespace tensorflow::ops; using namespace tensorflow::java; bool TFJ_HasGradient(const char* op_type) { - GradFunc dummy; - tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); - return status.ok(); + GradFunc dummy; + tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); + return status.ok(); } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" - << reinterpret_cast(grad_func_adapter); - - if (grad_func_adapter == nullptr) { - LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; - return false; - } - - if (TFJ_HasGradient(op_type)) { - LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type - << ", which has already a registered function"; - return false; - } - - bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); - if (registered) { - g_grad_func_adapters.insert({op_type, grad_func_adapter}); - } - return registered; + if (grad_func_adapter == NULL) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for operation " << op_type; + return false; + } + + if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash + LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type + << ", which has already a registered function"; + return false; + } + bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); + if (registered) { + g_grad_func_adapters.insert({op_type, grad_func_adapter}); + } + return registered; } -#else // _WIN32 +#else // #ifndef _WIN32 + +/* This extension is not available on Windows */ bool TFJ_HasGradient(const char* op_type) { return true; } -bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - return false; -} +bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; } -#endif // _WIN32 +#endif // #ifndef _WIN32 From 1240293e6799ca0caef6f439349f1aef304f7785 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Thu, 12 Feb 2026 17:29:28 +0100 Subject: [PATCH 7/7] Fix review comments: enforce mutual exclusion for raw/typed gradients, remove inline ifs, add license headers and imports - Prevent dual registration of raw and typed gradients for the same op type - Use putIfAbsent and explicit exceptions to avoid silent overwrites - Replace inline if statements in tfj_gradients_impl.cc with brace blocks - Add Apache 2.0 headers to new files - Replace fully-qualified GradientDispatch reference with import --- .../main/java/org/tensorflow/TensorFlow.java | 9 ++-- .../op/DispatchingGradientAdapter.java | 43 ++++++++++++++++++- .../org/tensorflow/op/GradientDispatch.java | 15 +++++++ .../internal/c_api/tfj_gradients_impl.cc | 14 ++++-- 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 76c0f168eb6..a6b9e86dd9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -39,6 +39,7 @@ import org.tensorflow.internal.c_api.TF_Library; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.GradientDispatch; import org.tensorflow.op.RawCustomGradient; import org.tensorflow.op.RawOpInputs; import org.tensorflow.op.annotation.OpInputsMetadata; @@ -208,8 +209,8 @@ public static synchronized boolean registerCustomGradient( return false; } - org.tensorflow.op.GradientDispatch.putRaw(opType, gradient); - TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + GradientDispatch.putRaw(opType, gradient); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); if (!TFJ_RegisterCustomGradient(opType, g)) { return false; @@ -259,8 +260,8 @@ public static synchronized > boolean registerCustomGrad return false; } - org.tensorflow.op.GradientDispatch.putTyped(opType, gradient, inputClass); - TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + GradientDispatch.putTyped(opType, gradient, inputClass); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); if (!TFJ_RegisterCustomGradient(opType, g)) { return false; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java index 380dd6b555d..80b934460dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -1,3 +1,18 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import java.lang.reflect.Constructor; @@ -16,6 +31,16 @@ final class DispatchingGradientAdapter extends AbstractGradientAdapter { private final ConcurrentMap raw = new ConcurrentHashMap<>(); private final ConcurrentMap> typed = new ConcurrentHashMap<>(); + private static String dupMsg(String opType, String existingKind, String newKind) { + return "A " + + existingKind + + " gradient is already registered for op type '" + + opType + + "'. Raw and typed registrations are mutually exclusive; cannot register " + + newKind + + "."; + } + static final class TypedEntry> { final CustomGradient grad; final Class inputClass; @@ -35,12 +60,26 @@ static final class TypedEntry> { } void putRaw(String opType, RawCustomGradient g) { - raw.put(opType, g); + if (typed.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "typed", "raw")); + } + RawCustomGradient prev = raw.putIfAbsent(opType, g); + if (prev != null) { + throw new IllegalStateException( + "A raw gradient is already registered for op type '" + opType + "'."); + } } > void putTyped( String opType, CustomGradient g, Class inputClass) { - typed.put(opType, new TypedEntry<>(g, inputClass)); + if (raw.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "raw", "typed")); + } + TypedEntry prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass)); + if (prev != null) { + throw new IllegalStateException( + "A typed gradient is already registered for op type '" + opType + "'."); + } } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java index 64504121677..441cff5a2fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java @@ -1,3 +1,18 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 9c5ecd75e07..ad68e1e5c05 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -90,13 +90,18 @@ namespace tensorflow { &outputs ); - if (inputs != NULL) free(inputs); + if (inputs != NULL) { + free(inputs); + } if (num_outputs < 0) { - if (outputs != NULL) free(outputs); + if (outputs != NULL) { + free(outputs); + } return errors::Unknown("Java custom gradient adapter failed for operation ", op_type, " (num_outputs=", num_outputs, ")"); } + if (num_outputs > 0 && outputs == NULL) { return errors::Unknown("Java custom gradient adapter returned null outputs for operation ", op_type, " with num_outputs=", num_outputs); @@ -113,7 +118,10 @@ namespace tensorflow { } } - if (outputs != NULL) free(outputs); // outputs are allocated from Java but must be freed here + if (outputs != NULL) { + free(outputs); // outputs are allocated from Java but must be freed here + } + return OkStatus(); } }