Allow TPU embedding utility ops to take OptimizationParameters rather than OptimizationAlgorithm.
PiperOrigin-RevId: 333120828 Change-Id: I959f8f116293df5e2130654a073581219cca1fec
This commit is contained in:
parent
a151f928f3
commit
2c7d978a14
@ -37,8 +37,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -50,8 +49,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -62,8 +60,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -75,8 +72,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -86,9 +82,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -99,9 +93,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -111,9 +103,7 @@ REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP(
|
REGISTER_OP(
|
||||||
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
|
||||||
@ -125,9 +115,7 @@ REGISTER_OP(
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -139,8 +127,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -153,8 +140,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -166,8 +152,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -180,8 +165,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -193,8 +177,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -207,8 +190,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -220,8 +202,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -234,8 +215,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -246,8 +226,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -259,8 +238,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -271,9 +249,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -285,9 +261,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kMomentum,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -299,8 +273,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -313,8 +286,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -326,8 +298,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -340,8 +311,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -354,9 +324,7 @@ REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -369,9 +337,7 @@ REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -384,9 +350,7 @@ REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -399,9 +363,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -413,8 +375,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -427,8 +388,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -440,9 +400,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -455,9 +413,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kAdadelta,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -468,9 +424,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -482,9 +436,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -495,9 +447,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -509,9 +459,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -523,9 +471,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||||
.Input("parameters: float32")
|
.Input("parameters: float32")
|
||||||
@ -538,9 +484,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(LoadOpShapeFunction{
|
.SetShapeFn(LoadOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -552,9 +496,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
|
||||||
/*is_debug_op=*/false});
|
|
||||||
|
|
||||||
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||||
.Output("parameters: float32")
|
.Output("parameters: float32")
|
||||||
@ -567,9 +509,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
|||||||
.Attr("shard_id: int")
|
.Attr("shard_id: int")
|
||||||
.Attr("config: string = \"\"")
|
.Attr("config: string = \"\"")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(RetrieveOpShapeFunction{
|
.SetShapeFn(RetrieveOpShapeFunction());
|
||||||
/*alg=*/OptimizationAlgorithm::kProximalYogi,
|
|
||||||
/*is_debug_op=*/true});
|
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -94,8 +94,9 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
|
|||||||
// Returns the number of optimization parameter vectors used by the optimization
|
// Returns the number of optimization parameter vectors used by the optimization
|
||||||
// algorithm, excluding the weights themselves and assuming no gradient
|
// algorithm, excluding the weights themselves and assuming no gradient
|
||||||
// accumulation.
|
// accumulation.
|
||||||
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
|
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
|
||||||
switch (alg) {
|
int* count) {
|
||||||
|
switch (params.parameters_case()) {
|
||||||
case OptimizationAlgorithm::kAdagrad:
|
case OptimizationAlgorithm::kAdagrad:
|
||||||
*count = 1;
|
*count = 1;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -141,11 +142,11 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
|
|||||||
return errors::InvalidArgument("No optimization algorithm specified");
|
return errors::InvalidArgument("No optimization algorithm specified");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
|
Status GetGradientAccumulationSupport(const OptimizationParameters& params,
|
||||||
GradientAccumulationSupport* support) {
|
GradientAccumulationSupport* support) {
|
||||||
int auxiliary_parameter_count;
|
int auxiliary_parameter_count;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
|
GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
|
||||||
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
|
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
|
||||||
? GradientAccumulationSupport::kSupported
|
? GradientAccumulationSupport::kSupported
|
||||||
: GradientAccumulationSupport::kNotSupported;
|
: GradientAccumulationSupport::kNotSupported;
|
||||||
@ -168,7 +169,7 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status GetOptimizationAlgorithmStateVariables(
|
Status GetOptimizationAlgorithmStateVariables(
|
||||||
OptimizationAlgorithm alg, bool use_gradient_accumulation,
|
const OptimizationParameters& params, bool use_gradient_accumulation,
|
||||||
std::vector<StateVariableSpecification>* state_variables) {
|
std::vector<StateVariableSpecification>* state_variables) {
|
||||||
// The first parameter set is always the weights themselves.
|
// The first parameter set is always the weights themselves.
|
||||||
state_variables->push_back(
|
state_variables->push_back(
|
||||||
@ -176,7 +177,7 @@ Status GetOptimizationAlgorithmStateVariables(
|
|||||||
// The order of the returned parameters needs to match the offsets used by
|
// The order of the returned parameters needs to match the offsets used by
|
||||||
// the algorithm implementations in test_util.cc and
|
// the algorithm implementations in test_util.cc and
|
||||||
// address_handler_program_creator.cc.
|
// address_handler_program_creator.cc.
|
||||||
switch (alg) {
|
switch (params.parameters_case()) {
|
||||||
case OptimizationAlgorithm::kAdagrad: {
|
case OptimizationAlgorithm::kAdagrad: {
|
||||||
state_variables->push_back(
|
state_variables->push_back(
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||||
@ -276,7 +277,8 @@ Status GetOptimizationAlgorithmStateVariables(
|
|||||||
}
|
}
|
||||||
if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
|
if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Optimization algorithm", GetOptimizationAlgorithmName(alg),
|
"Optimization algorithm",
|
||||||
|
GetOptimizationAlgorithmName(params.parameters_case()),
|
||||||
"does not support gradient accumulation because it "
|
"does not support gradient accumulation because it "
|
||||||
"already has too many other accumulators");
|
"already has too many other accumulators");
|
||||||
}
|
}
|
||||||
@ -301,21 +303,8 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
|
|
||||||
bool is_debug_op)
|
|
||||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
|
||||||
|
|
||||||
Status LoadOpShapeFunction::operator()(
|
Status LoadOpShapeFunction::operator()(
|
||||||
shape_inference::InferenceContext* c) const {
|
shape_inference::InferenceContext* c) const {
|
||||||
GradientAccumulationSupport grad_accum_support;
|
|
||||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
|
||||||
|
|
||||||
std::vector<StateVariableSpecification> state_variable_specs;
|
|
||||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
|
||||||
alg_,
|
|
||||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
|
||||||
is_debug_op_,
|
|
||||||
&state_variable_specs));
|
|
||||||
int table_id;
|
int table_id;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||||
string table_name;
|
string table_name;
|
||||||
@ -329,52 +318,23 @@ Status LoadOpShapeFunction::operator()(
|
|||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||||
int shard_id;
|
int shard_id;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||||
const int user_param_count =
|
|
||||||
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
|
|
||||||
[&](const StateVariableSpecification& sv) {
|
|
||||||
return sv.has_user_defined() || is_debug_op_;
|
|
||||||
});
|
|
||||||
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
|
|
||||||
int input_index = 0;
|
|
||||||
for (int i = 0, end = state_variable_specs.size(); i < end; ++i) {
|
|
||||||
if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
|
|
||||||
std::vector<shape_inference::ShapeHandle> input_temp;
|
|
||||||
TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
|
|
||||||
if (input_temp.size() != 1) {
|
|
||||||
return errors::InvalidArgument("each input to be rank 1");
|
|
||||||
}
|
|
||||||
inputs[input_index] = input_temp[0];
|
|
||||||
++input_index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Verify shapes have rank 2 and are compatible when they are
|
// Verify shapes have rank 2 and are compatible when they are
|
||||||
// required to be valid.
|
// required to be valid.
|
||||||
shape_inference::ShapeHandle parameter_shape;
|
shape_inference::ShapeHandle parameter_shape;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, ¶meter_shape));
|
||||||
for (int j = 1; j < user_param_count; ++j) {
|
for (int j = 1; j < c->num_inputs(); ++j) {
|
||||||
shape_inference::ShapeHandle accumulator_j_shape;
|
shape_inference::ShapeHandle accumulator_j_shape;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
|
||||||
shape_inference::ShapeHandle merged;
|
shape_inference::ShapeHandle merged;
|
||||||
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
|
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
|
|
||||||
bool is_debug_op)
|
|
||||||
: alg_(alg), is_debug_op_(is_debug_op) {}
|
|
||||||
|
|
||||||
Status RetrieveOpShapeFunction::operator()(
|
Status RetrieveOpShapeFunction::operator()(
|
||||||
shape_inference::InferenceContext* c) const {
|
shape_inference::InferenceContext* c) const {
|
||||||
GradientAccumulationSupport grad_accum_support;
|
|
||||||
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
|
|
||||||
|
|
||||||
std::vector<StateVariableSpecification> state_variable_specs;
|
|
||||||
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
|
|
||||||
alg_,
|
|
||||||
grad_accum_support == GradientAccumulationSupport::kSupported &&
|
|
||||||
is_debug_op_,
|
|
||||||
&state_variable_specs));
|
|
||||||
int table_id;
|
int table_id;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
|
||||||
string table_name;
|
string table_name;
|
||||||
@ -388,14 +348,9 @@ Status RetrieveOpShapeFunction::operator()(
|
|||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
|
||||||
int shard_id;
|
int shard_id;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
|
||||||
for (int j = 0, end = state_variable_specs.size(); j < end; ++j) {
|
for (int j = 0; j < c->num_outputs(); ++j) {
|
||||||
if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
|
c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
|
||||||
auto shape = c->MakeShape(
|
2, c->UnknownDim())));
|
||||||
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->set_output(state_variable_specs[j].name(),
|
|
||||||
std::vector<shape_inference::ShapeHandle>(1, shape)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -48,18 +48,19 @@ enum class GradientAccumulationSupport {
|
|||||||
// Returns the number of optimization parameter vectors used by the optimization
|
// Returns the number of optimization parameter vectors used by the optimization
|
||||||
// algorithm, excluding the weights themselves and assuming no gradient
|
// algorithm, excluding the weights themselves and assuming no gradient
|
||||||
// accumulation.
|
// accumulation.
|
||||||
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
|
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms,
|
||||||
|
int *count);
|
||||||
|
|
||||||
// Returns whether (and how) an optimization algorithm supports gradient
|
// Returns whether (and how) an optimization algorithm supports gradient
|
||||||
// accumulation.
|
// accumulation.
|
||||||
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
|
Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms,
|
||||||
GradientAccumulationSupport *support);
|
GradientAccumulationSupport *support);
|
||||||
|
|
||||||
// Returns the parameter specifications for the optimization algorithm (the main
|
// Returns the parameter specifications for the optimization algorithm (the main
|
||||||
// parameters first, followed by any auxiliary parameters such as Adagrad
|
// parameters first, followed by any auxiliary parameters such as Adagrad
|
||||||
// accumulators).
|
// accumulators).
|
||||||
Status GetOptimizationAlgorithmStateVariables(
|
Status GetOptimizationAlgorithmStateVariables(
|
||||||
OptimizationAlgorithm alg, bool use_gradient_accumulation,
|
const OptimizationParameters ¶ms, bool use_gradient_accumulation,
|
||||||
std::vector<StateVariableSpecification> *state_variables);
|
std::vector<StateVariableSpecification> *state_variables);
|
||||||
|
|
||||||
// Maximum value of auxiliar_parameter_count for any optimization algorithm.
|
// Maximum value of auxiliar_parameter_count for any optimization algorithm.
|
||||||
@ -93,35 +94,15 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
|
|||||||
// Generic shape function for per-optimization-algorithm load ops.
|
// Generic shape function for per-optimization-algorithm load ops.
|
||||||
class LoadOpShapeFunction {
|
class LoadOpShapeFunction {
|
||||||
public:
|
public:
|
||||||
// Constructor.
|
|
||||||
LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
|
||||||
|
|
||||||
// Computes resulting shape and does parameter checking.
|
// Computes resulting shape and does parameter checking.
|
||||||
Status operator()(shape_inference::InferenceContext *c) const;
|
Status operator()(shape_inference::InferenceContext *c) const;
|
||||||
|
|
||||||
private:
|
|
||||||
// Optimization algorithm.
|
|
||||||
const OptimizationAlgorithm alg_;
|
|
||||||
|
|
||||||
// Whether this op has an extra parameter for the gradient accumulators.
|
|
||||||
const bool is_debug_op_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generic shape function for per-optimization-algorithm retrieve ops.
|
// Generic shape function for per-optimization-algorithm retrieve ops.
|
||||||
class RetrieveOpShapeFunction {
|
class RetrieveOpShapeFunction {
|
||||||
public:
|
public:
|
||||||
// Constructor.
|
|
||||||
RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
|
|
||||||
|
|
||||||
// Computes resulting shape and does parameter checking.
|
// Computes resulting shape and does parameter checking.
|
||||||
Status operator()(shape_inference::InferenceContext *c) const;
|
Status operator()(shape_inference::InferenceContext *c) const;
|
||||||
|
|
||||||
private:
|
|
||||||
// Optimization algorithm.
|
|
||||||
const OptimizationAlgorithm alg_;
|
|
||||||
|
|
||||||
// Whether this op has an extra parameter for the gradient accumulators.
|
|
||||||
const bool is_debug_op_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
Loading…
Reference in New Issue
Block a user