diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index d6455e012d0..294f7a21557 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -39,6 +39,7 @@ const char* const DEVICE_DEFAULT = "DEFAULT"; const char* const DEVICE_CPU = "CPU"; const char* const DEVICE_GPU = "GPU"; const char* const DEVICE_SYCL = "SYCL"; +const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM"; const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU; #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 61575a7b735..fe52f8b2b59 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -71,10 +71,11 @@ class DeviceType { std::ostream& operator<<(std::ostream& os, const DeviceType& d); // Convenient constants that can be passed to a DeviceType constructor -TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT" -TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" -TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" -TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL" +TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT" +TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" +TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" +TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL" +TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM" template <typename Device> struct DeviceName {}; diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index dc178d17d49..376effc6535 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -97,6 +97,7 @@ void ConstantOp::Compute(OpKernelContext* ctx) { ConstantOp::~ConstantOp() {} REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); +REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_TPU_SYSTEM), ConstantOp); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 1a0082c6a3b..f886235a3f7 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -54,6 +54,8 @@ void SwitchNOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER( Name("Switch").Device(DEVICE_DEFAULT).HostMemory("pred"), SwitchOp); +REGISTER_KERNEL_BUILDER( + Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp); REGISTER_KERNEL_BUILDER( Name("_SwitchN").Device(DEVICE_DEFAULT).HostMemory("output_index"), @@ -285,6 +287,8 @@ void MergeOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); REGISTER_KERNEL_BUILDER( Name("Merge").Device(DEVICE_DEFAULT).HostMemory("value_index"), MergeOp); +REGISTER_KERNEL_BUILDER( + Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp); REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); #define REGISTER_GPU_KERNEL(type) \ @@ -393,6 +397,7 @@ void EnterOp::Compute(OpKernelContext* context) { } REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_DEFAULT), EnterOp); +REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp); REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); #define REGISTER_GPU_KERNEL(type) \ @@ -489,6 +494,7 @@ void ExitOp::Compute(OpKernelContext* context) { } REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_DEFAULT), ExitOp); +REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp); REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); #define REGISTER_GPU_KERNEL(type) \ @@ -571,6 +577,8 @@ void NextIterationOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_DEFAULT), NextIterationOp); +REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM), + NextIterationOp); REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU), NextIterationOp); @@ -665,10 +673,17 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond") .HostMemory("input") .HostMemory("output"), LoopCondOp); +REGISTER_KERNEL_BUILDER(Name("LoopCond") + .Device(DEVICE_TPU_SYSTEM) + .HostMemory("input") + .HostMemory("output"), + LoopCondOp); // ControlTrigger kernel REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_DEFAULT), ControlTriggerOp); +REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_TPU_SYSTEM), + ControlTriggerOp); // When called, abort op will abort the current process. This can be used to // abort remote PSs when needed. diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index dd312fbf3e6..d69292082bc 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -89,6 +89,11 @@ REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp); REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp); REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp); +// TPU ops are only registered when they are required as part of the larger +// TPU runtime, and does not need to be registered when selective registration +// is turned on. +REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp); + #if TENSORFLOW_USE_SYCL #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/core/kernels/identity_n_op.cc b/tensorflow/core/kernels/identity_n_op.cc index 746a29bf5aa..eed372630de 100644 --- a/tensorflow/core/kernels/identity_n_op.cc +++ b/tensorflow/core/kernels/identity_n_op.cc @@ -25,6 +25,9 @@ limitations under the License. namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp); +REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_TPU_SYSTEM), + IdentityNOp); + // Do not worry about colocating IdentityN op with its resource inputs since // it just forwards it's inputs anyway. This is needed because we create // IdentityN nodes to club "all" outputs of functional ops while lowering to diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 4b226dd72d4..aee7b545f79 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp); +REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_TPU_SYSTEM), IdentityOp); + // StopGradient does the same thing as Identity, but has a different // gradient registered. REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp); diff --git a/tensorflow/core/kernels/no_op.cc b/tensorflow/core/kernels/no_op.cc index d1a0d240f30..dbbd806c275 100644 --- a/tensorflow/core/kernels/no_op.cc +++ b/tensorflow/core/kernels/no_op.cc @@ -18,5 +18,6 @@ limitations under the License. namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_DEFAULT), NoOp); +REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_TPU_SYSTEM), NoOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index f82d4645f13..f4c4fae2910 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -123,6 +123,8 @@ string SendOp::TraceString(OpKernelContext* ctx, bool verbose) { REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp); +REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_TPU_SYSTEM), SendOp); +REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_TPU_SYSTEM), SendOp); // Public alias. Added for use in Lingvo. REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_CPU), SendOp); @@ -215,6 +217,8 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp); REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_DEFAULT), RecvOp); +REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_TPU_SYSTEM), RecvOp); +REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_TPU_SYSTEM), RecvOp); // Public alias. Added for use in Lingvo. REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_CPU), RecvOp); diff --git a/tensorflow/core/tpu/tpu_defs.cc b/tensorflow/core/tpu/tpu_defs.cc index dc370ea2ba7..ad7f02a3d95 100644 --- a/tensorflow/core/tpu/tpu_defs.cc +++ b/tensorflow/core/tpu/tpu_defs.cc @@ -20,7 +20,6 @@ namespace tensorflow { const char* const DEVICE_TPU_NODE = "TPU"; const char* const TPU_FAST_MEM_ATTR = "_TPU_FAST_MEM"; const char* const DEVICE_TPU_REPLICATED_CORE = "TPU_REPLICATED_CORE"; -const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM"; const char* const DEVICE_TPU_XLA_JIT = "XLA_TPU_JIT"; const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR = "_mirrored_variable_indices"; diff --git a/tensorflow/core/tpu/tpu_defs.h b/tensorflow/core/tpu/tpu_defs.h index 497afb5c392..294b4253ee0 100644 --- a/tensorflow/core/tpu/tpu_defs.h +++ b/tensorflow/core/tpu/tpu_defs.h @@ -32,7 +32,7 @@ extern const char* const DEVICE_TPU_NODE; // "TPU"; // TPUReplicate computation. extern const char* const DEVICE_TPU_REPLICATED_CORE; -extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM"; +// DEVICE_TPU_SYSTEM is now defined in tensorflow/core/framework/types.h/.cc // Name of the XLA_TPU_JIT compilation device, which is an internal device to // compile graphs for TPU. Not registered as a device; no operators can be