Relax shapes for Keras _on_batch functions.
The current train,test,predict_on_batch functions use a regular tf.function when not in eager mode, which causes a retrace for every new batch size. Similarly, if sequences are passed on input, every different sequence size causes a retrace. Passing experimental_relax_shapes=True allow gracefully handle these cases. Fixes #34907.
This commit is contained in:
parent
115ea3db34
commit
f835a4a795
@ -119,7 +119,9 @@ def _make_on_batch_function(model, mode):
|
||||
func = model
|
||||
|
||||
if not model.run_eagerly:
|
||||
func = def_function.function(func)
|
||||
# Pass `experimental_relax_shapes` to avoid retracing for dynamic batch size,
|
||||
# variable length sequences, etc.
|
||||
func = def_function.function(func, experimental_relax_shapes=True)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user