Validate that all variables under the model were created in the distribution strategy scope.
PiperOrigin-RevId: 227595107
This commit is contained in:
parent
26d8486baa
commit
29bc2e98ea
@ -245,6 +245,18 @@ def all_strategy_combinations():
|
|||||||
return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
|
return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
|
||||||
|
|
||||||
|
|
||||||
|
def all_strategy_combinations_minus_default():
|
||||||
|
strategy_minus_default_combinations = combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
combinations.one_device_strategy,
|
||||||
|
combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
combinations.mirrored_strategy_with_two_gpus,
|
||||||
|
combinations.core_mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
combinations.core_mirrored_strategy_with_two_gpus],
|
||||||
|
mode=['graph', 'eager'])
|
||||||
|
return strategy_minus_default_combinations + tpu_strategy_combinations()
|
||||||
|
|
||||||
|
|
||||||
# TODO(priyag): Add v2 optimizers here.
|
# TODO(priyag): Add v2 optimizers here.
|
||||||
def strategy_and_optimizer_combinations():
|
def strategy_and_optimizer_combinations():
|
||||||
return combinations.times(
|
return combinations.times(
|
||||||
@ -1149,5 +1161,37 @@ class TestDistributionStrategyWithNormalizationLayer(
|
|||||||
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistributionStrategyVariableValidation(test.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(all_strategy_combinations_minus_default())
|
||||||
|
def test_layer_outside_scope(self, distribution):
|
||||||
|
with self.cached_session():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'was not created in the distribution strategy'):
|
||||||
|
x = keras.layers.Input(shape=(3,), name='input')
|
||||||
|
y = keras.layers.Dense(4, name='dense')(x)
|
||||||
|
with distribution.scope():
|
||||||
|
model = keras.Model(x, y)
|
||||||
|
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||||
|
loss = 'mse'
|
||||||
|
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
|
||||||
|
model.compile(optimizer, loss, metrics=metrics)
|
||||||
|
|
||||||
|
@combinations.generate(all_strategy_combinations_minus_default())
|
||||||
|
def test_model_outside_scope(self, distribution):
|
||||||
|
with self.cached_session():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'was not created in the distribution strategy'):
|
||||||
|
x = keras.layers.Input(shape=(3,), name='input')
|
||||||
|
y = keras.layers.Dense(4, name='dense')(x)
|
||||||
|
model = keras.Model(x, y)
|
||||||
|
with distribution.scope():
|
||||||
|
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||||
|
loss = 'mse'
|
||||||
|
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
|
||||||
|
model.compile(optimizer, loss, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -548,6 +548,20 @@ class Model(Network):
|
|||||||
trainable_weights = self.trainable_weights
|
trainable_weights = self.trainable_weights
|
||||||
self._collected_trainable_weights = trainable_weights
|
self._collected_trainable_weights = trainable_weights
|
||||||
|
|
||||||
|
# Validate all variables were correctly created in distribution scope.
|
||||||
|
if self._distribution_strategy and not self._compile_distribution:
|
||||||
|
for v in self.variables:
|
||||||
|
if v.distribute_strategy is not self._distribution_strategy:
|
||||||
|
raise ValueError(
|
||||||
|
'Variable (%s) was not created in the distribution strategy '
|
||||||
|
'scope of (%s). It is most likely due to not all layers or '
|
||||||
|
'the model or optimizer being created outside the distribution '
|
||||||
|
'strategy scope. Try to make sure your code looks similar '
|
||||||
|
'to the following.\n'
|
||||||
|
'with strategy.scope():\n'
|
||||||
|
' model=_create_model()\n'
|
||||||
|
' model.compile(...)'% (v, self._distribution_strategy))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metrics(self):
|
def metrics(self):
|
||||||
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
|
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
|
||||||
|
Loading…
Reference in New Issue
Block a user