diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 13dc2993371..08ced93f6eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -475,7 +475,8 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data - // input and the outputs are all using this type). + // input and the outputs are all using this type and predicate is tensor + // type). if (types.front().isa()) { FunctionType type = types.front().cast(); if (type.getNumInputs() != 2) @@ -508,7 +509,8 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) { // else print the shorter single type. p << " : "; if (switch_op.trueOutput().getType() != data_operand_ty || - switch_op.falseOutput().getType() != data_operand_ty) { + switch_op.falseOutput().getType() != data_operand_ty || + switch_op.predicate().getType().isa()) { p.printFunctionalType(switch_op.getOperation()); } else { p << switch_op.getType(0); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 03184ff6de8..6282ab17f17 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -177,6 +177,16 @@ func @switch_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor< return %result : tensor<*xf32> } +// CHECK-LABEL: func @switch_with_unranked_pred(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<*xi1>) -> tensor<*xf32> { +func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> tensor<*xf32> { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{.*}}, %{{.*}} : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + %true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + tf_executor.fetch %true : tensor<*xf32> + } + return %result : tensor<*xf32> +} + // CHECK-LABEL: func @switchN( func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph {