From 0939f0b7a8aed30559b54ae0dd5b9e6b312de1c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 Dec 2020 22:11:05 -0800 Subject: [PATCH] Expanding documentation for tf.nest.assert_same_structure with examples. PiperOrigin-RevId: 346251102 Change-Id: Id59890c4ceee299b6f63dc851dd2997020cf8ccb --- tensorflow/python/util/nest.py | 63 +++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index db3ad27310c..21a61e47d50 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -438,15 +438,62 @@ def assert_same_structure(nest1, nest2, check_types=True, expand_composites=False): """Asserts that two structures are nested in the same way. - Note that namedtuples with identical name and fields are always considered - to have the same shallow structure (even with `check_types=True`). - For instance, this code will print `True`: + Note the method does not check the types of data inside the structures. - ```python - def nt(a, b): - return collections.namedtuple('foo', 'a b')(a, b) - print(assert_same_structure(nt(0, 1), nt(2, 3))) - ``` + Examples: + + * These scalar vs. scalar comparisons will pass: + + >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32)) + >>> tf.nest.assert_same_structure("abc", np.array([1, 2])) + + * These sequence vs. sequence comparisons will pass: + + >>> structure1 = (((1, 2), 3), 4, (5, 6)) + >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]] + >>> tf.nest.assert_same_structure(structure1, structure2) + >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False) + + >>> import collections + >>> tf.nest.assert_same_structure( + ... collections.namedtuple("bar", "a b")(1, 2), + ... collections.namedtuple("foo", "a b")(2, 3), + ... check_types=False) + + >>> tf.nest.assert_same_structure( + ... collections.namedtuple("bar", "a b")(1, 2), + ... { "a": 1, "b": 2 }, + ... check_types=False) + + >>> tf.nest.assert_same_structure( + ... { "a": 1, "b": 2, "c": 3 }, + ... { "c": 6, "b": 5, "a": 4 }) + + >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_splits=[0, 4, 4, 7, 8, 8]) + >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits( + ... values=[3, 1, 4], + ... row_splits=[0, 3]) + >>> tf.nest.assert_same_structure( + ... ragged_tensor1, + ... ragged_tensor2, + ... expand_composites=True) + + * These examples will raise exceptions: + + >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1])) + Traceback (most recent call last): + ... + ValueError: The two structures don't have the same nested structure + + >>> tf.nest.assert_same_structure( + ... collections.namedtuple('bar', 'a b')(1, 2), + ... collections.namedtuple('foo', 'a b')(2, 3)) + Traceback (most recent call last): + ... + TypeError: The two structures don't have the same nested structure Args: nest1: an arbitrarily nested structure.