Raise error when calling .fit() w/ batch_size and a tf dataset
This commit is contained in:
parent
513fff7cf6
commit
8f278b5c18
@ -690,6 +690,7 @@ class DatasetAdapter(DataAdapter):
|
|||||||
y=None,
|
y=None,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
steps=None,
|
steps=None,
|
||||||
|
batch_size=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DatasetAdapter, self).__init__(x, y, **kwargs)
|
super(DatasetAdapter, self).__init__(x, y, **kwargs)
|
||||||
# Note that the dataset instance is immutable, its fine to reuse the user
|
# Note that the dataset instance is immutable, its fine to reuse the user
|
||||||
@ -699,7 +700,7 @@ class DatasetAdapter(DataAdapter):
|
|||||||
# The user-provided steps.
|
# The user-provided steps.
|
||||||
self._user_steps = steps
|
self._user_steps = steps
|
||||||
|
|
||||||
self._validate_args(y, sample_weights, steps)
|
self._validate_args(y, sample_weights, steps, batch_size)
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
@ -728,7 +729,7 @@ class DatasetAdapter(DataAdapter):
|
|||||||
return (self._user_steps is None or
|
return (self._user_steps is None or
|
||||||
cardinality.cardinality(self._dataset).numpy() == self._user_steps)
|
cardinality.cardinality(self._dataset).numpy() == self._user_steps)
|
||||||
|
|
||||||
def _validate_args(self, y, sample_weights, steps):
|
def _validate_args(self, y, sample_weights, steps, batch_size):
|
||||||
"""Validates `__init__` arguments."""
|
"""Validates `__init__` arguments."""
|
||||||
# Arguments that shouldn't be passed.
|
# Arguments that shouldn't be passed.
|
||||||
if not is_none_or_empty(y):
|
if not is_none_or_empty(y):
|
||||||
@ -738,6 +739,10 @@ class DatasetAdapter(DataAdapter):
|
|||||||
raise ValueError("`sample_weight` argument is not supported when using "
|
raise ValueError("`sample_weight` argument is not supported when using "
|
||||||
"dataset as input.")
|
"dataset as input.")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
raise ValueError("`batch_size` argument must not be specified when "
|
||||||
|
"using dataset as input.")
|
||||||
|
|
||||||
if steps is None:
|
if steps is None:
|
||||||
if _is_distributed_dataset(self._dataset):
|
if _is_distributed_dataset(self._dataset):
|
||||||
raise ValueError("When providing a distributed dataset, you must "
|
raise ValueError("When providing a distributed dataset, you must "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user