Print functional type Switch op if predicate is unranked.

For the short form parsing of Switch op, the assumption is that all data input and outputs have the same type, and the predicate is tensor<i1>. If the predicate is tensor<*xi1>, print the functional type format.

PiperOrigin-RevId: 290757098
Change-Id: I6e07ee46012428e5ae2eec7188dbfd99bdf38452
This commit is contained in:
Prakalp Srivastava 2020-01-21 09:17:59 -08:00 committed by TensorFlower Gardener
parent 6a75d9fb25
commit e7530cd06f
2 changed files with 14 additions and 2 deletions

View File

@ -475,7 +475,8 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
// Support parsing either a functional type (in which case all the types are
// fully qualified) or a short form with a single type (in which case the data
// input and the outputs are all using this type).
// input and the outputs are all using this type and predicate is tensor<i1>
// type).
if (types.front().isa<FunctionType>()) {
FunctionType type = types.front().cast<FunctionType>();
if (type.getNumInputs() != 2)
@ -508,7 +509,8 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) {
// else print the shorter single type.
p << " : ";
if (switch_op.trueOutput().getType() != data_operand_ty ||
switch_op.falseOutput().getType() != data_operand_ty) {
switch_op.falseOutput().getType() != data_operand_ty ||
switch_op.predicate().getType().isa<UnrankedTensorType>()) {
p.printFunctionalType(switch_op.getOperation());
} else {
p << switch_op.getType(0);

View File

@ -177,6 +177,16 @@ func @switch_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<
return %result : tensor<*xf32>
}
// CHECK-LABEL: func @switch_with_unranked_pred(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<*xi1>) -> tensor<*xf32> {
func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> tensor<*xf32> {
%result = tf_executor.graph {
// CHECK: tf_executor.Switch %{{.*}}, %{{.*}} : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
%true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
tf_executor.fetch %true : tensor<*xf32>
}
return %result : tensor<*xf32>
}
// CHECK-LABEL: func @switchN(
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {