Expanding documentation for tf.nest.assert_same_structure with examples.
PiperOrigin-RevId: 346251102 Change-Id: Id59890c4ceee299b6f63dc851dd2997020cf8ccb
This commit is contained in:
parent
55691acd36
commit
0939f0b7a8
@ -438,15 +438,62 @@ def assert_same_structure(nest1, nest2, check_types=True,
|
||||
expand_composites=False):
|
||||
"""Asserts that two structures are nested in the same way.
|
||||
|
||||
Note that namedtuples with identical name and fields are always considered
|
||||
to have the same shallow structure (even with `check_types=True`).
|
||||
For instance, this code will print `True`:
|
||||
Note the method does not check the types of data inside the structures.
|
||||
|
||||
```python
|
||||
def nt(a, b):
|
||||
return collections.namedtuple('foo', 'a b')(a, b)
|
||||
print(assert_same_structure(nt(0, 1), nt(2, 3)))
|
||||
```
|
||||
Examples:
|
||||
|
||||
* These scalar vs. scalar comparisons will pass:
|
||||
|
||||
>>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
|
||||
>>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
|
||||
|
||||
* These sequence vs. sequence comparisons will pass:
|
||||
|
||||
>>> structure1 = (((1, 2), 3), 4, (5, 6))
|
||||
>>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
|
||||
>>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
|
||||
>>> tf.nest.assert_same_structure(structure1, structure2)
|
||||
>>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
|
||||
|
||||
>>> import collections
|
||||
>>> tf.nest.assert_same_structure(
|
||||
... collections.namedtuple("bar", "a b")(1, 2),
|
||||
... collections.namedtuple("foo", "a b")(2, 3),
|
||||
... check_types=False)
|
||||
|
||||
>>> tf.nest.assert_same_structure(
|
||||
... collections.namedtuple("bar", "a b")(1, 2),
|
||||
... { "a": 1, "b": 2 },
|
||||
... check_types=False)
|
||||
|
||||
>>> tf.nest.assert_same_structure(
|
||||
... { "a": 1, "b": 2, "c": 3 },
|
||||
... { "c": 6, "b": 5, "a": 4 })
|
||||
|
||||
>>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
|
||||
... values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
... row_splits=[0, 4, 4, 7, 8, 8])
|
||||
>>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
|
||||
... values=[3, 1, 4],
|
||||
... row_splits=[0, 3])
|
||||
>>> tf.nest.assert_same_structure(
|
||||
... ragged_tensor1,
|
||||
... ragged_tensor2,
|
||||
... expand_composites=True)
|
||||
|
||||
* These examples will raise exceptions:
|
||||
|
||||
>>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: The two structures don't have the same nested structure
|
||||
|
||||
>>> tf.nest.assert_same_structure(
|
||||
... collections.namedtuple('bar', 'a b')(1, 2),
|
||||
... collections.namedtuple('foo', 'a b')(2, 3))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: The two structures don't have the same nested structure
|
||||
|
||||
Args:
|
||||
nest1: an arbitrarily nested structure.
|
||||
|
Loading…
x
Reference in New Issue
Block a user