Bug fix in tf.print() with OrderedDict where if an OrderedDict didn't have the keys sorted, the keys and values were not being printed in accordance with their correct mapping.

PiperOrigin-RevId: 335386269
Change-Id: I3c546b40530e863e4f7f7b93ff183f1f3ff43acc
This commit is contained in:
A. Unique TensorFlower 2020-10-05 04:03:55 -07:00 committed by TensorFlower Gardener
parent e8d84bd57a
commit bf63af8828
2 changed files with 22 additions and 2 deletions

View File

@ -1,4 +1,4 @@
# Release 2.4.0
h# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
@ -300,6 +300,12 @@
* Add parameter allow_mixed_precision_on_unconverted_ops to
TrtConversionParams.
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as py_collections
import os
import pprint
import random
@ -305,8 +306,21 @@ def print_v2(*inputs, **kwargs):
# printed input.
templates = []
tensors = []
# If an input to the print function is of type `OrderedDict`, sort its
# elements by the keys for consistency with the ordering of `nest.flatten`.
# This is not needed for `dict` types because `pprint.pformat()` takes care
# of printing the template in a sorted fashion.
inputs_ordered_dicts_sorted = []
for input_ in inputs:
if isinstance(input_, py_collections.OrderedDict):
inputs_ordered_dicts_sorted.append(
py_collections.OrderedDict(sorted(input_.items())))
else:
inputs_ordered_dicts_sorted.append(input_)
tensor_free_structure = nest.map_structure(
lambda x: "" if tensor_util.is_tensor(x) else x, inputs)
lambda x: "" if tensor_util.is_tensor(x) else x,
inputs_ordered_dicts_sorted)
tensor_free_template = " ".join(
pprint.pformat(x) for x in tensor_free_structure)
placeholder = _generate_placeholder_string(tensor_free_template)