Merge pull request #26181 from irenedea:withControlDependencies
PiperOrigin-RevId: 241613647
This commit is contained in:
commit
a6a0611a63
@ -243,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
|
|||||||
writer->EndLine();
|
writer->EndLine();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Add control dependencies, if any.
|
||||||
|
writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
|
||||||
|
writer->EndLine();
|
||||||
|
|
||||||
for (const AttributeSpec& attribute : op.attributes()) {
|
for (const AttributeSpec& attribute : op.attributes()) {
|
||||||
WriteSetAttrDirective(attribute, false, writer);
|
WriteSetAttrDirective(attribute, false, writer);
|
||||||
}
|
}
|
||||||
|
@ -54,6 +54,8 @@ import com.squareup.javapoet.JavaFile;
|
|||||||
import com.squareup.javapoet.MethodSpec;
|
import com.squareup.javapoet.MethodSpec;
|
||||||
import com.squareup.javapoet.ParameterSpec;
|
import com.squareup.javapoet.ParameterSpec;
|
||||||
import com.squareup.javapoet.TypeName;
|
import com.squareup.javapoet.TypeName;
|
||||||
|
import com.squareup.javapoet.ParameterizedTypeName;
|
||||||
|
import com.squareup.javapoet.WildcardTypeName;
|
||||||
import com.squareup.javapoet.TypeSpec;
|
import com.squareup.javapoet.TypeSpec;
|
||||||
import com.squareup.javapoet.TypeVariableName;
|
import com.squareup.javapoet.TypeVariableName;
|
||||||
|
|
||||||
@ -158,6 +160,13 @@ public final class OperatorProcessor extends AbstractProcessor {
|
|||||||
private static final TypeName T_EXEC_ENV =
|
private static final TypeName T_EXEC_ENV =
|
||||||
ClassName.get("org.tensorflow", "ExecutionEnvironment");
|
ClassName.get("org.tensorflow", "ExecutionEnvironment");
|
||||||
private static final TypeName T_STRING = ClassName.get(String.class);
|
private static final TypeName T_STRING = ClassName.get(String.class);
|
||||||
|
// Operand<?>
|
||||||
|
private static final TypeName T_OPERAND =
|
||||||
|
ParameterizedTypeName.get(
|
||||||
|
ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class));
|
||||||
|
// Iterable<Operand<?>>
|
||||||
|
private static final TypeName T_ITERABLE_OPERAND =
|
||||||
|
ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND);
|
||||||
|
|
||||||
private Filer filer;
|
private Filer filer;
|
||||||
private Messager messager;
|
private Messager messager;
|
||||||
@ -393,6 +402,18 @@ public final class OperatorProcessor extends AbstractProcessor {
|
|||||||
T_SCOPE)
|
T_SCOPE)
|
||||||
.build());
|
.build());
|
||||||
|
|
||||||
|
opsBuilder.addMethod(
|
||||||
|
MethodSpec.methodBuilder("withControlDependencies")
|
||||||
|
.addModifiers(Modifier.PUBLIC)
|
||||||
|
.addParameter(T_ITERABLE_OPERAND, "controls")
|
||||||
|
.returns(T_OPS)
|
||||||
|
.addStatement("return new Ops(scope.withControlDependencies(controls))")
|
||||||
|
.addJavadoc(
|
||||||
|
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
|
||||||
|
+ "@see {@link $T#withControlDependencies(Iterable<Operand<?>>)}\n",
|
||||||
|
T_SCOPE)
|
||||||
|
.build());
|
||||||
|
|
||||||
opsBuilder.addField(
|
opsBuilder.addField(
|
||||||
FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
|
FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
|
||||||
|
|
||||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
|||||||
package org.tensorflow.op;
|
package org.tensorflow.op;
|
||||||
|
|
||||||
import org.tensorflow.ExecutionEnvironment;
|
import org.tensorflow.ExecutionEnvironment;
|
||||||
|
import org.tensorflow.Operand;
|
||||||
|
import org.tensorflow.OperationBuilder;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
|
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
|
||||||
@ -81,7 +85,7 @@ public final class Scope {
|
|||||||
* @param env The execution environment used by the scope.
|
* @param env The execution environment used by the scope.
|
||||||
*/
|
*/
|
||||||
public Scope(ExecutionEnvironment env) {
|
public Scope(ExecutionEnvironment env) {
|
||||||
this(env, new NameScope());
|
this(env, new NameScope(), new ArrayList<Operand<?>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the execution environment used by this scope. */
|
/** Returns the execution environment used by this scope. */
|
||||||
@ -103,7 +107,7 @@ public final class Scope {
|
|||||||
* @throws IllegalArgumentException if the name is invalid
|
* @throws IllegalArgumentException if the name is invalid
|
||||||
*/
|
*/
|
||||||
public Scope withSubScope(String childScopeName) {
|
public Scope withSubScope(String childScopeName) {
|
||||||
return new Scope(env, nameScope.withSubScope(childScopeName));
|
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -119,7 +123,7 @@ public final class Scope {
|
|||||||
* @throws IllegalArgumentException if the name is invalid
|
* @throws IllegalArgumentException if the name is invalid
|
||||||
*/
|
*/
|
||||||
public Scope withName(String opName) {
|
public Scope withName(String opName) {
|
||||||
return new Scope(env, nameScope.withName(opName));
|
return new Scope(env, nameScope.withName(opName), controlDependencies);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -146,11 +150,39 @@ public final class Scope {
|
|||||||
return nameScope.makeOpName(defaultName);
|
return nameScope.makeOpName(defaultName);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Scope(ExecutionEnvironment env, NameScope nameScope) {
|
private Scope(
|
||||||
|
ExecutionEnvironment env, NameScope nameScope, Iterable<Operand<?>> controlDependencies) {
|
||||||
this.env = env;
|
this.env = env;
|
||||||
this.nameScope = nameScope;
|
this.nameScope = nameScope;
|
||||||
|
this.controlDependencies = controlDependencies;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new scope where added operations will have the provided control dependencies.
|
||||||
|
*
|
||||||
|
* <p>Ops created with this scope will have a control edge from each of the provided controls. All
|
||||||
|
* other properties are inherited from the current scope.
|
||||||
|
*
|
||||||
|
* @param controls control dependencies for ops created with the returned scope
|
||||||
|
* @return a new scope with the provided control dependencies
|
||||||
|
*/
|
||||||
|
public Scope withControlDependencies(Iterable<Operand<?>> controls) {
|
||||||
|
return new Scope(env, nameScope, controls);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds each Operand in controlDependencies as a control input to the provided builder.
|
||||||
|
*
|
||||||
|
* @param builder OperationBuilder to add control inputs to
|
||||||
|
*/
|
||||||
|
public OperationBuilder applyControlDependencies(OperationBuilder builder) {
|
||||||
|
for (Operand<?> control : controlDependencies) {
|
||||||
|
builder = builder.addControlInput(control.asOutput().op());
|
||||||
|
}
|
||||||
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
private final ExecutionEnvironment env;
|
private final ExecutionEnvironment env;
|
||||||
|
private final Iterable<Operand<?>> controlDependencies;
|
||||||
private final NameScope nameScope;
|
private final NameScope nameScope;
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ import org.tensorflow.Graph;
|
|||||||
import org.tensorflow.Operand;
|
import org.tensorflow.Operand;
|
||||||
import org.tensorflow.Session;
|
import org.tensorflow.Session;
|
||||||
import org.tensorflow.Tensor;
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.Shape;
|
||||||
import org.tensorflow.op.Ops;
|
import org.tensorflow.op.Ops;
|
||||||
|
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
@ -57,4 +58,29 @@ public final class GeneratedOperationsTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test for Ops.withControlDependencies.
|
||||||
|
*
|
||||||
|
* <p>Creates an add node with a control dependency to an assign node. In other words, the assign
|
||||||
|
* node is a control input to the add node. When the add node is run, the assign node is expected
|
||||||
|
* to have run beforehand due to the control dependency.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testControlDependencies() {
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Ops ops = Ops.create(g);
|
||||||
|
Operand<Integer> variable = ops.variable(Shape.scalar(), Integer.class);
|
||||||
|
Operand<?> initVariable = ops.assign(variable, ops.constant(0));
|
||||||
|
ArrayList<Operand<?>> controls = new ArrayList<Operand<?>>();
|
||||||
|
controls.add(ops.assign(variable, ops.constant(3)));
|
||||||
|
Operand<Integer> x =
|
||||||
|
ops.withControlDependencies(controls).math().add(variable, ops.constant(0));
|
||||||
|
sess.runner().addTarget(initVariable).run();
|
||||||
|
try (Tensor<Integer> result = sess.runner().fetch(x).run().get(0).expect(Integer.class); ) {
|
||||||
|
assertEquals(3, result.intValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user