[TPU] Don't pass host_shapes across the TPU API boundary. They can be computed from device shapes.
PiperOrigin-RevId: 335996516 Change-Id: I702ff81ce311042399ef87bddb5e0d68b7464331
This commit is contained in:
parent
cc33e1f05b
commit
6a7391ca02
@ -153,6 +153,9 @@ class ExecutionOutput {
|
|||||||
std::vector<se::OwningDeviceMemory> to_be_released)
|
std::vector<se::OwningDeviceMemory> to_be_released)
|
||||||
: result_(std::move(result)),
|
: result_(std::move(result)),
|
||||||
to_be_released_(std::move(to_be_released)) {}
|
to_be_released_(std::move(to_be_released)) {}
|
||||||
|
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
|
||||||
|
int device_ordinal)
|
||||||
|
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||||
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
||||||
|
@ -119,7 +119,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
|
ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
|
||||||
ApiConverter::ToC(arg.host_shape(), &se_args[i]->host_shape);
|
|
||||||
const auto& unowned_indices = arg.unowned_indices();
|
const auto& unowned_indices = arg.unowned_indices();
|
||||||
se_args[i]->unowned_indices_size = unowned_indices.size();
|
se_args[i]->unowned_indices_size = unowned_indices.size();
|
||||||
se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
|
se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
|
||||||
@ -142,7 +141,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
|||||||
for (int i = 0; i < arguments.size(); ++i) {
|
for (int i = 0; i < arguments.size(); ++i) {
|
||||||
ApiConverter::Free(&se_args[i]->shape_tree.shape);
|
ApiConverter::Free(&se_args[i]->shape_tree.shape);
|
||||||
ApiConverter::Free(&se_args[i]->dynamic_shape);
|
ApiConverter::Free(&se_args[i]->dynamic_shape);
|
||||||
ApiConverter::Free(&se_args[i]->host_shape);
|
|
||||||
delete[] se_args[i]->unowned_indices;
|
delete[] se_args[i]->unowned_indices;
|
||||||
delete[] se_args[i]->shape_tree.buffers;
|
delete[] se_args[i]->shape_tree.buffers;
|
||||||
delete se_args[i];
|
delete se_args[i];
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
namespace ApiConverter {
|
namespace ApiConverter {
|
||||||
|
|
||||||
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
||||||
xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape);
|
|
||||||
xla::Shape xla_on_device_shape =
|
xla::Shape xla_on_device_shape =
|
||||||
ApiConverter::FromC(&c_buffer->on_device_shape);
|
ApiConverter::FromC(&c_buffer->on_device_shape);
|
||||||
|
|
||||||
@ -36,7 +35,7 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::ShapedBuffer xla_shaped_buffer(
|
xla::ShapedBuffer xla_shaped_buffer(
|
||||||
xla_on_host_shape, xla_on_device_shape,
|
xla_on_device_shape,
|
||||||
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(),
|
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(),
|
||||||
c_buffer->device_ordinal);
|
c_buffer->device_ordinal);
|
||||||
xla_shaped_buffer.set_buffers(xla_shape_tree);
|
xla_shaped_buffer.set_buffers(xla_shape_tree);
|
||||||
@ -199,7 +198,6 @@ xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
|
void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
|
||||||
ApiConverter::ToC(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
|
|
||||||
ApiConverter::ToC(buffer.on_device_shape(),
|
ApiConverter::ToC(buffer.on_device_shape(),
|
||||||
&c_device_buffer->on_device_shape);
|
&c_device_buffer->on_device_shape);
|
||||||
c_device_buffer->device_ordinal = buffer.device_ordinal();
|
c_device_buffer->device_ordinal = buffer.device_ordinal();
|
||||||
@ -226,7 +224,6 @@ void Free(XLA_Literal* c_literal) {
|
|||||||
|
|
||||||
void Free(XLA_ShapedBuffer* c_buffer) {
|
void Free(XLA_ShapedBuffer* c_buffer) {
|
||||||
ApiConverter::Free(&c_buffer->on_device_shape);
|
ApiConverter::Free(&c_buffer->on_device_shape);
|
||||||
ApiConverter::Free(&c_buffer->on_host_shape);
|
|
||||||
delete[] c_buffer->bases;
|
delete[] c_buffer->bases;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,7 +177,6 @@ typedef struct XLA_Shape {
|
|||||||
|
|
||||||
// Represents a leaf node for a XLA shaped buffer.
|
// Represents a leaf node for a XLA shaped buffer.
|
||||||
typedef struct XLA_ShapedBuffer {
|
typedef struct XLA_ShapedBuffer {
|
||||||
XLA_Shape on_host_shape;
|
|
||||||
XLA_Shape on_device_shape;
|
XLA_Shape on_device_shape;
|
||||||
int device_ordinal;
|
int device_ordinal;
|
||||||
|
|
||||||
@ -208,7 +207,6 @@ typedef struct SE_ExecutionInput {
|
|||||||
XLA_ShapeIndex* unowned_indices;
|
XLA_ShapeIndex* unowned_indices;
|
||||||
int unowned_indices_size;
|
int unowned_indices_size;
|
||||||
XLA_Shape dynamic_shape;
|
XLA_Shape dynamic_shape;
|
||||||
XLA_Shape host_shape;
|
|
||||||
} SE_ExecutionInput;
|
} SE_ExecutionInput;
|
||||||
|
|
||||||
typedef struct SE_ExecutionOutput {
|
typedef struct SE_ExecutionOutput {
|
||||||
|
@ -90,8 +90,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutionOutput result(host_shape, std::move(device_shape), allocator,
|
ExecutionOutput result(std::move(device_shape), allocator, device_ordinal);
|
||||||
device_ordinal);
|
|
||||||
// Iterate through and allocate a buffer for each shape index, checking for
|
// Iterate through and allocate a buffer for each shape index, checking for
|
||||||
// possible input buffer reuse.
|
// possible input buffer reuse.
|
||||||
int64 reused_buffer_bytes = 0;
|
int64 reused_buffer_bytes = 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user