Prepare convert_to_constants.py for tensor equality changes.
PiperOrigin-RevId: 262655851
This commit is contained in:
parent
360e0db035
commit
643d109b3f
@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.training.saver import export_meta_graph
|
||||
|
||||
|
||||
@ -177,10 +178,12 @@ def _get_tensor_data(func):
|
||||
Dict
|
||||
"""
|
||||
tensor_data = {}
|
||||
map_index_to_variable = {
|
||||
func.captured_inputs.index(var.handle): var
|
||||
for var in func.graph.variables
|
||||
}
|
||||
map_index_to_variable = {}
|
||||
for var in func.graph.variables:
|
||||
for idx, captured_input in enumerate(func.captured_inputs):
|
||||
if var.handle is captured_input: # pylint: disable=protected-access
|
||||
map_index_to_variable[idx] = var
|
||||
break
|
||||
|
||||
# Iterates through all captures which are represented as Placeholders.
|
||||
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
|
||||
@ -353,9 +356,10 @@ def _construct_concrete_function(func, output_graph_def,
|
||||
"""
|
||||
# Create a ConcreteFunction from the new GraphDef.
|
||||
input_tensors = func.graph.internal_captures
|
||||
converted_inputs = set(
|
||||
converted_inputs = object_identity.ObjectIdentitySet(
|
||||
[input_tensors[index] for index in converted_input_indices])
|
||||
not_converted_inputs = set(func.inputs).difference(converted_inputs)
|
||||
not_converted_inputs = object_identity.ObjectIdentitySet(
|
||||
func.inputs).difference(converted_inputs)
|
||||
not_converted_inputs_map = {
|
||||
tensor.name: tensor for tensor in not_converted_inputs
|
||||
}
|
||||
|
@ -156,6 +156,12 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
||||
def __init__(self, *args):
|
||||
self._storage = set([self._wrap_key(obj) for obj in list(*args)])
|
||||
|
||||
@staticmethod
|
||||
def _from_storage(storage):
|
||||
result = ObjectIdentitySet()
|
||||
result._storage = storage # pylint: disable=protected-access
|
||||
return result
|
||||
|
||||
def _wrap_key(self, key):
|
||||
return _ObjectIdentityWrapper(key)
|
||||
|
||||
@ -174,6 +180,10 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
||||
def intersection(self, items):
|
||||
return self._storage.intersection([self._wrap_key(item) for item in items])
|
||||
|
||||
def difference(self, items):
|
||||
return ObjectIdentitySet._from_storage(
|
||||
self._storage.difference([self._wrap_key(item) for item in items]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
|
@ -30,5 +30,23 @@ class ObjectIdentityWrapperTest(test.TestCase):
|
||||
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
|
||||
|
||||
|
||||
class ObjectIdentitySetTest(test.TestCase):
|
||||
|
||||
def testDifference(self):
|
||||
|
||||
class Element(object):
|
||||
pass
|
||||
|
||||
a = Element()
|
||||
b = Element()
|
||||
c = Element()
|
||||
set1 = object_identity.ObjectIdentitySet([a, b])
|
||||
set2 = object_identity.ObjectIdentitySet([b, c])
|
||||
diff_set = set1.difference(set2)
|
||||
self.assertIn(a, diff_set)
|
||||
self.assertNotIn(b, diff_set)
|
||||
self.assertNotIn(c, diff_set)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user