diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 80f8738d30d..3eb8c5c7129 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -243,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, writer->EndLine(); } } + // Add control dependencies, if any. + writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);"); + writer->EndLine(); + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 289b0f37666..e3e40a17df5 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -54,6 +54,8 @@ import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterSpec; import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.WildcardTypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; @@ -158,6 +160,13 @@ public final class OperatorProcessor extends AbstractProcessor { private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment"); private static final TypeName T_STRING = ClassName.get(String.class); + // Operand + private static final TypeName T_OPERAND = + ParameterizedTypeName.get( + ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class)); + // Iterable> + private static final TypeName T_ITERABLE_OPERAND = + ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND); private Filer filer; private Messager messager; @@ -393,6 +402,18 @@ public final class OperatorProcessor extends AbstractProcessor { T_SCOPE) .build()); + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencies") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_ITERABLE_OPERAND, "controls") + .returns(T_OPS) + .addStatement("return new Ops(scope.withControlDependencies(controls))") + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencies(Iterable>)}\n", + T_SCOPE) + .build()); + opsBuilder.addField( FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 99026648f9a..ccbf776cbe8 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -16,6 +16,10 @@ limitations under the License. package org.tensorflow.op; import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.OperationBuilder; + +import java.util.ArrayList; /** * Manages groups of related properties when creating Tensorflow Operations, such as a common name @@ -81,7 +85,7 @@ public final class Scope { * @param env The execution environment used by the scope. */ public Scope(ExecutionEnvironment env) { - this(env, new NameScope()); + this(env, new NameScope(), new ArrayList>()); } /** Returns the execution environment used by this scope. */ @@ -103,7 +107,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withSubScope(String childScopeName) { - return new Scope(env, nameScope.withSubScope(childScopeName)); + return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies); } /** @@ -119,7 +123,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withName(String opName) { - return new Scope(env, nameScope.withName(opName)); + return new Scope(env, nameScope.withName(opName), controlDependencies); } /** @@ -146,11 +150,39 @@ public final class Scope { return nameScope.makeOpName(defaultName); } - private Scope(ExecutionEnvironment env, NameScope nameScope) { + private Scope( + ExecutionEnvironment env, NameScope nameScope, Iterable> controlDependencies) { this.env = env; this.nameScope = nameScope; + this.controlDependencies = controlDependencies; + } + + /** + * Returns a new scope where added operations will have the provided control dependencies. + * + *

Ops created with this scope will have a control edge from each of the provided controls. All + * other properties are inherited from the current scope. + * + * @param controls control dependencies for ops created with the returned scope + * @return a new scope with the provided control dependencies + */ + public Scope withControlDependencies(Iterable> controls) { + return new Scope(env, nameScope, controls); + } + + /** + * Adds each Operand in controlDependencies as a control input to the provided builder. + * + * @param builder OperationBuilder to add control inputs to + */ + public OperationBuilder applyControlDependencies(OperationBuilder builder) { + for (Operand control : controlDependencies) { + builder = builder.addControlInput(control.asOutput().op()); + } + return builder; } private final ExecutionEnvironment env; + private final Iterable> controlDependencies; private final NameScope nameScope; } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index 49c4ff639ec..daafd6b9503 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -25,6 +25,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.Shape; import org.tensorflow.op.Ops; @RunWith(JUnit4.class) @@ -57,4 +58,29 @@ public final class GeneratedOperationsTest { } } } + + /** + * Test for Ops.withControlDependencies. + * + *

Creates an add node with a control dependency to an assign node. In other words, the assign + * node is a control input to the add node. When the add node is run, the assign node is expected + * to have run beforehand due to the control dependency. + */ + @Test + public void testControlDependencies() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Ops ops = Ops.create(g); + Operand variable = ops.variable(Shape.scalar(), Integer.class); + Operand initVariable = ops.assign(variable, ops.constant(0)); + ArrayList> controls = new ArrayList>(); + controls.add(ops.assign(variable, ops.constant(3))); + Operand x = + ops.withControlDependencies(controls).math().add(variable, ops.constant(0)); + sess.runner().addTarget(initVariable).run(); + try (Tensor result = sess.runner().fetch(x).run().get(0).expect(Integer.class); ) { + assertEquals(3, result.intValue()); + } + } + } }