From 9c510da34b40bf0ba1865e52ea32305313f143bd Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 4 Nov 2020 13:59:11 -0800 Subject: [PATCH] Expose nest related function to tf.__internal__ API. PiperOrigin-RevId: 340722550 Change-Id: Ic5f7451346dbd0543e0cc25c90bcd16edfff212a --- tensorflow/python/keras/utils/tf_utils.py | 2 +- .../tools/api/generator/api_init_files.bzl | 1 + tensorflow/python/util/nest.py | 18 ++++++++++ .../v2/tensorflow.__internal__.nest.pbtxt | 35 +++++++++++++++++++ .../golden/v2/tensorflow.__internal__.pbtxt | 4 +++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 02974467b97..9d716d09176 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -128,7 +128,7 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested): raise ValueError( 'Received non-atomic and non-sequence element: {}'.format(nested)) if nest._is_mapping(nested): - values = [nested[k] for k in nest._sorted(nested)] + values = [nested[k] for k in sorted(nested.keys())] elif nest._is_attrs(nested): values = _astuple(nested) else: diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 3daf9eedb41..8dac33cec21 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES = [ "__internal__/distribute/__init__.py", "__internal__/distribute/combinations/__init__.py", "__internal__/distribute/multi_process_runner/__init__.py", + "__internal__/nest/__init__.py", "__internal__/test/__init__.py", "__internal__/test/combinations/__init__.py", "__internal__/tf2/__init__.py", diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 1a547b89d35..cdd6d0cbbad 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -126,6 +126,19 @@ _is_mutable_mapping = _pywrap_utils.IsMutableMapping _is_mapping = _pywrap_utils.IsMapping +@tf_export("__internal__.nest.is_attrs", v1=[]) +def is_attrs(obj): + """Returns a true if its input is an instance of an attr.s decorated class.""" + return _is_attrs(obj) + + +@tf_export("__internal__.nest.is_mapping", v1=[]) +def is_mapping(obj): + """Returns a true if its input is a collections.Mapping.""" + return is_mapping(obj) + + +@tf_export("__internal__.nest.sequence_like", v1=[]) def _sequence_like(instance, args): """Converts the sequence `args` to the same type as `instance`. @@ -894,6 +907,7 @@ def assert_shallow_structure(shallow_tree, expand_composites=expand_composites) +@tf_export("__internal__.nest.flatten_up_to", v1=[]) def flatten_up_to(shallow_tree, input_tree, check_types=True, expand_composites=False): """Flattens `input_tree` up to `shallow_tree`. @@ -1082,6 +1096,7 @@ def flatten_with_tuple_paths_up_to(shallow_tree, return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq)) +@tf_export("__internal__.nest.map_structure_up_to", v1=[]) def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): """Applies a function or op to a number of partially flattened inputs. @@ -1261,6 +1276,7 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): expand_composites=expand_composites) +@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[]) def get_traverse_shallow_structure(traverse_fn, structure, expand_composites=False): """Generates a shallow structure from a `traverse_fn` and `structure`. @@ -1331,6 +1347,7 @@ def get_traverse_shallow_structure(traverse_fn, structure, return _sequence_like(structure, level_traverse) +@tf_export("__internal__.nest.yield_flat_paths", v1=[]) def yield_flat_paths(nest, expand_composites=False): """Yields paths for some nested structure. @@ -1425,6 +1442,7 @@ def flatten_with_tuple_paths(structure, expand_composites=False): flatten(structure, expand_composites=expand_composites))) +@tf_export("__internal__.nest.list_to_tuple", v1=[]) def list_to_tuple(structure): """Replace all lists with tuples. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt new file mode 100644 index 00000000000..feca4a08adb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.__internal__.nest" +tf_module { + member_method { + name: "flatten_up_to" + argspec: "args=[\'shallow_tree\', \'input_tree\', \'check_types\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], " + } + member_method { + name: "get_traverse_shallow_structure" + argspec: "args=[\'traverse_fn\', \'structure\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "is_attrs" + argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_mapping" + argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "list_to_tuple" + argspec: "args=[\'structure\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map_structure_up_to" + argspec: "args=[\'shallow_tree\', \'func\'], varargs=inputs, keywords=kwargs, defaults=None" + } + member_method { + name: "sequence_like" + argspec: "args=[\'instance\', \'args\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yield_flat_paths" + argspec: "args=[\'nest\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'False\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt index b6a385783c6..35d23f01a51 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "distribute" mtype: "" } + member { + name: "nest" + mtype: "" + } member { name: "test" mtype: ""