Merge pull request #41640 from jonah-kohn:master
PiperOrigin-RevId: 322950441 Change-Id: I99f8334bed734f4ece7bbc4c55d3d7c7ea734aba
This commit is contained in:
commit
a667a0c934
@ -681,6 +681,7 @@ class DatasetAdapter(DataAdapter):
|
||||
y=None,
|
||||
sample_weights=None,
|
||||
steps=None,
|
||||
batch_size=None,
|
||||
**kwargs):
|
||||
super(DatasetAdapter, self).__init__(x, y, **kwargs)
|
||||
# Note that the dataset instance is immutable, its fine to reuse the user
|
||||
@ -690,7 +691,7 @@ class DatasetAdapter(DataAdapter):
|
||||
# The user-provided steps.
|
||||
self._user_steps = steps
|
||||
|
||||
self._validate_args(y, sample_weights, steps)
|
||||
self._validate_args(y, sample_weights, steps, batch_size)
|
||||
|
||||
def get_dataset(self):
|
||||
return self._dataset
|
||||
@ -719,7 +720,7 @@ class DatasetAdapter(DataAdapter):
|
||||
return (self._user_steps is None or
|
||||
cardinality.cardinality(self._dataset).numpy() == self._user_steps)
|
||||
|
||||
def _validate_args(self, y, sample_weights, steps):
|
||||
def _validate_args(self, y, sample_weights, steps, batch_size):
|
||||
"""Validates `__init__` arguments."""
|
||||
# Arguments that shouldn't be passed.
|
||||
if not is_none_or_empty(y):
|
||||
@ -729,6 +730,10 @@ class DatasetAdapter(DataAdapter):
|
||||
raise ValueError("`sample_weight` argument is not supported when using "
|
||||
"dataset as input.")
|
||||
|
||||
if batch_size is not None:
|
||||
raise ValueError("`batch_size` argument must not be specified when "
|
||||
"using dataset as input.")
|
||||
|
||||
if steps is None:
|
||||
if _is_distributed_dataset(self._dataset):
|
||||
raise ValueError("When providing a distributed dataset, you must "
|
||||
|
Loading…
x
Reference in New Issue
Block a user