[TPU] Don't pass host_shapes across the TPU API boundary. They can be computed from device shapes.
PiperOrigin-RevId: 335996516 Change-Id: I702ff81ce311042399ef87bddb5e0d68b7464331
This commit is contained in:
parent
cc33e1f05b
commit
6a7391ca02
@ -153,6 +153,9 @@ class ExecutionOutput {
|
||||
std::vector<se::OwningDeviceMemory> to_be_released)
|
||||
: result_(std::move(result)),
|
||||
to_be_released_(std::move(to_be_released)) {}
|
||||
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
||||
|
@ -119,7 +119,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
||||
}
|
||||
|
||||
ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
|
||||
ApiConverter::ToC(arg.host_shape(), &se_args[i]->host_shape);
|
||||
const auto& unowned_indices = arg.unowned_indices();
|
||||
se_args[i]->unowned_indices_size = unowned_indices.size();
|
||||
se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
|
||||
@ -142,7 +141,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
ApiConverter::Free(&se_args[i]->shape_tree.shape);
|
||||
ApiConverter::Free(&se_args[i]->dynamic_shape);
|
||||
ApiConverter::Free(&se_args[i]->host_shape);
|
||||
delete[] se_args[i]->unowned_indices;
|
||||
delete[] se_args[i]->shape_tree.buffers;
|
||||
delete se_args[i];
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
namespace ApiConverter {
|
||||
|
||||
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
||||
xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape);
|
||||
xla::Shape xla_on_device_shape =
|
||||
ApiConverter::FromC(&c_buffer->on_device_shape);
|
||||
|
||||
@ -36,7 +35,7 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
||||
}
|
||||
|
||||
xla::ShapedBuffer xla_shaped_buffer(
|
||||
xla_on_host_shape, xla_on_device_shape,
|
||||
xla_on_device_shape,
|
||||
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(),
|
||||
c_buffer->device_ordinal);
|
||||
xla_shaped_buffer.set_buffers(xla_shape_tree);
|
||||
@ -199,7 +198,6 @@ xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal) {
|
||||
}
|
||||
|
||||
void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
|
||||
ApiConverter::ToC(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
|
||||
ApiConverter::ToC(buffer.on_device_shape(),
|
||||
&c_device_buffer->on_device_shape);
|
||||
c_device_buffer->device_ordinal = buffer.device_ordinal();
|
||||
@ -226,7 +224,6 @@ void Free(XLA_Literal* c_literal) {
|
||||
|
||||
void Free(XLA_ShapedBuffer* c_buffer) {
|
||||
ApiConverter::Free(&c_buffer->on_device_shape);
|
||||
ApiConverter::Free(&c_buffer->on_host_shape);
|
||||
delete[] c_buffer->bases;
|
||||
}
|
||||
|
||||
|
@ -177,7 +177,6 @@ typedef struct XLA_Shape {
|
||||
|
||||
// Represents a leaf node for a XLA shaped buffer.
|
||||
typedef struct XLA_ShapedBuffer {
|
||||
XLA_Shape on_host_shape;
|
||||
XLA_Shape on_device_shape;
|
||||
int device_ordinal;
|
||||
|
||||
@ -208,7 +207,6 @@ typedef struct SE_ExecutionInput {
|
||||
XLA_ShapeIndex* unowned_indices;
|
||||
int unowned_indices_size;
|
||||
XLA_Shape dynamic_shape;
|
||||
XLA_Shape host_shape;
|
||||
} SE_ExecutionInput;
|
||||
|
||||
typedef struct SE_ExecutionOutput {
|
||||
|
@ -90,8 +90,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionOutput result(host_shape, std::move(device_shape), allocator,
|
||||
device_ordinal);
|
||||
ExecutionOutput result(std::move(device_shape), allocator, device_ordinal);
|
||||
// Iterate through and allocate a buffer for each shape index, checking for
|
||||
// possible input buffer reuse.
|
||||
int64 reused_buffer_bytes = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user