[TF:XLA] Add support for update_slots=False in ApplyAdagrad.
PiperOrigin-RevId: 269524145
This commit is contained in:
parent
d5efc0e9fd
commit
5453aee488
@ -232,7 +232,9 @@ REGISTER_XLA_OP(
|
||||
|
||||
class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType type = ctx->input_type(2);
|
||||
@ -261,11 +263,16 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
xla::XlaOp lr = ctx->Input(2);
|
||||
xla::XlaOp grad = ctx->Input(3);
|
||||
|
||||
if (update_slots_) {
|
||||
accum = accum + xla::Square(grad);
|
||||
}
|
||||
var = var - grad * lr * xla::Rsqrt(accum);
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||
}
|
||||
|
||||
private:
|
||||
bool update_slots_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdagrad);
|
||||
@ -273,7 +280,9 @@ REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||
class ResourceApplyAdagradV2 : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {}
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType type = ctx->input_type(2);
|
||||
@ -308,11 +317,16 @@ class ResourceApplyAdagradV2 : public XlaOpKernel {
|
||||
xla::XlaOp epsilon = ctx->Input(3);
|
||||
xla::XlaOp grad = ctx->Input(4);
|
||||
|
||||
if (update_slots_) {
|
||||
accum = accum + xla::Square(grad);
|
||||
}
|
||||
var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||
}
|
||||
|
||||
private:
|
||||
bool update_slots_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdagradV2);
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Optimization routines.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -210,13 +211,12 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "reg_adagrad_optimizer_test",
|
||||
srcs = ["python/training/reg_adagrad_optimizer_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":opt_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -226,7 +226,6 @@ py_test(
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user