Allow TPU embedding utility ops to take OptimizationParameters rather than OptimizationAlgorithm.

PiperOrigin-RevId: 333120828
Change-Id: I959f8f116293df5e2130654a073581219cca1fec
This commit is contained in:
Bruce Fontaine 2020-09-22 11:18:57 -07:00 committed by TensorFlower Gardener
parent a151f928f3
commit 2c7d978a14
3 changed files with 61 additions and 185 deletions

View File

@ -37,8 +37,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
.Input("parameters: float32")
@ -50,8 +49,7 @@ REGISTER_OP("LoadTPUEmbeddingAdagradParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
.Output("parameters: float32")
@ -62,8 +60,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
.Output("parameters: float32")
@ -75,8 +72,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdagradParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdagrad,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
.Input("parameters: float32")
@ -86,9 +82,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
.Input("parameters: float32")
@ -99,9 +93,7 @@ REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
.Output("parameters: float32")
@ -111,9 +103,7 @@ REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP(
"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug")
@ -125,9 +115,7 @@ REGISTER_OP(
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kStochasticGradientDescent,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
.Input("parameters: float32")
@ -139,8 +127,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
.Input("parameters: float32")
@ -153,8 +140,7 @@ REGISTER_OP("LoadTPUEmbeddingFTRLParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
.Output("parameters: float32")
@ -166,8 +152,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
.Output("parameters: float32")
@ -180,8 +165,7 @@ REGISTER_OP("RetrieveTPUEmbeddingFTRLParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kFtrl,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingADAMParameters")
.Input("parameters: float32")
@ -193,8 +177,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
.Input("parameters: float32")
@ -207,8 +190,7 @@ REGISTER_OP("LoadTPUEmbeddingADAMParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
.Output("parameters: float32")
@ -220,8 +202,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
.Output("parameters: float32")
@ -234,8 +215,7 @@ REGISTER_OP("RetrieveTPUEmbeddingADAMParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdam,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
.Input("parameters: float32")
@ -246,8 +226,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
.Input("parameters: float32")
@ -259,8 +238,7 @@ REGISTER_OP("LoadTPUEmbeddingMomentumParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
.Output("parameters: float32")
@ -271,9 +249,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
.Output("parameters: float32")
@ -285,9 +261,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMomentumParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMomentum,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
.Input("parameters: float32")
@ -299,8 +273,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
.Input("parameters: float32")
@ -313,8 +286,7 @@ REGISTER_OP("LoadTPUEmbeddingRMSPropParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
.Output("parameters: float32")
@ -326,8 +298,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
.Output("parameters: float32")
@ -340,8 +311,7 @@ REGISTER_OP("RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{/*alg=*/OptimizationAlgorithm::kRmsProp,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
.Input("parameters: float32")
@ -354,9 +324,7 @@ REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
.Output("parameters: float32")
@ -369,9 +337,7 @@ REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kCenteredRmsProp,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
.Input("parameters: float32")
@ -384,9 +350,7 @@ REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
.Output("parameters: float32")
@ -399,9 +363,7 @@ REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kMdlAdagradLight,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
.Input("parameters: float32")
@ -413,8 +375,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Input("parameters: float32")
@ -427,8 +388,7 @@ REGISTER_OP("LoadTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
.Output("parameters: float32")
@ -440,9 +400,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Output("parameters: float32")
@ -455,9 +413,7 @@ REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kAdadelta,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
.Input("parameters: float32")
@ -468,9 +424,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Input("parameters: float32")
@ -482,9 +436,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
.Output("parameters: float32")
@ -495,9 +447,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Output("parameters: float32")
@ -509,9 +459,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalAdagrad,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
.Input("parameters: float32")
@ -523,9 +471,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/false});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Input("parameters: float32")
@ -538,9 +484,7 @@ REGISTER_OP("LoadTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/true});
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
.Output("parameters: float32")
@ -552,9 +496,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/false});
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Output("parameters: float32")
@ -567,9 +509,7 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction{
/*alg=*/OptimizationAlgorithm::kProximalYogi,
/*is_debug_op=*/true});
.SetShapeFn(RetrieveOpShapeFunction());
} // namespace tpu
} // namespace tensorflow

View File

@ -94,8 +94,9 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
// Returns the number of optimization parameter vectors used by the optimization
// algorithm, excluding the weights themselves and assuming no gradient
// accumulation.
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
switch (alg) {
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
int* count) {
switch (params.parameters_case()) {
case OptimizationAlgorithm::kAdagrad:
*count = 1;
return Status::OK();
@ -141,11 +142,11 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
return errors::InvalidArgument("No optimization algorithm specified");
}
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
Status GetGradientAccumulationSupport(const OptimizationParameters& params,
GradientAccumulationSupport* support) {
int auxiliary_parameter_count;
TF_RETURN_IF_ERROR(
GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
*support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
? GradientAccumulationSupport::kSupported
: GradientAccumulationSupport::kNotSupported;
@ -168,7 +169,7 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
} // namespace
Status GetOptimizationAlgorithmStateVariables(
OptimizationAlgorithm alg, bool use_gradient_accumulation,
const OptimizationParameters& params, bool use_gradient_accumulation,
std::vector<StateVariableSpecification>* state_variables) {
// The first parameter set is always the weights themselves.
state_variables->push_back(
@ -176,7 +177,7 @@ Status GetOptimizationAlgorithmStateVariables(
// The order of the returned parameters needs to match the offsets used by
// the algorithm implementations in test_util.cc and
// address_handler_program_creator.cc.
switch (alg) {
switch (params.parameters_case()) {
case OptimizationAlgorithm::kAdagrad: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
@ -276,7 +277,8 @@ Status GetOptimizationAlgorithmStateVariables(
}
if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
return errors::InvalidArgument(
"Optimization algorithm", GetOptimizationAlgorithmName(alg),
"Optimization algorithm",
GetOptimizationAlgorithmName(params.parameters_case()),
"does not support gradient accumulation because it "
"already has too many other accumulators");
}
@ -301,21 +303,8 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
};
}
LoadOpShapeFunction::LoadOpShapeFunction(OptimizationAlgorithm alg,
bool is_debug_op)
: alg_(alg), is_debug_op_(is_debug_op) {}
Status LoadOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
std::vector<StateVariableSpecification> state_variable_specs;
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
alg_,
grad_accum_support == GradientAccumulationSupport::kSupported &&
is_debug_op_,
&state_variable_specs));
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
@ -329,52 +318,23 @@ Status LoadOpShapeFunction::operator()(
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
const int user_param_count =
std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
[&](const StateVariableSpecification& sv) {
return sv.has_user_defined() || is_debug_op_;
});
std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
int input_index = 0;
for (int i = 0, end = state_variable_specs.size(); i < end; ++i) {
if (state_variable_specs[i].has_user_defined() || is_debug_op_) {
std::vector<shape_inference::ShapeHandle> input_temp;
TF_RETURN_IF_ERROR(c->input(state_variable_specs[i].name(), &input_temp));
if (input_temp.size() != 1) {
return errors::InvalidArgument("each input to be rank 1");
}
inputs[input_index] = input_temp[0];
++input_index;
}
}
// Verify shapes have rank 2 and are compatible when they are
// required to be valid.
shape_inference::ShapeHandle parameter_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
for (int j = 1; j < user_param_count; ++j) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
for (int j = 1; j < c->num_inputs(); ++j) {
shape_inference::ShapeHandle accumulator_j_shape;
TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
}
return Status::OK();
}
RetrieveOpShapeFunction::RetrieveOpShapeFunction(OptimizationAlgorithm alg,
bool is_debug_op)
: alg_(alg), is_debug_op_(is_debug_op) {}
Status RetrieveOpShapeFunction::operator()(
shape_inference::InferenceContext* c) const {
GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg_, &grad_accum_support));
std::vector<StateVariableSpecification> state_variable_specs;
TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
alg_,
grad_accum_support == GradientAccumulationSupport::kSupported &&
is_debug_op_,
&state_variable_specs));
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
@ -388,14 +348,9 @@ Status RetrieveOpShapeFunction::operator()(
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
int shard_id;
TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
for (int j = 0, end = state_variable_specs.size(); j < end; ++j) {
if (state_variable_specs[j].has_user_defined() || is_debug_op_) {
auto shape = c->MakeShape(
std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
TF_RETURN_IF_ERROR(
c->set_output(state_variable_specs[j].name(),
std::vector<shape_inference::ShapeHandle>(1, shape)));
}
for (int j = 0; j < c->num_outputs(); ++j) {
c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
2, c->UnknownDim())));
}
return Status::OK();
}

