Remove more parts of DCASGD missed in the first pass. (47949b)
PiperOrigin-RevId: 164914552
This commit is contained in:
parent
73b3d52c7b
commit
b8d13d218f
@ -34,15 +34,6 @@ struct ApplyGradientDescent {
|
||||
typename TTypes<T>::ConstFlat delta);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyDelayCompensatedGradientDescent {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::ConstScalar alpha,
|
||||
typename TTypes<T>::ConstFlat delta,
|
||||
typename TTypes<T>::ConstScalar lambda,
|
||||
typename TTypes<T>::Flat shadow);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct ApplyAdadelta {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat var,
|
||||
|
@ -1035,59 +1035,6 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ApplyDelayCompensatedGradientDescent"
|
||||
input_arg {
|
||||
name: "var"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "alpha"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "delta"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "lambda"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "shadow"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ApplyFtrl"
|
||||
input_arg {
|
||||
|
@ -1049,67 +1049,6 @@ op {
|
||||
summary: "Update \'*var\' according to the centered RMSProp algorithm."
|
||||
description: "The centered RMSProp algorithm uses an estimate of the centered second moment\n(i.e., the variance) for normalization, as opposed to regular RMSProp, which\nuses the (uncentered) second moment. This often helps with training, but is\nslightly more expensive in terms of computation and memory.\n\nNote that in dense implementation of this algorithm, mg, ms, and mom will\nupdate even if the grad is zero, but in this sparse implementation, mg, ms,\nand mom will not update in iterations during which the grad is zero.\n\nmean_square = decay * mean_square + (1-decay) * gradient ** 2\nmean_grad = decay * mean_grad + (1-decay) * gradient\n\nDelta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)\n\nmg <- rho * mg_{t-1} + (1-rho) * grad\nms <- rho * ms_{t-1} + (1-rho) * grad * grad\nmom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\nvar <- var - mom"
|
||||
}
|
||||
op {
|
||||
name: "ApplyDelayCompensatedGradientDescent"
|
||||
input_arg {
|
||||
name: "var"
|
||||
description: "Should be from a Variable()."
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "alpha"
|
||||
description: "Scaling factor. Must be a scalar."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "delta"
|
||||
description: "The change."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "lambda"
|
||||
description: "The variance parameter."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "shadow"
|
||||
description: "Same as \"var\"."
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "If `True`, the subtraction will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "var -= alpha * (delta + lambda * delta * (var - shadow))"
|
||||
description: "Update \'*shadow\' by changing it to the new value of \'var\'"
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ApplyFtrl"
|
||||
input_arg {
|
||||
|
@ -103,28 +103,6 @@ use_locking: If `True`, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ApplyDelayCompensatedGradientDescent")
|
||||
.Input("var: resource")
|
||||
.Input("alpha: T")
|
||||
.Input("delta: T")
|
||||
.Input("lambda: T")
|
||||
.Input("shadow: resource")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("use_locking: bool = false")
|
||||
.SetShapeFn(ApplyGradientDescentShapeFn)
|
||||
.Doc(R"doc(
|
||||
var -= alpha * (delta + lambda * delta * (var - shadow))
|
||||
Update '*shadow' by changing it to the new value of 'var'
|
||||
|
||||
var: Should be from a Variable().
|
||||
alpha: Scaling factor. Must be a scalar.
|
||||
delta: The change.
|
||||
lambda: The variance parameter.
|
||||
shadow: Same as "var".
|
||||
use_locking: If `True`, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
)doc");
|
||||
|
||||
static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
|
||||
bool sparse) {
|
||||
ShapeHandle unused;
|
||||
|
Loading…
Reference in New Issue
Block a user