[XLA/GPU] Add an LMHLO -> Execution result test.

PiperOrigin-RevId: 350848311
Change-Id: Ic31b57d7a9adb2dc7b7d63e450e1d377b3deae1e
This commit is contained in:
Tim Shen 2021-01-08 15:32:38 -08:00 committed by TensorFlower Gardener
parent e9b3d33ccf
commit 70502be4a5
12 changed files with 531 additions and 26 deletions

View File

@ -140,6 +140,8 @@ class BufferAllocation {
// be live out of the entry computation.
bool maybe_live_out() const { return maybe_live_out_; }
void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
// Returns the size of the allocation. Necessarily this must be at least as
// large as any LogicalBuffer assigned to this allocation.
int64 size() const { return size_; }
@ -272,14 +274,6 @@ class BufferAllocation {
return index() < other.index();
}
private:
// Only BufferAssigner and BufferAssignment can modify BufferAllocation.
friend class BufferAssigner;
friend class BufferAssignment;
// Adds a LogicalBuffer to the set assigned to this buffer.
void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
void set_entry_computation_parameter(int64 parameter_number,
ShapeIndex param_shape_index,
bool parameter_aliased_with_output) {
@ -289,8 +283,15 @@ class BufferAllocation {
param_shape_index_ = std::move(param_shape_index);
}
private:
// Only BufferAssigner and BufferAssignment can modify BufferAllocation.
friend class BufferAssigner;
friend class BufferAssignment;
// Adds a LogicalBuffer to the set assigned to this buffer.
void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
void set_constant(bool is_constant) { is_constant_ = is_constant; }
void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
void set_index(Index index) { index_ = index; }
void set_size(int64 size) { size_ = size; }

View File

@ -1412,6 +1412,7 @@ cc_library(
":tree_reduction_rewriter",
":variadic_op_splitter",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "llvm/Transforms/Utils/SplitModule.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
@ -690,7 +691,11 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
return result;
}
if (DumpingEnabledForHloModule(*debug_module)) {
const bool should_dump =
DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
module_config.debug_options());
if (should_dump) {
if (debug_module) {
if (shard_number.has_value()) {
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
@ -713,7 +718,7 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
}
// Write PTX to IR dump directory, if IR dumping was requested.
if (DumpingEnabledForHloModule(*debug_module)) {
if (should_dump) {
absl::string_view ptx = result->first;
if (debug_module) {
if (shard_number.has_value()) {
@ -843,16 +848,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
llvm::LLVMContext llvm_context;
GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit =
stream_exec->GetDeviceDescription().threads_per_block_limit();
gpu_device_info.threads_per_warp =
stream_exec->GetDeviceDescription().threads_per_warp();
gpu_device_info.shared_memory_per_block =
stream_exec->GetDeviceDescription().shared_memory_per_block();
gpu_device_info.threads_per_core_limit =
stream_exec->GetDeviceDescription().threads_per_core_limit();
gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
absl::optional<CudaComputeCapability> cuda_compute_capability =
[&]() -> absl::optional<CudaComputeCapability> {
@ -948,6 +944,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
return std::unique_ptr<Executable>(gpu_executable);
}
GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) {
GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit =
stream_exec->GetDeviceDescription().threads_per_block_limit();
gpu_device_info.threads_per_warp =
stream_exec->GetDeviceDescription().threads_per_warp();
gpu_device_info.shared_memory_per_block =
stream_exec->GetDeviceDescription().shared_memory_per_block();
gpu_device_info.threads_per_core_limit =
stream_exec->GetDeviceDescription().threads_per_core_limit();
gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
return gpu_device_info;
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) {
@ -971,5 +981,160 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
&buffer_assignment, &thunk_schedule, nullptr));
return llvm_module;
}
// Analyze the function signature to reconstruct a vector of BufferAllocation
// objects, as well as other output information.
//
// This function also serves as a half-baked verifier for function arg
// attributes, since a full verifier doens't exist yet.
static Status GetMlirAllocationInfo(
mlir::FuncOp func, std::vector<BufferAllocation>* allocations,
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
Shape* output_shape) {
std::vector<absl::optional<BufferAllocation>> maybe_allocations;
for (int i = 0; i < func.getNumArguments(); i++) {
auto allocation_index_attr =
func.getArgAttr(i, "lmhlo.alloc").dyn_cast_or_null<mlir::IntegerAttr>();
TF_RET_CHECK(allocation_index_attr);
int index = allocation_index_attr.getInt();
if (index >= maybe_allocations.size()) {
maybe_allocations.resize(index + 1);
}
mlir::BlockArgument arg = func.getArgument(i);
TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
size_t size = arg.getType().cast<mlir::ShapedType>().getSizeInBits() / 8;
maybe_allocations[index].emplace(index, size, 0);
}
allocations->reserve(maybe_allocations.size());
for (auto& maybe_alloc : maybe_allocations) {
if (maybe_alloc.has_value()) {
allocations->push_back(*maybe_alloc);
} else {
return InvalidArgument("Allocation indices should range in [0, n)");
}
}
for (int i = 0; i < func.getNumArguments(); i++) {
for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
TF_RET_CHECK(attr.first == "lmhlo.alloc" ||
attr.first == "lmhlo.params" ||
attr.first == "lmhlo.output_index");
}
}
std::vector<Shape> output_shapes;
absl::optional<int> rank;
for (int i = 0; i < func.getNumArguments(); i++) {
auto index =
func.getArgAttr(i, "lmhlo.alloc").cast<mlir::IntegerAttr>().getInt();
if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
allocations->at(index).set_entry_computation_parameter(
param_attr.cast<mlir::IntegerAttr>().getInt(), {},
static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
}
if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
allocations->at(index).set_maybe_live_out(true);
// Reconstruct a shape index from output_index.
ShapeIndex shape_index;
for (const llvm::APInt& i :
output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
shape_index.push_back(i.getSExtValue());
}
if (rank.has_value()) {
if (*rank != shape_index.size()) {
return InvalidArgument("Expect output_index to have the same ranks");
}
} else {
rank.emplace(shape_index.size());
}
auto& o = (*output_info)[shape_index];
o.allocation_index = index;
if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
ShapeIndex{});
}
if (shape_index.size() > 1) {
return Unimplemented("Expect array type or 1-level tuple type");
}
mlir::BlockArgument arg = func.getArgument(i);
if (shape_index.empty()) {
output_shapes.push_back(TypeToShape(arg.getType()));
} else {
if (shape_index[0] >= output_shapes.size()) {
output_shapes.resize(shape_index[0] + 1);
}
output_shapes[shape_index[0]] = TypeToShape(arg.getType());
}
}
}
*output_shape = ShapeUtil::MakeTupleShape(output_shapes);
return Status::OK();
}
StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
const HloModuleConfig& module_config,
const Compiler::CompileOptions& options,
absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
std::unique_ptr<llvm::Module> llvm_module,
IrEmitterContext* ir_emitter_context) {
mlir::FuncOp entry_function = mlir::cast<mlir::FuncOp>(module.lookupSymbol(
llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
std::vector<BufferAllocation> allocations;
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
Shape output_shape;
absl::flat_hash_map<ShapeIndex, int> output_to_argnum_map;
TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
&output_info, &output_shape));
CHECK(!allocations.empty());
ir_emitter_context->set_allocations(allocations);
TF_ASSIGN_OR_RETURN(
auto ir_emitter,
IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
ir_emitter_context));
ThunkSequence thunk_sequence;
for (mlir::Operation& op : entry_function.getBody().front()) {
// Omit the terminator.
if (&op == &entry_function.getBody().front().back()) {
continue;
}
MlirEmitterInput input;
input.op = &op;
TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));
std::unique_ptr<ThunkSequence> thunks = ir_emitter->ConsumeThunkSequence();
TF_RET_CHECK(thunks->size() <= 1);
if (!thunks->empty()) {
auto thunk = std::move(thunks->front());
thunk_sequence.push_back(std::move(thunk));
}
}
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
std::make_unique<ThunkSequence>(std::move(thunk_sequence)));
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
compiler->CompileToTargetBinary(
module_config, std::move(llvm_module), stream_exec,
options, /*debug_module=*/nullptr));
GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
auto* gpu_executable = new GpuExecutable(
{std::move(backend_result.first), std::move(backend_result.second),
gpu_version, std::move(thunk_schedule),
std::move(ir_emitter_context->constants()), std::move(output_info),
module_name, output_shape, std::move(allocations)});
return std::unique_ptr<Executable>(gpu_executable);
}
} // namespace gpu
} // namespace xla

