Update keras to use collections.abc instead of collections, which will be deprecated in py3.9
PiperOrigin-RevId: 339588609 Change-Id: I30553aa31227ff8e56a15509d0261ec64de3240c
This commit is contained in:
parent
fb7a8d29bb
commit
9278b9421f
tensorflow/python/keras
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import copy
|
||||
import csv
|
||||
import io
|
||||
@ -2608,7 +2609,7 @@ class CSVLogger(Callback):
|
||||
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
||||
if isinstance(k, six.string_types):
|
||||
return k
|
||||
elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray:
|
||||
elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
|
||||
return '"[%s]"' % (', '.join(map(str, k)))
|
||||
else:
|
||||
return k
|
||||
|
@ -968,7 +968,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
form of datasets, generators, or `keras.utils.Sequence` instances
|
||||
(since they generate batches).
|
||||
validation_freq: Only relevant if validation data is provided. Integer
|
||||
or `collections.Container` instance (e.g. list, tuple, etc.).
|
||||
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
|
||||
If an integer, specifies how many training epochs to run before a
|
||||
new validation run is performed, e.g. `validation_freq=2` runs
|
||||
validation every 2 epochs. If a Container, specifies the epochs on
|
||||
|
@ -94,7 +94,7 @@ def model_iteration(model,
|
||||
validation from data tensors). Ignored with the default value of
|
||||
`None`.
|
||||
validation_freq: Only relevant if validation data is provided. Integer or
|
||||
`collections.Container` instance (e.g. list, tuple, etc.). If an
|
||||
`collections.abc.Container` instance (e.g. list, tuple, etc.). If an
|
||||
integer, specifies how many training epochs to run before a new
|
||||
validation run is performed, e.g. `validation_freq=2` runs
|
||||
validation every 2 epochs. If a Container, specifies the epochs on
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import abc
|
||||
import atexit
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import functools
|
||||
import multiprocessing.pool
|
||||
import threading
|
||||
@ -584,7 +585,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
||||
'You should provide one `' + weight_type + '`'
|
||||
'array per model output.')
|
||||
return x_weight
|
||||
if isinstance(x_weight, collections.Mapping):
|
||||
if isinstance(x_weight, collections_abc.Mapping):
|
||||
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
|
||||
x_weights = []
|
||||
for name in output_names:
|
||||
@ -771,7 +772,7 @@ def collect_per_output_metric_info(metrics,
|
||||
[metrics_module.clone_metric(m) for m in metrics])
|
||||
else:
|
||||
nested_metrics = [metrics]
|
||||
elif isinstance(metrics, collections.Mapping):
|
||||
elif isinstance(metrics, collections_abc.Mapping):
|
||||
generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
|
||||
nested_metrics = []
|
||||
for name in output_names:
|
||||
@ -1087,7 +1088,7 @@ def get_loss_function(loss):
|
||||
'before passing them to Model.compile.'.format(loss))
|
||||
|
||||
# Deserialize loss configuration, if needed.
|
||||
if isinstance(loss, collections.Mapping):
|
||||
if isinstance(loss, collections_abc.Mapping):
|
||||
loss = losses.get(loss)
|
||||
|
||||
# Custom callable class.
|
||||
@ -1304,7 +1305,7 @@ def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
|
||||
ValueError: In case of invalid `sample_weight_mode` input.
|
||||
"""
|
||||
|
||||
if isinstance(sample_weight_mode, collections.Mapping):
|
||||
if isinstance(sample_weight_mode, collections_abc.Mapping):
|
||||
generic_utils.check_for_unexpected_keys(
|
||||
'sample_weight_mode', sample_weight_mode,
|
||||
[e.output_name for e in training_endpoints])
|
||||
@ -1351,7 +1352,7 @@ def prepare_loss_functions(loss, output_names):
|
||||
ValueError: If loss is a dict with keys not in model output names,
|
||||
or if loss is a list with len not equal to model outputs.
|
||||
"""
|
||||
if isinstance(loss, collections.Mapping):
|
||||
if isinstance(loss, collections_abc.Mapping):
|
||||
generic_utils.check_for_unexpected_keys('loss', loss, output_names)
|
||||
loss_functions = []
|
||||
for name in output_names:
|
||||
@ -1363,7 +1364,7 @@ def prepare_loss_functions(loss, output_names):
|
||||
loss_functions.append(get_loss_function(loss.get(name, None)))
|
||||
elif isinstance(loss, six.string_types):
|
||||
loss_functions = [get_loss_function(loss) for _ in output_names]
|
||||
elif isinstance(loss, collections.Sequence):
|
||||
elif isinstance(loss, collections_abc.Sequence):
|
||||
if len(loss) != len(output_names):
|
||||
raise ValueError('When passing a list as loss, it should have one entry '
|
||||
'per model outputs. The model has {} outputs, but you '
|
||||
@ -1397,7 +1398,7 @@ def prepare_loss_weights(training_endpoints, loss_weights=None):
|
||||
if loss_weights is None:
|
||||
for e in training_endpoints:
|
||||
e.loss_weight = 1.
|
||||
elif isinstance(loss_weights, collections.Mapping):
|
||||
elif isinstance(loss_weights, collections_abc.Mapping):
|
||||
generic_utils.check_for_unexpected_keys(
|
||||
'loss_weights', loss_weights,
|
||||
[e.output_name for e in training_endpoints])
|
||||
@ -1704,9 +1705,9 @@ def should_run_validation(validation_freq, epoch):
|
||||
raise ValueError('`validation_freq` can not be less than 1.')
|
||||
return one_indexed_epoch % validation_freq == 0
|
||||
|
||||
if not isinstance(validation_freq, collections.Container):
|
||||
if not isinstance(validation_freq, collections_abc.Container):
|
||||
raise ValueError('`validation_freq` must be an Integer or '
|
||||
'`collections.Container` (e.g. list, tuple, etc.)')
|
||||
'`collections.abc.Container` (e.g. list, tuple, etc.)')
|
||||
return one_indexed_epoch in validation_freq
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -745,7 +746,7 @@ class Model(training_lib.Model):
|
||||
the dataset at each epoch. This ensures that the same validation
|
||||
samples are used every time.
|
||||
validation_freq: Only relevant if validation data is provided. Integer
|
||||
or `collections.Container` instance (e.g. list, tuple, etc.).
|
||||
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
|
||||
If an integer, specifies how many training epochs to run before a
|
||||
new validation run is performed, e.g. `validation_freq=2` runs
|
||||
validation every 2 epochs. If a Container, specifies the epochs on
|
||||
@ -1207,7 +1208,7 @@ class Model(training_lib.Model):
|
||||
# at this point.
|
||||
if self.run_eagerly or self._distribution_strategy:
|
||||
inputs = training_utils_v1.cast_if_floating_dtype(inputs)
|
||||
if isinstance(inputs, collections.Sequence):
|
||||
if isinstance(inputs, collections_abc.Sequence):
|
||||
# Unwrap lists with only one input, as we do when training on batch
|
||||
if len(inputs) == 1:
|
||||
inputs = inputs[0]
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
@ -473,7 +473,7 @@ def _test_or_class_decorator(test_or_class, single_method_decorator):
|
||||
The decorated result.
|
||||
"""
|
||||
def _decorate_test_or_class(obj):
|
||||
if isinstance(obj, collections.Iterable):
|
||||
if isinstance(obj, collections_abc.Iterable):
|
||||
return itertools.chain.from_iterable(
|
||||
single_method_decorator(method) for method in obj)
|
||||
if isinstance(obj, type):
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
@ -37,7 +37,7 @@ class PreprocessingLayerTest(test.TestCase):
|
||||
self.assertEqual(len(a), len(b))
|
||||
for a_value, b_value in zip(a, b):
|
||||
self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
|
||||
elif isinstance(a, collections.Mapping):
|
||||
elif isinstance(a, collections_abc.Mapping):
|
||||
self.assertEqual(len(a), len(b))
|
||||
for key, a_value in a.items():
|
||||
b_value = b[key]
|
||||
|
@ -19,7 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
@ -828,7 +829,7 @@ class RNN(Layer):
|
||||
# input shape: `(samples, time (padded with zeros), input_dim)`
|
||||
# note that the .build() method of subclasses MUST define
|
||||
# self.input_spec and self.state_spec with complete input shapes.
|
||||
if (isinstance(inputs, collections.Sequence)
|
||||
if (isinstance(inputs, collections_abc.Sequence)
|
||||
and not isinstance(inputs, tuple)):
|
||||
# get initial_state from full input spec
|
||||
# as they could be copied to multiple GPU.
|
||||
|
@ -25,7 +25,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import json
|
||||
import numpy as np
|
||||
import wrapt
|
||||
@ -117,7 +117,7 @@ def get_json_type(obj):
|
||||
if isinstance(obj, dtypes.DType):
|
||||
return obj.name
|
||||
|
||||
if isinstance(obj, collections.Mapping):
|
||||
if isinstance(obj, collections_abc.Mapping):
|
||||
return dict(obj)
|
||||
|
||||
if obj is Ellipsis:
|
||||
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import collections.abc as collections_abc
|
||||
import copy
|
||||
import os
|
||||
import six
|
||||
@ -82,7 +82,7 @@ def model_input_signature(model, keep_original_batch_size=False):
|
||||
input_specs = _enforce_names_consistency(input_specs)
|
||||
# Return a list with a single element as the model's input signature.
|
||||
if isinstance(input_specs,
|
||||
collections.Sequence) and len(input_specs) == 1:
|
||||
collections_abc.Sequence) and len(input_specs) == 1:
|
||||
# Note that the isinstance check filters out single-element dictionaries,
|
||||
# which should also be wrapped as a single-element list.
|
||||
return input_specs
|
||||
|
Loading…
Reference in New Issue
Block a user