Print functional type Switch op if predicate is unranked.
For the short form parsing of Switch op, the assumption is that all data input and outputs have the same type, and the predicate is tensor<i1>. If the predicate is tensor<*xi1>, print the functional type format. PiperOrigin-RevId: 290757098 Change-Id: I6e07ee46012428e5ae2eec7188dbfd99bdf38452
This commit is contained in:
parent
6a75d9fb25
commit
e7530cd06f
@ -475,7 +475,8 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
// Support parsing either a functional type (in which case all the types are
|
||||
// fully qualified) or a short form with a single type (in which case the data
|
||||
// input and the outputs are all using this type).
|
||||
// input and the outputs are all using this type and predicate is tensor<i1>
|
||||
// type).
|
||||
if (types.front().isa<FunctionType>()) {
|
||||
FunctionType type = types.front().cast<FunctionType>();
|
||||
if (type.getNumInputs() != 2)
|
||||
@ -508,7 +509,8 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) {
|
||||
// else print the shorter single type.
|
||||
p << " : ";
|
||||
if (switch_op.trueOutput().getType() != data_operand_ty ||
|
||||
switch_op.falseOutput().getType() != data_operand_ty) {
|
||||
switch_op.falseOutput().getType() != data_operand_ty ||
|
||||
switch_op.predicate().getType().isa<UnrankedTensorType>()) {
|
||||
p.printFunctionalType(switch_op.getOperation());
|
||||
} else {
|
||||
p << switch_op.getType(0);
|
||||
|
@ -177,6 +177,16 @@ func @switch_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @switch_with_unranked_pred(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<*xi1>) -> tensor<*xf32> {
|
||||
func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: tf_executor.Switch %{{.*}}, %{{.*}} : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
|
||||
%true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
|
||||
tf_executor.fetch %true : tensor<*xf32>
|
||||
}
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @switchN(
|
||||
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
Loading…
x
Reference in New Issue
Block a user