Merge pull request #26181 from irenedea:withControlDependencies
PiperOrigin-RevId: 241613647
This commit is contained in:
commit
a6a0611a63
@ -243,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
|
||||
writer->EndLine();
|
||||
}
|
||||
}
|
||||
// Add control dependencies, if any.
|
||||
writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
|
||||
writer->EndLine();
|
||||
|
||||
for (const AttributeSpec& attribute : op.attributes()) {
|
||||
WriteSetAttrDirective(attribute, false, writer);
|
||||
}
|
||||
|
@ -54,6 +54,8 @@ import com.squareup.javapoet.JavaFile;
|
||||
import com.squareup.javapoet.MethodSpec;
|
||||
import com.squareup.javapoet.ParameterSpec;
|
||||
import com.squareup.javapoet.TypeName;
|
||||
import com.squareup.javapoet.ParameterizedTypeName;
|
||||
import com.squareup.javapoet.WildcardTypeName;
|
||||
import com.squareup.javapoet.TypeSpec;
|
||||
import com.squareup.javapoet.TypeVariableName;
|
||||
|
||||
@ -158,6 +160,13 @@ public final class OperatorProcessor extends AbstractProcessor {
|
||||
private static final TypeName T_EXEC_ENV =
|
||||
ClassName.get("org.tensorflow", "ExecutionEnvironment");
|
||||
private static final TypeName T_STRING = ClassName.get(String.class);
|
||||
// Operand<?>
|
||||
private static final TypeName T_OPERAND =
|
||||
ParameterizedTypeName.get(
|
||||
ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class));
|
||||
// Iterable<Operand<?>>
|
||||
private static final TypeName T_ITERABLE_OPERAND =
|
||||
ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND);
|
||||
|
||||
private Filer filer;
|
||||
private Messager messager;
|
||||
@ -393,6 +402,18 @@ public final class OperatorProcessor extends AbstractProcessor {
|
||||
T_SCOPE)
|
||||
.build());
|
||||
|
||||
opsBuilder.addMethod(
|
||||
MethodSpec.methodBuilder("withControlDependencies")
|
||||
.addModifiers(Modifier.PUBLIC)
|
||||
.addParameter(T_ITERABLE_OPERAND, "controls")
|
||||
.returns(T_OPS)
|
||||
.addStatement("return new Ops(scope.withControlDependencies(controls))")
|
||||
.addJavadoc(
|
||||
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
|
||||
+ "@see {@link $T#withControlDependencies(Iterable<Operand<?>>)}\n",
|
||||
T_SCOPE)
|
||||
.build());
|
||||
|
||||
opsBuilder.addField(
|
||||
FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
|
||||
|
||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
||||
package org.tensorflow.op;
|
||||
|
||||
import org.tensorflow.ExecutionEnvironment;
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.OperationBuilder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
|
||||
@ -81,7 +85,7 @@ public final class Scope {
|
||||
* @param env The execution environment used by the scope.
|
||||
*/
|
||||
public Scope(ExecutionEnvironment env) {
|
||||
this(env, new NameScope());
|
||||
this(env, new NameScope(), new ArrayList<Operand<?>>());
|
||||
}
|
||||
|
||||
/** Returns the execution environment used by this scope. */
|
||||
@ -103,7 +107,7 @@ public final class Scope {
|
||||
* @throws IllegalArgumentException if the name is invalid
|
||||
*/
|
||||
public Scope withSubScope(String childScopeName) {
|
||||
return new Scope(env, nameScope.withSubScope(childScopeName));
|
||||
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -119,7 +123,7 @@ public final class Scope {
|
||||
* @throws IllegalArgumentException if the name is invalid
|
||||
*/
|
||||
public Scope withName(String opName) {
|
||||
return new Scope(env, nameScope.withName(opName));
|
||||
return new Scope(env, nameScope.withName(opName), controlDependencies);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -146,11 +150,39 @@ public final class Scope {
|
||||
return nameScope.makeOpName(defaultName);
|
||||
}
|
||||
|
||||
private Scope(ExecutionEnvironment env, NameScope nameScope) {
|
||||
private Scope(
|
||||
ExecutionEnvironment env, NameScope nameScope, Iterable<Operand<?>> controlDependencies) {
|
||||
this.env = env;
|
||||
this.nameScope = nameScope;
|
||||
this.controlDependencies = controlDependencies;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new scope where added operations will have the provided control dependencies.
|
||||
*
|
||||
* <p>Ops created with this scope will have a control edge from each of the provided controls. All
|
||||
* other properties are inherited from the current scope.
|
||||
*
|
||||
* @param controls control dependencies for ops created with the returned scope
|
||||
* @return a new scope with the provided control dependencies
|
||||
*/
|
||||
public Scope withControlDependencies(Iterable<Operand<?>> controls) {
|
||||
return new Scope(env, nameScope, controls);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds each Operand in controlDependencies as a control input to the provided builder.
|
||||
*
|
||||
* @param builder OperationBuilder to add control inputs to
|
||||
*/
|
||||
public OperationBuilder applyControlDependencies(OperationBuilder builder) {
|
||||
for (Operand<?> control : controlDependencies) {
|
||||
builder = builder.addControlInput(control.asOutput().op());
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
private final ExecutionEnvironment env;
|
||||
private final Iterable<Operand<?>> controlDependencies;
|
||||
private final NameScope nameScope;
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ import org.tensorflow.Graph;
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.Shape;
|
||||
import org.tensorflow.op.Ops;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
@ -57,4 +58,29 @@ public final class GeneratedOperationsTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test for Ops.withControlDependencies.
|
||||
*
|
||||
* <p>Creates an add node with a control dependency to an assign node. In other words, the assign
|
||||
* node is a control input to the add node. When the add node is run, the assign node is expected
|
||||
* to have run beforehand due to the control dependency.
|
||||
*/
|
||||
@Test
|
||||
public void testControlDependencies() {
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Ops ops = Ops.create(g);
|
||||
Operand<Integer> variable = ops.variable(Shape.scalar(), Integer.class);
|
||||
Operand<?> initVariable = ops.assign(variable, ops.constant(0));
|
||||
ArrayList<Operand<?>> controls = new ArrayList<Operand<?>>();
|
||||
controls.add(ops.assign(variable, ops.constant(3)));
|
||||
Operand<Integer> x =
|
||||
ops.withControlDependencies(controls).math().add(variable, ops.constant(0));
|
||||
sess.runner().addTarget(initVariable).run();
|
||||
try (Tensor<Integer> result = sess.runner().fetch(x).run().get(0).expect(Integer.class); ) {
|
||||
assertEquals(3, result.intValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user