Improve ASYNC resource variable creation

Checking the boolean tensor from the VarIsInitializedOp requires we
evaluate the tensor in python. This is undesirable for ASYNC execution
as it results in the python code being blocked, waiting for all
queued operations to be completed before synchronously evaluating the
tensor.

PiperOrigin-RevId: 249113005
This commit is contained in:
Gaurav Jain 2019-05-20 13:23:04 -07:00 committed by TensorFlower Gardener
parent cb76a97ac1
commit 4dd726e547

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
@ -150,10 +151,13 @@ def variable_handle_from_shape_and_dtype(
# When in eager mode, explicitly ensure so here. When in graph mode, it's
# ensured by always generating different variable names.
exists = gen_resource_variable_ops.var_is_initialized_op(handle)
if exists:
raise ValueError("variable object with name '%s' already created. Use "
"get_variable() if reuse is desired." %
shared_name)
# We create an assert Op instead of checking right away in order to be
# compatible with ASYNC execution mode. Further, since not all devices
# support string tensors, we encode the assertion string in the Op name
gen_logging_ops._assert( # pylint: disable=protected-access
math_ops.logical_not(exists), [exists], name="EagerVariableNameReuse")
with context.graph_mode(), ops.Graph().as_default() as graph:
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
shared_name=shared_name,
@ -1781,5 +1785,6 @@ def copy_to_graph_uninitialized(var):
# pylint: enable=protected-access
return new_variable
ops.NotDifferentiable("Assert")
ops.NotDifferentiable("VarIsInitializedOp")
ops.NotDifferentiable("VariableShape")