Converted device_lib.i to pybind11
This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 271564267
This commit is contained in:
parent
a0aa739277
commit
7623d23a63
tensorflow
python
tools/def_file_filter
@ -4682,6 +4682,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_device_lib",
|
||||
":errors",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
@ -4952,7 +4953,23 @@ py_library(
|
||||
srcs = ["client/device_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
":_pywrap_device_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_device_lib",
|
||||
srcs = ["client/device_lib_wrapper.cc"],
|
||||
module_name = "_pywrap_device_lib",
|
||||
deps = [
|
||||
":pybind11_proto",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework_internal_headers_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
@ -5029,7 +5046,6 @@ tf_py_wrap_cc(
|
||||
name = "pywrap_tensorflow_internal",
|
||||
srcs = ["tensorflow.i"],
|
||||
swig_includes = [
|
||||
"client/device_lib.i",
|
||||
"client/tf_session.i",
|
||||
"client/tf_sessionrun_wrapper.i",
|
||||
"framework/python_op_gen.i",
|
||||
@ -5154,6 +5170,8 @@ genrule(
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
||||
"//tensorflow/core/lib/core:status", # device_lib
|
||||
"//tensorflow/core:core_cpu_impl", # device_lib
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
cmd = select({
|
||||
|
@ -1,111 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The ConfigProto could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
static std::vector<string> ListDevicesWithSessionConfig(
|
||||
const tensorflow::ConfigProto& config, TF_Status* status) {
|
||||
std::vector<string> output;
|
||||
SessionOptions options;
|
||||
options.config = config;
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status s = DeviceFactory::AddDevices(options, "" /* name_prefix */, &devices);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<Device>& device : devices) {
|
||||
const DeviceAttributes& attr = device->attributes();
|
||||
string attr_serialized;
|
||||
if (!attr.SerializeToString(&attr_serialized)) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
errors::Internal("Could not serialize device string"));
|
||||
output.clear();
|
||||
return output;
|
||||
}
|
||||
output.push_back(attr_serialized);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<string> ListDevices(TF_Status* status) {
|
||||
tensorflow::ConfigProto session_config;
|
||||
return ListDevicesWithSessionConfig(session_config, status);
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ListDevicesWithSessionConfig;
|
||||
%unignore tensorflow::swig::ListDevices;
|
||||
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
std::vector<string> ListDevices(TF_Status* status);
|
||||
static std::vector<string> ListDevicesWithSessionConfig(
|
||||
const tensorflow::ConfigProto& config, TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%insert("python") %{
|
||||
def list_devices(session_config=None):
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
if session_config:
|
||||
return ListDevicesWithSessionConfig(session_config.SerializeToString())
|
||||
else:
|
||||
return ListDevices()
|
||||
%}
|
||||
|
||||
%unignoreall
|
||||
|
||||
%newobject tensorflow::SessionOptions;
|
@ -19,7 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import device_attributes_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_device_lib
|
||||
|
||||
|
||||
def list_local_devices(session_config=None):
|
||||
@ -36,7 +36,9 @@ def list_local_devices(session_config=None):
|
||||
m.ParseFromString(pb_str)
|
||||
return m
|
||||
|
||||
serialized_config = None
|
||||
if session_config is not None:
|
||||
serialized_config = session_config.SerializeToString()
|
||||
return [
|
||||
_convert(s)
|
||||
for s in pywrap_tensorflow.list_devices(session_config=session_config)
|
||||
_convert(s) for s in _pywrap_device_lib.list_devices(serialized_config)
|
||||
]
|
||||
|
58
tensorflow/python/client/device_lib_wrapper.cc
Normal file
58
tensorflow/python/client/device_lib_wrapper.cc
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_proto.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_device_lib, m) {
|
||||
m.def("list_devices", [](py::object serialized_config) {
|
||||
tensorflow::ConfigProto config;
|
||||
if (!serialized_config.is_none()) {
|
||||
config.ParseFromString(
|
||||
static_cast<std::string>(serialized_config.cast<py::bytes>()));
|
||||
}
|
||||
|
||||
tensorflow::SessionOptions options;
|
||||
options.config = config;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
tensorflow::MaybeRaiseFromStatus(tensorflow::DeviceFactory::AddDevices(
|
||||
options, /*name_prefix=*/"", &devices));
|
||||
|
||||
py::list results;
|
||||
std::string serialized_attr;
|
||||
for (const auto& device : devices) {
|
||||
if (!device->attributes().SerializeToString(&serialized_attr)) {
|
||||
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
|
||||
"Could not serialize DeviceAttributes to bytes"));
|
||||
}
|
||||
|
||||
// The default type caster for std::string assumes its contents
|
||||
// is UTF8-encoded.
|
||||
results.append(py::bytes(serialized_attr));
|
||||
}
|
||||
return results;
|
||||
});
|
||||
}
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
%include "tensorflow/python/lib/io/py_record_writer.i"
|
||||
|
||||
%include "tensorflow/python/client/tf_session.i"
|
||||
%include "tensorflow/python/client/device_lib.i"
|
||||
|
||||
%include "tensorflow/python/lib/core/bfloat16.i"
|
||||
|
||||
|
@ -60,3 +60,17 @@ tensorflow::OpRegistry::Global
|
||||
tensorflow::OpRegistry::LookUpOpDef
|
||||
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
|
||||
|
||||
[status] # device_lib
|
||||
tensorflow::Status::code
|
||||
tensorflow::Status::error_message
|
||||
tensorflow::Status::ok()
|
||||
|
||||
[core_cpu_impl] # device_lib
|
||||
tensorflow::Device::attributes
|
||||
tensorflow::DeviceFactory::AddDevices
|
||||
tensorflow::SessionOptions::SessionOptions
|
||||
|
||||
[protos_all] # device_lib
|
||||
tensorflow::ConfigProto::ConfigProto
|
||||
tensorflow::ConfigProto::ParseFromString
|
||||
tensorflow::DeviceAttributes::SerializeToString
|
||||
|
Loading…
Reference in New Issue
Block a user