Support alternative Evaluators in the XLA Interpreter Executable
PiperOrigin-RevId: 301206533 Change-Id: I8aee5751d2714f1c88bbff0d883a8482e63dd52e
This commit is contained in:
parent
09e0b6cea4
commit
a749d9dfa8
@ -71,11 +71,36 @@ cc_library(
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "executable_base",
|
||||
srcs = ["executable_base.cc"],
|
||||
hdrs = ["executable_base.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:dynamic_dimension_inference",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "executable",
|
||||
srcs = ["executable.cc"],
|
||||
hdrs = ["executable.h"],
|
||||
deps = [
|
||||
":executable_base",
|
||||
":executor",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable_base.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
@ -41,8 +42,7 @@ InterpreterExecutable::InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
|
||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
||||
/*hlo_profile_index_map=*/nullptr),
|
||||
: InterpreterExecutableBase(std::move(hlo_module)),
|
||||
evaluator_(std::move(evaluator)),
|
||||
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
|
||||
if (dynamic_dimension_inference_.has_value()) {
|
||||
@ -51,107 +51,12 @@ InterpreterExecutable::InterpreterExecutable(
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterExecutable::~InterpreterExecutable() {}
|
||||
|
||||
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
const se::Platform* platform = executor->platform();
|
||||
|
||||
// Convert the ShapeTree to a ShapedBuffer. We do this so we can call
|
||||
// TransferManager methods below.
|
||||
std::vector<ShapedBuffer> argument_buffers;
|
||||
argument_buffers.reserve(arguments.size());
|
||||
for (auto& argument : arguments) {
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
|
||||
/*platform=*/nullptr,
|
||||
/*device_ordinal=*/0));
|
||||
auto in_it = buffers.begin();
|
||||
auto out_it = argument_buffers.back().buffers().begin();
|
||||
for (; in_it != buffers.end(); ++in_it, ++out_it) {
|
||||
out_it->second = in_it->second.AsDeviceMemoryBase();
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Execute " << module().name();
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (const auto& a : argument_buffers) {
|
||||
VLOG(2) << "-- argument " << a;
|
||||
}
|
||||
}
|
||||
|
||||
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
const HloComputation* computation = module().entry_computation();
|
||||
if (computation->num_parameters() != arguments.size()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Mismatch between argument count and graph parameter count.");
|
||||
}
|
||||
|
||||
// Check that the args have the right shape.
|
||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
||||
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||
actual_shape)) {
|
||||
return InvalidArgument(
|
||||
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
|
||||
ShapeUtil::HumanStringWithLayout(expected_shape),
|
||||
ShapeUtil::HumanStringWithLayout(actual_shape));
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager,
|
||||
TransferManager::GetForPlatform(platform));
|
||||
|
||||
// Transform the ShapedBuffer arguments into literals which the evaluator
|
||||
// consumes.
|
||||
std::vector<Literal> arg_literals;
|
||||
for (int64 p = 0; p < computation->num_parameters(); ++p) {
|
||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||
transfer_manager->TransferLiteralFromDevice(
|
||||
run_options->stream(), argument_buffers[p]));
|
||||
arg_literals.push_back(std::move(arg_literal));
|
||||
}
|
||||
|
||||
StatusOr<Literal> InterpreterExecutable::Evaluate(
|
||||
const HloComputation& computation, absl::Span<const Literal> arg_literals) {
|
||||
// Execute the graph using the HloEvaluator.
|
||||
Literal result_literal;
|
||||
{
|
||||
tensorflow::mutex_lock lock(evaluator_lock_);
|
||||
evaluator_->ResetVisitStates();
|
||||
TF_ASSIGN_OR_RETURN(result_literal,
|
||||
evaluator_->Evaluate(*computation, arg_literals));
|
||||
}
|
||||
|
||||
// Transform the result literal back into a ShapedBuffer.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
result_literal.shape(), run_options->allocator(),
|
||||
executor->device_ordinal()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
|
||||
run_options->stream(), result_literal, result_buffers));
|
||||
ExecutionOutput result(std::move(result_buffers));
|
||||
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
ExecutionProfile* profile = run_options->run_options().execution_profile();
|
||||
if (profile) {
|
||||
const double nanoseconds = (end_micros - start_micros) * 1000.0;
|
||||
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||
}
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
result.AddToBeReleased(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(result);
|
||||
tensorflow::mutex_lock lock(evaluator_lock_);
|
||||
evaluator_->ResetVisitStates();
|
||||
return evaluator_->Evaluate(computation, arg_literals);
|
||||
}
|
||||
|
||||
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable_base.h"
|
||||
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -40,23 +41,20 @@ namespace interpreter {
|
||||
|
||||
// Responsible for running a HLO graph through the HloEvaluator and output
|
||||
// buffer allocation. Refer to interpreter/README.md for more.
|
||||
class InterpreterExecutable : public Executable {
|
||||
class InterpreterExecutable : public InterpreterExecutableBase {
|
||||
public:
|
||||
InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
|
||||
~InterpreterExecutable() override;
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override
|
||||
TF_LOCKS_EXCLUDED(evaluator_lock_);
|
||||
|
||||
static int64 ShapeSizeBytes(const Shape& shape);
|
||||
|
||||
protected:
|
||||
StatusOr<Literal> Evaluate(const HloComputation& computation,
|
||||
absl::Span<const Literal> arg_literals) override
|
||||
TF_LOCKS_EXCLUDED(evaluator_lock_);
|
||||
|
||||
// The interpreter interprets executables with an HloEvaluator.
|
||||
std::unique_ptr<HloEvaluator> evaluator_ TF_PT_GUARDED_BY(evaluator_lock_);
|
||||
mutable tensorflow::mutex evaluator_lock_;
|
||||
|
||||
137
tensorflow/compiler/xla/service/interpreter/executable_base.cc
Normal file
137
tensorflow/compiler/xla/service/interpreter/executable_base.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable_base.h"
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
|
||||
namespace xla {
|
||||
namespace interpreter {
|
||||
|
||||
InterpreterExecutableBase::InterpreterExecutableBase(
|
||||
std::unique_ptr<HloModule> hlo_module)
|
||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
||||
/*hlo_profile_index_map=*/nullptr) {}
|
||||
|
||||
StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
const se::Platform* platform = executor->platform();
|
||||
|
||||
// Convert the ShapeTree to a ShapedBuffer. We do this so we can call
|
||||
// TransferManager methods below.
|
||||
std::vector<ShapedBuffer> argument_buffers;
|
||||
argument_buffers.reserve(arguments.size());
|
||||
for (auto& argument : arguments) {
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
|
||||
/*platform=*/nullptr,
|
||||
/*device_ordinal=*/0));
|
||||
auto in_it = buffers.begin();
|
||||
auto out_it = argument_buffers.back().buffers().begin();
|
||||
for (; in_it != buffers.end(); ++in_it, ++out_it) {
|
||||
out_it->second = in_it->second.AsDeviceMemoryBase();
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Execute " << module().name();
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (const auto& a : argument_buffers) {
|
||||
VLOG(2) << "-- argument " << a;
|
||||
}
|
||||
}
|
||||
|
||||
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
const HloComputation* computation = module().entry_computation();
|
||||
if (computation->num_parameters() != arguments.size()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Mismatch between argument count and graph parameter count.");
|
||||
}
|
||||
|
||||
// Check that the args have the right shape.
|
||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
||||
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||
actual_shape)) {
|
||||
return InvalidArgument(
|
||||
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
|
||||
ShapeUtil::HumanStringWithLayout(expected_shape),
|
||||
ShapeUtil::HumanStringWithLayout(actual_shape));
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager,
|
||||
TransferManager::GetForPlatform(platform));
|
||||
|
||||
// Transform the ShapedBuffer arguments into literals which the evaluator
|
||||
// consumes.
|
||||
std::vector<Literal> arg_literals;
|
||||
for (int64 p = 0; p < computation->num_parameters(); ++p) {
|
||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||
transfer_manager->TransferLiteralFromDevice(
|
||||
run_options->stream(), argument_buffers[p]));
|
||||
arg_literals.push_back(std::move(arg_literal));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Literal result_literal,
|
||||
Evaluate(*computation, arg_literals));
|
||||
|
||||
// Transform the result literal back into a ShapedBuffer.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
result_literal.shape(), run_options->allocator(),
|
||||
executor->device_ordinal()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
|
||||
run_options->stream(), result_literal, result_buffers));
|
||||
ExecutionOutput result(std::move(result_buffers));
|
||||
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
ExecutionProfile* profile = run_options->run_options().execution_profile();
|
||||
if (profile) {
|
||||
const double nanoseconds = (end_micros - start_micros) * 1000.0;
|
||||
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||
}
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
result.AddToBeReleased(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
} // namespace interpreter
|
||||
} // namespace xla
|
||||
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
namespace xla {
|
||||
namespace interpreter {
|
||||
|
||||
// Responsible for running a HLO graph through the HloEvaluator and output
|
||||
// buffer allocation. Refer to interpreter/README.md for more.
|
||||
class InterpreterExecutableBase : public Executable {
|
||||
public:
|
||||
explicit InterpreterExecutableBase(std::unique_ptr<HloModule> hlo_module);
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
protected:
|
||||
virtual StatusOr<Literal> Evaluate(
|
||||
const HloComputation& computation,
|
||||
absl::Span<const Literal> arg_literals) = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutableBase);
|
||||
};
|
||||
|
||||
} // namespace interpreter
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_
|
||||
@ -130,19 +130,19 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
|
||||
std::function<port::Status()> callback) override;
|
||||
|
||||
port::Status AllocateEvent(Event *event) override {
|
||||
return port::Status{port::error::UNIMPLEMENTED, ""};
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status DeallocateEvent(Event *event) override {
|
||||
return port::Status{port::error::UNIMPLEMENTED, ""};
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status RecordEvent(Stream *stream, Event *event) override {
|
||||
return port::Status{port::error::UNIMPLEMENTED, ""};
|
||||
return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"};
|
||||
}
|
||||
|
||||
port::Status WaitForEvent(Stream *stream, Event *event) override {
|
||||
return port::Status{port::error::UNIMPLEMENTED, ""};
|
||||
return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"};
|
||||
}
|
||||
|
||||
Event::Status PollForEventStatus(Event *event) override {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user