View File

@ -20,9 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@ -144,6 +146,8 @@ class GpuCompiler : public LLVMCompiler {
TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
};
GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec);
// Compile `hlo_module` using XLA GPU and return the LLVM module thus generated.
// The GpuExecutable (and the Thunks that are part of it) are not returned.
StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
@ -153,6 +157,21 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
absl::optional<CudaComputeCapability> cuda_compute_capability,
int pointer_size);
// Compiles the given LMHLO module to an executable.
// ir_emitter_context should be partially populated: buffer_assignment
// or buffer_allocations should not be populated, while other fields should be
// populated (or left empty if that field is optional).
//
// NOTE: buffer_assignment will be gone from ir_emitter_context once LMHLO
// transition is done.
StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
const HloModuleConfig& module_config,
const Compiler::CompileOptions& options,
absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
std::unique_ptr<llvm::Module> llvm_module,
IrEmitterContext* ir_emitter_context);
} // namespace gpu
} // namespace xla

View File

@ -431,6 +431,9 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
for (auto& p : result.MutableResult()->buffers()) {
const ShapeIndex& index = p.first;
if (!output_info_.contains(index)) {
continue;
}
const OutputInfo& output_info = output_info_.at(index);
const BufferAllocation* allocation =
&allocations_[output_info.allocation_index];

View File

@ -56,12 +56,12 @@ class GpuExecutable : public Executable {
};
struct OutputInfo {
// Output is passed-through from a parameter.
bool passthrough;
// Corresponding allocation index.
int allocation_index;
// Output is passed-through from a parameter.
bool passthrough = false;
// Whether this output is hinted to alias a parameter (BufferAllocation*
// would indicate the aliased parameter), and what kind of alias it is.
absl::optional<HloInputOutputAliasConfig::Alias> alias_config;

View File

@ -56,6 +56,36 @@ cc_library(
],
)
cc_library(
name = "mlir_gpu_test_base",
testonly = True,
srcs = ["mlir_gpu_test_base.cc"],
hdrs = ["mlir_gpu_test_base.h"],
deps = [
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"@llvm-project//llvm:Core",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
],
)
tf_cc_test(
name = "mlir_sorting_test",
srcs = ["mlir_sorting_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":mlir_gpu_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "gemm_rewrite_test",
srcs = [

View File

@ -0,0 +1,146 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
#include "llvm/IR/LLVMContext.h"
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
namespace xla {
namespace gpu {
MlirGpuTestBase::MlirGpuTestBase() {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie();
BackendOptions options;
options.set_platform(platform);
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
}
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
mlir::ModuleOp module, se::Stream* stream,
absl::Span<const se::DeviceMemoryBase> arguments) {
llvm::LLVMContext llvm_context;
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
llvm_module->setTargetTriple("nvptx");
se::StreamExecutor* stream_exec = stream->parent();
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
absl::optional<CudaComputeCapability> cuda_compute_capability =
[&]() -> absl::optional<CudaComputeCapability> {
CudaComputeCapability cuda_compute_capability;
stream_exec->GetDeviceDescription().cuda_compute_capability(
&cuda_compute_capability.cc_major, &cuda_compute_capability.cc_minor);
if (cuda_compute_capability.cc_major == -1) {
return absl::nullopt;
}
return cuda_compute_capability;
}();
IrEmitterContext ir_emitter_context(
/*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr,
backend_->platform()->Name(), gpu_device_info, cuda_compute_capability,
/*profile_index_map=*/nullptr, /*mlir_context=*/nullptr,
llvm_module.get());
HloModuleConfig module_config;
module_config.set_debug_options(DefaultDebugOptionsIgnoringFlags());
TF_ASSIGN_OR_RETURN(
auto executable,
CompileLmhloToExecutable(static_cast<GpuCompiler*>(backend_->compiler()),
module, "TestModule", module_config,
Compiler::CompileOptions(), "main", stream_exec,
std::move(llvm_module), &ir_emitter_context));
ExecutableRunOptions executable_run_options;
executable_run_options.set_stream(stream);
executable_run_options.set_allocator(backend_->memory_allocator());
ServiceExecutableRunOptions run_options(executable_run_options);
std::vector<ExecutionInput> execution_inputs;
for (auto arg : arguments) {
Shape shape =
ShapeUtil::MakeShape(xla::U8, {static_cast<int64>(arg.size())});
execution_inputs.emplace_back(shape);
execution_inputs.back().SetBuffer({}, MaybeOwningDeviceMemory(arg));
}
TF_ASSIGN_OR_RETURN(auto output,
executable->ExecuteAsyncOnStream(
&run_options, std::move(execution_inputs),
/*hlo_execution_profile=*/nullptr));
TF_CHECK_OK(stream->BlockHostUntilDone());
return std::move(output);
}
StatusOr<std::vector<std::vector<uint8>>>
MlirGpuTestBase::RunMlirModuleWithHostBuffers(
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments) {
auto* allocator = backend_->memory_allocator();
std::vector<se::OwningDeviceMemory> owning_memory;
owning_memory.reserve(arguments.size());
for (auto host_buffer : arguments) {
owning_memory.push_back(
allocator
->Allocate(backend_->default_device_ordinal(), host_buffer.size())
.ConsumeValueOrDie());
}
auto stream = backend_->BorrowStream(backend_->default_device_ordinal())
.ConsumeValueOrDie();
std::vector<se::DeviceMemoryBase> args;
for (int i = 0; i < owning_memory.size(); i++) {
se::DeviceMemoryBase memory(*owning_memory[i]);
stream->ThenMemcpy(&memory, static_cast<void*>(arguments[i].data()),
memory.size());
args.push_back(memory);
}
TF_ASSIGN_OR_RETURN(ExecutionOutput output,
RunMlirModule(module, stream.get(), args));
std::vector<std::vector<uint8>> host_outputs;
for (const auto& result : output.Result().buffers().leaves()) {
host_outputs.emplace_back();
host_outputs.back().resize(result.second.size());
stream->ThenMemcpy(static_cast<void*>(host_outputs.back().data()),
result.second, result.second.size());
}
TF_CHECK_OK(stream->BlockHostUntilDone());
return host_outputs;
}
StatusOr<std::vector<std::vector<uint8>>>
MlirGpuTestBase::RunMlirTextWithHostBuffers(
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
mlir::MLIRContext context;
context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
mlir::StandardOpsDialect,
mlir::lmhlo_gpu::LmhloGpuDialect>();
mlir::OwningModuleRef module = parseSourceString(
llvm::StringRef(module_text.data(), module_text.size()), &context);
CHECK(module);
return RunMlirModuleWithHostBuffers(*module, arguments);
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace gpu {
class MlirGpuTestBase : public HloTestBase {
public:
MlirGpuTestBase();
StatusOr<std::vector<std::vector<uint8>>> RunMlirTextWithHostBuffers(
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments);
template <typename T>
static absl::Span<uint8> ToUint8Span(std::vector<T>* v) {
return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
v->size() * sizeof(T));
}
template <typename T>
static absl::Span<const T> FromUint8Span(absl::Span<const uint8> span) {
CHECK_EQ(0, span.size() % sizeof(T));
return absl::Span<const T>(reinterpret_cast<const T*>(span.data()),
span.size() / sizeof(T));
}
private:
StatusOr<std::vector<std::vector<uint8>>> RunMlirModuleWithHostBuffers(
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments);
StatusOr<ExecutionOutput> RunMlirModule(
mlir::ModuleOp module, se::Stream* stream,
absl::Span<const se::DeviceMemoryBase> arguments);
std::unique_ptr<xla::Backend> backend_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_MLIR_GPU_TEST_BASE_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace gpu {
using ::testing::ElementsAreArray;
class SortingTest : public MlirGpuTestBase {};
TEST_F(SortingTest, SimpleCase1) {
const char* mlir_text = R"(
func @main(%arg0: memref<4xf32> {lmhlo.alloc = 0 : index, lmhlo.params = 0 : index},
%arg1: memref<4xf32> {lmhlo.alloc = 1 : index, lmhlo.params = 1 : index},
%arg2: memref<4xf32> {lmhlo.alloc = 2 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>},
%arg3: memref<4xf32> {lmhlo.alloc = 3 : index, lmhlo.output_index = dense<[1]> : tensor<1xindex>},
%arg4: memref<4xf32> {lmhlo.alloc = 4 : index, lmhlo.output_index = dense<[2]> : tensor<1xindex>},
%arg5: memref<4xf32> {lmhlo.alloc = 5 : index, lmhlo.output_index = dense<[3]> : tensor<1xindex>}) -> () {
"lmhlo.sort"(%arg0, %arg1, %arg2, %arg3) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f32>, %d: tensor<f32>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 0 : i64, is_stable = true} : (memref<4xf32>, memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
"lmhlo.sort"(%arg0, %arg1, %arg4, %arg5) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f32>, %d: tensor<f32>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 0 : i64, is_stable = true} : (memref<4xf32>, memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
"std.return" () : () -> ()
})";
std::vector<float> arg0 = {3, 1, 2, 4};
std::vector<float> arg1 = {13, 12, 14, 11};
auto outputs = RunMlirTextWithHostBuffers(
mlir_text, {ToUint8Span(&arg0), ToUint8Span(&arg1)})
.ConsumeValueOrDie();
ASSERT_EQ(4, outputs.size());
EXPECT_THAT(FromUint8Span<float>(outputs[0]),
ElementsAreArray<float>({4, 3, 2, 1}));
EXPECT_THAT(FromUint8Span<float>(outputs[1]),
ElementsAreArray<float>({11, 13, 14, 12}));
EXPECT_THAT(FromUint8Span<float>(outputs[2]),
ElementsAreArray<float>({1, 2, 3, 4}));
EXPECT_THAT(FromUint8Span<float>(outputs[3]),
ElementsAreArray<float>({12, 14, 13, 11}));
}
} // namespace gpu
} // namespace xla

View File

@ -81,6 +81,13 @@ ThunkSchedule::ThunkSchedule(
}
}
ThunkSchedule::ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)
: thunks_(std::move(thunks)) {
for (auto& thunk : *thunks_) {
thunk_total_order_.push_back(thunk.get());
}
}
void ThunkSchedule::RemoveRedundantDependencyEdges() {
std::unordered_map<const Thunk*, int> thunk_to_total_order;
for (int i = 0; i < thunk_total_order_.size(); ++i) {

View File

@ -55,6 +55,9 @@ class ThunkSchedule {
std::unique_ptr<StreamAssignment> stream_assignment,
absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo);
// Single stream, trivial schedule in the ThunkSequence order.
explicit ThunkSchedule(std::unique_ptr<ThunkSequence> thunks);
// Returns the total order of executing all the thunks.
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
@ -66,10 +69,18 @@ class ThunkSchedule {
}
// Delegates to StreamAssignment.
int StreamCount() const { return stream_assignment_->StreamCount(); }
int StreamCount() const {
if (stream_assignment_) {
return stream_assignment_->StreamCount();
}
return 1;
}
int StreamNumberForThunk(const Thunk* thunk) const {
if (stream_assignment_) {
return stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(thunk));
}
return 0;
}
string ToString() const;