Bug fix in tf.print()
with OrderedDict
where if an OrderedDict
didn't have the keys sorted, the keys and values were not being printed in accordance with their correct mapping.
PiperOrigin-RevId: 335386269 Change-Id: I3c546b40530e863e4f7f7b93ff183f1f3ff43acc
This commit is contained in:
parent
e8d84bd57a
commit
bf63af8828
@ -1,4 +1,4 @@
|
||||
# Release 2.4.0
|
||||
h# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
@ -300,6 +300,12 @@
|
||||
* Add parameter allow_mixed_precision_on_unconverted_ops to
|
||||
TrtConversionParams.
|
||||
|
||||
* `tf.print`:
|
||||
|
||||
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||
didn't have the keys sorted, the keys and values were not being printed
|
||||
in accordance with their correct mapping.
|
||||
|
||||
* Other:
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections as py_collections
|
||||
import os
|
||||
import pprint
|
||||
import random
|
||||
@ -305,8 +306,21 @@ def print_v2(*inputs, **kwargs):
|
||||
# printed input.
|
||||
templates = []
|
||||
tensors = []
|
||||
# If an input to the print function is of type `OrderedDict`, sort its
|
||||
# elements by the keys for consistency with the ordering of `nest.flatten`.
|
||||
# This is not needed for `dict` types because `pprint.pformat()` takes care
|
||||
# of printing the template in a sorted fashion.
|
||||
inputs_ordered_dicts_sorted = []
|
||||
for input_ in inputs:
|
||||
if isinstance(input_, py_collections.OrderedDict):
|
||||
inputs_ordered_dicts_sorted.append(
|
||||
py_collections.OrderedDict(sorted(input_.items())))
|
||||
else:
|
||||
inputs_ordered_dicts_sorted.append(input_)
|
||||
tensor_free_structure = nest.map_structure(
|
||||
lambda x: "" if tensor_util.is_tensor(x) else x, inputs)
|
||||
lambda x: "" if tensor_util.is_tensor(x) else x,
|
||||
inputs_ordered_dicts_sorted)
|
||||
|
||||
tensor_free_template = " ".join(
|
||||
pprint.pformat(x) for x in tensor_free_structure)
|
||||
placeholder = _generate_placeholder_string(tensor_free_template)
|
||||
|
Loading…
Reference in New Issue
Block a user