Changed TPU embedding load and retrieve ops to checked-in generated code.

PiperOrigin-RevId: 301887553
Change-Id: Ib6e042e73cd4a0214239175a4e86b090a0817f12
This commit is contained in:
A. Unique TensorFlower 2020-03-19 13:47:56 -07:00 committed by TensorFlower Gardener
parent f29c62f405
commit 4c0d6b7d51
13 changed files with 762 additions and 300 deletions

View File

@ -632,6 +632,7 @@ tf_gen_op_wrappers_cc(
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_embedding_load_retrieve_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",

View File

@ -723,6 +723,7 @@ tf_gen_op_libs(
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_embedding_load_retrieve_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",
@ -735,6 +736,7 @@ tf_gen_op_libs(
":lib",
":lib_proto_parsing",
":protos_all_cc",
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
"//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils",
"//tensorflow/core/tpu:tpu_embedding_output_layout_utils",
@ -894,6 +896,7 @@ cc_library(
":tpu_configuration_ops_op_lib",
":tpu_cross_replica_ops_op_lib",
":tpu_embedding_ops_op_lib",
":tpu_embedding_load_retrieve_ops_op_lib",
":tpu_functional_ops_op_lib",
":tpu_heartbeat_ops_op_lib",
":tpu_host_compute_ops_op_lib",

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
visibility: HIDDEN
in_arg {
name: "parameters"
description: <<END
Value of parameters used in the stochastic gradient descent optimization algorithm.
END
}
in_arg {
name: "gradient_accumulators"
description: <<END
Value of gradient_accumulators used in the Adadelta optimization algorithm.
END
}
summary: "Load SGD embedding parameters."
description: <<END
An op that loads optimization parameters into HBM for embedding. Must be
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
embedding table configuration. For example, this op is used to install
parameters that are loaded from a checkpoint before a training loop is
executed.
END
}

View File

@ -0,0 +1,23 @@
op {
graph_op_name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
visibility: HIDDEN
out_arg {
name: "parameters"
description: <<END
Parameter parameters updated by the stochastic gradient descent optimization algorithm.
END
}
out_arg {
name: "gradient_accumulators"
description: <<END
Parameter gradient_accumulators updated by the Adadelta optimization algorithm.
END
}
summary: "Retrieve SGD embedding parameters with debug support."
description: <<END
An op that retrieves optimization parameters from embedding to host
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
END
}

View File

@ -0,0 +1,575 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Produced by generate_tpu_embedding_load_retrieve_ops.py (Google-internal).
#include <string>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
namespace tensorflow {
namespace tpu {
using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
.Input("parameters: float32")
.Input("accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
.Output("parameters: float32")
.Output("accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
.Input("parameters: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
.Input("parameters: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
.Output("parameters: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/false});
REGISTER_OP(
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
.Output("parameters: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("linears: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("linears: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("linears: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("linears: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
.Input("parameters: float32")
.Input("momenta: float32")
.Input("velocities: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
.Input("parameters: float32")
.Input("momenta: float32")
.Input("velocities: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
.Output("parameters: float32")
.Output("momenta: float32")
.Output("velocities: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
.Output("parameters: float32")
.Output("momenta: float32")
.Output("velocities: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
.Input("parameters: float32")
.Input("momenta: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
.Input("parameters: float32")
.Input("momenta: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
.Output("parameters: float32")
.Output("momenta: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
.Output("parameters: float32")
.Output("momenta: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
.Input("parameters: float32")
.Input("ms: float32")
.Input("mom: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
.Input("parameters: float32")
.Input("ms: float32")
.Input("mom: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
.Output("parameters: float32")
.Output("ms: float32")
.Output("mom: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
.Output("parameters: float32")
.Output("ms: float32")
.Output("mom: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
.Input("parameters: float32")
.Input("ms: float32")
.Input("mom: float32")
.Input("mg: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
.Output("parameters: float32")
.Output("ms: float32")
.Output("mom: float32")
.Output("mg: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("weights: float32")
.Input("benefits: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("weights: float32")
.Output("benefits: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("updates: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("updates: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("updates: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("updates: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
.Input("parameters: float32")
.Input("accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Input("parameters: float32")
.Input("accumulators: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
.Output("parameters: float32")
.Output("accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Output("parameters: float32")
.Output("accumulators: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/true});
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
.Input("parameters: float32")
.Input("v: float32")
.Input("m: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/false});
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Input("parameters: float32")
.Input("v: float32")
.Input("m: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/true});
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
.Output("parameters: float32")
.Output("v: float32")
.Output("m: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/false});
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Output("parameters: float32")
.Output("v: float32")
.Output("m: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/true});
} // namespace tpu
} // namespace tensorflow

View File

@ -58,69 +58,6 @@ namespace tensorflow {
// saving a checkpoint, the model must Retrieve the parameters back into the
// host CPU memory.
namespace {
void RegisterPerTableLoadAndRetrieveOps();
class RegisterPerTableLoadAndRetrieveOpsOnConstruction {
public:
RegisterPerTableLoadAndRetrieveOpsOnConstruction() {
RegisterPerTableLoadAndRetrieveOps();
}
};
// Object whose constructor does registrations.
RegisterPerTableLoadAndRetrieveOpsOnConstruction
register_per_table_load_and_retrieve_ops_var;
void RegisterPerTableLoadAndRetrieveOps() {
// Load ops
for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
bool internal;
TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal));
if (!internal) {
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, false,
op_reg_data);
});
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
op_reg_data);
});
}
}
}
// Retrieve ops
for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
bool internal;
TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal));
if (!internal) {
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody(
alg, false, op_reg_data);
});
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody(
alg, true, op_reg_data);
});
}
}
}
}
} // namespace
REGISTER_OP("RecvTPUEmbeddingActivations")
.Output("outputs: num_outputs * float32")
.Attr("num_outputs: int >= 1")

View File

@ -143,24 +143,15 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
GradientAccumulationSupport* support) {
switch (alg) {
case OptimizationAlgorithm::kAdagrad:
*support = GradientAccumulationSupport::kSupported;
return Status::OK();
case OptimizationAlgorithm::kStochasticGradientDescent:
*support = GradientAccumulationSupport::kUnnecessary;
return Status::OK();
default: {
int auxiliary_parameter_count;
TF_RETURN_IF_ERROR(
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
? GradientAccumulationSupport::kSupported
: GradientAccumulationSupport::kNotSupported;
return Status::OK();
}
}
int auxiliary_parameter_count;
TF_RETURN_IF_ERROR(
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
? GradientAccumulationSupport::kSupported
: GradientAccumulationSupport::kNotSupported;
return Status::OK();
}
namespace {
// Make a normal state variable specification. Please refer to
// //tensorflow/core/protobuf/tpu/optimization_parameters.proto
@ -310,227 +301,102 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
};
}
Status RegisterPerTableLoadOpsForAlgorithmBody(
OptimizationAlgorithm alg, bool is_debug_op,
OpRegistrationData* op_reg_data) {
LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
bool is_debug_op)
: alg_(alg), is_debug_op_(is_debug_op) {}
Status LoadOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
std::vector<StateVariableSpecification> state_variable_specs;
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
alg,
alg_,
grad_accum_support == GradientAccumulationSupport::kSupported &&
is_debug_op,
is_debug_op_,
&state_variable_specs));
auto* op_def = &op_reg_data->op_def;
op_def->set_name(
strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg),
"Parameters", (is_debug_op ? "GradAccumDebug" : "")));
// It is important for the order of the inputs to the op defined here
// to match the order in input_names because the indexes are used in
// the combining transformation.
for (const auto& parameter : state_variable_specs) {
if (parameter.has_user_defined() || is_debug_op) {
auto* arg = op_def->add_input_arg();
arg->set_name(parameter.name());
arg->set_type(DT_FLOAT);
}
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
{
auto* table_id_attr = op_def->add_attr();
table_id_attr->set_name("table_id");
table_id_attr->set_type("int");
table_id_attr->set_has_minimum(true);
table_id_attr->set_minimum(-1);
table_id_attr->mutable_default_value()->set_i(-1);
}
{
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
num_shards_attr->set_name("num_shards");
num_shards_attr->set_type("int");
}
{
auto* shard_id_attr = op_def->add_attr();
shard_id_attr->set_name("shard_id");
shard_id_attr->set_type("int");
}
{
auto* embedding_config_attr = op_def->add_attr();
embedding_config_attr->set_name("config");
embedding_config_attr->set_type("string");
embedding_config_attr->mutable_default_value()->set_s("");
}
string parameter_descriptions;
for (const auto& parameter : state_variable_specs) {
if (parameter.has_user_defined() || is_debug_op) {
strings::Appendf(&parameter_descriptions,
R"(
%s: A tensor containing the initial embedding table %s to use in embedding
lookups using the %s optimization algorithm.)",
parameter.name().c_str(), parameter.name().c_str(),
GetOptimizationAlgorithmFriendlyName(alg).c_str());
}
}
op_def->set_is_commutative(false);
op_def->set_is_aggregate(false);
op_def->set_is_stateful(true);
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
const int user_param_count =
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
[&](const StateVariableSpecification& sv) {
return sv.has_user_defined() || is_debug_op;
});
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
int input_index = 0;
for (int i = 0; i < state_variable_specs.size(); ++i) {
if (state_variable_specs[i].has_user_defined() || is_debug_op) {
std::vector<shape_inference::ShapeHandle> input_temp;
TF_RETURN_IF_ERROR(
c->input(state_variable_specs[i].name(), &input_temp));
if (input_temp.size() != 1) {
return errors::InvalidArgument("each input to be rank 1");
}
inputs[input_index] = input_temp[0];
++input_index;
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
const int user_param_count =
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
[&](const StateVariableSpecification& sv) {
return sv.has_user_defined() || is_debug_op_;
});
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
int input_index = 0;
for (int i = 0; i < state_variable_specs.size(); ++i) {
if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
std::vector<shape_inference::ShapeHandle> input_temp;
TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
if (input_temp.size() != 1) {
return errors::InvalidArgument("each input to be rank 1");
}
inputs[input_index] = input_temp[0];
++input_index;
}
// Verify shapes have rank 2 and are compatible when they are
// required to be valid.
shape_inference::ShapeHandle parameter_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
for (int j = 1; j < user_param_count; ++j) {
shape_inference::ShapeHandle accumulator_j_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(
c->Merge(parameter_shape, accumulator_j_shape, &merged));
}
return Status::OK();
};
op_reg_data->shape_inference_fn = shape_inference_function;
}
// Verify shapes have rank 2 and are compatible when they are
// required to be valid.
shape_inference::ShapeHandle parameter_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
for (int j = 1; j < user_param_count; ++j) {
shape_inference::ShapeHandle accumulator_j_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
}
return Status::OK();
}
Status RegisterPerTableRetrieveOpsForAlgorithmBody(
OptimizationAlgorithm alg, bool is_debug_op,
OpRegistrationData* op_reg_data) {
RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
bool is_debug_op)
: alg_(alg), is_debug_op_(is_debug_op) {}
Status RetrieveOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
std::vector<StateVariableSpecification> state_variable_specs;
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
alg,
alg_,
grad_accum_support == GradientAccumulationSupport::kSupported &&
is_debug_op,
is_debug_op_,
&state_variable_specs));
auto* op_def = &op_reg_data->op_def;
op_def->set_name(
strings::StrCat("RetrieveTPUEmbedding", GetOptimizationAlgorithmName(alg),
"Parameters", (is_debug_op ? "GradAccumDebug" : "")));
// It is important for the order of the outputs of the op defined here
// to match the order in output_names because the indexes are used in
// the combining transformation.
for (const auto& parameter : state_variable_specs) {
if (parameter.has_user_defined() || is_debug_op) {
auto* arg = op_def->add_output_arg();
arg->set_name(parameter.name());
arg->set_type(DT_FLOAT);
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
for (int j = 0; j < state_variable_specs.size(); ++j) {
if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
auto shape = c->MakeShape(
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
TF_RETURN_IF_ERROR(
c->set_output(state_variable_specs[j].name(),
std::vector<shape_inference::ShapeHandle>(1, shape)));
}
}
{
auto* table_id_attr = op_def->add_attr();
table_id_attr->set_name("table_id");
table_id_attr->set_type("int");
table_id_attr->set_has_minimum(true);
table_id_attr->set_minimum(-1);
table_id_attr->mutable_default_value()->set_i(-1);
}
{
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
num_shards_attr->set_name("num_shards");
num_shards_attr->set_type("int");
}
{
auto* shard_id_attr = op_def->add_attr();
shard_id_attr->set_name("shard_id");
shard_id_attr->set_type("int");
}
{
auto* embedding_config_attr = op_def->add_attr();
embedding_config_attr->set_name("config");
embedding_config_attr->set_type("string");
embedding_config_attr->mutable_default_value()->set_s("");
}
string parameter_descriptions;
for (const auto& param : state_variable_specs) {
if (param.has_user_defined() || is_debug_op) {
strings::Appendf(&parameter_descriptions,
R"(
%s: A tensor containing the embedding table %s to store with the
parameters from embedding updates using the %s optimization algorithm.)",
param.name().c_str(), param.name().c_str(),
GetOptimizationAlgorithmFriendlyName(alg).c_str());
}
}
op_def->set_is_commutative(false);
op_def->set_is_aggregate(false);
op_def->set_is_stateful(true);
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
// Exactly one must be non-default.
if ((table_id >= 0) == (!table_name.empty())) {
return errors::InvalidArgument(
"exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
for (int j = 0; j < state_variable_specs.size(); ++j) {
if (state_variable_specs[j].has_user_defined() || is_debug_op) {
auto shape = c->MakeShape(
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
TF_RETURN_IF_ERROR(
c->set_output(state_variable_specs[j].name(),
std::vector<shape_inference::ShapeHandle>(1, shape)));
}
}
return Status::OK();
};
op_reg_data->shape_inference_fn = shape_inference_function;
return Status::OK();
}

View File

@ -41,9 +41,6 @@ enum class GradientAccumulationSupport {
// Accumulation cannot be used with this optimizer.
kNotSupported,
// Accumulation is unnecessary because optimizer application is commutative.
kUnnecessary,
// Accumulation is allowed and changes optimizer behavior.
kSupported,
};
@ -88,23 +85,45 @@ inline float GradientAccumulatorInitialValue() {
return absl::bit_cast<float, uint32>(1);
}
// Computes registration data for per table load Op. Each load Op transfers
// the embedding parameters from the host memory to the TPU memory.
Status RegisterPerTableLoadOpsForAlgorithmBody(OptimizationAlgorithm alg,
bool is_debug_op,
OpRegistrationData *op_reg_data);
// Computes registration data for per table retrieve Op. Each retrieve Op
// transfers the embedding parameters from the TPU memory to the host memory.
Status RegisterPerTableRetrieveOpsForAlgorithmBody(
OptimizationAlgorithm alg, bool is_debug_op,
OpRegistrationData *op_reg_data);
// Returns whether an optimization algorithm is only supported internally.
// Returns an error if the algorithm is not recognized at all.
Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
bool *internal);
// Generic shape function for per-optimization-algorithm load ops.
class LoadOpShapeFunction {
public:
// Constructor.
LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
// Computes resulting shape and does parameter checking.
Status operator()(shape_inference::InferenceContext *c) const;
private:
// Optimization algorithm.
const OptimizationAlgorithm alg_;
// Whether this op has an extra parameter for the gradient accumulators.
const bool is_debug_op_;
};
// Generic shape function for per-optimization-algorithm retrieve ops.
class RetrieveOpShapeFunction {
public:
// Constructor.
RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
// Computes resulting shape and does parameter checking.
Status operator()(shape_inference::InferenceContext *c) const;
private:
// Optimization algorithm.
const OptimizationAlgorithm alg_;
// Whether this op has an extra parameter for the gradient accumulators.
const bool is_debug_op_;
};
} // namespace tpu
} // namespace tensorflow

View File

@ -2992,6 +2992,7 @@ tf_gen_op_wrapper_private_py(
deps = [
"//tensorflow/core:tpu_configuration_ops_op_lib",
"//tensorflow/core:tpu_cross_replica_ops_op_lib",
"//tensorflow/core:tpu_embedding_load_retrieve_ops_op_lib",
"//tensorflow/core:tpu_embedding_ops_op_lib",
"//tensorflow/core:tpu_functional_ops_op_lib",
"//tensorflow/core:tpu_heartbeat_ops_op_lib",

View File

@ -119,6 +119,7 @@ from tensorflow.python.ops import gen_boosted_trees_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import gen_rnn_ops
from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import gen_tpu_ops
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train

View File

@ -548,10 +548,6 @@ tf_module {
name: "python_io"
mtype: "<type \'module\'>"
}
member {
name: "pywrap_tensorflow"
mtype: "<type \'module\'>"
}
member {
name: "qint16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"

View File

@ -2092,6 +2092,10 @@ tf_module {
name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
argspec: "args=[\'parameters\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "Log"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3636,6 +3640,10 @@ tf_module {
name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "Reverse"
argspec: "args=[\'tensor\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -2092,6 +2092,10 @@ tf_module {
name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
argspec: "args=[\'parameters\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "Log"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3636,6 +3640,10 @@ tf_module {
name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "Reverse"
argspec: "args=[\'tensor\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "