Allow both Enter(data, control) as well as Enter(data).
PiperOrigin-RevId: 315584317 Change-Id: I50e8651ccbf0957a7edf1ef958aa398235867292
This commit is contained in:
parent
b8de8f444e
commit
894f1324dd
@ -811,11 +811,13 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
|
||||
// fully qualified) or a short form with a single type (in which case the data
|
||||
// input and the outputs are all using this type).
|
||||
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
|
||||
if (type.getNumInputs() != 1)
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< " expects a single data type";
|
||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
// One data input, and any number of control inputs.
|
||||
if (type.getNumInputs() >= 1) {
|
||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
} else {
|
||||
return parser.emitError(parser.getNameLoc()) << " expects a data input";
|
||||
}
|
||||
} else {
|
||||
Type control_type = ControlType::get(context);
|
||||
types.append(op_infos.size() - 1, control_type);
|
||||
|
@ -416,6 +416,17 @@ func @enter_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<i1>) -> tensor<*xf32> {
|
||||
func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
|
||||
// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32>
|
||||
%res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control)
|
||||
tf_executor.fetch %res#0 : tensor<*xf32>
|
||||
}
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> {
|
||||
func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
|
Loading…
x
Reference in New Issue
Block a user