Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,17 @@ private static TF_Output toNativeOutputs(List<Operand<?>> outputs) {
new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class)));

for (int i = 0; i < outputs.size(); ++i) {
var output = outputs.get(i).asOutput();
Operand<?> operand = outputs.get(i);
var nativeOutput = nativeOutputs.getPointer(i);

// Convention: null Operand => NoGradient
if (operand == null) {
nativeOutput.oper((TF_Operation) null);
nativeOutput.index(0);
continue;
}

var output = operand.asOutput();
nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle());
nativeOutput.index(output.index());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.tensorflow.internal.c_api.TF_Library;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.CustomGradient;
import org.tensorflow.op.GradientDispatch;
import org.tensorflow.op.RawCustomGradient;
import org.tensorflow.op.RawOpInputs;
import org.tensorflow.op.annotation.OpInputsMetadata;
Expand Down Expand Up @@ -207,7 +208,10 @@ public static synchronized boolean registerCustomGradient(
if (hasGradient(opType)) {
return false;
}
TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient);

GradientDispatch.putRaw(opType, gradient);
TFJ_GradFuncAdapter g = GradientDispatch.adapter();

if (!TFJ_RegisterCustomGradient(opType, g)) {
return false;
}
Expand Down Expand Up @@ -255,7 +259,10 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
if (hasGradient(opType)) {
return false;
}
TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass);

GradientDispatch.putTyped(opType, gradient, inputClass);
TFJ_GradFuncAdapter g = GradientDispatch.adapter();

if (!TFJ_RegisterCustomGradient(opType, g)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
package org.tensorflow.op;

import java.util.List;
import org.bytedeco.javacpp.PointerPointer;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_GraphId;
import org.tensorflow.internal.c_api.TFJ_Scope;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;

/**
* A custom gradient for ops of type {@link T}. Should be registered using {@link
Expand Down Expand Up @@ -57,6 +62,31 @@ public interface CustomGradient<T extends RawOpInputs> {
*/
static <T extends RawOpInputs<?>> TFJ_GradFuncAdapter adapter(
CustomGradient<T> gradient, Class<T> opClass) {
return new TypedGradientAdapter<T>(gradient, opClass);

final TypedGradientAdapter<T> impl = new TypedGradientAdapter<T>(gradient, opClass);

// IMPORTANT:
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
// subclass.
return new TFJ_GradFuncAdapter() {
@Override
public int call(
TFJ_GraphId nativeGraphId,
TFJ_Scope nativeScope,
TF_Operation nativeOperation,
TF_Output nativeGradInputs,
int nativeGradInputsLength,
PointerPointer nativeGradOutputsPtr) {

return impl.call(
nativeGraphId,
nativeScope,
nativeOperation,
nativeGradInputs,
nativeGradInputsLength,
nativeGradOutputsPtr);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow.op;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the copyright and license header to the new files?


import java.lang.reflect.Constructor;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.tensorflow.AbstractGradientAdapter;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.internal.c_api.TFJ_Scope;

final class DispatchingGradientAdapter extends AbstractGradientAdapter {

private final ConcurrentMap<String, RawCustomGradient> raw = new ConcurrentHashMap<>();
private final ConcurrentMap<String, TypedEntry<?>> typed = new ConcurrentHashMap<>();

private static String dupMsg(String opType, String existingKind, String newKind) {
return "A "
+ existingKind
+ " gradient is already registered for op type '"
+ opType
+ "'. Raw and typed registrations are mutually exclusive; cannot register "
+ newKind
+ ".";
}

static final class TypedEntry<T extends RawOpInputs<?>> {
final CustomGradient<T> grad;
final Class<T> inputClass;
final Constructor<T> ctor;

TypedEntry(CustomGradient<T> grad, Class<T> inputClass) {
this.grad = grad;
this.inputClass = inputClass;
try {
this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class);
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException(
"Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).",
e);
}
}
}

void putRaw(String opType, RawCustomGradient g) {
if (typed.containsKey(opType)) {
throw new IllegalStateException(dupMsg(opType, "typed", "raw"));
}
RawCustomGradient prev = raw.putIfAbsent(opType, g);
if (prev != null) {
throw new IllegalStateException(
"A raw gradient is already registered for op type '" + opType + "'.");
}
}

<T extends RawOpInputs<?>> void putTyped(
String opType, CustomGradient<T> g, Class<T> inputClass) {
if (raw.containsKey(opType)) {
throw new IllegalStateException(dupMsg(opType, "raw", "typed"));
}
TypedEntry<?> prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass));
if (prev != null) {
throw new IllegalStateException(
"A typed gradient is already registered for op type '" + opType + "'.");
}
}

@Override
protected List<Operand<?>> apply(
Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs) {

final String opType = operation.type();

RawCustomGradient rg = raw.get(opType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic prefers raw gradients over typed ones, but there isn't anything documented about why it prefers them or if it makes sense to add both raw and typed gradients for the same op. It would be good to clarify this, and if it doesn't make sense to have both kinds of gradients the adapter should reject them in the puts.

if (rg != null) {
// NativeScope & Ops constructors are package-private => must be in org.tensorflow.op
Scope nativeScope =
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
return rg.call(new Ops(nativeScope), operation, gradInputs);
}

@SuppressWarnings("rawtypes")
TypedEntry te = typed.get(opType);
if (te != null) {
return applyTyped(graph, scope, operation, gradInputs, te);
}

throw new IllegalStateException("No Java custom gradient registered for op type: " + opType);
}

private <T extends RawOpInputs<?>> List<Operand<?>> applyTyped(
Graph graph,
TFJ_Scope scope,
GraphOperation operation,
List<Output<?>> gradInputs,
TypedEntry<T> te) {
try {
T inputs = te.ctor.newInstance(operation);
Scope nativeScope =
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
return te.grad.call(new Ops(nativeScope), inputs, gradInputs);
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow.op;

import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;

/** Public bridge to a single native gradient adapter. */
public final class GradientDispatch {

// package-private adapter that can access NativeScope/Ops constructors
static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter();

private GradientDispatch() {}

public static TFJ_GradFuncAdapter adapter() {
return ADAPTER;
}

public static void putRaw(String opType, RawCustomGradient gradient) {
ADAPTER.putRaw(opType, gradient);
}

public static <T extends RawOpInputs<?>> void putTyped(
String opType, CustomGradient<T> gradient, Class<T> inputClass) {
ADAPTER.putTyped(opType, gradient, inputClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
package org.tensorflow.op;

import java.util.List;
import org.bytedeco.javacpp.PointerPointer;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_GraphId;
import org.tensorflow.internal.c_api.TFJ_Scope;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;

/**
* A custom gradient for an op of unspecified type. Should be registered using {@link
Expand Down Expand Up @@ -54,6 +59,30 @@ public interface RawCustomGradient {
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
*/
static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) {
return new RawGradientAdapter(gradient);
final RawGradientAdapter impl = new RawGradientAdapter(gradient);

// IMPORTANT:
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
// subclass.
return new TFJ_GradFuncAdapter() {
@Override
public int call(
TFJ_GraphId nativeGraphId,
TFJ_Scope nativeScope,
TF_Operation nativeOperation,
TF_Output nativeGradInputs,
int nativeGradInputsLength,
PointerPointer nativeGradOutputsPtr) {

return impl.call(
nativeGraphId,
nativeScope,
nativeOperation,
nativeGradInputs,
nativeGradInputsLength,
nativeGradOutputsPtr);
}
};
}
}
Loading