Use __slots__ for small classes
This commit is contained in:
parent
4aa879aab5
commit
c9b4679c9b
@ -3770,6 +3770,8 @@ class BatchDataset(UnaryDataset):
|
|||||||
class _NumpyIterator(object):
|
class _NumpyIterator(object):
|
||||||
"""Iterator over a dataset with elements converted to numpy."""
|
"""Iterator over a dataset with elements converted to numpy."""
|
||||||
|
|
||||||
|
__slots__ = ["_iterator"]
|
||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
self._iterator = iter(dataset)
|
self._iterator = iter(dataset)
|
||||||
|
|
||||||
|
@ -522,6 +522,8 @@ class IteratorResourceDeleter(object):
|
|||||||
object is part of a reference cycle, the cycle will be collectable.
|
object is part of a reference cycle, the cycle will be collectable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_deleter", "_handle", "_device", "_eager_mode"]
|
||||||
|
|
||||||
def __init__(self, handle, device, deleter):
|
def __init__(self, handle, device, deleter):
|
||||||
self._deleter = deleter
|
self._deleter = deleter
|
||||||
self._handle = handle
|
self._handle = handle
|
||||||
|
@ -377,6 +377,9 @@ class MultiDeviceIteratorResourceDeleter(object):
|
|||||||
object is part of a reference cycle, the cycle will be collectible.
|
object is part of a reference cycle, the cycle will be collectible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_deleter", "_multi_device_iterator", "_iterators",
|
||||||
|
"_device", "_eager_mode"]
|
||||||
|
|
||||||
def __init__(self, multi_device_iterator, iterators, device, deleter):
|
def __init__(self, multi_device_iterator, iterators, device, deleter):
|
||||||
self._deleter = deleter
|
self._deleter = deleter
|
||||||
self._multi_device_iterator = multi_device_iterator
|
self._multi_device_iterator = multi_device_iterator
|
||||||
|
@ -84,6 +84,8 @@ def resolve(d):
|
|||||||
class _FakeNodeDef(object):
|
class _FakeNodeDef(object):
|
||||||
"""A fake NodeDef for _FakeOperation."""
|
"""A fake NodeDef for _FakeOperation."""
|
||||||
|
|
||||||
|
__slots__ = ["op", "name"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.op = ""
|
self.op = ""
|
||||||
self.name = ""
|
self.name = ""
|
||||||
|
@ -250,6 +250,8 @@ def get_update_replica_id():
|
|||||||
class UpdateContext(object):
|
class UpdateContext(object):
|
||||||
"""Context manager when you are in `update()` or `update_non_slot()`."""
|
"""Context manager when you are in `update()` or `update_non_slot()`."""
|
||||||
|
|
||||||
|
__slots__ = ["_replica_id", "_old_replica_id"]
|
||||||
|
|
||||||
def __init__(self, replica_id):
|
def __init__(self, replica_id):
|
||||||
self._replica_id = replica_id
|
self._replica_id = replica_id
|
||||||
self._old_replica_id = None
|
self._old_replica_id = None
|
||||||
@ -454,6 +456,9 @@ class InputContext(object):
|
|||||||
source etc).
|
source etc).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_num_input_pipelines", "_input_pipeline_id",
|
||||||
|
"_num_replicas_in_sync"]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_input_pipelines=1,
|
num_input_pipelines=1,
|
||||||
input_pipeline_id=0,
|
input_pipeline_id=0,
|
||||||
@ -545,6 +550,8 @@ class ValueContext(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
replica_id_in_sync_group=0,
|
replica_id_in_sync_group=0,
|
||||||
num_replicas_in_sync=1):
|
num_replicas_in_sync=1):
|
||||||
@ -2923,6 +2930,8 @@ class _DefaultDistributionStrategy(StrategyV1):
|
|||||||
class _DefaultDistributionContext(object):
|
class _DefaultDistributionContext(object):
|
||||||
"""Context manager setting the default `tf.distribute.Strategy`."""
|
"""Context manager setting the default `tf.distribute.Strategy`."""
|
||||||
|
|
||||||
|
__slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
|
||||||
|
|
||||||
def __init__(self, strategy):
|
def __init__(self, strategy):
|
||||||
|
|
||||||
def creator(next_creator, **kwargs):
|
def creator(next_creator, **kwargs):
|
||||||
|
@ -24,6 +24,8 @@ from tensorflow.python import pywrap_tfe
|
|||||||
class CancellationManager(object):
|
class CancellationManager(object):
|
||||||
"""A mechanism for cancelling blocking computation."""
|
"""A mechanism for cancelling blocking computation."""
|
||||||
|
|
||||||
|
__slots__ = ["_impl"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._impl = pywrap_tfe.TFE_NewCancellationManager()
|
self._impl = pywrap_tfe.TFE_NewCancellationManager()
|
||||||
|
|
||||||
|
@ -80,6 +80,8 @@ _python_eager_context_create_counter = monitoring.Counter(
|
|||||||
class _EagerTensorCache(object):
|
class _EagerTensorCache(object):
|
||||||
"""Simple cache which evicts items based on length in a FIFO manner."""
|
"""Simple cache which evicts items based on length in a FIFO manner."""
|
||||||
|
|
||||||
|
__slots__ = ["_data", "_max_items", "_max_tensor_size"]
|
||||||
|
|
||||||
def __init__(self, max_items=256, max_tensor_size=10000):
|
def __init__(self, max_items=256, max_tensor_size=10000):
|
||||||
self._data = collections.OrderedDict()
|
self._data = collections.OrderedDict()
|
||||||
self._max_items = max_items
|
self._max_items = max_items
|
||||||
@ -107,6 +109,8 @@ class FunctionCallOptions(object):
|
|||||||
Eager functions are functions decorated with tf.contrib.eager.defun.
|
Eager functions are functions decorated with tf.contrib.eager.defun.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_config_proto_serialized", "_executor_type"]
|
||||||
|
|
||||||
def __init__(self, executor_type=None, config_proto=None):
|
def __init__(self, executor_type=None, config_proto=None):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
@ -161,6 +165,8 @@ _tensor_caches_map = {}
|
|||||||
class _TensorCaches(threading.local):
|
class _TensorCaches(threading.local):
|
||||||
"""Thread local tensor caches."""
|
"""Thread local tensor caches."""
|
||||||
|
|
||||||
|
__slots__ = ["_ones_rank_cache", "_zeros_cache"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(_TensorCaches, self).__init__()
|
super(_TensorCaches, self).__init__()
|
||||||
self._ones_rank_cache = None
|
self._ones_rank_cache = None
|
||||||
@ -316,6 +322,8 @@ class PhysicalDevice(
|
|||||||
class _AtomicCounter(object):
|
class _AtomicCounter(object):
|
||||||
"""A simple atomic counter."""
|
"""A simple atomic counter."""
|
||||||
|
|
||||||
|
__slots__ = ["_value", "_lock"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._value = 0
|
self._value = 0
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
@ -332,6 +340,8 @@ _context_id_counter = _AtomicCounter()
|
|||||||
class _TensorCacheDeleter(object):
|
class _TensorCacheDeleter(object):
|
||||||
"""Deletes tensor caches for a given context."""
|
"""Deletes tensor caches for a given context."""
|
||||||
|
|
||||||
|
__slots__ = ["_context_id"]
|
||||||
|
|
||||||
def __init__(self, context_id):
|
def __init__(self, context_id):
|
||||||
self._context_id = context_id
|
self._context_id = context_id
|
||||||
|
|
||||||
@ -1730,6 +1740,8 @@ class Context(object):
|
|||||||
class _EagerDeviceContext(object):
|
class _EagerDeviceContext(object):
|
||||||
"""Context-manager forcing placement of ops and Tensors on a device."""
|
"""Context-manager forcing placement of ops and Tensors on a device."""
|
||||||
|
|
||||||
|
__slots__ = ["_device_name", "_ctx", "_stack"]
|
||||||
|
|
||||||
def __init__(self, ctx, device_name):
|
def __init__(self, ctx, device_name):
|
||||||
self._device_name = device_name
|
self._device_name = device_name
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
|
@ -54,6 +54,8 @@ FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
|||||||
class _CallCounter(object):
|
class _CallCounter(object):
|
||||||
"""Class keeping track of how many recent calls triggered tracing."""
|
"""Class keeping track of how many recent calls triggered tracing."""
|
||||||
|
|
||||||
|
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
|
||||||
|
|
||||||
def __init__(self, max_call_history):
|
def __init__(self, max_call_history):
|
||||||
self._max_call_history = max_call_history
|
self._max_call_history = max_call_history
|
||||||
self._calls_per_tracings = []
|
self._calls_per_tracings = []
|
||||||
@ -84,6 +86,8 @@ class _CallCounter(object):
|
|||||||
class _FrequentTracingDetector(object):
|
class _FrequentTracingDetector(object):
|
||||||
"""Class for frequent retracing detection and warning."""
|
"""Class for frequent retracing detection and warning."""
|
||||||
|
|
||||||
|
__slots__ = ["_counters", "_lock"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
@ -428,6 +432,8 @@ def functions_run_eagerly():
|
|||||||
|
|
||||||
class FunctionDeleter(object):
|
class FunctionDeleter(object):
|
||||||
|
|
||||||
|
__slots__ = ["func_graph"]
|
||||||
|
|
||||||
def __init__(self, func_graph):
|
def __init__(self, func_graph):
|
||||||
self.func_graph = func_graph
|
self.func_graph = func_graph
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ class Executor(object):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_handle"]
|
||||||
|
|
||||||
def __init__(self, handle):
|
def __init__(self, handle):
|
||||||
self._handle = handle
|
self._handle = handle
|
||||||
|
|
||||||
|
@ -262,6 +262,8 @@ def _parse_func_attrs(attributes):
|
|||||||
class _InterpolateFunctionError(object):
|
class _InterpolateFunctionError(object):
|
||||||
"""Context Manager that interpolates the exception from 'top_level_func'."""
|
"""Context Manager that interpolates the exception from 'top_level_func'."""
|
||||||
|
|
||||||
|
__slots__ = ["_func"]
|
||||||
|
|
||||||
def __init__(self, top_level_func):
|
def __init__(self, top_level_func):
|
||||||
self._func = top_level_func
|
self._func = top_level_func
|
||||||
|
|
||||||
@ -378,6 +380,8 @@ def _enclosing_xla_context():
|
|||||||
class _EagerDefinedFunctionDeleter(object):
|
class _EagerDefinedFunctionDeleter(object):
|
||||||
"""Unregister function from eager context."""
|
"""Unregister function from eager context."""
|
||||||
|
|
||||||
|
__slots__ = ["name"]
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@ -1410,6 +1414,9 @@ _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
|
|||||||
class _ForwardBackwardCall(object):
|
class _ForwardBackwardCall(object):
|
||||||
"""Holds the state of a function call between execution and recording."""
|
"""Holds the state of a function call between execution and recording."""
|
||||||
|
|
||||||
|
__slots__ = ["_functions", "_inference_args", "_input_tangents",
|
||||||
|
"_tape_watching"]
|
||||||
|
|
||||||
def __init__(self, functions, inference_args, input_tangents, tape_watching):
|
def __init__(self, functions, inference_args, input_tangents, tape_watching):
|
||||||
"""Collects information about the function call.
|
"""Collects information about the function call.
|
||||||
|
|
||||||
@ -2740,6 +2747,9 @@ class FunctionCache(object):
|
|||||||
"""A lightweight container for cached functions.
|
"""A lightweight container for cached functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["missed", "primary", "arg_relaxed_specs",
|
||||||
|
"arg_relaxed", "_garbage_collectors"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# The set of functions that have been missed; entries are CacheKey with
|
# The set of functions that have been missed; entries are CacheKey with
|
||||||
# input_signature `None` (e.g. a "call context key")
|
# input_signature `None` (e.g. a "call context key")
|
||||||
@ -3771,6 +3781,8 @@ def class_method_to_instance_method(original_function, instance):
|
|||||||
class _FunctionGarbageCollector(object):
|
class _FunctionGarbageCollector(object):
|
||||||
"""Cleans up cycles when a defun goes out of scope."""
|
"""Cleans up cycles when a defun goes out of scope."""
|
||||||
|
|
||||||
|
__slots__ = ["_cache"]
|
||||||
|
|
||||||
def __init__(self, cache):
|
def __init__(self, cache):
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
|
|
||||||
@ -3788,6 +3800,8 @@ class _FunctionGarbageCollector(object):
|
|||||||
class ConcreteFunctionGarbageCollector(object):
|
class ConcreteFunctionGarbageCollector(object):
|
||||||
"""Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
|
"""Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
|
||||||
|
|
||||||
|
__slots__ = ["_func_graph"]
|
||||||
|
|
||||||
def __init__(self, func_graph):
|
def __init__(self, func_graph):
|
||||||
self._func_graph = func_graph
|
self._func_graph = func_graph
|
||||||
|
|
||||||
@ -3807,6 +3821,8 @@ class ConcreteFunctionGarbageCollector(object):
|
|||||||
class _Marker(object):
|
class _Marker(object):
|
||||||
"""Markers used to pretty-print nested args in function signatures."""
|
"""Markers used to pretty-print nested args in function signatures."""
|
||||||
|
|
||||||
|
__slots__ = ["_s"]
|
||||||
|
|
||||||
def __init__(self, s):
|
def __init__(self, s):
|
||||||
self._s = s
|
self._s = s
|
||||||
|
|
||||||
|
@ -104,6 +104,8 @@ _sampler_methods = [
|
|||||||
class Metric(object):
|
class Metric(object):
|
||||||
"""The base class of metric."""
|
"""The base class of metric."""
|
||||||
|
|
||||||
|
__slots__ = ["_metric", "_metric_name", "_metric_methods", "_label_length"]
|
||||||
|
|
||||||
def __init__(self, metric_name, metric_methods, label_length, *args):
|
def __init__(self, metric_name, metric_methods, label_length, *args):
|
||||||
"""Creates a new metric.
|
"""Creates a new metric.
|
||||||
|
|
||||||
@ -145,6 +147,8 @@ class Metric(object):
|
|||||||
class CounterCell(object):
|
class CounterCell(object):
|
||||||
"""CounterCell stores each value of a Counter."""
|
"""CounterCell stores each value of a Counter."""
|
||||||
|
|
||||||
|
__slots__ = ["_cell"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new CounterCell.
|
"""Creates a new CounterCell.
|
||||||
|
|
||||||
@ -174,6 +178,8 @@ class Counter(Metric):
|
|||||||
user to increment each value.
|
user to increment each value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, name, description, *labels):
|
def __init__(self, name, description, *labels):
|
||||||
"""Creates a new Counter.
|
"""Creates a new Counter.
|
||||||
|
|
||||||
@ -193,6 +199,8 @@ class Counter(Metric):
|
|||||||
class IntGaugeCell(object):
|
class IntGaugeCell(object):
|
||||||
"""A single integer value stored in an `IntGauge`."""
|
"""A single integer value stored in an `IntGauge`."""
|
||||||
|
|
||||||
|
__slots__ = ["_cell"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new IntGaugeCell.
|
"""Creates a new IntGaugeCell.
|
||||||
|
|
||||||
@ -222,6 +230,8 @@ class IntGauge(Metric):
|
|||||||
allows the user to set each value.
|
allows the user to set each value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, name, description, *labels):
|
def __init__(self, name, description, *labels):
|
||||||
"""Creates a new IntGauge.
|
"""Creates a new IntGauge.
|
||||||
|
|
||||||
@ -241,6 +251,8 @@ class IntGauge(Metric):
|
|||||||
class StringGaugeCell(object):
|
class StringGaugeCell(object):
|
||||||
"""A single string value stored in an `StringGauge`."""
|
"""A single string value stored in an `StringGauge`."""
|
||||||
|
|
||||||
|
__slots__ = ["_cell"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new StringGaugeCell.
|
"""Creates a new StringGaugeCell.
|
||||||
|
|
||||||
@ -273,6 +285,8 @@ class StringGauge(Metric):
|
|||||||
allows the user to set each value.
|
allows the user to set each value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, name, description, *labels):
|
def __init__(self, name, description, *labels):
|
||||||
"""Creates a new StringGauge.
|
"""Creates a new StringGauge.
|
||||||
|
|
||||||
@ -292,6 +306,8 @@ class StringGauge(Metric):
|
|||||||
class BoolGaugeCell(object):
|
class BoolGaugeCell(object):
|
||||||
"""A single boolean value stored in an `BoolGauge`."""
|
"""A single boolean value stored in an `BoolGauge`."""
|
||||||
|
|
||||||
|
__slots__ = ["_cell"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new BoolGaugeCell.
|
"""Creates a new BoolGaugeCell.
|
||||||
|
|
||||||
@ -321,6 +337,8 @@ class BoolGauge(Metric):
|
|||||||
allows the user to set each value.
|
allows the user to set each value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, name, description, *labels):
|
def __init__(self, name, description, *labels):
|
||||||
"""Creates a new BoolGauge.
|
"""Creates a new BoolGauge.
|
||||||
|
|
||||||
@ -340,6 +358,8 @@ class BoolGauge(Metric):
|
|||||||
class SamplerCell(object):
|
class SamplerCell(object):
|
||||||
"""SamplerCell stores each value of a Sampler."""
|
"""SamplerCell stores each value of a Sampler."""
|
||||||
|
|
||||||
|
__slots__ = ["_cell"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new SamplerCell.
|
"""Creates a new SamplerCell.
|
||||||
|
|
||||||
@ -373,6 +393,8 @@ class SamplerCell(object):
|
|||||||
class Buckets(object):
|
class Buckets(object):
|
||||||
"""Bucketing strategies for the samplers."""
|
"""Bucketing strategies for the samplers."""
|
||||||
|
|
||||||
|
__slots__ = ["buckets"]
|
||||||
|
|
||||||
def __init__(self, buckets):
|
def __init__(self, buckets):
|
||||||
"""Creates a new Buckets.
|
"""Creates a new Buckets.
|
||||||
|
|
||||||
@ -393,6 +415,8 @@ class ExponentialBuckets(Buckets):
|
|||||||
scale * growth_factor^(i + 1), ..., DBL_MAX].
|
scale * growth_factor^(i + 1), ..., DBL_MAX].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, scale, growth_factor, bucket_count):
|
def __init__(self, scale, growth_factor, bucket_count):
|
||||||
"""Creates a new exponential Buckets.
|
"""Creates a new exponential Buckets.
|
||||||
|
|
||||||
@ -415,6 +439,8 @@ class Sampler(Metric):
|
|||||||
user to add a sample to each histogram value.
|
user to add a sample to each histogram value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __init__(self, name, buckets, description, *labels):
|
def __init__(self, name, buckets, description, *labels):
|
||||||
"""Creates a new Sampler.
|
"""Creates a new Sampler.
|
||||||
|
|
||||||
@ -435,6 +461,8 @@ class Sampler(Metric):
|
|||||||
class MonitoredTimer(object):
|
class MonitoredTimer(object):
|
||||||
"""A context manager to measure the walltime and increment a Counter cell."""
|
"""A context manager to measure the walltime and increment a Counter cell."""
|
||||||
|
|
||||||
|
__slots__ = ["cell", "t"]
|
||||||
|
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
"""Creates a new MonitoredTimer.
|
"""Creates a new MonitoredTimer.
|
||||||
|
|
||||||
|
@ -35,6 +35,8 @@ distribution_strategy_context = LazyLoader(
|
|||||||
class Tape(object):
|
class Tape(object):
|
||||||
"""Represents a gradient propagation trace."""
|
"""Represents a gradient propagation trace."""
|
||||||
|
|
||||||
|
__slots__ = ["_tape"]
|
||||||
|
|
||||||
def __init__(self, tape):
|
def __init__(self, tape):
|
||||||
self._tape = tape
|
self._tape = tape
|
||||||
|
|
||||||
@ -72,6 +74,8 @@ class VariableWatcher(object):
|
|||||||
assert variable_watcher.watched_variables == [var]
|
assert variable_watcher.watched_variables == [var]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_variable_watcher"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._variable_watcher = None
|
self._variable_watcher = None
|
||||||
|
|
||||||
|
@ -175,6 +175,9 @@ class AutomaticControlDependencies(object):
|
|||||||
NOT THREAD SAFE
|
NOT THREAD SAFE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_returned_tensors", "ops_which_must_run", "_graph",
|
||||||
|
"_n_operations", "collective_manager_ids_used"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._returned_tensors = object_identity.ObjectIdentitySet()
|
self._returned_tensors = object_identity.ObjectIdentitySet()
|
||||||
self.ops_which_must_run = set()
|
self.ops_which_must_run = set()
|
||||||
|
@ -29,6 +29,8 @@ from tensorflow.python.util import tf_contextlib
|
|||||||
class ScopedTFStatus(object):
|
class ScopedTFStatus(object):
|
||||||
"""Wrapper around TF_Status that handles deletion."""
|
"""Wrapper around TF_Status that handles deletion."""
|
||||||
|
|
||||||
|
__slots__ = ["status"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.status = c_api.TF_NewStatus()
|
self.status = c_api.TF_NewStatus()
|
||||||
|
|
||||||
@ -42,6 +44,8 @@ class ScopedTFStatus(object):
|
|||||||
class ScopedTFGraph(object):
|
class ScopedTFGraph(object):
|
||||||
"""Wrapper around TF_Graph that handles deletion."""
|
"""Wrapper around TF_Graph that handles deletion."""
|
||||||
|
|
||||||
|
__slots__ = ["graph", "deleter"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.graph = c_api.TF_NewGraph()
|
self.graph = c_api.TF_NewGraph()
|
||||||
# Note: when we're destructing the global context (i.e when the process is
|
# Note: when we're destructing the global context (i.e when the process is
|
||||||
@ -57,6 +61,8 @@ class ScopedTFGraph(object):
|
|||||||
class ScopedTFImportGraphDefOptions(object):
|
class ScopedTFImportGraphDefOptions(object):
|
||||||
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
||||||
|
|
||||||
|
__slots__ = ["options"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.options = c_api.TF_NewImportGraphDefOptions()
|
self.options = c_api.TF_NewImportGraphDefOptions()
|
||||||
|
|
||||||
@ -70,6 +76,8 @@ class ScopedTFImportGraphDefOptions(object):
|
|||||||
class ScopedTFImportGraphDefResults(object):
|
class ScopedTFImportGraphDefResults(object):
|
||||||
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
|
||||||
|
|
||||||
|
__slots__ = ["results"]
|
||||||
|
|
||||||
def __init__(self, results):
|
def __init__(self, results):
|
||||||
self.results = results
|
self.results = results
|
||||||
|
|
||||||
@ -83,6 +91,8 @@ class ScopedTFImportGraphDefResults(object):
|
|||||||
class ScopedTFFunction(object):
|
class ScopedTFFunction(object):
|
||||||
"""Wrapper around TF_Function that handles deletion."""
|
"""Wrapper around TF_Function that handles deletion."""
|
||||||
|
|
||||||
|
__slots__ = ["func", "deleter"]
|
||||||
|
|
||||||
def __init__(self, func):
|
def __init__(self, func):
|
||||||
self.func = func
|
self.func = func
|
||||||
# Note: when we're destructing the global context (i.e when the process is
|
# Note: when we're destructing the global context (i.e when the process is
|
||||||
@ -100,6 +110,8 @@ class ScopedTFFunction(object):
|
|||||||
class ScopedTFBuffer(object):
|
class ScopedTFBuffer(object):
|
||||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||||
|
|
||||||
|
__slots__ = ["buffer"]
|
||||||
|
|
||||||
def __init__(self, buf_string):
|
def __init__(self, buf_string):
|
||||||
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
|
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
|
||||||
|
|
||||||
@ -114,6 +126,8 @@ class ApiDefMap(object):
|
|||||||
be queried by op name.
|
be queried by op name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_api_def_map", "_op_per_name"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
op_def_proto = op_def_pb2.OpList()
|
op_def_proto = op_def_pb2.OpList()
|
||||||
buf = c_api.TF_GetAllOpList()
|
buf = c_api.TF_GetAllOpList()
|
||||||
|
@ -112,6 +112,8 @@ class MergeDevice(object):
|
|||||||
performance of device placement.
|
performance of device placement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_spec"]
|
||||||
|
|
||||||
def __init__(self, spec):
|
def __init__(self, spec):
|
||||||
if isinstance(spec, device_spec.DeviceSpecV2):
|
if isinstance(spec, device_spec.DeviceSpecV2):
|
||||||
self._spec = spec
|
self._spec = spec
|
||||||
|
@ -194,6 +194,8 @@ class Defun(object):
|
|||||||
class _DefinedFunctionDeleter(object):
|
class _DefinedFunctionDeleter(object):
|
||||||
"""Unregister function from eager context."""
|
"""Unregister function from eager context."""
|
||||||
|
|
||||||
|
__slots__ = ["name"]
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@ -2583,6 +2583,8 @@ class RegisterGradient(object):
|
|||||||
that defines the operation.
|
that defines the operation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_op_type"]
|
||||||
|
|
||||||
def __init__(self, op_type):
|
def __init__(self, op_type):
|
||||||
"""Creates a new decorator with `op_type` as the Operation type.
|
"""Creates a new decorator with `op_type` as the Operation type.
|
||||||
|
|
||||||
@ -2679,6 +2681,8 @@ class OpStats(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_statistic_type", "_value"]
|
||||||
|
|
||||||
def __init__(self, statistic_type, value=None):
|
def __init__(self, statistic_type, value=None):
|
||||||
"""Sets up the initial placeholders for the statistics."""
|
"""Sets up the initial placeholders for the statistics."""
|
||||||
self.statistic_type = statistic_type
|
self.statistic_type = statistic_type
|
||||||
@ -2758,6 +2762,8 @@ class RegisterStatistics(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_op_type", "_statistic_type"]
|
||||||
|
|
||||||
def __init__(self, op_type, statistic_type):
|
def __init__(self, op_type, statistic_type):
|
||||||
"""Saves the `op_type` as the `Operation` type."""
|
"""Saves the `op_type` as the `Operation` type."""
|
||||||
if not isinstance(op_type, six.string_types):
|
if not isinstance(op_type, six.string_types):
|
||||||
@ -5176,6 +5182,8 @@ class enable_auto_cast_variables(object):
|
|||||||
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_dtype", "_graph", "_prev_read_dtype"]
|
||||||
|
|
||||||
def __init__(self, dtype, graph=None):
|
def __init__(self, dtype, graph=None):
|
||||||
if dtype and not dtype.is_floating:
|
if dtype and not dtype.is_floating:
|
||||||
self._dtype = None
|
self._dtype = None
|
||||||
@ -6529,6 +6537,8 @@ class name_scope_v1(object): # pylint: disable=invalid-name
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_name", "_name_scope"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
@ -6582,6 +6592,8 @@ class name_scope_v2(object):
|
|||||||
will generate `MyOp_1/a`, etc.
|
will generate `MyOp_1/a`, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_name", "_exit_fns"]
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
"""Initialize the context manager.
|
"""Initialize the context manager.
|
||||||
|
|
||||||
@ -6918,6 +6930,8 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
|
|||||||
class _TensorIterator(object):
|
class _TensorIterator(object):
|
||||||
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
|
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
|
||||||
|
|
||||||
|
__slots__ = ["_tensor", "_index", "_limit"]
|
||||||
|
|
||||||
def __init__(self, tensor, dim0):
|
def __init__(self, tensor, dim0):
|
||||||
self._tensor = tensor
|
self._tensor = tensor
|
||||||
self._index = 0
|
self._index = 0
|
||||||
|
@ -36,6 +36,8 @@ _TYPE_TAG = "type"
|
|||||||
class Registry(object):
|
class Registry(object):
|
||||||
"""Provides a registry for saving objects."""
|
"""Provides a registry for saving objects."""
|
||||||
|
|
||||||
|
__slots__ = ["_name", "_registry"]
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
"""Creates a new registry."""
|
"""Creates a new registry."""
|
||||||
self._name = name
|
self._name = name
|
||||||
|
@ -67,6 +67,8 @@ def _recursive_apply(tensors, apply_fn):
|
|||||||
class _ControlOutputCache(object):
|
class _ControlOutputCache(object):
|
||||||
"""Helper class to manage calculating and caching control_outputs in graph."""
|
"""Helper class to manage calculating and caching control_outputs in graph."""
|
||||||
|
|
||||||
|
__slots__ = ["cache"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -54,6 +54,8 @@ class InputSpec(object):
|
|||||||
a specific dimension value.
|
a specific dimension value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ['dtype', 'shape', 'ndim', 'max_ndim', 'min_ndim', 'axes']
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
shape=None,
|
shape=None,
|
||||||
|
@ -49,6 +49,8 @@ class _UnwrapPreventer(object):
|
|||||||
unwrapped by DistributionStrategy
|
unwrapped by DistributionStrategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["value"]
|
||||||
|
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
@ -719,6 +719,8 @@ class ConversionNotImplementedError(Exception):
|
|||||||
class _PforInput(object):
|
class _PforInput(object):
|
||||||
"""Input object passed to registered pfor converters."""
|
"""Input object passed to registered pfor converters."""
|
||||||
|
|
||||||
|
__slots__ = ["pfor", "_op", "_inputs"]
|
||||||
|
|
||||||
def __init__(self, pfor, op, inputs):
|
def __init__(self, pfor, op, inputs):
|
||||||
"""Creates a _PforInput object.
|
"""Creates a _PforInput object.
|
||||||
|
|
||||||
|
@ -265,6 +265,8 @@ class EagerResourceDeleter(object):
|
|||||||
the cycle will be collectable.
|
the cycle will be collectable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_handle", "_handle_device", "_context"]
|
||||||
|
|
||||||
def __init__(self, handle, handle_device):
|
def __init__(self, handle, handle_device):
|
||||||
if not isinstance(handle, ops.Tensor):
|
if not isinstance(handle, ops.Tensor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -63,6 +63,8 @@ _api_usage_gauge = monitoring.BoolGauge(
|
|||||||
class _PartitionInfo(object):
|
class _PartitionInfo(object):
|
||||||
"""Holds partition info used by initializer functions."""
|
"""Holds partition info used by initializer functions."""
|
||||||
|
|
||||||
|
__slots__ = ["_full_shape", "_var_offset"]
|
||||||
|
|
||||||
def __init__(self, full_shape, var_offset):
|
def __init__(self, full_shape, var_offset):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
@ -279,6 +281,8 @@ class _VariableStore(object):
|
|||||||
the corresponding TensorFlow Variables as values.
|
the corresponding TensorFlow Variables as values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Create a variable store."""
|
"""Create a variable store."""
|
||||||
self._vars = {} # A dictionary of the stored TensorFlow variables.
|
self._vars = {} # A dictionary of the stored TensorFlow variables.
|
||||||
|
@ -41,6 +41,8 @@ from tensorflow.python.util import nest
|
|||||||
class _SingleDeviceSaver(object):
|
class _SingleDeviceSaver(object):
|
||||||
"""Saves and restores checkpoints from the current device."""
|
"""Saves and restores checkpoints from the current device."""
|
||||||
|
|
||||||
|
__slots__ = ["_saveable_objects"]
|
||||||
|
|
||||||
def __init__(self, saveable_objects):
|
def __init__(self, saveable_objects):
|
||||||
"""Specify a list of `SaveableObject`s to save and restore.
|
"""Specify a list of `SaveableObject`s to save and restore.
|
||||||
|
|
||||||
|
@ -553,6 +553,8 @@ def _ready(op, sess, msg):
|
|||||||
|
|
||||||
class _CountDownTimer(object):
|
class _CountDownTimer(object):
|
||||||
|
|
||||||
|
__slots__ = ["_start_time_secs", "_duration_secs"]
|
||||||
|
|
||||||
def __init__(self, duration_secs):
|
def __init__(self, duration_secs):
|
||||||
self._start_time_secs = time.time()
|
self._start_time_secs = time.time()
|
||||||
self._duration_secs = duration_secs
|
self._duration_secs = duration_secs
|
||||||
|
@ -190,6 +190,8 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
|||||||
class CheckpointPosition(object):
|
class CheckpointPosition(object):
|
||||||
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
|
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
|
||||||
|
|
||||||
|
__slots__ = ["_checkpoint", "_proto_id"]
|
||||||
|
|
||||||
def __init__(self, checkpoint, proto_id):
|
def __init__(self, checkpoint, proto_id):
|
||||||
"""Specify an object within a checkpoint.
|
"""Specify an object within a checkpoint.
|
||||||
|
|
||||||
|
@ -59,6 +59,8 @@ class NoDependency(object):
|
|||||||
variables will appear in `Model.variables`).
|
variables will appear in `Model.variables`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["value"]
|
||||||
|
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
@ -140,6 +140,8 @@ def delete_tracking(obj, name):
|
|||||||
class ResourceTracker(object):
|
class ResourceTracker(object):
|
||||||
"""An object that tracks a list of resources."""
|
"""An object that tracks a list of resources."""
|
||||||
|
|
||||||
|
__slots__ = ["_resources"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._resources = []
|
self._resources = []
|
||||||
|
|
||||||
@ -183,6 +185,8 @@ def resource_tracker_scope(resource_tracker):
|
|||||||
class CapturableResourceDeleter(object):
|
class CapturableResourceDeleter(object):
|
||||||
"""Deleter to destroy CapturableResource without overriding its __del__()."""
|
"""Deleter to destroy CapturableResource without overriding its __del__()."""
|
||||||
|
|
||||||
|
__slots__ = ["_destruction_context", "_destroy_resource"]
|
||||||
|
|
||||||
def __init__(self, destroy_resource_fn=None):
|
def __init__(self, destroy_resource_fn=None):
|
||||||
if destroy_resource_fn:
|
if destroy_resource_fn:
|
||||||
self._destroy_resource = destroy_resource_fn
|
self._destroy_resource = destroy_resource_fn
|
||||||
|
@ -86,6 +86,8 @@ class _ObjectGraphProtoPrettyPrinter(object):
|
|||||||
repeated naming is cheap after the first.
|
repeated naming is cheap after the first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_object_graph_proto", "_node_name_cache"]
|
||||||
|
|
||||||
def __init__(self, object_graph_proto):
|
def __init__(self, object_graph_proto):
|
||||||
self._object_graph_proto = object_graph_proto
|
self._object_graph_proto = object_graph_proto
|
||||||
self._node_name_cache = None
|
self._node_name_cache = None
|
||||||
@ -124,6 +126,9 @@ class _ObjectGraphProtoPrettyPrinter(object):
|
|||||||
class _CheckpointRestoreCoordinatorDeleter(object):
|
class _CheckpointRestoreCoordinatorDeleter(object):
|
||||||
"""Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
|
"""Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
|
||||||
|
|
||||||
|
__slots__ = ["expect_partial", "object_graph_proto",
|
||||||
|
"matched_proto_ids", "unused_attributes"]
|
||||||
|
|
||||||
def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
|
def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
|
||||||
unused_attributes):
|
unused_attributes):
|
||||||
self.expect_partial = expect_partial
|
self.expect_partial = expect_partial
|
||||||
|
@ -51,6 +51,8 @@ class GroupLock(object):
|
|||||||
can also use the `acquire` and `release` method directly.
|
can also use the `acquire` and `release` method directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_ready", "_num_groups", "_group_member_counts"]
|
||||||
|
|
||||||
def __init__(self, num_groups=2):
|
def __init__(self, num_groups=2):
|
||||||
"""Initialize a group lock.
|
"""Initialize a group lock.
|
||||||
|
|
||||||
@ -116,6 +118,8 @@ class GroupLock(object):
|
|||||||
class _Context(object):
|
class _Context(object):
|
||||||
"""Context manager helper for `GroupLock`."""
|
"""Context manager helper for `GroupLock`."""
|
||||||
|
|
||||||
|
__slots__ = ["_lock", "_group_id"]
|
||||||
|
|
||||||
def __init__(self, lock, group_id):
|
def __init__(self, lock, group_id):
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
self._group_id = group_id
|
self._group_id = group_id
|
||||||
|
@ -344,6 +344,8 @@ _same_namedtuples = _pywrap_utils.SameNamedtuples
|
|||||||
|
|
||||||
class _DotString(object):
|
class _DotString(object):
|
||||||
|
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "."
|
return "."
|
||||||
|
|
||||||
|
@ -122,6 +122,8 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
|
|||||||
and comparing based on the equality of their contents by default).
|
and comparing based on the equality of their contents by default).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_storage"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._storage = {}
|
self._storage = {}
|
||||||
|
|
||||||
@ -171,6 +173,8 @@ class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
|||||||
class ObjectIdentitySet(collections_abc.MutableSet):
|
class ObjectIdentitySet(collections_abc.MutableSet):
|
||||||
"""Like the built-in set, but compares objects with "is"."""
|
"""Like the built-in set, but compares objects with "is"."""
|
||||||
|
|
||||||
|
__slots__ = ["_storage"]
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
self._storage = set(self._wrap_key(obj) for obj in list(*args))
|
self._storage = set(self._wrap_key(obj) for obj in list(*args))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user