Avoid deprecation warning related to the use of `collections.Sequence`. It will stop working in Python 3.8.
PiperOrigin-RevId: 316965801 Change-Id: Ia44313b1920653a0dd0a94d404ac914b08239c43
This commit is contained in:
parent
cbf8f57413
commit
92d68dd5f1
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
|
@ -57,7 +56,6 @@ try:
|
|||
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
scipy_sparse = None
|
||||
|
||||
try:
|
||||
import pandas as pd # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
|
@ -786,7 +784,6 @@ class GeneratorDataAdapter(DataAdapter):
|
|||
# Since we have to know the dtype of the python generator when we build the
|
||||
# dataset, we have to look at a batch to infer the structure.
|
||||
peek, x = self._peek_and_restore(x)
|
||||
assert_not_namedtuple(peek)
|
||||
peek = self._standardize_batch(peek)
|
||||
peek = _process_tensorlike(peek)
|
||||
|
||||
|
@ -1070,21 +1067,6 @@ def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
|
|||
return sample_weight_modes
|
||||
|
||||
|
||||
def assert_not_namedtuple(x):
|
||||
if (isinstance(x, tuple) and
|
||||
# TODO(b/144192902): Use a namedtuple checking utility.
|
||||
hasattr(x, "_fields") and
|
||||
isinstance(x._fields, collections.Sequence) and
|
||||
all(isinstance(f, six.string_types) for f in x._fields)):
|
||||
raise ValueError(
|
||||
"Received namedtuple ({}) with fields `{}` as input. namedtuples "
|
||||
"cannot, in general, be unambiguously resolved into `x`, `y`, "
|
||||
"and `sample_weight`. For this reason Keras has elected not to "
|
||||
"support them. If you would like the value to be unpacked, "
|
||||
"please explicitly convert it to a tuple before passing it to "
|
||||
"Keras.".format(x.__class__, x._fields))
|
||||
|
||||
|
||||
class DataHandler(object):
|
||||
"""Handles iterating over epoch-level `tf.data.Iterator` objects."""
|
||||
|
||||
|
|
|
@ -19,8 +19,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
|
@ -49,6 +47,11 @@ from tensorflow.python.util import nest
|
|||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
try:
|
||||
from collections import abc as collections_abc # pylint: disable=g-import-not-at-top
|
||||
except ImportError: # For Python 2
|
||||
import collections as collections_abc # pylint: disable=g-import-not-at-top
|
||||
|
||||
|
||||
RECURRENT_DROPOUT_WARNING_MSG = (
|
||||
'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
|
||||
|
@ -828,7 +831,7 @@ class RNN(Layer):
|
|||
# input shape: `(samples, time (padded with zeros), input_dim)`
|
||||
# note that the .build() method of subclasses MUST define
|
||||
# self.input_spec and self.state_spec with complete input shapes.
|
||||
if (isinstance(inputs, collections.Sequence)
|
||||
if (isinstance(inputs, collections_abc.Sequence)
|
||||
and not isinstance(inputs, tuple)):
|
||||
# get initial_state from full input spec
|
||||
# as they could be copied to multiple GPU.
|
||||
|
|
Loading…
Reference in New Issue