From a86557b09724bf0a1d4b5539848ca5d7475c985b Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 27 Feb 2019 09:00:53 -0800 Subject: [PATCH] [TF Java] Adds withcontrolDependencies to org.tensorflow.op.Ops Co-authored-by: Melissa Grueter Co-authored-by: Samantha Andow --- tensorflow/java/src/gen/cc/op_generator.cc | 4 ++ .../processor/OperatorProcessor.java | 19 +++++++++ .../main/java/org/tensorflow/op/Scope.java | 39 +++++++++++++++++-- .../op/core/GeneratedOperationsTest.java | 27 +++++++++++++ 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 80f8738d30d..3eb8c5c7129 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -243,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, writer->EndLine(); } } + // Add control dependencies, if any. + writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);"); + writer->EndLine(); + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index a9e61a34397..5993cdedecd 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -54,6 +54,8 @@ import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterSpec; import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.WildcardTypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; @@ -158,6 +160,11 @@ public final class OperatorProcessor extends AbstractProcessor { private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment"); private static final TypeName T_STRING = ClassName.get(String.class); + // Operand + private static final TypeName T_OPERAND = + ParameterizedTypeName.get(ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class)); + // Iterable> + private static final TypeName T_ITERABLE_OPERAND = ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND); private Filer filer; private Messager messager; @@ -393,6 +400,18 @@ public final class OperatorProcessor extends AbstractProcessor { T_SCOPE) .build()); + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencies") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_ITERABLE_OPERAND, "controls") + .returns(T_OPS) + .addStatement("return new Ops(scope.withControlDependencies(controls))") + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencies(Iterable>)}\n", + T_SCOPE) + .build()); + opsBuilder.addField( FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 99026648f9a..fa9da79614e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -16,6 +16,10 @@ limitations under the License. package org.tensorflow.op; import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.OperationBuilder; + +import java.util.ArrayList; /** * Manages groups of related properties when creating Tensorflow Operations, such as a common name @@ -81,7 +85,7 @@ public final class Scope { * @param env The execution environment used by the scope. */ public Scope(ExecutionEnvironment env) { - this(env, new NameScope()); + this(env, new NameScope(), new ArrayList>()); } /** Returns the execution environment used by this scope. */ @@ -103,7 +107,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withSubScope(String childScopeName) { - return new Scope(env, nameScope.withSubScope(childScopeName)); + return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies); } /** @@ -119,7 +123,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withName(String opName) { - return new Scope(env, nameScope.withName(opName)); + return new Scope(env, nameScope.withName(opName), controlDependencies); } /** @@ -146,11 +150,38 @@ public final class Scope { return nameScope.makeOpName(defaultName); } - private Scope(ExecutionEnvironment env, NameScope nameScope) { + private Scope(ExecutionEnvironment env, NameScope nameScope, Iterable> controlDependencies) { this.env = env; this.nameScope = nameScope; + this.controlDependencies = controlDependencies; + } + + /** + * Returns a new scope where added operations will have the provided control dependencies. + * + *

Ops created with this scope will have a control edge from each of the provided controls. + * All other properties are inherited from the current scope. + * + * @param controls control dependencies for ops created with the returned scope + * @return a new scope with the provided control dependencies + */ + public Scope withControlDependencies(Iterable> controls) { + return new Scope(env, nameScope, controls); + } + + /** + * Adds each Operand in controlDependencies as a control input to the provided builder. + * + * @param builder OperationBuilder to add control inputs to + */ + public OperationBuilder applyControlDependencies(OperationBuilder builder) { + for (Operand control : controlDependencies) { + builder = builder.addControlInput(control.asOutput().op()); + } + return builder; } private final ExecutionEnvironment env; + private final Iterable> controlDependencies; private final NameScope nameScope; } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index 49c4ff639ec..cbeebd85fbe 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -25,6 +25,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.Shape; import org.tensorflow.op.Ops; @RunWith(JUnit4.class) @@ -57,4 +58,30 @@ public final class GeneratedOperationsTest { } } } + + /** + * Test for Ops.withControlDependencies. + * + * Creates an add node with a control dependency to an assign node. In other words, + * the assign node is a control input to the add node. + * When the add node is run, the assign node is expected to have run beforehand + * due to the control dependency. + */ + @Test + public void testControlDependencies() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Ops ops = Ops.create(g); + Operand variable = ops.variable(Shape.scalar(), Integer.class); + Operand initVariable = ops.assign(variable, ops.constant(0)); + ArrayList> controls = new ArrayList>(); + controls.add(ops.assign(variable, ops.constant(3))); + Operand x = ops.withControlDependencies(controls).math().add(variable, ops.constant(0)); + sess.runner().addTarget(initVariable).run(); + try (Tensor result = sess.runner().fetch(x).run().get(0).expect(Integer.class); + ) { + assertEquals(3, result.intValue()); + } + } + } }