Merge pull request #26181 from irenedea:withControlDependencies

PiperOrigin-RevId: 241613647
This commit is contained in:
TensorFlower Gardener 2019-04-02 15:32:37 -07:00
commit a6a0611a63
4 changed files with 87 additions and 4 deletions

View File

@ -243,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
writer->EndLine();
}
}
// Add control dependencies, if any.
writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
writer->EndLine();
for (const AttributeSpec& attribute : op.attributes()) {
WriteSetAttrDirective(attribute, false, writer);
}

View File

@ -54,6 +54,8 @@ import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.WildcardTypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
@ -158,6 +160,13 @@ public final class OperatorProcessor extends AbstractProcessor {
private static final TypeName T_EXEC_ENV =
ClassName.get("org.tensorflow", "ExecutionEnvironment");
private static final TypeName T_STRING = ClassName.get(String.class);
// Operand<?>
private static final TypeName T_OPERAND =
ParameterizedTypeName.get(
ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class));
// Iterable<Operand<?>>
private static final TypeName T_ITERABLE_OPERAND =
ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND);
private Filer filer;
private Messager messager;
@ -393,6 +402,18 @@ public final class OperatorProcessor extends AbstractProcessor {
T_SCOPE)
.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withControlDependencies")
.addModifiers(Modifier.PUBLIC)
.addParameter(T_ITERABLE_OPERAND, "controls")
.returns(T_OPS)
.addStatement("return new Ops(scope.withControlDependencies(controls))")
.addJavadoc(
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
+ "@see {@link $T#withControlDependencies(Iterable<Operand<?>>)}\n",
T_SCOPE)
.build());
opsBuilder.addField(
FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());

View File

@ -16,6 +16,10 @@ limitations under the License.
package org.tensorflow.op;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
import org.tensorflow.OperationBuilder;
import java.util.ArrayList;
/**
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
@ -81,7 +85,7 @@ public final class Scope {
* @param env The execution environment used by the scope.
*/
public Scope(ExecutionEnvironment env) {
this(env, new NameScope());
this(env, new NameScope(), new ArrayList<Operand<?>>());
}
/** Returns the execution environment used by this scope. */
@ -103,7 +107,7 @@ public final class Scope {
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withSubScope(String childScopeName) {
return new Scope(env, nameScope.withSubScope(childScopeName));
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies);
}
/**
@ -119,7 +123,7 @@ public final class Scope {
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withName(String opName) {
return new Scope(env, nameScope.withName(opName));
return new Scope(env, nameScope.withName(opName), controlDependencies);
}
/**
@ -146,11 +150,39 @@ public final class Scope {
return nameScope.makeOpName(defaultName);
}
private Scope(ExecutionEnvironment env, NameScope nameScope) {
private Scope(
ExecutionEnvironment env, NameScope nameScope, Iterable<Operand<?>> controlDependencies) {
this.env = env;
this.nameScope = nameScope;
this.controlDependencies = controlDependencies;
}
/**
* Returns a new scope where added operations will have the provided control dependencies.
*
* <p>Ops created with this scope will have a control edge from each of the provided controls. All
* other properties are inherited from the current scope.
*
* @param controls control dependencies for ops created with the returned scope
* @return a new scope with the provided control dependencies
*/
public Scope withControlDependencies(Iterable<Operand<?>> controls) {
return new Scope(env, nameScope, controls);
}
/**
* Adds each Operand in controlDependencies as a control input to the provided builder.
*
* @param builder OperationBuilder to add control inputs to
*/
public OperationBuilder applyControlDependencies(OperationBuilder builder) {
for (Operand<?> control : controlDependencies) {
builder = builder.addControlInput(control.asOutput().op());
}
return builder;
}
private final ExecutionEnvironment env;
private final Iterable<Operand<?>> controlDependencies;
private final NameScope nameScope;
}

View File

@ -25,6 +25,7 @@ import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
@RunWith(JUnit4.class)
@ -57,4 +58,29 @@ public final class GeneratedOperationsTest {
}
}
}
/**
* Test for Ops.withControlDependencies.
*
* <p>Creates an add node with a control dependency to an assign node. In other words, the assign
* node is a control input to the add node. When the add node is run, the assign node is expected
* to have run beforehand due to the control dependency.
*/
@Test
public void testControlDependencies() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Ops ops = Ops.create(g);
Operand<Integer> variable = ops.variable(Shape.scalar(), Integer.class);
Operand<?> initVariable = ops.assign(variable, ops.constant(0));
ArrayList<Operand<?>> controls = new ArrayList<Operand<?>>();
controls.add(ops.assign(variable, ops.constant(3)));
Operand<Integer> x =
ops.withControlDependencies(controls).math().add(variable, ops.constant(0));
sess.runner().addTarget(initVariable).run();
try (Tensor<Integer> result = sess.runner().fetch(x).run().get(0).expect(Integer.class); ) {
assertEquals(3, result.intValue());
}
}
}
}