diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index ef1d2a94992..1f5077a75ae 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1959,6 +1959,9 @@ def _split_dataset_batch(dataset, split_batch_by): prefetch_buffer = None if isinstance(dataset, dataset_ops.PrefetchDataset): prefetch_buffer = dataset._buffer_size + elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) + and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): + prefetch_buffer = dataset._dataset._buffer_size # pylint: enable=protected-access if tensor_util.is_tensor(batch_size):