diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 0db7d27e7e7..2d249be1314 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver @@ -368,6 +369,31 @@ class StaticHashTableTest(BaseLookupTableTest): result = lookup_table_func(constant_op.constant([2, -1, 1])) self.assertAllEqual([b"surgery", b"n/a", b"salad"], result) + def testTwoTablesInControlFlow(self): + keys = constant_op.constant([1, 2, 3], dtypes.int32) + values = constant_op.constant([5, 10, 15], dtypes.int32) + + def table_func1(x): + table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + keys, values), -1) + return table.lookup(x) + + elems = np.array([2, 4, 1], dtype=np.int32) + result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32) + + def table_func2(x): + table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + keys, values), -1) + return table.lookup(x) + + elems = np.array([2, 4, 1], dtype=np.int32) + result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32) + + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllEqual([10, -1, 5], self.evaluate(result1)) + self.assertAllEqual([10, -1, 5], self.evaluate(result2)) + class KeyValueTensorInitializerTest(BaseLookupTableTest): diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 3b726a611fa..802a5b2d261 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -171,6 +171,11 @@ class InitializableLookupTableBase(LookupInterface): self._initializer = self._track_trackable(initializer, "_initializer") with ops.init_scope(): self._resource_handle = self._create_resource() + if (not context.executing_eagerly() and + ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access + with ops.init_scope(): + self._init_op = self._initialize() + else: self._init_op = self._initialize() def _initialize(self): @@ -420,7 +425,13 @@ class KeyValueTensorInitializer(TableInitializerBase): value_dtype: The `values` data type. Used when `values` is a python array. name: A name for the operation (optional). """ - with ops.init_scope(): + if (not context.executing_eagerly() and + ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access + with ops.init_scope(): + self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") + self._values = ops.convert_to_tensor( + values, dtype=value_dtype, name="values") + else: self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") self._values = ops.convert_to_tensor( values, dtype=value_dtype, name="values")