Changed TPU embedding load and retrieve ops to checked-in generated code.
PiperOrigin-RevId: 301887553 Change-Id: Ib6e042e73cd4a0214239175a4e86b090a0817f12
This commit is contained in:
parent
f29c62f405
commit
4c0d6b7d51
tensorflow
cc
core
BUILD
api_def/base_api
api_def_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxtapi_def_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.pbtxt
ops
tpu
python
tools/api/golden
@ -632,6 +632,7 @@ tf_gen_op_wrappers_cc(
|
||||
"tpu_configuration_ops",
|
||||
"tpu_cross_replica_ops",
|
||||
"tpu_embedding_ops",
|
||||
"tpu_embedding_load_retrieve_ops",
|
||||
"tpu_functional_ops",
|
||||
"tpu_heartbeat_ops",
|
||||
"tpu_host_compute_ops",
|
||||
|
@ -723,6 +723,7 @@ tf_gen_op_libs(
|
||||
"tpu_configuration_ops",
|
||||
"tpu_cross_replica_ops",
|
||||
"tpu_embedding_ops",
|
||||
"tpu_embedding_load_retrieve_ops",
|
||||
"tpu_functional_ops",
|
||||
"tpu_heartbeat_ops",
|
||||
"tpu_host_compute_ops",
|
||||
@ -735,6 +736,7 @@ tf_gen_op_libs(
|
||||
":lib",
|
||||
":lib_proto_parsing",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils",
|
||||
"//tensorflow/core/tpu:tpu_embedding_output_layout_utils",
|
||||
@ -894,6 +896,7 @@ cc_library(
|
||||
":tpu_configuration_ops_op_lib",
|
||||
":tpu_cross_replica_ops_op_lib",
|
||||
":tpu_embedding_ops_op_lib",
|
||||
":tpu_embedding_load_retrieve_ops_op_lib",
|
||||
":tpu_functional_ops_op_lib",
|
||||
":tpu_heartbeat_ops_op_lib",
|
||||
":tpu_host_compute_ops_op_lib",
|
||||
|
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Value of parameters used in the stochastic gradient descent optimization algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient_accumulators"
|
||||
description: <<END
|
||||
Value of gradient_accumulators used in the Adadelta optimization algorithm.
|
||||
END
|
||||
}
|
||||
summary: "Load SGD embedding parameters."
|
||||
description: <<END
|
||||
An op that loads optimization parameters into HBM for embedding. Must be
|
||||
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
|
||||
embedding table configuration. For example, this op is used to install
|
||||
parameters that are loaded from a checkpoint before a training loop is
|
||||
executed.
|
||||
END
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
op {
|
||||
graph_op_name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
out_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Parameter parameters updated by the stochastic gradient descent optimization algorithm.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "gradient_accumulators"
|
||||
description: <<END
|
||||
Parameter gradient_accumulators updated by the Adadelta optimization algorithm.
|
||||
END
|
||||
}
|
||||
summary: "Retrieve SGD embedding parameters with debug support."
|
||||
description: <<END
|
||||
An op that retrieves optimization parameters from embedding to host
|
||||
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
|
||||
the correct embedding table configuration. For example, this op is
|
||||
used to retrieve updated parameters before saving a checkpoint.
|
||||
END
|
||||
}
|
575
tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc
Normal file
575
tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc
Normal file
@ -0,0 +1,575 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Produced by generate_tpu_embedding_load_retrieve_ops.py (Google-internal).
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
|
||||
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Input("parameters: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Output("parameters: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP(
|
||||
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("linears: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("linears: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("linears: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("linears: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("momenta: float32")
|
||||
.Input("velocities: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("momenta: float32")
|
||||
.Input("velocities: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("momenta: float32")
|
||||
.Output("velocities: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("momenta: float32")
|
||||
.Output("velocities: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("momenta: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("momenta: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("momenta: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("momenta: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("ms: float32")
|
||||
.Input("mom: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("ms: float32")
|
||||
.Input("mom: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("ms: float32")
|
||||
.Output("mom: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("ms: float32")
|
||||
.Output("mom: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("ms: float32")
|
||||
.Input("mom: float32")
|
||||
.Input("mg: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("ms: float32")
|
||||
.Output("mom: float32")
|
||||
.Output("mg: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("weights: float32")
|
||||
.Input("benefits: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("weights: float32")
|
||||
.Output("benefits: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("updates: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("updates: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("updates: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("updates: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("accumulators: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("accumulators: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("v: float32")
|
||||
.Input("m: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("v: float32")
|
||||
.Input("m: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("v: float32")
|
||||
.Output("m: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/false});
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("v: float32")
|
||||
.Output("m: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/true});
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -58,69 +58,6 @@ namespace tensorflow {
|
||||
// saving a checkpoint, the model must Retrieve the parameters back into the
|
||||
// host CPU memory.
|
||||
|
||||
namespace {
|
||||
|
||||
void RegisterPerTableLoadAndRetrieveOps();
|
||||
|
||||
class RegisterPerTableLoadAndRetrieveOpsOnConstruction {
|
||||
public:
|
||||
RegisterPerTableLoadAndRetrieveOpsOnConstruction() {
|
||||
RegisterPerTableLoadAndRetrieveOps();
|
||||
}
|
||||
};
|
||||
|
||||
// Object whose constructor does registrations.
|
||||
RegisterPerTableLoadAndRetrieveOpsOnConstruction
|
||||
register_per_table_load_and_retrieve_ops_var;
|
||||
|
||||
void RegisterPerTableLoadAndRetrieveOps() {
|
||||
// Load ops
|
||||
for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
|
||||
bool internal;
|
||||
TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal));
|
||||
if (!internal) {
|
||||
OpRegistry::Global()->Register(
|
||||
[alg](OpRegistrationData* op_reg_data) -> Status {
|
||||
return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, false,
|
||||
op_reg_data);
|
||||
});
|
||||
tpu::GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
|
||||
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
|
||||
OpRegistry::Global()->Register(
|
||||
[alg](OpRegistrationData* op_reg_data) -> Status {
|
||||
return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
|
||||
op_reg_data);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve ops
|
||||
for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
|
||||
bool internal;
|
||||
TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal));
|
||||
if (!internal) {
|
||||
OpRegistry::Global()->Register(
|
||||
[alg](OpRegistrationData* op_reg_data) -> Status {
|
||||
return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody(
|
||||
alg, false, op_reg_data);
|
||||
});
|
||||
tpu::GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
|
||||
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
|
||||
OpRegistry::Global()->Register(
|
||||
[alg](OpRegistrationData* op_reg_data) -> Status {
|
||||
return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody(
|
||||
alg, true, op_reg_data);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("RecvTPUEmbeddingActivations")
|
||||
.Output("outputs: num_outputs * float32")
|
||||
.Attr("num_outputs: int >= 1")
|
||||
|
@ -143,24 +143,15 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
|
||||
|
||||
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
|
||||
GradientAccumulationSupport* support) {
|
||||
switch (alg) {
|
||||
case OptimizationAlgorithm::kAdagrad:
|
||||
*support = GradientAccumulationSupport::kSupported;
|
||||
return Status::OK();
|
||||
case OptimizationAlgorithm::kStochasticGradientDescent:
|
||||
*support = GradientAccumulationSupport::kUnnecessary;
|
||||
return Status::OK();
|
||||
default: {
|
||||
int auxiliary_parameter_count;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
|
||||
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
|
||||
? GradientAccumulationSupport::kSupported
|
||||
: GradientAccumulationSupport::kNotSupported;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
int auxiliary_parameter_count;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
|
||||
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
|
||||
? GradientAccumulationSupport::kSupported
|
||||
: GradientAccumulationSupport::kNotSupported;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Make a normal state variable specification. Please refer to
|
||||
// //tensorflow/core/protobuf/tpu/optimization_parameters.proto
|
||||
@ -310,227 +301,102 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
|
||||
};
|
||||
}
|
||||
|
||||
Status RegisterPerTableLoadOpsForAlgorithmBody(
|
||||
OptimizationAlgorithm alg, bool is_debug_op,
|
||||
OpRegistrationData* op_reg_data) {
|
||||
LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
|
||||
bool is_debug_op)
|
||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
||||
|
||||
Status LoadOpShapeFunction::operator()(
|
||||
shape_inference::InferenceContext* c) const {
|
||||
GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
||||
|
||||
std::vector<StateVariableSpecification> state_variable_specs;
|
||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
||||
alg,
|
||||
alg_,
|
||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
||||
is_debug_op,
|
||||
is_debug_op_,
|
||||
&state_variable_specs));
|
||||
auto* op_def = &op_reg_data->op_def;
|
||||
op_def->set_name(
|
||||
strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg),
|
||||
"Parameters", (is_debug_op ? "GradAccumDebug" : "")));
|
||||
// It is important for the order of the inputs to the op defined here
|
||||
// to match the order in input_names because the indexes are used in
|
||||
// the combining transformation.
|
||||
for (const auto& parameter : state_variable_specs) {
|
||||
if (parameter.has_user_defined() || is_debug_op) {
|
||||
auto* arg = op_def->add_input_arg();
|
||||
arg->set_name(parameter.name());
|
||||
arg->set_type(DT_FLOAT);
|
||||
}
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
|
||||
// Exactly one must be non-default.
|
||||
if ((table_id >= 0) == (!table_name.empty())) {
|
||||
return errors::InvalidArgument(
|
||||
"exactly one of table_id or table_name must be non-default");
|
||||
}
|
||||
{
|
||||
auto* table_id_attr = op_def->add_attr();
|
||||
table_id_attr->set_name("table_id");
|
||||
table_id_attr->set_type("int");
|
||||
table_id_attr->set_has_minimum(true);
|
||||
table_id_attr->set_minimum(-1);
|
||||
table_id_attr->mutable_default_value()->set_i(-1);
|
||||
}
|
||||
{
|
||||
auto* table_name_attr = op_def->add_attr();
|
||||
table_name_attr->set_name("table_name");
|
||||
table_name_attr->set_type("string");
|
||||
table_name_attr->mutable_default_value()->set_s("");
|
||||
}
|
||||
{
|
||||
auto* num_shards_attr = op_def->add_attr();
|
||||
num_shards_attr->set_name("num_shards");
|
||||
num_shards_attr->set_type("int");
|
||||
}
|
||||
{
|
||||
auto* shard_id_attr = op_def->add_attr();
|
||||
shard_id_attr->set_name("shard_id");
|
||||
shard_id_attr->set_type("int");
|
||||
}
|
||||
{
|
||||
auto* embedding_config_attr = op_def->add_attr();
|
||||
embedding_config_attr->set_name("config");
|
||||
embedding_config_attr->set_type("string");
|
||||
embedding_config_attr->mutable_default_value()->set_s("");
|
||||
}
|
||||
string parameter_descriptions;
|
||||
for (const auto& parameter : state_variable_specs) {
|
||||
if (parameter.has_user_defined() || is_debug_op) {
|
||||
strings::Appendf(¶meter_descriptions,
|
||||
R"(
|
||||
%s: A tensor containing the initial embedding table %s to use in embedding
|
||||
lookups using the %s optimization algorithm.)",
|
||||
parameter.name().c_str(), parameter.name().c_str(),
|
||||
GetOptimizationAlgorithmFriendlyName(alg).c_str());
|
||||
}
|
||||
}
|
||||
op_def->set_is_commutative(false);
|
||||
op_def->set_is_aggregate(false);
|
||||
op_def->set_is_stateful(true);
|
||||
auto shape_inference_function =
|
||||
[state_variable_specs,
|
||||
is_debug_op](shape_inference::InferenceContext* c) -> Status {
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
|
||||
// Exactly one must be non-default.
|
||||
if ((table_id >= 0) == (!table_name.empty())) {
|
||||
return errors::InvalidArgument(
|
||||
"exactly one of table_id or table_name must be non-default");
|
||||
}
|
||||
int num_shards;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
const int user_param_count =
|
||||
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
|
||||
[&](const StateVariableSpecification& sv) {
|
||||
return sv.has_user_defined() || is_debug_op;
|
||||
});
|
||||
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
|
||||
int input_index = 0;
|
||||
for (int i = 0; i < state_variable_specs.size(); ++i) {
|
||||
if (state_variable_specs[i].has_user_defined() || is_debug_op) {
|
||||
std::vector<shape_inference::ShapeHandle> input_temp;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->input(state_variable_specs[i].name(), &input_temp));
|
||||
if (input_temp.size() != 1) {
|
||||
return errors::InvalidArgument("each input to be rank 1");
|
||||
}
|
||||
inputs[input_index] = input_temp[0];
|
||||
++input_index;
|
||||
int num_shards;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
const int user_param_count =
|
||||
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
|
||||
[&](const StateVariableSpecification& sv) {
|
||||
return sv.has_user_defined() || is_debug_op_;
|
||||
});
|
||||
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
|
||||
int input_index = 0;
|
||||
for (int i = 0; i < state_variable_specs.size(); ++i) {
|
||||
if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
|
||||
std::vector<shape_inference::ShapeHandle> input_temp;
|
||||
TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
|
||||
if (input_temp.size() != 1) {
|
||||
return errors::InvalidArgument("each input to be rank 1");
|
||||
}
|
||||
inputs[input_index] = input_temp[0];
|
||||
++input_index;
|
||||
}
|
||||
// Verify shapes have rank 2 and are compatible when they are
|
||||
// required to be valid.
|
||||
shape_inference::ShapeHandle parameter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape));
|
||||
for (int j = 1; j < user_param_count; ++j) {
|
||||
shape_inference::ShapeHandle accumulator_j_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
|
||||
shape_inference::ShapeHandle merged;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(parameter_shape, accumulator_j_shape, &merged));
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
op_reg_data->shape_inference_fn = shape_inference_function;
|
||||
}
|
||||
// Verify shapes have rank 2 and are compatible when they are
|
||||
// required to be valid.
|
||||
shape_inference::ShapeHandle parameter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape));
|
||||
for (int j = 1; j < user_param_count; ++j) {
|
||||
shape_inference::ShapeHandle accumulator_j_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
|
||||
shape_inference::ShapeHandle merged;
|
||||
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterPerTableRetrieveOpsForAlgorithmBody(
|
||||
OptimizationAlgorithm alg, bool is_debug_op,
|
||||
OpRegistrationData* op_reg_data) {
|
||||
RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
|
||||
bool is_debug_op)
|
||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
||||
|
||||
Status RetrieveOpShapeFunction::operator()(
|
||||
shape_inference::InferenceContext* c) const {
|
||||
GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
||||
|
||||
std::vector<StateVariableSpecification> state_variable_specs;
|
||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
||||
alg,
|
||||
alg_,
|
||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
||||
is_debug_op,
|
||||
is_debug_op_,
|
||||
&state_variable_specs));
|
||||
|
||||
auto* op_def = &op_reg_data->op_def;
|
||||
op_def->set_name(
|
||||
strings::StrCat("RetrieveTPUEmbedding", GetOptimizationAlgorithmName(alg),
|
||||
"Parameters", (is_debug_op ? "GradAccumDebug" : "")));
|
||||
// It is important for the order of the outputs of the op defined here
|
||||
// to match the order in output_names because the indexes are used in
|
||||
// the combining transformation.
|
||||
for (const auto& parameter : state_variable_specs) {
|
||||
if (parameter.has_user_defined() || is_debug_op) {
|
||||
auto* arg = op_def->add_output_arg();
|
||||
arg->set_name(parameter.name());
|
||||
arg->set_type(DT_FLOAT);
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
|
||||
// Exactly one must be non-default.
|
||||
if ((table_id >= 0) == (!table_name.empty())) {
|
||||
return errors::InvalidArgument(
|
||||
"exactly one of table_id or table_name must be non-default");
|
||||
}
|
||||
int num_shards;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
for (int j = 0; j < state_variable_specs.size(); ++j) {
|
||||
if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
|
||||
auto shape = c->MakeShape(
|
||||
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->set_output(state_variable_specs[j].name(),
|
||||
std::vector<shape_inference::ShapeHandle>(1, shape)));
|
||||
}
|
||||
}
|
||||
{
|
||||
auto* table_id_attr = op_def->add_attr();
|
||||
table_id_attr->set_name("table_id");
|
||||
table_id_attr->set_type("int");
|
||||
table_id_attr->set_has_minimum(true);
|
||||
table_id_attr->set_minimum(-1);
|
||||
table_id_attr->mutable_default_value()->set_i(-1);
|
||||
}
|
||||
{
|
||||
auto* table_name_attr = op_def->add_attr();
|
||||
table_name_attr->set_name("table_name");
|
||||
table_name_attr->set_type("string");
|
||||
table_name_attr->mutable_default_value()->set_s("");
|
||||
}
|
||||
{
|
||||
auto* num_shards_attr = op_def->add_attr();
|
||||
num_shards_attr->set_name("num_shards");
|
||||
num_shards_attr->set_type("int");
|
||||
}
|
||||
{
|
||||
auto* shard_id_attr = op_def->add_attr();
|
||||
shard_id_attr->set_name("shard_id");
|
||||
shard_id_attr->set_type("int");
|
||||
}
|
||||
{
|
||||
auto* embedding_config_attr = op_def->add_attr();
|
||||
embedding_config_attr->set_name("config");
|
||||
embedding_config_attr->set_type("string");
|
||||
embedding_config_attr->mutable_default_value()->set_s("");
|
||||
}
|
||||
string parameter_descriptions;
|
||||
for (const auto& param : state_variable_specs) {
|
||||
if (param.has_user_defined() || is_debug_op) {
|
||||
strings::Appendf(¶meter_descriptions,
|
||||
R"(
|
||||
%s: A tensor containing the embedding table %s to store with the
|
||||
parameters from embedding updates using the %s optimization algorithm.)",
|
||||
param.name().c_str(), param.name().c_str(),
|
||||
GetOptimizationAlgorithmFriendlyName(alg).c_str());
|
||||
}
|
||||
}
|
||||
op_def->set_is_commutative(false);
|
||||
op_def->set_is_aggregate(false);
|
||||
op_def->set_is_stateful(true);
|
||||
auto shape_inference_function =
|
||||
[state_variable_specs,
|
||||
is_debug_op](shape_inference::InferenceContext* c) -> Status {
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
|
||||
// Exactly one must be non-default.
|
||||
if ((table_id >= 0) == (!table_name.empty())) {
|
||||
return errors::InvalidArgument(
|
||||
"exactly one of table_id or table_name must be non-default");
|
||||
}
|
||||
int num_shards;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
for (int j = 0; j < state_variable_specs.size(); ++j) {
|
||||
if (state_variable_specs[j].has_user_defined() || is_debug_op) {
|
||||
auto shape = c->MakeShape(
|
||||
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->set_output(state_variable_specs[j].name(),
|
||||
std::vector<shape_inference::ShapeHandle>(1, shape)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
op_reg_data->shape_inference_fn = shape_inference_function;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -41,9 +41,6 @@ enum class GradientAccumulationSupport {
|
||||
// Accumulation cannot be used with this optimizer.
|
||||
kNotSupported,
|
||||
|
||||
// Accumulation is unnecessary because optimizer application is commutative.
|
||||
kUnnecessary,
|
||||
|
||||
// Accumulation is allowed and changes optimizer behavior.
|
||||
kSupported,
|
||||
};
|
||||
@ -88,23 +85,45 @@ inline float GradientAccumulatorInitialValue() {
|
||||
return absl::bit_cast<float, uint32>(1);
|
||||
}
|
||||
|
||||
// Computes registration data for per table load Op. Each load Op transfers
|
||||
// the embedding parameters from the host memory to the TPU memory.
|
||||
Status RegisterPerTableLoadOpsForAlgorithmBody(OptimizationAlgorithm alg,
|
||||
bool is_debug_op,
|
||||
OpRegistrationData *op_reg_data);
|
||||
|
||||
// Computes registration data for per table retrieve Op. Each retrieve Op
|
||||
// transfers the embedding parameters from the TPU memory to the host memory.
|
||||
Status RegisterPerTableRetrieveOpsForAlgorithmBody(
|
||||
OptimizationAlgorithm alg, bool is_debug_op,
|
||||
OpRegistrationData *op_reg_data);
|
||||
|
||||
// Returns whether an optimization algorithm is only supported internally.
|
||||
// Returns an error if the algorithm is not recognized at all.
|
||||
Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
|
||||
bool *internal);
|
||||
|
||||
// Generic shape function for per-optimization-algorithm load ops.
|
||||
class LoadOpShapeFunction {
|
||||
public:
|
||||
// Constructor.
|
||||
LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
||||
|
||||
// Computes resulting shape and does parameter checking.
|
||||
Status operator()(shape_inference::InferenceContext *c) const;
|
||||
|
||||
private:
|
||||
// Optimization algorithm.
|
||||
const OptimizationAlgorithm alg_;
|
||||
|
||||
// Whether this op has an extra parameter for the gradient accumulators.
|
||||
const bool is_debug_op_;
|
||||
};
|
||||
|
||||
// Generic shape function for per-optimization-algorithm retrieve ops.
|
||||
class RetrieveOpShapeFunction {
|
||||
public:
|
||||
// Constructor.
|
||||
RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
||||
|
||||
// Computes resulting shape and does parameter checking.
|
||||
Status operator()(shape_inference::InferenceContext *c) const;
|
||||
|
||||
private:
|
||||
// Optimization algorithm.
|
||||
const OptimizationAlgorithm alg_;
|
||||
|
||||
// Whether this op has an extra parameter for the gradient accumulators.
|
||||
const bool is_debug_op_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -2992,6 +2992,7 @@ tf_gen_op_wrapper_private_py(
|
||||
deps = [
|
||||
"//tensorflow/core:tpu_configuration_ops_op_lib",
|
||||
"//tensorflow/core:tpu_cross_replica_ops_op_lib",
|
||||
"//tensorflow/core:tpu_embedding_load_retrieve_ops_op_lib",
|
||||
"//tensorflow/core:tpu_embedding_ops_op_lib",
|
||||
"//tensorflow/core:tpu_functional_ops_op_lib",
|
||||
"//tensorflow/core:tpu_heartbeat_ops_op_lib",
|
||||
|
@ -119,6 +119,7 @@ from tensorflow.python.ops import gen_boosted_trees_ops
|
||||
from tensorflow.python.ops import gen_cudnn_rnn_ops
|
||||
from tensorflow.python.ops import gen_rnn_ops
|
||||
from tensorflow.python.ops import gen_sendrecv_ops
|
||||
from tensorflow.python.ops import gen_tpu_ops
|
||||
|
||||
# Import the names from python/training.py as train.Name.
|
||||
from tensorflow.python.training import training as train
|
||||
|
@ -548,10 +548,6 @@ tf_module {
|
||||
name: "python_io"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "pywrap_tensorflow"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "qint16"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
|
@ -2092,6 +2092,10 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
|
||||
argspec: "args=[\'parameters\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Log"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3636,6 +3640,10 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Reverse"
|
||||
argspec: "args=[\'tensor\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2092,6 +2092,10 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
|
||||
argspec: "args=[\'parameters\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Log"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3636,6 +3640,10 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Reverse"
|
||||
argspec: "args=[\'tensor\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user