Add a set of dynamic embedding optimizers directly taking an HloModule.

PiperOrigin-RevId: 334507326
Change-Id: I6cec860e7ce56d77fe1c429945323691928b3c7a
This commit is contained in:
Bruce Fontaine 2020-09-29 18:46:53 -07:00 committed by TensorFlower Gardener
parent aeeb7d93d8
commit 84db547203
4 changed files with 121 additions and 50 deletions

View File

@ -30,6 +30,7 @@ tf_proto_library(
"optimization_parameters.proto",
],
cc_api_version = 2,
protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"],
visibility = ["//visibility:public"],
)

View File

@ -3,6 +3,7 @@ syntax = "proto3";
package tensorflow.tpu;
import "google/protobuf/wrappers.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
message ClippingLimits {
google.protobuf.FloatValue lower = 1; // -inf if not set
@ -317,6 +318,34 @@ message FrequencyEstimatorParameters {
float weight_exponent = 4;
}
// A user-defined optimizer.
// The contained HLO program must take the following arguments in the following
// order:
// 1. gradients
// 2. table weights
// 3. slot variables
// 4. an optional scalar input that is passed in via the dynamic learning
// rate mechanism.
//
// It must return/end in a tuple op that contains the following values in the
// following order:
// 1. new table values
// 2. new slot variable value
//
// The program must have shape (1,1) with dtype float32 throughout and only use
// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
// and no broadcasting outside of the single scalar input).
// The HLO program should be written as if it were a dense update. It will be
// called on each row that needs an update and will applied elementwise.
message UserDefinedProgramParameters {
xla.HloModuleProto program = 1;
// Padding values for the parameter and the slots, see
// StateVariableSpecification.padding_initial_value below for more details on
// how this should be set. One value is needed for the weights and one for
// each slot.
repeated float padding_values = 2;
}
// Status of using gradient accumulation (doing two passes over the input
// gradients: one to accumulate them into a temporary array and another to apply
// them using the actual optimization algorithm). The extra message is to wrap
@ -395,6 +424,7 @@ message OptimizationParameters {
OnlineYogiParameters online_yogi = 20;
ProximalYogiParameters proximal_yogi = 21;
FrequencyEstimatorParameters frequency_estimator = 23;
UserDefinedProgramParameters user_defined_program = 24;
}
reserved 15; // Old use_gradient_accumulation.

View File

@ -27,6 +27,9 @@ cc_library(
hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
return "ProximalYogi";
case OptimizationAlgorithm::kFrequencyEstimator:
return "FrequencyEstimator";
case OptimizationAlgorithm::kUserDefinedProgram:
return "UserDefinedProgram";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "*** Not set ***";
}
@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
return "proximal Yogi";
case OptimizationAlgorithm::kFrequencyEstimator:
return "frequency estimator";
case OptimizationAlgorithm::kUserDefinedProgram:
return "UserDefinedProgram";
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return "unknown (not specified)";
}
@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
case OptimizationAlgorithm::kFrequencyEstimator:
*count = 1;
return Status::OK();
case OptimizationAlgorithm::kUserDefinedProgram: {
const xla::ProgramShapeProto& program_shape =
params.user_defined_program().program().host_program_shape();
const int num_inputs = program_shape.parameters_size();
const int num_outputs = program_shape.result().tuple_shapes_size();
if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
(num_inputs != num_outputs + 2))) {
return errors::InvalidArgument(
"User-defined TPU embedding optimizer program must have at least "
"two inputs and the number of outputs must be 1 or 2 less than the "
"number of inputs. Received ",
num_inputs, " input(s) and ", num_outputs, "output(s).");
}
*count = num_outputs - 1;
return Status::OK();
}
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
return errors::InvalidArgument("No optimization algorithm specified");
}
@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
Status GetOptimizationAlgorithmStateVariables(
const OptimizationParameters& params, bool use_gradient_accumulation,
std::vector<StateVariableSpecification>* state_variables) {
// The first parameter set is always the weights themselves.
state_variables->push_back(
MakeStandardStateVariableSpecification("parameters", 0.0));
// The order of the returned parameters needs to match the offsets used by
// the algorithm implementations in test_util.cc and
// address_handler_program_creator.cc.
// The first parameter set is always the weights themselves.
auto add_state_variable = [&](const std::string& name, float value) {
state_variables->push_back(
MakeStandardStateVariableSpecification(name, value));
};
switch (params.parameters_case()) {
case OptimizationAlgorithm::kAdagrad: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.1);
break;
}
case OptimizationAlgorithm::kBoundedAdagrad: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.1);
break;
}
case OptimizationAlgorithm::kStochasticGradientDescent: {
// None.
add_state_variable("parameters", 0.0);
break;
}
case OptimizationAlgorithm::kFtrl: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
state_variables->push_back(
MakeStandardStateVariableSpecification("linears", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.1);
add_state_variable("linears", 0.0);
break;
}
case OptimizationAlgorithm::kAdam: {
state_variables->push_back(
MakeStandardStateVariableSpecification("momenta", 0.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("velocities", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("momenta", 0.0);
add_state_variable("velocities", 0.0);
break;
}
case OptimizationAlgorithm::kMomentum: {
state_variables->push_back(
MakeStandardStateVariableSpecification("momenta", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("momenta", 0.0);
break;
}
case OptimizationAlgorithm::kRmsProp: {
state_variables->push_back(
MakeStandardStateVariableSpecification("ms", 1.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("mom", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("ms", 1.0);
add_state_variable("mom", 0.0);
break;
}
case OptimizationAlgorithm::kCenteredRmsProp: {
state_variables->push_back(
MakeStandardStateVariableSpecification("ms", 1.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("mom", 0.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("mg", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("ms", 1.0);
add_state_variable("mom", 0.0);
add_state_variable("mg", 0.0);
break;
}
case OptimizationAlgorithm::kMdlAdagradLight: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
state_variables->push_back(
MakeStandardStateVariableSpecification("weights", 0.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("benefits", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.1);
add_state_variable("weights", 0.0);
add_state_variable("benefits", 0.0);
break;
}
case OptimizationAlgorithm::kAdadelta: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.0));
state_variables->push_back(
MakeStandardStateVariableSpecification("updates", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.0);
add_state_variable("updates", 0.0);
break;
}
case OptimizationAlgorithm::kProximalAdagrad: {
state_variables->push_back(
MakeStandardStateVariableSpecification("accumulators", 0.1));
add_state_variable("parameters", 0.0);
add_state_variable("accumulators", 0.1);
break;
}
case OptimizationAlgorithm::kOnlineYogi: {
state_variables->push_back(
MakeStandardStateVariableSpecification("vs", 0.1));
state_variables->push_back(
MakeStandardStateVariableSpecification("linears", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("vs", 0.1);
add_state_variable("linears", 0.0);
break;
}
case OptimizationAlgorithm::kProximalYogi: {
state_variables->push_back(
MakeStandardStateVariableSpecification("v", 0.1));
state_variables->push_back(
MakeStandardStateVariableSpecification("m", 0.0));
add_state_variable("parameters", 0.0);
add_state_variable("v", 0.1);
add_state_variable("m", 0.0);
break;
}
case OptimizationAlgorithm::kFrequencyEstimator: {
state_variables->push_back(
MakeStandardStateVariableSpecification("last_hit_step", 0));
add_state_variable("parameters", 0.0);
add_state_variable("last_hit_step", 0);
break;
}
case OptimizationAlgorithm::kUserDefinedProgram: {
add_state_variable("parameters",
params.user_defined_program().padding_values(0));
int num_slots = -1;
TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
if (num_slots + 1 !=
params.user_defined_program().padding_values_size()) {
return errors::InvalidArgument(
"Number of slots does not agree with the number of padding values "
"specified.");
}
for (int i = 0; i < num_slots; ++i) {
add_state_variable(absl::StrCat("Slot_", i),
params.user_defined_program().padding_values(i + 1));
}
break;
}
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
@ -313,6 +349,7 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
OptimizationAlgorithm::kOnlineYogi,
OptimizationAlgorithm::kProximalYogi,
OptimizationAlgorithm::kFrequencyEstimator,
OptimizationAlgorithm::kUserDefinedProgram,
};
}