Fix corner case from recent Model.*_generator optimization

PiperOrigin-RevId: 276768411
Change-Id: If19a841effbde629cff8f71a3054f2cc4623838f
This commit is contained in:
Taylor Robie 2019-10-25 15:04:51 -07:00 committed by TensorFlower Gardener
parent 4afd931c76
commit 49feca4ec7
2 changed files with 12 additions and 6 deletions

View File

@ -809,7 +809,7 @@ class GeneratorDataAdapter(DataAdapter):
return t
return np.array(t, dtype=backend.floatx())
canonicalized_peek = nest.list_to_tuple(
canonicalized_peek = nest._list_to_tuple(
nest.map_structure(convert_for_inspection, peek[:elements_to_keep]))
nested_dtypes = nest.map_structure(lambda t: t.dtype, canonicalized_peek)
nested_shape = nest.map_structure(dynamic_shape_like, canonicalized_peek)
@ -851,7 +851,7 @@ class GeneratorDataAdapter(DataAdapter):
"""Optional compatibility layer between user's data and Dataset."""
must_prune_nones = (elements_to_keep != len(peek))
try:
nest.assert_same_structure(peek, nest.list_to_tuple(peek))
nest.assert_same_structure(peek, nest._list_to_tuple(peek))
must_extract_lists = False
except TypeError:
must_extract_lists = True
@ -868,7 +868,7 @@ class GeneratorDataAdapter(DataAdapter):
batch = (batch,)
if must_extract_lists:
batch = nest.list_to_tuple(batch)
batch = nest._list_to_tuple(batch)
if must_prune_nones:
batch = batch[:elements_to_keep]

View File

@ -2001,10 +2001,16 @@ def unpack_validation_data(validation_data, raise_if_ambiguous=True):
val_y = None
val_sample_weight = None
elif len(validation_data) == 2:
val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
val_sample_weight = None
try:
val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
val_sample_weight = None
except ValueError:
val_x, val_y, val_sample_weight = validation_data, None, None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
try:
val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
except ValueError:
val_x, val_y, val_sample_weight = validation_data, None, None
else:
if raise_if_ambiguous:
raise ValueError(