diff --git a/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc b/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc index 30bdaa4d848..f32261209ae 100644 --- a/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_load_retrieve_ops.cc @@ -37,8 +37,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug") .Input("parameters: float32") @@ -50,8 +49,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters") .Output("parameters: float32") @@ -62,8 +60,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug") .Output("parameters: float32") @@ -75,8 +72,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters") .Input("parameters: float32") @@ -86,9 +82,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug") .Input("parameters: float32") @@ -99,9 +93,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters") .Output("parameters: float32") @@ -111,9 +103,7 @@ REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP( "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug") @@ -125,9 +115,7 @@ REGISTER_OP( .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kStochasticGradientDescent, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingFTRLParameters") .Input("parameters: float32") @@ -139,8 +127,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug") .Input("parameters: float32") @@ -153,8 +140,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters") .Output("parameters: float32") @@ -166,8 +152,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug") .Output("parameters: float32") @@ -180,8 +165,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingADAMParameters") .Input("parameters: float32") @@ -193,8 +177,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug") .Input("parameters: float32") @@ -207,8 +190,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingADAMParameters") .Output("parameters: float32") @@ -220,8 +202,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug") .Output("parameters: float32") @@ -234,8 +215,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingMomentumParameters") .Input("parameters: float32") @@ -246,8 +226,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug") .Input("parameters: float32") @@ -259,8 +238,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters") .Output("parameters: float32") @@ -271,9 +249,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kMomentum, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug") .Output("parameters: float32") @@ -285,9 +261,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kMomentum, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingRMSPropParameters") .Input("parameters: float32") @@ -299,8 +273,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug") .Input("parameters: float32") @@ -313,8 +286,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters") .Output("parameters: float32") @@ -326,8 +298,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug") .Output("parameters: float32") @@ -340,8 +311,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters") .Input("parameters: float32") @@ -354,9 +324,7 @@ REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kCenteredRmsProp, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters") .Output("parameters: float32") @@ -369,9 +337,7 @@ REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kCenteredRmsProp, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters") .Input("parameters: float32") @@ -384,9 +350,7 @@ REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kMdlAdagradLight, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters") .Output("parameters: float32") @@ -399,9 +363,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kMdlAdagradLight, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters") .Input("parameters: float32") @@ -413,8 +375,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug") .Input("parameters: float32") @@ -427,8 +388,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters") .Output("parameters: float32") @@ -440,9 +400,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kAdadelta, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug") .Output("parameters: float32") @@ -455,9 +413,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kAdadelta, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters") .Input("parameters: float32") @@ -468,9 +424,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalAdagrad, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug") .Input("parameters: float32") @@ -482,9 +436,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalAdagrad, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters") .Output("parameters: float32") @@ -495,9 +447,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalAdagrad, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug") .Output("parameters: float32") @@ -509,9 +459,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalAdagrad, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters") .Input("parameters: float32") @@ -523,9 +471,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalYogi, - /*is_debug_op=*/false}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug") .Input("parameters: float32") @@ -538,9 +484,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(LoadOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalYogi, - /*is_debug_op=*/true}); + .SetShapeFn(LoadOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters") .Output("parameters: float32") @@ -552,9 +496,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalYogi, - /*is_debug_op=*/false}); + .SetShapeFn(RetrieveOpShapeFunction()); REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug") .Output("parameters: float32") @@ -567,9 +509,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug") .Attr("shard_id: int") .Attr("config: string = \"\"") .SetIsStateful() - .SetShapeFn(RetrieveOpShapeFunction{ - /*alg=*/OptimizationAlgorithm::kProximalYogi, - /*is_debug_op=*/true}); + .SetShapeFn(RetrieveOpShapeFunction()); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index 961858665a4..d786b5d1b8f 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -94,8 +94,9 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { // Returns the number of optimization parameter vectors used by the optimization // algorithm, excluding the weights themselves and assuming no gradient // accumulation. -Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) { - switch (alg) { +Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, + int* count) { + switch (params.parameters_case()) { case OptimizationAlgorithm::kAdagrad: *count = 1; return Status::OK(); @@ -141,11 +142,11 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) { return errors::InvalidArgument("No optimization algorithm specified"); } -Status GetGradientAccumulationSupport(OptimizationAlgorithm alg, +Status GetGradientAccumulationSupport(const OptimizationParameters& params, GradientAccumulationSupport* support) { int auxiliary_parameter_count; TF_RETURN_IF_ERROR( - GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count)); + GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count)); *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount ? GradientAccumulationSupport::kSupported : GradientAccumulationSupport::kNotSupported; @@ -168,7 +169,7 @@ StateVariableSpecification MakeStandardStateVariableSpecification( } // namespace Status GetOptimizationAlgorithmStateVariables( - OptimizationAlgorithm alg, bool use_gradient_accumulation, + const OptimizationParameters& params, bool use_gradient_accumulation, std::vector<StateVariableSpecification>* state_variables) { // The first parameter set is always the weights themselves. state_variables->push_back( @@ -176,7 +177,7 @@ Status GetOptimizationAlgorithmStateVariables( // The order of the returned parameters needs to match the offsets used by // the algorithm implementations in test_util.cc and // address_handler_program_creator.cc. - switch (alg) { + switch (params.parameters_case()) { case OptimizationAlgorithm::kAdagrad: { state_variables->push_back( MakeStandardStateVariableSpecification("accumulators", 0.1)); @@ -276,7 +277,8 @@ Status GetOptimizationAlgorithmStateVariables( } if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) { return errors::InvalidArgument( - "Optimization algorithm", GetOptimizationAlgorithmName(alg), + "Optimization algorithm", + GetOptimizationAlgorithmName(params.parameters_case()), "does not support gradient accumulation because it " "already has too many other accumulators"); } @@ -301,21 +303,8 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() { }; } -LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg, - bool is_debug_op) - : alg_(alg), is_debug_op_(is_debug_op) {} - Status LoadOpShapeFunction::operator()( shape_inference::InferenceContext* c) const { - GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support)); - - std::vector<StateVariableSpecification> state_variable_specs; - TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( - alg_, - grad_accum_support == GradientAccumulationSupport::kSupported && - is_debug_op_, - &state_variable_specs)); int table_id; TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; @@ -329,52 +318,23 @@ Status LoadOpShapeFunction::operator()( TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); int shard_id; TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); - const int user_param_count = - std::count_if(state_variable_specs.begin(), state_variable_specs.end(), - [&](const StateVariableSpecification& sv) { - return sv.has_user_defined() || is_debug_op_; - }); - std::vector<shape_inference::ShapeHandle> inputs(user_param_count); - int input_index = 0; - for (int i = 0, end = state_variable_specs.size(); i < end; ++i) { - if (state_variable_specs[i].has_user_defined() || is_debug_op_) { - std::vector<shape_inference::ShapeHandle> input_temp; - TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp)); - if (input_temp.size() != 1) { - return errors::InvalidArgument("each input to be rank 1"); - } - inputs[input_index] = input_temp[0]; - ++input_index; - } - } + // Verify shapes have rank 2 and are compatible when they are // required to be valid. shape_inference::ShapeHandle parameter_shape; - TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape)); - for (int j = 1; j < user_param_count; ++j) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, ¶meter_shape)); + for (int j = 1; j < c->num_inputs(); ++j) { shape_inference::ShapeHandle accumulator_j_shape; - TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape)); shape_inference::ShapeHandle merged; TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged)); } + return Status::OK(); } -RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg, - bool is_debug_op) - : alg_(alg), is_debug_op_(is_debug_op) {} - Status RetrieveOpShapeFunction::operator()( shape_inference::InferenceContext* c) const { - GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support)); - - std::vector<StateVariableSpecification> state_variable_specs; - TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( - alg_, - grad_accum_support == GradientAccumulationSupport::kSupported && - is_debug_op_, - &state_variable_specs)); int table_id; TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; @@ -388,14 +348,9 @@ Status RetrieveOpShapeFunction::operator()( TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); int shard_id; TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); - for (int j = 0, end = state_variable_specs.size(); j < end; ++j) { - if (state_variable_specs[j].has_user_defined() || is_debug_op_) { - auto shape = c->MakeShape( - std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim())); - TF_RETURN_IF_ERROR( - c->set_output(state_variable_specs[j].name(), - std::vector<shape_inference::ShapeHandle>(1, shape))); - } + for (int j = 0; j < c->num_outputs(); ++j) { + c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>( + 2, c->UnknownDim()))); } return Status::OK(); } diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h index 3e864c89cf2..8d98864ba6a 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h @@ -48,18 +48,19 @@ enum class GradientAccumulationSupport { // Returns the number of optimization parameter vectors used by the optimization // algorithm, excluding the weights themselves and assuming no gradient // accumulation. -Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count); +Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms, + int *count); // Returns whether (and how) an optimization algorithm supports gradient // accumulation. -Status GetGradientAccumulationSupport(OptimizationAlgorithm alg, +Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms, GradientAccumulationSupport *support); // Returns the parameter specifications for the optimization algorithm (the main // parameters first, followed by any auxiliary parameters such as Adagrad // accumulators). Status GetOptimizationAlgorithmStateVariables( - OptimizationAlgorithm alg, bool use_gradient_accumulation, + const OptimizationParameters ¶ms, bool use_gradient_accumulation, std::vector<StateVariableSpecification> *state_variables); // Maximum value of auxiliar_parameter_count for any optimization algorithm. @@ -93,35 +94,15 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg, // Generic shape function for per-optimization-algorithm load ops. class LoadOpShapeFunction { public: - // Constructor. - LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op); - // Computes resulting shape and does parameter checking. Status operator()(shape_inference::InferenceContext *c) const; - - private: - // Optimization algorithm. - const OptimizationAlgorithm alg_; - - // Whether this op has an extra parameter for the gradient accumulators. - const bool is_debug_op_; }; // Generic shape function for per-optimization-algorithm retrieve ops. class RetrieveOpShapeFunction { public: - // Constructor. - RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op); - // Computes resulting shape and does parameter checking. Status operator()(shape_inference::InferenceContext *c) const; - - private: - // Optimization algorithm. - const OptimizationAlgorithm alg_; - - // Whether this op has an extra parameter for the gradient accumulators. - const bool is_debug_op_; }; } // namespace tpu