Prepare convert_to_constants.py for tensor equality changes.
PiperOrigin-RevId: 262655851
This commit is contained in:
parent
360e0db035
commit
643d109b3f
@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.util import object_identity
|
||||||
from tensorflow.python.training.saver import export_meta_graph
|
from tensorflow.python.training.saver import export_meta_graph
|
||||||
|
|
||||||
|
|
||||||
@ -177,10 +178,12 @@ def _get_tensor_data(func):
|
|||||||
Dict
|
Dict
|
||||||
"""
|
"""
|
||||||
tensor_data = {}
|
tensor_data = {}
|
||||||
map_index_to_variable = {
|
map_index_to_variable = {}
|
||||||
func.captured_inputs.index(var.handle): var
|
for var in func.graph.variables:
|
||||||
for var in func.graph.variables
|
for idx, captured_input in enumerate(func.captured_inputs):
|
||||||
}
|
if var.handle is captured_input: # pylint: disable=protected-access
|
||||||
|
map_index_to_variable[idx] = var
|
||||||
|
break
|
||||||
|
|
||||||
# Iterates through all captures which are represented as Placeholders.
|
# Iterates through all captures which are represented as Placeholders.
|
||||||
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
|
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
|
||||||
@ -353,9 +356,10 @@ def _construct_concrete_function(func, output_graph_def,
|
|||||||
"""
|
"""
|
||||||
# Create a ConcreteFunction from the new GraphDef.
|
# Create a ConcreteFunction from the new GraphDef.
|
||||||
input_tensors = func.graph.internal_captures
|
input_tensors = func.graph.internal_captures
|
||||||
converted_inputs = set(
|
converted_inputs = object_identity.ObjectIdentitySet(
|
||||||
[input_tensors[index] for index in converted_input_indices])
|
[input_tensors[index] for index in converted_input_indices])
|
||||||
not_converted_inputs = set(func.inputs).difference(converted_inputs)
|
not_converted_inputs = object_identity.ObjectIdentitySet(
|
||||||
|
func.inputs).difference(converted_inputs)
|
||||||
not_converted_inputs_map = {
|
not_converted_inputs_map = {
|
||||||
tensor.name: tensor for tensor in not_converted_inputs
|
tensor.name: tensor for tensor in not_converted_inputs
|
||||||
}
|
}
|
||||||
|
@ -156,6 +156,12 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
|||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
self._storage = set([self._wrap_key(obj) for obj in list(*args)])
|
self._storage = set([self._wrap_key(obj) for obj in list(*args)])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _from_storage(storage):
|
||||||
|
result = ObjectIdentitySet()
|
||||||
|
result._storage = storage # pylint: disable=protected-access
|
||||||
|
return result
|
||||||
|
|
||||||
def _wrap_key(self, key):
|
def _wrap_key(self, key):
|
||||||
return _ObjectIdentityWrapper(key)
|
return _ObjectIdentityWrapper(key)
|
||||||
|
|
||||||
@ -174,6 +180,10 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
|||||||
def intersection(self, items):
|
def intersection(self, items):
|
||||||
return self._storage.intersection([self._wrap_key(item) for item in items])
|
return self._storage.intersection([self._wrap_key(item) for item in items])
|
||||||
|
|
||||||
|
def difference(self, items):
|
||||||
|
return ObjectIdentitySet._from_storage(
|
||||||
|
self._storage.difference([self._wrap_key(item) for item in items]))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._storage)
|
return len(self._storage)
|
||||||
|
|
||||||
|
@ -30,5 +30,23 @@ class ObjectIdentityWrapperTest(test.TestCase):
|
|||||||
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
|
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectIdentitySetTest(test.TestCase):
|
||||||
|
|
||||||
|
def testDifference(self):
|
||||||
|
|
||||||
|
class Element(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
a = Element()
|
||||||
|
b = Element()
|
||||||
|
c = Element()
|
||||||
|
set1 = object_identity.ObjectIdentitySet([a, b])
|
||||||
|
set2 = object_identity.ObjectIdentitySet([b, c])
|
||||||
|
diff_set = set1.difference(set2)
|
||||||
|
self.assertIn(a, diff_set)
|
||||||
|
self.assertNotIn(b, diff_set)
|
||||||
|
self.assertNotIn(c, diff_set)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user