Relax shapes for Keras _on_batch functions.

The current train,test,predict_on_batch functions use a regular
tf.function when not in eager mode, which causes a retrace for every
new batch size. Similarly, if sequences are passed on input, every
different sequence size causes a retrace.

Passing experimental_relax_shapes=True allow gracefully handle these
cases.

Fixes #34907.
This commit is contained in:
Milan Straka 2019-12-11 08:40:37 +01:00
parent 115ea3db34
commit f835a4a795

View File

@ -119,7 +119,9 @@ def _make_on_batch_function(model, mode):
func = model
if not model.run_eagerly:
func = def_function.function(func)
# Pass `experimental_relax_shapes` to avoid retracing for dynamic batch size,
# variable length sequences, etc.
func = def_function.function(func, experimental_relax_shapes=True)
return func