First phase of nest.py migration to cc: move flatten_dict_items implementation from python to C++
PiperOrigin-RevId: 357815150 Change-Id: I3689d5f7b80a621ea3696d44b7b00116bcff6ee9
This commit is contained in:
parent
1b00e4d951
commit
63a277f28a
@ -237,6 +237,7 @@ py_library(
|
||||
"//tensorflow/python/util",
|
||||
"//tensorflow/python/util:_pywrap_checkpoint_reader",
|
||||
"//tensorflow/python/util:_pywrap_kernel_registry",
|
||||
"//tensorflow/python/util:_pywrap_nest",
|
||||
"//tensorflow/python/util:_pywrap_stat_summarizer",
|
||||
"//tensorflow/python/util:_pywrap_tfprof",
|
||||
"//tensorflow/python/util:_pywrap_transform_graph",
|
||||
@ -751,6 +752,7 @@ py_library(
|
||||
deps = [
|
||||
":_pywrap_debug_events_writer",
|
||||
":_pywrap_events_writer",
|
||||
"//tensorflow/python/util:_pywrap_nest",
|
||||
"//tensorflow/python/util:_pywrap_kernel_registry",
|
||||
":_pywrap_py_exception_registry",
|
||||
"//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
||||
@ -5212,6 +5214,7 @@ pywrap_tensorflow_macro(
|
||||
":model_analyzer_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/python/util:cpp_nest",
|
||||
"//tensorflow/python/util:cpp_python_util",
|
||||
"//tensorflow/python/util:function_parameter_canonicalizer",
|
||||
"//tensorflow/python/util:kernel_registry",
|
||||
@ -5289,6 +5292,7 @@ filegroup(
|
||||
srcs = [
|
||||
":bfloat16_lib", # bfloat16
|
||||
":cost_analyzer_lib", # cost_analyzer
|
||||
"//tensorflow/python/util:cpp_nest",
|
||||
"//tensorflow/python/util:cpp_python_util",
|
||||
"//tensorflow/python/util:kernel_registry",
|
||||
":model_analyzer_lib", # model_analyzer
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -1451,6 +1451,14 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
|
||||
self._run(fn, 100000)
|
||||
|
||||
def benchmark_tf_flatten_dict_items(self):
|
||||
nested = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
|
||||
|
||||
def fn():
|
||||
nest.flatten_dict_items(nested)
|
||||
|
||||
self._run(fn, 100000)
|
||||
|
||||
def benchmark_tf_nn_convolution_overhead(self):
|
||||
inputs = array_ops.ones((1, 1, 1, 1))
|
||||
filters = array_ops.ones((1, 1, 1, 1))
|
||||
|
@ -89,6 +89,34 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_nest",
|
||||
srcs = ["nest_wrapper.cc"],
|
||||
hdrs = ["nest.h"],
|
||||
module_name = "_pywrap_nest",
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpp_nest",
|
||||
srcs = ["nest.cc"],
|
||||
hdrs = ["nest.h"],
|
||||
deps = [
|
||||
":cpp_python_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:stringpiece",
|
||||
"//tensorflow/python/lib/core:safe_pyobject_ptr",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_kernel_registry",
|
||||
srcs = ["kernel_registry_wrapper.cc"],
|
||||
|
146
tensorflow/python/util/nest.cc
Normal file
146
tensorflow/python/util/nest.cc
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/util/nest.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Gets a string representation of the input object.
|
||||
//
|
||||
// Args:
|
||||
// o: a python object.
|
||||
// length: If set to negative, the whole string is returned. Otherwise, the
|
||||
// string gets clipped to 'length' in size.
|
||||
//
|
||||
// Returns:
|
||||
// A string representation.
|
||||
std::string PyObject_ToString(PyObject* o, int length = -1) {
|
||||
auto str_o = make_safe(PyObject_Str(o));
|
||||
std::string str = PyUnicode_AsUTF8(str_o.get());
|
||||
if (length < 0 || str.size() <= length) {
|
||||
return str;
|
||||
}
|
||||
tensorflow::StringPiece str_piece(str);
|
||||
return tensorflow::strings::StrCat(str_piece.substr(length), "...");
|
||||
}
|
||||
|
||||
// Gets a list of keys from a dict or mapping type object.
|
||||
//
|
||||
// Args:
|
||||
// o: a dictionary or mapping type object.
|
||||
//
|
||||
// Returns:
|
||||
// A new reference to a list.
|
||||
//
|
||||
// Raises:
|
||||
// TypeError: if `o` is not a dict or mapping type object.
|
||||
PyObject* GetKeysFromDictOrMapping(PyObject* o) {
|
||||
if (PyDict_Check(o)) {
|
||||
return PyDict_Keys(o);
|
||||
} else if (PyMapping_Check(o)) {
|
||||
return PyMapping_Keys(o);
|
||||
} else {
|
||||
auto* o_type = Py_TYPE(o);
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a type compatible with dict or mapping, got '",
|
||||
o_type->tp_name, "'")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* FlattenDictItems(PyObject* dict) {
|
||||
if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"FlattenDictItems: 'dict' must be a dictionary or ",
|
||||
"collection.Mapping type object, instead of '",
|
||||
Py_TYPE(dict)->tp_name, "'.")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* flat_dictionary = PyDict_New();
|
||||
auto keys = make_safe(GetKeysFromDictOrMapping(dict));
|
||||
for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
|
||||
auto* key = PyList_GetItem(keys.get(), i);
|
||||
// We use a general approach in case 'dict' is a PyMapping type,
|
||||
// but not a PyDict type.
|
||||
auto* value = PyObject_GetItem(dict, key);
|
||||
if (swig::IsSequence(key)) {
|
||||
// The dict might contain list - list pairs.
|
||||
auto flat_keys = make_safe(swig::Flatten(key, false));
|
||||
auto flat_values = make_safe(swig::Flatten(value, false));
|
||||
size_t flat_keys_sz = PyList_Size(flat_keys.get());
|
||||
size_t flat_values_sz = PyList_Size(flat_values.get());
|
||||
if (flat_keys_sz != flat_values_sz) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Could not flatten dictionary. Key had ", flat_keys_sz,
|
||||
" elements, but value had ", flat_values_sz,
|
||||
" elements. Key: ", PyObject_ToString(flat_keys.get()),
|
||||
", value: ", PyObject_ToString(flat_values.get()), ".")
|
||||
.c_str());
|
||||
Py_DecRef(flat_dictionary);
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < flat_keys_sz; ++i) {
|
||||
auto* flat_key = PyList_GetItem(flat_keys.get(), i);
|
||||
auto* flat_value = PyList_GetItem(flat_values.get(), i);
|
||||
if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Cannot flatten dict because this key is not unique: ",
|
||||
PyObject_ToString(flat_key))
|
||||
.c_str());
|
||||
Py_DecRef(flat_dictionary);
|
||||
return nullptr;
|
||||
}
|
||||
PyDict_SetItem(flat_dictionary, flat_key, flat_value);
|
||||
}
|
||||
} else {
|
||||
if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Cannot flatten dict because this key is not unique: ",
|
||||
PyObject_ToString(key))
|
||||
.c_str());
|
||||
Py_DecRef(flat_dictionary);
|
||||
return nullptr;
|
||||
}
|
||||
PyDict_SetItem(flat_dictionary, key, value);
|
||||
}
|
||||
// Manually decrease because PyObject_GetItem() returns a new reference.
|
||||
Py_DECREF(value);
|
||||
}
|
||||
return flat_dictionary;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
37
tensorflow/python/util/nest.h
Normal file
37
tensorflow/python/util/nest.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_COMPAT_NEST_H_
|
||||
#define TENSORFLOW_PYTHON_COMPAT_NEST_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns a dictionary with flattened keys and values.
|
||||
//
|
||||
// Args:
|
||||
// dict: the dictionary to zip
|
||||
//
|
||||
// Returns:
|
||||
// An new reference to the zipped dictionary.
|
||||
//
|
||||
// Raises:
|
||||
// TypeError: If the input is not a dictionary.
|
||||
// ValueError: If any key and value do not have the same structure layout, or
|
||||
// if keys are not unique.
|
||||
PyObject* FlattenDictItems(PyObject* dict);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_COMPAT_NEST_H_
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -48,6 +48,7 @@ import six as _six
|
||||
import wrapt as _wrapt
|
||||
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import _pywrap_nest
|
||||
from tensorflow.python.util import _pywrap_utils
|
||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -562,30 +563,7 @@ def flatten_dict_items(dictionary):
|
||||
ValueError: If any key and value do not have the same structure layout, or
|
||||
if keys are not unique.
|
||||
"""
|
||||
if not isinstance(dictionary, (dict, _collections_abc.Mapping)):
|
||||
raise TypeError("input must be a dictionary")
|
||||
flat_dictionary = {}
|
||||
for i, v in _six.iteritems(dictionary):
|
||||
if not is_sequence(i):
|
||||
if i in flat_dictionary:
|
||||
raise ValueError(
|
||||
"Could not flatten dictionary: key %s is not unique." % i)
|
||||
flat_dictionary[i] = v
|
||||
else:
|
||||
flat_i = flatten(i)
|
||||
flat_v = flatten(v)
|
||||
if len(flat_i) != len(flat_v):
|
||||
raise ValueError(
|
||||
"Could not flatten dictionary. Key had %d elements, but value had "
|
||||
"%d elements. Key: %s, value: %s."
|
||||
% (len(flat_i), len(flat_v), flat_i, flat_v))
|
||||
for new_i, new_v in zip(flat_i, flat_v):
|
||||
if new_i in flat_dictionary:
|
||||
raise ValueError(
|
||||
"Could not flatten dictionary: key %s is not unique."
|
||||
% (new_i))
|
||||
flat_dictionary[new_i] = new_v
|
||||
return flat_dictionary
|
||||
return _pywrap_nest.FlattenDictItems(dictionary)
|
||||
|
||||
|
||||
def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None):
|
||||
|
35
tensorflow/python/util/nest_wrapper.cc
Normal file
35
tensorflow/python/util/nest_wrapper.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/util/nest.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_nest, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_nest
|
||||
-----
|
||||
)pbdoc";
|
||||
m.def(
|
||||
"FlattenDictItems",
|
||||
[](const py::handle& dict) {
|
||||
return tensorflow::PyoOrThrow(tensorflow::FlattenDictItems(dict.ptr()));
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns a dictionary with flattened keys and values.
|
||||
)pbdoc");
|
||||
}
|
@ -22,6 +22,9 @@ tensorflow::swig::RegisterType
|
||||
tensorflow::swig::IsEagerTensorSlow
|
||||
tensorflow::swig::GetRegisteredPyObject
|
||||
|
||||
[//tensorflow/python/util:cpp_nest] # nest
|
||||
tensorflow::FlattenDictItems
|
||||
|
||||
[//tensorflow/core/util:port] # util_port
|
||||
tensorflow::IsGoogleCudaEnabled
|
||||
tensorflow::IsBuiltWithROCm
|
||||
|
Loading…
Reference in New Issue
Block a user