track persistent memory for temporary memory op.
Change: 155302379
This commit is contained in:
parent
93572de9a1
commit
24a61445f4
@ -127,6 +127,16 @@ class TemporaryVariableOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
|
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
|
||||||
var_name_, tmp_var));
|
var_name_, tmp_var));
|
||||||
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
|
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
|
||||||
|
if (context->track_allocations()) {
|
||||||
|
AllocatorAttributes attr;
|
||||||
|
if (context->allocate_on_host(attr)) {
|
||||||
|
context->record_host_persistent_memory_allocation(
|
||||||
|
tmp_var->val.AllocatedBytes());
|
||||||
|
} else {
|
||||||
|
context->record_device_persistent_memory_allocation(
|
||||||
|
tmp_var->val.AllocatedBytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Loading…
Reference in New Issue
Block a user