From e7530cd06f08a622046962f41962759f3f070c9a Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Tue, 21 Jan 2020 09:17:59 -0800 Subject: [PATCH] Print functional type Switch op if predicate is unranked. For the short form parsing of Switch op, the assumption is that all data input and outputs have the same type, and the predicate is tensor. If the predicate is tensor<*xi1>, print the functional type format. PiperOrigin-RevId: 290757098 Change-Id: I6e07ee46012428e5ae2eec7188dbfd99bdf38452 --- tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc | 6 ++++-- .../mlir/tensorflow/tests/tf_executor_ops.mlir | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 13dc2993371..08ced93f6eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -475,7 +475,8 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data - // input and the outputs are all using this type). + // input and the outputs are all using this type and predicate is tensor + // type). if (types.front().isa()) { FunctionType type = types.front().cast(); if (type.getNumInputs() != 2) @@ -508,7 +509,8 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) { // else print the shorter single type. p << " : "; if (switch_op.trueOutput().getType() != data_operand_ty || - switch_op.falseOutput().getType() != data_operand_ty) { + switch_op.falseOutput().getType() != data_operand_ty || + switch_op.predicate().getType().isa()) { p.printFunctionalType(switch_op.getOperation()); } else { p << switch_op.getType(0); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 03184ff6de8..6282ab17f17 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -177,6 +177,16 @@ func @switch_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor< return %result : tensor<*xf32> } +// CHECK-LABEL: func @switch_with_unranked_pred(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<*xi1>) -> tensor<*xf32> { +func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> tensor<*xf32> { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{.*}}, %{{.*}} : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + %true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : (tensor<*xf32>, tensor<*xi1>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + tf_executor.fetch %true : tensor<*xf32> + } + return %result : tensor<*xf32> +} + // CHECK-LABEL: func @switchN( func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph {