[MLIR/XLA] Constant sinking to control flow regions.
This is necessary for exporting to XLA since functional control flow is expected. PiperOrigin-RevId: 311753796 Change-Id: If4e50a3b2fa668f162c9b30cc80e2bf743a9b641
This commit is contained in:
parent
9957cb60a2
commit
53c634a6c1
@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
|
||||
@ -305,6 +305,10 @@ Status ConvertMLIRToXlaComputation(
|
||||
// invocation.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
// In order to export to XLA, we must sink constants to control flow regions,
|
||||
// since XLA uses functional control flow.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
mlir::xla_hlo::createSinkConstantsToControlFlowPass());
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
// Print the whole module after each pass which requires disabling
|
||||
|
||||
@ -193,6 +193,24 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_sink_constants_to_control_flow",
|
||||
srcs = [
|
||||
"transforms/sink_constants_to_control_flow.cc",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_xla_to_scalar_op",
|
||||
hdrs = ["transforms/map_xla_to_scalar_op.h"],
|
||||
@ -873,6 +891,7 @@ cc_library(
|
||||
":xla_legalize_to_standard",
|
||||
":xla_lower",
|
||||
":xla_materialize_broadcasts",
|
||||
":xla_sink_constants_to_control_flow",
|
||||
":xla_test_passes",
|
||||
],
|
||||
)
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail
|
||||
|
||||
// Tests sinking constants to a while loop.
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_while
|
||||
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NEXT: xla_hlo.while
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1A:.+]]: tensor<i64>
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1B:.+]]: tensor<i64>
|
||||
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = xla_hlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = xla_hlo.add %c1, %3 : tensor<i64>
|
||||
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests sinking constants to a conditional op.
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_conditional
|
||||
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: xla_hlo.conditional
|
||||
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
|
||||
%4 = xla_hlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
|
||||
%7 = xla_hlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
|
||||
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
return %9 : tensor<i64>
|
||||
}
|
||||
@ -65,6 +65,10 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
|
||||
// Lowers from HLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||
|
||||
// Sinks constants implicitly captured in control flow regions. This is
|
||||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// A pass that sinks constants implicitly captured in control flow regions. This
|
||||
// is necessary to export to XLA.
|
||||
class SinkConstantsToControlFlow
|
||||
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([](Operation* op) {
|
||||
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
|
||||
SinkToRegion(&while_op.body());
|
||||
SinkToRegion(&while_op.cond());
|
||||
} else if (auto cond_op = llvm::dyn_cast<ConditionalOp>(op)) {
|
||||
SinkToRegion(&cond_op.true_branch());
|
||||
SinkToRegion(&cond_op.false_branch());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// Performs constant sinking into a region.
|
||||
static void SinkToRegion(Region* region) {
|
||||
llvm::DenseMap<Value, ConstOp> sunk_constant;
|
||||
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||
Value constant = use->get();
|
||||
auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
|
||||
if (!const_op) return;
|
||||
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||
if (!map_entry.second) {
|
||||
// This constant has already been cloned into the region, reuse it.
|
||||
use->set(map_entry.first->getSecond().getResult());
|
||||
if (constant.use_empty()) const_op.erase();
|
||||
return;
|
||||
}
|
||||
if (constant.hasOneUse()) {
|
||||
const_op.getOperation()->moveBefore(®ion->front().front());
|
||||
return;
|
||||
}
|
||||
map_entry.first->getSecond() = const_op.clone();
|
||||
region->front().getOperations().insert(region->front().begin(),
|
||||
map_entry.first->getSecond());
|
||||
use->set(map_entry.first->getSecond().getResult());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
|
||||
"xla-hlo-sink-constants-to-control-flow",
|
||||
"Sink constants implicitly captured in control flow regions. This is "
|
||||
"necessary to export to XLA.");
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||
return std::make_unique<SinkConstantsToControlFlow>();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
Loading…
x
Reference in New Issue
Block a user