Move Device and DeviceFactory to core/framework. https://github.com/tensorflow/tensorflow/pull/43611/ needs to access device registry to get a subtype for a given device type as a part of kernel registration. Kernel registration is a part of framework and should not depend on common_runtime to avoid a circular dependency.

This change was originally rolled back due to "CPU Factory not registered. Did you link in threadpool_device?" on Windows. The error happened because both tensorflow.dll and pywrap_tfe.dll contained device_factory implementation. CPU was registered in tensorflow.dll but queried in pywrap_tfe.dll. This should be fixed now by cleaning up pywrap_tfe.dll dependencies (f5af1da4f8).

PiperOrigin-RevId: 337585576
Change-Id: Idec2c8ad3ad32e02b972ab0438cc3c232b6be532
This commit is contained in:
Anna R 2020-10-16 15:06:20 -07:00 committed by TensorFlower Gardener
parent 0997c7a789
commit c60ced548c
12 changed files with 375 additions and 336 deletions
tensorflow

View File

@ -418,7 +418,9 @@ tf_cuda_library(
"//tensorflow/core/framework:control_flow.h", # TODO(josh11b): Make internal?
"//tensorflow/core/framework:dataset.h",
"//tensorflow/core/framework:dataset_stateful_op_allowlist.h",
"//tensorflow/core/framework:device.h",
"//tensorflow/core/framework:device_base.h",
"//tensorflow/core/framework:device_factory.h",
"//tensorflow/core/framework:function.h",
"//tensorflow/core/framework:function_handle_cache.h",
"//tensorflow/core/framework:graph_def_util.h",
@ -1581,6 +1583,7 @@ filegroup(
"//tensorflow/core/example:feature_util.h",
"//tensorflow/core/framework:framework_internal_private_hdrs",
"//tensorflow/core/graph:framework_internal_private_headers",
"//tensorflow/core/public:session_options.h",
"//tensorflow/core/util:framework_internal_private_hdrs",
"//tensorflow/core/util:memmapped_file_system_hdrs",
"//tensorflow/core/util/sparse:framework_internal_private_headers_group",

View File

@ -499,28 +499,19 @@ cc_library(
cc_library(
name = "device",
srcs = ["device.cc"],
hdrs = ["device.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "device_factory",
srcs = ["device_factory.cc"],
hdrs = ["device_factory.h"],
copts = tf_copts(),
deps = [
":device",
":session_options",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:framework_internal",
],
)

View File

@ -12,191 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A Device is a something that can perform computations as part of a
// model. Devices can be local (runs computation on this machine), or
// remote (contacts a device local to another machine using an RPC to
// do the work). Devices are registered in a DeviceSet, which is also
// responsible for the Device <-> id mapping.
//
// Device names
// * Every Device should have a unique name with the format:
// /job:___/replica:___/task:___/(gpu|cpu):___
// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
// * Task numbers are within the specified replica, so there are as
// many "task zeros" as replicas.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
#include <memory>
#include <string>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_segment.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class Device : public DeviceBase {
public:
// Callback type that takes a Status and returns void.
typedef std::function<void(const Status&)> DoneCallback;
Device(Env* env, const DeviceAttributes& device_attributes);
~Device() override;
// Full name of this device (see top comment).
const std::string& name() const override { return device_attributes_.name(); }
// Parsed name of this device
const DeviceNameUtils::ParsedName& parsed_name() const {
return parsed_name_;
}
// Describes what kind of device this is. This is intended to be
// human-readable and not computer-parsed, except that two devices
// with the same device_type() are expected to perform similarly
// (both from a computation and communication perspective).
const std::string& device_type() const {
return device_attributes_.device_type();
}
// Returns an aggregation of device attributes.
const DeviceAttributes& attributes() const override {
return device_attributes_;
}
// Performs the actual compute function.
//
// Subclasses may override this function if they wish to perform
// some initialization before each compute.
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
op_kernel->Compute(context);
}
// Asynchronous kernel's compute.
virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
op_kernel->ComputeAsync(context, std::move(done));
}
// Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device
// at completion.
virtual Status Sync() = 0;
// Calls the given callback when all operations queued on the device at the
// time of the call have completed. The callback is passed any error pending
// on the device at completion.
// TODO(b/112409994): Consolidate these two APIs, removing the synchronous
// version.
virtual void Sync(const DoneCallback& done);
// On session completion, the executor may call Device::Sync() depending on
// flag settings. Override this to return false for devices that don't allow
// such calls. Instead, these devices must use other mechanisms (such as
// num_deferred_ops) to ensure the device has finished processing necessary
// work at session completion. In addition, for these devices, RefreshStatus
// must be called at session completion to retrieve execution result status.
//
// Devices that override this function must also implement RefreshStatus.
virtual bool AllowsSyncOnCompletion() const { return true; }
// This is used in conjunction with AllowsSyncOnCompletion to allow the
// executor to get execution result status at session completion.
//
// For supported devices, this call returns the underlying device stream's
// current status in a non-blocking way, without using blocking calls such as
// Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
// status is also updated with the retrieved stream status.
virtual Status RefreshStatus() {
return errors::Unimplemented(
"RefreshStatus is not supported on this device.");
}
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
// prototyping of TensorFlow device implementations that need to modify
// the GraphDef before execution.
//
// 'graph' supplies the partition of the graph assigned to this
// device.
virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
return Status::OK();
}
// Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
// if the device does not support contexts. Returns an error status if any
// error occurred while trying to create a context, otherwise OK.
//
// The caller takes ownership of one reference on the output DeviceContext*,
// and should call Unref().
virtual Status TryGetDeviceContext(DeviceContext** out_context) {
*out_context = nullptr;
return Status::OK();
}
// Returns the op segment of this device. The caller can reuse op
// kernels registered for the same session running on this device.
OpSegment* op_segment() { return &op_seg_; }
// Returns the resource manager associated w/ this device.
virtual ResourceMgr* resource_manager() { return rmgr_; }
// Summarizes the status of this Device, for debugging.
std::string DebugString() const { return device_attributes_.DebugString(); }
// Assembles the parameter components into a complete DeviceAttributes value.
static DeviceAttributes BuildDeviceAttributes(
const std::string& name, DeviceType device, Bytes memory_limit,
const DeviceLocality& locality, const std::string& physical_device_desc);
static DeviceAttributes BuildDeviceAttributes(
const std::string& name, DeviceType device, Bytes memory_limit,
const DeviceLocality& locality) {
// Pass in an empty string as physical device name.
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
}
// Clears the resource manager associated with this device.
void ClearResourceMgr() { rmgr_->Clear(); }
virtual bool IsLocal() const { return true; }
protected:
void DeleteResourceMgr() {
delete rmgr_;
rmgr_ = nullptr;
}
private:
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;
// op_seg_ maps session handle and op name to OpKernel objects.
OpSegment op_seg_;
// Resources associated w/ this device. E.g., shared variables, etc.
ResourceMgr* rmgr_ = nullptr;
TF_DISALLOW_COPY_AND_ASSIGN(Device);
};
} // namespace tensorflow
#include "tensorflow/core/framework/device.h"
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_

