Fix corner case from recent Model.*_generator optimization
PiperOrigin-RevId: 276768411 Change-Id: If19a841effbde629cff8f71a3054f2cc4623838f
This commit is contained in:
parent
4afd931c76
commit
49feca4ec7
@ -809,7 +809,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
return t
|
||||
return np.array(t, dtype=backend.floatx())
|
||||
|
||||
canonicalized_peek = nest.list_to_tuple(
|
||||
canonicalized_peek = nest._list_to_tuple(
|
||||
nest.map_structure(convert_for_inspection, peek[:elements_to_keep]))
|
||||
nested_dtypes = nest.map_structure(lambda t: t.dtype, canonicalized_peek)
|
||||
nested_shape = nest.map_structure(dynamic_shape_like, canonicalized_peek)
|
||||
@ -851,7 +851,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
"""Optional compatibility layer between user's data and Dataset."""
|
||||
must_prune_nones = (elements_to_keep != len(peek))
|
||||
try:
|
||||
nest.assert_same_structure(peek, nest.list_to_tuple(peek))
|
||||
nest.assert_same_structure(peek, nest._list_to_tuple(peek))
|
||||
must_extract_lists = False
|
||||
except TypeError:
|
||||
must_extract_lists = True
|
||||
@ -868,7 +868,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
batch = (batch,)
|
||||
|
||||
if must_extract_lists:
|
||||
batch = nest.list_to_tuple(batch)
|
||||
batch = nest._list_to_tuple(batch)
|
||||
|
||||
if must_prune_nones:
|
||||
batch = batch[:elements_to_keep]
|
||||
|
@ -2001,10 +2001,16 @@ def unpack_validation_data(validation_data, raise_if_ambiguous=True):
|
||||
val_y = None
|
||||
val_sample_weight = None
|
||||
elif len(validation_data) == 2:
|
||||
val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
|
||||
val_sample_weight = None
|
||||
try:
|
||||
val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
|
||||
val_sample_weight = None
|
||||
except ValueError:
|
||||
val_x, val_y, val_sample_weight = validation_data, None, None
|
||||
elif len(validation_data) == 3:
|
||||
val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
|
||||
try:
|
||||
val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
|
||||
except ValueError:
|
||||
val_x, val_y, val_sample_weight = validation_data, None, None
|
||||
else:
|
||||
if raise_if_ambiguous:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user