Explicitly raise a (clearer) error message when models end up in invalid states due to interleaving graph and eager.
In rare cases code may have run w/o crashing when in these invalid states, but it's safer to error with an explanation rather than risk silent failures/fragile behavior. PiperOrigin-RevId: 321192744 Change-Id: I9e97ac3b7cea27c9b389e5202de9f1c09a4aa2b8
This commit is contained in:
parent
14b2d686d6
commit
ef4db27b31
|
@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
|
||||||
# might want to turn it off, like Sequential model.
|
# might want to turn it off, like Sequential model.
|
||||||
self._auto_track_sub_layers = True
|
self._auto_track_sub_layers = True
|
||||||
|
|
||||||
|
# Mark this layer as having been originally built as a tf1 layer/model
|
||||||
|
self._originally_built_as_v1 = True
|
||||||
|
|
||||||
@trackable.no_automatic_dependency_tracking
|
@trackable.no_automatic_dependency_tracking
|
||||||
@generic_utils.default
|
@generic_utils.default
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
|
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
|
||||||
ValueError: if the layer's `call` method returns None (an invalid value).
|
ValueError: if the layer's `call` method returns None (an invalid value).
|
||||||
RuntimeError: if `super().__init__()` was not called in the constructor.
|
RuntimeError: if `super().__init__()` was not called in the constructor.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
|
|
||||||
if not hasattr(self, '_thread_local'):
|
if not hasattr(self, '_thread_local'):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'You must call `super().__init__()` in the layer constructor.')
|
'You must call `super().__init__()` in the layer constructor.')
|
||||||
|
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def _assert_built_as_v1(self):
|
||||||
|
if not hasattr(self, '_originally_built_as_v1'):
|
||||||
|
raise ValueError(
|
||||||
|
'Your Layer or Model is in an invalid state. This can happen if you '
|
||||||
|
'are interleaving estimator/non-estimator models or '
|
||||||
|
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
|
||||||
|
'with models/layers created outside of it. '
|
||||||
|
'Converting a model to an estimator (via model_to_estimator) '
|
||||||
|
'invalidates all models/layers made before the conversion (even '
|
||||||
|
'if they were not the model converted to an estimator). '
|
||||||
|
'Similarly, making a layer or a model inside a '
|
||||||
|
'a tf.compat.v1.Graph invalidates all layers/models you previously '
|
||||||
|
'made outside of the graph.')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self._dtype_policy.variable_dtype
|
return self._dtype_policy.variable_dtype
|
||||||
|
|
|
@ -303,6 +303,7 @@ class Model(training_lib.Model):
|
||||||
ValueError: In case of invalid arguments for
|
ValueError: In case of invalid arguments for
|
||||||
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
||||||
self._experimental_run_tf_function = kwargs.pop(
|
self._experimental_run_tf_function = kwargs.pop(
|
||||||
'experimental_run_tf_function', True)
|
'experimental_run_tf_function', True)
|
||||||
|
@ -773,6 +774,7 @@ class Model(training_lib.Model):
|
||||||
ValueError: In case of mismatch between the provided input data
|
ValueError: In case of mismatch between the provided input data
|
||||||
and what the model expects.
|
and what the model expects.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('fit_v1').set(True)
|
_keras_api_gauge.get_cell('fit_v1').set(True)
|
||||||
# Legacy support
|
# Legacy support
|
||||||
if 'nb_epoch' in kwargs:
|
if 'nb_epoch' in kwargs:
|
||||||
|
@ -893,6 +895,7 @@ class Model(training_lib.Model):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid arguments.
|
ValueError: in case of invalid arguments.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('evaluate_v1').set(True)
|
_keras_api_gauge.get_cell('evaluate_v1').set(True)
|
||||||
self._assert_compile_was_called()
|
self._assert_compile_was_called()
|
||||||
self._check_call_args('evaluate')
|
self._check_call_args('evaluate')
|
||||||
|
@ -972,6 +975,7 @@ class Model(training_lib.Model):
|
||||||
or in case a stateful model receives a number of samples
|
or in case a stateful model receives a number of samples
|
||||||
that is not a multiple of the batch size.
|
that is not a multiple of the batch size.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('predict_v1').set(True)
|
_keras_api_gauge.get_cell('predict_v1').set(True)
|
||||||
self._check_call_args('predict')
|
self._check_call_args('predict')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue