Merge pull request #41385 from tomerk/cherrypicks_ORV3K

[Cherrypick:r2.3] Explicitly raise a (clearer) error message when models end up in inva…
This commit is contained in:
Goldie Gadde 2020-07-17 09:20:32 -07:00 committed by GitHub
commit 89bb4c3f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
# might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True
# Mark this layer as having been originally built as a tf1 layer/model
self._originally_built_as_v1 = True
@trackable.no_automatic_dependency_tracking
@generic_utils.default
def build(self, input_shape):
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
ValueError: if the layer's `call` method returns None (an invalid value).
RuntimeError: if `super().__init__()` was not called in the constructor.
"""
self._assert_built_as_v1()
if not hasattr(self, '_thread_local'):
raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.')
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
return outputs
def _assert_built_as_v1(self):
if not hasattr(self, '_originally_built_as_v1'):
raise ValueError(
'Your Layer or Model is in an invalid state. This can happen if you '
'are interleaving estimator/non-estimator models or '
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
'with models/layers created outside of it. '
'Converting a model to an estimator (via model_to_estimator) '
'invalidates all models/layers made before the conversion (even '
'if they were not the model converted to an estimator). '
'Similarly, making a layer or a model inside a '
'a tf.compat.v1.Graph invalidates all layers/models you previously '
'made outside of the graph.')
@property
def dtype(self):
return self._dtype_policy.variable_dtype

View File

@ -303,6 +303,7 @@ class Model(training_lib.Model):
ValueError: In case of invalid arguments for
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
self._assert_built_as_v1()
self._run_eagerly = kwargs.pop('run_eagerly', None)
self._experimental_run_tf_function = kwargs.pop(
'experimental_run_tf_function', True)
@ -773,6 +774,7 @@ class Model(training_lib.Model):
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('fit_v1').set(True)
# Legacy support
if 'nb_epoch' in kwargs:
@ -893,6 +895,7 @@ class Model(training_lib.Model):
Raises:
ValueError: in case of invalid arguments.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('evaluate_v1').set(True)
self._assert_compile_was_called()
self._check_call_args('evaluate')
@ -972,6 +975,7 @@ class Model(training_lib.Model):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('predict_v1').set(True)
self._check_call_args('predict')