diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h index f49438bde85..77e5ddb406c 100644 --- a/tensorflow/core/tpu/tpu_ops_c_api.h +++ b/tensorflow/core/tpu/tpu_ops_c_api.h @@ -122,6 +122,7 @@ typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { SE_DeviceMemoryBase* arguments; size_t arguments_len; SE_DeviceMemoryBase* result; + bool has_cross_program_prefetch_addr; SE_DeviceMemoryBase* cross_program_prefetch_addr; int32_t rng_seed; XLA_DeviceAssignment* device_assignment; diff --git a/tensorflow/stream_executor/tpu/tpu_executable.cc b/tensorflow/stream_executor/tpu/tpu_executable.cc index a251f6711c6..6408d37b990 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable.cc +++ b/tensorflow/stream_executor/tpu/tpu_executable.cc @@ -84,6 +84,8 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream( params.arguments = arguments_bases; params.arguments_len = arguments.size(); params.result = &result_base; + params.has_cross_program_prefetch_addr = + cross_program_prefetch_addr.has_value(); params.cross_program_prefetch_addr = cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr; params.rng_seed = rng_seed;