track persistent memory for temporary memory op.

Change: 155302379
This commit is contained in:
Yuefeng Zhou 2017-05-06 18:45:33 -08:00 committed by TensorFlower Gardener
parent 93572de9a1
commit 24a61445f4

View File

@ -127,6 +127,16 @@ class TemporaryVariableOp : public OpKernel {
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
var_name_, tmp_var));
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
if (context->track_allocations()) {
AllocatorAttributes attr;
if (context->allocate_on_host(attr)) {
context->record_host_persistent_memory_allocation(
tmp_var->val.AllocatedBytes());
} else {
context->record_device_persistent_memory_allocation(
tmp_var->val.AllocatedBytes());
}
}
}
private: