From 4dd726e54751411f40fb1ff4ec98c7f8aca4939b Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Mon, 20 May 2019 13:23:04 -0700 Subject: [PATCH] Improve ASYNC resource variable creation Checking the boolean tensor from the VarIsInitializedOp requires we evaluate the tensor in python. This is undesirable for ASYNC execution as it results in the python code being blocked, waiting for all queued operations to be completed before synchronously evaluating the tensor. PiperOrigin-RevId: 249113005 --- tensorflow/python/ops/resource_variable_ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 5f753041f41..8edac3c8cf1 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops @@ -150,10 +151,13 @@ def variable_handle_from_shape_and_dtype( # When in eager mode, explicitly ensure so here. When in graph mode, it's # ensured by always generating different variable names. exists = gen_resource_variable_ops.var_is_initialized_op(handle) - if exists: - raise ValueError("variable object with name '%s' already created. Use " - "get_variable() if reuse is desired." % - shared_name) + + # We create an assert Op instead of checking right away in order to be + # compatible with ASYNC execution mode. Further, since not all devices + # support string tensors, we encode the assertion string in the Op name + gen_logging_ops._assert( # pylint: disable=protected-access + math_ops.logical_not(exists), [exists], name="EagerVariableNameReuse") + with context.graph_mode(), ops.Graph().as_default() as graph: h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, @@ -1781,5 +1785,6 @@ def copy_to_graph_uninitialized(var): # pylint: enable=protected-access return new_variable +ops.NotDifferentiable("Assert") ops.NotDifferentiable("VarIsInitializedOp") ops.NotDifferentiable("VariableShape")