Removing the init_scope from table initialization except for control flow
cases. Earlier we tried to remove it in all cases but that doesn't work because if we create a table in a control flow setting then we can't run the initializer outside anymore. Fixes #29872 and #27086. PiperOrigin-RevId: 258792706
This commit is contained in:
parent
00bba04955
commit
f6e0ec468a
@ -37,6 +37,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
@ -368,6 +369,31 @@ class StaticHashTableTest(BaseLookupTableTest):
|
||||
result = lookup_table_func(constant_op.constant([2, -1, 1]))
|
||||
self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
|
||||
|
||||
def testTwoTablesInControlFlow(self):
|
||||
keys = constant_op.constant([1, 2, 3], dtypes.int32)
|
||||
values = constant_op.constant([5, 10, 15], dtypes.int32)
|
||||
|
||||
def table_func1(x):
|
||||
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
||||
keys, values), -1)
|
||||
return table.lookup(x)
|
||||
|
||||
elems = np.array([2, 4, 1], dtype=np.int32)
|
||||
result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32)
|
||||
|
||||
def table_func2(x):
|
||||
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
||||
keys, values), -1)
|
||||
return table.lookup(x)
|
||||
|
||||
elems = np.array([2, 4, 1], dtype=np.int32)
|
||||
result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32)
|
||||
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
|
||||
self.assertAllEqual([10, -1, 5], self.evaluate(result1))
|
||||
self.assertAllEqual([10, -1, 5], self.evaluate(result2))
|
||||
|
||||
|
||||
class KeyValueTensorInitializerTest(BaseLookupTableTest):
|
||||
|
||||
|
||||
@ -171,6 +171,11 @@ class InitializableLookupTableBase(LookupInterface):
|
||||
self._initializer = self._track_trackable(initializer, "_initializer")
|
||||
with ops.init_scope():
|
||||
self._resource_handle = self._create_resource()
|
||||
if (not context.executing_eagerly() and
|
||||
ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
|
||||
with ops.init_scope():
|
||||
self._init_op = self._initialize()
|
||||
else:
|
||||
self._init_op = self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
@ -420,7 +425,13 @@ class KeyValueTensorInitializer(TableInitializerBase):
|
||||
value_dtype: The `values` data type. Used when `values` is a python array.
|
||||
name: A name for the operation (optional).
|
||||
"""
|
||||
with ops.init_scope():
|
||||
if (not context.executing_eagerly() and
|
||||
ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
|
||||
with ops.init_scope():
|
||||
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
|
||||
self._values = ops.convert_to_tensor(
|
||||
values, dtype=value_dtype, name="values")
|
||||
else:
|
||||
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
|
||||
self._values = ops.convert_to_tensor(
|
||||
values, dtype=value_dtype, name="values")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user