View File

@ -12,140 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
#include <string>
#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class Device;
struct SessionOptions;
class DeviceFactory {
public:
virtual ~DeviceFactory() {}
static void Register(const std::string& device_type, DeviceFactory* factory,
int priority);
static DeviceFactory* GetFactory(const std::string& device_type);
// Append to "*devices" all suitable devices, respecting
// any device type specific properties/counts listed in "options".
//
// CPU devices are added first.
static Status AddDevices(const SessionOptions& options,
const std::string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices);
// Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments.
static std::unique_ptr<Device> NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix);
// Iterate through all device factories and build a list of all of the
// possible physical devices.
//
// CPU is are added first.
static Status ListAllPhysicalDevices(std::vector<string>* devices);
// Get details for a specific device among all device factories.
// 'device_index' indexes into devices from ListAllPhysicalDevices.
static Status GetAnyDeviceDetails(
int device_index, std::unordered_map<string, string>* details);
// For a specific device factory list all possible physical devices.
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
// Get details for a specific device for a specific factory. Subclasses
// can store arbitrary device information in the map. 'device_index' indexes
// into devices from ListPhysicalDevices.
virtual Status GetDeviceDetails(int device_index,
std::unordered_map<string, string>* details) {
return Status::OK();
}
// Most clients should call AddDevices() instead.
virtual Status CreateDevices(
const SessionOptions& options, const std::string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) = 0;
// Return the device priority number for a "device_type" string.
//
// Higher number implies higher priority.
//
// In standard TensorFlow distributions, GPU device types are
// preferred over CPU, and by default, custom devices that don't set
// a custom priority during registration will be prioritized lower
// than CPU. Custom devices that want a higher priority can set the
// 'priority' field when registering their device to something
// higher than the packaged devices. See calls to
// REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
// for built-in devices.
static int32 DevicePriority(const std::string& device_type);
};
namespace dfactory {
template <class Factory>
class Registrar {
public:
// Multiple registrations for the same device type with different priorities
// are allowed. Priorities are used in two different ways:
//
// 1) When choosing which factory (that is, which device
// implementation) to use for a specific 'device_type', the
// factory registered with the highest priority will be chosen.
// For example, if there are two registrations:
//
// Registrar<CPUFactory1>("CPU", 125);
// Registrar<CPUFactory2>("CPU", 150);
//
// then CPUFactory2 will be chosen when
// DeviceFactory::GetFactory("CPU") is called.
//
// 2) When choosing which 'device_type' is preferred over other
// DeviceTypes in a DeviceSet, the ordering is determined
// by the 'priority' set during registration. For example, if there
// are two registrations:
//
// Registrar<CPUFactory>("CPU", 100);
// Registrar<GPUFactory>("GPU", 200);
//
// then DeviceType("GPU") will be prioritized higher than
// DeviceType("CPU").
//
// The default priority values for built-in devices is:
// GPU: 210
// GPUCompatibleCPU: 70
// ThreadPoolDevice: 60
// Default: 50
explicit Registrar(const std::string& device_type, int priority = 50) {
DeviceFactory::Register(device_type, new Factory(), priority);
}
};
} // namespace dfactory
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
__COUNTER__, ##__VA_ARGS__)
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
ctr, ...) \
static ::tensorflow::dfactory::Registrar<device_factory> \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
##__VA_ARGS__)
// __COUNTER__ must go through another macro to be properly expanded
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
} // namespace tensorflow
#include "tensorflow/core/framework/device_factory.h"
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_

