diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 2119cddaa67..94f82786ed3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -92,8 +92,17 @@ private static TF_Output toNativeOutputs(List> outputs) { new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { - var output = outputs.get(i).asOutput(); + Operand operand = outputs.get(i); var nativeOutput = nativeOutputs.getPointer(i); + + // Convention: null Operand => NoGradient + if (operand == null) { + nativeOutput.oper((TF_Operation) null); + nativeOutput.index(0); + continue; + } + + var output = operand.asOutput(); nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle()); nativeOutput.index(output.index()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 7eba6d7ce30..a6b9e86dd9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -39,6 +39,7 @@ import org.tensorflow.internal.c_api.TF_Library; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.GradientDispatch; import org.tensorflow.op.RawCustomGradient; import org.tensorflow.op.RawOpInputs; import org.tensorflow.op.annotation.OpInputsMetadata; @@ -207,7 +208,10 @@ public static synchronized boolean registerCustomGradient( if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient); + + GradientDispatch.putRaw(opType, gradient); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } @@ -255,7 +259,10 @@ public static synchronized > boolean registerCustomGrad if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass); + + GradientDispatch.putTyped(opType, gradient, inputClass); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java index 02acce1cb37..4c3b80a6cad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java @@ -17,10 +17,15 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for ops of type {@link T}. Should be registered using {@link @@ -57,6 +62,31 @@ public interface CustomGradient { */ static > TFJ_GradFuncAdapter adapter( CustomGradient gradient, Class opClass) { - return new TypedGradientAdapter(gradient, opClass); + + final TypedGradientAdapter impl = new TypedGradientAdapter(gradient, opClass); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java new file mode 100644 index 00000000000..80b934460dc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -0,0 +1,123 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.op; + +import java.lang.reflect.Constructor; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.tensorflow.AbstractGradientAdapter; +import org.tensorflow.Graph; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.internal.c_api.TFJ_Scope; + +final class DispatchingGradientAdapter extends AbstractGradientAdapter { + + private final ConcurrentMap raw = new ConcurrentHashMap<>(); + private final ConcurrentMap> typed = new ConcurrentHashMap<>(); + + private static String dupMsg(String opType, String existingKind, String newKind) { + return "A " + + existingKind + + " gradient is already registered for op type '" + + opType + + "'. Raw and typed registrations are mutually exclusive; cannot register " + + newKind + + "."; + } + + static final class TypedEntry> { + final CustomGradient grad; + final Class inputClass; + final Constructor ctor; + + TypedEntry(CustomGradient grad, Class inputClass) { + this.grad = grad; + this.inputClass = inputClass; + try { + this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", + e); + } + } + } + + void putRaw(String opType, RawCustomGradient g) { + if (typed.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "typed", "raw")); + } + RawCustomGradient prev = raw.putIfAbsent(opType, g); + if (prev != null) { + throw new IllegalStateException( + "A raw gradient is already registered for op type '" + opType + "'."); + } + } + + > void putTyped( + String opType, CustomGradient g, Class inputClass) { + if (raw.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "raw", "typed")); + } + TypedEntry prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass)); + if (prev != null) { + throw new IllegalStateException( + "A typed gradient is already registered for op type '" + opType + "'."); + } + } + + @Override + protected List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { + + final String opType = operation.type(); + + RawCustomGradient rg = raw.get(opType); + if (rg != null) { + // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return rg.call(new Ops(nativeScope), operation, gradInputs); + } + + @SuppressWarnings("rawtypes") + TypedEntry te = typed.get(opType); + if (te != null) { + return applyTyped(graph, scope, operation, gradInputs, te); + } + + throw new IllegalStateException("No Java custom gradient registered for op type: " + opType); + } + + private > List> applyTyped( + Graph graph, + TFJ_Scope scope, + GraphOperation operation, + List> gradInputs, + TypedEntry te) { + try { + T inputs = te.ctor.newInstance(operation); + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return te.grad.call(new Ops(nativeScope), inputs, gradInputs); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java new file mode 100644 index 00000000000..441cff5a2fc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java @@ -0,0 +1,40 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.op; + +import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; + +/** Public bridge to a single native gradient adapter. */ +public final class GradientDispatch { + + // package-private adapter that can access NativeScope/Ops constructors + static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter(); + + private GradientDispatch() {} + + public static TFJ_GradFuncAdapter adapter() { + return ADAPTER; + } + + public static void putRaw(String opType, RawCustomGradient gradient) { + ADAPTER.putRaw(opType, gradient); + } + + public static > void putTyped( + String opType, CustomGradient gradient, Class inputClass) { + ADAPTER.putTyped(opType, gradient, inputClass); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java index c2d5496de2a..723d45d58ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java @@ -17,11 +17,16 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.GraphOperation; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for an op of unspecified type. Should be registered using {@link @@ -54,6 +59,30 @@ public interface RawCustomGradient { * TensorFlow#registerCustomGradient(String, RawCustomGradient)}. */ static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) { - return new RawGradientAdapter(gradient); + final RawGradientAdapter impl = new RawGradientAdapter(gradient); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java new file mode 100644 index 00000000000..6d7e9a098cd --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -0,0 +1,124 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.Ops; +import org.tensorflow.op.RawCustomGradient; +import org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +@DisabledOnOs(OS.WINDOWS) +public class CustomGradientsTest { + + @Test + public void noGradientNullIsSupported() { + // Register a custom gradient for an op that has NO native gradient in TF core. + CustomGradient grad = + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand gLoss = (Operand) gradInputs.get(0); // [B] + + @SuppressWarnings("unchecked") + Operand logits = op.features; + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, op.labels); + + Operand backprop = xent.backprop(); // [B,C] + Operand gLossE = tf.expandDims(gLoss, tf.constant(1)); // [B,1] + Operand dLogits = tf.math.mul(gLossE, backprop); // [B,C] + + // labels: NoGradient + return java.util.Arrays.asList(dLogits, null); + }; + + assertTrue( + TensorFlow.registerCustomGradient(SparseSoftmaxCrossEntropyWithLogits.Inputs.class, grad)); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + // Small fixed shapes to be able to create an explicit seed (avoid OnesLike in addGradients). + Operand logits = tf.constant(new float[][] {{1f, 2f, 3f}, {3f, 2f, 1f}}); + Operand labels = tf.constant(new int[] {2, 0}); + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, labels); + + Output loss = xent.loss(); // [2] + Operand seed = tf.constant(new float[] {1f, 1f}); // same shape as loss + + Output[] grads = + g.addGradients( + "seed", + new Output[] {loss}, + new Output[] {logits.asOutput(), labels.asOutput()}, + new Output[] {seed.asOutput()}); + + // logits grad exists, labels grad must be "NoGradient" (represented as a CLOSED Output) + assertNotNull(grads); + assertEquals(2, grads.length); + assertNotNull(grads[0], "Expected gradient for logits"); + assertNotNull(grads[1], "Expected an Output placeholder for labels gradient"); + assertTrue(grads[1].isClosed(), "Expected closed gradient (NoGradient) for labels"); + } + } + + @Test + public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { + // Register custom gradient for SigmoidGrad (if already registered, it will return false, + // but the test can still pass because the gradient exists in the current process). + TensorFlow.registerCustomGradient( + "SigmoidGrad", + (RawCustomGradient) + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand y = (Operand) op.input(0); // sigmoid(x) + @SuppressWarnings("unchecked") + Operand dy = (Operand) op.input(1); // upstream into SigmoidGrad + @SuppressWarnings("unchecked") + Operand upstream = (Operand) gradInputs.get(0); + + Operand one = tf.constant(1.0f); + Operand yTimesOneMinusY = tf.math.mul(y, tf.math.sub(one, y)); + + // dL/d(dy) = upstream * y*(1-y) + Operand dDy = tf.math.mul(upstream, yTimesOneMinusY); + + // dL/d(y) not needed for this test; return zeros to keep it non-null. + Operand dY = tf.zerosLike(y); + + return java.util.Arrays.asList(dY, dDy); + }); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand x = tf.placeholder(TFloat32.class); + Operand y = tf.math.sigmoid(x); + + // Provide an explicit seed dy to avoid Graph.addGradients defaulting to OnesLike(y) + Operand seed = tf.fill(tf.shape(y), tf.constant(1.0f)); + + Output[] grads = + g.addGradients( + "seed", + new Output[] {y.asOutput()}, + new Output[] {x.asOutput()}, + new Output[] {seed.asOutput()}); + + assertNotNull(grads); + assertEquals(1, grads.length); + assertNotNull(grads[0], "Expected a non-null gradient for sigmoid(x) wrt x."); + assertFalse(grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); + } + } +} diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index cd3be39fde2..01f8fe59b4b 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -213,6 +213,19 @@ public void map(InfoMap infoMap) { // Skip C++ classes infoMap.put(new Info("tsl::StatusGroup").skip()); + + // Force correct marshalling of TFJ_RegisterCustomGradient callback argument. + // Without an explicit cast, JavaCPP may pass a NULL function pointer for some FunctionPointer + // instances. + infoMap.put( + new Info("TFJ_RegisterCustomGradient") + .javaText( + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") BytePointer op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n" + + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") String op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n")); } @Override diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 6882cfa704c..ad68e1e5c05 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -18,6 +18,10 @@ limitations under the License. #include #include +#include +#include +#include + #include "tfj_graph.h" #include "tsl/platform/errors.h" #include "tensorflow/c/c_api.h" @@ -32,7 +36,7 @@ namespace tensorflow { unordered_map g_grad_func_adapters; /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit - + /// /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 template T* struct_cast(U* ptr) { @@ -53,16 +57,31 @@ namespace tensorflow { if (found_adapter == g_grad_func_adapters.end()) { return errors::NotFound("No gradient adapter found for operation ", op_type); } + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == NULL) { + return errors::Unknown("Null Java gradient adapter for operation ", op_type); + } + int num_inputs = grad_inputs.size(); - TF_Output* inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + TF_Output* inputs = NULL; + if (num_inputs > 0) { + inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + if (inputs == NULL) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + for (int i = 0; i < num_inputs; ++i) { Output grad_input = grad_inputs[i]; inputs[i].oper = struct_cast(grad_input.node()); inputs[i].index = grad_input.index(); } + TF_Output* outputs = NULL; LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - int num_outputs = found_adapter->second( + int num_outputs = adapter( static_cast(scope.graph()), struct_cast(const_cast(&scope)), struct_cast(op.node()), @@ -70,12 +89,39 @@ namespace tensorflow { num_inputs, &outputs ); + + if (inputs != NULL) { + free(inputs); + } + + if (num_outputs < 0) { + if (outputs != NULL) { + free(outputs); + } + return errors::Unknown("Java custom gradient adapter failed for operation ", op_type, + " (num_outputs=", num_outputs, ")"); + } + + if (num_outputs > 0 && outputs == NULL) { + return errors::Unknown("Java custom gradient adapter returned null outputs for operation ", + op_type, " with num_outputs=", num_outputs); + } + for (int i = 0; i < num_outputs; ++i) { TF_Output output = outputs[i]; - grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + + // Convention: output.oper == NULL => NoGradient + if (output.oper == NULL) { + grad_outputs->push_back(Output()); + } else { + grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + } } - free(inputs); - free(outputs); // outputs are allocated from Java but must be freed here + + if (outputs != NULL) { + free(outputs); // outputs are allocated from Java but must be freed here + } + return OkStatus(); } } @@ -91,6 +137,11 @@ bool TFJ_HasGradient(const char* op_type) { } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { + if (grad_func_adapter == NULL) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for operation " << op_type; + return false; + } + if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type << ", which has already a registered function";