diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 247db8d5d17..191ce9dee2b 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -270,6 +270,53 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyAdagradV2 : public XlaOpKernel { + public: + explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape epsilon_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp epsilon = ctx->Input(3); + xla::XlaOp grad = ctx->Input(4); + + accum = accum + xla::Square(grad); + var = var - grad * lr / (xla::Sqrt(accum) + epsilon); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagradV2); + class ResourceApplyProximalAdagrad : public XlaOpKernel { public: explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 1243e31a047..2db431c0413 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -57,6 +57,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyAdaMax" , kReadWrite, kVariable); add("ResourceApplyAdadelta" , kReadWrite, kVariable); add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradV2" , kReadWrite, kVariable), add("ResourceApplyAdagradDA" , kReadWrite, kVariable); add("ResourceApplyAdam" , kReadWrite, kVariable); add("ResourceApplyAddSign" , kReadWrite, kVariable); diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt new file mode 100644 index 00000000000..07366bfd367 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt @@ -0,0 +1,53 @@ +op { + graph_op_name: "ApplyAdagradV2" + visibility: HIDDEN + in_arg { + name: "var" + description: <