Move Device and DeviceFactory to core/framework. https://github.com/tensorflow/tensorflow/pull/43611/ needs to access device registry to get a subtype for a given device type as a part of kernel registration. Kernel registration is a part of framework and should not depend on common_runtime to avoid a circular dependency.
This change was originally rolled back due to "CPU Factory not registered. Did you link in threadpool_device?" on Windows. The error happened because both tensorflow.dll and pywrap_tfe.dll contained device_factory implementation. CPU was registered in tensorflow.dll but queried in pywrap_tfe.dll. This should be fixed now by cleaning up pywrap_tfe.dll dependencies (f5af1da4f8
).
PiperOrigin-RevId: 337585576
Change-Id: Idec2c8ad3ad32e02b972ab0438cc3c232b6be532
This commit is contained in:
parent
0997c7a789
commit
c60ced548c
tensorflow
@ -418,7 +418,9 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:control_flow.h", # TODO(josh11b): Make internal?
|
||||
"//tensorflow/core/framework:dataset.h",
|
||||
"//tensorflow/core/framework:dataset_stateful_op_allowlist.h",
|
||||
"//tensorflow/core/framework:device.h",
|
||||
"//tensorflow/core/framework:device_base.h",
|
||||
"//tensorflow/core/framework:device_factory.h",
|
||||
"//tensorflow/core/framework:function.h",
|
||||
"//tensorflow/core/framework:function_handle_cache.h",
|
||||
"//tensorflow/core/framework:graph_def_util.h",
|
||||
@ -1581,6 +1583,7 @@ filegroup(
|
||||
"//tensorflow/core/example:feature_util.h",
|
||||
"//tensorflow/core/framework:framework_internal_private_hdrs",
|
||||
"//tensorflow/core/graph:framework_internal_private_headers",
|
||||
"//tensorflow/core/public:session_options.h",
|
||||
"//tensorflow/core/util:framework_internal_private_hdrs",
|
||||
"//tensorflow/core/util:memmapped_file_system_hdrs",
|
||||
"//tensorflow/core/util/sparse:framework_internal_private_headers_group",
|
||||
|
@ -499,28 +499,19 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "device",
|
||||
srcs = ["device.cc"],
|
||||
hdrs = ["device.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_factory",
|
||||
srcs = ["device_factory.cc"],
|
||||
hdrs = ["device_factory.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":session_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,191 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A Device is a something that can perform computations as part of a
|
||||
// model. Devices can be local (runs computation on this machine), or
|
||||
// remote (contacts a device local to another machine using an RPC to
|
||||
// do the work). Devices are registered in a DeviceSet, which is also
|
||||
// responsible for the Device <-> id mapping.
|
||||
//
|
||||
// Device names
|
||||
// * Every Device should have a unique name with the format:
|
||||
// /job:___/replica:___/task:___/(gpu|cpu):___
|
||||
// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
|
||||
// * Task numbers are within the specified replica, so there are as
|
||||
// many "task zeros" as replicas.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_segment.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device : public DeviceBase {
|
||||
public:
|
||||
// Callback type that takes a Status and returns void.
|
||||
typedef std::function<void(const Status&)> DoneCallback;
|
||||
|
||||
Device(Env* env, const DeviceAttributes& device_attributes);
|
||||
~Device() override;
|
||||
|
||||
// Full name of this device (see top comment).
|
||||
const std::string& name() const override { return device_attributes_.name(); }
|
||||
|
||||
// Parsed name of this device
|
||||
const DeviceNameUtils::ParsedName& parsed_name() const {
|
||||
return parsed_name_;
|
||||
}
|
||||
|
||||
// Describes what kind of device this is. This is intended to be
|
||||
// human-readable and not computer-parsed, except that two devices
|
||||
// with the same device_type() are expected to perform similarly
|
||||
// (both from a computation and communication perspective).
|
||||
const std::string& device_type() const {
|
||||
return device_attributes_.device_type();
|
||||
}
|
||||
|
||||
// Returns an aggregation of device attributes.
|
||||
const DeviceAttributes& attributes() const override {
|
||||
return device_attributes_;
|
||||
}
|
||||
|
||||
// Performs the actual compute function.
|
||||
//
|
||||
// Subclasses may override this function if they wish to perform
|
||||
// some initialization before each compute.
|
||||
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
// Asynchronous kernel's compute.
|
||||
virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
op_kernel->ComputeAsync(context, std::move(done));
|
||||
}
|
||||
|
||||
// Blocks until all operations queued on the device at the time of
|
||||
// the call have completed. Returns any error pending on the device
|
||||
// at completion.
|
||||
virtual Status Sync() = 0;
|
||||
|
||||
// Calls the given callback when all operations queued on the device at the
|
||||
// time of the call have completed. The callback is passed any error pending
|
||||
// on the device at completion.
|
||||
// TODO(b/112409994): Consolidate these two APIs, removing the synchronous
|
||||
// version.
|
||||
virtual void Sync(const DoneCallback& done);
|
||||
|
||||
// On session completion, the executor may call Device::Sync() depending on
|
||||
// flag settings. Override this to return false for devices that don't allow
|
||||
// such calls. Instead, these devices must use other mechanisms (such as
|
||||
// num_deferred_ops) to ensure the device has finished processing necessary
|
||||
// work at session completion. In addition, for these devices, RefreshStatus
|
||||
// must be called at session completion to retrieve execution result status.
|
||||
//
|
||||
// Devices that override this function must also implement RefreshStatus.
|
||||
virtual bool AllowsSyncOnCompletion() const { return true; }
|
||||
|
||||
// This is used in conjunction with AllowsSyncOnCompletion to allow the
|
||||
// executor to get execution result status at session completion.
|
||||
//
|
||||
// For supported devices, this call returns the underlying device stream's
|
||||
// current status in a non-blocking way, without using blocking calls such as
|
||||
// Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
|
||||
// status is also updated with the retrieved stream status.
|
||||
virtual Status RefreshStatus() {
|
||||
return errors::Unimplemented(
|
||||
"RefreshStatus is not supported on this device.");
|
||||
}
|
||||
|
||||
// Optionally modify the device's GraphDef before execution.
|
||||
//
|
||||
// This method should be considered experimental and is supplied to enable
|
||||
// prototyping of TensorFlow device implementations that need to modify
|
||||
// the GraphDef before execution.
|
||||
//
|
||||
// 'graph' supplies the partition of the graph assigned to this
|
||||
// device.
|
||||
virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
|
||||
// if the device does not support contexts. Returns an error status if any
|
||||
// error occurred while trying to create a context, otherwise OK.
|
||||
//
|
||||
// The caller takes ownership of one reference on the output DeviceContext*,
|
||||
// and should call Unref().
|
||||
virtual Status TryGetDeviceContext(DeviceContext** out_context) {
|
||||
*out_context = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns the op segment of this device. The caller can reuse op
|
||||
// kernels registered for the same session running on this device.
|
||||
OpSegment* op_segment() { return &op_seg_; }
|
||||
|
||||
// Returns the resource manager associated w/ this device.
|
||||
virtual ResourceMgr* resource_manager() { return rmgr_; }
|
||||
|
||||
// Summarizes the status of this Device, for debugging.
|
||||
std::string DebugString() const { return device_attributes_.DebugString(); }
|
||||
|
||||
// Assembles the parameter components into a complete DeviceAttributes value.
|
||||
static DeviceAttributes BuildDeviceAttributes(
|
||||
const std::string& name, DeviceType device, Bytes memory_limit,
|
||||
const DeviceLocality& locality, const std::string& physical_device_desc);
|
||||
|
||||
static DeviceAttributes BuildDeviceAttributes(
|
||||
const std::string& name, DeviceType device, Bytes memory_limit,
|
||||
const DeviceLocality& locality) {
|
||||
// Pass in an empty string as physical device name.
|
||||
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
|
||||
}
|
||||
|
||||
// Clears the resource manager associated with this device.
|
||||
void ClearResourceMgr() { rmgr_->Clear(); }
|
||||
|
||||
virtual bool IsLocal() const { return true; }
|
||||
|
||||
protected:
|
||||
void DeleteResourceMgr() {
|
||||
delete rmgr_;
|
||||
rmgr_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
const DeviceAttributes device_attributes_;
|
||||
DeviceNameUtils::ParsedName parsed_name_;
|
||||
|
||||
// op_seg_ maps session handle and op name to OpKernel objects.
|
||||
OpSegment op_seg_;
|
||||
|
||||
// Resources associated w/ this device. E.g., shared variables, etc.
|
||||
ResourceMgr* rmgr_ = nullptr;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Device);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#include "tensorflow/core/framework/device.h"
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
|
||||
|
@ -12,140 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
struct SessionOptions;
|
||||
|
||||
class DeviceFactory {
|
||||
public:
|
||||
virtual ~DeviceFactory() {}
|
||||
static void Register(const std::string& device_type, DeviceFactory* factory,
|
||||
int priority);
|
||||
static DeviceFactory* GetFactory(const std::string& device_type);
|
||||
|
||||
// Append to "*devices" all suitable devices, respecting
|
||||
// any device type specific properties/counts listed in "options".
|
||||
//
|
||||
// CPU devices are added first.
|
||||
static Status AddDevices(const SessionOptions& options,
|
||||
const std::string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices);
|
||||
|
||||
// Helper for tests. Create a single device of type "type". The
|
||||
// returned device is always numbered zero, so if creating multiple
|
||||
// devices of the same type, supply distinct name_prefix arguments.
|
||||
static std::unique_ptr<Device> NewDevice(const string& type,
|
||||
const SessionOptions& options,
|
||||
const string& name_prefix);
|
||||
|
||||
// Iterate through all device factories and build a list of all of the
|
||||
// possible physical devices.
|
||||
//
|
||||
// CPU is are added first.
|
||||
static Status ListAllPhysicalDevices(std::vector<string>* devices);
|
||||
|
||||
// Get details for a specific device among all device factories.
|
||||
// 'device_index' indexes into devices from ListAllPhysicalDevices.
|
||||
static Status GetAnyDeviceDetails(
|
||||
int device_index, std::unordered_map<string, string>* details);
|
||||
|
||||
// For a specific device factory list all possible physical devices.
|
||||
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
|
||||
|
||||
// Get details for a specific device for a specific factory. Subclasses
|
||||
// can store arbitrary device information in the map. 'device_index' indexes
|
||||
// into devices from ListPhysicalDevices.
|
||||
virtual Status GetDeviceDetails(int device_index,
|
||||
std::unordered_map<string, string>* details) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Most clients should call AddDevices() instead.
|
||||
virtual Status CreateDevices(
|
||||
const SessionOptions& options, const std::string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) = 0;
|
||||
|
||||
// Return the device priority number for a "device_type" string.
|
||||
//
|
||||
// Higher number implies higher priority.
|
||||
//
|
||||
// In standard TensorFlow distributions, GPU device types are
|
||||
// preferred over CPU, and by default, custom devices that don't set
|
||||
// a custom priority during registration will be prioritized lower
|
||||
// than CPU. Custom devices that want a higher priority can set the
|
||||
// 'priority' field when registering their device to something
|
||||
// higher than the packaged devices. See calls to
|
||||
// REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
|
||||
// for built-in devices.
|
||||
static int32 DevicePriority(const std::string& device_type);
|
||||
};
|
||||
|
||||
namespace dfactory {
|
||||
|
||||
template <class Factory>
|
||||
class Registrar {
|
||||
public:
|
||||
// Multiple registrations for the same device type with different priorities
|
||||
// are allowed. Priorities are used in two different ways:
|
||||
//
|
||||
// 1) When choosing which factory (that is, which device
|
||||
// implementation) to use for a specific 'device_type', the
|
||||
// factory registered with the highest priority will be chosen.
|
||||
// For example, if there are two registrations:
|
||||
//
|
||||
// Registrar<CPUFactory1>("CPU", 125);
|
||||
// Registrar<CPUFactory2>("CPU", 150);
|
||||
//
|
||||
// then CPUFactory2 will be chosen when
|
||||
// DeviceFactory::GetFactory("CPU") is called.
|
||||
//
|
||||
// 2) When choosing which 'device_type' is preferred over other
|
||||
// DeviceTypes in a DeviceSet, the ordering is determined
|
||||
// by the 'priority' set during registration. For example, if there
|
||||
// are two registrations:
|
||||
//
|
||||
// Registrar<CPUFactory>("CPU", 100);
|
||||
// Registrar<GPUFactory>("GPU", 200);
|
||||
//
|
||||
// then DeviceType("GPU") will be prioritized higher than
|
||||
// DeviceType("CPU").
|
||||
//
|
||||
// The default priority values for built-in devices is:
|
||||
// GPU: 210
|
||||
// GPUCompatibleCPU: 70
|
||||
// ThreadPoolDevice: 60
|
||||
// Default: 50
|
||||
explicit Registrar(const std::string& device_type, int priority = 50) {
|
||||
DeviceFactory::Register(device_type, new Factory(), priority);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace dfactory
|
||||
|
||||
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
|
||||
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
|
||||
__COUNTER__, ##__VA_ARGS__)
|
||||
|
||||
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
|
||||
ctr, ...) \
|
||||
static ::tensorflow::dfactory::Registrar<device_factory> \
|
||||
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
|
||||
##__VA_ARGS__)
|
||||
|
||||
// __COUNTER__ must go through another macro to be properly expanded
|
||||
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
|
||||
|
||||
} // namespace tensorflow
|
||||
#include "tensorflow/core/framework/device_factory.h"
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
|
||||
|
@ -50,7 +50,9 @@ exports_files(
|
||||
"control_flow.h",
|
||||
"dataset.h",
|
||||
"dataset_stateful_op_allowlist.h",
|
||||
"device.h",
|
||||
"device_base.h",
|
||||
"device_factory.h",
|
||||
"function.h",
|
||||
"function_handle_cache.h",
|
||||
"graph_def_util.h",
|
||||
@ -178,7 +180,9 @@ filegroup(
|
||||
"control_flow.h",
|
||||
"dataset.h",
|
||||
"dataset_stateful_op_allowlist.h",
|
||||
"device.h",
|
||||
"device_base.h",
|
||||
"device_factory.h",
|
||||
"function.h",
|
||||
"function_handle_cache.h",
|
||||
"graph_def_util.h",
|
||||
@ -251,7 +255,9 @@ filegroup(
|
||||
"cancellation.cc",
|
||||
"collective.cc",
|
||||
"dataset.cc",
|
||||
"device.cc",
|
||||
"device_base.cc",
|
||||
"device_factory.cc",
|
||||
"function.cc",
|
||||
"function_handle_cache.cc",
|
||||
"graph_def_util.cc",
|
||||
@ -342,8 +348,12 @@ filegroup(
|
||||
"dataset.cc",
|
||||
"dataset.h",
|
||||
"dataset_stateful_op_allowlist.h",
|
||||
"device.cc",
|
||||
"device.h",
|
||||
"device_base.cc",
|
||||
"device_base.h",
|
||||
"device_factory.cc",
|
||||
"device_factory.h",
|
||||
"function.cc",
|
||||
"function.h",
|
||||
"function_handle_cache.cc",
|
||||
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_segment.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
202
tensorflow/core/framework/device.h
Normal file
202
tensorflow/core/framework/device.h
Normal file
@ -0,0 +1,202 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A Device is a something that can perform computations as part of a
|
||||
// model. Devices can be local (runs computation on this machine), or
|
||||
// remote (contacts a device local to another machine using an RPC to
|
||||
// do the work). Devices are registered in a DeviceSet, which is also
|
||||
// responsible for the Device <-> id mapping.
|
||||
//
|
||||
// Device names
|
||||
// * Every Device should have a unique name with the format:
|
||||
// /job:___/replica:___/task:___/(gpu|cpu):___
|
||||
// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
|
||||
// * Task numbers are within the specified replica, so there are as
|
||||
// many "task zeros" as replicas.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_segment.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device : public DeviceBase {
|
||||
public:
|
||||
// Callback type that takes a Status and returns void.
|
||||
typedef std::function<void(const Status&)> DoneCallback;
|
||||
|
||||
Device(Env* env, const DeviceAttributes& device_attributes);
|
||||
~Device() override;
|
||||
|
||||
// Full name of this device (see top comment).
|
||||
const std::string& name() const override { return device_attributes_.name(); }
|
||||
|
||||
// Parsed name of this device
|
||||
const DeviceNameUtils::ParsedName& parsed_name() const {
|
||||
return parsed_name_;
|
||||
}
|
||||
|
||||
// Describes what kind of device this is. This is intended to be
|
||||
// human-readable and not computer-parsed, except that two devices
|
||||
// with the same device_type() are expected to perform similarly
|
||||
// (both from a computation and communication perspective).
|
||||
const std::string& device_type() const {
|
||||
return device_attributes_.device_type();
|
||||
}
|
||||
|
||||
// Returns an aggregation of device attributes.
|
||||
const DeviceAttributes& attributes() const override {
|
||||
return device_attributes_;
|
||||
}
|
||||
|
||||
// Performs the actual compute function.
|
||||
//
|
||||
// Subclasses may override this function if they wish to perform
|
||||
// some initialization before each compute.
|
||||
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
// Asynchronous kernel's compute.
|
||||
virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
op_kernel->ComputeAsync(context, std::move(done));
|
||||
}
|
||||
|
||||
// Blocks until all operations queued on the device at the time of
|
||||
// the call have completed. Returns any error pending on the device
|
||||
// at completion.
|
||||
virtual Status Sync() = 0;
|
||||
|
||||
// Calls the given callback when all operations queued on the device at the
|
||||
// time of the call have completed. The callback is passed any error pending
|
||||
// on the device at completion.
|
||||
// TODO(b/112409994): Consolidate these two APIs, removing the synchronous
|
||||
// version.
|
||||
virtual void Sync(const DoneCallback& done);
|
||||
|
||||
// On session completion, the executor may call Device::Sync() depending on
|
||||
// flag settings. Override this to return false for devices that don't allow
|
||||
// such calls. Instead, these devices must use other mechanisms (such as
|
||||
// num_deferred_ops) to ensure the device has finished processing necessary
|
||||
// work at session completion. In addition, for these devices, RefreshStatus
|
||||
// must be called at session completion to retrieve execution result status.
|
||||
//
|
||||
// Devices that override this function must also implement RefreshStatus.
|
||||
virtual bool AllowsSyncOnCompletion() const { return true; }
|
||||
|
||||
// This is used in conjunction with AllowsSyncOnCompletion to allow the
|
||||
// executor to get execution result status at session completion.
|
||||
//
|
||||
// For supported devices, this call returns the underlying device stream's
|
||||
// current status in a non-blocking way, without using blocking calls such as
|
||||
// Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
|
||||
// status is also updated with the retrieved stream status.
|
||||
virtual Status RefreshStatus() {
|
||||
return errors::Unimplemented(
|
||||
"RefreshStatus is not supported on this device.");
|
||||
}
|
||||
|
||||
// Optionally modify the device's GraphDef before execution.
|
||||
//
|
||||
// This method should be considered experimental and is supplied to enable
|
||||
// prototyping of TensorFlow device implementations that need to modify
|
||||
// the GraphDef before execution.
|
||||
//
|
||||
// 'graph' supplies the partition of the graph assigned to this
|
||||
// device.
|
||||
virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
|
||||
// if the device does not support contexts. Returns an error status if any
|
||||
// error occurred while trying to create a context, otherwise OK.
|
||||
//
|
||||
// The caller takes ownership of one reference on the output DeviceContext*,
|
||||
// and should call Unref().
|
||||
virtual Status TryGetDeviceContext(DeviceContext** out_context) {
|
||||
*out_context = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns the op segment of this device. The caller can reuse op
|
||||
// kernels registered for the same session running on this device.
|
||||
OpSegment* op_segment() { return &op_seg_; }
|
||||
|
||||
// Returns the resource manager associated w/ this device.
|
||||
virtual ResourceMgr* resource_manager() { return rmgr_; }
|
||||
|
||||
// Summarizes the status of this Device, for debugging.
|
||||
std::string DebugString() const { return device_attributes_.DebugString(); }
|
||||
|
||||
// Assembles the parameter components into a complete DeviceAttributes value.
|
||||
static DeviceAttributes BuildDeviceAttributes(
|
||||
const std::string& name, DeviceType device, Bytes memory_limit,
|
||||
const DeviceLocality& locality, const std::string& physical_device_desc);
|
||||
|
||||
static DeviceAttributes BuildDeviceAttributes(
|
||||
const std::string& name, DeviceType device, Bytes memory_limit,
|
||||
const DeviceLocality& locality) {
|
||||
// Pass in an empty string as physical device name.
|
||||
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
|
||||
}
|
||||
|
||||
// Clears the resource manager associated with this device.
|
||||
void ClearResourceMgr() { rmgr_->Clear(); }
|
||||
|
||||
virtual bool IsLocal() const { return true; }
|
||||
|
||||
protected:
|
||||
void DeleteResourceMgr() {
|
||||
delete rmgr_;
|
||||
rmgr_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
const DeviceAttributes device_attributes_;
|
||||
DeviceNameUtils::ParsedName parsed_name_;
|
||||
|
||||
// op_seg_ maps session handle and op name to OpKernel objects.
|
||||
OpSegment op_seg_;
|
||||
|
||||
// Resources associated w/ this device. E.g., shared variables, etc.
|
||||
ResourceMgr* rmgr_ = nullptr;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Device);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_factory.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
151
tensorflow/core/framework/device_factory.h
Normal file
151
tensorflow/core/framework/device_factory.h
Normal file
@ -0,0 +1,151 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
struct SessionOptions;
|
||||
|
||||
class DeviceFactory {
|
||||
public:
|
||||
virtual ~DeviceFactory() {}
|
||||
static void Register(const std::string& device_type, DeviceFactory* factory,
|
||||
int priority);
|
||||
static DeviceFactory* GetFactory(const std::string& device_type);
|
||||
|
||||
// Append to "*devices" all suitable devices, respecting
|
||||
// any device type specific properties/counts listed in "options".
|
||||
//
|
||||
// CPU devices are added first.
|
||||
static Status AddDevices(const SessionOptions& options,
|
||||
const std::string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices);
|
||||
|
||||
// Helper for tests. Create a single device of type "type". The
|
||||
// returned device is always numbered zero, so if creating multiple
|
||||
// devices of the same type, supply distinct name_prefix arguments.
|
||||
static std::unique_ptr<Device> NewDevice(const string& type,
|
||||
const SessionOptions& options,
|
||||
const string& name_prefix);
|
||||
|
||||
// Iterate through all device factories and build a list of all of the
|
||||
// possible physical devices.
|
||||
//
|
||||
// CPU is are added first.
|
||||
static Status ListAllPhysicalDevices(std::vector<string>* devices);
|
||||
|
||||
// Get details for a specific device among all device factories.
|
||||
// 'device_index' indexes into devices from ListAllPhysicalDevices.
|
||||
static Status GetAnyDeviceDetails(
|
||||
int device_index, std::unordered_map<string, string>* details);
|
||||
|
||||
// For a specific device factory list all possible physical devices.
|
||||
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
|
||||
|
||||
// Get details for a specific device for a specific factory. Subclasses
|
||||
// can store arbitrary device information in the map. 'device_index' indexes
|
||||
// into devices from ListPhysicalDevices.
|
||||
virtual Status GetDeviceDetails(int device_index,
|
||||
std::unordered_map<string, string>* details) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Most clients should call AddDevices() instead.
|
||||
virtual Status CreateDevices(
|
||||
const SessionOptions& options, const std::string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) = 0;
|
||||
|
||||
// Return the device priority number for a "device_type" string.
|
||||
//
|
||||
// Higher number implies higher priority.
|
||||
//
|
||||
// In standard TensorFlow distributions, GPU device types are
|
||||
// preferred over CPU, and by default, custom devices that don't set
|
||||
// a custom priority during registration will be prioritized lower
|
||||
// than CPU. Custom devices that want a higher priority can set the
|
||||
// 'priority' field when registering their device to something
|
||||
// higher than the packaged devices. See calls to
|
||||
// REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
|
||||
// for built-in devices.
|
||||
static int32 DevicePriority(const std::string& device_type);
|
||||
};
|
||||
|
||||
namespace dfactory {
|
||||
|
||||
template <class Factory>
|
||||
class Registrar {
|
||||
public:
|
||||
// Multiple registrations for the same device type with different priorities
|
||||
// are allowed. Priorities are used in two different ways:
|
||||
//
|
||||
// 1) When choosing which factory (that is, which device
|
||||
// implementation) to use for a specific 'device_type', the
|
||||
// factory registered with the highest priority will be chosen.
|
||||
// For example, if there are two registrations:
|
||||
//
|
||||
// Registrar<CPUFactory1>("CPU", 125);
|
||||
// Registrar<CPUFactory2>("CPU", 150);
|
||||
//
|
||||
// then CPUFactory2 will be chosen when
|
||||
// DeviceFactory::GetFactory("CPU") is called.
|
||||
//
|
||||
// 2) When choosing which 'device_type' is preferred over other
|
||||
// DeviceTypes in a DeviceSet, the ordering is determined
|
||||
// by the 'priority' set during registration. For example, if there
|
||||
// are two registrations:
|
||||
//
|
||||
// Registrar<CPUFactory>("CPU", 100);
|
||||
// Registrar<GPUFactory>("GPU", 200);
|
||||
//
|
||||
// then DeviceType("GPU") will be prioritized higher than
|
||||
// DeviceType("CPU").
|
||||
//
|
||||
// The default priority values for built-in devices is:
|
||||
// GPU: 210
|
||||
// GPUCompatibleCPU: 70
|
||||
// ThreadPoolDevice: 60
|
||||
// Default: 50
|
||||
explicit Registrar(const std::string& device_type, int priority = 50) {
|
||||
DeviceFactory::Register(device_type, new Factory(), priority);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace dfactory
|
||||
|
||||
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
|
||||
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
|
||||
__COUNTER__, ##__VA_ARGS__)
|
||||
|
||||
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
|
||||
ctr, ...) \
|
||||
static ::tensorflow::dfactory::Registrar<device_factory> \
|
||||
INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
|
||||
##__VA_ARGS__)
|
||||
|
||||
// __COUNTER__ must go through another macro to be properly expanded
|
||||
#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
|
@ -64,6 +64,7 @@ filegroup(
|
||||
"graph_node_util.h",
|
||||
"node_builder.h",
|
||||
"tensor_id.h",
|
||||
"types.h",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -100,8 +100,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
@ -119,8 +117,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
|
@ -6125,8 +6125,6 @@ filegroup(
|
||||
"//tensorflow/compiler/jit:get_compiler_ir", #tfe
|
||||
"//tensorflow/compiler/jit:flags", #tfe
|
||||
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
||||
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
|
||||
"//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session
|
||||
"//tensorflow/core/common_runtime:graph_constructor", # tf_session
|
||||
"//tensorflow/core/common_runtime:quantize_training", # quantize_training
|
||||
"//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session
|
||||
|
Loading…
Reference in New Issue
Block a user