[TF:XLA] Add support for update_slots=False in ApplyAdagrad.
PiperOrigin-RevId: 269524145
This commit is contained in:
parent
d5efc0e9fd
commit
5453aee488
@ -232,7 +232,9 @@ REGISTER_XLA_OP(
|
|||||||
|
|
||||||
class ResourceApplyAdagrad : public XlaOpKernel {
|
class ResourceApplyAdagrad : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
DataType type = ctx->input_type(2);
|
DataType type = ctx->input_type(2);
|
||||||
@ -261,11 +263,16 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
|||||||
xla::XlaOp lr = ctx->Input(2);
|
xla::XlaOp lr = ctx->Input(2);
|
||||||
xla::XlaOp grad = ctx->Input(3);
|
xla::XlaOp grad = ctx->Input(3);
|
||||||
|
|
||||||
|
if (update_slots_) {
|
||||||
accum = accum + xla::Square(grad);
|
accum = accum + xla::Square(grad);
|
||||||
|
}
|
||||||
var = var - grad * lr * xla::Rsqrt(accum);
|
var = var - grad * lr * xla::Rsqrt(accum);
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool update_slots_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||||
ResourceApplyAdagrad);
|
ResourceApplyAdagrad);
|
||||||
@ -273,7 +280,9 @@ REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
|||||||
class ResourceApplyAdagradV2 : public XlaOpKernel {
|
class ResourceApplyAdagradV2 : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
|
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
|
||||||
: XlaOpKernel(ctx) {}
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
DataType type = ctx->input_type(2);
|
DataType type = ctx->input_type(2);
|
||||||
@ -308,11 +317,16 @@ class ResourceApplyAdagradV2 : public XlaOpKernel {
|
|||||||
xla::XlaOp epsilon = ctx->Input(3);
|
xla::XlaOp epsilon = ctx->Input(3);
|
||||||
xla::XlaOp grad = ctx->Input(4);
|
xla::XlaOp grad = ctx->Input(4);
|
||||||
|
|
||||||
|
if (update_slots_) {
|
||||||
accum = accum + xla::Square(grad);
|
accum = accum + xla::Square(grad);
|
||||||
|
}
|
||||||
var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
|
var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool update_slots_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
|
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
|
||||||
ResourceApplyAdagradV2);
|
ResourceApplyAdagradV2);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Optimization routines.
|
# Optimization routines.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:__subpackages__"],
|
default_visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -210,13 +211,12 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "reg_adagrad_optimizer_test",
|
name = "reg_adagrad_optimizer_test",
|
||||||
srcs = ["python/training/reg_adagrad_optimizer_test.py"],
|
srcs = ["python/training/reg_adagrad_optimizer_test.py"],
|
||||||
python_version = "PY2",
|
additional_deps = [
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":opt_py",
|
":opt_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
@ -226,7 +226,6 @@ py_test(
|
|||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user