From 15c0cb71fe1db995a878fabd0a32fb5ce9c24d7b Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Thu, 9 Apr 2020 08:35:28 -0700 Subject: [PATCH] Updated tf.nest to be more consistent in its treatment of namdetuples. Prior to this, some tf.nest functions treated a type T as a namedtuple only when it was a *direct* subclass of `tuple`, while others treated T as a namedtuple if as long as it was a direct or indirect subclass of `tuple`. (E.g., tf.nest.assert_same_structure required direct subclasses, but tf.nest.assert_shallow_structure allowed indirect subclasses.) Now, all tf.nest methods treat T as a namedtuple if it is a direct or indirect subclass of tuple (along with a few other conditions, such as having a _fields attribute). Removed a check in function deserialization that was made redundant by this change. PiperOrigin-RevId: 305689861 Change-Id: Iacc463bd40aef13aa883c61c77360200f36068c5 --- .../python/saved_model/function_deserialization.py | 10 ---------- tensorflow/python/util/util.cc | 4 ++-- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index e8a9514dd56..9fcffc8ccdf 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -93,16 +93,6 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False - try: - # Verify that no input elements were dropped during flattening. - repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) - # TODO(b/129422719): Namedtuple subclasses re-created through - # saved_model.load don't compare equal in type to the original in - # assert_same_structure. Fix that and we can take out check_types=False - # here. - nest.assert_same_structure(inputs, repacked, check_types=False) - except (TypeError, ValueError): - return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 6da3fdbf945..1d0dd695d74 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -730,9 +730,9 @@ bool AssertSameStructureHelper( // We treat two different namedtuples with identical name and fields // as having the same type. - const PyObject* o1_tuple = IsNamedtuple(o1, true); + const PyObject* o1_tuple = IsNamedtuple(o1, false); if (o1_tuple == nullptr) return false; - const PyObject* o2_tuple = IsNamedtuple(o2, true); + const PyObject* o2_tuple = IsNamedtuple(o2, false); if (o2_tuple == nullptr) { Py_DECREF(o1_tuple); return false;