Let adagrad use fused op.
PiperOrigin-RevId: 261015592
This commit is contained in:
parent
3351279836
commit
d5c6687d99
@ -270,6 +270,53 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdagrad);
|
||||
|
||||
class ResourceApplyAdagradV2 : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType type = ctx->input_type(2);
|
||||
|
||||
TensorShape var_shape, accum_shape;
|
||||
xla::XlaOp var, accum;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
|
||||
|
||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||
errors::InvalidArgument(
|
||||
"var and accum do not have the same shape",
|
||||
var_shape.DebugString(), " ", accum_shape.DebugString()));
|
||||
|
||||
TensorShape lr_shape = ctx->InputShape(2);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr_shape.DebugString()));
|
||||
|
||||
TensorShape epsilon_shape = ctx->InputShape(3);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon_shape.DebugString()));
|
||||
|
||||
TensorShape grad_shape = ctx->InputShape(4);
|
||||
OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
|
||||
errors::InvalidArgument(
|
||||
"var and grad do not have the same shape",
|
||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||
|
||||
xla::XlaOp lr = ctx->Input(2);
|
||||
xla::XlaOp epsilon = ctx->Input(3);
|
||||
xla::XlaOp grad = ctx->Input(4);
|
||||
|
||||
accum = accum + xla::Square(grad);
|
||||
var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdagradV2);
|
||||
|
||||
class ResourceApplyProximalAdagrad : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
|
||||
|
@ -57,6 +57,7 @@ CreateResourceOpInfoMap() {
|
||||
add("ResourceApplyAdaMax" , kReadWrite, kVariable);
|
||||
add("ResourceApplyAdadelta" , kReadWrite, kVariable);
|
||||
add("ResourceApplyAdagrad" , kReadWrite, kVariable);
|
||||
add("ResourceApplyAdagradV2" , kReadWrite, kVariable),
|
||||
add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
|
||||
add("ResourceApplyAdam" , kReadWrite, kVariable);
|
||||
add("ResourceApplyAddSign" , kReadWrite, kVariable);
|
||||
|
@ -0,0 +1,53 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAdagradV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Constant factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / sqrt(accum))
|
||||
END
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
op {
|
||||
graph_op_name: "ResourceApplyAdagradV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Constant factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / sqrt(accum))
|
||||
END
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
op {
|
||||
graph_op_name: "ResourceSparseApplyAdagradV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Learning rate. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Constant factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A vector of indices into the first dimension of var and accum.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update relevant entries in \'*var\' and \'*accum\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
That is for rows we have grad for, we update var and accum as follows:
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / sqrt(accum))
|
||||
END
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
op {
|
||||
graph_op_name: "SparseApplyAdagradV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Learning rate. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Constant factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A vector of indices into the first dimension of var and accum.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update relevant entries in \'*var\' and \'*accum\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
That is for rows we have grad for, we update var and accum as follows:
|
||||
$$accum += grad * grad$$
|
||||
$$var -= lr * grad * (1 / sqrt(accum))$$
|
||||
END
|
||||
}
|
@ -161,6 +161,20 @@ struct ApplyAdagrad<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagradV2<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
if (update_slots) {
|
||||
accum.device(d) += grad.square();
|
||||
}
|
||||
var.device(d) -= grad * lr() / (accum.sqrt() + epsilon());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyProximalAdagrad<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -1264,6 +1278,106 @@ REGISTER_KERNELS(GPU, double);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyAdagradV2Op : public OpKernel {
|
||||
public:
|
||||
explicit ApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", requested_input(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", requested_input(1)));
|
||||
const Tensor& lr = ctx->input(2);
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
const Tensor& epsilon = ctx->input(3);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon.shape().DebugString()));
|
||||
const Tensor& grad = ctx->input(4);
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(accum.shape()),
|
||||
errors::InvalidArgument("var and accum do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
accum.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(grad.shape()),
|
||||
errors::InvalidArgument("var and grad do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
grad.shape().DebugString()));
|
||||
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
functor::ApplyAdagradV2<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||
lr.scalar<T>(), epsilon.scalar<T>(),
|
||||
grad.flat<T>(), update_slots_);
|
||||
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
bool update_slots_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdagradV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdagradV2Op<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradV2") \
|
||||
.HostMemory("var") \
|
||||
.HostMemory("accum") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyAdagradV2Op<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void ApplyAdagradV2<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::Flat var, \
|
||||
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstScalar epsilon, \
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots); \
|
||||
extern template struct ApplyAdagradV2<GPUDevice, T>;
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#endif
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyProximalAdagradOp : public OpKernel {
|
||||
public:
|
||||
@ -1530,6 +1644,179 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
// Note, this op works on cpu only.
|
||||
template <typename T, typename Tindex>
|
||||
class SparseApplyAdagradV2Op : public OpKernel {
|
||||
public:
|
||||
explicit SparseApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", requested_input(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", requested_input(1)));
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(accum.shape()),
|
||||
errors::InvalidArgument("var and accum do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
accum.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
const Tensor& lr = ctx->input(2);
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
const Tensor& epsilon = ctx->input(3);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon.shape().DebugString()));
|
||||
const Tensor& grad = ctx->input(4);
|
||||
const Tensor& indices = ctx->input(5);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
|
||||
errors::InvalidArgument("indices must be one-dimensional"));
|
||||
|
||||
int64 inner_dim = 1;
|
||||
for (int d = 1; d < var.dims(); d++) {
|
||||
OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"var and grad must match in dimension ", d)));
|
||||
inner_dim *= grad.dim_size(d);
|
||||
}
|
||||
const Tindex N = indices.dim_size(0);
|
||||
OP_REQUIRES(
|
||||
ctx, grad.dim_size(0) == N,
|
||||
errors::InvalidArgument(
|
||||
"grad must be the same size as indices in the first dimension."));
|
||||
|
||||
OP_REQUIRES(ctx, inner_dim > 0,
|
||||
errors::InvalidArgument(
|
||||
"Inner dimension should be greater than zero."));
|
||||
|
||||
// This op is implemented only for CPU device.
|
||||
const auto& d = ctx->eigen_cpu_device();
|
||||
|
||||
if (N > 0) {
|
||||
const int in_bytes = inner_dim * sizeof(T) * 3;
|
||||
const int out_bytes = inner_dim * sizeof(T) * 2;
|
||||
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
|
||||
Eigen::TensorOpCost::MulCost<T>() * 2);
|
||||
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
|
||||
|
||||
if (inner_dim > 1) {
|
||||
const Tindex first_dim_size = var.dim_size(0);
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
auto var_flat = var.flat_outer_dims<T>();
|
||||
auto accum_flat = accum.flat_outer_dims<T>();
|
||||
auto grad_flat = grad.flat_outer_dims<T>();
|
||||
const T lr_scalar = lr.scalar<T>()();
|
||||
const T epsilon_scalar = epsilon.scalar<T>()();
|
||||
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
auto a = accum_flat.template chip<0>(index);
|
||||
auto g = grad_flat.template chip<0>(i);
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
if (update_slots_) {
|
||||
a += g.square();
|
||||
}
|
||||
v -= g.constant(lr_scalar) * g /
|
||||
(a.sqrt() + a.constant(epsilon_scalar));
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
|
||||
} else {
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
auto var_flat = var.flat<T>();
|
||||
auto accum_flat = accum.flat<T>();
|
||||
auto grad_flat = grad.flat<T>();
|
||||
T lr_scalar = lr.scalar<T>()();
|
||||
const T epsilon_scalar = epsilon.scalar<T>()();
|
||||
const Tindex first_dim_size = accum_flat.size();
|
||||
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
T& a = accum_flat(index);
|
||||
const T& g = grad_flat(i);
|
||||
if (update_slots_) {
|
||||
a += g * g;
|
||||
}
|
||||
var_flat(index) -=
|
||||
lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon_scalar);
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
}
|
||||
}
|
||||
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
bool update_slots_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradV2Op<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradV2Op<T, Tindices>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(T, int32); \
|
||||
REGISTER_KERNELS(T, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
// Note, this op works on cpu only.
|
||||
template <typename T, typename Tindex>
|
||||
class SparseApplyProximalAdagradOp : public OpKernel {
|
||||
|
@ -71,6 +71,15 @@ struct ApplyAdagrad {
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdagradV2 {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdagradDA {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
|
@ -53,6 +53,25 @@ struct ApplyAdagrad<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagradV2<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
if (update_slots) {
|
||||
accum.device(d) += grad.square();
|
||||
}
|
||||
const auto update =
|
||||
grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -348,6 +367,10 @@ template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, double>;
|
||||
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, double>;
|
||||
|
||||
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||
|
@ -245,6 +245,20 @@ static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status ApplyAdagradV2ShapeFn(InferenceContext* c, bool sparse) {
|
||||
ShapeHandle unused;
|
||||
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
|
||||
TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // epsilon
|
||||
TF_RETURN_IF_ERROR(
|
||||
HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
|
||||
if (c->num_outputs() > 0) {
|
||||
c->set_output(0, s);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("ApplyAdagrad")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
@ -270,6 +284,33 @@ REGISTER_OP("ResourceApplyAdagrad")
|
||||
return ApplyAdagradShapeFn(c, false /* sparse */);
|
||||
});
|
||||
|
||||
REGISTER_OP("ApplyAdagradV2")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
.Input("lr: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("update_slots: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return ApplyAdagradV2ShapeFn(c, false /* sparse */);
|
||||
});
|
||||
|
||||
REGISTER_OP("ResourceApplyAdagradV2")
|
||||
.Input("var: resource")
|
||||
.Input("accum: resource")
|
||||
.Input("lr: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("update_slots: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return ApplyAdagradV2ShapeFn(c, false /* sparse */);
|
||||
});
|
||||
|
||||
static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
|
||||
ShapeHandle unused;
|
||||
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
|
||||
@ -341,6 +382,37 @@ REGISTER_OP("ResourceSparseApplyAdagrad")
|
||||
return ApplyAdagradShapeFn(c, true /* sparse */);
|
||||
});
|
||||
|
||||
REGISTER_OP("SparseApplyAdagradV2")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
.Input("lr: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tindices")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("update_slots: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return ApplyAdagradV2ShapeFn(c, true /* sparse */);
|
||||
});
|
||||
|
||||
REGISTER_OP("ResourceSparseApplyAdagradV2")
|
||||
.Input("var: resource")
|
||||
.Input("accum: resource")
|
||||
.Input("lr: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tindices")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("update_slots: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return ApplyAdagradV2ShapeFn(c, true /* sparse */);
|
||||
});
|
||||
|
||||
static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
|
||||
ShapeHandle unused;
|
||||
ShapeHandle s = ShapeOrHandleShape(c, 0); // var
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
@ -29,6 +30,7 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -152,6 +154,15 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
acc = self.get_slot(var, 'accumulator')
|
||||
if compat.forward_compatible(2019, 8, 20):
|
||||
return training_ops.resource_apply_adagrad_v2(
|
||||
var.handle,
|
||||
acc.handle,
|
||||
coefficients['lr_t'],
|
||||
coefficients['epsilon'],
|
||||
grad,
|
||||
use_locking=self._use_locking)
|
||||
|
||||
acc_t = state_ops.assign_add(
|
||||
acc, math_ops.square(grad), use_locking=self._use_locking)
|
||||
var_update = state_ops.assign_sub(
|
||||
@ -165,6 +176,15 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
acc = self.get_slot(var, 'accumulator')
|
||||
if compat.forward_compatible(2019, 8, 20):
|
||||
return training_ops.resource_sparse_apply_adagrad_v2(
|
||||
var.handle,
|
||||
acc.handle,
|
||||
coefficients['lr_t'],
|
||||
coefficients['epsilon'],
|
||||
grad,
|
||||
indices,
|
||||
use_locking=self._use_locking)
|
||||
with ops.control_dependencies([
|
||||
resource_variable_ops.resource_scatter_add(acc.handle, indices,
|
||||
math_ops.square(grad))
|
||||
|
@ -161,6 +161,47 @@ class AdagradOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
def testBasicWithLargeEpsilon(self):
|
||||
with self.cached_session():
|
||||
var0_np = np.array([1.0, 2.0])
|
||||
var1_np = np.array([3.0, 4.0])
|
||||
grads0_np = np.array([0.1, 0.1])
|
||||
grads1_np = np.array([0.01, 0.01])
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
learning_rate = 3.0
|
||||
|
||||
ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.0)
|
||||
|
||||
accum0_np = np.array([0.1, 0.1])
|
||||
accum1_np = np.array([0.1, 0.1])
|
||||
|
||||
if not context.executing_eagerly():
|
||||
ada_update = ada_opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
self.assertAllClose([1.0, 2.0], v0_val)
|
||||
self.assertAllClose([3.0, 4.0], v1_val)
|
||||
|
||||
# Run 3 steps of adagrad
|
||||
for _ in range(3):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(ada_update)
|
||||
else:
|
||||
ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np,
|
||||
3.0, 1.0)
|
||||
var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np,
|
||||
3.0, 1.0)
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
def testBasicWithLearningRateInverseTimeDecay(self):
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
@ -308,6 +349,41 @@ class AdagradOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSparseSingleVarDim(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
var0_np = np.array([1.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
grads0_np_indices = np.array([0], dtype=np.int32)
|
||||
grads0 = ops.IndexedSlices(
|
||||
constant_op.constant(grads0_np[grads0_np_indices]),
|
||||
constant_op.constant(grads0_np_indices), constant_op.constant([3]))
|
||||
learning_rate = 3.0
|
||||
ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.)
|
||||
ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0], var0.eval())
|
||||
|
||||
accum0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
# Run 3 step of sgd
|
||||
for _ in range(3):
|
||||
ada_update.run()
|
||||
|
||||
var0_np, accum0_np = sparse_adagrad_update_numpy(
|
||||
var0_np,
|
||||
accum0_np,
|
||||
grads0_np_indices,
|
||||
grads0_np[grads0_np_indices],
|
||||
learning_rate,
|
||||
epsilon=1.)
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSparseRepeatedIndices(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
|
@ -124,6 +124,10 @@ tf_module {
|
||||
name: "ApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ApplyAdam"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
@ -3160,6 +3164,10 @@ tf_module {
|
||||
name: "ResourceApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdam"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
@ -3280,6 +3288,10 @@ tf_module {
|
||||
name: "ResourceSparseApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceSparseApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceSparseApplyCenteredRMSProp"
|
||||
argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -3760,6 +3772,10 @@ tf_module {
|
||||
name: "SparseApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseApplyCenteredRMSProp"
|
||||
argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
@ -124,6 +124,10 @@ tf_module {
|
||||
name: "ApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ApplyAdam"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
@ -3160,6 +3164,10 @@ tf_module {
|
||||
name: "ResourceApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdam"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
@ -3280,6 +3288,10 @@ tf_module {
|
||||
name: "ResourceSparseApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceSparseApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceSparseApplyCenteredRMSProp"
|
||||
argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -3760,6 +3772,10 @@ tf_module {
|
||||
name: "SparseApplyAdagradDA"
|
||||
argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseApplyAdagradV2"
|
||||
argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseApplyCenteredRMSProp"
|
||||
argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user