Update keras to use collections.abc instead of collections, which will be deprecated in py3.9
PiperOrigin-RevId: 339588609 Change-Id: I30553aa31227ff8e56a15509d0261ec64de3240c
This commit is contained in:
parent
fb7a8d29bb
commit
9278b9421f
@ -21,6 +21,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import collections.abc as collections_abc
|
||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
@ -2608,7 +2609,7 @@ class CSVLogger(Callback):
|
|||||||
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
||||||
if isinstance(k, six.string_types):
|
if isinstance(k, six.string_types):
|
||||||
return k
|
return k
|
||||||
elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray:
|
elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
|
||||||
return '"[%s]"' % (', '.join(map(str, k)))
|
return '"[%s]"' % (', '.join(map(str, k)))
|
||||||
else:
|
else:
|
||||||
return k
|
return k
|
||||||
|
@ -968,7 +968,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
form of datasets, generators, or `keras.utils.Sequence` instances
|
form of datasets, generators, or `keras.utils.Sequence` instances
|
||||||
(since they generate batches).
|
(since they generate batches).
|
||||||
validation_freq: Only relevant if validation data is provided. Integer
|
validation_freq: Only relevant if validation data is provided. Integer
|
||||||
or `collections.Container` instance (e.g. list, tuple, etc.).
|
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
|
||||||
If an integer, specifies how many training epochs to run before a
|
If an integer, specifies how many training epochs to run before a
|
||||||
new validation run is performed, e.g. `validation_freq=2` runs
|
new validation run is performed, e.g. `validation_freq=2` runs
|
||||||
validation every 2 epochs. If a Container, specifies the epochs on
|
validation every 2 epochs. If a Container, specifies the epochs on
|
||||||
|
@ -94,7 +94,7 @@ def model_iteration(model,
|
|||||||
validation from data tensors). Ignored with the default value of
|
validation from data tensors). Ignored with the default value of
|
||||||
`None`.
|
`None`.
|
||||||
validation_freq: Only relevant if validation data is provided. Integer or
|
validation_freq: Only relevant if validation data is provided. Integer or
|
||||||
`collections.Container` instance (e.g. list, tuple, etc.). If an
|
`collections.abc.Container` instance (e.g. list, tuple, etc.). If an
|
||||||
integer, specifies how many training epochs to run before a new
|
integer, specifies how many training epochs to run before a new
|
||||||
validation run is performed, e.g. `validation_freq=2` runs
|
validation run is performed, e.g. `validation_freq=2` runs
|
||||||
validation every 2 epochs. If a Container, specifies the epochs on
|
validation every 2 epochs. If a Container, specifies the epochs on
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import atexit
|
import atexit
|
||||||
import collections
|
import collections
|
||||||
|
import collections.abc as collections_abc
|
||||||
import functools
|
import functools
|
||||||
import multiprocessing.pool
|
import multiprocessing.pool
|
||||||
import threading
|
import threading
|
||||||
@ -584,7 +585,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
|||||||
'You should provide one `' + weight_type + '`'
|
'You should provide one `' + weight_type + '`'
|
||||||
'array per model output.')
|
'array per model output.')
|
||||||
return x_weight
|
return x_weight
|
||||||
if isinstance(x_weight, collections.Mapping):
|
if isinstance(x_weight, collections_abc.Mapping):
|
||||||
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
|
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
|
||||||
x_weights = []
|
x_weights = []
|
||||||
for name in output_names:
|
for name in output_names:
|
||||||
@ -771,7 +772,7 @@ def collect_per_output_metric_info(metrics,
|
|||||||
[metrics_module.clone_metric(m) for m in metrics])
|
[metrics_module.clone_metric(m) for m in metrics])
|
||||||
else:
|
else:
|
||||||
nested_metrics = [metrics]
|
nested_metrics = [metrics]
|
||||||
elif isinstance(metrics, collections.Mapping):
|
elif isinstance(metrics, collections_abc.Mapping):
|
||||||
generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
|
generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
|
||||||
nested_metrics = []
|
nested_metrics = []
|
||||||
for name in output_names:
|
for name in output_names:
|
||||||
@ -1087,7 +1088,7 @@ def get_loss_function(loss):
|
|||||||
'before passing them to Model.compile.'.format(loss))
|
'before passing them to Model.compile.'.format(loss))
|
||||||
|
|
||||||
# Deserialize loss configuration, if needed.
|
# Deserialize loss configuration, if needed.
|
||||||
if isinstance(loss, collections.Mapping):
|
if isinstance(loss, collections_abc.Mapping):
|
||||||
loss = losses.get(loss)
|
loss = losses.get(loss)
|
||||||
|
|
||||||
# Custom callable class.
|
# Custom callable class.
|
||||||
@ -1304,7 +1305,7 @@ def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
|
|||||||
ValueError: In case of invalid `sample_weight_mode` input.
|
ValueError: In case of invalid `sample_weight_mode` input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(sample_weight_mode, collections.Mapping):
|
if isinstance(sample_weight_mode, collections_abc.Mapping):
|
||||||
generic_utils.check_for_unexpected_keys(
|
generic_utils.check_for_unexpected_keys(
|
||||||
'sample_weight_mode', sample_weight_mode,
|
'sample_weight_mode', sample_weight_mode,
|
||||||
[e.output_name for e in training_endpoints])
|
[e.output_name for e in training_endpoints])
|
||||||
@ -1351,7 +1352,7 @@ def prepare_loss_functions(loss, output_names):
|
|||||||
ValueError: If loss is a dict with keys not in model output names,
|
ValueError: If loss is a dict with keys not in model output names,
|
||||||
or if loss is a list with len not equal to model outputs.
|
or if loss is a list with len not equal to model outputs.
|
||||||
"""
|
"""
|
||||||
if isinstance(loss, collections.Mapping):
|
if isinstance(loss, collections_abc.Mapping):
|
||||||
generic_utils.check_for_unexpected_keys('loss', loss, output_names)
|
generic_utils.check_for_unexpected_keys('loss', loss, output_names)
|
||||||
loss_functions = []
|
loss_functions = []
|
||||||
for name in output_names:
|
for name in output_names:
|
||||||
@ -1363,7 +1364,7 @@ def prepare_loss_functions(loss, output_names):
|
|||||||
loss_functions.append(get_loss_function(loss.get(name, None)))
|
loss_functions.append(get_loss_function(loss.get(name, None)))
|
||||||
elif isinstance(loss, six.string_types):
|
elif isinstance(loss, six.string_types):
|
||||||
loss_functions = [get_loss_function(loss) for _ in output_names]
|
loss_functions = [get_loss_function(loss) for _ in output_names]
|
||||||
elif isinstance(loss, collections.Sequence):
|
elif isinstance(loss, collections_abc.Sequence):
|
||||||
if len(loss) != len(output_names):
|
if len(loss) != len(output_names):
|
||||||
raise ValueError('When passing a list as loss, it should have one entry '
|
raise ValueError('When passing a list as loss, it should have one entry '
|
||||||
'per model outputs. The model has {} outputs, but you '
|
'per model outputs. The model has {} outputs, but you '
|
||||||
@ -1397,7 +1398,7 @@ def prepare_loss_weights(training_endpoints, loss_weights=None):
|
|||||||
if loss_weights is None:
|
if loss_weights is None:
|
||||||
for e in training_endpoints:
|
for e in training_endpoints:
|
||||||
e.loss_weight = 1.
|
e.loss_weight = 1.
|
||||||
elif isinstance(loss_weights, collections.Mapping):
|
elif isinstance(loss_weights, collections_abc.Mapping):
|
||||||
generic_utils.check_for_unexpected_keys(
|
generic_utils.check_for_unexpected_keys(
|
||||||
'loss_weights', loss_weights,
|
'loss_weights', loss_weights,
|
||||||
[e.output_name for e in training_endpoints])
|
[e.output_name for e in training_endpoints])
|
||||||
@ -1704,9 +1705,9 @@ def should_run_validation(validation_freq, epoch):
|
|||||||
raise ValueError('`validation_freq` can not be less than 1.')
|
raise ValueError('`validation_freq` can not be less than 1.')
|
||||||
return one_indexed_epoch % validation_freq == 0
|
return one_indexed_epoch % validation_freq == 0
|
||||||
|
|
||||||
if not isinstance(validation_freq, collections.Container):
|
if not isinstance(validation_freq, collections_abc.Container):
|
||||||
raise ValueError('`validation_freq` must be an Integer or '
|
raise ValueError('`validation_freq` must be an Integer or '
|
||||||
'`collections.Container` (e.g. list, tuple, etc.)')
|
'`collections.abc.Container` (e.g. list, tuple, etc.)')
|
||||||
return one_indexed_epoch in validation_freq
|
return one_indexed_epoch in validation_freq
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import collections.abc as collections_abc
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -745,7 +746,7 @@ class Model(training_lib.Model):
|
|||||||
the dataset at each epoch. This ensures that the same validation
|
the dataset at each epoch. This ensures that the same validation
|
||||||
samples are used every time.
|
samples are used every time.
|
||||||
validation_freq: Only relevant if validation data is provided. Integer
|
validation_freq: Only relevant if validation data is provided. Integer
|
||||||
or `collections.Container` instance (e.g. list, tuple, etc.).
|
or `collections.abc.Container` instance (e.g. list, tuple, etc.).
|
||||||
If an integer, specifies how many training epochs to run before a
|
If an integer, specifies how many training epochs to run before a
|
||||||
new validation run is performed, e.g. `validation_freq=2` runs
|
new validation run is performed, e.g. `validation_freq=2` runs
|
||||||
validation every 2 epochs. If a Container, specifies the epochs on
|
validation every 2 epochs. If a Container, specifies the epochs on
|
||||||
@ -1207,7 +1208,7 @@ class Model(training_lib.Model):
|
|||||||
# at this point.
|
# at this point.
|
||||||
if self.run_eagerly or self._distribution_strategy:
|
if self.run_eagerly or self._distribution_strategy:
|
||||||
inputs = training_utils_v1.cast_if_floating_dtype(inputs)
|
inputs = training_utils_v1.cast_if_floating_dtype(inputs)
|
||||||
if isinstance(inputs, collections.Sequence):
|
if isinstance(inputs, collections_abc.Sequence):
|
||||||
# Unwrap lists with only one input, as we do when training on batch
|
# Unwrap lists with only one input, as we do when training on batch
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
inputs = inputs[0]
|
inputs = inputs[0]
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections.abc as collections_abc
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
@ -473,7 +473,7 @@ def _test_or_class_decorator(test_or_class, single_method_decorator):
|
|||||||
The decorated result.
|
The decorated result.
|
||||||
"""
|
"""
|
||||||
def _decorate_test_or_class(obj):
|
def _decorate_test_or_class(obj):
|
||||||
if isinstance(obj, collections.Iterable):
|
if isinstance(obj, collections_abc.Iterable):
|
||||||
return itertools.chain.from_iterable(
|
return itertools.chain.from_iterable(
|
||||||
single_method_decorator(method) for method in obj)
|
single_method_decorator(method) for method in obj)
|
||||||
if isinstance(obj, type):
|
if isinstance(obj, type):
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections.abc as collections_abc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -37,7 +37,7 @@ class PreprocessingLayerTest(test.TestCase):
|
|||||||
self.assertEqual(len(a), len(b))
|
self.assertEqual(len(a), len(b))
|
||||||
for a_value, b_value in zip(a, b):
|
for a_value, b_value in zip(a, b):
|
||||||
self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
|
self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
|
||||||
elif isinstance(a, collections.Mapping):
|
elif isinstance(a, collections_abc.Mapping):
|
||||||
self.assertEqual(len(a), len(b))
|
self.assertEqual(len(a), len(b))
|
||||||
for key, a_value in a.items():
|
for key, a_value in a.items():
|
||||||
b_value = b[key]
|
b_value = b[key]
|
||||||
|
@ -19,7 +19,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections.abc as collections_abc
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -47,6 +47,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.compat import collections_abc
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
@ -828,7 +829,7 @@ class RNN(Layer):
|
|||||||
# input shape: `(samples, time (padded with zeros), input_dim)`
|
# input shape: `(samples, time (padded with zeros), input_dim)`
|
||||||
# note that the .build() method of subclasses MUST define
|
# note that the .build() method of subclasses MUST define
|
||||||
# self.input_spec and self.state_spec with complete input shapes.
|
# self.input_spec and self.state_spec with complete input shapes.
|
||||||
if (isinstance(inputs, collections.Sequence)
|
if (isinstance(inputs, collections_abc.Sequence)
|
||||||
and not isinstance(inputs, tuple)):
|
and not isinstance(inputs, tuple)):
|
||||||
# get initial_state from full input spec
|
# get initial_state from full input spec
|
||||||
# as they could be copied to multiple GPU.
|
# as they could be copied to multiple GPU.
|
||||||
|
@ -25,7 +25,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections.abc as collections_abc
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import wrapt
|
import wrapt
|
||||||
@ -117,7 +117,7 @@ def get_json_type(obj):
|
|||||||
if isinstance(obj, dtypes.DType):
|
if isinstance(obj, dtypes.DType):
|
||||||
return obj.name
|
return obj.name
|
||||||
|
|
||||||
if isinstance(obj, collections.Mapping):
|
if isinstance(obj, collections_abc.Mapping):
|
||||||
return dict(obj)
|
return dict(obj)
|
||||||
|
|
||||||
if obj is Ellipsis:
|
if obj is Ellipsis:
|
||||||
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections.abc as collections_abc
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import six
|
import six
|
||||||
@ -82,7 +82,7 @@ def model_input_signature(model, keep_original_batch_size=False):
|
|||||||
input_specs = _enforce_names_consistency(input_specs)
|
input_specs = _enforce_names_consistency(input_specs)
|
||||||
# Return a list with a single element as the model's input signature.
|
# Return a list with a single element as the model's input signature.
|
||||||
if isinstance(input_specs,
|
if isinstance(input_specs,
|
||||||
collections.Sequence) and len(input_specs) == 1:
|
collections_abc.Sequence) and len(input_specs) == 1:
|
||||||
# Note that the isinstance check filters out single-element dictionaries,
|
# Note that the isinstance check filters out single-element dictionaries,
|
||||||
# which should also be wrapped as a single-element list.
|
# which should also be wrapped as a single-element list.
|
||||||
return input_specs
|
return input_specs
|
||||||
|
Loading…
Reference in New Issue
Block a user