diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index df1f5fa1432..3f6019a059a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -244,7 +244,8 @@ Status TransferManager::WriteTupleIndexTablesAsync( return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { - if (device_subshape.IsTuple()) { + if (device_subshape.IsTuple() && + ShapeUtil::TupleElementCount(device_subshape) > 0) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); @@ -268,6 +269,9 @@ Status TransferManager::WriteTupleIndexTablesAsync( Status TransferManager::WriteRootTupleIndexTable( se::Stream* stream, const ShapedBuffer& device_buffer) { TF_RET_CHECK(device_buffer.on_device_shape().IsTuple()); + if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) { + return Status::OK(); + } se::DeviceMemoryBase device_memory = device_buffer.buffer({}); TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) == device_memory.size());