Change DeviceFactory functions that create devices to propagate
Statuses, so that failures to initialize devices don't crash the program. Changes swig for device_lib to be a lot simpler, thanks to mrry@ and keveman@'s help. Change allocation of eigen scratch memory to go through the allocator. Re-enable test for local devices now that python3 issue is fixed. Change: 129678132
This commit is contained in:
parent
0204fbd5fe
commit
2d0d126749
@ -1,3 +1,9 @@
|
||||
# Changes since last release
|
||||
|
||||
## Breaking Changes to the API
|
||||
|
||||
* DeviceFactory's AddDevices and CreateDevices functions now return
|
||||
a Status instead of void.
|
||||
|
||||
# Release 0.10.0
|
||||
|
||||
|
@ -173,9 +173,9 @@ Device* GetCPUDevice() {
|
||||
mutex_lock l(mu);
|
||||
if (!device) {
|
||||
std::vector<Device*> devices;
|
||||
DeviceFactory::GetFactory(DEVICE_CPU)
|
||||
->CreateDevices(SessionOptions{}, "", &devices);
|
||||
if (devices.size() > 0) {
|
||||
Status s = DeviceFactory::GetFactory(DEVICE_CPU)
|
||||
->CreateDevices(SessionOptions{}, "", &devices);
|
||||
if (s.ok() && devices.size() > 0) {
|
||||
device = devices[0];
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -74,25 +75,26 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
|
||||
return it->second.factory.get();
|
||||
}
|
||||
|
||||
void DeviceFactory::AddDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
Status DeviceFactory::AddDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
// CPU first.
|
||||
auto cpu_factory = GetFactory("CPU");
|
||||
if (!cpu_factory) {
|
||||
LOG(FATAL)
|
||||
<< "CPU Factory not registered. Did you link in threadpool_device?";
|
||||
return errors::NotFound(
|
||||
"CPU Factory not registered. Did you link in threadpool_device?");
|
||||
}
|
||||
size_t init_size = devices->size();
|
||||
cpu_factory->CreateDevices(options, name_prefix, devices);
|
||||
if (devices->size() == init_size) {
|
||||
LOG(FATAL) << "No CPU devices are available in this process";
|
||||
return errors::NotFound("No CPU devices are available in this process");
|
||||
}
|
||||
|
||||
// Then GPU.
|
||||
auto gpu_factory = GetFactory("GPU");
|
||||
if (gpu_factory) {
|
||||
gpu_factory->CreateDevices(options, name_prefix, devices);
|
||||
TF_RETURN_IF_ERROR(
|
||||
gpu_factory->CreateDevices(options, name_prefix, devices));
|
||||
}
|
||||
|
||||
// Then the rest.
|
||||
@ -100,9 +102,11 @@ void DeviceFactory::AddDevices(const SessionOptions& options,
|
||||
for (auto& p : device_factories()) {
|
||||
auto factory = p.second.factory.get();
|
||||
if (factory != cpu_factory && factory != gpu_factory) {
|
||||
factory->CreateDevices(options, name_prefix, devices);
|
||||
TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Device* DeviceFactory::NewDevice(const string& type,
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -36,9 +38,9 @@ class DeviceFactory {
|
||||
// any device type specific properties/counts listed in "options".
|
||||
//
|
||||
// CPU devices are added first.
|
||||
static void AddDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices);
|
||||
static Status AddDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices);
|
||||
|
||||
// Helper for tests. Create a single device of type "type". The
|
||||
// returned device is always numbered zero, so if creating multiple
|
||||
@ -47,9 +49,9 @@ class DeviceFactory {
|
||||
const string& name_prefix);
|
||||
|
||||
// Most clients should call AddDevices() instead.
|
||||
virtual void CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) = 0;
|
||||
virtual Status CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) = 0;
|
||||
};
|
||||
|
||||
namespace dfactory {
|
||||
|
@ -1045,8 +1045,13 @@ class DirectSessionFactory : public SessionFactory {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
|
||||
&devices);
|
||||
Status s = DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new DirectSession(options, new DeviceMgr(devices));
|
||||
}
|
||||
};
|
||||
|
@ -229,22 +229,38 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
|
||||
gpu_allocator_(gpu_allocator),
|
||||
cpu_allocator_(cpu_allocator),
|
||||
gpu_id_(gpu_id),
|
||||
sync_every_op_(sync_every_op) {
|
||||
sync_every_op_(sync_every_op),
|
||||
max_streams_(max_streams) {
|
||||
ProcessState::singleton()->EnableGPUDevice();
|
||||
}
|
||||
|
||||
executor_ = GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie();
|
||||
if (!executor_) {
|
||||
LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_;
|
||||
return;
|
||||
BaseGPUDevice::~BaseGPUDevice() {
|
||||
delete gpu_device_info_;
|
||||
for (auto ctx : device_contexts_) ctx->Unref();
|
||||
for (auto& stream_group : streams_) {
|
||||
delete stream_group.compute;
|
||||
delete stream_group.host_to_device;
|
||||
delete stream_group.device_to_host;
|
||||
delete stream_group.device_to_device;
|
||||
}
|
||||
}
|
||||
|
||||
Status BaseGPUDevice::Init(const SessionOptions& options) {
|
||||
auto executor_status = GPUMachineManager()->ExecutorForDevice(gpu_id_);
|
||||
if (!executor_status.status().ok()) {
|
||||
return errors::Internal("Failed to get StreamExecutor for device ",
|
||||
gpu_id_);
|
||||
}
|
||||
|
||||
executor_ = executor_status.ValueOrDie();
|
||||
em_.reset(new EventMgr(executor_, options.config.gpu_options()));
|
||||
|
||||
if (max_streams < 1) {
|
||||
LOG(FATAL) << "Invalid value for max_streams.";
|
||||
if (max_streams_ < 1) {
|
||||
return errors::InvalidArgument("Invalid value for max_streams.");
|
||||
}
|
||||
|
||||
// Create the specified number of GPU streams
|
||||
for (int i = 0; i < max_streams; i++) {
|
||||
for (int i = 0; i < max_streams_; i++) {
|
||||
auto stream = new gpu::Stream(executor_);
|
||||
stream->Init();
|
||||
VLOG(2) << "Created stream[" << i << "] = " << stream;
|
||||
@ -267,14 +283,24 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
|
||||
streams_.push_back({stream, host_to_device_stream, device_to_host_stream,
|
||||
device_to_device_stream});
|
||||
|
||||
perftools::gputools::DeviceMemory<char> mem =
|
||||
executor_->AllocateArray<char>(Eigen::kCudaScratchSize +
|
||||
sizeof(unsigned int));
|
||||
scratch_.push_back(static_cast<char*>(mem.opaque()));
|
||||
size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
|
||||
void* scratch_buffer = gpu_allocator_->AllocateRaw(
|
||||
Allocator::kAllocatorAlignment, scratch_buffer_size);
|
||||
if (scratch_buffer == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to allocate scratch buffer for device ", gpu_id_);
|
||||
}
|
||||
scratch_.push_back(static_cast<char*>(scratch_buffer));
|
||||
|
||||
perftools::gputools::DeviceMemory<char> mem(
|
||||
perftools::gputools::DeviceMemoryBase(scratch_buffer,
|
||||
scratch_buffer_size));
|
||||
|
||||
bool ok = executor_->SynchronousMemZero(
|
||||
&mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
|
||||
if (!ok) {
|
||||
LOG(FATAL) << "Failed to initialize device " << gpu_id;
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to memcopy into scratch buffer for device ", gpu_id_);
|
||||
}
|
||||
|
||||
device_contexts_.push_back(
|
||||
@ -286,17 +312,8 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
|
||||
gpu_device_info_->default_context = device_contexts_[0];
|
||||
gpu_device_info_->event_mgr = em_.get();
|
||||
set_tensorflow_gpu_device_info(gpu_device_info_);
|
||||
}
|
||||
|
||||
BaseGPUDevice::~BaseGPUDevice() {
|
||||
delete gpu_device_info_;
|
||||
for (auto ctx : device_contexts_) ctx->Unref();
|
||||
for (auto& stream_group : streams_) {
|
||||
delete stream_group.compute;
|
||||
delete stream_group.host_to_device;
|
||||
delete stream_group.device_to_host;
|
||||
delete stream_group.device_to_device;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
|
||||
@ -571,9 +588,9 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
int n = INT_MAX;
|
||||
auto iter = options.config.device_count().find("GPU");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
@ -585,9 +602,15 @@ void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
n = valid_gpu_ids.size();
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
devices->push_back(CreateGPUDevice(
|
||||
options, strings::StrCat(name_prefix, "/gpu:", i), valid_gpu_ids[i]));
|
||||
BaseGPUDevice* gpu_device;
|
||||
TF_RETURN_IF_ERROR(CreateGPUDevice(options,
|
||||
strings::StrCat(name_prefix, "/gpu:", i),
|
||||
valid_gpu_ids[i], &gpu_device));
|
||||
TF_RETURN_IF_ERROR(gpu_device->Init(options));
|
||||
devices->push_back(gpu_device);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -615,8 +638,9 @@ static string GetShortDeviceDescription(int device_id,
|
||||
", pci bus id: ", desc.pci_bus_id());
|
||||
}
|
||||
|
||||
LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
|
||||
const SessionOptions& options, const string& name, int gpu_id) {
|
||||
Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, int gpu_id,
|
||||
BaseGPUDevice** out_device) {
|
||||
CHECK_GE(gpu_id, 0);
|
||||
|
||||
// Look up the device, to see its attributes.
|
||||
@ -675,12 +699,14 @@ LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
|
||||
<< " numa: " << numa_node << " pci: " << desc.pci_bus_id();
|
||||
|
||||
ProcessState* process_state = ProcessState::singleton();
|
||||
return CreateGPUDevice(
|
||||
*out_device = CreateGPUDevice(
|
||||
options, name, allocated_bytes, bus_adjacency, gpu_id,
|
||||
GetShortDeviceDescription(gpu_id, desc),
|
||||
process_state->GetGPUAllocator(options.config.gpu_options(), gpu_id,
|
||||
allocated_memory),
|
||||
process_state->GetCPUAllocator(numa_node));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static int GetDefaultMinGPUMultiprocessorCount(gpu::Platform* gpu_manager) {
|
||||
|
@ -48,6 +48,9 @@ class BaseGPUDevice : public LocalDevice {
|
||||
|
||||
~BaseGPUDevice() override;
|
||||
|
||||
// Initialize the device and return the status of initialization.
|
||||
Status Init(const SessionOptions& options);
|
||||
|
||||
// GPU devices require the Op Compute method to save a reference to
|
||||
// any temporary tensors that are allocated until the Op execution
|
||||
// completes.
|
||||
@ -97,6 +100,7 @@ class BaseGPUDevice : public LocalDevice {
|
||||
mutex trace_mu_;
|
||||
int gpu_id_ = -1;
|
||||
const bool sync_every_op_ = false;
|
||||
const int32 max_streams_;
|
||||
std::unique_ptr<EventMgr> em_;
|
||||
|
||||
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
|
||||
@ -105,19 +109,19 @@ class BaseGPUDevice : public LocalDevice {
|
||||
|
||||
class BaseGPUDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
void CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override;
|
||||
|
||||
private:
|
||||
LocalDevice* CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, int gpu_id);
|
||||
Status CreateGPUDevice(const SessionOptions& options, const string& name,
|
||||
int gpu_id, BaseGPUDevice** out_device);
|
||||
|
||||
virtual LocalDevice* CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, Bytes memory_limit,
|
||||
BusAdjacency bus_adjacency, int gpu_id,
|
||||
const string& physical_device_desc,
|
||||
Allocator* gpu_allocator,
|
||||
Allocator* cpu_allocator) = 0;
|
||||
virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, Bytes memory_limit,
|
||||
BusAdjacency bus_adjacency, int gpu_id,
|
||||
const string& physical_device_desc,
|
||||
Allocator* gpu_allocator,
|
||||
Allocator* cpu_allocator) = 0;
|
||||
|
||||
void GetValidDeviceIds(std::vector<int>* ids);
|
||||
};
|
||||
|
@ -49,12 +49,12 @@ class GPUDevice : public BaseGPUDevice {
|
||||
|
||||
class GPUDeviceFactory : public BaseGPUDeviceFactory {
|
||||
private:
|
||||
LocalDevice* CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, Bytes memory_limit,
|
||||
BusAdjacency bus_adjacency, int gpu_id,
|
||||
const string& physical_device_desc,
|
||||
Allocator* gpu_allocator,
|
||||
Allocator* cpu_allocator) override {
|
||||
BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
|
||||
const string& name, Bytes memory_limit,
|
||||
BusAdjacency bus_adjacency, int gpu_id,
|
||||
const string& physical_device_desc,
|
||||
Allocator* gpu_allocator,
|
||||
Allocator* cpu_allocator) override {
|
||||
return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
|
||||
physical_device_desc, gpu_allocator, cpu_allocator);
|
||||
}
|
||||
@ -89,8 +89,8 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
|
||||
// The associated factory.
|
||||
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
void CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
int n = 1;
|
||||
auto iter = options.config.device_count().find("CPU");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
@ -101,6 +101,8 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
|
||||
devices->push_back(new GPUCompatibleCPUDevice(
|
||||
options, name, Bytes(256 << 20), BUS_ANY, cpu_allocator()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 50);
|
||||
|
@ -26,8 +26,8 @@ namespace tensorflow {
|
||||
// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
|
||||
class ThreadPoolDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
void CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
// TODO(zhifengc/tucker): Figure out the number of available CPUs
|
||||
// and/or NUMA configuration.
|
||||
int n = 1;
|
||||
@ -40,6 +40,8 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
|
||||
devices->push_back(new ThreadPoolDevice(options, name, Bytes(256 << 20),
|
||||
BUS_ANY, cpu_allocator()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory);
|
||||
|
@ -86,7 +86,8 @@ Status GrpcServer::Init() {
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", server_def_.job_name(), "/replica:0", "/task:",
|
||||
server_def_.task_index());
|
||||
DeviceFactory::AddDevices(sess_opts, name_prefix, &master_env_.local_devices);
|
||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
|
||||
&master_env_.local_devices));
|
||||
worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices);
|
||||
string unused;
|
||||
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
|
||||
|
@ -19,69 +19,63 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
%}
|
||||
|
||||
%typemap(in, numinputs=0) const tensorflow::SessionOptions& options (
|
||||
tensorflow::SessionOptions temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
%typemap(in, numinputs=0) std::vector<tensorflow::Device*>* devices (
|
||||
std::vector<tensorflow::Device*> temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
// Handle string input into AddDevices
|
||||
%typemap(in, numinputs=0) const string& name_prefix (
|
||||
string temp) {
|
||||
// Always pass an empty name_prefix.
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(argout) std::vector<tensorflow::Device*>* devices {
|
||||
std::vector< std::unique_ptr<tensorflow::Device> > safe_devices;
|
||||
for (auto* device : *$1) safe_devices.emplace_back(device);
|
||||
|
||||
auto temp_string_list = tensorflow::make_safe(PyList_New(0));
|
||||
if (!temp_string_list) {
|
||||
SWIG_fail;
|
||||
static std::vector<string> ListDevices(TF_Status* out_status) {
|
||||
std::vector<string> output;
|
||||
SessionOptions options;
|
||||
std::vector<Device*> devices;
|
||||
Status status = DeviceFactory::AddDevices(
|
||||
options, "" /* name_prefix */, &devices);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
}
|
||||
|
||||
for (const auto& device : safe_devices) {
|
||||
const tensorflow::DeviceAttributes& attr = device->attributes();
|
||||
for (const Device* device : devices) {
|
||||
const DeviceAttributes& attr = device->attributes();
|
||||
string attr_serialized;
|
||||
if (!attr.SerializeToString(&attr_serialized)) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"Unable to serialize DeviceAttributes");
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
tensorflow::Safe_PyObjectPtr safe_attr_string = tensorflow::make_safe(
|
||||
%#if PY_MAJOR_VERSION < 3
|
||||
PyString_FromStringAndSize(
|
||||
%#else
|
||||
PyBytes_FromStringAndSize(
|
||||
%#endif
|
||||
reinterpret_cast<const char*>(
|
||||
attr_serialized.data()), attr_serialized.size()));
|
||||
|
||||
if (PyList_Append(temp_string_list.get(), safe_attr_string.get()) == -1) {
|
||||
SWIG_fail;
|
||||
Set_TF_Status_from_Status(
|
||||
out_status,
|
||||
errors::Internal("Could not serialize device string"));
|
||||
output.clear();
|
||||
return output;
|
||||
}
|
||||
output.push_back(attr_serialized);
|
||||
}
|
||||
|
||||
$result = temp_string_list.release();
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::DeviceFactory;
|
||||
%unignore tensorflow::DeviceFactory::AddDevices;
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ListDevices;
|
||||
|
||||
%include "tensorflow/core/common_runtime/device_factory.h"
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
std::vector<string> ListDevices(TF_Status* out_status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%insert("python") %{
|
||||
def list_devices():
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return ListDevices(status)
|
||||
%}
|
||||
|
||||
%unignoreall
|
||||
|
||||
|
@ -34,4 +34,5 @@ def list_local_devices():
|
||||
m = device_attributes_pb2.DeviceAttributes()
|
||||
m.ParseFromString(pb_str)
|
||||
return m
|
||||
return [_convert(s) for s in pywrap_tensorflow.DeviceFactory_AddDevices()]
|
||||
|
||||
return [_convert(s) for s in pywrap_tensorflow.list_devices()]
|
||||
|
@ -27,8 +27,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class DeviceLibTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# TODO(ebrevdo): fix python3 compatibility: b/27727661
|
||||
def _testListLocalDevices(self):
|
||||
def testListLocalDevices(self):
|
||||
devices = device_lib.list_local_devices()
|
||||
self.assertGreater(len(devices), 0)
|
||||
self.assertEqual(devices[0].device_type, "CPU")
|
||||
|
@ -80,7 +80,7 @@ bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) {
|
||||
}
|
||||
}
|
||||
|
||||
// Converts a C++ string vector to a Python string list.
|
||||
// Converts a C++ string vector to a a list of Python bytes objects.
|
||||
%typemap(out) std::vector<string> {
|
||||
const int size = $1.size();
|
||||
auto temp_string_list = tensorflow::make_safe(PyList_New(size));
|
||||
@ -90,11 +90,9 @@ bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) {
|
||||
tensorflow::Safe_PyObjectVector converted;
|
||||
converted.reserve(size);
|
||||
for (const string& op : $1) {
|
||||
%#if PY_MAJOR_VERSION >= 3
|
||||
PyObject* py_str = PyUnicode_FromStringAndSize(op.data(), op.size());
|
||||
%#else
|
||||
PyObject* py_str = PyString_FromStringAndSize(op.data(), op.size());
|
||||
%#endif
|
||||
// Always treat strings as bytes, consistent with the typemap
|
||||
// for string.
|
||||
PyObject* py_str = PyBytes_FromStringAndSize(op.data(), op.size());
|
||||
if (!py_str) {
|
||||
SWIG_fail;
|
||||
}
|
||||
|
@ -99,7 +99,11 @@ def get_matching_files(filename):
|
||||
errors.OpError: If there are filesystem / directory listing errors.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return pywrap_tensorflow.GetMatchingFiles(compat.as_bytes(filename), status)
|
||||
# Convert each element to string, since the return values of the
|
||||
# vector of string should be interpreted as strings, not bytes.
|
||||
return [compat.as_str(matching_filename)
|
||||
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
|
||||
compat.as_bytes(filename), status)]
|
||||
|
||||
|
||||
def create_dir(dirname):
|
||||
|
Loading…
Reference in New Issue
Block a user