[mlir] Add FusionOp to XLA HLO -> LMHLO

Also refactor the cache to take (HloInstruction*, ShapeIndex) as the key. It makes tuple handling simpler.

PiperOrigin-RevId: 336951382
Change-Id: I6e86870e00a364b46ee0f8ae21bad3d19486bf24
This commit is contained in:
Tim Shen 2020-10-13 14:06:18 -07:00 committed by TensorFlower Gardener
parent 72028307fd
commit 37da1f0ee1
14 changed files with 303 additions and 74 deletions

View File

@ -136,6 +136,7 @@ cc_library(
":hlo_module_importer",
":hlo_utils",
":mlir_hlo_to_hlo",
":translate_cl_options",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/xla:debug_options_flags",

View File

@ -140,31 +140,42 @@ tensorflow::Status HloFunctionImporter::ImportAsRegion(
return ImportInstructions(computation, block);
}
tensorflow::Status HloFunctionImporter::ImportInstructions(
const HloComputation& computation, mlir::Block* block) {
StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
// Setup the input parameters.
const int num_parameters = computation.num_parameters();
if (arguments.size() != num_parameters)
return InvalidArgument("Caller vs callee argument sizes do not match");
for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation.parameter_instruction(i);
instruction_value_map_[hlo_parameter] = block->getArgument(i);
instruction_value_map_[hlo_parameter] = arguments[i];
}
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
for (auto instruction : computation.MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, &builder));
ImportInstruction(instruction, builder));
if (new_operation) {
instruction_value_map_[instruction] = new_operation->getResult(0);
}
}
// Setup the return type (HLO only supports a single return value).
return GetMlirValue(computation.root_instruction());
}
Status HloFunctionImporter::ImportInstructions(
const HloComputation& computation, mlir::Block* block) {
llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
TF_ASSIGN_OR_RETURN(Value result,
ImportInstructionsImpl(computation, arguments, &builder));
// TODO(suderman): Add location tracking details.
mlir::Location loc = builder.getUnknownLoc();
// Setup the return type (HLO only supports a single return value).
TF_ASSIGN_OR_RETURN(auto result,
GetMlirValue(computation.root_instruction()));
// Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) {
builder.create<mlir::ReturnOp>(loc, result);
@ -174,6 +185,19 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
return tensorflow::Status::OK();
}
StatusOr<Value> HloFunctionImporter::ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
mlir::Block* block = builder->getBlock();
if (block == nullptr)
return InvalidArgument(
"ImportInstructions requires a valid block in the builder");
HloFunctionImporter importer(
block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
return importer.ImportInstructionsImpl(computation, arguments, builder);
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));

View File

@ -55,6 +55,13 @@ class HloFunctionImporter {
static Status ImportAsRegion(const xla::HloComputation& computation,
mlir::Region* region, mlir::Builder* builder);
// Imports the given computation to the given place specified by `builder`.
// `arguments` contains values for all parameters.
static StatusOr<mlir::Value> ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
private:
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
@ -80,6 +87,10 @@ class HloFunctionImporter {
// Assumes that the block already has correct arguments populated.
tensorflow::Status ImportInstructions(const HloComputation& computation,
mlir::Block* block);
StatusOr<mlir::Value> ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<mlir::Value>& arguments,
mlir::OpBuilder* builder);
// Imports an instruction.
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,

View File

@ -0,0 +1,16 @@
// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
HloModule TestModule
// CHECK: func @TestComputation
FusedComputation {
// CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>}
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} add(x, x)
}
ENTRY TestComputation {
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation
}

View File

@ -325,3 +325,52 @@ func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>,
return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<f32> {{.*}}lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<f32> {{.*}}lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
// CHECK: "lmhlo.fusion"() ( {
// CHECK: %[[VAR0:.*]] = tensor_load %[[ARG0]] : memref<f32>
// CHECK: %[[VAR1:.*]] = tensor_load %[[ARG1]] : memref<f32>
// CHECK: %[[VAR2:.*]] = mhlo.add %[[VAR0]], %[[VAR1]] : tensor<f32>
// CHECK: tensor_store %[[VAR2]], %[[MEMREF:.*]] : memref<f32>
// CHECK: "lmhlo.terminator"() : () -> ()
// CHECK: }) : () -> ()
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%result = "mhlo.fusion"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%result = "mhlo.add"(%arg2, %arg3): (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%result) : (tensor<f32>) -> ()
}) { fusion_kind = "kLoop" } : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %result : tensor<f32>
}
// -----
// CHECK-LABEL: func @main
// CHECK: "lmhlo.fusion"() ( {
// CHECK: %[[VAL0:.*]] = tensor_load %{{.*}} : memref<f32>
// CHECK: %[[VAL1:.*]] = tensor_load %{{.*}} : memref<f32>
// CHECK: %[[VAL2:.*]] = tensor_load %{{.*}} : memref<f32>
// CHECK: tensor_store %[[VAL0]], %{{.*}} : memref<f32>
// CHECK: tensor_store %[[VAL1]], %{{.*}} : memref<f32>
// CHECK: tensor_store %[[VAL2]], %{{.*}} : memref<f32>
// CHECK: "lmhlo.terminator"() : () -> ()
// CHECK: }) : () -> ()
func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f32>>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
%result = "mhlo.fusion"(%arg0, %arg1) ( {
^bb0(%arg2: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg3: tuple<tensor<f32>>):
%0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<tuple<tensor<f32>>, tensor<f32>>) -> tuple<tensor<f32>>
%1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple<tuple<tensor<f32>>, tensor<f32>>) -> tensor<f32>
%3 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%4 = "mhlo.tuple"(%1, %2, %3) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
"mhlo.return"(%4) : (tuple<tensor<f32>, tensor<f32>, tensor<f32>>) -> ()
}) { fusion_kind = "kLoop" } : (tuple<tuple<tensor<f32>>, tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
}

View File

