Updated tf.nest to be more consistent in its treatment of namdetuples. Prior to this, some tf.nest functions treated a type T as a namedtuple only when it was a *direct* subclass of tuple, while others treated T as a namedtuple if as long as it was a direct or indirect subclass of tuple. (E.g., tf.nest.assert_same_structure required direct subclasses, but tf.nest.assert_shallow_structure allowed indirect subclasses.)

Now, all tf.nest methods treat T as a namedtuple if it is a direct or indirect subclass of tuple (along with a few other conditions, such as having a _fields attribute).

Removed a check in function deserialization that was made redundant by this change.

PiperOrigin-RevId: 305689861
Change-Id: Iacc463bd40aef13aa883c61c77360200f36068c5
This commit is contained in:
Edward Loper 2020-04-09 08:35:28 -07:00 committed by TensorFlower Gardener
parent c3e12bcd6d
commit 15c0cb71fe
2 changed files with 2 additions and 12 deletions

View File

@ -93,16 +93,6 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
flatten_inputs = nest.flatten_up_to(expected_structure, inputs) flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
try:
# Verify that no input elements were dropped during flattening.
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
# TODO(b/129422719): Namedtuple subclasses re-created through
# saved_model.load don't compare equal in type to the original in
# assert_same_structure. Fix that and we can take out check_types=False
# here.
nest.assert_same_structure(inputs, repacked, check_types=False)
except (TypeError, ValueError):
return False
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
if isinstance(expected, tensor_spec.TensorSpec): if isinstance(expected, tensor_spec.TensorSpec):

View File

@ -730,9 +730,9 @@ bool AssertSameStructureHelper(
// We treat two different namedtuples with identical name and fields // We treat two different namedtuples with identical name and fields
// as having the same type. // as having the same type.
const PyObject* o1_tuple = IsNamedtuple(o1, true); const PyObject* o1_tuple = IsNamedtuple(o1, false);
if (o1_tuple == nullptr) return false; if (o1_tuple == nullptr) return false;
const PyObject* o2_tuple = IsNamedtuple(o2, true); const PyObject* o2_tuple = IsNamedtuple(o2, false);
if (o2_tuple == nullptr) { if (o2_tuple == nullptr) {
Py_DECREF(o1_tuple); Py_DECREF(o1_tuple);
return false; return false;