From 84db54720300ba31f6a73a64b95ea466ffa5eef7 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Tue, 29 Sep 2020 18:46:53 -0700 Subject: [PATCH] Add a set of dynamic embedding optimizers directly taking an HloModule. PiperOrigin-RevId: 334507326 Change-Id: I6cec860e7ce56d77fe1c429945323691928b3c7a --- tensorflow/core/protobuf/tpu/BUILD | 1 + .../tpu/optimization_parameters.proto | 30 ++++ tensorflow/core/tpu/BUILD | 3 + ...embedding_optimization_parameters_utils.cc | 137 +++++++++++------- 4 files changed, 121 insertions(+), 50 deletions(-) diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD index 2cd25b4272e..2ce2b132c64 100644 --- a/tensorflow/core/protobuf/tpu/BUILD +++ b/tensorflow/core/protobuf/tpu/BUILD @@ -30,6 +30,7 @@ tf_proto_library( "optimization_parameters.proto", ], cc_api_version = 2, + protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto index b95574f7827..76c817957df 100644 --- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto +++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package tensorflow.tpu; import "google/protobuf/wrappers.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; message ClippingLimits { google.protobuf.FloatValue lower = 1; // -inf if not set @@ -317,6 +318,34 @@ message FrequencyEstimatorParameters { float weight_exponent = 4; } +// A user-defined optimizer. +// The contained HLO program must take the following arguments in the following +// order: +// 1. gradients +// 2. table weights +// 3. slot variables +// 4. an optional scalar input that is passed in via the dynamic learning +// rate mechanism. +// +// It must return/end in a tuple op that contains the following values in the +// following order: +// 1. new table values +// 2. new slot variable value +// +// The program must have shape (1,1) with dtype float32 throughout and only use +// HLO that operate elementwise (e.g., no reduce, no variables, no control flow +// and no broadcasting outside of the single scalar input). +// The HLO program should be written as if it were a dense update. It will be +// called on each row that needs an update and will applied elementwise. +message UserDefinedProgramParameters { + xla.HloModuleProto program = 1; + // Padding values for the parameter and the slots, see + // StateVariableSpecification.padding_initial_value below for more details on + // how this should be set. One value is needed for the weights and one for + // each slot. + repeated float padding_values = 2; +} + // Status of using gradient accumulation (doing two passes over the input // gradients: one to accumulate them into a temporary array and another to apply // them using the actual optimization algorithm). The extra message is to wrap @@ -395,6 +424,7 @@ message OptimizationParameters { OnlineYogiParameters online_yogi = 20; ProximalYogiParameters proximal_yogi = 21; FrequencyEstimatorParameters frequency_estimator = 23; + UserDefinedProgramParameters user_defined_program = 24; } reserved 15; // Old use_gradient_accumulation. diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 5022ad6228d..03b37ea1918 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -27,6 +27,9 @@ cc_library( hdrs = ["tpu_embedding_optimization_parameters_utils.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index 46633d22f9d..84d8ea70308 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" @@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) { return "ProximalYogi"; case OptimizationAlgorithm::kFrequencyEstimator: return "FrequencyEstimator"; + case OptimizationAlgorithm::kUserDefinedProgram: + return "UserDefinedProgram"; case OptimizationAlgorithm::PARAMETERS_NOT_SET: return "*** Not set ***"; } @@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { return "proximal Yogi"; case OptimizationAlgorithm::kFrequencyEstimator: return "frequency estimator"; + case OptimizationAlgorithm::kUserDefinedProgram: + return "UserDefinedProgram"; case OptimizationAlgorithm::PARAMETERS_NOT_SET: return "unknown (not specified)"; } @@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, case OptimizationAlgorithm::kFrequencyEstimator: *count = 1; return Status::OK(); + case OptimizationAlgorithm::kUserDefinedProgram: { + const xla::ProgramShapeProto& program_shape = + params.user_defined_program().program().host_program_shape(); + + const int num_inputs = program_shape.parameters_size(); + const int num_outputs = program_shape.result().tuple_shapes_size(); + + if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) && + (num_inputs != num_outputs + 2))) { + return errors::InvalidArgument( + "User-defined TPU embedding optimizer program must have at least " + "two inputs and the number of outputs must be 1 or 2 less than the " + "number of inputs. Received ", + num_inputs, " input(s) and ", num_outputs, "output(s)."); + } + + *count = num_outputs - 1; + + return Status::OK(); + } case OptimizationAlgorithm::PARAMETERS_NOT_SET: return errors::InvalidArgument("No optimization algorithm specified"); } @@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification( Status GetOptimizationAlgorithmStateVariables( const OptimizationParameters& params, bool use_gradient_accumulation, std::vector* state_variables) { - // The first parameter set is always the weights themselves. - state_variables->push_back( - MakeStandardStateVariableSpecification("parameters", 0.0)); // The order of the returned parameters needs to match the offsets used by // the algorithm implementations in test_util.cc and // address_handler_program_creator.cc. + // The first parameter set is always the weights themselves. + auto add_state_variable = [&](const std::string& name, float value) { + state_variables->push_back( + MakeStandardStateVariableSpecification(name, value)); + }; switch (params.parameters_case()) { case OptimizationAlgorithm::kAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kBoundedAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kStochasticGradientDescent: { - // None. + add_state_variable("parameters", 0.0); break; } case OptimizationAlgorithm::kFtrl: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("linears", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); + add_state_variable("linears", 0.0); break; } case OptimizationAlgorithm::kAdam: { - state_variables->push_back( - MakeStandardStateVariableSpecification("momenta", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("velocities", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("momenta", 0.0); + add_state_variable("velocities", 0.0); break; } case OptimizationAlgorithm::kMomentum: { - state_variables->push_back( - MakeStandardStateVariableSpecification("momenta", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("momenta", 0.0); break; } case OptimizationAlgorithm::kRmsProp: { - state_variables->push_back( - MakeStandardStateVariableSpecification("ms", 1.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mom", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("ms", 1.0); + add_state_variable("mom", 0.0); break; } case OptimizationAlgorithm::kCenteredRmsProp: { - state_variables->push_back( - MakeStandardStateVariableSpecification("ms", 1.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mom", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mg", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("ms", 1.0); + add_state_variable("mom", 0.0); + add_state_variable("mg", 0.0); break; } case OptimizationAlgorithm::kMdlAdagradLight: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("weights", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("benefits", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); + add_state_variable("weights", 0.0); + add_state_variable("benefits", 0.0); break; } case OptimizationAlgorithm::kAdadelta: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("updates", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.0); + add_state_variable("updates", 0.0); break; } case OptimizationAlgorithm::kProximalAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kOnlineYogi: { - state_variables->push_back( - MakeStandardStateVariableSpecification("vs", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("linears", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("vs", 0.1); + add_state_variable("linears", 0.0); break; } case OptimizationAlgorithm::kProximalYogi: { - state_variables->push_back( - MakeStandardStateVariableSpecification("v", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("m", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("v", 0.1); + add_state_variable("m", 0.0); break; } case OptimizationAlgorithm::kFrequencyEstimator: { - state_variables->push_back( - MakeStandardStateVariableSpecification("last_hit_step", 0)); + add_state_variable("parameters", 0.0); + add_state_variable("last_hit_step", 0); + break; + } + case OptimizationAlgorithm::kUserDefinedProgram: { + add_state_variable("parameters", + params.user_defined_program().padding_values(0)); + int num_slots = -1; + TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots)); + if (num_slots + 1 != + params.user_defined_program().padding_values_size()) { + return errors::InvalidArgument( + "Number of slots does not agree with the number of padding values " + "specified."); + } + for (int i = 0; i < num_slots; ++i) { + add_state_variable(absl::StrCat("Slot_", i), + params.user_defined_program().padding_values(i + 1)); + } break; } case OptimizationAlgorithm::PARAMETERS_NOT_SET: { @@ -313,6 +349,7 @@ std::vector GetOptimizationAlgorithms() { OptimizationAlgorithm::kOnlineYogi, OptimizationAlgorithm::kProximalYogi, OptimizationAlgorithm::kFrequencyEstimator, + OptimizationAlgorithm::kUserDefinedProgram, }; }