[MLIR/XLA] Constant sinking to control flow regions.

This is necessary for exporting to XLA since functional control flow is expected.

PiperOrigin-RevId: 311753796
Change-Id: If4e50a3b2fa668f162c9b30cc80e2bf743a9b641
This commit is contained in:
Yuanzhong Xu 2020-05-15 10:17:05 -07:00 committed by TensorFlower Gardener
parent 9957cb60a2
commit 53c634a6c1
6 changed files with 173 additions and 0 deletions

View File

@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",

View File

@ -305,6 +305,10 @@ Status ConvertMLIRToXlaComputation(
// invocation.
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createLegalizeTFPass(false));
// In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow.
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createSinkConstantsToControlFlowPass());
if (VLOG_IS_ON(1)) {
// Print the whole module after each pass which requires disabling

View File

@ -193,6 +193,24 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_sink_constants_to_control_flow",
srcs = [
"transforms/sink_constants_to_control_flow.cc",
],
deps = [
":hlo",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library(
name = "map_xla_to_scalar_op",
hdrs = ["transforms/map_xla_to_scalar_op.h"],
@ -873,6 +891,7 @@ cc_library(
":xla_legalize_to_standard",
":xla_lower",
":xla_materialize_broadcasts",
":xla_sink_constants_to_control_flow",
":xla_test_passes",
],
)

View File

@ -0,0 +1,60 @@
// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail
// Tests sinking constants to a while loop.
// CHECK-LABEL: func @sink_const_to_while
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: xla_hlo.while
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1A:.+]]: tensor<i64>
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1B:.+]]: tensor<i64>
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
%3 = xla_hlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
%4 = xla_hlo.add %c1, %3 : tensor<i64>
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// Tests sinking constants to a conditional op.
// CHECK-LABEL: func @sink_const_to_conditional
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: xla_hlo.conditional
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
%4 = xla_hlo.add %c0, %3 : tensor<i64>
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
%7 = xla_hlo.add %c1, %6 : tensor<i64>
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
return %9 : tensor<i64>
}

View File

@ -65,6 +65,10 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
// Sinks constants implicitly captured in control flow regions. This is
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
} // namespace xla_hlo
namespace xla_lhlo {

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace {
// A pass that sinks constants implicitly captured in control flow regions. This
// is necessary to export to XLA.
class SinkConstantsToControlFlow
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
void runOnFunction() override {
getFunction().walk([](Operation* op) {
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
SinkToRegion(&while_op.body());
SinkToRegion(&while_op.cond());
} else if (auto cond_op = llvm::dyn_cast<ConditionalOp>(op)) {
SinkToRegion(&cond_op.true_branch());
SinkToRegion(&cond_op.false_branch());
}
});
}
private:
// Performs constant sinking into a region.
static void SinkToRegion(Region* region) {
llvm::DenseMap<Value, ConstOp> sunk_constant;
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get();
auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
if (!const_op) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it.
use->set(map_entry.first->getSecond().getResult());
if (constant.use_empty()) const_op.erase();
return;
}
if (constant.hasOneUse()) {
const_op.getOperation()->moveBefore(&region->front().front());
return;
}
map_entry.first->getSecond() = const_op.clone();
region->front().getOperations().insert(region->front().begin(),
map_entry.first->getSecond());
use->set(map_entry.first->getSecond().getResult());
});
}
};
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
"xla-hlo-sink-constants-to-control-flow",
"Sink constants implicitly captured in control flow regions. This is "
"necessary to export to XLA.");
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>();
}
} // namespace xla_hlo
} // namespace mlir