Add a set of dynamic embedding optimizers directly taking an HloModule.
PiperOrigin-RevId: 334507326 Change-Id: I6cec860e7ce56d77fe1c429945323691928b3c7a
This commit is contained in:
parent
aeeb7d93d8
commit
84db547203
@ -30,6 +30,7 @@ tf_proto_library(
|
|||||||
"optimization_parameters.proto",
|
"optimization_parameters.proto",
|
||||||
],
|
],
|
||||||
cc_api_version = 2,
|
cc_api_version = 2,
|
||||||
|
protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ syntax = "proto3";
|
|||||||
package tensorflow.tpu;
|
package tensorflow.tpu;
|
||||||
|
|
||||||
import "google/protobuf/wrappers.proto";
|
import "google/protobuf/wrappers.proto";
|
||||||
|
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||||
|
|
||||||
message ClippingLimits {
|
message ClippingLimits {
|
||||||
google.protobuf.FloatValue lower = 1; // -inf if not set
|
google.protobuf.FloatValue lower = 1; // -inf if not set
|
||||||
@ -317,6 +318,34 @@ message FrequencyEstimatorParameters {
|
|||||||
float weight_exponent = 4;
|
float weight_exponent = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A user-defined optimizer.
|
||||||
|
// The contained HLO program must take the following arguments in the following
|
||||||
|
// order:
|
||||||
|
// 1. gradients
|
||||||
|
// 2. table weights
|
||||||
|
// 3. slot variables
|
||||||
|
// 4. an optional scalar input that is passed in via the dynamic learning
|
||||||
|
// rate mechanism.
|
||||||
|
//
|
||||||
|
// It must return/end in a tuple op that contains the following values in the
|
||||||
|
// following order:
|
||||||
|
// 1. new table values
|
||||||
|
// 2. new slot variable value
|
||||||
|
//
|
||||||
|
// The program must have shape (1,1) with dtype float32 throughout and only use
|
||||||
|
// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
|
||||||
|
// and no broadcasting outside of the single scalar input).
|
||||||
|
// The HLO program should be written as if it were a dense update. It will be
|
||||||
|
// called on each row that needs an update and will applied elementwise.
|
||||||
|
message UserDefinedProgramParameters {
|
||||||
|
xla.HloModuleProto program = 1;
|
||||||
|
// Padding values for the parameter and the slots, see
|
||||||
|
// StateVariableSpecification.padding_initial_value below for more details on
|
||||||
|
// how this should be set. One value is needed for the weights and one for
|
||||||
|
// each slot.
|
||||||
|
repeated float padding_values = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Status of using gradient accumulation (doing two passes over the input
|
// Status of using gradient accumulation (doing two passes over the input
|
||||||
// gradients: one to accumulate them into a temporary array and another to apply
|
// gradients: one to accumulate them into a temporary array and another to apply
|
||||||
// them using the actual optimization algorithm). The extra message is to wrap
|
// them using the actual optimization algorithm). The extra message is to wrap
|
||||||
@ -395,6 +424,7 @@ message OptimizationParameters {
|
|||||||
OnlineYogiParameters online_yogi = 20;
|
OnlineYogiParameters online_yogi = 20;
|
||||||
ProximalYogiParameters proximal_yogi = 21;
|
ProximalYogiParameters proximal_yogi = 21;
|
||||||
FrequencyEstimatorParameters frequency_estimator = 23;
|
FrequencyEstimatorParameters frequency_estimator = 23;
|
||||||
|
UserDefinedProgramParameters user_defined_program = 24;
|
||||||
}
|
}
|
||||||
|
|
||||||
reserved 15; // Old use_gradient_accumulation.
|
reserved 15; // Old use_gradient_accumulation.
|
||||||
|
@ -27,6 +27,9 @@ cc_library(
|
|||||||
hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
|
hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
|
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
|
|||||||
return "ProximalYogi";
|
return "ProximalYogi";
|
||||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||||
return "FrequencyEstimator";
|
return "FrequencyEstimator";
|
||||||
|
case OptimizationAlgorithm::kUserDefinedProgram:
|
||||||
|
return "UserDefinedProgram";
|
||||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||||
return "*** Not set ***";
|
return "*** Not set ***";
|
||||||
}
|
}
|
||||||
@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
|
|||||||
return "proximal Yogi";
|
return "proximal Yogi";
|
||||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||||
return "frequency estimator";
|
return "frequency estimator";
|
||||||
|
case OptimizationAlgorithm::kUserDefinedProgram:
|
||||||
|
return "UserDefinedProgram";
|
||||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||||
return "unknown (not specified)";
|
return "unknown (not specified)";
|
||||||
}
|
}
|
||||||
@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
|
|||||||
case OptimizationAlgorithm::kFrequencyEstimator:
|
case OptimizationAlgorithm::kFrequencyEstimator:
|
||||||
*count = 1;
|
*count = 1;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
case OptimizationAlgorithm::kUserDefinedProgram: {
|
||||||
|
const xla::ProgramShapeProto& program_shape =
|
||||||
|
params.user_defined_program().program().host_program_shape();
|
||||||
|
|
||||||
|
const int num_inputs = program_shape.parameters_size();
|
||||||
|
const int num_outputs = program_shape.result().tuple_shapes_size();
|
||||||
|
|
||||||
|
if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
|
||||||
|
(num_inputs != num_outputs + 2))) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"User-defined TPU embedding optimizer program must have at least "
|
||||||
|
"two inputs and the number of outputs must be 1 or 2 less than the "
|
||||||
|
"number of inputs. Received ",
|
||||||
|
num_inputs, " input(s) and ", num_outputs, "output(s).");
|
||||||
|
}
|
||||||
|
|
||||||
|
*count = num_outputs - 1;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
case OptimizationAlgorithm::PARAMETERS_NOT_SET:
|
||||||
return errors::InvalidArgument("No optimization algorithm specified");
|
return errors::InvalidArgument("No optimization algorithm specified");
|
||||||
}
|
}
|
||||||
@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
|
|||||||
Status GetOptimizationAlgorithmStateVariables(
|
Status GetOptimizationAlgorithmStateVariables(
|
||||||
const OptimizationParameters& params, bool use_gradient_accumulation,
|
const OptimizationParameters& params, bool use_gradient_accumulation,
|
||||||
std::vector<StateVariableSpecification>* state_variables) {
|
std::vector<StateVariableSpecification>* state_variables) {
|
||||||
// The first parameter set is always the weights themselves.
|
|
||||||
state_variables->push_back(
|
|
||||||
MakeStandardStateVariableSpecification("parameters", 0.0));
|
|
||||||
// The order of the returned parameters needs to match the offsets used by
|
// The order of the returned parameters needs to match the offsets used by
|
||||||
// the algorithm implementations in test_util.cc and
|
// the algorithm implementations in test_util.cc and
|
||||||
// address_handler_program_creator.cc.
|
// address_handler_program_creator.cc.
|
||||||
|
// The first parameter set is always the weights themselves.
|
||||||
|
auto add_state_variable = [&](const std::string& name, float value) {
|
||||||
|
state_variables->push_back(
|
||||||
|
MakeStandardStateVariableSpecification(name, value));
|
||||||
|
};
|
||||||
switch (params.parameters_case()) {
|
switch (params.parameters_case()) {
|
||||||
case OptimizationAlgorithm::kAdagrad: {
|
case OptimizationAlgorithm::kAdagrad: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
add_state_variable("accumulators", 0.1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kBoundedAdagrad: {
|
case OptimizationAlgorithm::kBoundedAdagrad: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
add_state_variable("accumulators", 0.1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kStochasticGradientDescent: {
|
case OptimizationAlgorithm::kStochasticGradientDescent: {
|
||||||
// None.
|
add_state_variable("parameters", 0.0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kFtrl: {
|
case OptimizationAlgorithm::kFtrl: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
add_state_variable("accumulators", 0.1);
|
||||||
state_variables->push_back(
|
add_state_variable("linears", 0.0);
|
||||||
MakeStandardStateVariableSpecification("linears", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kAdam: {
|
case OptimizationAlgorithm::kAdam: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("momenta", 0.0));
|
add_state_variable("momenta", 0.0);
|
||||||
state_variables->push_back(
|
add_state_variable("velocities", 0.0);
|
||||||
MakeStandardStateVariableSpecification("velocities", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kMomentum: {
|
case OptimizationAlgorithm::kMomentum: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("momenta", 0.0));
|
add_state_variable("momenta", 0.0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kRmsProp: {
|
case OptimizationAlgorithm::kRmsProp: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("ms", 1.0));
|
add_state_variable("ms", 1.0);
|
||||||
state_variables->push_back(
|
add_state_variable("mom", 0.0);
|
||||||
MakeStandardStateVariableSpecification("mom", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kCenteredRmsProp: {
|
case OptimizationAlgorithm::kCenteredRmsProp: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("ms", 1.0));
|
add_state_variable("ms", 1.0);
|
||||||
state_variables->push_back(
|
add_state_variable("mom", 0.0);
|
||||||
MakeStandardStateVariableSpecification("mom", 0.0));
|
add_state_variable("mg", 0.0);
|
||||||
state_variables->push_back(
|
|
||||||
MakeStandardStateVariableSpecification("mg", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kMdlAdagradLight: {
|
case OptimizationAlgorithm::kMdlAdagradLight: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
add_state_variable("accumulators", 0.1);
|
||||||
state_variables->push_back(
|
add_state_variable("weights", 0.0);
|
||||||
MakeStandardStateVariableSpecification("weights", 0.0));
|
add_state_variable("benefits", 0.0);
|
||||||
state_variables->push_back(
|
|
||||||
MakeStandardStateVariableSpecification("benefits", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kAdadelta: {
|
case OptimizationAlgorithm::kAdadelta: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.0));
|
add_state_variable("accumulators", 0.0);
|
||||||
state_variables->push_back(
|
add_state_variable("updates", 0.0);
|
||||||
MakeStandardStateVariableSpecification("updates", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kProximalAdagrad: {
|
case OptimizationAlgorithm::kProximalAdagrad: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("accumulators", 0.1));
|
add_state_variable("accumulators", 0.1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kOnlineYogi: {
|
case OptimizationAlgorithm::kOnlineYogi: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("vs", 0.1));
|
add_state_variable("vs", 0.1);
|
||||||
state_variables->push_back(
|
add_state_variable("linears", 0.0);
|
||||||
MakeStandardStateVariableSpecification("linears", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kProximalYogi: {
|
case OptimizationAlgorithm::kProximalYogi: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("v", 0.1));
|
add_state_variable("v", 0.1);
|
||||||
state_variables->push_back(
|
add_state_variable("m", 0.0);
|
||||||
MakeStandardStateVariableSpecification("m", 0.0));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::kFrequencyEstimator: {
|
case OptimizationAlgorithm::kFrequencyEstimator: {
|
||||||
state_variables->push_back(
|
add_state_variable("parameters", 0.0);
|
||||||
MakeStandardStateVariableSpecification("last_hit_step", 0));
|
add_state_variable("last_hit_step", 0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OptimizationAlgorithm::kUserDefinedProgram: {
|
||||||
|
add_state_variable("parameters",
|
||||||
|
params.user_defined_program().padding_values(0));
|
||||||
|
int num_slots = -1;
|
||||||
|
TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
|
||||||
|
if (num_slots + 1 !=
|
||||||
|
params.user_defined_program().padding_values_size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of slots does not agree with the number of padding values "
|
||||||
|
"specified.");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_slots; ++i) {
|
||||||
|
add_state_variable(absl::StrCat("Slot_", i),
|
||||||
|
params.user_defined_program().padding_values(i + 1));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
|
case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
|
||||||
@ -313,6 +349,7 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
|
|||||||
OptimizationAlgorithm::kOnlineYogi,
|
OptimizationAlgorithm::kOnlineYogi,
|
||||||
OptimizationAlgorithm::kProximalYogi,
|
OptimizationAlgorithm::kProximalYogi,
|
||||||
OptimizationAlgorithm::kFrequencyEstimator,
|
OptimizationAlgorithm::kFrequencyEstimator,
|
||||||
|
OptimizationAlgorithm::kUserDefinedProgram,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user