Use __slots__ for small classes

This commit is contained in:
Lukas Geiger 2020-06-25 01:24:15 +02:00
parent 4aa879aab5
commit c9b4679c9b
34 changed files with 167 additions and 1 deletions

View File

@ -3770,6 +3770,8 @@ class BatchDataset(UnaryDataset):
class _NumpyIterator(object):
"""Iterator over a dataset with elements converted to numpy."""
__slots__ = ["_iterator"]
def __init__(self, dataset):
self._iterator = iter(dataset)

View File

@ -522,6 +522,8 @@ class IteratorResourceDeleter(object):
object is part of a reference cycle, the cycle will be collectable.
"""
__slots__ = ["_deleter", "_handle", "_device", "_eager_mode"]
def __init__(self, handle, device, deleter):
self._deleter = deleter
self._handle = handle

View File

@ -377,6 +377,9 @@ class MultiDeviceIteratorResourceDeleter(object):
object is part of a reference cycle, the cycle will be collectible.
"""
__slots__ = ["_deleter", "_multi_device_iterator", "_iterators",
"_device", "_eager_mode"]
def __init__(self, multi_device_iterator, iterators, device, deleter):
self._deleter = deleter
self._multi_device_iterator = multi_device_iterator

View File

@ -84,6 +84,8 @@ def resolve(d):
class _FakeNodeDef(object):
"""A fake NodeDef for _FakeOperation."""
__slots__ = ["op", "name"]
def __init__(self):
self.op = ""
self.name = ""

View File

@ -250,6 +250,8 @@ def get_update_replica_id():
class UpdateContext(object):
"""Context manager when you are in `update()` or `update_non_slot()`."""
__slots__ = ["_replica_id", "_old_replica_id"]
def __init__(self, replica_id):
self._replica_id = replica_id
self._old_replica_id = None
@ -454,6 +456,9 @@ class InputContext(object):
source etc).
"""
__slots__ = ["_num_input_pipelines", "_input_pipeline_id",
"_num_replicas_in_sync"]
def __init__(self,
num_input_pipelines=1,
input_pipeline_id=0,
@ -545,6 +550,8 @@ class ValueContext(object):
"""
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
def __init__(self,
replica_id_in_sync_group=0,
num_replicas_in_sync=1):
@ -2923,6 +2930,8 @@ class _DefaultDistributionStrategy(StrategyV1):
class _DefaultDistributionContext(object):
"""Context manager setting the default `tf.distribute.Strategy`."""
__slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
def __init__(self, strategy):
def creator(next_creator, **kwargs):

View File

@ -24,6 +24,8 @@ from tensorflow.python import pywrap_tfe
class CancellationManager(object):
"""A mechanism for cancelling blocking computation."""
__slots__ = ["_impl"]
def __init__(self):
self._impl = pywrap_tfe.TFE_NewCancellationManager()

View File

