First phase of nest.py migration to cc: move flatten_dict_items implementation from python to C++

PiperOrigin-RevId: 357815150
Change-Id: I3689d5f7b80a621ea3696d44b7b00116bcff6ee9
This commit is contained in:
A. Unique TensorFlower 2021-02-16 15:15:31 -08:00 committed by TensorFlower Gardener
parent 1b00e4d951
commit 63a277f28a
8 changed files with 265 additions and 26 deletions

View File

@ -237,6 +237,7 @@ py_library(
"//tensorflow/python/util",
"//tensorflow/python/util:_pywrap_checkpoint_reader",
"//tensorflow/python/util:_pywrap_kernel_registry",
"//tensorflow/python/util:_pywrap_nest",
"//tensorflow/python/util:_pywrap_stat_summarizer",
"//tensorflow/python/util:_pywrap_tfprof",
"//tensorflow/python/util:_pywrap_transform_graph",
@ -751,6 +752,7 @@ py_library(
deps = [
":_pywrap_debug_events_writer",
":_pywrap_events_writer",
"//tensorflow/python/util:_pywrap_nest",
"//tensorflow/python/util:_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
"//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
@ -5212,6 +5214,7 @@ pywrap_tensorflow_macro(
":model_analyzer_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/python/util:cpp_nest",
"//tensorflow/python/util:cpp_python_util",
"//tensorflow/python/util:function_parameter_canonicalizer",
"//tensorflow/python/util:kernel_registry",
@ -5289,6 +5292,7 @@ filegroup(
srcs = [
":bfloat16_lib", # bfloat16
":cost_analyzer_lib", # cost_analyzer
"//tensorflow/python/util:cpp_nest",
"//tensorflow/python/util:cpp_python_util",
"//tensorflow/python/util:kernel_registry",
":model_analyzer_lib", # model_analyzer

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1451,6 +1451,14 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._run(fn, 100000)
def benchmark_tf_flatten_dict_items(self):
nested = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
def fn():
nest.flatten_dict_items(nested)
self._run(fn, 100000)
def benchmark_tf_nn_convolution_overhead(self):
inputs = array_ops.ones((1, 1, 1, 1))
filters = array_ops.ones((1, 1, 1, 1))

View File

@ -89,6 +89,34 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_nest",
srcs = ["nest_wrapper.cc"],
hdrs = ["nest.h"],
module_name = "_pywrap_nest",
deps = [
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
cc_library(
name = "cpp_nest",
srcs = ["nest.cc"],
hdrs = ["nest.h"],
deps = [
":cpp_python_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:stringpiece",
"//tensorflow/python/lib/core:safe_pyobject_ptr",
"//third_party/python_runtime:headers",
],
alwayslink = 1,
)
tf_python_pybind_extension(
name = "_pywrap_kernel_registry",
srcs = ["kernel_registry_wrapper.cc"],

View File

@ -0,0 +1,146 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/util/nest.h"
#include <utility>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace tensorflow {
namespace {
// Gets a string representation of the input object.
//
// Args:
// o: a python object.
// length: If set to negative, the whole string is returned. Otherwise, the
// string gets clipped to 'length' in size.
//
// Returns:
// A string representation.
std::string PyObject_ToString(PyObject* o, int length = -1) {
auto str_o = make_safe(PyObject_Str(o));
std::string str = PyUnicode_AsUTF8(str_o.get());
if (length < 0 || str.size() <= length) {
return str;
}
tensorflow::StringPiece str_piece(str);
return tensorflow::strings::StrCat(str_piece.substr(length), "...");
}
// Gets a list of keys from a dict or mapping type object.
//
// Args:
// o: a dictionary or mapping type object.
//
// Returns:
// A new reference to a list.
//
// Raises:
// TypeError: if `o` is not a dict or mapping type object.
PyObject* GetKeysFromDictOrMapping(PyObject* o) {
if (PyDict_Check(o)) {
return PyDict_Keys(o);
} else if (PyMapping_Check(o)) {
return PyMapping_Keys(o);
} else {
auto* o_type = Py_TYPE(o);
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat(
"Expecting a type compatible with dict or mapping, got '",
o_type->tp_name, "'")
.c_str());
return nullptr;
}
}
} // namespace
PyObject* FlattenDictItems(PyObject* dict) {
if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
PyErr_SetString(PyExc_TypeError,
tensorflow::strings::StrCat(
"FlattenDictItems: 'dict' must be a dictionary or ",
"collection.Mapping type object, instead of '",
Py_TYPE(dict)->tp_name, "'.")
.c_str());
return nullptr;
}
PyObject* flat_dictionary = PyDict_New();
auto keys = make_safe(GetKeysFromDictOrMapping(dict));
for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
auto* key = PyList_GetItem(keys.get(), i);
// We use a general approach in case 'dict' is a PyMapping type,
// but not a PyDict type.
auto* value = PyObject_GetItem(dict, key);
if (swig::IsSequence(key)) {
// The dict might contain list - list pairs.
auto flat_keys = make_safe(swig::Flatten(key, false));
auto flat_values = make_safe(swig::Flatten(value, false));
size_t flat_keys_sz = PyList_Size(flat_keys.get());
size_t flat_values_sz = PyList_Size(flat_values.get());
if (flat_keys_sz != flat_values_sz) {
PyErr_SetString(
PyExc_ValueError,
tensorflow::strings::StrCat(
"Could not flatten dictionary. Key had ", flat_keys_sz,
" elements, but value had ", flat_values_sz,
" elements. Key: ", PyObject_ToString(flat_keys.get()),
", value: ", PyObject_ToString(flat_values.get()), ".")
.c_str());
Py_DecRef(flat_dictionary);
return nullptr;
}
for (size_t i = 0; i < flat_keys_sz; ++i) {
auto* flat_key = PyList_GetItem(flat_keys.get(), i);
auto* flat_value = PyList_GetItem(flat_values.get(), i);
if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
PyErr_SetString(
PyExc_ValueError,
tensorflow::strings::StrCat(
"Cannot flatten dict because this key is not unique: ",
PyObject_ToString(flat_key))
.c_str());
Py_DecRef(flat_dictionary);
return nullptr;
}
PyDict_SetItem(flat_dictionary, flat_key, flat_value);
}
} else {
if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
PyErr_SetString(
PyExc_ValueError,
tensorflow::strings::StrCat(
"Cannot flatten dict because this key is not unique: ",
PyObject_ToString(key))
.c_str());
Py_DecRef(flat_dictionary);
return nullptr;
}
PyDict_SetItem(flat_dictionary, key, value);
}
// Manually decrease because PyObject_GetItem() returns a new reference.
Py_DECREF(value);
}
return flat_dictionary;
}
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_COMPAT_NEST_H_
#define TENSORFLOW_PYTHON_COMPAT_NEST_H_
#include <Python.h>
namespace tensorflow {
// Returns a dictionary with flattened keys and values.
//
// Args:
// dict: the dictionary to zip
//
// Returns:
// An new reference to the zipped dictionary.
//
// Raises:
// TypeError: If the input is not a dictionary.
// ValueError: If any key and value do not have the same structure layout, or
// if keys are not unique.
PyObject* FlattenDictItems(PyObject* dict);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_COMPAT_NEST_H_

View File

@ -1,4 +1,4 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -48,6 +48,7 @@ import six as _six
import wrapt as _wrapt
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import _pywrap_nest
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util.compat import collections_abc as _collections_abc
from tensorflow.python.util.tf_export import tf_export
@ -562,30 +563,7 @@ def flatten_dict_items(dictionary):
ValueError: If any key and value do not have the same structure layout, or
if keys are not unique.
"""
if not isinstance(dictionary, (dict, _collections_abc.Mapping)):
raise TypeError("input must be a dictionary")
flat_dictionary = {}
for i, v in _six.iteritems(dictionary):
if not is_sequence(i):
if i in flat_dictionary:
raise ValueError(
"Could not flatten dictionary: key %s is not unique." % i)
flat_dictionary[i] = v
else:
flat_i = flatten(i)
flat_v = flatten(v)
if len(flat_i) != len(flat_v):
raise ValueError(
"Could not flatten dictionary. Key had %d elements, but value had "
"%d elements. Key: %s, value: %s."
% (len(flat_i), len(flat_v), flat_i, flat_v))
for new_i, new_v in zip(flat_i, flat_v):
if new_i in flat_dictionary:
raise ValueError(
"Could not flatten dictionary: key %s is not unique."
% (new_i))
flat_dictionary[new_i] = new_v
return flat_dictionary
return _pywrap_nest.FlattenDictItems(dictionary)
def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None):

View File

@ -0,0 +1,35 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/util/nest.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_nest, m) {
m.doc() = R"pbdoc(
_pywrap_nest
-----
)pbdoc";
m.def(
"FlattenDictItems",
[](const py::handle& dict) {
return tensorflow::PyoOrThrow(tensorflow::FlattenDictItems(dict.ptr()));
},
R"pbdoc(
Returns a dictionary with flattened keys and values.
)pbdoc");
}

View File

@ -22,6 +22,9 @@ tensorflow::swig::RegisterType
tensorflow::swig::IsEagerTensorSlow
tensorflow::swig::GetRegisteredPyObject
[//tensorflow/python/util:cpp_nest] # nest
tensorflow::FlattenDictItems
[//tensorflow/core/util:port] # util_port
tensorflow::IsGoogleCudaEnabled
tensorflow::IsBuiltWithROCm