View File

@ -50,7 +50,9 @@ exports_files(
"control_flow.h",
"dataset.h",
"dataset_stateful_op_allowlist.h",
"device.h",
"device_base.h",
"device_factory.h",
"function.h",
"function_handle_cache.h",
"graph_def_util.h",
@ -178,7 +180,9 @@ filegroup(
"control_flow.h",
"dataset.h",
"dataset_stateful_op_allowlist.h",
"device.h",
"device_base.h",
"device_factory.h",
"function.h",
"function_handle_cache.h",
"graph_def_util.h",
@ -251,7 +255,9 @@ filegroup(
"cancellation.cc",
"collective.cc",
"dataset.cc",
"device.cc",
"device_base.cc",
"device_factory.cc",
"function.cc",
"function_handle_cache.cc",
"graph_def_util.cc",
@ -342,8 +348,12 @@ filegroup(
"dataset.cc",
"dataset.h",
"dataset_stateful_op_allowlist.h",
"device.cc",
"device.h",
"device_base.cc",
"device_base.h",
"device_factory.cc",
"device_factory.h",
"function.cc",
"function.h",
"function_handle_cache.cc",

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/op_segment.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {

View File

@ -0,0 +1,202 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A Device is a something that can perform computations as part of a
// model. Devices can be local (runs computation on this machine), or
// remote (contacts a device local to another machine using an RPC to
// do the work). Devices are registered in a DeviceSet, which is also
// responsible for the Device <-> id mapping.
//
// Device names
// * Every Device should have a unique name with the format:
// /job:___/replica:___/task:___/(gpu|cpu):___
// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
// * Task numbers are within the specified replica, so there are as
// many "task zeros" as replicas.
#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
#include <memory>
#include <string>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_segment.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class Device : public DeviceBase {
public:
// Callback type that takes a Status and returns void.
typedef std::function<void(const Status&)> DoneCallback;
Device(Env* env, const DeviceAttributes& device_attributes);
~Device() override;
// Full name of this device (see top comment).
const std::string& name() const override { return device_attributes_.name(); }
// Parsed name of this device
const DeviceNameUtils::ParsedName& parsed_name() const {
return parsed_name_;
}
// Describes what kind of device this is. This is intended to be
// human-readable and not computer-parsed, except that two devices
// with the same device_type() are expected to perform similarly
// (both from a computation and communication perspective).
const std::string& device_type() const {
return device_attributes_.device_type();
}
// Returns an aggregation of device attributes.
const DeviceAttributes& attributes() const override {
return device_attributes_;
}
// Performs the actual compute function.
//
// Subclasses may override this function if they wish to perform
// some initialization before each compute.
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
op_kernel->Compute(context);
}
// Asynchronous kernel's compute.
virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
op_kernel->ComputeAsync(context, std::move(done));
}
// Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device
// at completion.
virtual Status Sync() = 0;
// Calls the given callback when all operations queued on the device at the
// time of the call have completed. The callback is passed any error pending
// on the device at completion.
// TODO(b/112409994): Consolidate these two APIs, removing the synchronous
// version.
virtual void Sync(const DoneCallback& done);
// On session completion, the executor may call Device::Sync() depending on
// flag settings. Override this to return false for devices that don't allow
// such calls. Instead, these devices must use other mechanisms (such as
// num_deferred_ops) to ensure the device has finished processing necessary
// work at session completion. In addition, for these devices, RefreshStatus
// must be called at session completion to retrieve execution result status.
//
// Devices that override this function must also implement RefreshStatus.
virtual bool AllowsSyncOnCompletion() const { return true; }
// This is used in conjunction with AllowsSyncOnCompletion to allow the
// executor to get execution result status at session completion.
//
// For supported devices, this call returns the underlying device stream's
// current status in a non-blocking way, without using blocking calls such as
// Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
// status is also updated with the retrieved stream status.
virtual Status RefreshStatus() {
return errors::Unimplemented(
"RefreshStatus is not supported on this device.");
}
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
// prototyping of TensorFlow device implementations that need to modify
// the GraphDef before execution.
//
// 'graph' supplies the partition of the graph assigned to this
// device.
virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
return Status::OK();
}
// Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
// if the device does not support contexts. Returns an error status if any
// error occurred while trying to create a context, otherwise OK.
//
// The caller takes ownership of one reference on the output DeviceContext*,
// and should call Unref().
virtual Status TryGetDeviceContext(DeviceContext** out_context) {
*out_context = nullptr;
return Status::OK();
}
// Returns the op segment of this device. The caller can reuse op
// kernels registered for the same session running on this device.
OpSegment* op_segment() { return &op_seg_; }
// Returns the resource manager associated w/ this device.
virtual ResourceMgr* resource_manager() { return rmgr_; }
// Summarizes the status of this Device, for debugging.
std::string DebugString() const { return device_attributes_.DebugString(); }
// Assembles the parameter components into a complete DeviceAttributes value.
static DeviceAttributes BuildDeviceAttributes(
const std::string& name, DeviceType device, Bytes memory_limit,
const DeviceLocality& locality, const std::string& physical_device_desc);
static DeviceAttributes BuildDeviceAttributes(
const std::string& name, DeviceType device, Bytes memory_limit,
const DeviceLocality& locality) {
// Pass in an empty string as physical device name.
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
}
// Clears the resource manager associated with this device.
void ClearResourceMgr() { rmgr_->Clear(); }
virtual bool IsLocal() const { return true; }
protected:
void DeleteResourceMgr() {
delete rmgr_;
rmgr_ = nullptr;
}
private:
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;
// op_seg_ maps session handle and op name to OpKernel objects.
OpSegment op_seg_;
// Resources associated w/ this device. E.g., shared variables, etc.
ResourceMgr* rmgr_ = nullptr;
TF_DISALLOW_COPY_AND_ASSIGN(Device);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -0,0 +1,151 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
#include <string>
#include <vector>
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class Device;
struct SessionOptions;
class DeviceFactory {
public:
virtual ~DeviceFactory() {}
static void Register(const std::string& device_type, DeviceFactory* factory,
int priority);
static DeviceFactory* GetFactory(const std::string& device_type);
// Append to "*devices" all suitable devices, respecting
// any device type specific properties/counts listed in "options".
//
// CPU devices are added first.
static Status AddDevices(const SessionOptions& options,
const std::string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices);
// Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments.
static std::unique_ptr<Device> NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix);
// Iterate through all device factories and build a list of all of the
// possible physical devices.
//
// CPU is are added first.
static Status ListAllPhysicalDevices(std::vector<string>* devices);
// Get details for a specific device among all device factories.
// 'device_index' indexes into devices from ListAllPhysicalDevices.
static Status GetAnyDeviceDetails(
int device_index, std::unordered_map<string, string>* details);
// For a specific device factory list all possible physical devices.
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
// Get details for a specific device for a specific factory. Subclasses
// can store arbitrary device information in the map. 'device_index' indexes
// into devices from ListPhysicalDevices.
virtual Status GetDeviceDetails(int device_index,
std::unordered_map<string, string>* details) {
return Status::OK();
}
// Most clients should call AddDevices() instead.
virtual Status CreateDevices(
const SessionOptions& options, const std::string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) = 0;
// Return the device priority number for a "device_type" string.
//
// Higher number implies higher priority.
//
// In standard TensorFlow distributions, GPU device types are
// preferred over CPU, and by default, custom devices that don't set
// a custom priority during registration will be prioritized lower
// than CPU. Custom devices that want a higher priority can set the
// 'priority' field when registering their device to something
// higher than the packaged devices. See calls to
// REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
// for built-in devices.
static int32 DevicePriority(const std::string& device_type);
};
namespace dfactory {
template <class Factory>
class Registrar {
public:
// Multiple registrations for the same device type with different priorities
// are allowed. Priorities are used in two different ways:
//
// 1) When choosing which factory (that is, which device
// implementation) to use for a specific 'device_type', the
// factory registered with the highest priority will be chosen.
// For example, if there are two registrations:
//
// Registrar<CPUFactory1>("CPU", 125);
// Registrar<CPUFactory2>("CPU", 150);
//
// then CPUFactory2 will be chosen when
// DeviceFactory::GetFactory("CPU") is called.
//
// 2) When choosing which 'device_type' is preferred over other
// DeviceTypes in a DeviceSet, the ordering is determined
// by the 'priority' set during registration. For example, if there
// are two registrations:
//
// Registrar<CPUFactory>("CPU", 100);
// Registrar<GPUFactory>("GPU", 200);
//
// then DeviceType("GPU") will be prioritized higher than
// DeviceType("CPU").
//
// The default priority values for built-in devices is:
// GPU: 210
// GPUCompatibleCPU: 70
// ThreadPoolDevice: 60
// Default: 50
explicit Registrar(const std::string& device_type, int priority = 50) {
DeviceFactory::Register(device_type, new Factory(), priority);
}
};
} // namespace dfactory
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
__COUNTER__, ##__VA_ARGS__)
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
ctr, ...) \
static ::tensorflow::dfactory::Registrar<device_factory> \
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
##__VA_ARGS__)
// __COUNTER__ must go through another macro to be properly expanded
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_

View File

@ -64,6 +64,7 @@ filegroup(
"graph_node_util.h",
"node_builder.h",
"tensor_id.h",
"types.h",
],
)

View File

@ -100,8 +100,6 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:ops_testutil",
],
@ -119,8 +117,6 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:ops_testutil",
],

View File

@ -6125,8 +6125,6 @@ filegroup(
"//tensorflow/compiler/jit:get_compiler_ir", #tfe
"//tensorflow/compiler/jit:flags", #tfe
"//tensorflow/compiler/mlir/python:mlir", # mlir
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
"//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session
"//tensorflow/core/common_runtime:graph_constructor", # tf_session
"//tensorflow/core/common_runtime:quantize_training", # quantize_training
"//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session