Fix error when using colocate_with
function with DistributionStrategy.
PiperOrigin-RevId: 208702602
This commit is contained in:
parent
d576b76bf5
commit
0c98648e9a
@ -802,6 +802,7 @@ class _VariableStore(object):
|
||||
initializing_from_value = False
|
||||
if initializer is not None and not callable(initializer):
|
||||
initializing_from_value = True
|
||||
|
||||
if shape is not None and initializing_from_value:
|
||||
raise ValueError("If initializer is a constant, do not specify shape.")
|
||||
|
||||
@ -837,9 +838,6 @@ class _VariableStore(object):
|
||||
raise ValueError("Variable %s does not exist, or was not created with "
|
||||
"tf.get_variable(). Did you mean to set "
|
||||
"reuse=tf.AUTO_REUSE in VarScope?" % name)
|
||||
if not shape.is_fully_defined() and not initializing_from_value:
|
||||
raise ValueError("Shape of a new variable (%s) must be fully defined, "
|
||||
"but instead was %s." % (name, shape))
|
||||
|
||||
# Create the tensor to initialize the variable with default value.
|
||||
if initializer is None:
|
||||
@ -854,8 +852,11 @@ class _VariableStore(object):
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if isinstance(initializer, type(init_ops.Initializer)):
|
||||
initializer = initializer(dtype=dtype)
|
||||
if validate_shape:
|
||||
init_val = lambda: initializer( # pylint: disable=g-long-lambda
|
||||
shape.as_list(), dtype=dtype, partition_info=partition_info)
|
||||
else:
|
||||
init_val = initializer
|
||||
variable_dtype = dtype.base_dtype
|
||||
|
||||
# Create the variable.
|
||||
|
@ -70,17 +70,17 @@ class AdagradOptimizer(optimizer.Optimizer):
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
for v in var_list:
|
||||
with ops.colocate_with(v):
|
||||
dtype = v.dtype.base_dtype
|
||||
if v.get_shape().is_fully_defined():
|
||||
init = init_ops.constant_initializer(self._initial_accumulator_value,
|
||||
dtype=dtype)
|
||||
else:
|
||||
def init(v=v, dtype=dtype):
|
||||
# Use a Tensor instead of initializer if variable does not have static
|
||||
# shape.
|
||||
init_constant = gen_array_ops.fill(array_ops.shape(v),
|
||||
self._initial_accumulator_value)
|
||||
init = math_ops.cast(init_constant, dtype)
|
||||
return math_ops.cast(init_constant, dtype)
|
||||
self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype,
|
||||
"accumulator", self._name)
|
||||
|
||||
|
@ -302,6 +302,39 @@ class AdagradOptimizerTest(test.TestCase):
|
||||
# Creating optimizer should cause no exception.
|
||||
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
|
||||
|
||||
def testDynamicShapeVariableWithCallableInit(self):
|
||||
var0 = variable_scope.get_variable("var0",
|
||||
initializer=constant_op.constant(1.),
|
||||
validate_shape=False)
|
||||
self.assertFalse(var0.shape.is_fully_defined())
|
||||
|
||||
grads0 = constant_op.constant(0.1, dtype=dtypes.float32)
|
||||
learning_rate = lambda: 3.0
|
||||
|
||||
ada_opt = adagrad.AdagradOptimizer(
|
||||
learning_rate, initial_accumulator_value=0.1, use_locking=True)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
ada_update = ada_opt.apply_gradients(
|
||||
zip([grads0], [var0]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
v0_val = self.evaluate([var0])
|
||||
self.assertAllClose([1.0], v0_val)
|
||||
|
||||
# Run 3 steps of adagrad
|
||||
for _ in range(3):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(ada_update)
|
||||
else:
|
||||
ada_opt.apply_gradients(zip([grads0], [var0]))
|
||||
|
||||
# Validate updated params
|
||||
v0_val = self.evaluate([var0])
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-1.6026098728179932]), v0_val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user