Removing the init_scope from table initialization except for control flow

cases. Earlier we tried to remove it in all cases but that doesn't work because
if we create a table in a control flow setting then we can't run the
initializer outside anymore.

Fixes #29872 and
#27086.

PiperOrigin-RevId: 258792706
This commit is contained in:
Rohan Jain 2019-07-18 10:05:52 -07:00 committed by TensorFlower Gardener
parent 00bba04955
commit f6e0ec468a
2 changed files with 38 additions and 1 deletions

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@ -368,6 +369,31 @@ class StaticHashTableTest(BaseLookupTableTest):
result = lookup_table_func(constant_op.constant([2, -1, 1]))
self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
def testTwoTablesInControlFlow(self):
keys = constant_op.constant([1, 2, 3], dtypes.int32)
values = constant_op.constant([5, 10, 15], dtypes.int32)
def table_func1(x):
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
keys, values), -1)
return table.lookup(x)
elems = np.array([2, 4, 1], dtype=np.int32)
result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32)
def table_func2(x):
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
keys, values), -1)
return table.lookup(x)
elems = np.array([2, 4, 1], dtype=np.int32)
result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32)
self.evaluate(lookup_ops.tables_initializer())
self.assertAllEqual([10, -1, 5], self.evaluate(result1))
self.assertAllEqual([10, -1, 5], self.evaluate(result2))
class KeyValueTensorInitializerTest(BaseLookupTableTest):

View File

@ -171,6 +171,11 @@ class InitializableLookupTableBase(LookupInterface):
self._initializer = self._track_trackable(initializer, "_initializer")
with ops.init_scope():
self._resource_handle = self._create_resource()
if (not context.executing_eagerly() and
ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
with ops.init_scope():
self._init_op = self._initialize()
else:
self._init_op = self._initialize()
def _initialize(self):
@ -420,7 +425,13 @@ class KeyValueTensorInitializer(TableInitializerBase):
value_dtype: The `values` data type. Used when `values` is a python array.
name: A name for the operation (optional).
"""
with ops.init_scope():
if (not context.executing_eagerly() and
ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
with ops.init_scope():
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
self._values = ops.convert_to_tensor(
values, dtype=value_dtype, name="values")
else:
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
self._values = ops.convert_to_tensor(
values, dtype=value_dtype, name="values")