Make XRT CPU/GPU use MaybeOwning buffer interface, so the new copy protection CL won't break aliasing.

PiperOrigin-RevId: 317700747
Change-Id: Ie7b5bb1989cd4359b30ad86a450de5bff0962c31
This commit is contained in:
Davide Libenzi 2020-06-22 11:41:43 -07:00 committed by TensorFlower Gardener
parent 79518facb4
commit 44067f0783
11 changed files with 385 additions and 172 deletions

View File

@ -299,12 +299,11 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
const Shape& expected_shape = const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape(); entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i].Buffers().shape(); const Shape& actual_shape = arguments[i].Buffers().shape();
CHECK( TF_RET_CHECK(
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape)) ShapeUtil::DynamicShapeIsCompatible(actual_shape, expected_shape))
<< absl::StreamFormat( << "Shape mismatch on argument " << i << ", "
"Shape mismatch on argument %d. Expected %s, but was %s.", i, << expected_shape.ToString(/*print_layout=*/true) << " vs. "
expected_shape.ToString(/*print_layout=*/true), << actual_shape.ToString(/*print_layout=*/true);
actual_shape.ToString(/*print_layout=*/true));
} }
} }

View File

@ -28,10 +28,57 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/device_description.h" #include "tensorflow/stream_executor/device_description.h"
namespace xla { namespace xla {
ExecutionInput::~ExecutionInput() {
for (auto& index : unowned_indices_) {
auto buffer = buffers_.mutable_element(index)->Release();
if (buffer) {
buffer->Release();
}
}
}
Status ExecutionInput::SetDynamicShape(Shape dynamic_shape) {
const Shape& input_shape = shape();
if (!ShapeUtil::DynamicShapeIsCompatible(input_shape, dynamic_shape)) {
return tensorflow::errors::InvalidArgument(
"Cannot set dynamic shape: ", input_shape.DebugString(), " vs. ",
dynamic_shape.DebugString());
}
dynamic_shape_ = absl::make_unique<Shape>(std::move(dynamic_shape));
return Status::OK();
}
void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer);
unowned_indices_.insert(index);
}
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
const Shape& input_shape = shape();
xla::ShapedBuffer shaped_buffer(input_shape, input_shape,
allocator->platform(), device_ordinal);
for (const auto& index_buffer : Buffers()) {
const tensorflow::se::OwningDeviceMemory* mem =
index_buffer.second.AsOwningDeviceMemory();
if (mem != nullptr && (mem->allocator() != allocator ||
mem->device_ordinal() != device_ordinal)) {
return tensorflow::errors::InvalidArgument(
"Device buffer at index ", index_buffer.first.ToString(),
" has mismatching allocator/device");
}
shaped_buffer.set_buffer(index_buffer.second.AsDeviceMemoryBase(),
index_buffer.first);
}
return std::move(shaped_buffer);
}
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream( StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments, absl::Span<const ShapedBuffer* const> arguments,

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
#include <memory> #include <memory>
#include <set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -65,31 +66,32 @@ class ExecutionInput {
: buffers_(std::move(buffers)) {} : buffers_(std::move(buffers)) {}
ExecutionInput(ExecutionInput&&) = default; ExecutionInput(ExecutionInput&&) = default;
~ExecutionInput() { ~ExecutionInput();
for (auto& index : unowned_indices_) {
auto buffer = buffers_.mutable_element(index)->Release();
if (buffer) {
buffer->Release();
}
}
}
ExecutionInput& operator=(ExecutionInput&&) = default; ExecutionInput& operator=(ExecutionInput&&) = default;
const Shape& shape() const { return buffers_.shape(); } const Shape& shape() const {
return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape();
}
Status SetDynamicShape(Shape dynamic_shape);
xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal) const;
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer); *buffers_.mutable_element(index) = std::move(buffer);
} }
void SetUnownedBuffer(const ShapeIndex& index, void SetUnownedBuffer(const ShapeIndex& index,
MaybeOwningDeviceMemory buffer) { MaybeOwningDeviceMemory buffer);
*buffers_.mutable_element(index) = std::move(buffer);
unowned_indices_.push_back(index);
}
void SetUnownedIndex(const ShapeIndex& index) { void SetUnownedIndex(const ShapeIndex& index) {
unowned_indices_.push_back(index); unowned_indices_.insert(index);
}
void ClearUnownedIndex(const ShapeIndex& index) {
unowned_indices_.erase(index);
} }
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; } const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
@ -106,9 +108,10 @@ class ExecutionInput {
private: private:
ShapeTree<MaybeOwningDeviceMemory> buffers_; ShapeTree<MaybeOwningDeviceMemory> buffers_;
// (Unordered) set of indices of buffers that should be returned to the // Set of indices of buffers that should be returned to the caller if an error
// caller if an error occurs when enqueuing the computation. // occurs when enqueuing the computation.
std::vector<ShapeIndex> unowned_indices_; std::set<ShapeIndex> unowned_indices_;
std::unique_ptr<Shape> dynamic_shape_;
}; };
// ExecutionOutput encapsulates the output buffers of a execution and the // ExecutionOutput encapsulates the output buffers of a execution and the
@ -145,7 +148,6 @@ class ExecutionOutput {
to_be_released_.push_back(std::move(mem)); to_be_released_.push_back(std::move(mem));
} }
// Should be called once it is known that the execute operation succeeded, // Should be called once it is known that the execute operation succeeded,
// before returning the ExecutionOutput to the caller. // before returning the ExecutionOutput to the caller.
ExecutionOutput& Commit() { ExecutionOutput& Commit() {

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
namespace xla { namespace xla {
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
@ -38,4 +40,10 @@ MaybeOwningDeviceMemory::Release() {
return std::move(absl::get<tensorflow::se::OwningDeviceMemory>(mem_)); return std::move(absl::get<tensorflow::se::OwningDeviceMemory>(mem_));
} }
const tensorflow::se::OwningDeviceMemory*
MaybeOwningDeviceMemory::AsOwningDeviceMemory() const {
return HasOwnership() ? &absl::get<tensorflow::se::OwningDeviceMemory>(mem_)
: nullptr;
}
} // namespace xla } // namespace xla

View File

@ -57,6 +57,10 @@ class MaybeOwningDeviceMemory {
// A nullopt is returned if the HasOwnership() == false; // A nullopt is returned if the HasOwnership() == false;
absl::optional<tensorflow::se::OwningDeviceMemory> Release(); absl::optional<tensorflow::se::OwningDeviceMemory> Release();
// If the device memory is owned, returns a pointer to the internal
// OwningDeviceMemory, otherwise nullptr is returned.
const tensorflow::se::OwningDeviceMemory* AsOwningDeviceMemory() const;
// Returns true if the device_memory has ownership over underlying memory. // Returns true if the device_memory has ownership over underlying memory.
bool HasOwnership() const; bool HasOwnership() const;

View File

@ -1461,7 +1461,7 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
return shape; return shape;
} }
/* static */ bool ShapeUtil::DynamicShapeIsCompatible( /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
if (dynamic_shape.rank() != bounded_shape.rank()) { if (dynamic_shape.rank() != bounded_shape.rank()) {
return false; return false;
@ -1474,6 +1474,36 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
return true; return true;
} }
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
bool compatible = true;
xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
const ShapeIndex& index) {
if (compatible) {
auto subshape_result = TryGetSubshape(bounded_shape, index);
if (subshape_result.ok()) {
const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
if (sub_shape.IsTuple()) {
if (!bounded_sub_shape->IsTuple()) {
compatible = false;
}
} else {
if (bounded_sub_shape->IsTuple()) {
compatible = false;
} else if (!sub_shape.is_static() &&
!DynamicArrayShapeIsCompatible(sub_shape,
*bounded_sub_shape)) {
compatible = false;
}
}
} else {
compatible = false;
}
}
});
return compatible;
}
/* static */ Shape ShapeUtil::FilterDimensions( /* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) { const std::function<bool(int64)>& p, Shape shape) {
CHECK(shape.IsArray()); CHECK(shape.IsArray());

View File

@ -657,7 +657,11 @@ class ShapeUtil {
Shape shape); Shape shape);
// Returns true if `dynamic_shape` has dimensions that are less-equal to the // Returns true if `dynamic_shape` has dimensions that are less-equal to the
// "bounded_shape". // "bounded_shape". Shapes must be arrays.
static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
const xla::Shape& bounded_shape);
// Same as DynamicArrayShapeIsCompatible() but supports tuples.
static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
const xla::Shape& bounded_shape); const xla::Shape& bounded_shape);

View File

@ -74,6 +74,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -51,12 +51,6 @@ namespace tensorflow {
namespace { namespace {
struct InputBuffers {
std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
std::vector<xla::ShapedBuffer> input_allocations;
std::vector<xla::ShapedBuffer*> input_pointers;
};
uint32 InitialRandomSeed() { uint32 InitialRandomSeed() {
// Support plumbing the TF seed through to XLA is being worked on. // Support plumbing the TF seed through to XLA is being worked on.
// If a user wants deterministic behavior, their best option // If a user wants deterministic behavior, their best option
@ -80,75 +74,51 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2); return counter.fetch_add(2);
} }
xla::StatusOr<InputBuffers> GetInputBuffers( std::vector<bool> GetDynamicInputInfo(
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, const xla::ComputationLayout& computation_layout) {
const std::vector<InputCoords>& input_coords, bool release_inputs) { std::vector<bool> input_is_dynamic;
InputBuffers input_buffers; input_is_dynamic.reserve(computation_layout.parameter_count());
input_buffers.input_tuples.reserve(input_coords.size()); for (int64 i = 0; i < computation_layout.parameter_count(); ++i) {
input_buffers.input_allocations.reserve(input_coords.size()); input_is_dynamic.push_back(
input_buffers.input_pointers.reserve(input_coords.size()); !computation_layout.parameter_shape(i).is_static());
for (size_t i = 0; i < input_coords.size(); ++i) {
TF_RETURN_IF_ERROR(
working_set->LookupAndPin(backend, input_coords[i].handle));
auto tuple = working_set->PinnedTuples().back();
input_buffers.input_tuples.emplace_back(tuple);
if (release_inputs) {
// We are holding a reference to the tuple, so we can safely delete it
// from the resource manager here.
TF_RETURN_IF_ERROR(
working_set->MemoryManager()->Release(input_coords[i].handle));
VLOG(2) << "Released allocation handle " << input_coords[i].handle;
}
if (input_coords[i].index.empty()) {
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
tuple->ToShapedBuffer());
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
} else {
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
tuple->ToShapedBuffer());
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer,
shaped_buffer.SubShapedBuffer(input_coords[i].index));
input_buffers.input_allocations.emplace_back(
std::move(sub_shaped_buffer));
}
} }
for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { return input_is_dynamic;
input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]);
}
return std::move(input_buffers);
} }
xla::StatusOr<InputBuffers> GetChainedOpInputs( xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTuples(
xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set,
xla::Backend* backend, const std::vector<InputCoords>& input_coords,
bool release_inputs) {
const xla::ComputationLayout& computation_layout =
executable->executable()->module_config().entry_computation_layout();
return GetInputTupleAllocations(
input_coords, working_set, backend, computation_layout.parameter_count(),
[&](int64 i) { return computation_layout.parameter_shape(i); },
release_inputs);
}
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputTuples(
const xrt::XRTChainedExecuteOp& op, const xrt::XRTChainedExecuteOp& op,
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) { absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
InputBuffers input_buffers; std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
input_buffers.input_tuples.reserve(op.inputs_size()); input_tuples.reserve(op.inputs_size());
input_buffers.input_allocations.reserve(op.inputs_size());
input_buffers.input_pointers.reserve(op.inputs_size());
for (int i = 0; i < op.inputs_size(); ++i) { for (int i = 0; i < op.inputs_size(); ++i) {
auto& input = op.inputs(i); auto& input = op.inputs(i);
input_buffers.input_tuples.emplace_back(op_inputs[i]);
// Thanks to the greatness of proto3, there is no way to query for // Thanks to the greatness of proto3, there is no way to query for
// explicitly set fields, so the default for output_index (zero) means no // explicitly set fields, so the default for output_index (zero) means no
// sub-index. As consequence, the real index is output_index - 1. // sub-index. As consequence, the real index is output_index - 1.
if (input.output_index() == 0) { if (input.output_index() == 0) {
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, input_tuples.emplace_back(op_inputs[i]);
input_buffers.input_tuples.back()->ToShapedBuffer());
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
} else { } else {
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, XRTTupleAllocation* sub_tuple;
input_buffers.input_tuples.back()->ToShapedBuffer()); TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
TF_ASSIGN_OR_RETURN( op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple,
xla::ShapedBuffer sub_shaped_buffer, /*alias_parent_allocation=*/true));
shaped_buffer.SubShapedBuffer({input.output_index() - 1})); input_tuples.emplace_back(sub_tuple);
input_buffers.input_allocations.emplace_back(
std::move(sub_shaped_buffer));
} }
} }
for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { return input_tuples;
input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]);
}
return std::move(input_buffers);
} }
// Given a shape, returns a byte array representing the shape metadata of the // Given a shape, returns a byte array representing the shape metadata of the
@ -228,12 +198,11 @@ Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
// As we can't expand the size of an existing memory allocation, a reallocation // As we can't expand the size of an existing memory allocation, a reallocation
// is required. A list of new allocations are returned after this function. The // is required. A list of new allocations are returned after this function. The
// caller is reponsible for maintaining those allocations. // caller is reponsible for maintaining those allocations.
xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs( Status UpdateDynamicInputs(
se::Stream* stream, se::DeviceMemoryAllocator* allocator, se::Stream* stream, se::DeviceMemoryAllocator* allocator,
std::vector<xla::ShapedBuffer*> runtime_inputs, std::vector<xla::ExecutionInput>* execution_inputs,
const std::vector<xla::ShapeLayout>& compile_time_shapes) { const std::vector<xla::ShapeLayout>& compile_time_shapes) {
std::vector<se::OwningDeviceMemory> new_allocations; TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size());
TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size());
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform())); stream->parent()->platform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction(); auto shape_size_fn = compiler->ShapeSizeBytesFunction();
@ -242,57 +211,61 @@ xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs(
if (compile_time_shape.is_static()) { if (compile_time_shape.is_static()) {
continue; continue;
} }
auto* runtime_input = runtime_inputs[i]; xla::ExecutionInput* execution_input = &(*execution_inputs)[i];
bool element_modified = false; bool element_modified = false;
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
compile_time_shape, compile_time_shape,
[&](const xla::Shape& compile_time_shape, [&](const xla::Shape& sub_shape,
const xla::ShapeIndex& index) -> Status { const xla::ShapeIndex& index) -> Status {
if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { if (sub_shape.IsTuple() || sub_shape.is_static()) {
return Status::OK(); return Status::OK();
} }
const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape(
runtime_input->on_device_shape(), index);
TF_RET_CHECK(!runtime_shape.IsTuple());
TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible(
runtime_shape, compile_time_shape));
se::DeviceMemoryBase* static_input =
runtime_input->buffers().mutable_element(index);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto dynamic_input, const xla::Shape* runtime_shape,
xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index));
TF_RET_CHECK(!runtime_shape->IsTuple());
TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible(
*runtime_shape, sub_shape));
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory dynamic_input,
allocator->Allocate(stream->parent()->device_ordinal(), allocator->Allocate(stream->parent()->device_ordinal(),
shape_size_fn(compile_time_shape))); shape_size_fn(sub_shape)));
new_allocations.emplace_back(std::move(dynamic_input));
se::DeviceMemory<uint8>* dynamic_input_base = se::DeviceMemoryBase static_input =
new_allocations.back().ptr(); execution_input->Buffer(index).AsDeviceMemoryBase();
se::DeviceMemory<uint8>* dynamic_input_base = dynamic_input.ptr();
// Send the original data to the new location. // Send the original data to the new location.
stream->ThenMemcpyD2D(dynamic_input_base, *static_input, stream->ThenMemcpyD2D(dynamic_input_base, static_input,
static_input->size()); static_input.size());
TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
compile_time_shape, runtime_shape)); sub_shape, *runtime_shape));
// Modify the memory location in the input shape tree to point to the // Modify the memory location in the input shape tree to point to the
// new input. // new input.
runtime_input->set_buffer(*dynamic_input_base, index); execution_input->SetBuffer(
index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input)));
execution_input->ClearUnownedIndex(index);
element_modified = true; element_modified = true;
return Status::OK(); return Status::OK();
})); }));
if (element_modified) { if (element_modified) {
runtime_input->set_shapes(compile_time_shape, compile_time_shape); TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape));
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
execution_input->ToShapedBuffer(
allocator, stream->parent()->device_ordinal()));
// The input location has been modified, need to fix tuple table to // The input location has been modified, need to fix tuple table to
// point to the correct address. // point to the correct address.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto transfer_manager, auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform())); xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer));
} }
} }
return std::move(new_allocations); return Status::OK();
} }
xla::StatusOr<xla::Literal> ReadMetadataLiteral( xla::StatusOr<xla::Literal> ReadMetadataLiteral(
se::Stream* stream, se::DeviceMemoryBase* buffer, se::Stream* stream, se::DeviceMemoryBase buffer,
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform())); stream->parent()->platform()));
@ -302,7 +275,7 @@ xla::StatusOr<xla::Literal> ReadMetadataLiteral(
const int64 offset = shape_size_fn(buffer_shape_static); const int64 offset = shape_size_fn(buffer_shape_static);
int64 metadata_size = shape_size_fn(buffer_shape) - offset; int64 metadata_size = shape_size_fn(buffer_shape) - offset;
TF_RET_CHECK(metadata_size != 0); TF_RET_CHECK(metadata_size != 0);
auto buffer_8 = se::DeviceMemory<uint8>(*buffer); auto buffer_8 = se::DeviceMemory<uint8>(buffer);
auto metadata_buffer = auto metadata_buffer =
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
return transfer_manager->TransferArrayFromDevice( return transfer_manager->TransferArrayFromDevice(
@ -315,7 +288,7 @@ xla::StatusOr<xla::Literal> ReadMetadataLiteral(
// dimension sizes from the metadata, and update output shapes. The result shape // dimension sizes from the metadata, and update output shapes. The result shape
// is a static and concrete shape. // is a static and concrete shape.
xla::Status UpdateDynamicOutputs(se::Stream* stream, xla::Status UpdateDynamicOutputs(se::Stream* stream,
xla::ShapedBuffer* shaped_buffer, const xla::ShapedBuffer& shaped_buffer,
xla::Shape* output_host_shape, xla::Shape* output_host_shape,
xla::Shape* output_device_shape) { xla::Shape* output_device_shape) {
DCHECK(output_device_shape->is_dynamic()); DCHECK(output_device_shape->is_dynamic());
@ -323,8 +296,8 @@ xla::Status UpdateDynamicOutputs(se::Stream* stream,
auto transfer_manager, auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform())); xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
const xla::Shape& buffer_shape = const xla::Shape& buffer_shape =
xla::ShapeUtil::GetSubshape(*output_device_shape, index); xla::ShapeUtil::GetSubshape(*output_device_shape, index);
if (buffer_shape.IsTuple()) { if (buffer_shape.IsTuple()) {
@ -352,19 +325,18 @@ xla::Status UpdateDynamicOutputs(se::Stream* stream,
return Status::OK(); return Status::OK();
} }
// Create output tuple from run_result.
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple( xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
se::Stream* stream, xla::ScopedShapedBuffer run_result, se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
xla::Backend* backend, int device_ordinal) { int device_ordinal) {
XRTTupleAllocation* output_tuple; XRTTupleAllocation* output_tuple;
xla::ShapedBuffer shaped_buffer = run_result.release(); const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
if (shaped_buffer.on_device_shape().is_dynamic()) { if (shaped_buffer.on_device_shape().is_dynamic()) {
// Update dynamic shapes from output buffer, and create a XRT tensor with // Update dynamic shapes from output buffer, and create a XRT tensor with
// dimension sizes read from metadata. // dimension sizes read from metadata.
xla::Shape output_host_shape = shaped_buffer.on_host_shape(); xla::Shape output_host_shape = shaped_buffer.on_host_shape();
xla::Shape output_device_shape = shaped_buffer.on_device_shape(); xla::Shape output_device_shape = shaped_buffer.on_device_shape();
TF_RETURN_IF_ERROR(UpdateDynamicOutputs( TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
stream, &shaped_buffer, &output_host_shape, &output_device_shape)); stream, shaped_buffer, &output_host_shape, &output_device_shape));
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, output_host_shape, output_device_shape, backend, shaped_buffer, output_host_shape, output_device_shape, backend,
device_ordinal, &output_tuple)); device_ordinal, &output_tuple));
@ -373,15 +345,27 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, backend, device_ordinal, &output_tuple)); shaped_buffer, backend, device_ordinal, &output_tuple));
} }
// After the output tuple is created, we can release the output result
// buffers, to make sure they won't be cleared by its destructor.
(void)run_result.ConsumeResult().release();
return RefPtr<XRTTupleAllocation>(output_tuple); return RefPtr<XRTTupleAllocation>(output_tuple);
} }
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable( xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers, xla::LocalExecutable* executable,
se::Stream* stream, int rng_seed, absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
bool release_inputs, se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) { const xrt::CommonExecutionConfig& config) {
VLOG(2) << "Executing computation."; const xla::ComputationLayout& computation_layout =
executable->executable()->module_config().entry_computation_layout();
std::vector<bool> input_is_dynamic = GetDynamicInputInfo(computation_layout);
TF_ASSIGN_OR_RETURN(
std::vector<xla::ExecutionInput> execution_inputs,
GetArgumentsBuffers(
executable->executable()->module().input_output_alias_config(),
input_tuples, input_is_dynamic, release_inputs));
xla::ExecutableRunOptions run_options; xla::ExecutableRunOptions run_options;
run_options.set_stream(stream); run_options.set_stream(stream);
run_options.set_allocator(device_ref->backend()->memory_allocator()); run_options.set_allocator(device_ref->backend()->memory_allocator());
@ -419,51 +403,28 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
} }
run_options.set_gpu_executable_run_options(&gpu_options); run_options.set_gpu_executable_run_options(&gpu_options);
Env* env = Env::Default();
auto start_time = env->NowMicros();
const std::vector<xla::ShapeLayout>& shape_layouts = const std::vector<xla::ShapeLayout>& shape_layouts =
executable->executable() executable->executable()
->module_config() ->module_config()
.entry_computation_layout() .entry_computation_layout()
.parameter_layouts(); .parameter_layouts();
TF_ASSIGN_OR_RETURN(auto new_allocations, TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(),
UpdateDynamicInputs(stream, run_options.allocator(), &execution_inputs, shape_layouts));
input_buffers.input_pointers,
shape_layouts));
auto new_allocations_ptr =
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(new_allocations));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
xla::ScopedShapedBuffer run_result, xla::ExecutionOutput run_result,
executable->Run(input_buffers.input_pointers, run_options)); executable->Run(std::move(execution_inputs), run_options));
// Retain the new allocation for input memory until the end of execution.
stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); });
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
RefPtr<XRTTupleAllocation> output_tuple_ptr, RefPtr<XRTTupleAllocation> output_tuple_ptr,
CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
device_ref->device_ordinal())); device_ref->device_ordinal()));
// The ScopedShapedBuffer returned by the executable Run() API, in case of // The ScopedShapedBuffer returned by the executable Run() API, in case of
// input/output buffer aliasing, might have holes in it, which need to be // input/output buffer aliasing, might have holes in it, which need to be
// filled using the proper input tuples buffers which are the source of // filled using the proper input tuples buffers which are the source of
// aliasing. // aliasing.
const xla::HloInputOutputAliasConfig& input_output_alias = TF_RETURN_IF_ERROR(RebuildOutputAliases(
executable->executable()->module().input_output_alias_config(); output_tuple_ptr, input_tuples,
auto alias_function = executable->executable()->module().input_output_alias_config()));
[&](const xla::ShapeIndex& output_index,
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size());
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
? output_tuple_ptr->AliasBufferFrom(
*input_buffers.input_tuples[alias.parameter_number],
alias.parameter_index, output_index)
: Status::OK();
};
TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function));
return std::move(output_tuple_ptr); return std::move(output_tuple_ptr);
} }
@ -471,12 +432,13 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation( xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
OpKernelContext* context, XRTMemoryManager* memory_manager, OpKernelContext* context, XRTMemoryManager* memory_manager,
XRTGenericDeviceAccessor::ScopedRef* device_ref, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers, xla::LocalExecutable* executable,
se::Stream* stream, int rng_seed, absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
bool release_inputs, se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) { const xrt::CommonExecutionConfig& config) {
auto runfn = [&]() { auto runfn = [&]() {
return RunExecutable(context, device_ref, executable, input_buffers, stream, return RunExecutable(context, device_ref, executable, input_tuples,
rng_seed, config); release_inputs, stream, rng_seed, config);
}; };
// We pass zero as requested_free_size as there is no simple way to get the // We pass zero as requested_free_size as there is no simple way to get the
@ -495,12 +457,13 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
se::Stream* stream, int rng_seed, se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) { const xrt::CommonExecutionConfig& config) {
XRTMemoryManager::WorkingSet working_set(memory_manager); XRTMemoryManager::WorkingSet working_set(memory_manager);
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, TF_ASSIGN_OR_RETURN(
GetInputBuffers(&working_set, device_ref->backend(), std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
input_coords, release_inputs)); GetInputTuples(executable, &working_set, device_ref->backend(),
input_coords, release_inputs));
return ExecuteComputation(context, memory_manager.get(), device_ref, return ExecuteComputation(context, memory_manager.get(), device_ref,
executable, input_buffers, stream, rng_seed, executable, input_tuples, release_inputs, stream,
config); rng_seed, config);
} }
// XRTExecuteOp // XRTExecuteOp
@ -653,16 +616,16 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
-> xla::StatusOr<RefPtr<XRTTupleAllocation>> { -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
GetChainedOpInputs(op, op_inputs));
std::unique_ptr<XRTCompilationCacheEntryRef> entry; std::unique_ptr<XRTCompilationCacheEntryRef> entry;
TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry)); TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
xla::LocalExecutable* executable = entry->get().get_executable(); xla::LocalExecutable* executable = entry->get().get_executable();
return ExecuteComputation(context, memory_manager.get(), &device_ref, TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
executable, input_buffers, stream, rng_seed, GetChainedOpInputTuples(op, op_inputs));
config.common_config());
return ExecuteComputation(
context, memory_manager.get(), &device_ref, executable, input_tuples,
/*release_inputs=*/false, stream, rng_seed, config.common_config());
}; };
return ExecuteChained(context, memory_manager, device_ref.backend(), return ExecuteChained(context, memory_manager, device_ref.backend(),

View File

@ -221,6 +221,140 @@ xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
return std::move(input_coords); return std::move(input_coords);
} }
bool InputShapeMatches(const xla::Shape& parameter_shape,
const xla::Shape& input_shape) {
auto shape_checker = [&](const xla::Shape& pshape,
const xla::ShapeIndex& index) {
if (pshape.IsArray()) {
TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
xla::ShapeUtil::TryGetSubshape(input_shape, index));
if (pshape.rank() != ishape->rank() ||
pshape.element_type() != ishape->element_type()) {
return errors::InvalidArgument("Mismatching shapes");
}
if (pshape.is_static() && pshape.layout() != ishape->layout()) {
return errors::InvalidArgument("Mismatching layouts");
}
for (int64 dim = 0; dim < pshape.rank(); ++dim) {
if (pshape.is_dynamic_dimension(dim)) {
if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
return errors::InvalidArgument("Mismatching shapes");
}
} else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
return errors::InvalidArgument("Mismatching shapes");
}
}
}
return Status::OK();
};
return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
shape_checker)
.ok();
}
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
const std::vector<InputCoords>& input_coords,
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
int64 num_input_shapes,
const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs) {
if (input_coords.size() != num_input_shapes) {
return errors::InvalidArgument(
"Number of inputs does not match executable proto input shapes: ",
input_coords.size(), " vs. ", num_input_shapes);
}
std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
input_tuples.reserve(input_coords.size());
for (size_t i = 0; i < input_coords.size(); ++i) {
TF_RETURN_IF_ERROR(
working_set->LookupAndPin(backend, input_coords[i].handle));
auto tuple = working_set->PinnedTuples().back();
if (release_inputs) {
// We are holding a reference to the tuple, so we can safely delete it
// from the resource manager here.
TF_RETURN_IF_ERROR(
working_set->MemoryManager()->Release(input_coords[i].handle));
VLOG(2) << "Released allocation handle " << input_coords[i].handle;
}
xla::Shape input_shape = shape_getter(i);
if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
return errors::InvalidArgument(
"Run-time shape mismatch for XRTExecute argument[", i, "] (",
input_coords[i].handle, "). Expected ", input_shape.DebugString(),
"; got ", tuple->on_host_shape().DebugString());
}
if (input_coords[i].index.empty()) {
input_tuples.emplace_back(std::move(tuple));
} else {
XRTTupleAllocation* sub_tuple;
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
tuple.get(), input_coords[i].index, &sub_tuple,
/*alias_parent_allocation=*/true));
input_tuples.emplace_back(sub_tuple);
}
}
return std::move(input_tuples);
}
Status RebuildOutputAliases(
const RefPtr<XRTTupleAllocation>& output_tuple,
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
const xla::HloInputOutputAliasConfig& input_output_alias) {
auto alias_function =
[&](const xla::ShapeIndex& output_index,
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
TF_RET_CHECK(alias.parameter_number < input_tuples.size());
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
? output_tuple->AliasBufferFrom(
*input_tuples[alias.parameter_number],
alias.parameter_index, output_index)
: Status::OK();
};
return input_output_alias.ForEachAliasWithStatus(alias_function);
}
xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
const std::vector<bool>& input_is_dynamic, bool release_inputs) {
auto is_dynamic = [&](size_t arg) {
return arg < input_is_dynamic.size() && input_is_dynamic[arg];
};
std::vector<xla::ExecutionInput> arguments;
// Don't alias dynamic input -- Due to the underlying implementation,
// aliased inputs have two owners: XRTAllocation and return value of
// this function. If an argument is dynamic and the ownership is
// released to output of this function, TPUExecute will free it and
// reallocate a new one, which creates a double freeing issue where
// XRTAllocation also attempts to release the buffer.
bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
arguments.reserve(input_tuples.size());
for (int64 i = 0; i < input_tuples.size(); ++i) {
auto alias_checker =
[&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
// Only the buffers which the caller explicitly marked as aliased
// (kUserAlias), should create aliases.
// The XLA compiler might create opportunistic aliases (kSystemAlias)
// which need a different handling. With a system alias we know that XLA
// is going to reuse a given input parameter buffer for a given output, so
// unless it is known at call site that the input buffer has no more uses,
// a copy needs to be made at call site. With user specified alias the
// caller tells us that he expects a given output to land over the buffers
// of a given parametter.
if (input_output_alias.ParameterAliasKind(i, index) ==
xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) {
TF_RET_CHECK(!is_dynamic(i));
return true;
}
return alias_outputs;
};
TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
input_tuples[i]->ToExecutionInput(alias_checker));
arguments.emplace_back(std::move(exec_input));
}
return std::move(arguments);
}
Status CreateExecuteOutput(OpKernelContext* context, Status CreateExecuteOutput(OpKernelContext* context,
XRTMemoryManager* memory_manager, XRTMemoryManager* memory_manager,
RefPtr<XRTTupleAllocation> output_tuple, RefPtr<XRTTupleAllocation> output_tuple,

View File

@ -23,6 +23,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla.pb.h"
@ -69,6 +71,25 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options);
xla::StatusOr<std::vector<InputCoords>> GetComputationInputs( xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
OpKernelContext* context, const char* input_name); OpKernelContext* context, const char* input_name);
bool InputShapeMatches(const xla::Shape& parameter_shape,
const xla::Shape& input_shape);
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
const std::vector<InputCoords>& input_coords,
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
int64 num_input_shapes,
const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs);
Status RebuildOutputAliases(
const RefPtr<XRTTupleAllocation>& output_tuple,
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
const xla::HloInputOutputAliasConfig& input_output_alias);
xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
const std::vector<bool>& input_is_dynamic, bool release_inputs);
// Create the XRT execute output tensor given the computation result // Create the XRT execute output tensor given the computation result
// (output_tuple). The return_exploded_tuple tells whether a tuple result should // (output_tuple). The return_exploded_tuple tells whether a tuple result should
// be returned as vector of handles representing each tuple child. // be returned as vector of handles representing each tuple child.