Converted device_lib.i to pybind11
This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 271564267
This commit is contained in:
parent
a0aa739277
commit
7623d23a63
@ -4682,6 +4682,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_device_lib",
|
||||||
":errors",
|
":errors",
|
||||||
":framework",
|
":framework",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
@ -4952,7 +4953,23 @@ py_library(
|
|||||||
srcs = ["client/device_lib.py"],
|
srcs = ["client/device_lib.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":pywrap_tensorflow",
|
":_pywrap_device_lib",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_device_lib",
|
||||||
|
srcs = ["client/device_lib_wrapper.cc"],
|
||||||
|
module_name = "_pywrap_device_lib",
|
||||||
|
deps = [
|
||||||
|
":pybind11_proto",
|
||||||
|
":pybind11_status",
|
||||||
|
"//tensorflow/core:core_cpu_headers_lib",
|
||||||
|
"//tensorflow/core:framework_internal_headers_lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5029,7 +5046,6 @@ tf_py_wrap_cc(
|
|||||||
name = "pywrap_tensorflow_internal",
|
name = "pywrap_tensorflow_internal",
|
||||||
srcs = ["tensorflow.i"],
|
srcs = ["tensorflow.i"],
|
||||||
swig_includes = [
|
swig_includes = [
|
||||||
"client/device_lib.i",
|
|
||||||
"client/tf_session.i",
|
"client/tf_session.i",
|
||||||
"client/tf_sessionrun_wrapper.i",
|
"client/tf_sessionrun_wrapper.i",
|
||||||
"framework/python_op_gen.i",
|
"framework/python_op_gen.i",
|
||||||
@ -5154,6 +5170,8 @@ genrule(
|
|||||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||||
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
||||||
|
"//tensorflow/core/lib/core:status", # device_lib
|
||||||
|
"//tensorflow/core:core_cpu_impl", # device_lib
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
|
@ -1,111 +0,0 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) {
|
|
||||||
char* c_string;
|
|
||||||
Py_ssize_t py_size;
|
|
||||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
|
||||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
|
||||||
PyErr_SetString(
|
|
||||||
PyExc_TypeError,
|
|
||||||
"The ConfigProto could not be parsed as a valid protocol buffer");
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
$1 = &temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
|
||||||
#include "tensorflow/core/public/session_options.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace swig {
|
|
||||||
|
|
||||||
static std::vector<string> ListDevicesWithSessionConfig(
|
|
||||||
const tensorflow::ConfigProto& config, TF_Status* status) {
|
|
||||||
std::vector<string> output;
|
|
||||||
SessionOptions options;
|
|
||||||
options.config = config;
|
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
|
||||||
Status s = DeviceFactory::AddDevices(options, "" /* name_prefix */, &devices);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const std::unique_ptr<Device>& device : devices) {
|
|
||||||
const DeviceAttributes& attr = device->attributes();
|
|
||||||
string attr_serialized;
|
|
||||||
if (!attr.SerializeToString(&attr_serialized)) {
|
|
||||||
Set_TF_Status_from_Status(
|
|
||||||
status,
|
|
||||||
errors::Internal("Could not serialize device string"));
|
|
||||||
output.clear();
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
output.push_back(attr_serialized);
|
|
||||||
}
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<string> ListDevices(TF_Status* status) {
|
|
||||||
tensorflow::ConfigProto session_config;
|
|
||||||
return ListDevicesWithSessionConfig(session_config, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace swig
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
|
|
||||||
%unignore tensorflow;
|
|
||||||
%unignore tensorflow::swig;
|
|
||||||
%unignore tensorflow::swig::ListDevicesWithSessionConfig;
|
|
||||||
%unignore tensorflow::swig::ListDevices;
|
|
||||||
|
|
||||||
// Wrap this function
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace swig {
|
|
||||||
std::vector<string> ListDevices(TF_Status* status);
|
|
||||||
static std::vector<string> ListDevicesWithSessionConfig(
|
|
||||||
const tensorflow::ConfigProto& config, TF_Status* status);
|
|
||||||
} // namespace swig
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
%insert("python") %{
|
|
||||||
def list_devices(session_config=None):
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
|
|
||||||
if session_config:
|
|
||||||
return ListDevicesWithSessionConfig(session_config.SerializeToString())
|
|
||||||
else:
|
|
||||||
return ListDevices()
|
|
||||||
%}
|
|
||||||
|
|
||||||
%unignoreall
|
|
||||||
|
|
||||||
%newobject tensorflow::SessionOptions;
|
|
@ -19,7 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.framework import device_attributes_pb2
|
from tensorflow.core.framework import device_attributes_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_device_lib
|
||||||
|
|
||||||
|
|
||||||
def list_local_devices(session_config=None):
|
def list_local_devices(session_config=None):
|
||||||
@ -36,7 +36,9 @@ def list_local_devices(session_config=None):
|
|||||||
m.ParseFromString(pb_str)
|
m.ParseFromString(pb_str)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
serialized_config = None
|
||||||
|
if session_config is not None:
|
||||||
|
serialized_config = session_config.SerializeToString()
|
||||||
return [
|
return [
|
||||||
_convert(s)
|
_convert(s) for s in _pywrap_device_lib.list_devices(serialized_config)
|
||||||
for s in pywrap_tensorflow.list_devices(session_config=session_config)
|
|
||||||
]
|
]
|
||||||
|
58
tensorflow/python/client/device_lib_wrapper.cc
Normal file
58
tensorflow/python/client/device_lib_wrapper.cc
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_proto.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
namespace py = ::pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_device_lib, m) {
|
||||||
|
m.def("list_devices", [](py::object serialized_config) {
|
||||||
|
tensorflow::ConfigProto config;
|
||||||
|
if (!serialized_config.is_none()) {
|
||||||
|
config.ParseFromString(
|
||||||
|
static_cast<std::string>(serialized_config.cast<py::bytes>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::SessionOptions options;
|
||||||
|
options.config = config;
|
||||||
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
|
tensorflow::MaybeRaiseFromStatus(tensorflow::DeviceFactory::AddDevices(
|
||||||
|
options, /*name_prefix=*/"", &devices));
|
||||||
|
|
||||||
|
py::list results;
|
||||||
|
std::string serialized_attr;
|
||||||
|
for (const auto& device : devices) {
|
||||||
|
if (!device->attributes().SerializeToString(&serialized_attr)) {
|
||||||
|
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
|
||||||
|
"Could not serialize DeviceAttributes to bytes"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The default type caster for std::string assumes its contents
|
||||||
|
// is UTF8-encoded.
|
||||||
|
results.append(py::bytes(serialized_attr));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
});
|
||||||
|
}
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
%include "tensorflow/python/lib/io/py_record_writer.i"
|
%include "tensorflow/python/lib/io/py_record_writer.i"
|
||||||
|
|
||||||
%include "tensorflow/python/client/tf_session.i"
|
%include "tensorflow/python/client/tf_session.i"
|
||||||
%include "tensorflow/python/client/device_lib.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/bfloat16.i"
|
%include "tensorflow/python/lib/core/bfloat16.i"
|
||||||
|
|
||||||
|
@ -60,3 +60,17 @@ tensorflow::OpRegistry::Global
|
|||||||
tensorflow::OpRegistry::LookUpOpDef
|
tensorflow::OpRegistry::LookUpOpDef
|
||||||
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
|
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
|
||||||
|
|
||||||
|
[status] # device_lib
|
||||||
|
tensorflow::Status::code
|
||||||
|
tensorflow::Status::error_message
|
||||||
|
tensorflow::Status::ok()
|
||||||
|
|
||||||
|
[core_cpu_impl] # device_lib
|
||||||
|
tensorflow::Device::attributes
|
||||||
|
tensorflow::DeviceFactory::AddDevices
|
||||||
|
tensorflow::SessionOptions::SessionOptions
|
||||||
|
|
||||||
|
[protos_all] # device_lib
|
||||||
|
tensorflow::ConfigProto::ConfigProto
|
||||||
|
tensorflow::ConfigProto::ParseFromString
|
||||||
|
tensorflow::DeviceAttributes::SerializeToString
|
||||||
|
Loading…
x
Reference in New Issue
Block a user