From 37da1f0ee19568ffddbb1b58d0f37bf8844427a6 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 13 Oct 2020 14:06:18 -0700 Subject: [PATCH] [mlir] Add FusionOp to XLA HLO -> LMHLO Also refactor the cache to take (HloInstruction*, ShapeIndex) as the key. It makes tuple handling simpler. PiperOrigin-RevId: 336951382 Change-Id: I6e86870e00a364b46ee0f8ae21bad3d19486bf24 --- tensorflow/compiler/mlir/xla/BUILD | 1 + .../mlir/xla/hlo_function_importer.cc | 42 +++- .../compiler/mlir/xla/hlo_function_importer.h | 11 + .../fusion_layouts.hlotxt | 16 ++ .../xla/tests/hlo_to_lhlo_with_xla/ops.mlir | 49 +++++ .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 191 ++++++++++++++---- .../xla/transforms/mhlo_to_lhlo_with_xla.h | 22 +- .../mlir/xla/xla_mlir_translate_cl.cc | 6 + .../compiler/mlir/xla/xla_mlir_translate_cl.h | 1 + tensorflow/compiler/xla/service/compiler.cc | 8 - tensorflow/compiler/xla/service/compiler.h | 5 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 8 +- .../compiler/xla/service/cpu/cpu_compiler.h | 7 +- tensorflow/core/tpu/tpu_on_demand_compiler.cc | 10 - 14 files changed, 303 insertions(+), 74 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 689eb14e4af..1919446a365 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -136,6 +136,7 @@ cc_library( ":hlo_module_importer", ":hlo_utils", ":mlir_hlo_to_hlo", + ":translate_cl_options", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla:debug_options_flags", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 209a7dfa7fe..253156b44a5 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -140,31 +140,42 @@ tensorflow::Status HloFunctionImporter::ImportAsRegion( return ImportInstructions(computation, block); } -tensorflow::Status HloFunctionImporter::ImportInstructions( - const HloComputation& computation, mlir::Block* block) { +StatusOr HloFunctionImporter::ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { // Setup the input parameters. const int num_parameters = computation.num_parameters(); + + if (arguments.size() != num_parameters) + return InvalidArgument("Caller vs callee argument sizes do not match"); + for (int i = 0; i < num_parameters; i++) { auto hlo_parameter = computation.parameter_instruction(i); - instruction_value_map_[hlo_parameter] = block->getArgument(i); + instruction_value_map_[hlo_parameter] = arguments[i]; } - mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); for (auto instruction : computation.MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(auto new_operation, - ImportInstruction(instruction, &builder)); + ImportInstruction(instruction, builder)); if (new_operation) { instruction_value_map_[instruction] = new_operation->getResult(0); } } + // Setup the return type (HLO only supports a single return value). + return GetMlirValue(computation.root_instruction()); +} + +Status HloFunctionImporter::ImportInstructions( + const HloComputation& computation, mlir::Block* block) { + llvm::SmallVector arguments(block->args_begin(), block->args_end()); + mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); + TF_ASSIGN_OR_RETURN(Value result, + ImportInstructionsImpl(computation, arguments, &builder)); + // TODO(suderman): Add location tracking details. mlir::Location loc = builder.getUnknownLoc(); - // Setup the return type (HLO only supports a single return value). - TF_ASSIGN_OR_RETURN(auto result, - GetMlirValue(computation.root_instruction())); - // Create terminator op depending on the parent op of this region. if (llvm::isa(block->getParentOp())) { builder.create(loc, result); @@ -174,6 +185,19 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( return tensorflow::Status::OK(); } +StatusOr HloFunctionImporter::ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { + mlir::Block* block = builder->getBlock(); + if (block == nullptr) + return InvalidArgument( + "ImportInstructions requires a valid block in the builder"); + + HloFunctionImporter importer( + block->getParent()->getParentOfType(), {}, builder); + return importer.ImportInstructionsImpl(computation, arguments, builder); +} + StatusOr HloFunctionImporter::ImportInstructionImpl( HloInstruction* instruction, mlir::OpBuilder* func_builder) { TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction)); diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index f925f7f471b..4a75b079d76 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -55,6 +55,13 @@ class HloFunctionImporter { static Status ImportAsRegion(const xla::HloComputation& computation, mlir::Region* region, mlir::Builder* builder); + // Imports the given computation to the given place specified by `builder`. + // `arguments` contains values for all parameters. + static StatusOr ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); + private: HloFunctionImporter(mlir::ModuleOp module, std::unordered_map ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); // Imports an instruction. StatusOr ImportInstruction(xla::HloInstruction* instruction, diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt new file mode 100644 index 00000000000..781e203510b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s + +HloModule TestModule + +// CHECK: func @TestComputation + +FusedComputation { + // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} add(x, x) +} + +ENTRY TestComputation { + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 5ce78c2dfa3..e7312e2114c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -325,3 +325,52 @@ func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>, return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32> } + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref {{.*}}lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref {{.*}}lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8> +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAR0:.*]] = tensor_load %[[ARG0]] : memref +// CHECK: %[[VAR1:.*]] = tensor_load %[[ARG1]] : memref +// CHECK: %[[VAR2:.*]] = mhlo.add %[[VAR0]], %[[VAR1]] : tensor +// CHECK: tensor_store %[[VAR2]], %[[MEMREF:.*]] : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %result = "mhlo.add"(%arg2, %arg3): (tensor, tensor) -> tensor + "mhlo.return"(%result) : (tensor) -> () + }) { fusion_kind = "kLoop" } : (tensor, tensor) -> tensor + + return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAL0:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL1:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL2:.*]] = tensor_load %{{.*}} : memref +// CHECK: tensor_store %[[VAL0]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL1]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL2]], %{{.*}} : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tuple>, tensor>, %arg1: tuple>) -> tuple, tensor, tensor> { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tuple>, tensor>, %arg3: tuple>): + %0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple>, tensor>) -> tuple> + %1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple>) -> tensor + %2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple>, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple>) -> tensor + %4 = "mhlo.tuple"(%1, %2, %3) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%4) : (tuple, tensor, tensor>) -> () + }) { fusion_kind = "kLoop" } : (tuple>, tensor>, tuple>) -> tuple, tensor, tensor> + + return %result : tuple, tensor, tensor> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index b60d95d1ddf..9cf161fb2ae 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -29,7 +29,9 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -110,7 +113,7 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // Run all HLO passes to produce an optimized module. auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( std::move(hlo_module), backend->default_stream_executor(), - backend->memory_allocator()); + backend->memory_allocator(), optimize_xla_hlo); TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), "running XLA pass pipeline"); std::unique_ptr optimized_hlo_module = @@ -276,27 +279,138 @@ Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { return EmitSortOp(instr).status(); } -Status LhloDialectEmitter::CreateView(const HloInstruction* instr, - const Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values) { - if (current_shape.IsTuple()) { - for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { - current_shape_index->push_back(i); - TF_RETURN_IF_ERROR(CreateView(instr, current_shape.tuple_shapes(i), - current_shape_index, values)); - current_shape_index->pop_back(); +// Walks MHLO::TupleOp recursively. +Status WalkTuplePostOrder(Value v, + const std::function& visitor) { + if (auto* op = v.getDefiningOp()) { + if (auto tuple = dyn_cast(op)) { + for (Value sub_v : tuple.val()) { + TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); + } + return Status::OK(); } - return Status::OK(); } + return visitor(v); +} + +// This function removes all uses of a fused region argument, and rewire those +// uses to a `tensor_load %memref`, where %memref is caller argument. +// +// It also flattens all input/output tuples into more region arguments / +// results. +StatusOr LhloDialectEmitter::RewriteFusionOperand( + const HloInstruction* root, const Shape& shape, + ::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { + if (shape.IsTuple()) { + llvm::SmallVector values; + for (int i = 0; i < shape.tuple_shapes_size(); i++) { + shape_index->push_back(i); + TF_ASSIGN_OR_RETURN( + auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, + b, loc)); + values.push_back(v); + shape_index->pop_back(); + } + return Value(b->create(loc, values)); + } + TF_ASSIGN_OR_RETURN(Value memref, + GetOrCreateArrayView(root, shape, *shape_index)); + auto load = b->create(loc, memref); + if (shape.layout() != + xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { + llvm::SmallVector minor_to_major( + shape.layout().minor_to_major().begin(), + shape.layout().minor_to_major().end()); + load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major)); + } + return load.getResult(); +} + +StatusOr LhloDialectEmitter::EmitFusionOp( + HloInstruction* instr) { + Location loc = getLocation(instr); + + auto* fusion_instr = ::xla::Cast<::xla::HloFusionInstruction>(instr); + + auto fusion = builder_.create(getLocation(instr), + ArrayRef{}); + auto after_fusion = builder_.saveInsertionPoint(); + builder_ = mlir::OpBuilder(fusion); + + auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front()); + + llvm::SmallVector arguments; + for (int i = 0; i < instr->operands().size(); i++) { + const HloInstruction* operand = instr->operand(i); + xla::ShapeIndex shape_index; + TF_ASSIGN_OR_RETURN( + auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, + ®ion_builder, loc)); + arguments.push_back(arg); + } + + TF_ASSIGN_OR_RETURN(Value result, + ::xla::HloFunctionImporter::ImportInstructions( + *fusion_instr->fused_instructions_computation(), + arguments, ®ion_builder)); + + { + int i = 0; + llvm::SmallVector output; + TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); + TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { + region_builder.create(loc, v, output[i++]); + return Status::OK(); + })); + if (i != output.size()) { + return ::xla::InternalError("output sizes don't match"); + } + } + + // Fold GTE/Tuple pairs. + // + // Since the fused region refers to values in its parent region, we can't + // call applyPatternAndFoldGreedily. We optimize it manually. + // + // Only walk once, because post-ordering is exactly what we need for GTE + // optimizations. + fusion.region().walk([](mhlo::GetTupleElementOp gte) { + SmallVector folded_values; + if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { + gte.replaceAllUsesWith(folded_values[0]); + } + }); + + // Effectively a DCE on the region. + { + llvm::SmallVector ops; + fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); }); + // Visit the user first. + std::reverse(ops.begin(), ops.end()); + for (auto op : ops) { + if (isOpTriviallyDead(op)) op->erase(); + } + } + + LOG(ERROR) << instr->GetModule()->ToString(); + builder_.restoreInsertionPoint(after_fusion); + return fusion; +} + +Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { + return EmitFusionOp(instr).status(); +} + +StatusOr LhloDialectEmitter::GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& shape_index) { TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( current_shape, builder_)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(instr, *current_shape_index)); + assignment_.GetUniqueSlice(instr, shape_index)); Value alloc = allocations_[slice.allocation()]; if (alloc.getType() == out_type && slice.offset() == 0) { - values->push_back(alloc); - return Status::OK(); + return alloc; } auto out_memref_type = out_type.dyn_cast(); @@ -304,6 +418,13 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, return tensorflow::errors::Internal( "Expected memref type when creating a view for leaf type of a tuple."); + // Cache generated ViewOp and StaticMemRefCastOp by (instruction, + // shape_index). + auto& cached_value = slices_[std::make_pair(instr, shape_index)]; + if (cached_value) { + return cached_value; + } + Value byte_shift = builder_.create(alloc.getLoc(), slice.offset()); @@ -327,7 +448,24 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, if (physical_out_type != out_type) result = builder_.create(loc, out_memref_type, result); - values->push_back(result); + return cached_value = result; +} + +Status LhloDialectEmitter::GetOrCreateViewImpl( + const HloInstruction* instr, const Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, SmallVectorImpl* values) { + if (current_shape.IsTuple()) { + for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { + current_shape_index->push_back(i); + TF_RETURN_IF_ERROR(GetOrCreateViewImpl( + instr, current_shape.tuple_shapes(i), current_shape_index, values)); + current_shape_index->pop_back(); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index)); + values->push_back(v); return Status::OK(); } @@ -336,25 +474,8 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, // create another view to adjust the slice for the shape of the instruction. Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr, SmallVectorImpl* values) { - // Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have - // gone fancier to do the following caching: - // %slice = ViewOp(%allocation, %offset) : memref - // %typed_slice = ViewOp(%slice) : memref - // - // where %slice is cached. This in theory gives easier time for alias - // analysis, since the identity of %slice defines alias. However, - // %typed_slice can't be cached, as different buffers with different types and - // shapes may still alias. Creating two ViewOps doesn't seem to worth the - // effort for a slightly easier aliasing, so we don't over optimize here. - auto result = slices_.try_emplace(instr, llvm::SmallVector{}); - llvm::SmallVectorImpl& new_values = result.first->second; - if (result.second) { - ::xla::ShapeIndex shape_index; - TF_RETURN_IF_ERROR( - CreateView(instr, instr->shape(), &shape_index, &new_values)); - } - values->insert(values->end(), new_values.begin(), new_values.end()); - return Status::OK(); + ::xla::ShapeIndex shape_index; + return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values); } Status LhloDialectEmitter::Initialize() { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 89514116254..a57db3cb67e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -43,6 +43,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { i8_type_(builder_.getIntegerType(8)) {} ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitFusionOp(::xla::HloInstruction* instr); private: template @@ -57,21 +58,31 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { } tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; + tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in // `current_shape`, and reconstruct a matching lmhlo::TupleOp. // Each leaf node is converted to an std.view op with corresponding offsets. // If no tuple presents, it simply returns a view of the buffer. - tensorflow::Status CreateView(const ::xla::HloInstruction* instr, - const ::xla::Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values); + tensorflow::Status GetOrCreateViewImpl(const ::xla::HloInstruction* instr, + const ::xla::Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, + SmallVectorImpl* values); // Helper function to create view/tuple of views to a buffer for a given // instruction result. tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr, SmallVectorImpl* values); + ::xla::StatusOr GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& current_shape_index); + + ::xla::StatusOr RewriteFusionOperand(const ::xla::HloInstruction* root, + const ::xla::Shape& shape, + ::xla::ShapeIndex* shape_index, + OpBuilder* b, Location loc); + // Return an MLIR location for an HLO instruction. Location getLocation(::xla::HloInstruction* inst) { return NameLoc::get(builder_.getIdentifier(inst->name()), @@ -102,7 +113,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // // `slices_` is populated lazily in the `GetOrCreateView()` helper as we // process every instruction. - llvm::DenseMap> + absl::flat_hash_map, + Value> slices_; // The BufferAssignment computed by XLA ahead of time. diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc index bfe4ed3844f..7eb1fb40f5e 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc @@ -27,3 +27,9 @@ llvm::cl::opt emit_return_tuple( "emit-return-tuple", llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt optimize_xla_hlo( + "optimize-xla-hlo", + llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"), + llvm::cl::init(true)); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h index 1d5a29a5fdb..14a2878dff8 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h @@ -24,5 +24,6 @@ limitations under the License. extern llvm::cl::opt emit_use_tuple_arg; extern llvm::cl::opt emit_return_tuple; +extern llvm::cl::opt optimize_xla_hlo; #endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index f03b27cdcc7..653f4555a77 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,14 +28,6 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -StatusOr< - std::tuple, std::unique_ptr>> -Compiler::RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("This compiler does not support this method"); -} - std::vector> Compiler::ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const { diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 312a068ba65..253caac195c 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -188,7 +188,10 @@ class Compiler { std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator); + se::DeviceMemoryAllocator* device_allocator, + bool optimize) { + return Unimplemented("This compiler does not support this method"); + } // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 0260b3926c7..1ffafd37a27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -562,9 +562,11 @@ StatusOr< std::tuple, std::unique_ptr>> CpuCompiler::RunHloPassesAndBufferAssignement( std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - TF_ASSIGN_OR_RETURN( - module, RunHloPasses(std::move(module), executor, device_allocator)); + se::DeviceMemoryAllocator* device_allocator, bool optimize) { + if (optimize) { + TF_ASSIGN_OR_RETURN( + module, RunHloPasses(std::move(module), executor, device_allocator)); + } // Select an order for emitting the HLO instructions for each computation. // Using this sequence enables tighter buffer liveness analysis and reduced diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d28ccd985a3..5c056fcacaa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -138,9 +138,10 @@ class CpuCompiler : public LLVMCompiler { StatusOr< std::tuple, std::unique_ptr>> - RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) override; + RunHloPassesAndBufferAssignement(std::unique_ptr module, + se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator, + bool optimize) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index 01ea9f5848a..c34a13a45dc 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -276,16 +276,6 @@ class TpuCompiler : public Compiler { return HloModule::CreateFromProto(result_proto, module->config()); } - StatusOr< - std::tuple, std::unique_ptr>> - RunHloPassesAndBufferAssignement( - std::unique_ptr module, - stream_executor::StreamExecutor* executor, - stream_executor::DeviceMemoryAllocator* device_allocator) override { - return Unimplemented( - "This compiler does not support RunHloPassesAndBufferAssignment."); - } - StatusOr> RunBackend( std::unique_ptr module, stream_executor::StreamExecutor* executor,