Update keras to use collections.abc instead of collections, which will be deprecated in py3.9

PiperOrigin-RevId: 339588609
Change-Id: I30553aa31227ff8e56a15509d0261ec64de3240c
This commit is contained in:
Scott Zhu 2020-10-28 20:12:11 -07:00 committed by TensorFlower Gardener
parent fb7a8d29bb
commit 9278b9421f
10 changed files with 28 additions and 24 deletions

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import collections.abc as collections_abc
import copy import copy
import csv import csv
import io import io
@ -2608,7 +2609,7 @@ class CSVLogger(Callback):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types): if isinstance(k, six.string_types):
return k return k
elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray: elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k))) return '"[%s]"' % (', '.join(map(str, k)))
else: else:
return k return k

View File

@ -968,7 +968,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
form of datasets, generators, or `keras.utils.Sequence` instances form of datasets, generators, or `keras.utils.Sequence` instances
(since they generate batches). (since they generate batches).
validation_freq: Only relevant if validation data is provided. Integer validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.). or `collections.abc.Container` instance (e.g. list, tuple, etc.).
If an integer, specifies how many training epochs to run before a If an integer, specifies how many training epochs to run before a
new validation run is performed, e.g. `validation_freq=2` runs new validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on validation every 2 epochs. If a Container, specifies the epochs on

View File

