Remove colocate_with function from FTRL optimizer for use with DistributionStrategy.

PiperOrigin-RevId: 286981819
Change-Id: I76e644013c8f84e0a3e90b5e2039acb2904023e6
This commit is contained in:
A. Unique TensorFlower 2019-12-23 21:58:16 -08:00 committed by TensorFlower Gardener
parent 49f453f1f9
commit 6a4342982f
3 changed files with 16 additions and 7 deletions

View File

@ -44,7 +44,11 @@ from tensorflow.python.ops.losses import losses_impl
VAR_MAP_V1 = {
"GradientDescent": ("dense/kernel", "dense/bias"),
"Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
"dense/bias")
"dense/bias"),
"Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl",
"dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"),
"RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1",
"dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp")
}
VAR_MAP_V2 = {

View File

@ -40,6 +40,7 @@ from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
@ -130,11 +131,16 @@ adagrad_optimizer_v1_fn = combinations.NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
adam_optimizer_v1_fn = combinations.NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
ftrl_optimizer_v1_fn = combinations.NamedObject(
"FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
rmsprop_optimizer_v1_fn = combinations.NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
# TODO(shiningsun): consider adding the other v1 optimizers
optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
optimizers_v1 = [
gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
]
adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
"AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))

View File

@ -132,11 +132,10 @@ class FtrlOptimizer(optimizer.Optimizer):
def _create_slots(self, var_list):
# Create the "accum" and "linear" slots.
for v in var_list:
with ops.colocate_with(v):
val = constant_op.constant(
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
self._zeros_slot(v, "linear", self._linear_name or self._name)
val = constant_op.constant(
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
self._zeros_slot(v, "linear", self._linear_name or self._name)
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(