track persistent memory for temporary memory op.
Change: 155302379
This commit is contained in:
parent
93572de9a1
commit
24a61445f4
@ -127,6 +127,16 @@ class TemporaryVariableOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
|
||||
var_name_, tmp_var));
|
||||
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
|
||||
if (context->track_allocations()) {
|
||||
AllocatorAttributes attr;
|
||||
if (context->allocate_on_host(attr)) {
|
||||
context->record_host_persistent_memory_allocation(
|
||||
tmp_var->val.AllocatedBytes());
|
||||
} else {
|
||||
context->record_device_persistent_memory_allocation(
|
||||
tmp_var->val.AllocatedBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user