Allow TPU embedding utility ops to take OptimizationParameters rather than OptimizationAlgorithm.
PiperOrigin-RevId: 333120828 Change-Id: I959f8f116293df5e2130654a073581219cca1fec
This commit is contained in:
parent
a151f928f3
commit
2c7d978a14
@ -37,8 +37,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -50,8 +49,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -62,8 +60,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -75,8 +72,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -86,9 +82,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -99,9 +93,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -111,9 +103,7 @@ REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP(
|
||||
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||
@ -125,9 +115,7 @@ REGISTER_OP(
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -139,8 +127,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -153,8 +140,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -166,8 +152,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -180,8 +165,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -193,8 +177,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -207,8 +190,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -220,8 +202,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -234,8 +215,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -246,8 +226,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -259,8 +238,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -271,9 +249,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -285,9 +261,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -299,8 +273,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -313,8 +286,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -326,8 +298,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -340,8 +311,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -354,9 +324,7 @@ REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -369,9 +337,7 @@ REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -384,9 +350,7 @@ REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -399,9 +363,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -413,8 +375,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -427,8 +388,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -440,9 +400,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -455,9 +413,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -468,9 +424,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -482,9 +436,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -495,9 +447,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -509,9 +459,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
||||
.Input("parameters: float32")
|
||||
@ -523,9 +471,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
@ -538,9 +484,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
||||
.Output("parameters: float32")
|
||||
@ -552,9 +496,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/false});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
@ -567,9 +509,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction{
|
||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
||||
/*is_debug_op=*/true});
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -94,8 +94,9 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
|
||||
// Returns the number of optimization parameter vectors used by the optimization
|
||||
// algorithm, excluding the weights themselves and assuming no gradient
|
||||
// accumulation.
|
||||
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
|
||||
switch (alg) {
|
||||
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
|
||||
int* count) {
|
||||
switch (params.parameters_case()) {
|
||||
case OptimizationAlgorithm::kAdagrad:
|
||||
*count = 1;
|
||||
return Status::OK();
|
||||
@ -141,11 +142,11 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
|
||||
return errors::InvalidArgument("No optimization algorithm specified");
|
||||
}
|
||||
|
||||
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
|
||||
Status GetGradientAccumulationSupport(const OptimizationParameters& params,
|
||||
GradientAccumulationSupport* support) {
|
||||
int auxiliary_parameter_count;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
|
||||
GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
|
||||
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
|
||||
? GradientAccumulationSupport::kSupported
|
||||
: GradientAccumulationSupport::kNotSupported;
|
||||
@ -168,7 +169,7 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
|
||||
} // namespace
|
||||
|
||||
Status GetOptimizationAlgorithmStateVariables(
|
||||
OptimizationAlgorithm alg, bool use_gradient_accumulation,
|
||||
const OptimizationParameters& params, bool use_gradient_accumulation,
|
||||
std::vector<StateVariableSpecification>* state_variables) {
|
||||
// The first parameter set is always the weights themselves.
|
||||
state_variables->push_back(
|
||||
@ -176,7 +177,7 @@ Status GetOptimizationAlgorithmStateVariables(
|
||||
// The order of the returned parameters needs to match the offsets used by
|
||||
// the algorithm implementations in test_util.cc and
|
||||
// address_handler_program_creator.cc.
|
||||
switch (alg) {
|
||||
switch (params.parameters_case()) {
|
||||
case OptimizationAlgorithm::kAdagrad: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
@ -276,7 +277,8 @@ Status GetOptimizationAlgorithmStateVariables(
|
||||
}
|
||||
if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Optimization algorithm", GetOptimizationAlgorithmName(alg),
|
||||
"Optimization algorithm",
|
||||
GetOptimizationAlgorithmName(params.parameters_case()),
|
||||
"does not support gradient accumulation because it "
|
||||
"already has too many other accumulators");
|
||||
}
|
||||
@ -301,21 +303,8 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
|
||||
};
|
||||
}
|
||||
|
||||
LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
|
||||
bool is_debug_op)
|
||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
||||
|
||||
Status LoadOpShapeFunction::operator()(
|
||||
shape_inference::InferenceContext* c) const {
|
||||
GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
||||
|
||||
std::vector<StateVariableSpecification> state_variable_specs;
|
||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
||||
alg_,
|
||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
||||
is_debug_op_,
|
||||
&state_variable_specs));
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
@ -329,52 +318,23 @@ Status LoadOpShapeFunction::operator()(
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
const int user_param_count =
|
||||
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
|
||||
[&](const StateVariableSpecification& sv) {
|
||||
return sv.has_user_defined() || is_debug_op_;
|
||||
});
|
||||
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
|
||||
int input_index = 0;
|
||||
for (int i = 0, end = state_variable_specs.size(); i < end; ++i) {
|
||||
if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
|
||||
std::vector<shape_inference::ShapeHandle> input_temp;
|
||||
TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
|
||||
if (input_temp.size() != 1) {
|
||||
return errors::InvalidArgument("each input to be rank 1");
|
||||
}
|
||||
inputs[input_index] = input_temp[0];
|
||||
++input_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify shapes have rank 2 and are compatible when they are
|
||||
// required to be valid.
|
||||
shape_inference::ShapeHandle parameter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape));
|
||||
for (int j = 1; j < user_param_count; ++j) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, ¶meter_shape));
|
||||
for (int j = 1; j < c->num_inputs(); ++j) {
|
||||
shape_inference::ShapeHandle accumulator_j_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
|
||||
shape_inference::ShapeHandle merged;
|
||||
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
|
||||
bool is_debug_op)
|
||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
||||
|
||||
Status RetrieveOpShapeFunction::operator()(
|
||||
shape_inference::InferenceContext* c) const {
|
||||
GradientAccumulationSupport grad_accum_support;
|
||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
||||
|
||||
std::vector<StateVariableSpecification> state_variable_specs;
|
||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
||||
alg_,
|
||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
||||
is_debug_op_,
|
||||
&state_variable_specs));
|
||||
int table_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||
string table_name;
|
||||
@ -388,14 +348,9 @@ Status RetrieveOpShapeFunction::operator()(
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||
int shard_id;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||
for (int j = 0, end = state_variable_specs.size(); j < end; ++j) {
|
||||
if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
|
||||
auto shape = c->MakeShape(
|
||||
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->set_output(state_variable_specs[j].name(),
|
||||
std::vector<shape_inference::ShapeHandle>(1, shape)));
|
||||
}
|
||||
for (int j = 0; j < c->num_outputs(); ++j) {
|
||||
c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
|
||||
2, c->UnknownDim())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -48,18 +48,19 @@ enum class GradientAccumulationSupport {
|
||||
// Returns the number of optimization parameter vectors used by the optimization
|
||||
// algorithm, excluding the weights themselves and assuming no gradient
|
||||
// accumulation.
|
||||
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
|
||||
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms,
|
||||
int *count);
|
||||
|
||||
// Returns whether (and how) an optimization algorithm supports gradient
|
||||
// accumulation.
|
||||
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
|
||||
Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms,
|
||||
GradientAccumulationSupport *support);
|
||||
|
||||
// Returns the parameter specifications for the optimization algorithm (the main
|
||||
// parameters first, followed by any auxiliary parameters such as Adagrad
|
||||
// accumulators).
|
||||
Status GetOptimizationAlgorithmStateVariables(
|
||||
OptimizationAlgorithm alg, bool use_gradient_accumulation,
|
||||
const OptimizationParameters ¶ms, bool use_gradient_accumulation,
|
||||
std::vector<StateVariableSpecification> *state_variables);
|
||||
|
||||
// Maximum value of auxiliar_parameter_count for any optimization algorithm.
|
||||
@ -93,35 +94,15 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
|
||||
// Generic shape function for per-optimization-algorithm load ops.
|
||||
class LoadOpShapeFunction {
|
||||
public:
|
||||
// Constructor.
|
||||
LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
||||
|
||||
// Computes resulting shape and does parameter checking.
|
||||
Status operator()(shape_inference::InferenceContext *c) const;
|
||||
|
||||
private:
|
||||
// Optimization algorithm.
|
||||
const OptimizationAlgorithm alg_;
|
||||
|
||||
// Whether this op has an extra parameter for the gradient accumulators.
|
||||
const bool is_debug_op_;
|
||||
};
|
||||
|
||||
// Generic shape function for per-optimization-algorithm retrieve ops.
|
||||
class RetrieveOpShapeFunction {
|
||||
public:
|
||||
// Constructor.
|
||||
RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
||||
|
||||
// Computes resulting shape and does parameter checking.
|
||||
Status operator()(shape_inference::InferenceContext *c) const;
|
||||
|
||||
private:
|
||||
// Optimization algorithm.
|
||||
const OptimizationAlgorithm alg_;
|
||||
|
||||
// Whether this op has an extra parameter for the gradient accumulators.
|
||||
const bool is_debug_op_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
Loading…
Reference in New Issue
Block a user