From 53c634a6c150da732dcd6305478ffecd6a887668 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Fri, 15 May 2020 10:17:05 -0700 Subject: [PATCH] [MLIR/XLA] Constant sinking to control flow regions. This is necessary for exporting to XLA since functional control flow is expected. PiperOrigin-RevId: 311753796 Change-Id: If4e50a3b2fa668f162c9b30cc80e2bf743a9b641 --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../tensorflow/utils/compile_mlir_util.cc | 4 + tensorflow/compiler/mlir/xla/BUILD | 19 +++++ .../tests/sink-constants-to-control-flow.mlir | 60 +++++++++++++ .../compiler/mlir/xla/transforms/passes.h | 4 + .../sink_constants_to_control_flow.cc | 85 +++++++++++++++++++ 6 files changed, 173 insertions(+) create mode 100644 tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir create mode 100644 tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 54b560ed6ce..eb220a31f80 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index e8ca691f961..03283da0112 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -305,6 +305,10 @@ Status ConvertMLIRToXlaComputation( // invocation. tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + tf2xla.addNestedPass( + mlir::xla_hlo::createSinkConstantsToControlFlowPass()); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 12334e463fa..179a637ec7b 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -193,6 +193,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_sink_constants_to_control_flow", + srcs = [ + "transforms/sink_constants_to_control_flow.cc", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "map_xla_to_scalar_op", hdrs = ["transforms/map_xla_to_scalar_op.h"], @@ -873,6 +891,7 @@ cc_library( ":xla_legalize_to_standard", ":xla_lower", ":xla_materialize_broadcasts", + ":xla_sink_constants_to_control_flow", ":xla_test_passes", ], ) diff --git a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 00000000000..c2fbad2faec --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor) -> tensor { + // CHECK-NEXT: xla_hlo.while + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1A:.+]]: tensor + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1B:.+]]: tensor + // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] + %2 = xla_hlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] + %3 = xla_hlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] + %4 = xla_hlo.add %c1, %3 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor) -> tensor { + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: xla_hlo.conditional + %2 = "xla_hlo.conditional"(%0, %1, %1) ( { + ^bb0(%arg1: tuple>): + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], + %4 = xla_hlo.add %c0, %3 : tensor + %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> + "xla_hlo.return"(%5) : (tuple>) -> () + }, { + ^bb0(%arg1: tuple>): + // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], + %7 = xla_hlo.add %c1, %6 : tensor + %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> + "xla_hlo.return"(%8) : (tuple>) -> () + }) : (tensor, tuple>, tuple>) -> tuple> + %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + return %9 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 39375e210d5..b148eac4286 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -65,6 +65,10 @@ std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); +// Sinks constants implicitly captured in control flow regions. This is +// necessary to export to XLA. +std::unique_ptr> createSinkConstantsToControlFlowPass(); + } // namespace xla_hlo namespace xla_lhlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc new file mode 100644 index 00000000000..29646465acd --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// A pass that sinks constants implicitly captured in control flow regions. This +// is necessary to export to XLA. +class SinkConstantsToControlFlow + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](Operation* op) { + if (auto while_op = llvm::dyn_cast(op)) { + SinkToRegion(&while_op.body()); + SinkToRegion(&while_op.cond()); + } else if (auto cond_op = llvm::dyn_cast(op)) { + SinkToRegion(&cond_op.true_branch()); + SinkToRegion(&cond_op.false_branch()); + } + }); + } + + private: + // Performs constant sinking into a region. + static void SinkToRegion(Region* region) { + llvm::DenseMap sunk_constant; + visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { + Value constant = use->get(); + auto const_op = dyn_cast_or_null(constant.getDefiningOp()); + if (!const_op) return; + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + if (constant.use_empty()) const_op.erase(); + return; + } + if (constant.hasOneUse()) { + const_op.getOperation()->moveBefore(®ion->front().front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + region->front().getOperations().insert(region->front().begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + }); + } +}; + +static mlir::PassRegistration pass( + "xla-hlo-sink-constants-to-control-flow", + "Sink constants implicitly captured in control flow regions. This is " + "necessary to export to XLA."); + +} // anonymous namespace + +std::unique_ptr> createSinkConstantsToControlFlowPass() { + return std::make_unique(); +} + +} // namespace xla_hlo +} // namespace mlir