From d153f264c52d7056b0cf1a1933ceeedde4ded753 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 9 Sep 2019 19:42:50 -0700 Subject: [PATCH] Eliminate TPUReplicatedInput and TPUReplicatedOutput when TPURewrite is complete These operations are markers in the graph for each input/output of the TPU cluster. They don't have a kernel and can't be executed: they are only present to drive the rewrite and must be eliminated. PiperOrigin-RevId: 268133032 --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 18 ++++++++++++++++++ .../tensorflow/transforms/tpu_rewrite_pass.cc | 13 +++++++++++++ 2 files changed, 31 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 3b324a75027..e91e772d47f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -418,3 +418,21 @@ func @tpu0_func(%arg0: tensor) -> tensor { %0 = "tf.B"(%arg0) : (tensor) -> tensor return %0 : tensor } + + +// ----- + +// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten + +func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> { + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1 + %0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + %1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + %2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + %3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + return %3 : tensor<0xf32> +} +func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + %0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32> + return %0 : tensor<0xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 5b453aab192..91fc073e1f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir @@ -22,6 +23,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -267,6 +269,17 @@ void TPURewritePass::runOnModule() { Rewrite(op, &builder); }); + // Eliminate TPUReplicatedInput and TPUReplicatedOutput now that the rewrite + // is complete. + getModule().walk([&](Operation* op) { + auto op_name = op->getName().getStringRef(); + if (op_name != "tf.TPUReplicatedInput" && + op_name != "tf.TPUReplicatedOutput") + return; + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + }); + // TODO(b/139377366): Remove functions that are no longer needed. }