[TF:XLA] Add support for update_slots=False in ApplyAdagrad.

PiperOrigin-RevId: 269524145
This commit is contained in:
Tres Popp 2019-09-17 02:23:32 -07:00 committed by TensorFlower Gardener
parent d5efc0e9fd
commit 5453aee488
2 changed files with 22 additions and 9 deletions

View File

@ -232,7 +232,9 @@ REGISTER_XLA_OP(
class ResourceApplyAdagrad : public XlaOpKernel {
public:
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(2);
@ -261,11 +263,16 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
if (update_slots_) {
accum = accum + xla::Square(grad);
}
var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
private:
bool update_slots_;
};
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagrad);
@ -273,7 +280,9 @@ REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
class ResourceApplyAdagradV2 : public XlaOpKernel {
public:
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(2);
@ -308,11 +317,16 @@ class ResourceApplyAdagradV2 : public XlaOpKernel {
xla::XlaOp epsilon = ctx->Input(3);
xla::XlaOp grad = ctx->Input(4);
if (update_slots_) {
accum = accum + xla::Square(grad);
}
var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
private:
bool update_slots_;
};
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagradV2);

View File

@ -2,6 +2,7 @@
# Optimization routines.
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
default_visibility = ["//tensorflow:__subpackages__"],
@ -210,13 +211,12 @@ py_test(
],
)
py_test(
cuda_py_test(
name = "reg_adagrad_optimizer_test",
srcs = ["python/training/reg_adagrad_optimizer_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
additional_deps = [
":opt_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -226,7 +226,6 @@ py_test(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)