diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD
index 2cd25b4272e..2ce2b132c64 100644
--- a/tensorflow/core/protobuf/tpu/BUILD
+++ b/tensorflow/core/protobuf/tpu/BUILD
@@ -30,6 +30,7 @@ tf_proto_library(
         "optimization_parameters.proto",
     ],
     cc_api_version = 2,
+    protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"],
     visibility = ["//visibility:public"],
 )
 
diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
index b95574f7827..76c817957df 100644
--- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto
+++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
@@ -3,6 +3,7 @@ syntax = "proto3";
 package tensorflow.tpu;
 
 import "google/protobuf/wrappers.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
 
 message ClippingLimits {
   google.protobuf.FloatValue lower = 1;  // -inf if not set
@@ -317,6 +318,34 @@ message FrequencyEstimatorParameters {
   float weight_exponent = 4;
 }
 
+// A user-defined optimizer.
+// The contained HLO program must take the following arguments in the following
+// order:
+// 1.  gradients
+// 2.  table weights
+// 3.  slot variables
+// 4.  an optional scalar input that is passed in via the dynamic learning
+//     rate mechanism.
+//
+// It must return/end in a tuple op that contains the following values in the
+// following order:
+// 1.  new table values
+// 2.  new slot variable value
+//
+// The program must have shape (1,1) with dtype float32 throughout and only use
+// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
+// and no broadcasting outside of the single scalar input).
+// The HLO program should be written as if it were a dense update. It will be
+// called on each row that needs an update and will applied elementwise.
+message UserDefinedProgramParameters {
+  xla.HloModuleProto program = 1;
+  // Padding values for the parameter and the slots, see
+  // StateVariableSpecification.padding_initial_value below for more details on
+  // how this should be set. One value is needed for the weights and one for
+  // each slot.
+  repeated float padding_values = 2;
+}
+
 // Status of using gradient accumulation (doing two passes over the input
 // gradients: one to accumulate them into a temporary array and another to apply
 // them using the actual optimization algorithm). The extra message is to wrap
@@ -395,6 +424,7 @@ message OptimizationParameters {
     OnlineYogiParameters online_yogi = 20;
     ProximalYogiParameters proximal_yogi = 21;
     FrequencyEstimatorParameters frequency_estimator = 23;
+    UserDefinedProgramParameters user_defined_program = 24;
   }
 
   reserved 15;  // Old use_gradient_accumulation.
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 5022ad6228d..03b37ea1918 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -27,6 +27,9 @@ cc_library(
     hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
index 46633d22f9d..84d8ea70308 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
@@ -15,6 +15,9 @@ limitations under the License.
 
 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
 
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
       return "ProximalYogi";
     case OptimizationAlgorithm::kFrequencyEstimator:
       return "FrequencyEstimator";
+    case OptimizationAlgorithm::kUserDefinedProgram:
+      return "UserDefinedProgram";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "*** Not set ***";
   }
@@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
       return "proximal Yogi";
     case OptimizationAlgorithm::kFrequencyEstimator:
       return "frequency estimator";
+    case OptimizationAlgorithm::kUserDefinedProgram:
+      return "UserDefinedProgram";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "unknown (not specified)";
   }
@@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
     case OptimizationAlgorithm::kFrequencyEstimator:
       *count = 1;
       return Status::OK();
+    case OptimizationAlgorithm::kUserDefinedProgram: {
+      const xla::ProgramShapeProto& program_shape =
+          params.user_defined_program().program().host_program_shape();
+
+      const int num_inputs = program_shape.parameters_size();
+      const int num_outputs = program_shape.result().tuple_shapes_size();
+
+      if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
+                               (num_inputs != num_outputs + 2))) {
+        return errors::InvalidArgument(
+            "User-defined TPU embedding optimizer program must have at least "
+            "two inputs and the number of outputs must be 1 or 2 less than the "
+            "number of inputs. Received ",
+            num_inputs, " input(s) and ", num_outputs, "output(s).");
+      }
+
+      *count = num_outputs - 1;
+
+      return Status::OK();
+    }
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return errors::InvalidArgument("No optimization algorithm specified");
   }
