Improve ASYNC resource variable creation
Checking the boolean tensor from the VarIsInitializedOp requires we evaluate the tensor in python. This is undesirable for ASYNC execution as it results in the python code being blocked, waiting for all queued operations to be completed before synchronously evaluating the tensor. PiperOrigin-RevId: 249113005
This commit is contained in:
parent
cb76a97ac1
commit
4dd726e547
@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_logging_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -150,10 +151,13 @@ def variable_handle_from_shape_and_dtype(
|
||||
# When in eager mode, explicitly ensure so here. When in graph mode, it's
|
||||
# ensured by always generating different variable names.
|
||||
exists = gen_resource_variable_ops.var_is_initialized_op(handle)
|
||||
if exists:
|
||||
raise ValueError("variable object with name '%s' already created. Use "
|
||||
"get_variable() if reuse is desired." %
|
||||
shared_name)
|
||||
|
||||
# We create an assert Op instead of checking right away in order to be
|
||||
# compatible with ASYNC execution mode. Further, since not all devices
|
||||
# support string tensors, we encode the assertion string in the Op name
|
||||
gen_logging_ops._assert( # pylint: disable=protected-access
|
||||
math_ops.logical_not(exists), [exists], name="EagerVariableNameReuse")
|
||||
|
||||
with context.graph_mode(), ops.Graph().as_default() as graph:
|
||||
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||
shared_name=shared_name,
|
||||
@ -1781,5 +1785,6 @@ def copy_to_graph_uninitialized(var):
|
||||
# pylint: enable=protected-access
|
||||
return new_variable
|
||||
|
||||
ops.NotDifferentiable("Assert")
|
||||
ops.NotDifferentiable("VarIsInitializedOp")
|
||||
ops.NotDifferentiable("VariableShape")
|
||||
|
Loading…
x
Reference in New Issue
Block a user