diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8d2028128d6..d5ecbf3e292 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -474,7 +474,7 @@ namespace { ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { SmallVector op_infos; SmallVector types; - if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types)) + if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 1) return parser.emitError(parser.getNameLoc()) @@ -486,12 +486,15 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { // type). if (types.front().isa()) { FunctionType type = types.front().cast(); - if (type.getNumInputs() != 2) + if (type.getNumInputs() < 2) return parser.emitError(parser.getNameLoc()) << " expects a single data type and a predicate"; result.types.assign(type.getResults().begin(), type.getResults().end()); types.assign(type.getInputs().begin(), type.getInputs().end()); } else { + if (op_infos.size() < 2) + return parser.emitError(parser.getNameLoc()) + << " expects a single data type and a predicate"; Type control_type = ControlType::get(parser.getBuilder().getContext()); result.types.append(2, types[0]); result.types.push_back(control_type); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 6282ab17f17..c048db5a5ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -187,6 +187,26 @@ func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> te return %result : tensor<*xf32> } +// CHECK-LABEL: func @switch_with_control_inputs( +func @switch_with_control_inputs(%arg0: tensor, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor + %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor + tf_executor.fetch %1#0 : tensor + } + return %result : tensor +} + +// CHECK-LABEL: func @switch_with_control_inputs_functional( +func @switch_with_control_inputs_functional(%arg0: tensor, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor + %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor, tensor, !tf_executor.control, !tf_executor.control) -> (tensor, tensor, !tf_executor.control) + tf_executor.fetch %1#0 : tensor + } + return %result : tensor +} + // CHECK-LABEL: func @switchN( func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index a249090a3cf..1fdc99d1ec8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -333,7 +333,7 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { // ----- -// Check that a switch always takes two arguments. +// Check that a switch always needs at least two arguments. func @invalid_switch(%arg0: tensor<*xf32>) { tf_executor.graph { %true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) @@ -344,6 +344,17 @@ func @invalid_switch(%arg0: tensor<*xf32>) { // ----- +// Check that a switch always needs at least two arguments. +func @invalid_switch(%arg0: tensor<*xf32>) { + tf_executor.graph { + %true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32> +// expected-error@-1 {{custom op 'tf_executor.Switch' expects a single data type and a predicate}} + } + return +} + +// ----- + // Check that a switch second argument must be a valid predicate (i1). func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph {