Add Adadelta optimizer
This commit is contained in:
parent
c1c4cbf144
commit
cbfe3a0f5a
BIN
cuda-repo-ubuntu1404_7.5-18_amd64.deb
Normal file
BIN
cuda-repo-ubuntu1404_7.5-18_amd64.deb
Normal file
Binary file not shown.
@ -36,6 +36,22 @@ struct ApplyGradientDescent<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::Flat accum_update,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar rho,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
accum.device(d) = accum * rho() + grad.square() * (1 - rho());
|
||||
const auto update = accum_update * (accum + epsilon()).rsqrt() * grad;
|
||||
accum_update.device(d) = accum_update * rho() + update.square() * (1 - rho());
|
||||
var.device(d) -= update * lr();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagrad<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -224,6 +240,266 @@ REGISTER_KERNELS(GPU, double);
|
||||
#endif
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyAdadeltaOp : public OpKernel {
|
||||
public:
|
||||
explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
if (use_exclusive_lock_) {
|
||||
mutex_lock l1(*ctx->input_ref_mutex(0));
|
||||
// Don't try to acquire a lock on the second ref as they share the same
|
||||
// mutex.
|
||||
//
|
||||
// mutex_lock l2(*ctx->input_ref_mutex(1));
|
||||
DoValidate(ctx);
|
||||
if (!ctx->status().ok()) return;
|
||||
DoCompute(ctx);
|
||||
} else {
|
||||
DoValidate(ctx);
|
||||
if (!ctx->status().ok()) return;
|
||||
DoCompute(ctx);
|
||||
}
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
|
||||
void DoValidate(OpKernelContext* ctx) {
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(1)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum_update.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(2)));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
const Tensor& epsilon = ctx->input(5);
|
||||
const Tensor& grad = ctx->input(6);
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
|
||||
errors::InvalidArgument("rho is not a scalar: ",
|
||||
rho.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(accum.shape()),
|
||||
errors::InvalidArgument("var and accum do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
accum.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(grad.shape()),
|
||||
errors::InvalidArgument("var and grad do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
grad.shape().DebugString()));
|
||||
}
|
||||
|
||||
void DoCompute(OpKernelContext* ctx) {
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
const Tensor& epsilon = ctx->input(5);
|
||||
const Tensor& grad = ctx->input(6);
|
||||
|
||||
functor::ApplyAdadelta<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||
accum_update.flat<T>(), lr.scalar<T>(),
|
||||
rho.scalar<T>(), epsilon.scalar<T>(),
|
||||
grad.flat<T>());
|
||||
}
|
||||
};
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdadeltaOp<D##Device, T>);
|
||||
|
||||
REGISTER_KERNELS(CPU, float);
|
||||
REGISTER_KERNELS(CPU, double);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void ApplyAdadelta<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::Flat var, \
|
||||
typename TTypes<T>::Flat accum, \
|
||||
typename TTypes<T>::Flat accum_update, \
|
||||
typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstScalar rho, \
|
||||
typename TTypes<T>::ConstScalar epsilon, \
|
||||
typename TTypes<T>::ConstFlat grad); \
|
||||
extern template struct ApplyAdadelta<GPUDevice, T>;
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#endif
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
// Note, this op works on cpu only.
|
||||
template <typename T, typename Tindex>
|
||||
class SparseApplyAdadeltaOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
mutex* mu_var = ctx->input_ref_mutex(0);
|
||||
// mu_accum is actually the same mutex as mu_var since currently we use a
|
||||
// global mutex.
|
||||
//
|
||||
// mutex* mu_accum = ctx->input_ref_mutex(1);
|
||||
if (use_exclusive_lock_) {
|
||||
mu_var->lock();
|
||||
}
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum_grad = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum_grad.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(1)));
|
||||
OP_REQUIRES(
|
||||
ctx, accum_update.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(2)));
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(accum_grad.shape()),
|
||||
errors::InvalidArgument("var and accum_grad do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
accum_grad.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(accum_update.shape()),
|
||||
errors::InvalidArgument("var and accum_update do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
accum_update.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
const Tensor& rho = ctx->input(4);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
|
||||
errors::InvalidArgument("rho is not a scalar: ",
|
||||
rho.shape().DebugString()));
|
||||
const Tensor& epsilon = ctx->input(5);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon.shape().DebugString()));
|
||||
const Tensor& grad = ctx->input(6);
|
||||
const Tensor& indices = ctx->input(7);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
|
||||
errors::InvalidArgument("indices must be one-dimensional"));
|
||||
|
||||
for (int d = 1; d < var.dims(); d++) {
|
||||
OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"var and grad must match in dimension ", d)));
|
||||
}
|
||||
const Tindex N = indices.dim_size(0);
|
||||
OP_REQUIRES(
|
||||
ctx, grad.dim_size(0) == N,
|
||||
errors::InvalidArgument(
|
||||
"grad must be the same size as indices in the first dimension."));
|
||||
|
||||
if (N > 0) {
|
||||
const Tindex first_dim_size = var.dim_size(0);
|
||||
// Validate all the indices are in range
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = indices_vec(i);
|
||||
OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
}
|
||||
|
||||
auto var_flat = var.flat_outer_dims<T>();
|
||||
auto accum_grad_flat = accum_grad.flat_outer_dims<T>();
|
||||
auto accum_update_flat = accum_update.flat_outer_dims<T>();
|
||||
auto grad_flat = grad.flat_outer_dims<T>();
|
||||
const T lr_scalar = lr.scalar<T>()();
|
||||
const T rho_scalar = rho.scalar<T>()();
|
||||
const T epsilon_scalar = epsilon.scalar<T>()();
|
||||
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = indices_vec(i);
|
||||
auto accum_ = accum_grad_flat.template chip<0>(index);
|
||||
auto accum_update_ = accum_update_flat.template chip<0>(index);
|
||||
auto grad_ = grad_flat.template chip<0>(i);
|
||||
|
||||
accum_ = accum_ * accum_.constant(rho_scalar) + grad_.square() * grad_.constant(1 - rho_scalar);
|
||||
const auto update = (accum_update_ + accum_update_.constant(epsilon_scalar)).sqrt() * (accum_ + accum_.constant(epsilon_scalar)).rsqrt() * grad_;
|
||||
accum_update_ = accum_update_ * accum_update_.constant(rho_scalar) + update.square() * update.constant(1 - rho_scalar);
|
||||
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
v -= update * update.constant(lr_scalar);
|
||||
}
|
||||
}
|
||||
if (use_exclusive_lock_) {
|
||||
mu_var->unlock();
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyAdadelta") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdadeltaOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(float, int32);
|
||||
REGISTER_KERNELS(float, int64);
|
||||
REGISTER_KERNELS(double, int32);
|
||||
REGISTER_KERNELS(double, int64);
|
||||
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyAdagradOp : public OpKernel {
|
||||
public:
|
||||
|
@ -33,6 +33,17 @@ struct ApplyGradientDescent {
|
||||
typename TTypes<T>::ConstFlat delta);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdadelta {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::Flat accum_update,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar rho,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdagrad {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
|
@ -51,6 +51,33 @@ struct ApplyAdagrad<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::Flat accum_update,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar rho,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
|
||||
accum.device(d) =
|
||||
accum_update * rho.reshape(single).broadcast(bcast) +
|
||||
grad.square() * (grad.constant(1) - rho.reshape(single).broadcast(bcast));
|
||||
const auto update =
|
||||
(accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
|
||||
(accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
|
||||
accum_update.device(d) =
|
||||
accum_update * rho.reshape(single).broadcast(bcast) +
|
||||
update.square() * (grad.constant(1) - rho.reshape(single).broadcast(bcast));
|
||||
var.device(d) -= update * lr.reshape(single).broadcast(bcast);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct ApplyMomentum<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -129,6 +156,9 @@ template struct functor::ApplyGradientDescent<GPUDevice, double>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, double>;
|
||||
|
||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||
|
||||
template struct functor::ApplyMomentum<GPUDevice, float>;
|
||||
template struct functor::ApplyMomentum<GPUDevice, double>;
|
||||
|
||||
|
@ -392,6 +392,73 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ApplyAdadelta"
|
||||
input_arg {
|
||||
name: "var"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum_update"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "lr"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "rho"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "epsilon"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "out"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ApplyAdagrad"
|
||||
input_arg {
|
||||
@ -12465,6 +12532,87 @@ op {
|
||||
type: "int"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseApplyAdadelta"
|
||||
input_arg {
|
||||
name: "var"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum_update"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "lr"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "rho"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "epsilon"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "out"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseApplyAdagrad"
|
||||
input_arg {
|
||||
|
@ -288,6 +288,84 @@ op {
|
||||
summary: "Computes the \"logical or\" of elements across dimensions of a tensor."
|
||||
description: "Reduces `input` along the dimensions given in `reduction_indices`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_indices`. If `keep_dims` is true, the reduced dimensions are\nretained with length 1."
|
||||
}
|
||||
op {
|
||||
name: "ApplyAdadelta"
|
||||
input_arg {
|
||||
name: "var"
|
||||
description: "Should be from a Variable()."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum"
|
||||
description: "Should be from a Variable()."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum_update"
|
||||
description: "Should be from a Variable()."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "lr"
|
||||
description: "Scaling factor. Must be a scalar."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "rho"
|
||||
description: "Decay factor. Must be a scalar."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "epsilon"
|
||||
description: "Constant factor. Must be a scalar."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "The gradient."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "out"
|
||||
description: "Same as \"var\"."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "If True, updating of the var, accum and update_accum tensors will be protected by\na lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Update \'*var\' according to the adadelta scheme."
|
||||
description: "accum = rho() * accum + (1 - rho()) * grad.square();\nupdate = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad;\nupdate_accum = rho() * update_accum + (1 - rho()) * update.square();\nvar -= update;"
|
||||
}
|
||||
op {
|
||||
name: "ApplyAdagrad"
|
||||
input_arg {
|
||||
@ -9293,6 +9371,88 @@ op {
|
||||
summary: "SpaceToDepth for tensors of type T."
|
||||
description: "Rearranges blocks of spatial data, into depth. More specifically,\nthis op outputs a copy of the input tensor where values from the `height`\nand `width` dimensions are moved to the `depth` dimension.\nThe attr `block_size` indicates the input block size and how the data is moved.\n\n * Non-overlapping blocks of size `block_size x block size` are rearranged\n into depth at each location.\n * The depth of the output tensor is `input_depth * block_size * block_size`.\n * The input tensor\'s height and width must be divisible by block_size.\n\nThat is, assuming the input is in the shape:\n`[batch, height, width, depth]`,\nthe shape of the output will be:\n`[batch, height/block_size, width/block_size, depth*block_size*block_size]`\n\nThis operation requires that the input tensor be of rank 4, and that\n`block_size` be >=1 and a divisor of both the input `height` and `width`.\n\nThis operation is useful for resizing the activations between convolutions\n(but keeping all data), e.g. instead of pooling. It is also useful for training\npurely convolutional models.\n\nFor example, given this input of shape `[1, 2, 2, 1]`, and block_size of 2:\n\n```prettyprint\nx = [[[[1], [2]],\n [[3], [4]]]]\n```\n\nThis operation will output a tensor of shape `[1, 1, 1, 4]`:\n\n```prettyprint\n[[[[1, 2, 3, 4]]]]\n```\n\nHere, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`,\nthe corresponding output will have a single element (i.e. width and height are\nboth 1) and will have a depth of 4 channels (1 * block_size * block_size).\nThe output element shape is `[1, 1, 4]`.\n\nFor an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g.\n\n```prettyprint\nx = [[[[1, 2, 3], [4, 5, 6]],\n [[7, 8, 9], [10, 11, 12]]]]\n```\n\nThis operation, for block_size of 2, will return the following tensor of shape\n`[1, 1, 1, 12]`\n\n```prettyprint\n[[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]\n```\n\nSimilarly, for the following input of shape `[1 4 4 1]`, and a block size of 2:\n\n```prettyprint\nx = [[ [1], [2], [5], [6]],\n [ [3], [4], [7], [8]],\n [ [9], [10], [13], [14]],\n [ [11], [12], [15], [16]]]\n```\n\nthe operator will return the following tensor of shape `[1 2 2 4]`:\n\n```prettyprint\nx = [[[[1, 2, 3, 4],\n [5, 6, 7, 8]],\n [[9, 10, 11, 12],\n [13, 14, 15, 16]]]]\n```"
|
||||
}
|
||||
op {
|
||||
name: "SparseApplyAdadelta"
|
||||
input_arg {
|
||||
name: "var"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "accum_update"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "lr"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "rho"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "epsilon"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "out"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
summary: "var: Should be from a Variable()."
|
||||
}
|
||||
op {
|
||||
name: "SparseApplyAdagrad"
|
||||
input_arg {
|
||||
|
@ -35,6 +35,64 @@ use_locking: If True, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ApplyAdadelta")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
.Input("accum_update: Ref(T)")
|
||||
.Input("lr: T")
|
||||
.Input("rho: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(R"doc(
|
||||
Update '*var' according to the adadelta scheme.
|
||||
|
||||
accum = rho() * accum + (1 - rho()) * grad.square();
|
||||
update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad;
|
||||
update_accum = rho() * update_accum + (1 - rho()) * update.square();
|
||||
var -= update;
|
||||
|
||||
var: Should be from a Variable().
|
||||
accum: Should be from a Variable().
|
||||
accum_update: Should be from a Variable().
|
||||
lr: Scaling factor. Must be a scalar.
|
||||
rho: Decay factor. Must be a scalar.
|
||||
epsilon: Constant factor. Must be a scalar.
|
||||
grad: The gradient.
|
||||
out: Same as "var".
|
||||
use_locking: If True, updating of the var, accum and update_accum tensors will be protected by
|
||||
a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseApplyAdadelta")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
.Input("accum_update: Ref(T)")
|
||||
.Input("lr: T")
|
||||
.Input("rho: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tindices")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(R"doc(
|
||||
var: Should be from a Variable().
|
||||
accum_grad: Should be from a Variable().
|
||||
accum_update:: Should be from a Variable().
|
||||
lr: Learning rate. Must be a scalar.
|
||||
rho: Decay factor. Must be a scalar.
|
||||
epsilon: Constant factor. Must be a scalar.
|
||||
grad: The gradient.
|
||||
indices: A vector of indices into the first dimension of var and accum.
|
||||
out: Same as "var".
|
||||
use_locking: If True, updating of the var and accum tensors will be protected by
|
||||
a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ApplyAdagrad")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("accum: Ref(T)")
|
||||
|
84
tensorflow/python/training/adadelta.py
Normal file
84
tensorflow/python/training/adadelta.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Adadelta for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
|
||||
|
||||
class AdadeltaOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Adadelta algorithm.
|
||||
|
||||
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
|
||||
([pdf](http://arxiv.org/pdf/1212.570.pdf))
|
||||
|
||||
@@__init__
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,
|
||||
use_locking=False, name="Adadelta"):
|
||||
"""Construct a new Adadelta optimizer.
|
||||
|
||||
Args:
|
||||
learning_rate: A `Tensor` or a floating point value. The learning rate.
|
||||
rho: A `Tensor` or a floating point value. The decay rate.
|
||||
epsilon: A `Tensor` or a floating point value. A constant epsilon used
|
||||
to better conditioning the grad update.
|
||||
use_locking: If `True` use locks for update operations.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "Adadelta".
|
||||
"""
|
||||
super(AdadeltaOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._rho = rho
|
||||
self._epsilon = epsilon
|
||||
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._rho_t = None
|
||||
self._epsilon_t = None
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "accum", self._name)
|
||||
self._zeros_slot(v, "accum_update", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
self._lr_t = ops.convert_to_tensor(self._lr, name="lr")
|
||||
self._rho_t = ops.convert_to_tensor(self._rho, name="rho")
|
||||
self._epsilon_t = ops.convert_to_tensor(self._epsilon,
|
||||
name="epsilon")
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
accum = self.get_slot(var, "accum")
|
||||
accum_update = self.get_slot(var, "accum_update")
|
||||
return training_ops.apply_adadelta(
|
||||
var, accum, accum_update,
|
||||
self._lr_t, self._rho_t, self._epsilon_t, grad,
|
||||
use_locking=self._use_locking)
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
accum = self.get_slot(var, "accum")
|
||||
accum_update = self.get_slot(var, "accum_update")
|
||||
return training_ops.sparse_apply_adadelta(
|
||||
var, accum, accum_update, self._lr_t,
|
||||
self._rho_t, self._epsilon_t, grad.values,
|
||||
grad.indices, use_locking=self._use_locking)
|
113
tensorflow/python/training/adadelta_test.py
Normal file
113
tensorflow/python/training/adadelta_test.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for Adadelta Optimizer."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class AdadeltaOptimizerTest(tf.test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session():
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0, 4.0])
|
||||
grads0 = tf.constant([0.1, 0.1])
|
||||
grads1 = tf.constant([0.01, 0.01])
|
||||
lr = 1.0
|
||||
rho = 0.95
|
||||
epsilon = 1e-8
|
||||
|
||||
adadelta_opt = tf.train.AdadeltaOptimizer(lr, rho=rho, epsilon=epsilon)
|
||||
adadelta_update = adadelta_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
# Check we have slots
|
||||
self.assertEqual(["accum", "accum_update"],
|
||||
adadelta_opt.get_slot_names())
|
||||
slot0 = adadelta_opt.get_slot(var0, "accum")
|
||||
self.assertEquals(slot0.get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot0 in tf.trainable_variables())
|
||||
|
||||
slot0_update = adadelta_opt.get_slot(var0, "accum_update")
|
||||
self.assertEquals(slot0_update.get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot0_update in tf.trainable_variables())
|
||||
|
||||
|
||||
slot1 = adadelta_opt.get_slot(var1, "accum")
|
||||
self.assertEquals(slot1.get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot1 in tf.trainable_variables())
|
||||
|
||||
slot1_update = adadelta_opt.get_slot(var1, "accum_update")
|
||||
self.assertEquals(slot1_update.get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot1_update in tf.trainable_variables())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
adadelta_update.run()
|
||||
|
||||
# Check that the accumulators have been updated.
|
||||
grad = 0.1
|
||||
accum = 0
|
||||
accum_update = 0
|
||||
|
||||
accum = accum * rho + (grad**2) * (1 - rho)
|
||||
update1 = np.sqrt(accum_update + epsilon) * (1. / np.sqrt(accum + epsilon)) * grad
|
||||
accum_update = accum_update * rho + (update1**2) * (1.0 - rho)
|
||||
|
||||
self.assertAllClose(np.array([accum, accum]), slot0.eval())
|
||||
self.assertAllClose(np.array([accum_update, accum_update]), slot0_update.eval())
|
||||
|
||||
# Check that the parameters have been updated.
|
||||
self.assertAllClose(np.array([1.0 - update1 * lr,
|
||||
2.0 - update1 * lr]),
|
||||
var0.eval(), rtol=1e-3)
|
||||
|
||||
self.assertAllClose(np.array([3.0 - update1 * lr,
|
||||
4.0 - update1 * lr]),
|
||||
var1.eval(), rtol=1e-3)
|
||||
|
||||
# Step 2: the momentum accumulators contain the previous update.
|
||||
accum = accum * rho + (grad**2) * (1 - rho)
|
||||
update2 = ((accum_update + epsilon)**0.5) * (1. / (accum + epsilon)**0.5) * grad
|
||||
accum_update = accum_update * rho + (update2**2) * (1.0 - rho)
|
||||
|
||||
adadelta_update.run()
|
||||
|
||||
# Check that the momentum accumulators have been updated.
|
||||
self.assertAllClose(np.array([accum, accum]), slot0.eval())
|
||||
self.assertAllClose(np.array([accum_update, accum_update]), slot0_update.eval())
|
||||
|
||||
# Check that the parameters have been updated.
|
||||
self.assertAllClose(
|
||||
np.array([1.0 - update1 - update2,
|
||||
2.0 - update1 - update2]),
|
||||
var0.eval(), rtol=1e-3)
|
||||
|
||||
self.assertAllClose(np.array([3.0 - update1 - update2,
|
||||
4.0 - update1 - update2]),
|
||||
var1.eval(), rtol=1e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -28,6 +28,7 @@ of the subclasses.
|
||||
@@Optimizer
|
||||
|
||||
@@GradientDescentOptimizer
|
||||
@@AdadeltaOptimizer
|
||||
@@AdagradOptimizer
|
||||
@@MomentumOptimizer
|
||||
@@AdamOptimizer
|
||||
@ -134,6 +135,7 @@ from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
|
||||
from tensorflow.python.training.adadelta import AdadeltaOptimizer
|
||||
from tensorflow.python.training.adagrad import AdagradOptimizer
|
||||
from tensorflow.python.training.adam import AdamOptimizer
|
||||
from tensorflow.python.training.ftrl import FtrlOptimizer
|
||||
|
@ -47,6 +47,18 @@ def _AssertInputIsScalar(op, index):
|
||||
op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyAdadelta")
|
||||
def _ApplyAdadeltaShape(op):
|
||||
"""Shape function for the ApplyAdadelta op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
accum_update_shape = op.inputs[2].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # rho
|
||||
_AssertInputIsScalar(op, 5) # epsilon
|
||||
grad_shape = op.inputs[6].get_shape().merge_with(accum_shape)
|
||||
return [grad_shape]
|
||||
|
||||
@ops.RegisterShape("ApplyAdagrad")
|
||||
def _ApplyAdagradShape(op):
|
||||
"""Shape function for the ApplyAdagrad op."""
|
||||
@ -120,6 +132,20 @@ def _ApplyGradientDescentShape(op):
|
||||
delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
|
||||
return [delta_shape]
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdadelta")
|
||||
def _SparseApplyAdadeltaShape(op):
|
||||
"""Shape function for the SparseApplyAdadelta op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_grad_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
accum_update_shape = op.inputs[2].get_shape().merge_with(accum_grad_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # decay_rate
|
||||
_AssertInputIsScalar(op, 5) # epsilon
|
||||
grad_shape = op.inputs[6].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(accum_update_shape[1:]))
|
||||
unused_indices_shape = op.inputs[7].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [accum_update_shape]
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdagrad")
|
||||
def _SparseApplyAdagradShape(op):
|
||||
|
Loading…
Reference in New Issue
Block a user