@ -94,7 +94,7 @@ def model_iteration(model,
validation from data tensors). Ignored with the default value of validation from data tensors). Ignored with the default value of
`None`. `None`.
validation_freq: Only relevant if validation data is provided. Integer or validation_freq: Only relevant if validation data is provided. Integer or
`collections.Container` instance (e.g. list, tuple, etc.). If an `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
integer, specifies how many training epochs to run before a new integer, specifies how many training epochs to run before a new
validation run is performed, e.g. `validation_freq=2` runs validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on validation every 2 epochs. If a Container, specifies the epochs on

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import abc import abc
import atexit import atexit
import collections import collections
import collections.abc as collections_abc
import functools import functools
import multiprocessing.pool import multiprocessing.pool
import threading import threading
@ -584,7 +585,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
'You should provide one `' + weight_type + '`' 'You should provide one `' + weight_type + '`'
'array per model output.') 'array per model output.')
return x_weight return x_weight
if isinstance(x_weight, collections.Mapping): if isinstance(x_weight, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
x_weights = [] x_weights = []
for name in output_names: for name in output_names:
@ -771,7 +772,7 @@ def collect_per_output_metric_info(metrics,
[metrics_module.clone_metric(m) for m in metrics]) [metrics_module.clone_metric(m) for m in metrics])
else: else:
nested_metrics = [metrics] nested_metrics = [metrics]
elif isinstance(metrics, collections.Mapping): elif isinstance(metrics, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
nested_metrics = [] nested_metrics = []
for name in output_names: for name in output_names:
@ -1087,7 +1088,7 @@ def get_loss_function(loss):
'before passing them to Model.compile.'.format(loss)) 'before passing them to Model.compile.'.format(loss))
# Deserialize loss configuration, if needed. # Deserialize loss configuration, if needed.
if isinstance(loss, collections.Mapping): if isinstance(loss, collections_abc.Mapping):
loss = losses.get(loss) loss = losses.get(loss)
# Custom callable class. # Custom callable class.
@ -1304,7 +1305,7 @@ def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
ValueError: In case of invalid `sample_weight_mode` input. ValueError: In case of invalid `sample_weight_mode` input.
""" """
if isinstance(sample_weight_mode, collections.Mapping): if isinstance(sample_weight_mode, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys( generic_utils.check_for_unexpected_keys(
'sample_weight_mode', sample_weight_mode, 'sample_weight_mode', sample_weight_mode,
[e.output_name for e in training_endpoints]) [e.output_name for e in training_endpoints])
@ -1351,7 +1352,7 @@ def prepare_loss_functions(loss, output_names):
ValueError: If loss is a dict with keys not in model output names, ValueError: If loss is a dict with keys not in model output names,
or if loss is a list with len not equal to model outputs. or if loss is a list with len not equal to model outputs.
""" """
if isinstance(loss, collections.Mapping): if isinstance(loss, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys('loss', loss, output_names) generic_utils.check_for_unexpected_keys('loss', loss, output_names)
loss_functions = [] loss_functions = []
for name in output_names: for name in output_names:
@ -1363,7 +1364,7 @@ def prepare_loss_functions(loss, output_names):
loss_functions.append(get_loss_function(loss.get(name, None))) loss_functions.append(get_loss_function(loss.get(name, None)))
elif isinstance(loss, six.string_types): elif isinstance(loss, six.string_types):
loss_functions = [get_loss_function(loss) for _ in output_names] loss_functions = [get_loss_function(loss) for _ in output_names]
elif isinstance(loss, collections.Sequence): elif isinstance(loss, collections_abc.Sequence):
if len(loss) != len(output_names): if len(loss) != len(output_names):
raise ValueError('When passing a list as loss, it should have one entry ' raise ValueError('When passing a list as loss, it should have one entry '
'per model outputs. The model has {} outputs, but you ' 'per model outputs. The model has {} outputs, but you '
@ -1397,7 +1398,7 @@ def prepare_loss_weights(training_endpoints, loss_weights=None):
if loss_weights is None: if loss_weights is None:
for e in training_endpoints: for e in training_endpoints:
e.loss_weight = 1. e.loss_weight = 1.
elif isinstance(loss_weights, collections.Mapping): elif isinstance(loss_weights, collections_abc.Mapping):
generic_utils.check_for_unexpected_keys( generic_utils.check_for_unexpected_keys(
'loss_weights', loss_weights, 'loss_weights', loss_weights,
[e.output_name for e in training_endpoints]) [e.output_name for e in training_endpoints])
@ -1704,9 +1705,9 @@ def should_run_validation(validation_freq, epoch):
raise ValueError('`validation_freq` can not be less than 1.') raise ValueError('`validation_freq` can not be less than 1.')
return one_indexed_epoch % validation_freq == 0 return one_indexed_epoch % validation_freq == 0
if not isinstance(validation_freq, collections.Container): if not isinstance(validation_freq, collections_abc.Container):
raise ValueError('`validation_freq` must be an Integer or ' raise ValueError('`validation_freq` must be an Integer or '
'`collections.Container` (e.g. list, tuple, etc.)') '`collections.abc.Container` (e.g. list, tuple, etc.)')
return one_indexed_epoch in validation_freq return one_indexed_epoch in validation_freq

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import collections.abc as collections_abc
import warnings import warnings
import numpy as np import numpy as np
@ -745,7 +746,7 @@ class Model(training_lib.Model):
the dataset at each epoch. This ensures that the same validation the dataset at each epoch. This ensures that the same validation
samples are used every time. samples are used every time.
validation_freq: Only relevant if validation data is provided. Integer validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.). or `collections.abc.Container` instance (e.g. list, tuple, etc.).
If an integer, specifies how many training epochs to run before a If an integer, specifies how many training epochs to run before a
new validation run is performed, e.g. `validation_freq=2` runs new validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on validation every 2 epochs. If a Container, specifies the epochs on
@ -1207,7 +1208,7 @@ class Model(training_lib.Model):
# at this point. # at this point.
if self.run_eagerly or self._distribution_strategy: if self.run_eagerly or self._distribution_strategy:
inputs = training_utils_v1.cast_if_floating_dtype(inputs) inputs = training_utils_v1.cast_if_floating_dtype(inputs)
if isinstance(inputs, collections.Sequence): if isinstance(inputs, collections_abc.Sequence):
# Unwrap lists with only one input, as we do when training on batch # Unwrap lists with only one input, as we do when training on batch
if len(inputs) == 1: if len(inputs) == 1:
inputs = inputs[0] inputs = inputs[0]

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections.abc as collections_abc
import functools import functools
import itertools import itertools
import unittest import unittest
@ -473,7 +473,7 @@ def _test_or_class_decorator(test_or_class, single_method_decorator):
The decorated result. The decorated result.
""" """
def _decorate_test_or_class(obj): def _decorate_test_or_class(obj):
if isinstance(obj, collections.Iterable): if isinstance(obj, collections_abc.Iterable):
return itertools.chain.from_iterable( return itertools.chain.from_iterable(
single_method_decorator(method) for method in obj) single_method_decorator(method) for method in obj)
if isinstance(obj, type): if isinstance(obj, type):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections.abc as collections_abc
import numpy as np import numpy as np
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -37,7 +37,7 @@ class PreprocessingLayerTest(test.TestCase):
self.assertEqual(len(a), len(b)) self.assertEqual(len(a), len(b))
for a_value, b_value in zip(a, b): for a_value, b_value in zip(a, b):
self.assertAllCloseOrEqual(a_value, b_value, msg=msg) self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
elif isinstance(a, collections.Mapping): elif isinstance(a, collections_abc.Mapping):
self.assertEqual(len(a), len(b)) self.assertEqual(len(a), len(b))
for key, a_value in a.items(): for key, a_value in a.items():
b_value = b[key] b_value = b[key]

View File

@ -19,7 +19,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections.abc as collections_abc
import warnings import warnings
import numpy as np import numpy as np
@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls
@ -828,7 +829,7 @@ class RNN(Layer):
# input shape: `(samples, time (padded with zeros), input_dim)` # input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define # note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes. # self.input_spec and self.state_spec with complete input shapes.
if (isinstance(inputs, collections.Sequence) if (isinstance(inputs, collections_abc.Sequence)
and not isinstance(inputs, tuple)): and not isinstance(inputs, tuple)):
# get initial_state from full input spec # get initial_state from full input spec
# as they could be copied to multiple GPU. # as they could be copied to multiple GPU.

View File

@ -25,7 +25,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections.abc as collections_abc
import json import json
import numpy as np import numpy as np
import wrapt import wrapt
@ -117,7 +117,7 @@ def get_json_type(obj):
if isinstance(obj, dtypes.DType): if isinstance(obj, dtypes.DType):
return obj.name return obj.name
if isinstance(obj, collections.Mapping): if isinstance(obj, collections_abc.Mapping):
return dict(obj) return dict(obj)
if obj is Ellipsis: if obj is Ellipsis:

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections.abc as collections_abc
import copy import copy
import os import os
import six import six
@ -82,7 +82,7 @@ def model_input_signature(model, keep_original_batch_size=False):
input_specs = _enforce_names_consistency(input_specs) input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature. # Return a list with a single element as the model's input signature.
if isinstance(input_specs, if isinstance(input_specs,
collections.Sequence) and len(input_specs) == 1: collections_abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries, # Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list. # which should also be wrapped as a single-element list.
return input_specs return input_specs