[mlir] Add FusionOp to XLA HLO -> LMHLO
Also refactor the cache to take (HloInstruction*, ShapeIndex) as the key. It makes tuple handling simpler. PiperOrigin-RevId: 336951382 Change-Id: I6e86870e00a364b46ee0f8ae21bad3d19486bf24
This commit is contained in:
parent
72028307fd
commit
37da1f0ee1
tensorflow
compiler
mlir/xla
BUILDhlo_function_importer.cchlo_function_importer.h
tests/hlo_to_lhlo_with_xla
transforms
xla_mlir_translate_cl.ccxla_mlir_translate_cl.hxla/service
core/tpu
@ -136,6 +136,7 @@ cc_library(
|
||||
":hlo_module_importer",
|
||||
":hlo_utils",
|
||||
":mlir_hlo_to_hlo",
|
||||
":translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
@ -140,31 +140,42 @@ tensorflow::Status HloFunctionImporter::ImportAsRegion(
|
||||
return ImportInstructions(computation, block);
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
const HloComputation& computation, mlir::Block* block) {
|
||||
StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
|
||||
const xla::HloComputation& computation,
|
||||
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
|
||||
// Setup the input parameters.
|
||||
const int num_parameters = computation.num_parameters();
|
||||
|
||||
if (arguments.size() != num_parameters)
|
||||
return InvalidArgument("Caller vs callee argument sizes do not match");
|
||||
|
||||
for (int i = 0; i < num_parameters; i++) {
|
||||
auto hlo_parameter = computation.parameter_instruction(i);
|
||||
instruction_value_map_[hlo_parameter] = block->getArgument(i);
|
||||
instruction_value_map_[hlo_parameter] = arguments[i];
|
||||
}
|
||||
|
||||
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
|
||||
for (auto instruction : computation.MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(auto new_operation,
|
||||
ImportInstruction(instruction, &builder));
|
||||
ImportInstruction(instruction, builder));
|
||||
if (new_operation) {
|
||||
instruction_value_map_[instruction] = new_operation->getResult(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
return GetMlirValue(computation.root_instruction());
|
||||
}
|
||||
|
||||
Status HloFunctionImporter::ImportInstructions(
|
||||
const HloComputation& computation, mlir::Block* block) {
|
||||
llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
|
||||
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
|
||||
TF_ASSIGN_OR_RETURN(Value result,
|
||||
ImportInstructionsImpl(computation, arguments, &builder));
|
||||
|
||||
// TODO(suderman): Add location tracking details.
|
||||
mlir::Location loc = builder.getUnknownLoc();
|
||||
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
GetMlirValue(computation.root_instruction()));
|
||||
|
||||
// Create terminator op depending on the parent op of this region.
|
||||
if (llvm::isa<FuncOp>(block->getParentOp())) {
|
||||
builder.create<mlir::ReturnOp>(loc, result);
|
||||
@ -174,6 +185,19 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<Value> HloFunctionImporter::ImportInstructions(
|
||||
const xla::HloComputation& computation,
|
||||
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
|
||||
mlir::Block* block = builder->getBlock();
|
||||
if (block == nullptr)
|
||||
return InvalidArgument(
|
||||
"ImportInstructions requires a valid block in the builder");
|
||||
|
||||
HloFunctionImporter importer(
|
||||
block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
|
||||
return importer.ImportInstructionsImpl(computation, arguments, builder);
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
|
||||
|
@ -55,6 +55,13 @@ class HloFunctionImporter {
|
||||
static Status ImportAsRegion(const xla::HloComputation& computation,
|
||||
mlir::Region* region, mlir::Builder* builder);
|
||||
|
||||
// Imports the given computation to the given place specified by `builder`.
|
||||
// `arguments` contains values for all parameters.
|
||||
static StatusOr<mlir::Value> ImportInstructions(
|
||||
const xla::HloComputation& computation,
|
||||
const llvm::SmallVectorImpl<mlir::Value>& arguments,
|
||||
mlir::OpBuilder* builder);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
@ -80,6 +87,10 @@ class HloFunctionImporter {
|
||||
// Assumes that the block already has correct arguments populated.
|
||||
tensorflow::Status ImportInstructions(const HloComputation& computation,
|
||||
mlir::Block* block);
|
||||
StatusOr<mlir::Value> ImportInstructionsImpl(
|
||||
const xla::HloComputation& computation,
|
||||
const llvm::SmallVectorImpl<mlir::Value>& arguments,
|
||||
mlir::OpBuilder* builder);
|
||||
|
||||
// Imports an instruction.
|
||||
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
|
||||
|
@ -0,0 +1,16 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
|
||||
|
||||
HloModule TestModule
|
||||
|
||||
// CHECK: func @TestComputation
|
||||
|
||||
FusedComputation {
|
||||
// CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>}
|
||||
x = f32[3, 2]{0,1} parameter(0)
|
||||
ROOT y = f32[3, 2]{0,1} add(x, x)
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
x = f32[3, 2]{0,1} parameter(0)
|
||||
ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation
|
||||
}
|
@ -325,3 +325,52 @@ func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>,
|
||||
|
||||
return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<f32> {{.*}}lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<f32> {{.*}}lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
|
||||
// CHECK: "lmhlo.fusion"() ( {
|
||||
// CHECK: %[[VAR0:.*]] = tensor_load %[[ARG0]] : memref<f32>
|
||||
// CHECK: %[[VAR1:.*]] = tensor_load %[[ARG1]] : memref<f32>
|
||||
// CHECK: %[[VAR2:.*]] = mhlo.add %[[VAR0]], %[[VAR1]] : tensor<f32>
|
||||
// CHECK: tensor_store %[[VAR2]], %[[MEMREF:.*]] : memref<f32>
|
||||
// CHECK: "lmhlo.terminator"() : () -> ()
|
||||
// CHECK: }) : () -> ()
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%result = "mhlo.fusion"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%result = "mhlo.add"(%arg2, %arg3): (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%result) : (tensor<f32>) -> ()
|
||||
}) { fusion_kind = "kLoop" } : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
return %result : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo.fusion"() ( {
|
||||
// CHECK: %[[VAL0:.*]] = tensor_load %{{.*}} : memref<f32>
|
||||
// CHECK: %[[VAL1:.*]] = tensor_load %{{.*}} : memref<f32>
|
||||
// CHECK: %[[VAL2:.*]] = tensor_load %{{.*}} : memref<f32>
|
||||
// CHECK: tensor_store %[[VAL0]], %{{.*}} : memref<f32>
|
||||
// CHECK: tensor_store %[[VAL1]], %{{.*}} : memref<f32>
|
||||
// CHECK: tensor_store %[[VAL2]], %{{.*}} : memref<f32>
|
||||
// CHECK: "lmhlo.terminator"() : () -> ()
|
||||
// CHECK: }) : () -> ()
|
||||
func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f32>>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
||||
%result = "mhlo.fusion"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg3: tuple<tensor<f32>>):
|
||||
%0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<tuple<tensor<f32>>, tensor<f32>>) -> tuple<tensor<f32>>
|
||||
%1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple<tuple<tensor<f32>>, tensor<f32>>) -> tensor<f32>
|
||||
%3 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%4 = "mhlo.tuple"(%1, %2, %3) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
"mhlo.return"(%4) : (tuple<tensor<f32>, tensor<f32>, tensor<f32>>) -> ()
|
||||
}) { fusion_kind = "kLoop" } : (tuple<tuple<tensor<f32>>, tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
|
||||
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
}
|
||||
|
@ -29,7 +29,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -40,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
@ -110,7 +113,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
||||
// Run all HLO passes to produce an optimized module.
|
||||
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
|
||||
std::move(hlo_module), backend->default_stream_executor(),
|
||||
backend->memory_allocator());
|
||||
backend->memory_allocator(), optimize_xla_hlo);
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
|
||||
"running XLA pass pipeline");
|
||||
std::unique_ptr<HloModule> optimized_hlo_module =
|
||||
@ -276,27 +279,138 @@ Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {
|
||||
return EmitSortOp(instr).status();
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
const Shape& current_shape,
|
||||
::xla::ShapeIndex* current_shape_index,
|
||||
SmallVectorImpl<Value>* values) {
|
||||
if (current_shape.IsTuple()) {
|
||||
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
|
||||
current_shape_index->push_back(i);
|
||||
TF_RETURN_IF_ERROR(CreateView(instr, current_shape.tuple_shapes(i),
|
||||
current_shape_index, values));
|
||||
current_shape_index->pop_back();
|
||||
// Walks MHLO::TupleOp recursively.
|
||||
Status WalkTuplePostOrder(Value v,
|
||||
const std::function<Status(Value)>& visitor) {
|
||||
if (auto* op = v.getDefiningOp()) {
|
||||
if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
|
||||
for (Value sub_v : tuple.val()) {
|
||||
TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return visitor(v);
|
||||
}
|
||||
|
||||
// This function removes all uses of a fused region argument, and rewire those
|
||||
// uses to a `tensor_load %memref`, where %memref is caller argument.
|
||||
//
|
||||
// It also flattens all input/output tuples into more region arguments /
|
||||
// results.
|
||||
StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
|
||||
const HloInstruction* root, const Shape& shape,
|
||||
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
|
||||
if (shape.IsTuple()) {
|
||||
llvm::SmallVector<Value, 4> values;
|
||||
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
|
||||
shape_index->push_back(i);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
|
||||
b, loc));
|
||||
values.push_back(v);
|
||||
shape_index->pop_back();
|
||||
}
|
||||
return Value(b->create<mhlo::TupleOp>(loc, values));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Value memref,
|
||||
GetOrCreateArrayView(root, shape, *shape_index));
|
||||
auto load = b->create<TensorLoadOp>(loc, memref);
|
||||
if (shape.layout() !=
|
||||
xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
|
||||
llvm::SmallVector<int64_t, 4> minor_to_major(
|
||||
shape.layout().minor_to_major().begin(),
|
||||
shape.layout().minor_to_major().end());
|
||||
load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
|
||||
}
|
||||
return load.getResult();
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
|
||||
HloInstruction* instr) {
|
||||
Location loc = getLocation(instr);
|
||||
|
||||
auto* fusion_instr = ::xla::Cast<::xla::HloFusionInstruction>(instr);
|
||||
|
||||
auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr),
|
||||
ArrayRef<NamedAttribute>{});
|
||||
auto after_fusion = builder_.saveInsertionPoint();
|
||||
builder_ = mlir::OpBuilder(fusion);
|
||||
|
||||
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
|
||||
|
||||
llvm::SmallVector<Value, 8> arguments;
|
||||
for (int i = 0; i < instr->operands().size(); i++) {
|
||||
const HloInstruction* operand = instr->operand(i);
|
||||
xla::ShapeIndex shape_index;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
|
||||
®ion_builder, loc));
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Value result,
|
||||
::xla::HloFunctionImporter::ImportInstructions(
|
||||
*fusion_instr->fused_instructions_computation(),
|
||||
arguments, ®ion_builder));
|
||||
|
||||
{
|
||||
int i = 0;
|
||||
llvm::SmallVector<Value, 4> output;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
|
||||
TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
|
||||
region_builder.create<TensorStoreOp>(loc, v, output[i++]);
|
||||
return Status::OK();
|
||||
}));
|
||||
if (i != output.size()) {
|
||||
return ::xla::InternalError("output sizes don't match");
|
||||
}
|
||||
}
|
||||
|
||||
// Fold GTE/Tuple pairs.
|
||||
//
|
||||
// Since the fused region refers to values in its parent region, we can't
|
||||
// call applyPatternAndFoldGreedily. We optimize it manually.
|
||||
//
|
||||
// Only walk once, because post-ordering is exactly what we need for GTE
|
||||
// optimizations.
|
||||
fusion.region().walk([](mhlo::GetTupleElementOp gte) {
|
||||
SmallVector<Value, 4> folded_values;
|
||||
if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
|
||||
gte.replaceAllUsesWith(folded_values[0]);
|
||||
}
|
||||
});
|
||||
|
||||
// Effectively a DCE on the region.
|
||||
{
|
||||
llvm::SmallVector<mlir::Operation*, 4> ops;
|
||||
fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
|
||||
// Visit the user first.
|
||||
std::reverse(ops.begin(), ops.end());
|
||||
for (auto op : ops) {
|
||||
if (isOpTriviallyDead(op)) op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
LOG(ERROR) << instr->GetModule()->ToString();
|
||||
builder_.restoreInsertionPoint(after_fusion);
|
||||
return fusion;
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) {
|
||||
return EmitFusionOp(instr).status();
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType<MemRefType>(
|
||||
current_shape, builder_));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
|
||||
assignment_.GetUniqueSlice(instr, *current_shape_index));
|
||||
assignment_.GetUniqueSlice(instr, shape_index));
|
||||
Value alloc = allocations_[slice.allocation()];
|
||||
if (alloc.getType() == out_type && slice.offset() == 0) {
|
||||
values->push_back(alloc);
|
||||
return Status::OK();
|
||||
return alloc;
|
||||
}
|
||||
|
||||
auto out_memref_type = out_type.dyn_cast<MemRefType>();
|
||||
@ -304,6 +418,13 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
return tensorflow::errors::Internal(
|
||||
"Expected memref type when creating a view for leaf type of a tuple.");
|
||||
|
||||
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
|
||||
// shape_index).
|
||||
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
|
||||
if (cached_value) {
|
||||
return cached_value;
|
||||
}
|
||||
|
||||
Value byte_shift =
|
||||
builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
|
||||
|
||||
@ -327,7 +448,24 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
if (physical_out_type != out_type)
|
||||
result = builder_.create<lmhlo::StaticMemRefCastOp>(loc, out_memref_type,
|
||||
result);
|
||||
values->push_back(result);
|
||||
return cached_value = result;
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::GetOrCreateViewImpl(
|
||||
const HloInstruction* instr, const Shape& current_shape,
|
||||
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
|
||||
if (current_shape.IsTuple()) {
|
||||
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
|
||||
current_shape_index->push_back(i);
|
||||
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
|
||||
instr, current_shape.tuple_shapes(i), current_shape_index, values));
|
||||
current_shape_index->pop_back();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index));
|
||||
values->push_back(v);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -336,25 +474,8 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
// create another view to adjust the slice for the shape of the instruction.
|
||||
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
|
||||
SmallVectorImpl<Value>* values) {
|
||||
// Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have
|
||||
// gone fancier to do the following caching:
|
||||
// %slice = ViewOp(%allocation, %offset) : memref<i8xSIZE>
|
||||
// %typed_slice = ViewOp(%slice) : memref<f32x...>
|
||||
//
|
||||
// where %slice is cached. This in theory gives easier time for alias
|
||||
// analysis, since the identity of %slice defines alias. However,
|
||||
// %typed_slice can't be cached, as different buffers with different types and
|
||||
// shapes may still alias. Creating two ViewOps doesn't seem to worth the
|
||||
// effort for a slightly easier aliasing, so we don't over optimize here.
|
||||
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 1>{});
|
||||
llvm::SmallVectorImpl<Value>& new_values = result.first->second;
|
||||
if (result.second) {
|
||||
::xla::ShapeIndex shape_index;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateView(instr, instr->shape(), &shape_index, &new_values));
|
||||
}
|
||||
values->insert(values->end(), new_values.begin(), new_values.end());
|
||||
return Status::OK();
|
||||
::xla::ShapeIndex shape_index;
|
||||
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values);
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::Initialize() {
|
||||
|
@ -43,6 +43,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
i8_type_(builder_.getIntegerType(8)) {}
|
||||
|
||||
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(::xla::HloInstruction* instr);
|
||||
|
||||
private:
|
||||
template <typename OpType>
|
||||
@ -57,21 +58,31 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
tensorflow::Status HandleSort(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final;
|
||||
|
||||
// Helper function that recursively visits the tuple structure in
|
||||
// `current_shape`, and reconstruct a matching lmhlo::TupleOp.
|
||||
// Each leaf node is converted to an std.view op with corresponding offsets.
|
||||
// If no tuple presents, it simply returns a view of the buffer.
|
||||
tensorflow::Status CreateView(const ::xla::HloInstruction* instr,
|
||||
const ::xla::Shape& current_shape,
|
||||
::xla::ShapeIndex* current_shape_index,
|
||||
SmallVectorImpl<Value>* values);
|
||||
tensorflow::Status GetOrCreateViewImpl(const ::xla::HloInstruction* instr,
|
||||
const ::xla::Shape& current_shape,
|
||||
::xla::ShapeIndex* current_shape_index,
|
||||
SmallVectorImpl<Value>* values);
|
||||
|
||||
// Helper function to create view/tuple of views to a buffer for a given
|
||||
// instruction result.
|
||||
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
|
||||
SmallVectorImpl<Value>* values);
|
||||
|
||||
::xla::StatusOr<Value> GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& current_shape_index);
|
||||
|
||||
::xla::StatusOr<Value> RewriteFusionOperand(const ::xla::HloInstruction* root,
|
||||
const ::xla::Shape& shape,
|
||||
::xla::ShapeIndex* shape_index,
|
||||
OpBuilder* b, Location loc);
|
||||
|
||||
// Return an MLIR location for an HLO instruction.
|
||||
Location getLocation(::xla::HloInstruction* inst) {
|
||||
return NameLoc::get(builder_.getIdentifier(inst->name()),
|
||||
@ -102,7 +113,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
//
|
||||
// `slices_` is populated lazily in the `GetOrCreateView()` helper as we
|
||||
// process every instruction.
|
||||
llvm::DenseMap<const xla::HloInstruction*, llvm::SmallVector<Value, 1>>
|
||||
absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>,
|
||||
Value>
|
||||
slices_;
|
||||
|
||||
// The BufferAssignment computed by XLA ahead of time.
|
||||
|
@ -27,3 +27,9 @@ llvm::cl::opt<bool> emit_return_tuple(
|
||||
"emit-return-tuple",
|
||||
llvm::cl::desc("Emit HLO modules with entry computations returning tuple"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::opt<bool> optimize_xla_hlo(
|
||||
"optimize-xla-hlo",
|
||||
llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"),
|
||||
llvm::cl::init(true));
|
||||
|
@ -24,5 +24,6 @@ limitations under the License.
|
||||
|
||||
extern llvm::cl::opt<bool> emit_use_tuple_arg;
|
||||
extern llvm::cl::opt<bool> emit_return_tuple;
|
||||
extern llvm::cl::opt<bool> optimize_xla_hlo;
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_
|
||||
|
@ -28,14 +28,6 @@ namespace xla {
|
||||
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
Compiler::RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return Unimplemented("This compiler does not support this method");
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
|
||||
Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
|
||||
se::StreamExecutor* executor) const {
|
||||
|
@ -188,7 +188,10 @@ class Compiler {
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator);
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool optimize) {
|
||||
return Unimplemented("This compiler does not support this method");
|
||||
}
|
||||
|
||||
// Compiles the HLO module for execution on a device given by the executor,
|
||||
// and returns an executable object or an error status. No HLO passes are
|
||||
|
@ -562,9 +562,11 @@ StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
CpuCompiler::RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, RunHloPasses(std::move(module), executor, device_allocator));
|
||||
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
|
||||
if (optimize) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, RunHloPasses(std::move(module), executor, device_allocator));
|
||||
}
|
||||
|
||||
// Select an order for emitting the HLO instructions for each computation.
|
||||
// Using this sequence enables tighter buffer liveness analysis and reduced
|
||||
|
@ -138,9 +138,10 @@ class CpuCompiler : public LLVMCompiler {
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool optimize) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
|
@ -276,16 +276,6 @@ class TpuCompiler : public Compiler {
|
||||
return HloModule::CreateFromProto(result_proto, module->config());
|
||||
}
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> module,
|
||||
stream_executor::StreamExecutor* executor,
|
||||
stream_executor::DeviceMemoryAllocator* device_allocator) override {
|
||||
return Unimplemented(
|
||||
"This compiler does not support RunHloPassesAndBufferAssignment.");
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
stream_executor::StreamExecutor* executor,
|
||||
|
Loading…
Reference in New Issue
Block a user