diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 3672ef64da3..33c868d02be 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -681,6 +681,7 @@ class DatasetAdapter(DataAdapter): y=None, sample_weights=None, steps=None, + batch_size=None, **kwargs): super(DatasetAdapter, self).__init__(x, y, **kwargs) # Note that the dataset instance is immutable, its fine to reuse the user @@ -690,7 +691,7 @@ class DatasetAdapter(DataAdapter): # The user-provided steps. self._user_steps = steps - self._validate_args(y, sample_weights, steps) + self._validate_args(y, sample_weights, steps, batch_size) def get_dataset(self): return self._dataset @@ -719,7 +720,7 @@ class DatasetAdapter(DataAdapter): return (self._user_steps is None or cardinality.cardinality(self._dataset).numpy() == self._user_steps) - def _validate_args(self, y, sample_weights, steps): + def _validate_args(self, y, sample_weights, steps, batch_size): """Validates `__init__` arguments.""" # Arguments that shouldn't be passed. if not is_none_or_empty(y): @@ -729,6 +730,10 @@ class DatasetAdapter(DataAdapter): raise ValueError("`sample_weight` argument is not supported when using " "dataset as input.") + if batch_size is not None: + raise ValueError("`batch_size` argument must not be specified when " + "using dataset as input.") + if steps is None: if _is_distributed_dataset(self._dataset): raise ValueError("When providing a distributed dataset, you must "