Add a set of dynamic embedding optimizers directly taking an HloModule.
PiperOrigin-RevId: 334507326 Change-Id: I6cec860e7ce56d77fe1c429945323691928b3c7a
This commit is contained in:
parent
aeeb7d93d8
commit
84db547203
@ -30,6 +30,7 @@ tf_proto_library(
|
||||
"optimization_parameters.proto",
|
||||
],
|
||||
cc_api_version = 2,
|
||||
protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
@ -3,6 +3,7 @@ syntax = "proto3";
|
||||
package tensorflow.tpu;
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
|
||||
message ClippingLimits {
|
||||
google.protobuf.FloatValue lower = 1; // -inf if not set
|
||||
@ -317,6 +318,34 @@ message FrequencyEstimatorParameters {
|
||||
float weight_exponent = 4;
|
||||
}
|
||||
|
||||
// A user-defined optimizer.
|
||||
// The contained HLO program must take the following arguments in the following
|
||||
// order:
|
||||
// 1. gradients
|
||||
// 2. table weights
|
||||
// 3. slot variables
|
||||
// 4. an optional scalar input that is passed in via the dynamic learning
|
||||
// rate mechanism.
|
||||
//
|
||||
// It must return/end in a tuple op that contains the following values in the
|
||||
// following order:
|
||||
// 1. new table values
|
||||
// 2. new slot variable value
|
||||
//
|
||||
// The program must have shape (1,1) with dtype float32 throughout and only use
|
||||
// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
|
||||
// and no broadcasting outside of the single scalar input).
|
||||
// The HLO program should be written as if it were a dense update. It will be
|
||||
// called on each row that needs an update and will applied elementwise.
|
||||
message UserDefinedProgramParameters {
|
||||
xla.HloModuleProto program = 1;
|
||||
// Padding values for the parameter and the slots, see
|
||||
// StateVariableSpecification.padding_initial_value below for more details on
|
||||
// how this should be set. One value is needed for the weights and one for
|
||||
// each slot.
|
||||
repeated float padding_values = 2;
|
||||
}
|
||||
|
||||
// Status of using gradient accumulation (doing two passes over the input
|
||||
// gradients: one to accumulate them into a temporary array and another to apply
|
||||
// them using the actual optimization algorithm). The extra message is to wrap
|
||||
@ -395,6 +424,7 @@ message OptimizationParameters {
|
||||
OnlineYogiParameters online_yogi = 20;
|
||||
ProximalYogiParameters proximal_yogi = 21;
|
||||
FrequencyEstimatorParameters frequency_estimator = 23;
|
||||
UserDefinedProgramParameters user_defined_program = 24;
|
||||
}
|
||||
|
||||
reserved 15; // Old use_gradient_accumulation.
|
||||
|
@ -27,6 +27,9 @@ cc_library(
|
||||
hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
|
||||
return "ProximalYogi";
|
||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||
return "FrequencyEstimator";
|
||||
case OptimizationAlgorithm::kUserDefinedProgram:
|
||||
return "UserDefinedProgram";
|
||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||
return "*** Not set ***";
|
||||
}
|
||||
@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
|
||||
return "proximal Yogi";
|
||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||
return "frequency estimator";
|
||||
case OptimizationAlgorithm::kUserDefinedProgram:
|
||||
return "UserDefinedProgram";
|
||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||
return "unknown (not specified)";
|
||||
}
|
||||
@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
|
||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||
*count = 1;
|
||||
return Status::OK();
|
||||
case OptimizationAlgorithm::kUserDefinedProgram: {
|
||||
const xla::ProgramShapeProto& program_shape =
|
||||
params.user_defined_program().program().host_program_shape();
|
||||
|
||||
const int num_inputs = program_shape.parameters_size();
|
||||
const int num_outputs = program_shape.result().tuple_shapes_size();
|
||||
|
||||
if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
|
||||
(num_inputs != num_outputs + 2))) {
|
||||
return errors::InvalidArgument(
|
||||
"User-defined TPU embedding optimizer program must have at least "
|
||||
"two inputs and the number of outputs must be 1 or 2 less than the "
|
||||
"number of inputs. Received ",
|
||||
num_inputs, " input(s) and ", num_outputs, "output(s).");
|
||||
}
|
||||
|
||||
*count = num_outputs - 1;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||
return errors::InvalidArgument("No optimization algorithm specified");
|
||||
}
|
||||
@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
|
||||
Status GetOptimizationAlgorithmStateVariables(
|
||||
const OptimizationParameters& params, bool use_gradient_accumulation,
|
||||
std::vector<StateVariableSpecification>* state_variables) {
|
||||
// The first parameter set is always the weights themselves.
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("parameters", 0.0));
|
||||
// The order of the returned parameters needs to match the offsets used by
|
||||
// the algorithm implementations in test_util.cc and
|
||||
// address_handler_program_creator.cc.
|
||||
// The first parameter set is always the weights themselves.
|
||||
auto add_state_variable = [&](const std::string& name, float value) {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification(name, value));
|
||||
};
|
||||
switch (params.parameters_case()) {
|
||||
case OptimizationAlgorithm::kAdagrad: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.1);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kBoundedAdagrad: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.1);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kStochasticGradientDescent: {
|
||||
// None.
|
||||
add_state_variable("parameters", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kFtrl: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("linears", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.1);
|
||||
add_state_variable("linears", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kAdam: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("momenta", 0.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("velocities", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("momenta", 0.0);
|
||||
add_state_variable("velocities", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kMomentum: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("momenta", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("momenta", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kRmsProp: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("ms", 1.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("mom", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("ms", 1.0);
|
||||
add_state_variable("mom", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kCenteredRmsProp: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("ms", 1.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("mom", 0.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("mg", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("ms", 1.0);
|
||||
add_state_variable("mom", 0.0);
|
||||
add_state_variable("mg", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kMdlAdagradLight: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("weights", 0.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("benefits", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.1);
|
||||
add_state_variable("weights", 0.0);
|
||||
add_state_variable("benefits", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kAdadelta: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.0));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("updates", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.0);
|
||||
add_state_variable("updates", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kProximalAdagrad: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("accumulators", 0.1);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kOnlineYogi: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("vs", 0.1));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("linears", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("vs", 0.1);
|
||||
add_state_variable("linears", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kProximalYogi: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("v", 0.1));
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("m", 0.0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("v", 0.1);
|
||||
add_state_variable("m", 0.0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kFrequencyEstimator: {
|
||||
state_variables->push_back(
|
||||
MakeStandardStateVariableSpecification("last_hit_step", 0));
|
||||
add_state_variable("parameters", 0.0);
|
||||
add_state_variable("last_hit_step", 0);
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::kUserDefinedProgram: {
|
||||
add_state_variable("parameters",
|
||||
params.user_defined_program().padding_values(0));
|
||||
int num_slots = -1;
|
||||
TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
|
||||
if (num_slots + 1 !=
|
||||
params.user_defined_program().padding_values_size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of slots does not agree with the number of padding values "
|
||||
"specified.");
|
||||
}
|
||||
for (int i = 0; i < num_slots; ++i) {
|
||||
add_state_variable(absl::StrCat("Slot_", i),
|
||||
params.user_defined_program().padding_values(i + 1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
|
||||
@ -313,6 +349,7 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
|
||||
OptimizationAlgorithm::kOnlineYogi,
|
||||
OptimizationAlgorithm::kProximalYogi,
|
||||
OptimizationAlgorithm::kFrequencyEstimator,
|
||||
OptimizationAlgorithm::kUserDefinedProgram,
|
||||
};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user