@@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification(
 Status GetOptimizationAlgorithmStateVariables(
     const OptimizationParameters& params, bool use_gradient_accumulation,
     std::vector<StateVariableSpecification>* state_variables) {
-  // The first parameter set is always the weights themselves.
-  state_variables->push_back(
-      MakeStandardStateVariableSpecification("parameters", 0.0));
   // The order of the returned parameters needs to match the offsets used by
   // the algorithm implementations in test_util.cc and
   // address_handler_program_creator.cc.
+  // The first parameter set is always the weights themselves.
+  auto add_state_variable = [&](const std::string& name, float value) {
+    state_variables->push_back(
+        MakeStandardStateVariableSpecification(name, value));
+  };
   switch (params.parameters_case()) {
     case OptimizationAlgorithm::kAdagrad: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.1));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.1);
       break;
     }
     case OptimizationAlgorithm::kBoundedAdagrad: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.1));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.1);
       break;
     }
     case OptimizationAlgorithm::kStochasticGradientDescent: {
-      // None.
+      add_state_variable("parameters", 0.0);
       break;
     }
     case OptimizationAlgorithm::kFtrl: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.1));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("linears", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.1);
+      add_state_variable("linears", 0.0);
       break;
     }
     case OptimizationAlgorithm::kAdam: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("momenta", 0.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("velocities", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("momenta", 0.0);
+      add_state_variable("velocities", 0.0);
       break;
     }
     case OptimizationAlgorithm::kMomentum: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("momenta", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("momenta", 0.0);
       break;
     }
     case OptimizationAlgorithm::kRmsProp: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("ms", 1.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("mom", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("ms", 1.0);
+      add_state_variable("mom", 0.0);
       break;
     }
     case OptimizationAlgorithm::kCenteredRmsProp: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("ms", 1.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("mom", 0.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("mg", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("ms", 1.0);
+      add_state_variable("mom", 0.0);
+      add_state_variable("mg", 0.0);
       break;
     }
     case OptimizationAlgorithm::kMdlAdagradLight: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.1));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("weights", 0.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("benefits", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.1);
+      add_state_variable("weights", 0.0);
+      add_state_variable("benefits", 0.0);
       break;
     }
     case OptimizationAlgorithm::kAdadelta: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.0));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("updates", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.0);
+      add_state_variable("updates", 0.0);
       break;
     }
     case OptimizationAlgorithm::kProximalAdagrad: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("accumulators", 0.1));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("accumulators", 0.1);
       break;
     }
     case OptimizationAlgorithm::kOnlineYogi: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("vs", 0.1));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("linears", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("vs", 0.1);
+      add_state_variable("linears", 0.0);
       break;
     }
     case OptimizationAlgorithm::kProximalYogi: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("v", 0.1));
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("m", 0.0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("v", 0.1);
+      add_state_variable("m", 0.0);
       break;
     }
     case OptimizationAlgorithm::kFrequencyEstimator: {
-      state_variables->push_back(
-          MakeStandardStateVariableSpecification("last_hit_step", 0));
+      add_state_variable("parameters", 0.0);
+      add_state_variable("last_hit_step", 0);
+      break;
+    }
+    case OptimizationAlgorithm::kUserDefinedProgram: {
+      add_state_variable("parameters",
+                         params.user_defined_program().padding_values(0));
+      int num_slots = -1;
+      TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
+      if (num_slots + 1 !=
+          params.user_defined_program().padding_values_size()) {
+        return errors::InvalidArgument(
+            "Number of slots does not agree with the number of padding values "
+            "specified.");
+      }
+      for (int i = 0; i < num_slots; ++i) {
+        add_state_variable(absl::StrCat("Slot_", i),
+                           params.user_defined_program().padding_values(i + 1));
+      }
       break;
     }
     case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
@@ -313,6 +349,7 @@ std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
       OptimizationAlgorithm::kOnlineYogi,
       OptimizationAlgorithm::kProximalYogi,
       OptimizationAlgorithm::kFrequencyEstimator,
+      OptimizationAlgorithm::kUserDefinedProgram,
   };
 }