Validate that all variables under the model were created in the distribution strategy scope.
PiperOrigin-RevId: 227595107
This commit is contained in:
parent
26d8486baa
commit
29bc2e98ea
@ -245,6 +245,18 @@ def all_strategy_combinations():
|
||||
return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
|
||||
|
||||
|
||||
def all_strategy_combinations_minus_default():
|
||||
strategy_minus_default_combinations = combinations.combine(
|
||||
distribution=[
|
||||
combinations.one_device_strategy,
|
||||
combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
combinations.mirrored_strategy_with_two_gpus,
|
||||
combinations.core_mirrored_strategy_with_gpu_and_cpu,
|
||||
combinations.core_mirrored_strategy_with_two_gpus],
|
||||
mode=['graph', 'eager'])
|
||||
return strategy_minus_default_combinations + tpu_strategy_combinations()
|
||||
|
||||
|
||||
# TODO(priyag): Add v2 optimizers here.
|
||||
def strategy_and_optimizer_combinations():
|
||||
return combinations.times(
|
||||
@ -1149,5 +1161,37 @@ class TestDistributionStrategyWithNormalizationLayer(
|
||||
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
class TestDistributionStrategyVariableValidation(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(all_strategy_combinations_minus_default())
|
||||
def test_layer_outside_scope(self, distribution):
|
||||
with self.cached_session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'was not created in the distribution strategy'):
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
with distribution.scope():
|
||||
model = keras.Model(x, y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
|
||||
@combinations.generate(all_strategy_combinations_minus_default())
|
||||
def test_model_outside_scope(self, distribution):
|
||||
with self.cached_session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'was not created in the distribution strategy'):
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
with distribution.scope():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -548,6 +548,20 @@ class Model(Network):
|
||||
trainable_weights = self.trainable_weights
|
||||
self._collected_trainable_weights = trainable_weights
|
||||
|
||||
# Validate all variables were correctly created in distribution scope.
|
||||
if self._distribution_strategy and not self._compile_distribution:
|
||||
for v in self.variables:
|
||||
if v.distribute_strategy is not self._distribution_strategy:
|
||||
raise ValueError(
|
||||
'Variable (%s) was not created in the distribution strategy '
|
||||
'scope of (%s). It is most likely due to not all layers or '
|
||||
'the model or optimizer being created outside the distribution '
|
||||
'strategy scope. Try to make sure your code looks similar '
|
||||
'to the following.\n'
|
||||
'with strategy.scope():\n'
|
||||
' model=_create_model()\n'
|
||||
' model.compile(...)'% (v, self._distribution_strategy))
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
|
||||
|
Loading…
Reference in New Issue
Block a user