Add required TPU_SYSTEM device kernel registrations for TPUs

PiperOrigin-RevId: 319265707
Change-Id: I135e2607b4a212ccc51a3327cf913503547121f3
This commit is contained in:
Frank Chen 2020-07-01 11:40:26 -07:00 committed by TensorFlower Gardener
parent 296c323683
commit 5e17b5a12c
11 changed files with 38 additions and 6 deletions

View File

@ -39,6 +39,7 @@ const char* const DEVICE_DEFAULT = "DEFAULT";
const char* const DEVICE_CPU = "CPU";
const char* const DEVICE_GPU = "GPU";
const char* const DEVICE_SYCL = "SYCL";
const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \

View File

@ -75,6 +75,7 @@ TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT"
TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL"
TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM"
template <typename Device>
struct DeviceName {};

View File

@ -97,6 +97,7 @@ void ConstantOp::Compute(OpKernelContext* ctx) {
ConstantOp::~ConstantOp() {}
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_TPU_SYSTEM), ConstantOp);
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)

View File

@ -54,6 +54,8 @@ void SwitchNOp::Compute(OpKernelContext* context) {
REGISTER_KERNEL_BUILDER(
Name("Switch").Device(DEVICE_DEFAULT).HostMemory("pred"), SwitchOp);
REGISTER_KERNEL_BUILDER(
Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp);
REGISTER_KERNEL_BUILDER(
Name("_SwitchN").Device(DEVICE_DEFAULT).HostMemory("output_index"),
@ -285,6 +287,8 @@ void MergeOp::Compute(OpKernelContext* context) {
REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
REGISTER_KERNEL_BUILDER(
Name("Merge").Device(DEVICE_DEFAULT).HostMemory("value_index"), MergeOp);
REGISTER_KERNEL_BUILDER(
Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp);
REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
#define REGISTER_GPU_KERNEL(type) \
@ -393,6 +397,7 @@ void EnterOp::Compute(OpKernelContext* context) {
}
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_DEFAULT), EnterOp);
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp);
REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
#define REGISTER_GPU_KERNEL(type) \
@ -489,6 +494,7 @@ void ExitOp::Compute(OpKernelContext* context) {
}
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_DEFAULT), ExitOp);
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp);
REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
#define REGISTER_GPU_KERNEL(type) \
@ -571,6 +577,8 @@ void NextIterationOp::Compute(OpKernelContext* context) {
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_DEFAULT),
NextIterationOp);
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM),
NextIterationOp);
REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
NextIterationOp);
@ -665,10 +673,17 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond")
.HostMemory("input")
.HostMemory("output"),
LoopCondOp);
REGISTER_KERNEL_BUILDER(Name("LoopCond")
.Device(DEVICE_TPU_SYSTEM)
.HostMemory("input")
.HostMemory("output"),
LoopCondOp);
// ControlTrigger kernel
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_DEFAULT),
ControlTriggerOp);
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_TPU_SYSTEM),
ControlTriggerOp);
// When called, abort op will abort the current process. This can be used to
// abort remote PSs when needed.

View File

@ -89,6 +89,11 @@ REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
// TPU ops are only registered when they are required as part of the larger
// TPU runtime, and does not need to be registered when selective registration
// is turned on.
REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
#if TENSORFLOW_USE_SYCL
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \

View File

@ -25,6 +25,9 @@ limitations under the License.
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_TPU_SYSTEM),
IdentityNOp);
// Do not worry about colocating IdentityN op with its resource inputs since
// it just forwards it's inputs anyway. This is needed because we create
// IdentityN nodes to club "all" outputs of functional ops while lowering to

View File

@ -24,6 +24,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp);
REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_TPU_SYSTEM), IdentityOp);
// StopGradient does the same thing as Identity, but has a different
// gradient registered.
REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp);

View File

@ -18,5 +18,6 @@ limitations under the License.
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_DEFAULT), NoOp);
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_TPU_SYSTEM), NoOp);
} // namespace tensorflow

View File

@ -123,6 +123,8 @@ string SendOp::TraceString(OpKernelContext* ctx, bool verbose) {
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_TPU_SYSTEM), SendOp);
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_TPU_SYSTEM), SendOp);
// Public alias. Added for use in Lingvo.
REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_CPU), SendOp);
@ -215,6 +217,8 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_DEFAULT), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_TPU_SYSTEM), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_TPU_SYSTEM), RecvOp);
// Public alias. Added for use in Lingvo.
REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_CPU), RecvOp);

View File

@ -20,7 +20,6 @@ namespace tensorflow {
const char* const DEVICE_TPU_NODE = "TPU";
const char* const TPU_FAST_MEM_ATTR = "_TPU_FAST_MEM";
const char* const DEVICE_TPU_REPLICATED_CORE = "TPU_REPLICATED_CORE";
const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
const char* const DEVICE_TPU_XLA_JIT = "XLA_TPU_JIT";
const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
"_mirrored_variable_indices";

View File

@ -32,7 +32,7 @@ extern const char* const DEVICE_TPU_NODE; // "TPU";
// TPUReplicate computation.
extern const char* const DEVICE_TPU_REPLICATED_CORE;
extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM";
// DEVICE_TPU_SYSTEM is now defined in tensorflow/core/framework/types.h/.cc
// Name of the XLA_TPU_JIT compilation device, which is an internal device to
// compile graphs for TPU. Not registered as a device; no operators can be