diff --git a/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc b/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc
index 30bdaa4d848..f32261209ae 100644
--- a/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc
+++ b/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc
@@ -37,8 +37,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -50,8 +49,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
     .Output("parameters: float32")
@@ -62,8 +60,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
-                                        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -75,8 +72,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
-                                        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
     .Input("parameters: float32")
@@ -86,9 +82,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
-        /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -99,9 +93,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
-        /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
     .Output("parameters: float32")
@@ -111,9 +103,7 @@ REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP(
     "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
@@ -125,9 +115,7 @@ REGISTER_OP(
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
-        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
     .Input("parameters: float32")
@@ -139,8 +127,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -153,8 +140,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
     .Output("parameters: float32")
@@ -166,8 +152,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
-                                        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -180,8 +165,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
-                                        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingADAMParameters")
     .Input("parameters: float32")
@@ -193,8 +177,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -207,8 +190,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
     .Output("parameters: float32")
@@ -220,8 +202,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
-                                        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -234,8 +215,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
-                                        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
     .Input("parameters: float32")
@@ -246,8 +226,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -259,8 +238,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
     .Output("parameters: float32")
@@ -271,9 +249,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kMomentum,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -285,9 +261,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kMomentum,
-        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
     .Input("parameters: float32")
@@ -299,8 +273,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -313,8 +286,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
     .Output("parameters: float32")
@@ -326,8 +298,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
-                                        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -340,8 +311,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
-                                        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
     .Input("parameters: float32")
@@ -354,9 +324,7 @@ REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
-        /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
     .Output("parameters: float32")
@@ -369,9 +337,7 @@ REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
     .Input("parameters: float32")
@@ -384,9 +350,7 @@ REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
-        /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
     .Output("parameters: float32")
@@ -399,9 +363,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
     .Input("parameters: float32")
@@ -413,8 +375,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
-                                    /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -427,8 +388,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
-                                    /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
     .Output("parameters: float32")
@@ -440,9 +400,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kAdadelta,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -455,9 +413,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kAdadelta,
-        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
     .Input("parameters: float32")
@@ -468,9 +424,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalAdagrad,
-        /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -482,9 +436,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalAdagrad,
-        /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
     .Output("parameters: float32")
@@ -495,9 +447,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalAdagrad,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -509,9 +459,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalAdagrad,
-        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
     .Input("parameters: float32")
@@ -523,9 +471,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalYogi,
-        /*is_debug_op=*/false});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
     .Input("parameters: float32")
@@ -538,9 +484,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(LoadOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalYogi,
-        /*is_debug_op=*/true});
+    .SetShapeFn(LoadOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
     .Output("parameters: float32")
@@ -552,9 +496,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalYogi,
-        /*is_debug_op=*/false});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
     .Output("parameters: float32")
@@ -567,9 +509,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
     .Attr("shard_id: int")
     .Attr("config: string = \"\"")
     .SetIsStateful()
-    .SetShapeFn(RetrieveOpShapeFunction{
-        /*alg=*/OptimizationAlgorithm::kProximalYogi,
-        /*is_debug_op=*/true});
+    .SetShapeFn(RetrieveOpShapeFunction());
 
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
index 961858665a4..d786b5d1b8f 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
@@ -94,8 +94,9 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
 // Returns the number of optimization parameter vectors used by the optimization
 // algorithm, excluding the weights themselves and assuming no gradient
 // accumulation.
-Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
-  switch (alg) {
+Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
+                                      int* count) {
+  switch (params.parameters_case()) {
     case OptimizationAlgorithm::kAdagrad:
       *count = 1;
       return Status::OK();
@@ -141,11 +142,11 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
   return errors::InvalidArgument("No optimization algorithm specified");
 }
 
-Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+Status GetGradientAccumulationSupport(const OptimizationParameters& params,
                                       GradientAccumulationSupport* support) {
   int auxiliary_parameter_count;
   TF_RETURN_IF_ERROR(
-      GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
+      GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
   *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
                  ? GradientAccumulationSupport::kSupported
                  : GradientAccumulationSupport::kNotSupported;
@@ -168,7 +169,7 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
 }  // namespace
 
 Status GetOptimizationAlgorithmStateVariables(
-    OptimizationAlgorithm alg, bool use_gradient_accumulation,
+    const OptimizationParameters& params, bool use_gradient_accumulation,
     std::vector<StateVariableSpecification>* state_variables) {
   // The first parameter set is always the weights themselves.
   state_variables->push_back(
@@ -176,7 +177,7 @@ Status GetOptimizationAlgorithmStateVariables(
   // The order of the returned parameters needs to match the offsets used by
   // the algorithm implementations in test_util.cc and
   // address_handler_program_creator.cc.
-  switch (alg) {
+  switch (params.parameters_case()) {
     case OptimizationAlgorithm::kAdagrad: {
       state_variables->push_back(
           MakeStandardStateVariableSpecification("accumulators", 0.1));
@@ -276,7 +277,8 @@ Status GetOptimizationAlgorithmStateVariables(
   }
   if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
     return errors::InvalidArgument(
-        "Optimization algorithm", GetOptimizationAlgorithmName(alg),
+        "Optimization algorithm",
+        GetOptimizationAlgorithmName(params.parameters_case()),
         "does not support gradient accumulation because it "
         "already has too many other accumulators");
   }
@@ -301,21 +303,8 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
   };
 }
 
-LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
-                                         bool is_debug_op)
-    : alg_(alg), is_debug_op_(is_debug_op) {}
-
 Status LoadOpShapeFunction::operator()(
     shape_inference::InferenceContext* c) const {
-  GradientAccumulationSupport grad_accum_support;
-  TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
-
-  std::vector<StateVariableSpecification> state_variable_specs;
-  TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
-      alg_,
-      grad_accum_support == GradientAccumulationSupport::kSupported &&
-          is_debug_op_,
-      &state_variable_specs));
   int table_id;
   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
   string table_name;
@@ -329,52 +318,23 @@ Status LoadOpShapeFunction::operator()(
   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
   int shard_id;
   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
-  const int user_param_count =
-      std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
-                    [&](const StateVariableSpecification& sv) {
-                      return sv.has_user_defined() || is_debug_op_;
-                    });
-  std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
-  int input_index = 0;
-  for (int i = 0, end = state_variable_specs.size(); i < end; ++i) {
-    if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
-      std::vector<shape_inference::ShapeHandle> input_temp;
-      TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
-      if (input_temp.size() != 1) {
-        return errors::InvalidArgument("each input to be rank 1");
-      }
-      inputs[input_index] = input_temp[0];
-      ++input_index;
-    }
-  }
+
   // Verify shapes have rank 2 and are compatible when they are
   // required to be valid.
   shape_inference::ShapeHandle parameter_shape;
-  TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
-  for (int j = 1; j < user_param_count; ++j) {
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
+  for (int j = 1; j < c->num_inputs(); ++j) {
     shape_inference::ShapeHandle accumulator_j_shape;
-    TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
+    TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
     shape_inference::ShapeHandle merged;
     TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
   }
+
   return Status::OK();
 }
 
-RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
-                                                 bool is_debug_op)
-    : alg_(alg), is_debug_op_(is_debug_op) {}
-
 Status RetrieveOpShapeFunction::operator()(
     shape_inference::InferenceContext* c) const {
-  GradientAccumulationSupport grad_accum_support;
-  TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
-
-  std::vector<StateVariableSpecification> state_variable_specs;
-  TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
-      alg_,
-      grad_accum_support == GradientAccumulationSupport::kSupported &&
-          is_debug_op_,
-      &state_variable_specs));
   int table_id;
   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
   string table_name;
@@ -388,14 +348,9 @@ Status RetrieveOpShapeFunction::operator()(
   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
   int shard_id;
   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
-  for (int j = 0, end = state_variable_specs.size(); j < end; ++j) {
-    if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
-      auto shape = c->MakeShape(
-          std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
-      TF_RETURN_IF_ERROR(
-          c->set_output(state_variable_specs[j].name(),
-                        std::vector<shape_inference::ShapeHandle>(1, shape)));
-    }
+  for (int j = 0; j < c->num_outputs(); ++j) {
+    c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
+                         2, c->UnknownDim())));
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
index 3e864c89cf2..8d98864ba6a 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
@@ -48,18 +48,19 @@ enum class GradientAccumulationSupport {
 // Returns the number of optimization parameter vectors used by the optimization
 // algorithm, excluding the weights themselves and assuming no gradient
 // accumulation.
-Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
+Status GetBaseAuxiliaryParameterCount(const OptimizationParameters &params,
+                                      int *count);
 
 // Returns whether (and how) an optimization algorithm supports gradient
 // accumulation.
-Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+Status GetGradientAccumulationSupport(const OptimizationParameters &params,
                                       GradientAccumulationSupport *support);
 
 // Returns the parameter specifications for the optimization algorithm (the main
 // parameters first, followed by any auxiliary parameters such as Adagrad
 // accumulators).
 Status GetOptimizationAlgorithmStateVariables(
-    OptimizationAlgorithm alg, bool use_gradient_accumulation,
+    const OptimizationParameters &params, bool use_gradient_accumulation,
     std::vector<StateVariableSpecification> *state_variables);
 
 // Maximum value of auxiliar_parameter_count for any optimization algorithm.
@@ -93,35 +94,15 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
 // Generic shape function for per-optimization-algorithm load ops.
 class LoadOpShapeFunction {
  public:
-  // Constructor.
-  LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
-
   // Computes resulting shape and does parameter checking.
   Status operator()(shape_inference::InferenceContext *c) const;
-
- private:
-  // Optimization algorithm.
-  const OptimizationAlgorithm alg_;
-
-  // Whether this op has an extra parameter for the gradient accumulators.
-  const bool is_debug_op_;
 };
 
 // Generic shape function for per-optimization-algorithm retrieve ops.
 class RetrieveOpShapeFunction {
  public:
-  // Constructor.
-  RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
-
   // Computes resulting shape and does parameter checking.
   Status operator()(shape_inference::InferenceContext *c) const;
-
- private:
-  // Optimization algorithm.
-  const OptimizationAlgorithm alg_;
-
-  // Whether this op has an extra parameter for the gradient accumulators.
-  const bool is_debug_op_;
 };
 
 }  // namespace tpu