Remove colocate_with
function from FTRL optimizer for use with DistributionStrategy.
PiperOrigin-RevId: 286981819 Change-Id: I76e644013c8f84e0a3e90b5e2039acb2904023e6
This commit is contained in:
parent
49f453f1f9
commit
6a4342982f
@ -44,7 +44,11 @@ from tensorflow.python.ops.losses import losses_impl
|
||||
VAR_MAP_V1 = {
|
||||
"GradientDescent": ("dense/kernel", "dense/bias"),
|
||||
"Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
|
||||
"dense/bias")
|
||||
"dense/bias"),
|
||||
"Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl",
|
||||
"dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"),
|
||||
"RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1",
|
||||
"dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp")
|
||||
}
|
||||
|
||||
VAR_MAP_V2 = {
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
from tensorflow.python.training import adagrad
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import ftrl
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
@ -130,11 +131,16 @@ adagrad_optimizer_v1_fn = combinations.NamedObject(
|
||||
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
||||
adam_optimizer_v1_fn = combinations.NamedObject(
|
||||
"AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
|
||||
ftrl_optimizer_v1_fn = combinations.NamedObject(
|
||||
"FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
|
||||
rmsprop_optimizer_v1_fn = combinations.NamedObject(
|
||||
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
|
||||
|
||||
# TODO(shiningsun): consider adding the other v1 optimizers
|
||||
optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
|
||||
optimizers_v1 = [
|
||||
gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
|
||||
ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
|
||||
]
|
||||
|
||||
adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||
"AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
|
||||
|
@ -132,11 +132,10 @@ class FtrlOptimizer(optimizer.Optimizer):
|
||||
def _create_slots(self, var_list):
|
||||
# Create the "accum" and "linear" slots.
|
||||
for v in var_list:
|
||||
with ops.colocate_with(v):
|
||||
val = constant_op.constant(
|
||||
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
|
||||
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
|
||||
self._zeros_slot(v, "linear", self._linear_name or self._name)
|
||||
val = constant_op.constant(
|
||||
self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
|
||||
self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
|
||||
self._zeros_slot(v, "linear", self._linear_name or self._name)
|
||||
|
||||
def _prepare(self):
|
||||
self._learning_rate_tensor = ops.convert_to_tensor(
|
||||
|
Loading…
Reference in New Issue
Block a user