Update keras to use collections.abc instead of collections, which will be deprecated in py3.9

PiperOrigin-RevId: 339588609
Change-Id: I30553aa31227ff8e56a15509d0261ec64de3240c
This commit is contained in:
Scott Zhu 2020-10-28 20:12:11 -07:00 committed by TensorFlower Gardener
parent fb7a8d29bb
commit 9278b9421f
10 changed files with 28 additions and 24 deletions

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import copy
import csv
import io
@ -2608,7 +2609,7 @@ class CSVLogger(Callback):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types):
return k
elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray:
elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k)))
else:
return k

View File

@ -968,7 +968,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
form of datasets, generators, or `keras.utils.Sequence` instances
(since they generate batches).
validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.).
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
If an integer, specifies how many training epochs to run before a
new validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on

View File

@ -94,7 +94,7 @@ def model_iteration(model,
validation from data tensors). Ignored with the default value of
`None`.
validation_freq: Only relevant if validation data is provided. Integer or
`collections.Container` instance (e.g. list, tuple, etc.). If an
`collections.abc.Container` instance (e.g. list, tuple, etc.). If an
integer, specifies how many training epochs to run before a new
validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import abc
import atexit
import collections
import collections.abc as collections_abc
import functools
import multiprocessing.pool
import threading
@ -584,7 +585,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
'You should provide one `' + weight_type + '`'
'array per model output.')
return x_weight
if isinstance(x_weight, collections.Mapping):
if isinstance(x_weight, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
x_weights = []
for name in output_names:
@ -771,7 +772,7 @@ def collect_per_output_metric_info(metrics,
[metrics_module.clone_metric(m) for m in metrics])
else:
nested_metrics = [metrics]
elif isinstance(metrics, collections.Mapping):
elif isinstance(metrics, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
nested_metrics = []
for name in output_names:
@ -1087,7 +1088,7 @@ def get_loss_function(loss):
'before passing them to Model.compile.'.format(loss))
# Deserialize loss configuration, if needed.
if isinstance(loss, collections.Mapping):
if isinstance(loss, collections_abc.Mapping):
loss = losses.get(loss)
# Custom callable class.
@ -1304,7 +1305,7 @@ def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
ValueError: In case of invalid `sample_weight_mode` input.
"""
if isinstance(sample_weight_mode, collections.Mapping):
if isinstance(sample_weight_mode, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys(
'sample_weight_mode', sample_weight_mode,
[e.output_name for e in training_endpoints])
@ -1351,7 +1352,7 @@ def prepare_loss_functions(loss, output_names):
ValueError: If loss is a dict with keys not in model output names,
or if loss is a list with len not equal to model outputs.
"""
if isinstance(loss, collections.Mapping):
if isinstance(loss, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys('loss', loss, output_names)
loss_functions = []
for name in output_names:
@ -1363,7 +1364,7 @@ def prepare_loss_functions(loss, output_names):
loss_functions.append(get_loss_function(loss.get(name, None)))
elif isinstance(loss, six.string_types):
loss_functions = [get_loss_function(loss) for _ in output_names]
elif isinstance(loss, collections.Sequence):
elif isinstance(loss, collections_abc.Sequence):
if len(loss) != len(output_names):
raise ValueError('When passing a list as loss, it should have one entry '
'per model outputs. The model has {} outputs, but you '
@ -1397,7 +1398,7 @@ def prepare_loss_weights(training_endpoints, loss_weights=None):
if loss_weights is None:
for e in training_endpoints:
e.loss_weight = 1.
elif isinstance(loss_weights, collections.Mapping):
elif isinstance(loss_weights, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys(
'loss_weights', loss_weights,
[e.output_name for e in training_endpoints])
@ -1704,9 +1705,9 @@ def should_run_validation(validation_freq, epoch):
raise ValueError('`validation_freq` can not be less than 1.')
return one_indexed_epoch % validation_freq == 0
if not isinstance(validation_freq, collections.Container):
if not isinstance(validation_freq, collections_abc.Container):
raise ValueError('`validation_freq` must be an Integer or '
'`collections.Container` (e.g. list, tuple, etc.)')
'`collections.abc.Container` (e.g. list, tuple, etc.)')
return one_indexed_epoch in validation_freq

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import warnings
import numpy as np
@ -745,7 +746,7 @@ class Model(training_lib.Model):
the dataset at each epoch. This ensures that the same validation
samples are used every time.
validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.).
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
If an integer, specifies how many training epochs to run before a
new validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on
@ -1207,7 +1208,7 @@ class Model(training_lib.Model):
# at this point.
if self.run_eagerly or self._distribution_strategy:
inputs = training_utils_v1.cast_if_floating_dtype(inputs)
if isinstance(inputs, collections.Sequence):
if isinstance(inputs, collections_abc.Sequence):
# Unwrap lists with only one input, as we do when training on batch
if len(inputs) == 1:
inputs = inputs[0]

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import functools
import itertools
import unittest
@ -473,7 +473,7 @@ def _test_or_class_decorator(test_or_class, single_method_decorator):
The decorated result.
"""
def _decorate_test_or_class(obj):
if isinstance(obj, collections.Iterable):
if isinstance(obj, collections_abc.Iterable):
return itertools.chain.from_iterable(
single_method_decorator(method) for method in obj)
if isinstance(obj, type):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import numpy as np
from tensorflow.python.platform import test
@ -37,7 +37,7 @@ class PreprocessingLayerTest(test.TestCase):
self.assertEqual(len(a), len(b))
for a_value, b_value in zip(a, b):
self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
elif isinstance(a, collections.Mapping):
elif isinstance(a, collections_abc.Mapping):
self.assertEqual(len(a), len(b))
for key, a_value in a.items():
b_value = b[key]

View File

@ -19,7 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import warnings
import numpy as np
@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
@ -828,7 +829,7 @@ class RNN(Layer):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if (isinstance(inputs, collections.Sequence)
if (isinstance(inputs, collections_abc.Sequence)
and not isinstance(inputs, tuple)):
# get initial_state from full input spec
# as they could be copied to multiple GPU.

View File

@ -25,7 +25,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import json
import numpy as np
import wrapt
@ -117,7 +117,7 @@ def get_json_type(obj):
if isinstance(obj, dtypes.DType):
return obj.name
if isinstance(obj, collections.Mapping):
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if obj is Ellipsis:

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import collections.abc as collections_abc
import copy
import os
import six
@ -82,7 +82,7 @@ def model_input_signature(model, keep_original_batch_size=False):
input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs,
collections.Sequence) and len(input_specs) == 1:
collections_abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs