[XLA/GPU] Add an LMHLO -> Execution result test.
PiperOrigin-RevId: 350848311 Change-Id: Ic31b57d7a9adb2dc7b7d63e450e1d377b3deae1e
This commit is contained in:
parent
e9b3d33ccf
commit
70502be4a5
@ -140,6 +140,8 @@ class BufferAllocation {
|
||||
// be live out of the entry computation.
|
||||
bool maybe_live_out() const { return maybe_live_out_; }
|
||||
|
||||
void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
|
||||
|
||||
// Returns the size of the allocation. Necessarily this must be at least as
|
||||
// large as any LogicalBuffer assigned to this allocation.
|
||||
int64 size() const { return size_; }
|
||||
@ -272,14 +274,6 @@ class BufferAllocation {
|
||||
return index() < other.index();
|
||||
}
|
||||
|
||||
private:
|
||||
// Only BufferAssigner and BufferAssignment can modify BufferAllocation.
|
||||
friend class BufferAssigner;
|
||||
friend class BufferAssignment;
|
||||
|
||||
// Adds a LogicalBuffer to the set assigned to this buffer.
|
||||
void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
|
||||
|
||||
void set_entry_computation_parameter(int64 parameter_number,
|
||||
ShapeIndex param_shape_index,
|
||||
bool parameter_aliased_with_output) {
|
||||
@ -289,8 +283,15 @@ class BufferAllocation {
|
||||
param_shape_index_ = std::move(param_shape_index);
|
||||
}
|
||||
|
||||
private:
|
||||
// Only BufferAssigner and BufferAssignment can modify BufferAllocation.
|
||||
friend class BufferAssigner;
|
||||
friend class BufferAssignment;
|
||||
|
||||
// Adds a LogicalBuffer to the set assigned to this buffer.
|
||||
void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
|
||||
|
||||
void set_constant(bool is_constant) { is_constant_ = is_constant; }
|
||||
void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
|
||||
void set_index(Index index) { index_ = index; }
|
||||
void set_size(int64 size) { size_ = size; }
|
||||
|
||||
|
@ -1412,6 +1412,7 @@ cc_library(
|
||||
":tree_reduction_rewriter",
|
||||
":variadic_op_splitter",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "llvm/Transforms/Utils/SplitModule.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
@ -690,7 +691,11 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
|
||||
return result;
|
||||
}
|
||||
|
||||
if (DumpingEnabledForHloModule(*debug_module)) {
|
||||
const bool should_dump =
|
||||
DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
|
||||
module_config.debug_options());
|
||||
|
||||
if (should_dump) {
|
||||
if (debug_module) {
|
||||
if (shard_number.has_value()) {
|
||||
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
|
||||
@ -713,7 +718,7 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
|
||||
}
|
||||
|
||||
// Write PTX to IR dump directory, if IR dumping was requested.
|
||||
if (DumpingEnabledForHloModule(*debug_module)) {
|
||||
if (should_dump) {
|
||||
absl::string_view ptx = result->first;
|
||||
if (debug_module) {
|
||||
if (shard_number.has_value()) {
|
||||
@ -843,16 +848,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
|
||||
llvm::LLVMContext llvm_context;
|
||||
|
||||
GpuDeviceInfo gpu_device_info;
|
||||
gpu_device_info.threads_per_block_limit =
|
||||
stream_exec->GetDeviceDescription().threads_per_block_limit();
|
||||
gpu_device_info.threads_per_warp =
|
||||
stream_exec->GetDeviceDescription().threads_per_warp();
|
||||
gpu_device_info.shared_memory_per_block =
|
||||
stream_exec->GetDeviceDescription().shared_memory_per_block();
|
||||
gpu_device_info.threads_per_core_limit =
|
||||
stream_exec->GetDeviceDescription().threads_per_core_limit();
|
||||
gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
|
||||
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
|
||||
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability =
|
||||
[&]() -> absl::optional<CudaComputeCapability> {
|
||||
@ -948,6 +944,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
return std::unique_ptr<Executable>(gpu_executable);
|
||||
}
|
||||
|
||||
GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) {
|
||||
GpuDeviceInfo gpu_device_info;
|
||||
gpu_device_info.threads_per_block_limit =
|
||||
stream_exec->GetDeviceDescription().threads_per_block_limit();
|
||||
gpu_device_info.threads_per_warp =
|
||||
stream_exec->GetDeviceDescription().threads_per_warp();
|
||||
gpu_device_info.shared_memory_per_block =
|
||||
stream_exec->GetDeviceDescription().shared_memory_per_block();
|
||||
gpu_device_info.threads_per_core_limit =
|
||||
stream_exec->GetDeviceDescription().threads_per_core_limit();
|
||||
gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
|
||||
return gpu_device_info;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
const AotCompilationOptions& options) {
|
||||
@ -971,5 +981,160 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
|
||||
&buffer_assignment, &thunk_schedule, nullptr));
|
||||
return llvm_module;
|
||||
}
|
||||
|
||||
// Analyze the function signature to reconstruct a vector of BufferAllocation
|
||||
// objects, as well as other output information.
|
||||
//
|
||||
// This function also serves as a half-baked verifier for function arg
|
||||
// attributes, since a full verifier doens't exist yet.
|
||||
static Status GetMlirAllocationInfo(
|
||||
mlir::FuncOp func, std::vector<BufferAllocation>* allocations,
|
||||
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
|
||||
Shape* output_shape) {
|
||||
std::vector<absl::optional<BufferAllocation>> maybe_allocations;
|
||||
|
||||
for (int i = 0; i < func.getNumArguments(); i++) {
|
||||
auto allocation_index_attr =
|
||||
func.getArgAttr(i, "lmhlo.alloc").dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
TF_RET_CHECK(allocation_index_attr);
|
||||
int index = allocation_index_attr.getInt();
|
||||
if (index >= maybe_allocations.size()) {
|
||||
maybe_allocations.resize(index + 1);
|
||||
}
|
||||
mlir::BlockArgument arg = func.getArgument(i);
|
||||
TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
|
||||
size_t size = arg.getType().cast<mlir::ShapedType>().getSizeInBits() / 8;
|
||||
maybe_allocations[index].emplace(index, size, 0);
|
||||
}
|
||||
|
||||
allocations->reserve(maybe_allocations.size());
|
||||
for (auto& maybe_alloc : maybe_allocations) {
|
||||
if (maybe_alloc.has_value()) {
|
||||
allocations->push_back(*maybe_alloc);
|
||||
} else {
|
||||
return InvalidArgument("Allocation indices should range in [0, n)");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < func.getNumArguments(); i++) {
|
||||
for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
|
||||
TF_RET_CHECK(attr.first == "lmhlo.alloc" ||
|
||||
attr.first == "lmhlo.params" ||
|
||||
attr.first == "lmhlo.output_index");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Shape> output_shapes;
|
||||
absl::optional<int> rank;
|
||||
for (int i = 0; i < func.getNumArguments(); i++) {
|
||||
auto index =
|
||||
func.getArgAttr(i, "lmhlo.alloc").cast<mlir::IntegerAttr>().getInt();
|
||||
if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
|
||||
allocations->at(index).set_entry_computation_parameter(
|
||||
param_attr.cast<mlir::IntegerAttr>().getInt(), {},
|
||||
static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
|
||||
}
|
||||
if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
|
||||
allocations->at(index).set_maybe_live_out(true);
|
||||
|
||||
// Reconstruct a shape index from output_index.
|
||||
ShapeIndex shape_index;
|
||||
for (const llvm::APInt& i :
|
||||
output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
|
||||
shape_index.push_back(i.getSExtValue());
|
||||
}
|
||||
if (rank.has_value()) {
|
||||
if (*rank != shape_index.size()) {
|
||||
return InvalidArgument("Expect output_index to have the same ranks");
|
||||
}
|
||||
} else {
|
||||
rank.emplace(shape_index.size());
|
||||
}
|
||||
auto& o = (*output_info)[shape_index];
|
||||
o.allocation_index = index;
|
||||
if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
|
||||
o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
|
||||
ShapeIndex{});
|
||||
}
|
||||
|
||||
if (shape_index.size() > 1) {
|
||||
return Unimplemented("Expect array type or 1-level tuple type");
|
||||
}
|
||||
|
||||
mlir::BlockArgument arg = func.getArgument(i);
|
||||
if (shape_index.empty()) {
|
||||
output_shapes.push_back(TypeToShape(arg.getType()));
|
||||
} else {
|
||||
if (shape_index[0] >= output_shapes.size()) {
|
||||
output_shapes.resize(shape_index[0] + 1);
|
||||
}
|
||||
output_shapes[shape_index[0]] = TypeToShape(arg.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
*output_shape = ShapeUtil::MakeTupleShape(output_shapes);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
|
||||
GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
|
||||
const HloModuleConfig& module_config,
|
||||
const Compiler::CompileOptions& options,
|
||||
absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
|
||||
std::unique_ptr<llvm::Module> llvm_module,
|
||||
IrEmitterContext* ir_emitter_context) {
|
||||
mlir::FuncOp entry_function = mlir::cast<mlir::FuncOp>(module.lookupSymbol(
|
||||
llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
|
||||
|
||||
std::vector<BufferAllocation> allocations;
|
||||
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
|
||||
Shape output_shape;
|
||||
absl::flat_hash_map<ShapeIndex, int> output_to_argnum_map;
|
||||
TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
|
||||
&output_info, &output_shape));
|
||||
|
||||
CHECK(!allocations.empty());
|
||||
|
||||
ir_emitter_context->set_allocations(allocations);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ir_emitter,
|
||||
IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
|
||||
ir_emitter_context));
|
||||
ThunkSequence thunk_sequence;
|
||||
for (mlir::Operation& op : entry_function.getBody().front()) {
|
||||
// Omit the terminator.
|
||||
if (&op == &entry_function.getBody().front().back()) {
|
||||
continue;
|
||||
}
|
||||
MlirEmitterInput input;
|
||||
input.op = &op;
|
||||
TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));
|
||||
std::unique_ptr<ThunkSequence> thunks = ir_emitter->ConsumeThunkSequence();
|
||||
TF_RET_CHECK(thunks->size() <= 1);
|
||||
if (!thunks->empty()) {
|
||||
auto thunk = std::move(thunks->front());
|
||||
thunk_sequence.push_back(std::move(thunk));
|
||||
}
|
||||
}
|
||||
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
||||
std::make_unique<ThunkSequence>(std::move(thunk_sequence)));
|
||||
|
||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
||||
compiler->CompileToTargetBinary(
|
||||
module_config, std::move(llvm_module), stream_exec,
|
||||
options, /*debug_module=*/nullptr));
|
||||
|
||||
GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
{std::move(backend_result.first), std::move(backend_result.second),
|
||||
gpu_version, std::move(thunk_schedule),
|
||||
std::move(ir_emitter_context->constants()), std::move(output_info),
|
||||
module_name, output_shape, std::move(allocations)});
|
||||
return std::unique_ptr<Executable>(gpu_executable);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -20,9 +20,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
|
||||
@ -144,6 +146,8 @@ class GpuCompiler : public LLVMCompiler {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
|
||||
};
|
||||
|
||||
GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec);
|
||||
|
||||
// Compile `hlo_module` using XLA GPU and return the LLVM module thus generated.
|
||||
// The GpuExecutable (and the Thunks that are part of it) are not returned.
|
||||
StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
|
||||
@ -153,6 +157,21 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
int pointer_size);
|
||||
|
||||
// Compiles the given LMHLO module to an executable.
|
||||
// ir_emitter_context should be partially populated: buffer_assignment
|
||||
// or buffer_allocations should not be populated, while other fields should be
|
||||
// populated (or left empty if that field is optional).
|
||||
//
|
||||
// NOTE: buffer_assignment will be gone from ir_emitter_context once LMHLO
|
||||
// transition is done.
|
||||
StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
|
||||
GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
|
||||
const HloModuleConfig& module_config,
|
||||
const Compiler::CompileOptions& options,
|
||||
absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
|
||||
std::unique_ptr<llvm::Module> llvm_module,
|
||||
IrEmitterContext* ir_emitter_context);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -431,6 +431,9 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
|
||||
for (auto& p : result.MutableResult()->buffers()) {
|
||||
const ShapeIndex& index = p.first;
|
||||
if (!output_info_.contains(index)) {
|
||||
continue;
|
||||
}
|
||||
const OutputInfo& output_info = output_info_.at(index);
|
||||
const BufferAllocation* allocation =
|
||||
&allocations_[output_info.allocation_index];
|
||||
|
@ -56,12 +56,12 @@ class GpuExecutable : public Executable {
|
||||
};
|
||||
|
||||
struct OutputInfo {
|
||||
// Output is passed-through from a parameter.
|
||||
bool passthrough;
|
||||
|
||||
// Corresponding allocation index.
|
||||
int allocation_index;
|
||||
|
||||
// Output is passed-through from a parameter.
|
||||
bool passthrough = false;
|
||||
|
||||
// Whether this output is hinted to alias a parameter (BufferAllocation*
|
||||
// would indicate the aliased parameter), and what kind of alias it is.
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias_config;
|
||||
|
@ -56,6 +56,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_gpu_test_base",
|
||||
testonly = True,
|
||||
srcs = ["mlir_gpu_test_base.cc"],
|
||||
hdrs = ["mlir_gpu_test_base.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_sorting_test",
|
||||
srcs = ["mlir_sorting_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":mlir_gpu_test_base",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gemm_rewrite_test",
|
||||
srcs = [
|
||||
|
146
tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc
Normal file
146
tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
|
||||
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
MlirGpuTestBase::MlirGpuTestBase() {
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie();
|
||||
BackendOptions options;
|
||||
options.set_platform(platform);
|
||||
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments) {
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
|
||||
llvm_module->setTargetTriple("nvptx");
|
||||
|
||||
se::StreamExecutor* stream_exec = stream->parent();
|
||||
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
|
||||
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability =
|
||||
[&]() -> absl::optional<CudaComputeCapability> {
|
||||
CudaComputeCapability cuda_compute_capability;
|
||||
stream_exec->GetDeviceDescription().cuda_compute_capability(
|
||||
&cuda_compute_capability.cc_major, &cuda_compute_capability.cc_minor);
|
||||
if (cuda_compute_capability.cc_major == -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return cuda_compute_capability;
|
||||
}();
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
/*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr,
|
||||
backend_->platform()->Name(), gpu_device_info, cuda_compute_capability,
|
||||
/*profile_index_map=*/nullptr, /*mlir_context=*/nullptr,
|
||||
llvm_module.get());
|
||||
|
||||
HloModuleConfig module_config;
|
||||
module_config.set_debug_options(DefaultDebugOptionsIgnoringFlags());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executable,
|
||||
CompileLmhloToExecutable(static_cast<GpuCompiler*>(backend_->compiler()),
|
||||
module, "TestModule", module_config,
|
||||
Compiler::CompileOptions(), "main", stream_exec,
|
||||
std::move(llvm_module), &ir_emitter_context));
|
||||
|
||||
ExecutableRunOptions executable_run_options;
|
||||
executable_run_options.set_stream(stream);
|
||||
executable_run_options.set_allocator(backend_->memory_allocator());
|
||||
ServiceExecutableRunOptions run_options(executable_run_options);
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
|
||||
for (auto arg : arguments) {
|
||||
Shape shape =
|
||||
ShapeUtil::MakeShape(xla::U8, {static_cast<int64>(arg.size())});
|
||||
execution_inputs.emplace_back(shape);
|
||||
execution_inputs.back().SetBuffer({}, MaybeOwningDeviceMemory(arg));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto output,
|
||||
executable->ExecuteAsyncOnStream(
|
||||
&run_options, std::move(execution_inputs),
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
|
||||
TF_CHECK_OK(stream->BlockHostUntilDone());
|
||||
|
||||
return std::move(output);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<uint8>>>
|
||||
MlirGpuTestBase::RunMlirModuleWithHostBuffers(
|
||||
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments) {
|
||||
auto* allocator = backend_->memory_allocator();
|
||||
std::vector<se::OwningDeviceMemory> owning_memory;
|
||||
owning_memory.reserve(arguments.size());
|
||||
for (auto host_buffer : arguments) {
|
||||
owning_memory.push_back(
|
||||
allocator
|
||||
->Allocate(backend_->default_device_ordinal(), host_buffer.size())
|
||||
.ConsumeValueOrDie());
|
||||
}
|
||||
auto stream = backend_->BorrowStream(backend_->default_device_ordinal())
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<se::DeviceMemoryBase> args;
|
||||
for (int i = 0; i < owning_memory.size(); i++) {
|
||||
se::DeviceMemoryBase memory(*owning_memory[i]);
|
||||
stream->ThenMemcpy(&memory, static_cast<void*>(arguments[i].data()),
|
||||
memory.size());
|
||||
args.push_back(memory);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(ExecutionOutput output,
|
||||
RunMlirModule(module, stream.get(), args));
|
||||
|
||||
std::vector<std::vector<uint8>> host_outputs;
|
||||
for (const auto& result : output.Result().buffers().leaves()) {
|
||||
host_outputs.emplace_back();
|
||||
host_outputs.back().resize(result.second.size());
|
||||
stream->ThenMemcpy(static_cast<void*>(host_outputs.back().data()),
|
||||
result.second, result.second.size());
|
||||
}
|
||||
TF_CHECK_OK(stream->BlockHostUntilDone());
|
||||
return host_outputs;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<uint8>>>
|
||||
MlirGpuTestBase::RunMlirTextWithHostBuffers(
|
||||
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
|
||||
mlir::StandardOpsDialect,
|
||||
mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
|
||||
mlir::OwningModuleRef module = parseSourceString(
|
||||
llvm::StringRef(module_text.data(), module_text.size()), &context);
|
||||
CHECK(module);
|
||||
return RunMlirModuleWithHostBuffers(*module, arguments);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class MlirGpuTestBase : public HloTestBase {
|
||||
public:
|
||||
MlirGpuTestBase();
|
||||
|
||||
StatusOr<std::vector<std::vector<uint8>>> RunMlirTextWithHostBuffers(
|
||||
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments);
|
||||
|
||||
template <typename T>
|
||||
static absl::Span<uint8> ToUint8Span(std::vector<T>* v) {
|
||||
return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
|
||||
v->size() * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static absl::Span<const T> FromUint8Span(absl::Span<const uint8> span) {
|
||||
CHECK_EQ(0, span.size() % sizeof(T));
|
||||
return absl::Span<const T>(reinterpret_cast<const T*>(span.data()),
|
||||
span.size() / sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
StatusOr<std::vector<std::vector<uint8>>> RunMlirModuleWithHostBuffers(
|
||||
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments);
|
||||
|
||||
StatusOr<ExecutionOutput> RunMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments);
|
||||
|
||||
std::unique_ptr<xla::Backend> backend_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class SortingTest : public MlirGpuTestBase {};
|
||||
|
||||
TEST_F(SortingTest, SimpleCase1) {
|
||||
const char* mlir_text = R"(
|
||||
func @main(%arg0: memref<4xf32> {lmhlo.alloc = 0 : index, lmhlo.params = 0 : index},
|
||||
%arg1: memref<4xf32> {lmhlo.alloc = 1 : index, lmhlo.params = 1 : index},
|
||||
%arg2: memref<4xf32> {lmhlo.alloc = 2 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>},
|
||||
%arg3: memref<4xf32> {lmhlo.alloc = 3 : index, lmhlo.output_index = dense<[1]> : tensor<1xindex>},
|
||||
%arg4: memref<4xf32> {lmhlo.alloc = 4 : index, lmhlo.output_index = dense<[2]> : tensor<1xindex>},
|
||||
%arg5: memref<4xf32> {lmhlo.alloc = 5 : index, lmhlo.output_index = dense<[3]> : tensor<1xindex>}) -> () {
|
||||
"lmhlo.sort"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f32>, %d: tensor<f32>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 0 : i64, is_stable = true} : (memref<4xf32>, memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
"lmhlo.sort"(%arg0, %arg1, %arg4, %arg5) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f32>, %d: tensor<f32>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 0 : i64, is_stable = true} : (memref<4xf32>, memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
"std.return" () : () -> ()
|
||||
})";
|
||||
std::vector<float> arg0 = {3, 1, 2, 4};
|
||||
std::vector<float> arg1 = {13, 12, 14, 11};
|
||||
auto outputs = RunMlirTextWithHostBuffers(
|
||||
mlir_text, {ToUint8Span(&arg0), ToUint8Span(&arg1)})
|
||||
.ConsumeValueOrDie();
|
||||
ASSERT_EQ(4, outputs.size());
|
||||
EXPECT_THAT(FromUint8Span<float>(outputs[0]),
|
||||
ElementsAreArray<float>({4, 3, 2, 1}));
|
||||
EXPECT_THAT(FromUint8Span<float>(outputs[1]),
|
||||
ElementsAreArray<float>({11, 13, 14, 12}));
|
||||
EXPECT_THAT(FromUint8Span<float>(outputs[2]),
|
||||
ElementsAreArray<float>({1, 2, 3, 4}));
|
||||
EXPECT_THAT(FromUint8Span<float>(outputs[3]),
|
||||
ElementsAreArray<float>({12, 14, 13, 11}));
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -81,6 +81,13 @@ ThunkSchedule::ThunkSchedule(
|
||||
}
|
||||
}
|
||||
|
||||
ThunkSchedule::ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)
|
||||
: thunks_(std::move(thunks)) {
|
||||
for (auto& thunk : *thunks_) {
|
||||
thunk_total_order_.push_back(thunk.get());
|
||||
}
|
||||
}
|
||||
|
||||
void ThunkSchedule::RemoveRedundantDependencyEdges() {
|
||||
std::unordered_map<const Thunk*, int> thunk_to_total_order;
|
||||
for (int i = 0; i < thunk_total_order_.size(); ++i) {
|
||||
|
@ -55,6 +55,9 @@ class ThunkSchedule {
|
||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
||||
absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo);
|
||||
|
||||
// Single stream, trivial schedule in the ThunkSequence order.
|
||||
explicit ThunkSchedule(std::unique_ptr<ThunkSequence> thunks);
|
||||
|
||||
// Returns the total order of executing all the thunks.
|
||||
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
||||
|
||||
@ -66,10 +69,18 @@ class ThunkSchedule {
|
||||
}
|
||||
|
||||
// Delegates to StreamAssignment.
|
||||
int StreamCount() const { return stream_assignment_->StreamCount(); }
|
||||
int StreamCount() const {
|
||||
if (stream_assignment_) {
|
||||
return stream_assignment_->StreamCount();
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
int StreamNumberForThunk(const Thunk* thunk) const {
|
||||
if (stream_assignment_) {
|
||||
return stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(thunk));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
string ToString() const;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user