Add overload to check for device memory access semantics, akin to
CanShapedBufferBeAccessedNow but for a single DeviceMemoryBase. PiperOrigin-RevId: 307053898 Change-Id: If79426ecd058eae75348f52e1c271702a6247fe1
This commit is contained in:
parent
a2288e4696
commit
5115a02054
@ -1175,8 +1175,8 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
LocalDeviceState::kComputeSynchronized) {
|
||||
stream->ThenWaitFor(local_device->compute_stream());
|
||||
} else {
|
||||
// In principle we would do a DCHECK for CanShapedBufferBeAccessedNow here
|
||||
// but that call requires a ShapedBuffer which we don't have.
|
||||
DCHECK(transfer_manager->CanBufferBeAccessedNow(
|
||||
local_device->compute_stream()->parent(), root_table_memory.cref()));
|
||||
}
|
||||
|
||||
ExecutionInput execution_input(on_device_shape);
|
||||
|
@ -1122,6 +1122,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -50,6 +51,12 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CanBufferBeAccessedNow(
|
||||
se::StreamExecutor* executor,
|
||||
const se::DeviceMemoryBase& device_buffer) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
|
||||
const void* source);
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -256,6 +257,13 @@ class TransferManager {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
|
||||
virtual bool CanBufferBeAccessedNow(
|
||||
se::StreamExecutor* executor,
|
||||
const se::DeviceMemoryBase& device_buffer) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/////
|
||||
// The TransferManager class also serves as a point to register objects for
|
||||
// the various platforms.
|
||||
|
Loading…
x
Reference in New Issue
Block a user