Explicitly raise a (clearer) error message when models end up in invalid states due to interleaving graph and eager.
In rare cases code may have run w/o crashing when in these invalid states, but it's safer to error with an explanation rather than risk silent failures/fragile behavior. PiperOrigin-RevId: 321192744 Change-Id: I9e97ac3b7cea27c9b389e5202de9f1c09a4aa2b8
This commit is contained in:
parent
14b2d686d6
commit
ef4db27b31
|
@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
|
|||
# might want to turn it off, like Sequential model.
|
||||
self._auto_track_sub_layers = True
|
||||
|
||||
# Mark this layer as having been originally built as a tf1 layer/model
|
||||
self._originally_built_as_v1 = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
|
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
|
|||
ValueError: if the layer's `call` method returns None (an invalid value).
|
||||
RuntimeError: if `super().__init__()` was not called in the constructor.
|
||||
"""
|
||||
self._assert_built_as_v1()
|
||||
|
||||
if not hasattr(self, '_thread_local'):
|
||||
raise RuntimeError(
|
||||
'You must call `super().__init__()` in the layer constructor.')
|
||||
|
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
|
|||
|
||||
return outputs
|
||||
|
||||
def _assert_built_as_v1(self):
|
||||
if not hasattr(self, '_originally_built_as_v1'):
|
||||
raise ValueError(
|
||||
'Your Layer or Model is in an invalid state. This can happen if you '
|
||||
'are interleaving estimator/non-estimator models or '
|
||||
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
|
||||
'with models/layers created outside of it. '
|
||||
'Converting a model to an estimator (via model_to_estimator) '
|
||||
'invalidates all models/layers made before the conversion (even '
|
||||
'if they were not the model converted to an estimator). '
|
||||
'Similarly, making a layer or a model inside a '
|
||||
'a tf.compat.v1.Graph invalidates all layers/models you previously '
|
||||
'made outside of the graph.')
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype_policy.variable_dtype
|
||||
|
|
|
@ -303,6 +303,7 @@ class Model(training_lib.Model):
|
|||
ValueError: In case of invalid arguments for
|
||||
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
||||
"""
|
||||
self._assert_built_as_v1()
|
||||
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
||||
self._experimental_run_tf_function = kwargs.pop(
|
||||
'experimental_run_tf_function', True)
|
||||
|
@ -773,6 +774,7 @@ class Model(training_lib.Model):
|
|||
ValueError: In case of mismatch between the provided input data
|
||||
and what the model expects.
|
||||
"""
|
||||
self._assert_built_as_v1()
|
||||
_keras_api_gauge.get_cell('fit_v1').set(True)
|
||||
# Legacy support
|
||||
if 'nb_epoch' in kwargs:
|
||||
|
@ -893,6 +895,7 @@ class Model(training_lib.Model):
|
|||
Raises:
|
||||
ValueError: in case of invalid arguments.
|
||||
"""
|
||||
self._assert_built_as_v1()
|
||||
_keras_api_gauge.get_cell('evaluate_v1').set(True)
|
||||
self._assert_compile_was_called()
|
||||
self._check_call_args('evaluate')
|
||||
|
@ -972,6 +975,7 @@ class Model(training_lib.Model):
|
|||
or in case a stateful model receives a number of samples
|
||||
that is not a multiple of the batch size.
|
||||
"""
|
||||
self._assert_built_as_v1()
|
||||
_keras_api_gauge.get_cell('predict_v1').set(True)
|
||||
self._check_call_args('predict')
|
||||
|
||||
|
|
Loading…
Reference in New Issue