Use __slots__ for small classes
This commit is contained in:
parent
4aa879aab5
commit
c9b4679c9b
@ -3770,6 +3770,8 @@ class BatchDataset(UnaryDataset):
|
||||
class _NumpyIterator(object):
|
||||
"""Iterator over a dataset with elements converted to numpy."""
|
||||
|
||||
__slots__ = ["_iterator"]
|
||||
|
||||
def __init__(self, dataset):
|
||||
self._iterator = iter(dataset)
|
||||
|
||||
|
@ -522,6 +522,8 @@ class IteratorResourceDeleter(object):
|
||||
object is part of a reference cycle, the cycle will be collectable.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deleter", "_handle", "_device", "_eager_mode"]
|
||||
|
||||
def __init__(self, handle, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._handle = handle
|
||||
|
@ -377,6 +377,9 @@ class MultiDeviceIteratorResourceDeleter(object):
|
||||
object is part of a reference cycle, the cycle will be collectible.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deleter", "_multi_device_iterator", "_iterators",
|
||||
"_device", "_eager_mode"]
|
||||
|
||||
def __init__(self, multi_device_iterator, iterators, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._multi_device_iterator = multi_device_iterator
|
||||
|
@ -84,6 +84,8 @@ def resolve(d):
|
||||
class _FakeNodeDef(object):
|
||||
"""A fake NodeDef for _FakeOperation."""
|
||||
|
||||
__slots__ = ["op", "name"]
|
||||
|
||||
def __init__(self):
|
||||
self.op = ""
|
||||
self.name = ""
|
||||
|
@ -250,6 +250,8 @@ def get_update_replica_id():
|
||||
class UpdateContext(object):
|
||||
"""Context manager when you are in `update()` or `update_non_slot()`."""
|
||||
|
||||
__slots__ = ["_replica_id", "_old_replica_id"]
|
||||
|
||||
def __init__(self, replica_id):
|
||||
self._replica_id = replica_id
|
||||
self._old_replica_id = None
|
||||
@ -454,6 +456,9 @@ class InputContext(object):
|
||||
source etc).
|
||||
"""
|
||||
|
||||
__slots__ = ["_num_input_pipelines", "_input_pipeline_id",
|
||||
"_num_replicas_in_sync"]
|
||||
|
||||
def __init__(self,
|
||||
num_input_pipelines=1,
|
||||
input_pipeline_id=0,
|
||||
@ -545,6 +550,8 @@ class ValueContext(object):
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
|
||||
|
||||
def __init__(self,
|
||||
replica_id_in_sync_group=0,
|
||||
num_replicas_in_sync=1):
|
||||
@ -2923,6 +2930,8 @@ class _DefaultDistributionStrategy(StrategyV1):
|
||||
class _DefaultDistributionContext(object):
|
||||
"""Context manager setting the default `tf.distribute.Strategy`."""
|
||||
|
||||
__slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
|
||||
|
||||
def __init__(self, strategy):
|
||||
|
||||
def creator(next_creator, **kwargs):
|
||||
|
@ -24,6 +24,8 @@ from tensorflow.python import pywrap_tfe
|
||||
class CancellationManager(object):
|
||||
"""A mechanism for cancelling blocking computation."""
|
||||
|
||||
__slots__ = ["_impl"]
|
||||
|
||||
def __init__(self):
|
||||
self._impl = pywrap_tfe.TFE_NewCancellationManager()
|
||||
|
||||
|
@ -80,6 +80,8 @@ _python_eager_context_create_counter = monitoring.Counter(
|
||||
class _EagerTensorCache(object):
|
||||
"""Simple cache which evicts items based on length in a FIFO manner."""
|
||||
|
||||
__slots__ = ["_data", "_max_items", "_max_tensor_size"]
|
||||
|
||||
def __init__(self, max_items=256, max_tensor_size=10000):
|
||||
self._data = collections.OrderedDict()
|
||||
self._max_items = max_items
|
||||
@ -107,6 +109,8 @@ class FunctionCallOptions(object):
|
||||
Eager functions are functions decorated with tf.contrib.eager.defun.
|
||||
"""
|
||||
|
||||
__slots__ = ["_config_proto_serialized", "_executor_type"]
|
||||
|
||||
def __init__(self, executor_type=None, config_proto=None):
|
||||
"""Constructor.
|
||||
|
||||
@ -161,6 +165,8 @@ _tensor_caches_map = {}
|
||||
class _TensorCaches(threading.local):
|
||||
"""Thread local tensor caches."""
|
||||
|
||||
__slots__ = ["_ones_rank_cache", "_zeros_cache"]
|
||||
|
||||
def __init__(self):
|
||||
super(_TensorCaches, self).__init__()
|
||||
self._ones_rank_cache = None
|
||||
@ -316,6 +322,8 @@ class PhysicalDevice(
|
||||
class _AtomicCounter(object):
|
||||
"""A simple atomic counter."""
|
||||
|
||||
__slots__ = ["_value", "_lock"]
|
||||
|
||||
def __init__(self):
|
||||
self._value = 0
|
||||
self._lock = threading.Lock()
|
||||
@ -332,6 +340,8 @@ _context_id_counter = _AtomicCounter()
|
||||
class _TensorCacheDeleter(object):
|
||||
"""Deletes tensor caches for a given context."""
|
||||
|
||||
__slots__ = ["_context_id"]
|
||||
|
||||
def __init__(self, context_id):
|
||||
self._context_id = context_id
|
||||
|
||||
@ -1730,6 +1740,8 @@ class Context(object):
|
||||
class _EagerDeviceContext(object):
|
||||
"""Context-manager forcing placement of ops and Tensors on a device."""
|
||||
|
||||
__slots__ = ["_device_name", "_ctx", "_stack"]
|
||||
|
||||
def __init__(self, ctx, device_name):
|
||||
self._device_name = device_name
|
||||
self._ctx = ctx
|
||||
|
@ -54,6 +54,8 @@ FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
||||
class _CallCounter(object):
|
||||
"""Class keeping track of how many recent calls triggered tracing."""
|
||||
|
||||
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
|
||||
|
||||
def __init__(self, max_call_history):
|
||||
self._max_call_history = max_call_history
|
||||
self._calls_per_tracings = []
|
||||
@ -84,6 +86,8 @@ class _CallCounter(object):
|
||||
class _FrequentTracingDetector(object):
|
||||
"""Class for frequent retracing detection and warning."""
|
||||
|
||||
__slots__ = ["_counters", "_lock"]
|
||||
|
||||
def __init__(self):
|
||||
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
||||
self._lock = threading.Lock()
|
||||
@ -428,6 +432,8 @@ def functions_run_eagerly():
|
||||
|
||||
class FunctionDeleter(object):
|
||||
|
||||
__slots__ = ["func_graph"]
|
||||
|
||||
def __init__(self, func_graph):
|
||||
self.func_graph = func_graph
|
||||
|
||||
|
@ -40,6 +40,8 @@ class Executor(object):
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ["_handle"]
|
||||
|
||||
def __init__(self, handle):
|
||||
self._handle = handle
|
||||
|
||||
|
@ -262,6 +262,8 @@ def _parse_func_attrs(attributes):
|
||||
class _InterpolateFunctionError(object):
|
||||
"""Context Manager that interpolates the exception from 'top_level_func'."""
|
||||
|
||||
__slots__ = ["_func"]
|
||||
|
||||
def __init__(self, top_level_func):
|
||||
self._func = top_level_func
|
||||
|
||||
@ -378,6 +380,8 @@ def _enclosing_xla_context():
|
||||
class _EagerDefinedFunctionDeleter(object):
|
||||
"""Unregister function from eager context."""
|
||||
|
||||
__slots__ = ["name"]
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@ -1410,6 +1414,9 @@ _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
|
||||
class _ForwardBackwardCall(object):
|
||||
"""Holds the state of a function call between execution and recording."""
|
||||
|
||||
__slots__ = ["_functions", "_inference_args", "_input_tangents",
|
||||
"_tape_watching"]
|
||||
|
||||
def __init__(self, functions, inference_args, input_tangents, tape_watching):
|
||||
"""Collects information about the function call.
|
||||
|
||||
@ -2740,6 +2747,9 @@ class FunctionCache(object):
|
||||
"""A lightweight container for cached functions.
|
||||
"""
|
||||
|
||||
__slots__ = ["missed", "primary", "arg_relaxed_specs",
|
||||
"arg_relaxed", "_garbage_collectors"]
|
||||
|
||||
def __init__(self):
|
||||
# The set of functions that have been missed; entries are CacheKey with
|
||||
# input_signature `None` (e.g. a "call context key")
|
||||
@ -3771,6 +3781,8 @@ def class_method_to_instance_method(original_function, instance):
|
||||
class _FunctionGarbageCollector(object):
|
||||
"""Cleans up cycles when a defun goes out of scope."""
|
||||
|
||||
__slots__ = ["_cache"]
|
||||
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
@ -3788,6 +3800,8 @@ class _FunctionGarbageCollector(object):
|
||||
class ConcreteFunctionGarbageCollector(object):
|
||||
"""Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
|
||||
|
||||
__slots__ = ["_func_graph"]
|
||||
|
||||
def __init__(self, func_graph):
|
||||
self._func_graph = func_graph
|
||||
|
||||
@ -3807,6 +3821,8 @@ class ConcreteFunctionGarbageCollector(object):
|
||||
class _Marker(object):
|
||||
"""Markers used to pretty-print nested args in function signatures."""
|
||||
|
||||
__slots__ = ["_s"]
|
||||
|
||||
def __init__(self, s):
|
||||
self._s = s
|
||||
|
||||
|
@ -104,6 +104,8 @@ _sampler_methods = [
|
||||
class Metric(object):
|
||||
"""The base class of metric."""
|
||||
|
||||
__slots__ = ["_metric", "_metric_name", "_metric_methods", "_label_length"]
|
||||
|
||||
def __init__(self, metric_name, metric_methods, label_length, *args):
|
||||
"""Creates a new metric.
|
||||
|
||||
@ -145,6 +147,8 @@ class Metric(object):
|
||||
class CounterCell(object):
|
||||
"""CounterCell stores each value of a Counter."""
|
||||
|
||||
__slots__ = ["_cell"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new CounterCell.
|
||||
|
||||
@ -174,6 +178,8 @@ class Counter(Metric):
|
||||
user to increment each value.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, name, description, *labels):
|
||||
"""Creates a new Counter.
|
||||
|
||||
@ -193,6 +199,8 @@ class Counter(Metric):
|
||||
class IntGaugeCell(object):
|
||||
"""A single integer value stored in an `IntGauge`."""
|
||||
|
||||
__slots__ = ["_cell"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new IntGaugeCell.
|
||||
|
||||
@ -222,6 +230,8 @@ class IntGauge(Metric):
|
||||
allows the user to set each value.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, name, description, *labels):
|
||||
"""Creates a new IntGauge.
|
||||
|
||||
@ -241,6 +251,8 @@ class IntGauge(Metric):
|
||||
class StringGaugeCell(object):
|
||||
"""A single string value stored in an `StringGauge`."""
|
||||
|
||||
__slots__ = ["_cell"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new StringGaugeCell.
|
||||
|
||||
@ -273,6 +285,8 @@ class StringGauge(Metric):
|
||||
allows the user to set each value.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, name, description, *labels):
|
||||
"""Creates a new StringGauge.
|
||||
|
||||
@ -292,6 +306,8 @@ class StringGauge(Metric):
|
||||
class BoolGaugeCell(object):
|
||||
"""A single boolean value stored in an `BoolGauge`."""
|
||||
|
||||
__slots__ = ["_cell"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new BoolGaugeCell.
|
||||
|
||||
@ -321,6 +337,8 @@ class BoolGauge(Metric):
|
||||
allows the user to set each value.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, name, description, *labels):
|
||||
"""Creates a new BoolGauge.
|
||||
|
||||
@ -340,6 +358,8 @@ class BoolGauge(Metric):
|
||||
class SamplerCell(object):
|
||||
"""SamplerCell stores each value of a Sampler."""
|
||||
|
||||
__slots__ = ["_cell"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new SamplerCell.
|
||||
|
||||
@ -373,6 +393,8 @@ class SamplerCell(object):
|
||||
class Buckets(object):
|
||||
"""Bucketing strategies for the samplers."""
|
||||
|
||||
__slots__ = ["buckets"]
|
||||
|
||||
def __init__(self, buckets):
|
||||
"""Creates a new Buckets.
|
||||
|
||||
@ -393,6 +415,8 @@ class ExponentialBuckets(Buckets):
|
||||
scale * growth_factor^(i + 1), ..., DBL_MAX].
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, scale, growth_factor, bucket_count):
|
||||
"""Creates a new exponential Buckets.
|
||||
|
||||
@ -415,6 +439,8 @@ class Sampler(Metric):
|
||||
user to add a sample to each histogram value.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, name, buckets, description, *labels):
|
||||
"""Creates a new Sampler.
|
||||
|
||||
@ -435,6 +461,8 @@ class Sampler(Metric):
|
||||
class MonitoredTimer(object):
|
||||
"""A context manager to measure the walltime and increment a Counter cell."""
|
||||
|
||||
__slots__ = ["cell", "t"]
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new MonitoredTimer.
|
||||
|
||||
|
@ -35,6 +35,8 @@ distribution_strategy_context = LazyLoader(
|
||||
class Tape(object):
|
||||
"""Represents a gradient propagation trace."""
|
||||
|
||||
__slots__ = ["_tape"]
|
||||
|
||||
def __init__(self, tape):
|
||||
self._tape = tape
|
||||
|
||||
@ -72,6 +74,8 @@ class VariableWatcher(object):
|
||||
assert variable_watcher.watched_variables == [var]
|
||||
"""
|
||||
|
||||
__slots__ = ["_variable_watcher"]
|
||||
|
||||
def __init__(self):
|
||||
self._variable_watcher = None
|
||||
|
||||
|
@ -175,6 +175,9 @@ class AutomaticControlDependencies(object):
|
||||
NOT THREAD SAFE
|
||||
"""
|
||||
|
||||
__slots__ = ["_returned_tensors", "ops_which_must_run", "_graph",
|
||||
"_n_operations", "collective_manager_ids_used"]
|
||||
|
||||
def __init__(self):
|
||||
self._returned_tensors = object_identity.ObjectIdentitySet()
|
||||
self.ops_which_must_run = set()
|
||||
|
@ -29,6 +29,8 @@ from tensorflow.python.util import tf_contextlib
|
||||
class ScopedTFStatus(object):
|
||||
"""Wrapper around TF_Status that handles deletion."""
|
||||
|
||||
__slots__ = ["status"]
|
||||
|
||||
def __init__(self):
|
||||
self.status = c_api.TF_NewStatus()
|
||||
|
||||
@ -42,6 +44,8 @@ class ScopedTFStatus(object):
|
||||
class ScopedTFGraph(object):
|
||||
"""Wrapper around TF_Graph that handles deletion."""
|
||||
|
||||
__slots__ = ["graph", "deleter"]
|
||||
|
||||
def __init__(self):
|
||||
self.graph = c_api.TF_NewGraph()
|
||||
# Note: when we're destructing the global context (i.e when the process is
|
||||
@ -57,6 +61,8 @@ class ScopedTFGraph(object):
|
||||
class ScopedTFImportGraphDefOptions(object):
|
||||
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
||||
|
||||
__slots__ = ["options"]
|
||||
|
||||
def __init__(self):
|
||||
self.options = c_api.TF_NewImportGraphDefOptions()
|
||||
|
||||
@ -70,6 +76,8 @@ class ScopedTFImportGraphDefOptions(object):
|
||||
class ScopedTFImportGraphDefResults(object):
|
||||
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
||||
|
||||
__slots__ = ["results"]
|
||||
|
||||
def __init__(self, results):
|
||||
self.results = results
|
||||
|
||||
@ -83,6 +91,8 @@ class ScopedTFImportGraphDefResults(object):
|
||||
class ScopedTFFunction(object):
|
||||
"""Wrapper around TF_Function that handles deletion."""
|
||||
|
||||
__slots__ = ["func", "deleter"]
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
# Note: when we're destructing the global context (i.e when the process is
|
||||
@ -100,6 +110,8 @@ class ScopedTFFunction(object):
|
||||
class ScopedTFBuffer(object):
|
||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||
|
||||
__slots__ = ["buffer"]
|
||||
|
||||
def __init__(self, buf_string):
|
||||
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
|
||||
|
||||
@ -114,6 +126,8 @@ class ApiDefMap(object):
|
||||
be queried by op name.
|
||||
"""
|
||||
|
||||
__slots__ = ["_api_def_map", "_op_per_name"]
|
||||
|
||||
def __init__(self):
|
||||
op_def_proto = op_def_pb2.OpList()
|
||||
buf = c_api.TF_GetAllOpList()
|
||||
|
@ -112,6 +112,8 @@ class MergeDevice(object):
|
||||
performance of device placement.
|
||||
"""
|
||||
|
||||
__slots__ = ["_spec"]
|
||||
|
||||
def __init__(self, spec):
|
||||
if isinstance(spec, device_spec.DeviceSpecV2):
|
||||
self._spec = spec
|
||||
|
@ -194,6 +194,8 @@ class Defun(object):
|
||||
class _DefinedFunctionDeleter(object):
|
||||
"""Unregister function from eager context."""
|
||||
|
||||
__slots__ = ["name"]
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
@ -2583,6 +2583,8 @@ class RegisterGradient(object):
|
||||
that defines the operation.
|
||||
"""
|
||||
|
||||
__slots__ = ["_op_type"]
|
||||
|
||||
def __init__(self, op_type):
|
||||
"""Creates a new decorator with `op_type` as the Operation type.
|
||||
|
||||
@ -2679,6 +2681,8 @@ class OpStats(object):
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ["_statistic_type", "_value"]
|
||||
|
||||
def __init__(self, statistic_type, value=None):
|
||||
"""Sets up the initial placeholders for the statistics."""
|
||||
self.statistic_type = statistic_type
|
||||
@ -2758,6 +2762,8 @@ class RegisterStatistics(object):
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ["_op_type", "_statistic_type"]
|
||||
|
||||
def __init__(self, op_type, statistic_type):
|
||||
"""Saves the `op_type` as the `Operation` type."""
|
||||
if not isinstance(op_type, six.string_types):
|
||||
@ -5176,6 +5182,8 @@ class enable_auto_cast_variables(object):
|
||||
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
||||
"""
|
||||
|
||||
__slots__ = ["_dtype", "_graph", "_prev_read_dtype"]
|
||||
|
||||
def __init__(self, dtype, graph=None):
|
||||
if dtype and not dtype.is_floating:
|
||||
self._dtype = None
|
||||
@ -6529,6 +6537,8 @@ class name_scope_v1(object): # pylint: disable=invalid-name
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ["_name", "_name_scope"]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
@ -6582,6 +6592,8 @@ class name_scope_v2(object):
|
||||
will generate `MyOp_1/a`, etc.
|
||||
"""
|
||||
|
||||
__slots__ = ["_name", "_exit_fns"]
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize the context manager.
|
||||
|
||||
@ -6918,6 +6930,8 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
|
||||
class _TensorIterator(object):
|
||||
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
|
||||
|
||||
__slots__ = ["_tensor", "_index", "_limit"]
|
||||
|
||||
def __init__(self, tensor, dim0):
|
||||
self._tensor = tensor
|
||||
self._index = 0
|
||||
|
@ -36,6 +36,8 @@ _TYPE_TAG = "type"
|
||||
class Registry(object):
|
||||
"""Provides a registry for saving objects."""
|
||||
|
||||
__slots__ = ["_name", "_registry"]
|
||||
|
||||
def __init__(self, name):
|
||||
"""Creates a new registry."""
|
||||
self._name = name
|
||||
|
@ -67,6 +67,8 @@ def _recursive_apply(tensors, apply_fn):
|
||||
class _ControlOutputCache(object):
|
||||
"""Helper class to manage calculating and caching control_outputs in graph."""
|
||||
|
||||
__slots__ = ["cache"]
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
|
@ -3452,7 +3452,7 @@ _VALUE_SET_CODE_STRING = """
|
||||
>>> print(K.get_value(v))
|
||||
3.0
|
||||
|
||||
Variable semantics in TensorFlow 2 are eager execution friendly. The above
|
||||
Variable semantics in TensorFlow 2 are eager execution friendly. The above
|
||||
code is roughly equivalent to:
|
||||
|
||||
>>> v = tf.Variable(1.)
|
||||
|
@ -54,6 +54,8 @@ class InputSpec(object):
|
||||
a specific dimension value.
|
||||
"""
|
||||
|
||||
__slots__ = ['dtype', 'shape', 'ndim', 'max_ndim', 'min_ndim', 'axes']
|
||||
|
||||
def __init__(self,
|
||||
dtype=None,
|
||||
shape=None,
|
||||
|
@ -49,6 +49,8 @@ class _UnwrapPreventer(object):
|
||||
unwrapped by DistributionStrategy
|
||||
"""
|
||||
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
@ -719,6 +719,8 @@ class ConversionNotImplementedError(Exception):
|
||||
class _PforInput(object):
|
||||
"""Input object passed to registered pfor converters."""
|
||||
|
||||
__slots__ = ["pfor", "_op", "_inputs"]
|
||||
|
||||
def __init__(self, pfor, op, inputs):
|
||||
"""Creates a _PforInput object.
|
||||
|
||||
|
@ -265,6 +265,8 @@ class EagerResourceDeleter(object):
|
||||
the cycle will be collectable.
|
||||
"""
|
||||
|
||||
__slots__ = ["_handle", "_handle_device", "_context"]
|
||||
|
||||
def __init__(self, handle, handle_device):
|
||||
if not isinstance(handle, ops.Tensor):
|
||||
raise ValueError(
|
||||
|
@ -63,6 +63,8 @@ _api_usage_gauge = monitoring.BoolGauge(
|
||||
class _PartitionInfo(object):
|
||||
"""Holds partition info used by initializer functions."""
|
||||
|
||||
__slots__ = ["_full_shape", "_var_offset"]
|
||||
|
||||
def __init__(self, full_shape, var_offset):
|
||||
"""Constructor.
|
||||
|
||||
@ -279,6 +281,8 @@ class _VariableStore(object):
|
||||
the corresponding TensorFlow Variables as values.
|
||||
"""
|
||||
|
||||
__slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
|
||||
|
||||
def __init__(self):
|
||||
"""Create a variable store."""
|
||||
self._vars = {} # A dictionary of the stored TensorFlow variables.
|
||||
|
@ -41,6 +41,8 @@ from tensorflow.python.util import nest
|
||||
class _SingleDeviceSaver(object):
|
||||
"""Saves and restores checkpoints from the current device."""
|
||||
|
||||
__slots__ = ["_saveable_objects"]
|
||||
|
||||
def __init__(self, saveable_objects):
|
||||
"""Specify a list of `SaveableObject`s to save and restore.
|
||||
|
||||
|
@ -553,6 +553,8 @@ def _ready(op, sess, msg):
|
||||
|
||||
class _CountDownTimer(object):
|
||||
|
||||
__slots__ = ["_start_time_secs", "_duration_secs"]
|
||||
|
||||
def __init__(self, duration_secs):
|
||||
self._start_time_secs = time.time()
|
||||
self._duration_secs = duration_secs
|
||||
|
@ -190,6 +190,8 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
||||
class CheckpointPosition(object):
|
||||
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
|
||||
|
||||
__slots__ = ["_checkpoint", "_proto_id"]
|
||||
|
||||
def __init__(self, checkpoint, proto_id):
|
||||
"""Specify an object within a checkpoint.
|
||||
|
||||
|
@ -59,6 +59,8 @@ class NoDependency(object):
|
||||
variables will appear in `Model.variables`).
|
||||
"""
|
||||
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
@ -140,6 +140,8 @@ def delete_tracking(obj, name):
|
||||
class ResourceTracker(object):
|
||||
"""An object that tracks a list of resources."""
|
||||
|
||||
__slots__ = ["_resources"]
|
||||
|
||||
def __init__(self):
|
||||
self._resources = []
|
||||
|
||||
@ -183,6 +185,8 @@ def resource_tracker_scope(resource_tracker):
|
||||
class CapturableResourceDeleter(object):
|
||||
"""Deleter to destroy CapturableResource without overriding its __del__()."""
|
||||
|
||||
__slots__ = ["_destruction_context", "_destroy_resource"]
|
||||
|
||||
def __init__(self, destroy_resource_fn=None):
|
||||
if destroy_resource_fn:
|
||||
self._destroy_resource = destroy_resource_fn
|
||||
|
@ -86,6 +86,8 @@ class _ObjectGraphProtoPrettyPrinter(object):
|
||||
repeated naming is cheap after the first.
|
||||
"""
|
||||
|
||||
__slots__ = ["_object_graph_proto", "_node_name_cache"]
|
||||
|
||||
def __init__(self, object_graph_proto):
|
||||
self._object_graph_proto = object_graph_proto
|
||||
self._node_name_cache = None
|
||||
@ -124,6 +126,9 @@ class _ObjectGraphProtoPrettyPrinter(object):
|
||||
class _CheckpointRestoreCoordinatorDeleter(object):
|
||||
"""Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
|
||||
|
||||
__slots__ = ["expect_partial", "object_graph_proto",
|
||||
"matched_proto_ids", "unused_attributes"]
|
||||
|
||||
def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
|
||||
unused_attributes):
|
||||
self.expect_partial = expect_partial
|
||||
|
@ -51,6 +51,8 @@ class GroupLock(object):
|
||||
can also use the `acquire` and `release` method directly.
|
||||
"""
|
||||
|
||||
__slots__ = ["_ready", "_num_groups", "_group_member_counts"]
|
||||
|
||||
def __init__(self, num_groups=2):
|
||||
"""Initialize a group lock.
|
||||
|
||||
@ -116,6 +118,8 @@ class GroupLock(object):
|
||||
class _Context(object):
|
||||
"""Context manager helper for `GroupLock`."""
|
||||
|
||||
__slots__ = ["_lock", "_group_id"]
|
||||
|
||||
def __init__(self, lock, group_id):
|
||||
self._lock = lock
|
||||
self._group_id = group_id
|
||||
|
@ -344,6 +344,8 @@ _same_namedtuples = _pywrap_utils.SameNamedtuples
|
||||
|
||||
class _DotString(object):
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __str__(self):
|
||||
return "."
|
||||
|
||||
|
@ -122,6 +122,8 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
|
||||
and comparing based on the equality of their contents by default).
|
||||
"""
|
||||
|
||||
__slots__ = ["_storage"]
|
||||
|
||||
def __init__(self):
|
||||
self._storage = {}
|
||||
|
||||
@ -171,6 +173,8 @@ class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
||||
class ObjectIdentitySet(collections_abc.MutableSet):
|
||||
"""Like the built-in set, but compares objects with "is"."""
|
||||
|
||||
__slots__ = ["_storage"]
|
||||
|
||||
def __init__(self, *args):
|
||||
self._storage = set(self._wrap_key(obj) for obj in list(*args))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user