diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 25b17b26c8d..f0b5796d04a 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -127,6 +127,16 @@ class TemporaryVariableOp : public OpKernel { OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(), var_name_, tmp_var)); context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); + if (context->track_allocations()) { + AllocatorAttributes attr; + if (context->allocate_on_host(attr)) { + context->record_host_persistent_memory_allocation( + tmp_var->val.AllocatedBytes()); + } else { + context->record_device_persistent_memory_allocation( + tmp_var->val.AllocatedBytes()); + } + } } private: