[XLA] Refactor Executable::ExecuteAsyncOnStream.
Change implementations of Executable to always implement the overload that takes a std::vector<ShapeTree<MaybeOwningDeviceMemory>>. Make the non-owning version a wrapper around the maybe-owning version. Simplification in preparation for plumbing buffer donation into JAX. This change is also a necessary preparatory step for implementing buffer donation on CPU and GPU. PiperOrigin-RevId: 283615681 Change-Id: I0d3c65bee506822d23e5827493213e0921b4ef9e
This commit is contained in:
parent
17f3e8ad39
commit
af79ee35f5
@ -242,16 +242,9 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/platform:env",
|
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//tensorflow/core/platform:macros",
|
|
||||||
"//tensorflow/core/platform:mutex",
|
|
||||||
"//tensorflow/core/platform:platform_port",
|
|
||||||
"//tensorflow/core/platform:types",
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"//tensorflow/stream_executor/host:host_stream",
|
"//tensorflow/stream_executor/host:host_stream",
|
||||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -45,7 +44,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
|
||||||
#include "tensorflow/stream_executor/host/host_stream.h"
|
#include "tensorflow/stream_executor/host/host_stream.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -75,12 +73,11 @@ CpuExecutable::CpuExecutable(
|
|||||||
<< reinterpret_cast<void*>(compute_function_);
|
<< reinterpret_cast<void*>(compute_function_);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
|
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||||
std::vector<se::OwningDeviceMemory>,
|
|
||||||
std::vector<se::OwningDeviceMemory>>>
|
std::vector<se::OwningDeviceMemory>>>
|
||||||
CpuExecutable::CreateBufferTable(
|
CpuExecutable::CreateBufferTable(
|
||||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
|
absl::Span<const ShapedBuffer* const> arguments) {
|
||||||
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
||||||
assignment_->Allocations().size());
|
assignment_->Allocations().size());
|
||||||
std::vector<se::OwningDeviceMemory> owning_buffers(
|
std::vector<se::OwningDeviceMemory> owning_buffers(
|
||||||
@ -94,9 +91,8 @@ CpuExecutable::CreateBufferTable(
|
|||||||
VLOG(3) << allocation.ToString();
|
VLOG(3) << allocation.ToString();
|
||||||
|
|
||||||
if (allocation.is_entry_computation_parameter()) {
|
if (allocation.is_entry_computation_parameter()) {
|
||||||
unowning_buffers[i] = arguments[allocation.parameter_number()]
|
unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
|
||||||
.element(allocation.param_shape_index())
|
allocation.param_shape_index());
|
||||||
.AsDeviceMemoryBase();
|
|
||||||
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
|
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
|
||||||
<< "Size mismatch on param " << allocation.parameter_number()
|
<< "Size mismatch on param " << allocation.parameter_number()
|
||||||
<< " at shape index " << allocation.param_shape_index().ToString();
|
<< " at shape index " << allocation.param_shape_index().ToString();
|
||||||
@ -138,17 +134,7 @@ CpuExecutable::CreateBufferTable(
|
|||||||
assignment_->GetUniqueTopLevelOutputSlice());
|
assignment_->GetUniqueTopLevelOutputSlice());
|
||||||
VLOG(3) << "result index: " << result_slice.index();
|
VLOG(3) << "result index: " << result_slice.index();
|
||||||
|
|
||||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
return {{std::move(unowning_buffers), std::move(owning_buffers)}};
|
||||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
|
||||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
|
||||||
auto maybe_owning_buffer = buffer.second.Release();
|
|
||||||
if (maybe_owning_buffer) {
|
|
||||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {{std::move(unowning_buffers), std::move(owning_buffers),
|
|
||||||
std::move(buffers_to_free)}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CpuExecutable::ExecuteComputeFunction(
|
Status CpuExecutable::ExecuteComputeFunction(
|
||||||
@ -282,9 +268,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
|||||||
return std::move(result_buffer);
|
return std::move(result_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) {
|
HloExecutionProfile* hlo_execution_profile) {
|
||||||
if (GetRootValueSet().IsAmbiguous()) {
|
if (GetRootValueSet().IsAmbiguous()) {
|
||||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||||
@ -297,7 +283,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
|||||||
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
|
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
|
||||||
const Shape& expected_shape =
|
const Shape& expected_shape =
|
||||||
entry_comp->parameter_instruction(i)->shape();
|
entry_comp->parameter_instruction(i)->shape();
|
||||||
const Shape& actual_shape = arguments[i].shape();
|
const Shape& actual_shape = arguments[i]->on_device_shape();
|
||||||
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
|
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
|
||||||
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
||||||
expected_shape.ToString(/*print_layout=*/true),
|
expected_shape.ToString(/*print_layout=*/true),
|
||||||
@ -311,11 +297,10 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
|||||||
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||||
std::vector<se::OwningDeviceMemory> owning_buffers;
|
std::vector<se::OwningDeviceMemory> owning_buffers;
|
||||||
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
||||||
std::vector<se::OwningDeviceMemory> buffers_to_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::tie(unowning_buffers, owning_buffers, buffers_to_release),
|
std::tie(unowning_buffers, owning_buffers),
|
||||||
CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
|
CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
|
||||||
std::move(arguments)));
|
arguments));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
ScopedShapedBuffer result,
|
ScopedShapedBuffer result,
|
||||||
@ -354,8 +339,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
|||||||
std::move(owning_buffers)),
|
std::move(owning_buffers)),
|
||||||
hlo_execution_profile});
|
hlo_execution_profile});
|
||||||
|
|
||||||
return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
|
return std::move(result);
|
||||||
se::OwningDeviceMemory());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
|
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||||
|
@ -55,9 +55,9 @@ class CpuExecutable : public Executable {
|
|||||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||||
~CpuExecutable() override {}
|
~CpuExecutable() override {}
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) override;
|
HloExecutionProfile* hlo_execution_profile) override;
|
||||||
|
|
||||||
// This should be called after set_ir_module_string.
|
// This should be called after set_ir_module_string.
|
||||||
@ -96,15 +96,11 @@ class CpuExecutable : public Executable {
|
|||||||
// allocated by this routine. This routine allocates buffers for temporary
|
// allocated by this routine. This routine allocates buffers for temporary
|
||||||
// storage and the live-out buffer into which the computation writes it
|
// storage and the live-out buffer into which the computation writes it
|
||||||
// result.
|
// result.
|
||||||
//
|
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||||
// - buffers_to_free: buffers whose ownership was donated by the caller that
|
|
||||||
// are to be freed by the caller.
|
|
||||||
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
|
|
||||||
std::vector<se::OwningDeviceMemory>,
|
|
||||||
std::vector<se::OwningDeviceMemory>>>
|
std::vector<se::OwningDeviceMemory>>>
|
||||||
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||||
int device_ordinal,
|
int device_ordinal,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
|
absl::Span<const ShapedBuffer* const> arguments);
|
||||||
|
|
||||||
// Calls the generated function performing the computation with the given
|
// Calls the generated function performing the computation with the given
|
||||||
// arguments using the supplied buffers.
|
// arguments using the supplied buffers.
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -44,36 +43,9 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ShapeTree<MaybeOwningDeviceMemory> MakeMaybeOwningDeviceMemoryTree(
|
|
||||||
const ShapedBuffer& shaped_buffer) {
|
|
||||||
ShapeTree<MaybeOwningDeviceMemory> result(shaped_buffer.on_device_shape());
|
|
||||||
auto in_it = shaped_buffer.buffers().begin();
|
|
||||||
auto out_it = result.begin();
|
|
||||||
for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
|
|
||||||
DCHECK(out_it != result.end());
|
|
||||||
out_it->second = MaybeOwningDeviceMemory(in_it->second);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
|
|
||||||
const ServiceExecutableRunOptions* run_options,
|
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
|
||||||
HloExecutionProfile* hlo_execution_profile) {
|
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args(arguments.size());
|
|
||||||
auto out_it = args.begin();
|
|
||||||
for (const ShapedBuffer* arg : arguments) {
|
|
||||||
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(ExecutionOutput out,
|
|
||||||
ExecuteAsyncOnStream(run_options, std::move(args),
|
|
||||||
hlo_execution_profile));
|
|
||||||
return out.ConsumeResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
|
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) {
|
HloExecutionProfile* hlo_execution_profile) {
|
||||||
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
|
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
|
||||||
run_options, std::move(arguments), hlo_execution_profile);
|
run_options, std::move(arguments), hlo_execution_profile);
|
||||||
@ -83,6 +55,14 @@ StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStream(
|
||||||
|
const ServiceExecutableRunOptions* /*run_options*/,
|
||||||
|
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> /*arguments*/,
|
||||||
|
HloExecutionProfile* /*hlo_execution_profile*/) {
|
||||||
|
return Unimplemented(
|
||||||
|
"MaybeOwningDeviceMemory version of overload is not implemented ");
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
|
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
|
||||||
absl::Span<const ServiceExecutableRunOptions> run_options,
|
absl::Span<const ServiceExecutableRunOptions> run_options,
|
||||||
absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
|
absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
|
||||||
|
@ -160,22 +160,22 @@ class Executable {
|
|||||||
// If the hlo_execution_profile is provided as non-nullptr, profiling will be
|
// If the hlo_execution_profile is provided as non-nullptr, profiling will be
|
||||||
// enabled. Note that profiling is tricky to use correctly, as the profiling
|
// enabled. Note that profiling is tricky to use correctly, as the profiling
|
||||||
// objects (when they exist) must out-live the task.
|
// objects (when they exist) must out-live the task.
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile);
|
HloExecutionProfile* hlo_execution_profile) = 0;
|
||||||
|
|
||||||
// Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
|
// Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
|
||||||
// complete.
|
// complete.
|
||||||
StatusOr<ExecutionOutput> ExecuteOnStream(
|
StatusOr<ExecutionOutput> ExecuteOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile);
|
HloExecutionProfile* hlo_execution_profile);
|
||||||
|
|
||||||
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) = 0;
|
HloExecutionProfile* hlo_execution_profile);
|
||||||
|
|
||||||
// Same as ExecuteOnStream(), but runs this executable on multiple
|
// Same as ExecuteOnStream(), but runs this executable on multiple
|
||||||
// streams. arguments[i] contains the arguments to the execution on
|
// streams. arguments[i] contains the arguments to the execution on
|
||||||
|
@ -299,14 +299,11 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
|
|||||||
return &module_globals_.emplace(executor, std::move(globals)).first->second;
|
return &module_globals_.emplace(executor, std::move(globals)).first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> GpuExecutable::Execute(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) {
|
HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
|
||||||
se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
|
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||||
// Force synchronous execution if the allocator requires it.
|
|
||||||
const bool block_host_until_done =
|
|
||||||
!memory_allocator->AllowsAsynchronousDeallocation();
|
|
||||||
|
|
||||||
if (GetRootValueSet().IsAmbiguous()) {
|
if (GetRootValueSet().IsAmbiguous()) {
|
||||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||||
@ -337,9 +334,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
|||||||
if (allocation.is_entry_computation_parameter()) {
|
if (allocation.is_entry_computation_parameter()) {
|
||||||
auto param_no = allocation.parameter_number();
|
auto param_no = allocation.parameter_number();
|
||||||
se::DeviceMemoryBase buffer =
|
se::DeviceMemoryBase buffer =
|
||||||
arguments[param_no]
|
arguments[param_no]->buffer(allocation.param_shape_index());
|
||||||
.element(allocation.param_shape_index())
|
|
||||||
.AsDeviceMemoryBase();
|
|
||||||
|
|
||||||
// All top-level buffers and sub-buffers must have an explicit, non-null
|
// All top-level buffers and sub-buffers must have an explicit, non-null
|
||||||
// pointer, except for zero-sized buffers, which may be null.
|
// pointer, except for zero-sized buffers, which may be null.
|
||||||
@ -428,17 +423,19 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
|||||||
}));
|
}));
|
||||||
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
|
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
|
||||||
|
|
||||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
return std::move(shaped_buffer);
|
||||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
}
|
||||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
|
||||||
auto maybe_owning_buffer = buffer.second.Release();
|
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
|
||||||
if (maybe_owning_buffer) {
|
const ServiceExecutableRunOptions* run_options,
|
||||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
}
|
HloExecutionProfile* hlo_execution_profile) {
|
||||||
}
|
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||||
}
|
// Force synchronous execution if the allocator requires it.
|
||||||
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
|
bool block_host_until_done =
|
||||||
{}, {});
|
!memory_allocator->AllowsAsynchronousDeallocation();
|
||||||
|
return Execute(run_options, arguments, hlo_execution_profile,
|
||||||
|
block_host_until_done);
|
||||||
}
|
}
|
||||||
|
|
||||||
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
||||||
|
@ -82,9 +82,9 @@ class GpuExecutable : public Executable {
|
|||||||
|
|
||||||
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
||||||
// doesn't match the compute capability passed to this object's constructor.
|
// doesn't match the compute capability passed to this object's constructor.
|
||||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) override;
|
HloExecutionProfile* hlo_execution_profile) override;
|
||||||
|
|
||||||
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
|
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
|
||||||
@ -92,6 +92,11 @@ class GpuExecutable : public Executable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
StatusOr<ScopedShapedBuffer> Execute(
|
||||||
|
const ServiceExecutableRunOptions* run_options,
|
||||||
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
|
HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
|
||||||
|
|
||||||
// If `block_host_until_done` is false, execution will not block the host
|
// If `block_host_until_done` is false, execution will not block the host
|
||||||
// until the kernels have completed. This is used as an optimization for
|
// until the kernels have completed. This is used as an optimization for
|
||||||
// clients, such as Tensorflow, that use a single stream of execution for
|
// clients, such as Tensorflow, that use a single stream of execution for
|
||||||
|
@ -151,8 +151,7 @@ absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
|||||||
absl::optional<HloInputOutputAliasConfig::Alias>
|
absl::optional<HloInputOutputAliasConfig::Alias>
|
||||||
HloInputOutputAliasConfig::GetAliasedParameter(
|
HloInputOutputAliasConfig::GetAliasedParameter(
|
||||||
const ShapeIndex& output_index) const {
|
const ShapeIndex& output_index) const {
|
||||||
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
|
||||||
<< ToString() << " " << alias_.shape().ToString() << " " << output_index;
|
|
||||||
return alias_.element(output_index);
|
return alias_.element(output_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,15 +89,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/platform:env",
|
|
||||||
"//tensorflow/core/platform:macros",
|
|
||||||
"//tensorflow/core/platform:mutex",
|
|
||||||
"//tensorflow/core/platform:types",
|
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
|
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -40,39 +39,24 @@ namespace interpreter {
|
|||||||
InterpreterExecutable::InterpreterExecutable(
|
InterpreterExecutable::InterpreterExecutable(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
std::unique_ptr<HloEvaluator> evaluator)
|
std::unique_ptr<HloEvaluator> evaluator)
|
||||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
: Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
|
||||||
/*hlo_profile_index_map=*/nullptr),
|
/*hlo_profile_index_map=*/nullptr),
|
||||||
evaluator_(std::move(evaluator)) {}
|
evaluator_(std::move(evaluator)) {}
|
||||||
|
|
||||||
InterpreterExecutable::~InterpreterExecutable() {}
|
InterpreterExecutable::~InterpreterExecutable() {}
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) {
|
HloExecutionProfile* hlo_execution_profile) {
|
||||||
se::Stream* stream = run_options->stream();
|
se::Stream* stream = run_options->stream();
|
||||||
se::StreamExecutor* executor = stream->parent();
|
se::StreamExecutor* executor = stream->parent();
|
||||||
const se::Platform* platform = executor->platform();
|
const se::Platform* platform = executor->platform();
|
||||||
|
|
||||||
// Convert the ShapeTree to a ShapedBuffer. We do this so we can call
|
|
||||||
// TransferManager methods below.
|
|
||||||
std::vector<ShapedBuffer> argument_buffers;
|
|
||||||
argument_buffers.reserve(arguments.size());
|
|
||||||
for (const ShapeTree<MaybeOwningDeviceMemory>& arg : arguments) {
|
|
||||||
argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
|
|
||||||
/*platform=*/nullptr,
|
|
||||||
/*device_ordinal=*/0));
|
|
||||||
auto in_it = arg.begin();
|
|
||||||
auto out_it = argument_buffers.back().buffers().begin();
|
|
||||||
for (; in_it != arg.end(); ++in_it, ++out_it) {
|
|
||||||
out_it->second = in_it->second.AsDeviceMemoryBase();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "Execute " << module().name();
|
VLOG(1) << "Execute " << module().name();
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
for (const auto& a : argument_buffers) {
|
for (const auto& a : arguments) {
|
||||||
VLOG(2) << "-- argument " << a;
|
VLOG(2) << "-- argument " << *a;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +71,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
|||||||
// Check that the args have the right shape.
|
// Check that the args have the right shape.
|
||||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||||
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
||||||
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
const auto& actual_shape = arguments[i]->on_device_shape();
|
||||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||||
actual_shape)) {
|
actual_shape)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -106,7 +90,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
|||||||
for (int64 p = 0; p < computation->num_parameters(); ++p) {
|
for (int64 p = 0; p < computation->num_parameters(); ++p) {
|
||||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||||
transfer_manager->TransferLiteralFromDevice(
|
transfer_manager->TransferLiteralFromDevice(
|
||||||
run_options->stream(), argument_buffers[p]));
|
run_options->stream(), *arguments[p]));
|
||||||
arg_literals.push_back(std::move(arg_literal));
|
arg_literals.push_back(std::move(arg_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,16 +119,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
|||||||
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
|
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
return std::move(result);
|
||||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
|
||||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
|
||||||
auto maybe_owning_buffer = buffer.second.Release();
|
|
||||||
if (maybe_owning_buffer) {
|
|
||||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
|
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||||
|
@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable {
|
|||||||
std::unique_ptr<HloEvaluator> evaluator);
|
std::unique_ptr<HloEvaluator> evaluator);
|
||||||
~InterpreterExecutable() override;
|
~InterpreterExecutable() override;
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
HloExecutionProfile* hlo_execution_profile) override
|
HloExecutionProfile* hlo_execution_profile) override
|
||||||
LOCKS_EXCLUDED(evaluator_lock_);
|
LOCKS_EXCLUDED(evaluator_lock_);
|
||||||
|
|
||||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
|||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
|
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
|
||||||
const {
|
|
||||||
if (HasOwnership()) {
|
if (HasOwnership()) {
|
||||||
return *absl::get<tensorflow::se::OwningDeviceMemory>(mem_);
|
return *absl::get<tensorflow::se::OwningDeviceMemory>(mem_);
|
||||||
} else {
|
} else {
|
||||||
|
@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory {
|
|||||||
|
|
||||||
// Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
|
// Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
|
||||||
// caller of this function is *not* responsible for freeing the memory.
|
// caller of this function is *not* responsible for freeing the memory.
|
||||||
tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const;
|
tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase();
|
||||||
|
|
||||||
// Release the tensorflow::se::OwningDeviceMemory without freeing it, and
|
// Release the tensorflow::se::OwningDeviceMemory without freeing it, and
|
||||||
// moves the ownership of the memory buffer from the object to the caller.
|
// moves the ownership of the memory buffer from the object to the caller.
|
||||||
|
Loading…
Reference in New Issue
Block a user