Allow both Enter(data, control) as well as Enter(data).

PiperOrigin-RevId: 315584317
Change-Id: I50e8651ccbf0957a7edf1ef958aa398235867292
This commit is contained in:
A. Unique TensorFlower 2020-06-09 16:30:50 -07:00 committed by TensorFlower Gardener
parent b8de8f444e
commit 894f1324dd
2 changed files with 18 additions and 5 deletions

View File

@ -811,11 +811,13 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
// fully qualified) or a short form with a single type (in which case the data
// input and the outputs are all using this type).
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
if (type.getNumInputs() != 1)
return parser.emitError(parser.getNameLoc())
<< " expects a single data type";
result.types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
// One data input, and any number of control inputs.
if (type.getNumInputs() >= 1) {
result.types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
} else {
return parser.emitError(parser.getNameLoc()) << " expects a data input";
}
} else {
Type control_type = ControlType::get(context);
types.append(op_infos.size() - 1, control_type);

View File

@ -416,6 +416,17 @@ func @enter_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<i1>) -> tensor<*xf32> {
func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32>
%res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control)
tf_executor.fetch %res#0 : tensor<*xf32>
}
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> {
func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%0 = tf_executor.graph {