diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0e20815f95a..89c3c6300e0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -237,6 +237,7 @@ py_library( "//tensorflow/python/util", "//tensorflow/python/util:_pywrap_checkpoint_reader", "//tensorflow/python/util:_pywrap_kernel_registry", + "//tensorflow/python/util:_pywrap_nest", "//tensorflow/python/util:_pywrap_stat_summarizer", "//tensorflow/python/util:_pywrap_tfprof", "//tensorflow/python/util:_pywrap_transform_graph", @@ -751,6 +752,7 @@ py_library( deps = [ ":_pywrap_debug_events_writer", ":_pywrap_events_writer", + "//tensorflow/python/util:_pywrap_nest", "//tensorflow/python/util:_pywrap_kernel_registry", ":_pywrap_py_exception_registry", "//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. @@ -5212,6 +5214,7 @@ pywrap_tensorflow_macro( ":model_analyzer_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/python/util:cpp_nest", "//tensorflow/python/util:cpp_python_util", "//tensorflow/python/util:function_parameter_canonicalizer", "//tensorflow/python/util:kernel_registry", diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 3287a1548ac..19760269a6e 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1451,6 +1451,14 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(fn, 100000) + def benchmark_tf_flatten_dict_items(self): + nested = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} + + def fn(): + nest.flatten_dict_items(nested) + + self._run(fn, 100000) + def benchmark_tf_nn_convolution_overhead(self): inputs = array_ops.ones((1, 1, 1, 1)) filters = array_ops.ones((1, 1, 1, 1)) diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index dfe44e2bfb9..e0eb8d00522 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -89,6 +89,34 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_nest", + srcs = ["nest_wrapper.cc"], + hdrs = ["nest.h"], + module_name = "_pywrap_nest", + deps = [ + "//tensorflow/python:pybind11_lib", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + +cc_library( + name = "cpp_nest", + srcs = ["nest.cc"], + hdrs = ["nest.h"], + deps = [ + ":cpp_python_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/python/lib/core:safe_pyobject_ptr", + "//third_party/python_runtime:headers", + ], + alwayslink = 1, +) + tf_python_pybind_extension( name = "_pywrap_kernel_registry", srcs = ["kernel_registry_wrapper.cc"], diff --git a/tensorflow/python/util/nest.cc b/tensorflow/python/util/nest.cc new file mode 100644 index 00000000000..63d6ab29771 --- /dev/null +++ b/tensorflow/python/util/nest.cc @@ -0,0 +1,146 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/util/nest.h" + +#include + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" +#include "tensorflow/python/util/util.h" + +namespace tensorflow { + +namespace { + +// Gets a string representation of the input object. +// +// Args: +// o: a python object. +// length: If set to negative, the whole string is returned. Otherwise, the +// string gets clipped to 'length' in size. +// +// Returns: +// A string representation. +std::string PyObject_ToString(PyObject* o, int length = -1) { + auto str_o = make_safe(PyObject_Str(o)); + std::string str = PyUnicode_AsUTF8(str_o.get()); + if (length < 0 || str.size() <= length) { + return str; + } + tensorflow::StringPiece str_piece(str); + return tensorflow::strings::StrCat(str_piece.substr(length), "..."); +} + +// Gets a list of keys from a dict or mapping type object. +// +// Args: +// o: a dictionary or mapping type object. +// +// Returns: +// A new reference to a list. +// +// Raises: +// TypeError: if `o` is not a dict or mapping type object. +PyObject* GetKeysFromDictOrMapping(PyObject* o) { + if (PyDict_Check(o)) { + return PyDict_Keys(o); + } else if (PyMapping_Check(o)) { + return PyMapping_Keys(o); + } else { + auto* o_type = Py_TYPE(o); + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting a type compatible with dict or mapping, got '", + o_type->tp_name, "'") + .c_str()); + return nullptr; + } +} + +} // namespace + +PyObject* FlattenDictItems(PyObject* dict) { + if (!PyDict_Check(dict) && !swig::IsMapping(dict)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat( + "FlattenDictItems: 'dict' must be a dictionary or ", + "collection.Mapping type object, instead of '", + Py_TYPE(dict)->tp_name, "'.") + .c_str()); + return nullptr; + } + PyObject* flat_dictionary = PyDict_New(); + auto keys = make_safe(GetKeysFromDictOrMapping(dict)); + for (size_t i = 0; i < PyList_Size(keys.get()); ++i) { + auto* key = PyList_GetItem(keys.get(), i); + // We use a general approach in case 'dict' is a PyMapping type, + // but not a PyDict type. + auto* value = PyObject_GetItem(dict, key); + if (swig::IsSequence(key)) { + // The dict might contain list - list pairs. + auto flat_keys = make_safe(swig::Flatten(key, false)); + auto flat_values = make_safe(swig::Flatten(value, false)); + size_t flat_keys_sz = PyList_Size(flat_keys.get()); + size_t flat_values_sz = PyList_Size(flat_values.get()); + if (flat_keys_sz != flat_values_sz) { + PyErr_SetString( + PyExc_ValueError, + tensorflow::strings::StrCat( + "Could not flatten dictionary. Key had ", flat_keys_sz, + " elements, but value had ", flat_values_sz, + " elements. Key: ", PyObject_ToString(flat_keys.get()), + ", value: ", PyObject_ToString(flat_values.get()), ".") + .c_str()); + Py_DecRef(flat_dictionary); + return nullptr; + } + for (size_t i = 0; i < flat_keys_sz; ++i) { + auto* flat_key = PyList_GetItem(flat_keys.get(), i); + auto* flat_value = PyList_GetItem(flat_values.get(), i); + if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) { + PyErr_SetString( + PyExc_ValueError, + tensorflow::strings::StrCat( + "Cannot flatten dict because this key is not unique: ", + PyObject_ToString(flat_key)) + .c_str()); + Py_DecRef(flat_dictionary); + return nullptr; + } + PyDict_SetItem(flat_dictionary, flat_key, flat_value); + } + } else { + if (PyDict_GetItem(flat_dictionary, key) != nullptr) { + PyErr_SetString( + PyExc_ValueError, + tensorflow::strings::StrCat( + "Cannot flatten dict because this key is not unique: ", + PyObject_ToString(key)) + .c_str()); + Py_DecRef(flat_dictionary); + return nullptr; + } + PyDict_SetItem(flat_dictionary, key, value); + } + // Manually decrease because PyObject_GetItem() returns a new reference. + Py_DECREF(value); + } + return flat_dictionary; +} + +} // namespace tensorflow diff --git a/tensorflow/python/util/nest.h b/tensorflow/python/util/nest.h new file mode 100644 index 00000000000..43829f44b14 --- /dev/null +++ b/tensorflow/python/util/nest.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_COMPAT_NEST_H_ +#define TENSORFLOW_PYTHON_COMPAT_NEST_H_ + +#include + +namespace tensorflow { +// Returns a dictionary with flattened keys and values. +// +// Args: +// dict: the dictionary to zip +// +// Returns: +// An new reference to the zipped dictionary. +// +// Raises: +// TypeError: If the input is not a dictionary. +// ValueError: If any key and value do not have the same structure layout, or +// if keys are not unique. +PyObject* FlattenDictItems(PyObject* dict); + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_COMPAT_NEST_H_ diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 910da988149..a3b1530ce06 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ import six as _six import wrapt as _wrapt from tensorflow.python.platform import tf_logging +from tensorflow.python.util import _pywrap_nest from tensorflow.python.util import _pywrap_utils from tensorflow.python.util.compat import collections_abc as _collections_abc from tensorflow.python.util.tf_export import tf_export @@ -562,30 +563,7 @@ def flatten_dict_items(dictionary): ValueError: If any key and value do not have the same structure layout, or if keys are not unique. """ - if not isinstance(dictionary, (dict, _collections_abc.Mapping)): - raise TypeError("input must be a dictionary") - flat_dictionary = {} - for i, v in _six.iteritems(dictionary): - if not is_sequence(i): - if i in flat_dictionary: - raise ValueError( - "Could not flatten dictionary: key %s is not unique." % i) - flat_dictionary[i] = v - else: - flat_i = flatten(i) - flat_v = flatten(v) - if len(flat_i) != len(flat_v): - raise ValueError( - "Could not flatten dictionary. Key had %d elements, but value had " - "%d elements. Key: %s, value: %s." - % (len(flat_i), len(flat_v), flat_i, flat_v)) - for new_i, new_v in zip(flat_i, flat_v): - if new_i in flat_dictionary: - raise ValueError( - "Could not flatten dictionary: key %s is not unique." - % (new_i)) - flat_dictionary[new_i] = new_v - return flat_dictionary + return _pywrap_nest.FlattenDictItems(dictionary) def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None): diff --git a/tensorflow/python/util/nest_wrapper.cc b/tensorflow/python/util/nest_wrapper.cc new file mode 100644 index 00000000000..6b87caa7619 --- /dev/null +++ b/tensorflow/python/util/nest_wrapper.cc @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/util/nest.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_pywrap_nest, m) { + m.doc() = R"pbdoc( + _pywrap_nest + ----- + )pbdoc"; + m.def( + "FlattenDictItems", + [](const py::handle& dict) { + return tensorflow::PyoOrThrow(tensorflow::FlattenDictItems(dict.ptr())); + }, + R"pbdoc( + Returns a dictionary with flattened keys and values. + )pbdoc"); +}