Avoid deprecation warning related to the use of `collections.Sequence`. It will stop working in Python 3.8.
PiperOrigin-RevId: 316965801 Change-Id: Ia44313b1920653a0dd0a94d404ac914b08239c43
This commit is contained in:
parent
cbf8f57413
commit
92d68dd5f1
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -57,7 +56,6 @@ try:
|
||||||
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
||||||
except ImportError:
|
except ImportError:
|
||||||
scipy_sparse = None
|
scipy_sparse = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pandas as pd # pylint: disable=g-import-not-at-top
|
import pandas as pd # pylint: disable=g-import-not-at-top
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -786,7 +784,6 @@ class GeneratorDataAdapter(DataAdapter):
|
||||||
# Since we have to know the dtype of the python generator when we build the
|
# Since we have to know the dtype of the python generator when we build the
|
||||||
# dataset, we have to look at a batch to infer the structure.
|
# dataset, we have to look at a batch to infer the structure.
|
||||||
peek, x = self._peek_and_restore(x)
|
peek, x = self._peek_and_restore(x)
|
||||||
assert_not_namedtuple(peek)
|
|
||||||
peek = self._standardize_batch(peek)
|
peek = self._standardize_batch(peek)
|
||||||
peek = _process_tensorlike(peek)
|
peek = _process_tensorlike(peek)
|
||||||
|
|
||||||
|
@ -1070,21 +1067,6 @@ def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
|
||||||
return sample_weight_modes
|
return sample_weight_modes
|
||||||
|
|
||||||
|
|
||||||
def assert_not_namedtuple(x):
|
|
||||||
if (isinstance(x, tuple) and
|
|
||||||
# TODO(b/144192902): Use a namedtuple checking utility.
|
|
||||||
hasattr(x, "_fields") and
|
|
||||||
isinstance(x._fields, collections.Sequence) and
|
|
||||||
all(isinstance(f, six.string_types) for f in x._fields)):
|
|
||||||
raise ValueError(
|
|
||||||
"Received namedtuple ({}) with fields `{}` as input. namedtuples "
|
|
||||||
"cannot, in general, be unambiguously resolved into `x`, `y`, "
|
|
||||||
"and `sample_weight`. For this reason Keras has elected not to "
|
|
||||||
"support them. If you would like the value to be unpacked, "
|
|
||||||
"please explicitly convert it to a tuple before passing it to "
|
|
||||||
"Keras.".format(x.__class__, x._fields))
|
|
||||||
|
|
||||||
|
|
||||||
class DataHandler(object):
|
class DataHandler(object):
|
||||||
"""Handles iterating over epoch-level `tf.data.Iterator` objects."""
|
"""Handles iterating over epoch-level `tf.data.Iterator` objects."""
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,6 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
|
@ -49,6 +47,11 @@ from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
|
try:
|
||||||
|
from collections import abc as collections_abc # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError: # For Python 2
|
||||||
|
import collections as collections_abc # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
|
|
||||||
RECURRENT_DROPOUT_WARNING_MSG = (
|
RECURRENT_DROPOUT_WARNING_MSG = (
|
||||||
'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
|
'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
|
||||||
|
@ -828,7 +831,7 @@ class RNN(Layer):
|
||||||
# input shape: `(samples, time (padded with zeros), input_dim)`
|
# input shape: `(samples, time (padded with zeros), input_dim)`
|
||||||
# note that the .build() method of subclasses MUST define
|
# note that the .build() method of subclasses MUST define
|
||||||
# self.input_spec and self.state_spec with complete input shapes.
|
# self.input_spec and self.state_spec with complete input shapes.
|
||||||
if (isinstance(inputs, collections.Sequence)
|
if (isinstance(inputs, collections_abc.Sequence)
|
||||||
and not isinstance(inputs, tuple)):
|
and not isinstance(inputs, tuple)):
|
||||||
# get initial_state from full input spec
|
# get initial_state from full input spec
|
||||||
# as they could be copied to multiple GPU.
|
# as they could be copied to multiple GPU.
|
||||||
|
|
Loading…
Reference in New Issue