Updated tf.nest to be more consistent in its treatment of namdetuples. Prior to this, some tf.nest functions treated a type T as a namedtuple only when it was a *direct* subclass of tuple
, while others treated T as a namedtuple if as long as it was a direct or indirect subclass of tuple
. (E.g., tf.nest.assert_same_structure required direct subclasses, but tf.nest.assert_shallow_structure allowed indirect subclasses.)
Now, all tf.nest methods treat T as a namedtuple if it is a direct or indirect subclass of tuple (along with a few other conditions, such as having a _fields attribute). Removed a check in function deserialization that was made redundant by this change. PiperOrigin-RevId: 305689861 Change-Id: Iacc463bd40aef13aa883c61c77360200f36068c5
This commit is contained in:
parent
c3e12bcd6d
commit
15c0cb71fe
@ -93,16 +93,6 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
|
||||
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
try:
|
||||
# Verify that no input elements were dropped during flattening.
|
||||
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
|
||||
# TODO(b/129422719): Namedtuple subclasses re-created through
|
||||
# saved_model.load don't compare equal in type to the original in
|
||||
# assert_same_structure. Fix that and we can take out check_types=False
|
||||
# here.
|
||||
nest.assert_same_structure(inputs, repacked, check_types=False)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
|
||||
if isinstance(expected, tensor_spec.TensorSpec):
|
||||
|
@ -730,9 +730,9 @@ bool AssertSameStructureHelper(
|
||||
|
||||
// We treat two different namedtuples with identical name and fields
|
||||
// as having the same type.
|
||||
const PyObject* o1_tuple = IsNamedtuple(o1, true);
|
||||
const PyObject* o1_tuple = IsNamedtuple(o1, false);
|
||||
if (o1_tuple == nullptr) return false;
|
||||
const PyObject* o2_tuple = IsNamedtuple(o2, true);
|
||||
const PyObject* o2_tuple = IsNamedtuple(o2, false);
|
||||
if (o2_tuple == nullptr) {
|
||||
Py_DECREF(o1_tuple);
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user