From 894f1324dd2abad859b7ac9e74d9a70d363d9a57 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jun 2020 16:30:50 -0700 Subject: [PATCH] Allow both Enter(data, control) as well as Enter(data). PiperOrigin-RevId: 315584317 Change-Id: I50e8651ccbf0957a7edf1ef958aa398235867292 --- .../compiler/mlir/tensorflow/ir/tf_executor.cc | 12 +++++++----- .../mlir/tensorflow/tests/tf_executor_ops.mlir | 11 +++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 9daebc22ba1..3403651eef8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -811,11 +811,13 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) { // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type). if (FunctionType type = types.front().dyn_cast()) { - if (type.getNumInputs() != 1) - return parser.emitError(parser.getNameLoc()) - << " expects a single data type"; - result.types.assign(type.getResults().begin(), type.getResults().end()); - types.assign(type.getInputs().begin(), type.getInputs().end()); + // One data input, and any number of control inputs. + if (type.getNumInputs() >= 1) { + result.types.assign(type.getResults().begin(), type.getResults().end()); + types.assign(type.getInputs().begin(), type.getInputs().end()); + } else { + return parser.emitError(parser.getNameLoc()) << " expects a data input"; + } } else { Type control_type = ControlType::get(context); types.append(op_infos.size() - 1, control_type); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index c048db5a5ee..27b84724b4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -416,6 +416,17 @@ func @enter_control(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor) -> tensor<*xf32> { +func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = tf_executor.graph { + %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> +// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32> + %res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control) + tf_executor.fetch %res#0 : tensor<*xf32> + } + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> { func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %0 = tf_executor.graph {