Remove colocate_with
function from FTRL optimizer for use with DistributionStrategy.
PiperOrigin-RevId: 286981819 Change-Id: I76e644013c8f84e0a3e90b5e2039acb2904023e6
This commit is contained in:
parent
49f453f1f9
commit
6a4342982f
@ -44,7 +44,11 @@ from tensorflow.python.ops.losses import losses_impl
|
|||||||
VAR_MAP_V1 = {
|
VAR_MAP_V1 = {
|
||||||
"GradientDescent": ("dense/kernel", "dense/bias"),
|
"GradientDescent": ("dense/kernel", "dense/bias"),
|
||||||
"Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
|
"Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
|
||||||
"dense/bias")
|
"dense/bias"),
|
||||||
|
"Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl",
|
||||||
|
"dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"),
|
||||||
|
"RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1",
|
||||||
|
"dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp")
|
||||||
}
|
}
|
||||||
|
|
||||||
VAR_MAP_V2 = {
|
VAR_MAP_V2 = {
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
|||||||
from tensorflow.python.tpu import tpu_strategy_util
|
from tensorflow.python.tpu import tpu_strategy_util
|
||||||
from tensorflow.python.training import adagrad
|
from tensorflow.python.training import adagrad
|
||||||
from tensorflow.python.training import adam
|
from tensorflow.python.training import adam
|
||||||
|
from tensorflow.python.training import ftrl
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
from tensorflow.python.training import rmsprop
|
from tensorflow.python.training import rmsprop
|
||||||
|
|
||||||
@ -130,11 +131,16 @@ adagrad_optimizer_v1_fn = combinations.NamedObject(
|
|||||||
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
||||||
adam_optimizer_v1_fn = combinations.NamedObject(
|
adam_optimizer_v1_fn = combinations.NamedObject(
|
||||||
"AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
|
"AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
|
||||||
|
ftrl_optimizer_v1_fn = combinations.NamedObject(
|
||||||
|
"FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
|
||||||
rmsprop_optimizer_v1_fn = combinations.NamedObject(
|
rmsprop_optimizer_v1_fn = combinations.NamedObject(
|
||||||
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
|
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
|
||||||
|
|
||||||
# TODO(shiningsun): consider adding the other v1 optimizers
|
# TODO(shiningsun): consider adding the other v1 optimizers
|
||||||
optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
|
optimizers_v1 = [
|
||||||
|
gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
|
||||||
|
ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
|
||||||
|
]
|
||||||
|
|
||||||
adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
|
adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||||
"AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
|
"AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
|
||||||
|
@ -132,7 +132,6 @@ class FtrlOptimizer(optimizer.Optimizer):
|
|||||||
def _create_slots(self, var_list):
|
def _create_slots(self, var_list):
|
||||||
# Create the "accum" and "linear" slots.
|
# Create the "accum" and "linear" slots.
|
||||||
for v in var_list:
|
for v in var_list:
|
||||||
with ops.colocate_with(v):
|
|
||||||
val = constant_op.constant(
|
val = constant_op.constant(
|
||||||
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
|
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
|
||||||
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
|
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
|
||||||
|
Loading…
Reference in New Issue
Block a user