Distinguish between not having a cross-program prefetch address and having a nullptr cross-program prefetch address.

PiperOrigin-RevId: 345334049
Change-Id: I150066c6ae18e762231b55388159c479ef0fb84f
This commit is contained in:
A. Unique TensorFlower 2020-12-02 16:11:57 -08:00 committed by TensorFlower Gardener
parent 254c5b9da2
commit fc52e64328
2 changed files with 3 additions and 0 deletions

View File

@ -122,6 +122,7 @@ typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
SE_DeviceMemoryBase* arguments; SE_DeviceMemoryBase* arguments;
size_t arguments_len; size_t arguments_len;
SE_DeviceMemoryBase* result; SE_DeviceMemoryBase* result;
bool has_cross_program_prefetch_addr;
SE_DeviceMemoryBase* cross_program_prefetch_addr; SE_DeviceMemoryBase* cross_program_prefetch_addr;
int32_t rng_seed; int32_t rng_seed;
XLA_DeviceAssignment* device_assignment; XLA_DeviceAssignment* device_assignment;

View File

@ -84,6 +84,8 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
params.arguments = arguments_bases; params.arguments = arguments_bases;
params.arguments_len = arguments.size(); params.arguments_len = arguments.size();
params.result = &result_base; params.result = &result_base;
params.has_cross_program_prefetch_addr =
cross_program_prefetch_addr.has_value();
params.cross_program_prefetch_addr = params.cross_program_prefetch_addr =
cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr; cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr;
params.rng_seed = rng_seed; params.rng_seed = rng_seed;