@ -29,7 +29,9 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
@ -110,7 +113,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
// Run all HLO passes to produce an optimized module.
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
std::move(hlo_module), backend->default_stream_executor(),
backend->memory_allocator());
backend->memory_allocator(), optimize_xla_hlo);
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
"running XLA pass pipeline");
std::unique_ptr<HloModule> optimized_hlo_module =
@ -276,27 +279,138 @@ Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {
return EmitSortOp(instr).status();
}
Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
const Shape& current_shape,
::xla::ShapeIndex* current_shape_index,
SmallVectorImpl<Value>* values) {
if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(CreateView(instr, current_shape.tuple_shapes(i),
current_shape_index, values));
current_shape_index->pop_back();
// Walks MHLO::TupleOp recursively.
Status WalkTuplePostOrder(Value v,
const std::function<Status(Value)>& visitor) {
if (auto* op = v.getDefiningOp()) {
if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
for (Value sub_v : tuple.val()) {
TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
}
return Status::OK();
}
return Status::OK();
}
return visitor(v);
}
// This function removes all uses of a fused region argument, and rewire those
// uses to a `tensor_load %memref`, where %memref is caller argument.
//
// It also flattens all input/output tuples into more region arguments /
// results.
StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
const HloInstruction* root, const Shape& shape,
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
if (shape.IsTuple()) {
llvm::SmallVector<Value, 4> values;
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
shape_index->push_back(i);
TF_ASSIGN_OR_RETURN(
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
b, loc));
values.push_back(v);
shape_index->pop_back();
}
return Value(b->create<mhlo::TupleOp>(loc, values));
}
TF_ASSIGN_OR_RETURN(Value memref,
GetOrCreateArrayView(root, shape, *shape_index));
auto load = b->create<TensorLoadOp>(loc, memref);
if (shape.layout() !=
xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
llvm::SmallVector<int64_t, 4> minor_to_major(
shape.layout().minor_to_major().begin(),
shape.layout().minor_to_major().end());
load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
}
return load.getResult();
}
StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
HloInstruction* instr) {
Location loc = getLocation(instr);
auto* fusion_instr = ::xla::Cast<::xla::HloFusionInstruction>(instr);
auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr),
ArrayRef<NamedAttribute>{});
auto after_fusion = builder_.saveInsertionPoint();
builder_ = mlir::OpBuilder(fusion);
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
llvm::SmallVector<Value, 8> arguments;
for (int i = 0; i < instr->operands().size(); i++) {
const HloInstruction* operand = instr->operand(i);
xla::ShapeIndex shape_index;
TF_ASSIGN_OR_RETURN(
auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
&region_builder, loc));
arguments.push_back(arg);
}
TF_ASSIGN_OR_RETURN(Value result,
::xla::HloFunctionImporter::ImportInstructions(
*fusion_instr->fused_instructions_computation(),
arguments, &region_builder));
{
int i = 0;
llvm::SmallVector<Value, 4> output;
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
region_builder.create<TensorStoreOp>(loc, v, output[i++]);
return Status::OK();
}));
if (i != output.size()) {
return ::xla::InternalError("output sizes don't match");
}
}
// Fold GTE/Tuple pairs.
//
// Since the fused region refers to values in its parent region, we can't
// call applyPatternAndFoldGreedily. We optimize it manually.
//
// Only walk once, because post-ordering is exactly what we need for GTE
// optimizations.
fusion.region().walk([](mhlo::GetTupleElementOp gte) {
SmallVector<Value, 4> folded_values;
if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
gte.replaceAllUsesWith(folded_values[0]);
}
});
// Effectively a DCE on the region.
{
llvm::SmallVector<mlir::Operation*, 4> ops;
fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
// Visit the user first.
std::reverse(ops.begin(), ops.end());
for (auto op : ops) {
if (isOpTriviallyDead(op)) op->erase();
}
}
LOG(ERROR) << instr->GetModule()->ToString();
builder_.restoreInsertionPoint(after_fusion);
return fusion;
}
Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) {
return EmitFusionOp(instr).status();
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {
TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType<MemRefType>(
current_shape, builder_));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(instr, *current_shape_index));
assignment_.GetUniqueSlice(instr, shape_index));
Value alloc = allocations_[slice.allocation()];
if (alloc.getType() == out_type && slice.offset() == 0) {
values->push_back(alloc);
return Status::OK();
return alloc;
}
auto out_memref_type = out_type.dyn_cast<MemRefType>();
@ -304,6 +418,13 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
return tensorflow::errors::Internal(
"Expected memref type when creating a view for leaf type of a tuple.");
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
// shape_index).
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
if (cached_value) {
return cached_value;
}
Value byte_shift =
builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
@ -327,7 +448,24 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
if (physical_out_type != out_type)
result = builder_.create<lmhlo::StaticMemRefCastOp>(loc, out_memref_type,
result);
values->push_back(result);
return cached_value = result;
}
Status LhloDialectEmitter::GetOrCreateViewImpl(
const HloInstruction* instr, const Shape& current_shape,
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
instr, current_shape.tuple_shapes(i), current_shape_index, values));
current_shape_index->pop_back();
}
return Status::OK();
}
TF_ASSIGN_OR_RETURN(
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index));
values->push_back(v);
return Status::OK();
}
@ -336,25 +474,8 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
// create another view to adjust the slice for the shape of the instruction.
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
SmallVectorImpl<Value>* values) {
// Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have
// gone fancier to do the following caching:
// %slice = ViewOp(%allocation, %offset) : memref<i8xSIZE>
// %typed_slice = ViewOp(%slice) : memref<f32x...>
//
// where %slice is cached. This in theory gives easier time for alias
// analysis, since the identity of %slice defines alias. However,
// %typed_slice can't be cached, as different buffers with different types and
// shapes may still alias. Creating two ViewOps doesn't seem to worth the
// effort for a slightly easier aliasing, so we don't over optimize here.
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 1>{});
llvm::SmallVectorImpl<Value>& new_values = result.first->second;
if (result.second) {
::xla::ShapeIndex shape_index;
TF_RETURN_IF_ERROR(
CreateView(instr, instr->shape(), &shape_index, &new_values));
}
values->insert(values->end(), new_values.begin(), new_values.end());
return Status::OK();
::xla::ShapeIndex shape_index;
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values);
}
Status LhloDialectEmitter::Initialize() {

View File

@ -43,6 +43,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
i8_type_(builder_.getIntegerType(8)) {}
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(::xla::HloInstruction* instr);
private:
template <typename OpType>
@ -57,21 +58,31 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
}
tensorflow::Status HandleSort(::xla::HloInstruction* instr) final;
tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final;
// Helper function that recursively visits the tuple structure in
// `current_shape`, and reconstruct a matching lmhlo::TupleOp.
// Each leaf node is converted to an std.view op with corresponding offsets.
// If no tuple presents, it simply returns a view of the buffer.
tensorflow::Status CreateView(const ::xla::HloInstruction* instr,
const ::xla::Shape& current_shape,
::xla::ShapeIndex* current_shape_index,
SmallVectorImpl<Value>* values);
tensorflow::Status GetOrCreateViewImpl(const ::xla::HloInstruction* instr,
const ::xla::Shape& current_shape,
::xla::ShapeIndex* current_shape_index,
SmallVectorImpl<Value>* values);
// Helper function to create view/tuple of views to a buffer for a given
// instruction result.
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
SmallVectorImpl<Value>* values);
::xla::StatusOr<Value> GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& current_shape_index);
::xla::StatusOr<Value> RewriteFusionOperand(const ::xla::HloInstruction* root,
const ::xla::Shape& shape,
::xla::ShapeIndex* shape_index,
OpBuilder* b, Location loc);
// Return an MLIR location for an HLO instruction.
Location getLocation(::xla::HloInstruction* inst) {
return NameLoc::get(builder_.getIdentifier(inst->name()),
@ -102,7 +113,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
//
// `slices_` is populated lazily in the `GetOrCreateView()` helper as we
// process every instruction.
llvm::DenseMap<const xla::HloInstruction*, llvm::SmallVector<Value, 1>>
absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>,
Value>
slices_;
// The BufferAssignment computed by XLA ahead of time.

View File

@ -27,3 +27,9 @@ llvm::cl::opt<bool> emit_return_tuple(
"emit-return-tuple",
llvm::cl::desc("Emit HLO modules with entry computations returning tuple"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> optimize_xla_hlo(
"optimize-xla-hlo",
llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"),
llvm::cl::init(true));

View File

@ -24,5 +24,6 @@ limitations under the License.
extern llvm::cl::opt<bool> emit_use_tuple_arg;
extern llvm::cl::opt<bool> emit_return_tuple;
extern llvm::cl::opt<bool> optimize_xla_hlo;
#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_

View File

@ -28,14 +28,6 @@ namespace xla {
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
tensorflow::LINKER_INITIALIZED);
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
Compiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return Unimplemented("This compiler does not support this method");
}
std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
se::StreamExecutor* executor) const {

View File

@ -188,7 +188,10 @@ class Compiler {
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator);
se::DeviceMemoryAllocator* device_allocator,
bool optimize) {
return Unimplemented("This compiler does not support this method");
}
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are

View File

@ -562,9 +562,11 @@ StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
CpuCompiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
TF_ASSIGN_OR_RETURN(
module, RunHloPasses(std::move(module), executor, device_allocator));
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
if (optimize) {
TF_ASSIGN_OR_RETURN(
module, RunHloPasses(std::move(module), executor, device_allocator));
}
// Select an order for emitting the HLO instructions for each computation.
// Using this sequence enables tighter buffer liveness analysis and reduced

View File

@ -138,9 +138,10 @@ class CpuCompiler : public LLVMCompiler {
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) override;
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator,
bool optimize) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,

View File

@ -276,16 +276,6 @@ class TpuCompiler : public Compiler {
return HloModule::CreateFromProto(result_proto, module->config());
}
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> module,
stream_executor::StreamExecutor* executor,
stream_executor::DeviceMemoryAllocator* device_allocator) override {
return Unimplemented(
"This compiler does not support RunHloPassesAndBufferAssignment.");
}
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
stream_executor::StreamExecutor* executor,