Add required TPU_SYSTEM device kernel registrations for TPUs
PiperOrigin-RevId: 319265707 Change-Id: I135e2607b4a212ccc51a3327cf913503547121f3
This commit is contained in:
parent
296c323683
commit
5e17b5a12c
@ -39,6 +39,7 @@ const char* const DEVICE_DEFAULT = "DEFAULT";
|
||||
const char* const DEVICE_CPU = "CPU";
|
||||
const char* const DEVICE_GPU = "GPU";
|
||||
const char* const DEVICE_SYCL = "SYCL";
|
||||
const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
|
||||
|
||||
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
|
@ -75,6 +75,7 @@ TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT"
|
||||
TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
|
||||
TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
|
||||
TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL"
|
||||
TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM"
|
||||
|
||||
template <typename Device>
|
||||
struct DeviceName {};
|
||||
|
@ -97,6 +97,7 @@ void ConstantOp::Compute(OpKernelContext* ctx) {
|
||||
ConstantOp::~ConstantOp() {}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_TPU_SYSTEM), ConstantOp);
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
|
@ -54,6 +54,8 @@ void SwitchNOp::Compute(OpKernelContext* context) {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Switch").Device(DEVICE_DEFAULT).HostMemory("pred"), SwitchOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_SwitchN").Device(DEVICE_DEFAULT).HostMemory("output_index"),
|
||||
@ -285,6 +287,8 @@ void MergeOp::Compute(OpKernelContext* context) {
|
||||
REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Merge").Device(DEVICE_DEFAULT).HostMemory("value_index"), MergeOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
@ -393,6 +397,7 @@ void EnterOp::Compute(OpKernelContext* context) {
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_DEFAULT), EnterOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
@ -489,6 +494,7 @@ void ExitOp::Compute(OpKernelContext* context) {
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_DEFAULT), ExitOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
@ -571,6 +577,8 @@ void NextIterationOp::Compute(OpKernelContext* context) {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_DEFAULT),
|
||||
NextIterationOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM),
|
||||
NextIterationOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
|
||||
NextIterationOp);
|
||||
|
||||
@ -665,10 +673,17 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond")
|
||||
.HostMemory("input")
|
||||
.HostMemory("output"),
|
||||
LoopCondOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("LoopCond")
|
||||
.Device(DEVICE_TPU_SYSTEM)
|
||||
.HostMemory("input")
|
||||
.HostMemory("output"),
|
||||
LoopCondOp);
|
||||
|
||||
// ControlTrigger kernel
|
||||
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_DEFAULT),
|
||||
ControlTriggerOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_TPU_SYSTEM),
|
||||
ControlTriggerOp);
|
||||
|
||||
// When called, abort op will abort the current process. This can be used to
|
||||
// abort remote PSs when needed.
|
||||
|
@ -89,6 +89,11 @@ REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||
|
||||
// TPU ops are only registered when they are required as part of the larger
|
||||
// TPU runtime, and does not need to be registered when selective registration
|
||||
// is turned on.
|
||||
REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -25,6 +25,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_TPU_SYSTEM),
|
||||
IdentityNOp);
|
||||
|
||||
// Do not worry about colocating IdentityN op with its resource inputs since
|
||||
// it just forwards it's inputs anyway. This is needed because we create
|
||||
// IdentityN nodes to club "all" outputs of functional ops while lowering to
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_TPU_SYSTEM), IdentityOp);
|
||||
|
||||
// StopGradient does the same thing as Identity, but has a different
|
||||
// gradient registered.
|
||||
REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp);
|
||||
|
@ -18,5 +18,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_DEFAULT), NoOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_TPU_SYSTEM), NoOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -123,6 +123,8 @@ string SendOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_TPU_SYSTEM), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_TPU_SYSTEM), SendOp);
|
||||
|
||||
// Public alias. Added for use in Lingvo.
|
||||
REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_CPU), SendOp);
|
||||
@ -215,6 +217,8 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_DEFAULT), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_TPU_SYSTEM), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_TPU_SYSTEM), RecvOp);
|
||||
|
||||
// Public alias. Added for use in Lingvo.
|
||||
REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_CPU), RecvOp);
|
||||
|
@ -20,7 +20,6 @@ namespace tensorflow {
|
||||
const char* const DEVICE_TPU_NODE = "TPU";
|
||||
const char* const TPU_FAST_MEM_ATTR = "_TPU_FAST_MEM";
|
||||
const char* const DEVICE_TPU_REPLICATED_CORE = "TPU_REPLICATED_CORE";
|
||||
const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
|
||||
const char* const DEVICE_TPU_XLA_JIT = "XLA_TPU_JIT";
|
||||
const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
|
||||
"_mirrored_variable_indices";
|
||||
|
@ -32,7 +32,7 @@ extern const char* const DEVICE_TPU_NODE; // "TPU";
|
||||
// TPUReplicate computation.
|
||||
extern const char* const DEVICE_TPU_REPLICATED_CORE;
|
||||
|
||||
extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM";
|
||||
// DEVICE_TPU_SYSTEM is now defined in tensorflow/core/framework/types.h/.cc
|
||||
|
||||
// Name of the XLA_TPU_JIT compilation device, which is an internal device to
|
||||
// compile graphs for TPU. Not registered as a device; no operators can be
|
||||
|
Loading…
Reference in New Issue
Block a user