[TF:XLA] Add support for update_slots=False in ApplyAdagrad.

PiperOrigin-RevId: 269524145
This commit is contained in:
Tres Popp 2019-09-17 02:23:32 -07:00 committed by TensorFlower Gardener
parent d5efc0e9fd
commit 5453aee488
2 changed files with 22 additions and 9 deletions

View File

@ -232,7 +232,9 @@ REGISTER_XLA_OP(
class ResourceApplyAdagrad : public XlaOpKernel { class ResourceApplyAdagrad : public XlaOpKernel {
public: public:
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(2); DataType type = ctx->input_type(2);
@ -261,11 +263,16 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2); xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3); xla::XlaOp grad = ctx->Input(3);
if (update_slots_) {
accum = accum + xla::Square(grad); accum = accum + xla::Square(grad);
}
var = var - grad * lr * xla::Rsqrt(accum); var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
} }
private:
bool update_slots_;
}; };
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagrad); ResourceApplyAdagrad);
@ -273,7 +280,9 @@ REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
class ResourceApplyAdagradV2 : public XlaOpKernel { class ResourceApplyAdagradV2 : public XlaOpKernel {
public: public:
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx) explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {} : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(2); DataType type = ctx->input_type(2);
@ -308,11 +317,16 @@ class ResourceApplyAdagradV2 : public XlaOpKernel {
xla::XlaOp epsilon = ctx->Input(3); xla::XlaOp epsilon = ctx->Input(3);
xla::XlaOp grad = ctx->Input(4); xla::XlaOp grad = ctx->Input(4);
if (update_slots_) {
accum = accum + xla::Square(grad); accum = accum + xla::Square(grad);
}
var = var - grad * lr / (xla::Sqrt(accum) + epsilon); var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
} }
private:
bool update_slots_;
}; };
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes), REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagradV2); ResourceApplyAdagradV2);

View File

@ -2,6 +2,7 @@
# Optimization routines. # Optimization routines.
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package( package(
default_visibility = ["//tensorflow:__subpackages__"], default_visibility = ["//tensorflow:__subpackages__"],
@ -210,13 +211,12 @@ py_test(
], ],
) )
py_test( cuda_py_test(
name = "reg_adagrad_optimizer_test", name = "reg_adagrad_optimizer_test",
srcs = ["python/training/reg_adagrad_optimizer_test.py"], srcs = ["python/training/reg_adagrad_optimizer_test.py"],
python_version = "PY2", additional_deps = [
srcs_version = "PY2AND3",
deps = [
":opt_py", ":opt_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
@ -226,7 +226,6 @@ py_test(
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//third_party/py/numpy",
], ],
) )