Raise error when calling .fit() w/ batch_size and a tf dataset

This commit is contained in:
Jonah Kohn 2020-07-22 16:16:29 -07:00 committed by GitHub
parent 513fff7cf6
commit 8f278b5c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -690,6 +690,7 @@ class DatasetAdapter(DataAdapter):
y=None, y=None,
sample_weights=None, sample_weights=None,
steps=None, steps=None,
batch_size=None,
**kwargs): **kwargs):
super(DatasetAdapter, self).__init__(x, y, **kwargs) super(DatasetAdapter, self).__init__(x, y, **kwargs)
# Note that the dataset instance is immutable, its fine to reuse the user # Note that the dataset instance is immutable, its fine to reuse the user
@ -699,7 +700,7 @@ class DatasetAdapter(DataAdapter):
# The user-provided steps. # The user-provided steps.
self._user_steps = steps self._user_steps = steps
self._validate_args(y, sample_weights, steps) self._validate_args(y, sample_weights, steps, batch_size)
def get_dataset(self): def get_dataset(self):
return self._dataset return self._dataset
@ -728,7 +729,7 @@ class DatasetAdapter(DataAdapter):
return (self._user_steps is None or return (self._user_steps is None or
cardinality.cardinality(self._dataset).numpy() == self._user_steps) cardinality.cardinality(self._dataset).numpy() == self._user_steps)
def _validate_args(self, y, sample_weights, steps): def _validate_args(self, y, sample_weights, steps, batch_size):
"""Validates `__init__` arguments.""" """Validates `__init__` arguments."""
# Arguments that shouldn't be passed. # Arguments that shouldn't be passed.
if not is_none_or_empty(y): if not is_none_or_empty(y):
@ -738,6 +739,10 @@ class DatasetAdapter(DataAdapter):
raise ValueError("`sample_weight` argument is not supported when using " raise ValueError("`sample_weight` argument is not supported when using "
"dataset as input.") "dataset as input.")
if batch_size is not None:
raise ValueError("`batch_size` argument must not be specified when "
"using dataset as input.")
if steps is None: if steps is None:
if _is_distributed_dataset(self._dataset): if _is_distributed_dataset(self._dataset):
raise ValueError("When providing a distributed dataset, you must " raise ValueError("When providing a distributed dataset, you must "