Prepare convert_to_constants.py for tensor equality changes.

PiperOrigin-RevId: 262655851
This commit is contained in:
Saurabh Saxena 2019-08-09 16:23:10 -07:00 committed by Goldie Gadde
parent 360e0db035
commit 643d109b3f
3 changed files with 38 additions and 6 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.util import object_identity
from tensorflow.python.training.saver import export_meta_graph
@ -177,10 +178,12 @@ def _get_tensor_data(func):
Dict
"""
tensor_data = {}
map_index_to_variable = {
func.captured_inputs.index(var.handle): var
for var in func.graph.variables
}
map_index_to_variable = {}
for var in func.graph.variables:
for idx, captured_input in enumerate(func.captured_inputs):
if var.handle is captured_input: # pylint: disable=protected-access
map_index_to_variable[idx] = var
break
# Iterates through all captures which are represented as Placeholders.
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
@ -353,9 +356,10 @@ def _construct_concrete_function(func, output_graph_def,
"""
# Create a ConcreteFunction from the new GraphDef.
input_tensors = func.graph.internal_captures
converted_inputs = set(
converted_inputs = object_identity.ObjectIdentitySet(
[input_tensors[index] for index in converted_input_indices])
not_converted_inputs = set(func.inputs).difference(converted_inputs)
not_converted_inputs = object_identity.ObjectIdentitySet(
func.inputs).difference(converted_inputs)
not_converted_inputs_map = {
tensor.name: tensor for tensor in not_converted_inputs
}

View File

@ -156,6 +156,12 @@ class ObjectIdentitySet(collections_abc.MutableSet):
def __init__(self, *args):
self._storage = set([self._wrap_key(obj) for obj in list(*args)])
@staticmethod
def _from_storage(storage):
result = ObjectIdentitySet()
result._storage = storage # pylint: disable=protected-access
return result
def _wrap_key(self, key):
return _ObjectIdentityWrapper(key)
@ -174,6 +180,10 @@ class ObjectIdentitySet(collections_abc.MutableSet):
def intersection(self, items):
return self._storage.intersection([self._wrap_key(item) for item in items])
def difference(self, items):
return ObjectIdentitySet._from_storage(
self._storage.difference([self._wrap_key(item) for item in items]))
def __len__(self):
return len(self._storage)

View File

@ -30,5 +30,23 @@ class ObjectIdentityWrapperTest(test.TestCase):
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
class ObjectIdentitySetTest(test.TestCase):
def testDifference(self):
class Element(object):
pass
a = Element()
b = Element()
c = Element()
set1 = object_identity.ObjectIdentitySet([a, b])
set2 = object_identity.ObjectIdentitySet([b, c])
diff_set = set1.difference(set2)
self.assertIn(a, diff_set)
self.assertNotIn(b, diff_set)
self.assertNotIn(c, diff_set)
if __name__ == '__main__':
test.main()