@ -80,6 +80,8 @@ _python_eager_context_create_counter = monitoring.Counter(
class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
__slots__ = ["_data", "_max_items", "_max_tensor_size"]
def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
self._max_items = max_items
@ -107,6 +109,8 @@ class FunctionCallOptions(object):
Eager functions are functions decorated with tf.contrib.eager.defun.
"""
__slots__ = ["_config_proto_serialized", "_executor_type"]
def __init__(self, executor_type=None, config_proto=None):
"""Constructor.
@ -161,6 +165,8 @@ _tensor_caches_map = {}
class _TensorCaches(threading.local):
"""Thread local tensor caches."""
__slots__ = ["_ones_rank_cache", "_zeros_cache"]
def __init__(self):
super(_TensorCaches, self).__init__()
self._ones_rank_cache = None
@ -316,6 +322,8 @@ class PhysicalDevice(
class _AtomicCounter(object):
"""A simple atomic counter."""
__slots__ = ["_value", "_lock"]
def __init__(self):
self._value = 0
self._lock = threading.Lock()
@ -332,6 +340,8 @@ _context_id_counter = _AtomicCounter()
class _TensorCacheDeleter(object):
"""Deletes tensor caches for a given context."""
__slots__ = ["_context_id"]
def __init__(self, context_id):
self._context_id = context_id
@ -1730,6 +1740,8 @@ class Context(object):
class _EagerDeviceContext(object):
"""Context-manager forcing placement of ops and Tensors on a device."""
__slots__ = ["_device_name", "_ctx", "_stack"]
def __init__(self, ctx, device_name):
self._device_name = device_name
self._ctx = ctx

View File

@ -54,6 +54,8 @@ FREQUENT_TRACING_WARNING_THRESHOLD = 5
class _CallCounter(object):
"""Class keeping track of how many recent calls triggered tracing."""
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
def __init__(self, max_call_history):
self._max_call_history = max_call_history
self._calls_per_tracings = []
@ -84,6 +86,8 @@ class _CallCounter(object):
class _FrequentTracingDetector(object):
"""Class for frequent retracing detection and warning."""
__slots__ = ["_counters", "_lock"]
def __init__(self):
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
self._lock = threading.Lock()
@ -428,6 +432,8 @@ def functions_run_eagerly():
class FunctionDeleter(object):
__slots__ = ["func_graph"]
def __init__(self, func_graph):
self.func_graph = func_graph

View File

@ -40,6 +40,8 @@ class Executor(object):
```
"""
__slots__ = ["_handle"]
def __init__(self, handle):
self._handle = handle

View File

@ -262,6 +262,8 @@ def _parse_func_attrs(attributes):
class _InterpolateFunctionError(object):
"""Context Manager that interpolates the exception from 'top_level_func'."""
__slots__ = ["_func"]
def __init__(self, top_level_func):
self._func = top_level_func
@ -378,6 +380,8 @@ def _enclosing_xla_context():
class _EagerDefinedFunctionDeleter(object):
"""Unregister function from eager context."""
__slots__ = ["name"]
def __init__(self, name):
self.name = name
@ -1410,6 +1414,9 @@ _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
class _ForwardBackwardCall(object):
"""Holds the state of a function call between execution and recording."""
__slots__ = ["_functions", "_inference_args", "_input_tangents",
"_tape_watching"]
def __init__(self, functions, inference_args, input_tangents, tape_watching):
"""Collects information about the function call.
@ -2740,6 +2747,9 @@ class FunctionCache(object):
"""A lightweight container for cached functions.
"""
__slots__ = ["missed", "primary", "arg_relaxed_specs",
"arg_relaxed", "_garbage_collectors"]
def __init__(self):
# The set of functions that have been missed; entries are CacheKey with
# input_signature `None` (e.g. a "call context key")
@ -3771,6 +3781,8 @@ def class_method_to_instance_method(original_function, instance):
class _FunctionGarbageCollector(object):
"""Cleans up cycles when a defun goes out of scope."""
__slots__ = ["_cache"]
def __init__(self, cache):
self._cache = cache
@ -3788,6 +3800,8 @@ class _FunctionGarbageCollector(object):
class ConcreteFunctionGarbageCollector(object):
"""Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
__slots__ = ["_func_graph"]
def __init__(self, func_graph):
self._func_graph = func_graph
@ -3807,6 +3821,8 @@ class ConcreteFunctionGarbageCollector(object):
class _Marker(object):
"""Markers used to pretty-print nested args in function signatures."""
__slots__ = ["_s"]
def __init__(self, s):
self._s = s

View File

@ -104,6 +104,8 @@ _sampler_methods = [
class Metric(object):
"""The base class of metric."""
__slots__ = ["_metric", "_metric_name", "_metric_methods", "_label_length"]
def __init__(self, metric_name, metric_methods, label_length, *args):
"""Creates a new metric.
@ -145,6 +147,8 @@ class Metric(object):
class CounterCell(object):
"""CounterCell stores each value of a Counter."""
__slots__ = ["_cell"]
def __init__(self, cell):
"""Creates a new CounterCell.
@ -174,6 +178,8 @@ class Counter(Metric):
user to increment each value.
"""
__slots__ = []
def __init__(self, name, description, *labels):
"""Creates a new Counter.
@ -193,6 +199,8 @@ class Counter(Metric):
class IntGaugeCell(object):
"""A single integer value stored in an `IntGauge`."""
__slots__ = ["_cell"]
def __init__(self, cell):
"""Creates a new IntGaugeCell.
@ -222,6 +230,8 @@ class IntGauge(Metric):
allows the user to set each value.
"""
__slots__ = []
def __init__(self, name, description, *labels):
"""Creates a new IntGauge.
@ -241,6 +251,8 @@ class IntGauge(Metric):
class StringGaugeCell(object):
"""A single string value stored in an `StringGauge`."""
__slots__ = ["_cell"]
def __init__(self, cell):
"""Creates a new StringGaugeCell.
@ -273,6 +285,8 @@ class StringGauge(Metric):
allows the user to set each value.
"""
__slots__ = []
def __init__(self, name, description, *labels):
"""Creates a new StringGauge.
@ -292,6 +306,8 @@ class StringGauge(Metric):
class BoolGaugeCell(object):
"""A single boolean value stored in an `BoolGauge`."""
__slots__ = ["_cell"]
def __init__(self, cell):
"""Creates a new BoolGaugeCell.
@ -321,6 +337,8 @@ class BoolGauge(Metric):
allows the user to set each value.
"""
__slots__ = []
def __init__(self, name, description, *labels):
"""Creates a new BoolGauge.
@ -340,6 +358,8 @@ class BoolGauge(Metric):
class SamplerCell(object):
"""SamplerCell stores each value of a Sampler."""
__slots__ = ["_cell"]
def __init__(self, cell):
"""Creates a new SamplerCell.
@ -373,6 +393,8 @@ class SamplerCell(object):
class Buckets(object):
"""Bucketing strategies for the samplers."""
__slots__ = ["buckets"]
def __init__(self, buckets):
"""Creates a new Buckets.
@ -393,6 +415,8 @@ class ExponentialBuckets(Buckets):
scale * growth_factor^(i + 1), ..., DBL_MAX].
"""
__slots__ = []
def __init__(self, scale, growth_factor, bucket_count):
"""Creates a new exponential Buckets.
@ -415,6 +439,8 @@ class Sampler(Metric):
user to add a sample to each histogram value.
"""
__slots__ = []
def __init__(self, name, buckets, description, *labels):
"""Creates a new Sampler.
@ -435,6 +461,8 @@ class Sampler(Metric):
class MonitoredTimer(object):
"""A context manager to measure the walltime and increment a Counter cell."""
__slots__ = ["cell", "t"]
def __init__(self, cell):
"""Creates a new MonitoredTimer.

View File

@ -35,6 +35,8 @@ distribution_strategy_context = LazyLoader(
class Tape(object):
"""Represents a gradient propagation trace."""
__slots__ = ["_tape"]
def __init__(self, tape):
self._tape = tape
@ -72,6 +74,8 @@ class VariableWatcher(object):
assert variable_watcher.watched_variables == [var]
"""
__slots__ = ["_variable_watcher"]
def __init__(self):
self._variable_watcher = None

View File

@ -175,6 +175,9 @@ class AutomaticControlDependencies(object):
NOT THREAD SAFE
"""
__slots__ = ["_returned_tensors", "ops_which_must_run", "_graph",
"_n_operations", "collective_manager_ids_used"]
def __init__(self):
self._returned_tensors = object_identity.ObjectIdentitySet()
self.ops_which_must_run = set()

View File

@ -29,6 +29,8 @@ from tensorflow.python.util import tf_contextlib
class ScopedTFStatus(object):
"""Wrapper around TF_Status that handles deletion."""
__slots__ = ["status"]
def __init__(self):
self.status = c_api.TF_NewStatus()
@ -42,6 +44,8 @@ class ScopedTFStatus(object):
class ScopedTFGraph(object):
"""Wrapper around TF_Graph that handles deletion."""
__slots__ = ["graph", "deleter"]
def __init__(self):
self.graph = c_api.TF_NewGraph()
# Note: when we're destructing the global context (i.e when the process is
@ -57,6 +61,8 @@ class ScopedTFGraph(object):
class ScopedTFImportGraphDefOptions(object):
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
__slots__ = ["options"]
def __init__(self):
self.options = c_api.TF_NewImportGraphDefOptions()
@ -70,6 +76,8 @@ class ScopedTFImportGraphDefOptions(object):
class ScopedTFImportGraphDefResults(object):
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
__slots__ = ["results"]
def __init__(self, results):
self.results = results
@ -83,6 +91,8 @@ class ScopedTFImportGraphDefResults(object):
class ScopedTFFunction(object):
"""Wrapper around TF_Function that handles deletion."""
__slots__ = ["func", "deleter"]
def __init__(self, func):
self.func = func
# Note: when we're destructing the global context (i.e when the process is
@ -100,6 +110,8 @@ class ScopedTFFunction(object):
class ScopedTFBuffer(object):
"""An internal class to help manage the TF_Buffer lifetime."""
__slots__ = ["buffer"]
def __init__(self, buf_string):
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
@ -114,6 +126,8 @@ class ApiDefMap(object):
be queried by op name.
"""
__slots__ = ["_api_def_map", "_op_per_name"]
def __init__(self):
op_def_proto = op_def_pb2.OpList()
buf = c_api.TF_GetAllOpList()

View File

@ -112,6 +112,8 @@ class MergeDevice(object):
performance of device placement.
"""
__slots__ = ["_spec"]
def __init__(self, spec):
if isinstance(spec, device_spec.DeviceSpecV2):
self._spec = spec

View File

@ -194,6 +194,8 @@ class Defun(object):
class _DefinedFunctionDeleter(object):
"""Unregister function from eager context."""
__slots__ = ["name"]
def __init__(self, name):
self.name = name

View File

@ -2583,6 +2583,8 @@ class RegisterGradient(object):
that defines the operation.
"""
__slots__ = ["_op_type"]
def __init__(self, op_type):
"""Creates a new decorator with `op_type` as the Operation type.
@ -2679,6 +2681,8 @@ class OpStats(object):
"""
__slots__ = ["_statistic_type", "_value"]
def __init__(self, statistic_type, value=None):
"""Sets up the initial placeholders for the statistics."""
self.statistic_type = statistic_type
@ -2758,6 +2762,8 @@ class RegisterStatistics(object):
"""
__slots__ = ["_op_type", "_statistic_type"]
def __init__(self, op_type, statistic_type):
"""Saves the `op_type` as the `Operation` type."""
if not isinstance(op_type, six.string_types):
@ -5176,6 +5182,8 @@ class enable_auto_cast_variables(object):
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
"""
__slots__ = ["_dtype", "_graph", "_prev_read_dtype"]
def __init__(self, dtype, graph=None):
if dtype and not dtype.is_floating:
self._dtype = None
@ -6529,6 +6537,8 @@ class name_scope_v1(object): # pylint: disable=invalid-name
```
"""
__slots__ = ["_name", "_name_scope"]
@property
def name(self):
return self._name
@ -6582,6 +6592,8 @@ class name_scope_v2(object):
will generate `MyOp_1/a`, etc.
"""
__slots__ = ["_name", "_exit_fns"]
def __init__(self, name):
"""Initialize the context manager.
@ -6918,6 +6930,8 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
class _TensorIterator(object):
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
__slots__ = ["_tensor", "_index", "_limit"]
def __init__(self, tensor, dim0):
self._tensor = tensor
self._index = 0

View File

@ -36,6 +36,8 @@ _TYPE_TAG = "type"
class Registry(object):
"""Provides a registry for saving objects."""
__slots__ = ["_name", "_registry"]
def __init__(self, name):
"""Creates a new registry."""
self._name = name

View File

@ -67,6 +67,8 @@ def _recursive_apply(tensors, apply_fn):
class _ControlOutputCache(object):
"""Helper class to manage calculating and caching control_outputs in graph."""
__slots__ = ["cache"]
def __init__(self):
self.cache = {}

View File

@ -54,6 +54,8 @@ class InputSpec(object):
a specific dimension value.
"""
__slots__ = ['dtype', 'shape', 'ndim', 'max_ndim', 'min_ndim', 'axes']
def __init__(self,
dtype=None,
shape=None,

View File

@ -49,6 +49,8 @@ class _UnwrapPreventer(object):
unwrapped by DistributionStrategy
"""
__slots__ = ["value"]
def __init__(self, value):
self.value = value

View File

@ -719,6 +719,8 @@ class ConversionNotImplementedError(Exception):
class _PforInput(object):
"""Input object passed to registered pfor converters."""
__slots__ = ["pfor", "_op", "_inputs"]
def __init__(self, pfor, op, inputs):
"""Creates a _PforInput object.

View File

@ -265,6 +265,8 @@ class EagerResourceDeleter(object):
the cycle will be collectable.
"""
__slots__ = ["_handle", "_handle_device", "_context"]
def __init__(self, handle, handle_device):
if not isinstance(handle, ops.Tensor):
raise ValueError(

View File

@ -63,6 +63,8 @@ _api_usage_gauge = monitoring.BoolGauge(
class _PartitionInfo(object):
"""Holds partition info used by initializer functions."""
__slots__ = ["_full_shape", "_var_offset"]
def __init__(self, full_shape, var_offset):
"""Constructor.
@ -279,6 +281,8 @@ class _VariableStore(object):
the corresponding TensorFlow Variables as values.
"""
__slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
def __init__(self):
"""Create a variable store."""
self._vars = {} # A dictionary of the stored TensorFlow variables.

View File

@ -41,6 +41,8 @@ from tensorflow.python.util import nest
class _SingleDeviceSaver(object):
"""Saves and restores checkpoints from the current device."""
__slots__ = ["_saveable_objects"]
def __init__(self, saveable_objects):
"""Specify a list of `SaveableObject`s to save and restore.

View File

@ -553,6 +553,8 @@ def _ready(op, sess, msg):
class _CountDownTimer(object):
__slots__ = ["_start_time_secs", "_duration_secs"]
def __init__(self, duration_secs):
self._start_time_secs = time.time()
self._duration_secs = duration_secs

View File

@ -190,6 +190,8 @@ class PythonStringStateSaveable(PythonStateSaveable):
class CheckpointPosition(object):
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
__slots__ = ["_checkpoint", "_proto_id"]
def __init__(self, checkpoint, proto_id):
"""Specify an object within a checkpoint.

View File

@ -59,6 +59,8 @@ class NoDependency(object):
variables will appear in `Model.variables`).
"""
__slots__ = ["value"]
def __init__(self, value):
self.value = value

View File

@ -140,6 +140,8 @@ def delete_tracking(obj, name):
class ResourceTracker(object):
"""An object that tracks a list of resources."""
__slots__ = ["_resources"]
def __init__(self):
self._resources = []
@ -183,6 +185,8 @@ def resource_tracker_scope(resource_tracker):
class CapturableResourceDeleter(object):
"""Deleter to destroy CapturableResource without overriding its __del__()."""
__slots__ = ["_destruction_context", "_destroy_resource"]
def __init__(self, destroy_resource_fn=None):
if destroy_resource_fn:
self._destroy_resource = destroy_resource_fn

View File

@ -86,6 +86,8 @@ class _ObjectGraphProtoPrettyPrinter(object):
repeated naming is cheap after the first.
"""
__slots__ = ["_object_graph_proto", "_node_name_cache"]
def __init__(self, object_graph_proto):
self._object_graph_proto = object_graph_proto
self._node_name_cache = None
@ -124,6 +126,9 @@ class _ObjectGraphProtoPrettyPrinter(object):
class _CheckpointRestoreCoordinatorDeleter(object):
"""Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
__slots__ = ["expect_partial", "object_graph_proto",
"matched_proto_ids", "unused_attributes"]
def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
unused_attributes):
self.expect_partial = expect_partial

View File

@ -51,6 +51,8 @@ class GroupLock(object):
can also use the `acquire` and `release` method directly.
"""
__slots__ = ["_ready", "_num_groups", "_group_member_counts"]
def __init__(self, num_groups=2):
"""Initialize a group lock.
@ -116,6 +118,8 @@ class GroupLock(object):
class _Context(object):
"""Context manager helper for `GroupLock`."""
__slots__ = ["_lock", "_group_id"]
def __init__(self, lock, group_id):
self._lock = lock
self._group_id = group_id

View File

@ -344,6 +344,8 @@ _same_namedtuples = _pywrap_utils.SameNamedtuples
class _DotString(object):
__slots__ = []
def __str__(self):
return "."

View File

@ -122,6 +122,8 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
and comparing based on the equality of their contents by default).
"""
__slots__ = ["_storage"]
def __init__(self):
self._storage = {}
@ -171,6 +173,8 @@ class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
class ObjectIdentitySet(collections_abc.MutableSet):
"""Like the built-in set, but compares objects with "is"."""
__slots__ = ["_storage"]
def __init__(self, *args):
self._storage = set(self._wrap_key(obj) for obj in list(*args))