Converted device_lib.i to pybind11

This is part of a larger effort to deprecate swig and eventually with
modularization break pywrap_tensorflow into smaller components.
Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md
for more information.

PiperOrigin-RevId: 271564267
This commit is contained in:
Sergei Lebedev 2019-09-27 07:12:47 -07:00 committed by TensorFlower Gardener
parent a0aa739277
commit 7623d23a63
6 changed files with 97 additions and 117 deletions

View File

@ -4682,6 +4682,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":_pywrap_device_lib",
":errors",
":framework",
":framework_for_generated_wrappers",
@ -4952,7 +4953,23 @@ py_library(
srcs = ["client/device_lib.py"],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
":_pywrap_device_lib",
"//tensorflow/core:protos_all_py",
],
)
tf_python_pybind_extension(
name = "_pywrap_device_lib",
srcs = ["client/device_lib_wrapper.cc"],
module_name = "_pywrap_device_lib",
deps = [
":pybind11_proto",
":pybind11_status",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework_internal_headers_lib",
"//tensorflow/core:protos_all_cc",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
@ -5029,7 +5046,6 @@ tf_py_wrap_cc(
name = "pywrap_tensorflow_internal",
srcs = ["tensorflow.i"],
swig_includes = [
"client/device_lib.i",
"client/tf_session.i",
"client/tf_sessionrun_wrapper.i",
"framework/python_op_gen.i",
@ -5154,6 +5170,8 @@ genrule(
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
"//tensorflow/core:framework_internal_impl", # op_def_registry
"//tensorflow/core/lib/core:status", # device_lib
"//tensorflow/core:core_cpu_impl", # device_lib
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({

View File

@ -1,111 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The ConfigProto could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace swig {
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* status) {
std::vector<string> output;
SessionOptions options;
options.config = config;
std::vector<std::unique_ptr<Device>> devices;
Status s = DeviceFactory::AddDevices(options, "" /* name_prefix */, &devices);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
for (const std::unique_ptr<Device>& device : devices) {
const DeviceAttributes& attr = device->attributes();
string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) {
Set_TF_Status_from_Status(
status,
errors::Internal("Could not serialize device string"));
output.clear();
return output;
}
output.push_back(attr_serialized);
}
return output;
}
std::vector<string> ListDevices(TF_Status* status) {
tensorflow::ConfigProto session_config;
return ListDevicesWithSessionConfig(session_config, status);
}
} // namespace swig
} // namespace tensorflow
%}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ListDevicesWithSessionConfig;
%unignore tensorflow::swig::ListDevices;
// Wrap this function
namespace tensorflow {
namespace swig {
std::vector<string> ListDevices(TF_Status* status);
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def list_devices(session_config=None):
from tensorflow.python.framework import errors
if session_config:
return ListDevicesWithSessionConfig(session_config.SerializeToString())
else:
return ListDevices()
%}
%unignoreall
%newobject tensorflow::SessionOptions;

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import device_attributes_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_device_lib
def list_local_devices(session_config=None):
@ -36,7 +36,9 @@ def list_local_devices(session_config=None):
m.ParseFromString(pb_str)
return m
serialized_config = None
if session_config is not None:
serialized_config = session_config.SerializeToString()
return [
_convert(s)
for s in pywrap_tensorflow.list_devices(session_config=session_config)
_convert(s) for s in _pywrap_device_lib.list_devices(serialized_config)
]

View File

@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/python/lib/core/pybind11_proto.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = ::pybind11;
PYBIND11_MODULE(_pywrap_device_lib, m) {
m.def("list_devices", [](py::object serialized_config) {
tensorflow::ConfigProto config;
if (!serialized_config.is_none()) {
config.ParseFromString(
static_cast<std::string>(serialized_config.cast<py::bytes>()));
}
tensorflow::SessionOptions options;
options.config = config;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
tensorflow::MaybeRaiseFromStatus(tensorflow::DeviceFactory::AddDevices(
options, /*name_prefix=*/"", &devices));
py::list results;
std::string serialized_attr;
for (const auto& device : devices) {
if (!device->attributes().SerializeToString(&serialized_attr)) {
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
"Could not serialize DeviceAttributes to bytes"));
}
// The default type caster for std::string assumes its contents
// is UTF8-encoded.
results.append(py::bytes(serialized_attr));
}
return results;
});
}

View File

@ -27,7 +27,6 @@ limitations under the License.
%include "tensorflow/python/lib/io/py_record_writer.i"
%include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/client/device_lib.i"
%include "tensorflow/python/lib/core/bfloat16.i"

View File

@ -60,3 +60,17 @@ tensorflow::OpRegistry::Global
tensorflow::OpRegistry::LookUpOpDef
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
[status] # device_lib
tensorflow::Status::code
tensorflow::Status::error_message
tensorflow::Status::ok()
[core_cpu_impl] # device_lib
tensorflow::Device::attributes
tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions
[protos_all] # device_lib
tensorflow::ConfigProto::ConfigProto
tensorflow::ConfigProto::ParseFromString
tensorflow::DeviceAttributes::SerializeToString