Make XRT CPU/GPU use MaybeOwning buffer interface, so the new copy protection CL won't break aliasing.
PiperOrigin-RevId: 317700747 Change-Id: Ie7b5bb1989cd4359b30ad86a450de5bff0962c31
This commit is contained in:
parent
79518facb4
commit
44067f0783
|
@ -299,12 +299,11 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||||
const Shape& expected_shape =
|
const Shape& expected_shape =
|
||||||
entry_comp->parameter_instruction(i)->shape();
|
entry_comp->parameter_instruction(i)->shape();
|
||||||
const Shape& actual_shape = arguments[i].Buffers().shape();
|
const Shape& actual_shape = arguments[i].Buffers().shape();
|
||||||
CHECK(
|
TF_RET_CHECK(
|
||||||
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
|
ShapeUtil::DynamicShapeIsCompatible(actual_shape, expected_shape))
|
||||||
<< absl::StreamFormat(
|
<< "Shape mismatch on argument " << i << ", "
|
||||||
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
<< expected_shape.ToString(/*print_layout=*/true) << " vs. "
|
||||||
expected_shape.ToString(/*print_layout=*/true),
|
<< actual_shape.ToString(/*print_layout=*/true);
|
||||||
actual_shape.ToString(/*print_layout=*/true));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,10 +28,57 @@ limitations under the License.
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/stream_executor/device_description.h"
|
#include "tensorflow/stream_executor/device_description.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
ExecutionInput::~ExecutionInput() {
|
||||||
|
for (auto& index : unowned_indices_) {
|
||||||
|
auto buffer = buffers_.mutable_element(index)->Release();
|
||||||
|
if (buffer) {
|
||||||
|
buffer->Release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ExecutionInput::SetDynamicShape(Shape dynamic_shape) {
|
||||||
|
const Shape& input_shape = shape();
|
||||||
|
if (!ShapeUtil::DynamicShapeIsCompatible(input_shape, dynamic_shape)) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot set dynamic shape: ", input_shape.DebugString(), " vs. ",
|
||||||
|
dynamic_shape.DebugString());
|
||||||
|
}
|
||||||
|
dynamic_shape_ = absl::make_unique<Shape>(std::move(dynamic_shape));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
|
||||||
|
MaybeOwningDeviceMemory buffer) {
|
||||||
|
*buffers_.mutable_element(index) = std::move(buffer);
|
||||||
|
unowned_indices_.insert(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer(
|
||||||
|
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
|
||||||
|
const Shape& input_shape = shape();
|
||||||
|
xla::ShapedBuffer shaped_buffer(input_shape, input_shape,
|
||||||
|
allocator->platform(), device_ordinal);
|
||||||
|
for (const auto& index_buffer : Buffers()) {
|
||||||
|
const tensorflow::se::OwningDeviceMemory* mem =
|
||||||
|
index_buffer.second.AsOwningDeviceMemory();
|
||||||
|
if (mem != nullptr && (mem->allocator() != allocator ||
|
||||||
|
mem->device_ordinal() != device_ordinal)) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Device buffer at index ", index_buffer.first.ToString(),
|
||||||
|
" has mismatching allocator/device");
|
||||||
|
}
|
||||||
|
shaped_buffer.set_buffer(index_buffer.second.AsDeviceMemoryBase(),
|
||||||
|
index_buffer.first);
|
||||||
|
}
|
||||||
|
return std::move(shaped_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -65,31 +66,32 @@ class ExecutionInput {
|
||||||
: buffers_(std::move(buffers)) {}
|
: buffers_(std::move(buffers)) {}
|
||||||
ExecutionInput(ExecutionInput&&) = default;
|
ExecutionInput(ExecutionInput&&) = default;
|
||||||
|
|
||||||
~ExecutionInput() {
|
~ExecutionInput();
|
||||||
for (auto& index : unowned_indices_) {
|
|
||||||
auto buffer = buffers_.mutable_element(index)->Release();
|
|
||||||
if (buffer) {
|
|
||||||
buffer->Release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutionInput& operator=(ExecutionInput&&) = default;
|
ExecutionInput& operator=(ExecutionInput&&) = default;
|
||||||
|
|
||||||
const Shape& shape() const { return buffers_.shape(); }
|
const Shape& shape() const {
|
||||||
|
return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SetDynamicShape(Shape dynamic_shape);
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(
|
||||||
|
se::DeviceMemoryAllocator* allocator, int device_ordinal) const;
|
||||||
|
|
||||||
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
|
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
|
||||||
*buffers_.mutable_element(index) = std::move(buffer);
|
*buffers_.mutable_element(index) = std::move(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetUnownedBuffer(const ShapeIndex& index,
|
void SetUnownedBuffer(const ShapeIndex& index,
|
||||||
MaybeOwningDeviceMemory buffer) {
|
MaybeOwningDeviceMemory buffer);
|
||||||
*buffers_.mutable_element(index) = std::move(buffer);
|
|
||||||
unowned_indices_.push_back(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetUnownedIndex(const ShapeIndex& index) {
|
void SetUnownedIndex(const ShapeIndex& index) {
|
||||||
unowned_indices_.push_back(index);
|
unowned_indices_.insert(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearUnownedIndex(const ShapeIndex& index) {
|
||||||
|
unowned_indices_.erase(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
|
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
|
||||||
|
@ -106,9 +108,10 @@ class ExecutionInput {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
||||||
// (Unordered) set of indices of buffers that should be returned to the
|
// Set of indices of buffers that should be returned to the caller if an error
|
||||||
// caller if an error occurs when enqueuing the computation.
|
// occurs when enqueuing the computation.
|
||||||
std::vector<ShapeIndex> unowned_indices_;
|
std::set<ShapeIndex> unowned_indices_;
|
||||||
|
std::unique_ptr<Shape> dynamic_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ExecutionOutput encapsulates the output buffers of a execution and the
|
// ExecutionOutput encapsulates the output buffers of a execution and the
|
||||||
|
@ -145,7 +148,6 @@ class ExecutionOutput {
|
||||||
to_be_released_.push_back(std::move(mem));
|
to_be_released_.push_back(std::move(mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Should be called once it is known that the execute operation succeeded,
|
// Should be called once it is known that the execute operation succeeded,
|
||||||
// before returning the ExecutionOutput to the caller.
|
// before returning the ExecutionOutput to the caller.
|
||||||
ExecutionOutput& Commit() {
|
ExecutionOutput& Commit() {
|
||||||
|
|
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
|
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase()
|
||||||
|
@ -38,4 +40,10 @@ MaybeOwningDeviceMemory::Release() {
|
||||||
return std::move(absl::get<tensorflow::se::OwningDeviceMemory>(mem_));
|
return std::move(absl::get<tensorflow::se::OwningDeviceMemory>(mem_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tensorflow::se::OwningDeviceMemory*
|
||||||
|
MaybeOwningDeviceMemory::AsOwningDeviceMemory() const {
|
||||||
|
return HasOwnership() ? &absl::get<tensorflow::se::OwningDeviceMemory>(mem_)
|
||||||
|
: nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
@ -57,6 +57,10 @@ class MaybeOwningDeviceMemory {
|
||||||
// A nullopt is returned if the HasOwnership() == false;
|
// A nullopt is returned if the HasOwnership() == false;
|
||||||
absl::optional<tensorflow::se::OwningDeviceMemory> Release();
|
absl::optional<tensorflow::se::OwningDeviceMemory> Release();
|
||||||
|
|
||||||
|
// If the device memory is owned, returns a pointer to the internal
|
||||||
|
// OwningDeviceMemory, otherwise nullptr is returned.
|
||||||
|
const tensorflow::se::OwningDeviceMemory* AsOwningDeviceMemory() const;
|
||||||
|
|
||||||
// Returns true if the device_memory has ownership over underlying memory.
|
// Returns true if the device_memory has ownership over underlying memory.
|
||||||
bool HasOwnership() const;
|
bool HasOwnership() const;
|
||||||
|
|
||||||
|
|
|
@ -1461,7 +1461,7 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
|
/* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
|
||||||
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
||||||
if (dynamic_shape.rank() != bounded_shape.rank()) {
|
if (dynamic_shape.rank() != bounded_shape.rank()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -1474,6 +1474,36 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
|
||||||
|
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
||||||
|
bool compatible = true;
|
||||||
|
xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
|
||||||
|
const ShapeIndex& index) {
|
||||||
|
if (compatible) {
|
||||||
|
auto subshape_result = TryGetSubshape(bounded_shape, index);
|
||||||
|
if (subshape_result.ok()) {
|
||||||
|
const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
|
||||||
|
if (sub_shape.IsTuple()) {
|
||||||
|
if (!bounded_sub_shape->IsTuple()) {
|
||||||
|
compatible = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (bounded_sub_shape->IsTuple()) {
|
||||||
|
compatible = false;
|
||||||
|
} else if (!sub_shape.is_static() &&
|
||||||
|
!DynamicArrayShapeIsCompatible(sub_shape,
|
||||||
|
*bounded_sub_shape)) {
|
||||||
|
compatible = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
compatible = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return compatible;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ Shape ShapeUtil::FilterDimensions(
|
/* static */ Shape ShapeUtil::FilterDimensions(
|
||||||
const std::function<bool(int64)>& p, Shape shape) {
|
const std::function<bool(int64)>& p, Shape shape) {
|
||||||
CHECK(shape.IsArray());
|
CHECK(shape.IsArray());
|
||||||
|
|
|
@ -657,7 +657,11 @@ class ShapeUtil {
|
||||||
Shape shape);
|
Shape shape);
|
||||||
|
|
||||||
// Returns true if `dynamic_shape` has dimensions that are less-equal to the
|
// Returns true if `dynamic_shape` has dimensions that are less-equal to the
|
||||||
// "bounded_shape".
|
// "bounded_shape". Shapes must be arrays.
|
||||||
|
static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
|
||||||
|
const xla::Shape& bounded_shape);
|
||||||
|
|
||||||
|
// Same as DynamicArrayShapeIsCompatible() but supports tuples.
|
||||||
static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
|
static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
|
||||||
const xla::Shape& bounded_shape);
|
const xla::Shape& bounded_shape);
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,7 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/service:backend",
|
"//tensorflow/compiler/xla/service:backend",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
"//tensorflow/compiler/xla/service:executable",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|
|
@ -51,12 +51,6 @@ namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct InputBuffers {
|
|
||||||
std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
|
|
||||||
std::vector<xla::ShapedBuffer> input_allocations;
|
|
||||||
std::vector<xla::ShapedBuffer*> input_pointers;
|
|
||||||
};
|
|
||||||
|
|
||||||
uint32 InitialRandomSeed() {
|
uint32 InitialRandomSeed() {
|
||||||
// Support plumbing the TF seed through to XLA is being worked on.
|
// Support plumbing the TF seed through to XLA is being worked on.
|
||||||
// If a user wants deterministic behavior, their best option
|
// If a user wants deterministic behavior, their best option
|
||||||
|
@ -80,75 +74,51 @@ uint32 GetXLARandomSeed() {
|
||||||
return counter.fetch_add(2);
|
return counter.fetch_add(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<InputBuffers> GetInputBuffers(
|
std::vector<bool> GetDynamicInputInfo(
|
||||||
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
|
const xla::ComputationLayout& computation_layout) {
|
||||||
const std::vector<InputCoords>& input_coords, bool release_inputs) {
|
std::vector<bool> input_is_dynamic;
|
||||||
InputBuffers input_buffers;
|
input_is_dynamic.reserve(computation_layout.parameter_count());
|
||||||
input_buffers.input_tuples.reserve(input_coords.size());
|
for (int64 i = 0; i < computation_layout.parameter_count(); ++i) {
|
||||||
input_buffers.input_allocations.reserve(input_coords.size());
|
input_is_dynamic.push_back(
|
||||||
input_buffers.input_pointers.reserve(input_coords.size());
|
!computation_layout.parameter_shape(i).is_static());
|
||||||
for (size_t i = 0; i < input_coords.size(); ++i) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
working_set->LookupAndPin(backend, input_coords[i].handle));
|
|
||||||
auto tuple = working_set->PinnedTuples().back();
|
|
||||||
input_buffers.input_tuples.emplace_back(tuple);
|
|
||||||
if (release_inputs) {
|
|
||||||
// We are holding a reference to the tuple, so we can safely delete it
|
|
||||||
// from the resource manager here.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
working_set->MemoryManager()->Release(input_coords[i].handle));
|
|
||||||
VLOG(2) << "Released allocation handle " << input_coords[i].handle;
|
|
||||||
}
|
}
|
||||||
if (input_coords[i].index.empty()) {
|
return input_is_dynamic;
|
||||||
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
|
|
||||||
tuple->ToShapedBuffer());
|
|
||||||
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
|
|
||||||
} else {
|
|
||||||
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
|
|
||||||
tuple->ToShapedBuffer());
|
|
||||||
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer,
|
|
||||||
shaped_buffer.SubShapedBuffer(input_coords[i].index));
|
|
||||||
input_buffers.input_allocations.emplace_back(
|
|
||||||
std::move(sub_shaped_buffer));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) {
|
|
||||||
input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]);
|
|
||||||
}
|
|
||||||
return std::move(input_buffers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<InputBuffers> GetChainedOpInputs(
|
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTuples(
|
||||||
|
xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set,
|
||||||
|
xla::Backend* backend, const std::vector<InputCoords>& input_coords,
|
||||||
|
bool release_inputs) {
|
||||||
|
const xla::ComputationLayout& computation_layout =
|
||||||
|
executable->executable()->module_config().entry_computation_layout();
|
||||||
|
|
||||||
|
return GetInputTupleAllocations(
|
||||||
|
input_coords, working_set, backend, computation_layout.parameter_count(),
|
||||||
|
[&](int64 i) { return computation_layout.parameter_shape(i); },
|
||||||
|
release_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputTuples(
|
||||||
const xrt::XRTChainedExecuteOp& op,
|
const xrt::XRTChainedExecuteOp& op,
|
||||||
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
|
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
|
||||||
InputBuffers input_buffers;
|
std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
|
||||||
input_buffers.input_tuples.reserve(op.inputs_size());
|
input_tuples.reserve(op.inputs_size());
|
||||||
input_buffers.input_allocations.reserve(op.inputs_size());
|
|
||||||
input_buffers.input_pointers.reserve(op.inputs_size());
|
|
||||||
for (int i = 0; i < op.inputs_size(); ++i) {
|
for (int i = 0; i < op.inputs_size(); ++i) {
|
||||||
auto& input = op.inputs(i);
|
auto& input = op.inputs(i);
|
||||||
input_buffers.input_tuples.emplace_back(op_inputs[i]);
|
|
||||||
// Thanks to the greatness of proto3, there is no way to query for
|
// Thanks to the greatness of proto3, there is no way to query for
|
||||||
// explicitly set fields, so the default for output_index (zero) means no
|
// explicitly set fields, so the default for output_index (zero) means no
|
||||||
// sub-index. As consequence, the real index is output_index - 1.
|
// sub-index. As consequence, the real index is output_index - 1.
|
||||||
if (input.output_index() == 0) {
|
if (input.output_index() == 0) {
|
||||||
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
|
input_tuples.emplace_back(op_inputs[i]);
|
||||||
input_buffers.input_tuples.back()->ToShapedBuffer());
|
|
||||||
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
|
|
||||||
} else {
|
} else {
|
||||||
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
|
XRTTupleAllocation* sub_tuple;
|
||||||
input_buffers.input_tuples.back()->ToShapedBuffer());
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
|
||||||
TF_ASSIGN_OR_RETURN(
|
op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple,
|
||||||
xla::ShapedBuffer sub_shaped_buffer,
|
/*alias_parent_allocation=*/true));
|
||||||
shaped_buffer.SubShapedBuffer({input.output_index() - 1}));
|
input_tuples.emplace_back(sub_tuple);
|
||||||
input_buffers.input_allocations.emplace_back(
|
|
||||||
std::move(sub_shaped_buffer));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) {
|
return input_tuples;
|
||||||
input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]);
|
|
||||||
}
|
|
||||||
return std::move(input_buffers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a shape, returns a byte array representing the shape metadata of the
|
// Given a shape, returns a byte array representing the shape metadata of the
|
||||||
|
@ -228,12 +198,11 @@ Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
|
||||||
// As we can't expand the size of an existing memory allocation, a reallocation
|
// As we can't expand the size of an existing memory allocation, a reallocation
|
||||||
// is required. A list of new allocations are returned after this function. The
|
// is required. A list of new allocations are returned after this function. The
|
||||||
// caller is reponsible for maintaining those allocations.
|
// caller is reponsible for maintaining those allocations.
|
||||||
xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs(
|
Status UpdateDynamicInputs(
|
||||||
se::Stream* stream, se::DeviceMemoryAllocator* allocator,
|
se::Stream* stream, se::DeviceMemoryAllocator* allocator,
|
||||||
std::vector<xla::ShapedBuffer*> runtime_inputs,
|
std::vector<xla::ExecutionInput>* execution_inputs,
|
||||||
const std::vector<xla::ShapeLayout>& compile_time_shapes) {
|
const std::vector<xla::ShapeLayout>& compile_time_shapes) {
|
||||||
std::vector<se::OwningDeviceMemory> new_allocations;
|
TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size());
|
||||||
TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size());
|
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||||
stream->parent()->platform()));
|
stream->parent()->platform()));
|
||||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||||
|
@ -242,57 +211,61 @@ xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs(
|
||||||
if (compile_time_shape.is_static()) {
|
if (compile_time_shape.is_static()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto* runtime_input = runtime_inputs[i];
|
xla::ExecutionInput* execution_input = &(*execution_inputs)[i];
|
||||||
|
|
||||||
bool element_modified = false;
|
bool element_modified = false;
|
||||||
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
|
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
|
||||||
compile_time_shape,
|
compile_time_shape,
|
||||||
[&](const xla::Shape& compile_time_shape,
|
[&](const xla::Shape& sub_shape,
|
||||||
const xla::ShapeIndex& index) -> Status {
|
const xla::ShapeIndex& index) -> Status {
|
||||||
if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
|
if (sub_shape.IsTuple() || sub_shape.is_static()) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape(
|
|
||||||
runtime_input->on_device_shape(), index);
|
|
||||||
TF_RET_CHECK(!runtime_shape.IsTuple());
|
|
||||||
TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible(
|
|
||||||
runtime_shape, compile_time_shape));
|
|
||||||
se::DeviceMemoryBase* static_input =
|
|
||||||
runtime_input->buffers().mutable_element(index);
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto dynamic_input,
|
const xla::Shape* runtime_shape,
|
||||||
|
xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index));
|
||||||
|
TF_RET_CHECK(!runtime_shape->IsTuple());
|
||||||
|
TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible(
|
||||||
|
*runtime_shape, sub_shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
se::OwningDeviceMemory dynamic_input,
|
||||||
allocator->Allocate(stream->parent()->device_ordinal(),
|
allocator->Allocate(stream->parent()->device_ordinal(),
|
||||||
shape_size_fn(compile_time_shape)));
|
shape_size_fn(sub_shape)));
|
||||||
new_allocations.emplace_back(std::move(dynamic_input));
|
|
||||||
se::DeviceMemory<uint8>* dynamic_input_base =
|
se::DeviceMemoryBase static_input =
|
||||||
new_allocations.back().ptr();
|
execution_input->Buffer(index).AsDeviceMemoryBase();
|
||||||
|
se::DeviceMemory<uint8>* dynamic_input_base = dynamic_input.ptr();
|
||||||
// Send the original data to the new location.
|
// Send the original data to the new location.
|
||||||
stream->ThenMemcpyD2D(dynamic_input_base, *static_input,
|
stream->ThenMemcpyD2D(dynamic_input_base, static_input,
|
||||||
static_input->size());
|
static_input.size());
|
||||||
TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
|
TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
|
||||||
compile_time_shape, runtime_shape));
|
sub_shape, *runtime_shape));
|
||||||
// Modify the memory location in the input shape tree to point to the
|
// Modify the memory location in the input shape tree to point to the
|
||||||
// new input.
|
// new input.
|
||||||
runtime_input->set_buffer(*dynamic_input_base, index);
|
execution_input->SetBuffer(
|
||||||
|
index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input)));
|
||||||
|
execution_input->ClearUnownedIndex(index);
|
||||||
element_modified = true;
|
element_modified = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
if (element_modified) {
|
if (element_modified) {
|
||||||
runtime_input->set_shapes(compile_time_shape, compile_time_shape);
|
TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
|
||||||
|
execution_input->ToShapedBuffer(
|
||||||
|
allocator, stream->parent()->device_ordinal()));
|
||||||
// The input location has been modified, need to fix tuple table to
|
// The input location has been modified, need to fix tuple table to
|
||||||
// point to the correct address.
|
// point to the correct address.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto transfer_manager,
|
auto transfer_manager,
|
||||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input));
|
transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::move(new_allocations);
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||||
se::Stream* stream, se::DeviceMemoryBase* buffer,
|
se::Stream* stream, se::DeviceMemoryBase buffer,
|
||||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||||
stream->parent()->platform()));
|
stream->parent()->platform()));
|
||||||
|
@ -302,7 +275,7 @@ xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||||
TF_RET_CHECK(metadata_size != 0);
|
TF_RET_CHECK(metadata_size != 0);
|
||||||
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
|
||||||
auto metadata_buffer =
|
auto metadata_buffer =
|
||||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||||
return transfer_manager->TransferArrayFromDevice(
|
return transfer_manager->TransferArrayFromDevice(
|
||||||
|
@ -315,7 +288,7 @@ xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
// dimension sizes from the metadata, and update output shapes. The result shape
|
||||||
// is a static and concrete shape.
|
// is a static and concrete shape.
|
||||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||||
xla::ShapedBuffer* shaped_buffer,
|
const xla::ShapedBuffer& shaped_buffer,
|
||||||
xla::Shape* output_host_shape,
|
xla::Shape* output_host_shape,
|
||||||
xla::Shape* output_device_shape) {
|
xla::Shape* output_device_shape) {
|
||||||
DCHECK(output_device_shape->is_dynamic());
|
DCHECK(output_device_shape->is_dynamic());
|
||||||
|
@ -323,8 +296,8 @@ xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||||
auto transfer_manager,
|
auto transfer_manager,
|
||||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||||
TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus(
|
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
|
||||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
||||||
const xla::Shape& buffer_shape =
|
const xla::Shape& buffer_shape =
|
||||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
||||||
if (buffer_shape.IsTuple()) {
|
if (buffer_shape.IsTuple()) {
|
||||||
|
@ -352,19 +325,18 @@ xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output tuple from run_result.
|
|
||||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||||
se::Stream* stream, xla::ScopedShapedBuffer run_result,
|
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
||||||
xla::Backend* backend, int device_ordinal) {
|
int device_ordinal) {
|
||||||
XRTTupleAllocation* output_tuple;
|
XRTTupleAllocation* output_tuple;
|
||||||
xla::ShapedBuffer shaped_buffer = run_result.release();
|
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
|
||||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
||||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||||
// dimension sizes read from metadata.
|
// dimension sizes read from metadata.
|
||||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
||||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
||||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
||||||
stream, &shaped_buffer, &output_host_shape, &output_device_shape));
|
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||||
device_ordinal, &output_tuple));
|
device_ordinal, &output_tuple));
|
||||||
|
@ -373,15 +345,27 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||||
}
|
}
|
||||||
|
// After the output tuple is created, we can release the output result
|
||||||
|
// buffers, to make sure they won't be cleared by its destructor.
|
||||||
|
(void)run_result.ConsumeResult().release();
|
||||||
return RefPtr<XRTTupleAllocation>(output_tuple);
|
return RefPtr<XRTTupleAllocation>(output_tuple);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||||
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||||
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
|
xla::LocalExecutable* executable,
|
||||||
se::Stream* stream, int rng_seed,
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
bool release_inputs, se::Stream* stream, int rng_seed,
|
||||||
const xrt::CommonExecutionConfig& config) {
|
const xrt::CommonExecutionConfig& config) {
|
||||||
VLOG(2) << "Executing computation.";
|
const xla::ComputationLayout& computation_layout =
|
||||||
|
executable->executable()->module_config().entry_computation_layout();
|
||||||
|
std::vector<bool> input_is_dynamic = GetDynamicInputInfo(computation_layout);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<xla::ExecutionInput> execution_inputs,
|
||||||
|
GetArgumentsBuffers(
|
||||||
|
executable->executable()->module().input_output_alias_config(),
|
||||||
|
input_tuples, input_is_dynamic, release_inputs));
|
||||||
|
|
||||||
xla::ExecutableRunOptions run_options;
|
xla::ExecutableRunOptions run_options;
|
||||||
run_options.set_stream(stream);
|
run_options.set_stream(stream);
|
||||||
run_options.set_allocator(device_ref->backend()->memory_allocator());
|
run_options.set_allocator(device_ref->backend()->memory_allocator());
|
||||||
|
@ -419,51 +403,28 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||||
}
|
}
|
||||||
run_options.set_gpu_executable_run_options(&gpu_options);
|
run_options.set_gpu_executable_run_options(&gpu_options);
|
||||||
|
|
||||||
Env* env = Env::Default();
|
|
||||||
auto start_time = env->NowMicros();
|
|
||||||
const std::vector<xla::ShapeLayout>& shape_layouts =
|
const std::vector<xla::ShapeLayout>& shape_layouts =
|
||||||
executable->executable()
|
executable->executable()
|
||||||
->module_config()
|
->module_config()
|
||||||
.entry_computation_layout()
|
.entry_computation_layout()
|
||||||
.parameter_layouts();
|
.parameter_layouts();
|
||||||
TF_ASSIGN_OR_RETURN(auto new_allocations,
|
TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(),
|
||||||
UpdateDynamicInputs(stream, run_options.allocator(),
|
&execution_inputs, shape_layouts));
|
||||||
input_buffers.input_pointers,
|
|
||||||
shape_layouts));
|
|
||||||
auto new_allocations_ptr =
|
|
||||||
std::make_shared<std::vector<se::OwningDeviceMemory>>(
|
|
||||||
std::move(new_allocations));
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
xla::ScopedShapedBuffer run_result,
|
xla::ExecutionOutput run_result,
|
||||||
executable->Run(input_buffers.input_pointers, run_options));
|
executable->Run(std::move(execution_inputs), run_options));
|
||||||
// Retain the new allocation for input memory until the end of execution.
|
|
||||||
stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); });
|
|
||||||
|
|
||||||
auto elapsed = env->NowMicros() - start_time;
|
|
||||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
RefPtr<XRTTupleAllocation> output_tuple_ptr,
|
RefPtr<XRTTupleAllocation> output_tuple_ptr,
|
||||||
CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
|
CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
|
||||||
device_ref->device_ordinal()));
|
device_ref->device_ordinal()));
|
||||||
|
|
||||||
// The ScopedShapedBuffer returned by the executable Run() API, in case of
|
// The ScopedShapedBuffer returned by the executable Run() API, in case of
|
||||||
// input/output buffer aliasing, might have holes in it, which need to be
|
// input/output buffer aliasing, might have holes in it, which need to be
|
||||||
// filled using the proper input tuples buffers which are the source of
|
// filled using the proper input tuples buffers which are the source of
|
||||||
// aliasing.
|
// aliasing.
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
TF_RETURN_IF_ERROR(RebuildOutputAliases(
|
||||||
executable->executable()->module().input_output_alias_config();
|
output_tuple_ptr, input_tuples,
|
||||||
auto alias_function =
|
executable->executable()->module().input_output_alias_config()));
|
||||||
[&](const xla::ShapeIndex& output_index,
|
|
||||||
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
|
|
||||||
TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size());
|
|
||||||
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
|
|
||||||
? output_tuple_ptr->AliasBufferFrom(
|
|
||||||
*input_buffers.input_tuples[alias.parameter_number],
|
|
||||||
alias.parameter_index, output_index)
|
|
||||||
: Status::OK();
|
|
||||||
};
|
|
||||||
TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function));
|
|
||||||
|
|
||||||
return std::move(output_tuple_ptr);
|
return std::move(output_tuple_ptr);
|
||||||
}
|
}
|
||||||
|
@ -471,12 +432,13 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
|
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
|
||||||
OpKernelContext* context, XRTMemoryManager* memory_manager,
|
OpKernelContext* context, XRTMemoryManager* memory_manager,
|
||||||
XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||||
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
|
xla::LocalExecutable* executable,
|
||||||
se::Stream* stream, int rng_seed,
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
bool release_inputs, se::Stream* stream, int rng_seed,
|
||||||
const xrt::CommonExecutionConfig& config) {
|
const xrt::CommonExecutionConfig& config) {
|
||||||
auto runfn = [&]() {
|
auto runfn = [&]() {
|
||||||
return RunExecutable(context, device_ref, executable, input_buffers, stream,
|
return RunExecutable(context, device_ref, executable, input_tuples,
|
||||||
rng_seed, config);
|
release_inputs, stream, rng_seed, config);
|
||||||
};
|
};
|
||||||
|
|
||||||
// We pass zero as requested_free_size as there is no simple way to get the
|
// We pass zero as requested_free_size as there is no simple way to get the
|
||||||
|
@ -495,12 +457,13 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
|
||||||
se::Stream* stream, int rng_seed,
|
se::Stream* stream, int rng_seed,
|
||||||
const xrt::CommonExecutionConfig& config) {
|
const xrt::CommonExecutionConfig& config) {
|
||||||
XRTMemoryManager::WorkingSet working_set(memory_manager);
|
XRTMemoryManager::WorkingSet working_set(memory_manager);
|
||||||
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
|
TF_ASSIGN_OR_RETURN(
|
||||||
GetInputBuffers(&working_set, device_ref->backend(),
|
std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
GetInputTuples(executable, &working_set, device_ref->backend(),
|
||||||
input_coords, release_inputs));
|
input_coords, release_inputs));
|
||||||
return ExecuteComputation(context, memory_manager.get(), device_ref,
|
return ExecuteComputation(context, memory_manager.get(), device_ref,
|
||||||
executable, input_buffers, stream, rng_seed,
|
executable, input_tuples, release_inputs, stream,
|
||||||
config);
|
rng_seed, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
// XRTExecuteOp
|
// XRTExecuteOp
|
||||||
|
@ -653,16 +616,16 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
|
||||||
auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
|
auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
|
||||||
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
|
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
|
||||||
-> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
|
-> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
|
||||||
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
|
|
||||||
GetChainedOpInputs(op, op_inputs));
|
|
||||||
|
|
||||||
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
|
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
|
||||||
TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
|
TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
|
||||||
xla::LocalExecutable* executable = entry->get().get_executable();
|
xla::LocalExecutable* executable = entry->get().get_executable();
|
||||||
|
|
||||||
return ExecuteComputation(context, memory_manager.get(), &device_ref,
|
TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
executable, input_buffers, stream, rng_seed,
|
GetChainedOpInputTuples(op, op_inputs));
|
||||||
config.common_config());
|
|
||||||
|
return ExecuteComputation(
|
||||||
|
context, memory_manager.get(), &device_ref, executable, input_tuples,
|
||||||
|
/*release_inputs=*/false, stream, rng_seed, config.common_config());
|
||||||
};
|
};
|
||||||
|
|
||||||
return ExecuteChained(context, memory_manager, device_ref.backend(),
|
return ExecuteChained(context, memory_manager, device_ref.backend(),
|
||||||
|
|
|
@ -221,6 +221,140 @@ xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
|
||||||
return std::move(input_coords);
|
return std::move(input_coords);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InputShapeMatches(const xla::Shape& parameter_shape,
|
||||||
|
const xla::Shape& input_shape) {
|
||||||
|
auto shape_checker = [&](const xla::Shape& pshape,
|
||||||
|
const xla::ShapeIndex& index) {
|
||||||
|
if (pshape.IsArray()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
|
||||||
|
xla::ShapeUtil::TryGetSubshape(input_shape, index));
|
||||||
|
if (pshape.rank() != ishape->rank() ||
|
||||||
|
pshape.element_type() != ishape->element_type()) {
|
||||||
|
return errors::InvalidArgument("Mismatching shapes");
|
||||||
|
}
|
||||||
|
if (pshape.is_static() && pshape.layout() != ishape->layout()) {
|
||||||
|
return errors::InvalidArgument("Mismatching layouts");
|
||||||
|
}
|
||||||
|
for (int64 dim = 0; dim < pshape.rank(); ++dim) {
|
||||||
|
if (pshape.is_dynamic_dimension(dim)) {
|
||||||
|
if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
|
||||||
|
return errors::InvalidArgument("Mismatching shapes");
|
||||||
|
}
|
||||||
|
} else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
|
||||||
|
return errors::InvalidArgument("Mismatching shapes");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
|
||||||
|
shape_checker)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
|
||||||
|
const std::vector<InputCoords>& input_coords,
|
||||||
|
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
|
||||||
|
int64 num_input_shapes,
|
||||||
|
const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs) {
|
||||||
|
if (input_coords.size() != num_input_shapes) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of inputs does not match executable proto input shapes: ",
|
||||||
|
input_coords.size(), " vs. ", num_input_shapes);
|
||||||
|
}
|
||||||
|
std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
|
||||||
|
input_tuples.reserve(input_coords.size());
|
||||||
|
for (size_t i = 0; i < input_coords.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
working_set->LookupAndPin(backend, input_coords[i].handle));
|
||||||
|
auto tuple = working_set->PinnedTuples().back();
|
||||||
|
if (release_inputs) {
|
||||||
|
// We are holding a reference to the tuple, so we can safely delete it
|
||||||
|
// from the resource manager here.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
working_set->MemoryManager()->Release(input_coords[i].handle));
|
||||||
|
VLOG(2) << "Released allocation handle " << input_coords[i].handle;
|
||||||
|
}
|
||||||
|
xla::Shape input_shape = shape_getter(i);
|
||||||
|
if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Run-time shape mismatch for XRTExecute argument[", i, "] (",
|
||||||
|
input_coords[i].handle, "). Expected ", input_shape.DebugString(),
|
||||||
|
"; got ", tuple->on_host_shape().DebugString());
|
||||||
|
}
|
||||||
|
if (input_coords[i].index.empty()) {
|
||||||
|
input_tuples.emplace_back(std::move(tuple));
|
||||||
|
} else {
|
||||||
|
XRTTupleAllocation* sub_tuple;
|
||||||
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
|
||||||
|
tuple.get(), input_coords[i].index, &sub_tuple,
|
||||||
|
/*alias_parent_allocation=*/true));
|
||||||
|
input_tuples.emplace_back(sub_tuple);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::move(input_tuples);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RebuildOutputAliases(
|
||||||
|
const RefPtr<XRTTupleAllocation>& output_tuple,
|
||||||
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||||
|
auto alias_function =
|
||||||
|
[&](const xla::ShapeIndex& output_index,
|
||||||
|
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
|
||||||
|
TF_RET_CHECK(alias.parameter_number < input_tuples.size());
|
||||||
|
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
|
||||||
|
? output_tuple->AliasBufferFrom(
|
||||||
|
*input_tuples[alias.parameter_number],
|
||||||
|
alias.parameter_index, output_index)
|
||||||
|
: Status::OK();
|
||||||
|
};
|
||||||
|
return input_output_alias.ForEachAliasWithStatus(alias_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
const std::vector<bool>& input_is_dynamic, bool release_inputs) {
|
||||||
|
auto is_dynamic = [&](size_t arg) {
|
||||||
|
return arg < input_is_dynamic.size() && input_is_dynamic[arg];
|
||||||
|
};
|
||||||
|
std::vector<xla::ExecutionInput> arguments;
|
||||||
|
// Don't alias dynamic input -- Due to the underlying implementation,
|
||||||
|
// aliased inputs have two owners: XRTAllocation and return value of
|
||||||
|
// this function. If an argument is dynamic and the ownership is
|
||||||
|
// released to output of this function, TPUExecute will free it and
|
||||||
|
// reallocate a new one, which creates a double freeing issue where
|
||||||
|
// XRTAllocation also attempts to release the buffer.
|
||||||
|
bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
|
||||||
|
input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
|
||||||
|
arguments.reserve(input_tuples.size());
|
||||||
|
for (int64 i = 0; i < input_tuples.size(); ++i) {
|
||||||
|
auto alias_checker =
|
||||||
|
[&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
|
||||||
|
// Only the buffers which the caller explicitly marked as aliased
|
||||||
|
// (kUserAlias), should create aliases.
|
||||||
|
// The XLA compiler might create opportunistic aliases (kSystemAlias)
|
||||||
|
// which need a different handling. With a system alias we know that XLA
|
||||||
|
// is going to reuse a given input parameter buffer for a given output, so
|
||||||
|
// unless it is known at call site that the input buffer has no more uses,
|
||||||
|
// a copy needs to be made at call site. With user specified alias the
|
||||||
|
// caller tells us that he expects a given output to land over the buffers
|
||||||
|
// of a given parametter.
|
||||||
|
if (input_output_alias.ParameterAliasKind(i, index) ==
|
||||||
|
xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) {
|
||||||
|
TF_RET_CHECK(!is_dynamic(i));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return alias_outputs;
|
||||||
|
};
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
|
||||||
|
input_tuples[i]->ToExecutionInput(alias_checker));
|
||||||
|
arguments.emplace_back(std::move(exec_input));
|
||||||
|
}
|
||||||
|
return std::move(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
Status CreateExecuteOutput(OpKernelContext* context,
|
Status CreateExecuteOutput(OpKernelContext* context,
|
||||||
XRTMemoryManager* memory_manager,
|
XRTMemoryManager* memory_manager,
|
||||||
RefPtr<XRTTupleAllocation> output_tuple,
|
RefPtr<XRTTupleAllocation> output_tuple,
|
||||||
|
|
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||||
|
@ -69,6 +71,25 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options);
|
||||||
xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
|
xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
|
||||||
OpKernelContext* context, const char* input_name);
|
OpKernelContext* context, const char* input_name);
|
||||||
|
|
||||||
|
bool InputShapeMatches(const xla::Shape& parameter_shape,
|
||||||
|
const xla::Shape& input_shape);
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
|
||||||
|
const std::vector<InputCoords>& input_coords,
|
||||||
|
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
|
||||||
|
int64 num_input_shapes,
|
||||||
|
const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs);
|
||||||
|
|
||||||
|
Status RebuildOutputAliases(
|
||||||
|
const RefPtr<XRTTupleAllocation>& output_tuple,
|
||||||
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
|
absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
|
||||||
|
const std::vector<bool>& input_is_dynamic, bool release_inputs);
|
||||||
|
|
||||||
// Create the XRT execute output tensor given the computation result
|
// Create the XRT execute output tensor given the computation result
|
||||||
// (output_tuple). The return_exploded_tuple tells whether a tuple result should
|
// (output_tuple). The return_exploded_tuple tells whether a tuple result should
|
||||||
// be returned as vector of handles representing each tuple child.
|
// be returned as vector of handles representing each tuple child.
|
||||||
|
|
Loading…
Reference in New Issue