View File

@ -48,18 +48,19 @@ enum class GradientAccumulationSupport {
// Returns the number of optimization parameter vectors used by the optimization
// algorithm, excluding the weights themselves and assuming no gradient
// accumulation.
Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
Status GetBaseAuxiliaryParameterCount(const OptimizationParameters &params,
int *count);
// Returns whether (and how) an optimization algorithm supports gradient
// accumulation.
Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
Status GetGradientAccumulationSupport(const OptimizationParameters &params,
GradientAccumulationSupport *support);
// Returns the parameter specifications for the optimization algorithm (the main
// parameters first, followed by any auxiliary parameters such as Adagrad
// accumulators).
Status GetOptimizationAlgorithmStateVariables(
OptimizationAlgorithm alg, bool use_gradient_accumulation,
const OptimizationParameters &params, bool use_gradient_accumulation,
std::vector<StateVariableSpecification> *state_variables);
// Maximum value of auxiliar_parameter_count for any optimization algorithm.
@ -93,35 +94,15 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
// Generic shape function for per-optimization-algorithm load ops.
class LoadOpShapeFunction {
public:
// Constructor.
LoadOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
// Computes resulting shape and does parameter checking.
Status operator()(shape_inference::InferenceContext *c) const;
private:
// Optimization algorithm.
const OptimizationAlgorithm alg_;
// Whether this op has an extra parameter for the gradient accumulators.
const bool is_debug_op_;
};
// Generic shape function for per-optimization-algorithm retrieve ops.
class RetrieveOpShapeFunction {
public:
// Constructor.
RetrieveOpShapeFunction(OptimizationAlgorithm alg, bool is_debug_op);
// Computes resulting shape and does parameter checking.
Status operator()(shape_inference::InferenceContext *c) const;
private:
// Optimization algorithm.
const OptimizationAlgorithm alg_;
// Whether this op has an extra parameter for the gradient accumulators.
const bool is_debug_op_;
};
} // namespace tpu