Add overload to check for device memory access semantics, akin to

CanShapedBufferBeAccessedNow but for a single DeviceMemoryBase.

PiperOrigin-RevId: 307053898
Change-Id: If79426ecd058eae75348f52e1c271702a6247fe1
This commit is contained in:
A. Unique TensorFlower 2020-04-17 09:05:33 -07:00 committed by TensorFlower Gardener
parent a2288e4696
commit 5115a02054
4 changed files with 18 additions and 2 deletions

View File

@ -1175,8 +1175,8 @@ StatusOr<TupleHandle> MakeTupleHelper(
LocalDeviceState::kComputeSynchronized) {
stream->ThenWaitFor(local_device->compute_stream());
} else {
// In principle we would do a DCHECK for CanShapedBufferBeAccessedNow here
// but that call requires a ShapedBuffer which we don't have.
DCHECK(transfer_manager->CanBufferBeAccessedNow(
local_device->compute_stream()->parent(), root_table_memory.cref()));
}
ExecutionInput execution_input(on_device_shape);

View File

@ -1122,6 +1122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory.h"
namespace xla {
@ -50,6 +51,12 @@ class CpuTransferManager : public GenericTransferManager {
return true;
}
bool CanBufferBeAccessedNow(
se::StreamExecutor* executor,
const se::DeviceMemoryBase& device_buffer) const override {
return true;
}
private:
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source);

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory.h"
namespace xla {
@ -256,6 +257,13 @@ class TransferManager {
return false;
}
// Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
virtual bool CanBufferBeAccessedNow(
se::StreamExecutor* executor,
const se::DeviceMemoryBase& device_buffer) const {
return false;
}
/////
// The TransferManager class also serves as a point to register objects for
// the various platforms.