Explicitly raise a (clearer) error message when models end up in invalid states due to interleaving graph and eager.

In rare cases code may have run w/o crashing when in these invalid states, but it's safer to error with an explanation rather than risk silent failures/fragile behavior.

PiperOrigin-RevId: 321192744
Change-Id: I9e97ac3b7cea27c9b389e5202de9f1c09a4aa2b8
This commit is contained in:
Tomer Kaftan 2020-07-14 11:02:34 -07:00
parent 14b2d686d6
commit ef4db27b31
2 changed files with 23 additions and 0 deletions

View File

@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
# might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True
# Mark this layer as having been originally built as a tf1 layer/model
self._originally_built_as_v1 = True
@trackable.no_automatic_dependency_tracking
@generic_utils.default
def build(self, input_shape):
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
ValueError: if the layer's `call` method returns None (an invalid value).
RuntimeError: if `super().__init__()` was not called in the constructor.
"""
self._assert_built_as_v1()
if not hasattr(self, '_thread_local'):
raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.')
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
return outputs
def _assert_built_as_v1(self):
if not hasattr(self, '_originally_built_as_v1'):
raise ValueError(
'Your Layer or Model is in an invalid state. This can happen if you '
'are interleaving estimator/non-estimator models or '
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
'with models/layers created outside of it. '
'Converting a model to an estimator (via model_to_estimator) '
'invalidates all models/layers made before the conversion (even '
'if they were not the model converted to an estimator). '
'Similarly, making a layer or a model inside a '
'a tf.compat.v1.Graph invalidates all layers/models you previously '
'made outside of the graph.')
@property
def dtype(self):
return self._dtype_policy.variable_dtype

View File

@ -303,6 +303,7 @@ class Model(training_lib.Model):
ValueError: In case of invalid arguments for
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
self._assert_built_as_v1()
self._run_eagerly = kwargs.pop('run_eagerly', None)
self._experimental_run_tf_function = kwargs.pop(
'experimental_run_tf_function', True)
@ -773,6 +774,7 @@ class Model(training_lib.Model):
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('fit_v1').set(True)
# Legacy support
if 'nb_epoch' in kwargs:
@ -893,6 +895,7 @@ class Model(training_lib.Model):
Raises:
ValueError: in case of invalid arguments.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('evaluate_v1').set(True)
self._assert_compile_was_called()
self._check_call_args('evaluate')
@ -972,6 +975,7 @@ class Model(training_lib.Model):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('predict_v1').set(True)
self._check_call_args('predict')