Implement several optimizations to reduce graph construction time.
In approximately decreasing order of significance: 1) Cache various to_string, from_string, and string to string functionality in device.py. 2) Optimize DeviceSpec.to_string to reduce unnecessary string copies. 3) _Skip no-op device assignments when creating ops. (When possible.) 4) Remove hash caching in DeviceSpec (since it can now be computed much more cheaply) which allows less aggressive locking. 5) Misc finesse around how high traffic functions (millions of calls). PiperOrigin-RevId: 242996847
This commit is contained in:
parent
c39738021b
commit
8d24f6ae5c
@ -628,6 +628,7 @@ py_library(
|
||||
deps = [
|
||||
":constant_op",
|
||||
":device",
|
||||
":device_spec",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":function",
|
||||
@ -649,6 +650,7 @@ py_library(
|
||||
deps = [
|
||||
":constant_op",
|
||||
":device",
|
||||
":device_spec",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":op_def_library",
|
||||
@ -747,6 +749,15 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "device_spec",
|
||||
srcs = ["framework/device_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "device",
|
||||
srcs = ["framework/device.py"],
|
||||
@ -1668,6 +1679,19 @@ tf_py_test(
|
||||
main = "framework/sparse_tensor_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "framework_device_spec_test",
|
||||
size = "small",
|
||||
srcs = ["framework/device_spec_test.py"],
|
||||
additional_deps = [
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
main = "framework/device_spec_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "framework_device_test",
|
||||
size = "small",
|
||||
@ -3924,6 +3948,7 @@ py_library(
|
||||
":control_flow_ops",
|
||||
":data_flow_ops",
|
||||
":device",
|
||||
":device_spec",
|
||||
":distribute",
|
||||
":errors",
|
||||
":framework",
|
||||
|
@ -51,10 +51,13 @@ def canonicalize(d, default=None):
|
||||
result = tf_device.DeviceSpec(
|
||||
replica=0, task=0, device_type="CPU", device_index=0)
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
result.job = "localhost"
|
||||
result = result.replace(job="localhost")
|
||||
if default:
|
||||
result.merge_from(tf_device.DeviceSpec.from_string(default))
|
||||
result.merge_from(d)
|
||||
result = result.make_merged_spec(
|
||||
tf_device.DeviceSpec.from_string(default))
|
||||
|
||||
# Apply `d` last, so that it takes precidence over the defaults.
|
||||
result = result.make_merged_spec(d)
|
||||
return result.to_string()
|
||||
|
||||
|
||||
@ -83,6 +86,9 @@ class _FakeOperation(object):
|
||||
def _set_device(self, device):
|
||||
self.device = ops._device_string(device) # pylint: disable=protected-access
|
||||
|
||||
def _set_device_from_string(self, device_str):
|
||||
self.device = device_str
|
||||
|
||||
|
||||
def current():
|
||||
"""Return a string (not canonicalized) for the current device."""
|
||||
|
@ -68,7 +68,7 @@ def _enter_graph(g, eager, creator_stack=None):
|
||||
|
||||
def _cpu_device(device):
|
||||
cpu_device = tf_device.DeviceSpec.from_string(device)
|
||||
cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
|
||||
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
|
||||
return cpu_device.to_string()
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ def _is_device_list_local(devices):
|
||||
"""
|
||||
all_local = None
|
||||
for d in devices:
|
||||
d_spec = tf_device.DeviceSpec().parse_from_string(d)
|
||||
d_spec = tf_device.DeviceSpec.from_string(d)
|
||||
is_local = d_spec.job in (None, "localhost")
|
||||
|
||||
if all_local is None: # Determine all_local from first device.
|
||||
@ -345,7 +345,7 @@ def _group_device_list(devices):
|
||||
device_dict = {}
|
||||
|
||||
for d in devices:
|
||||
d_spec = tf_device.DeviceSpec().parse_from_string(d)
|
||||
d_spec = tf_device.DeviceSpec.from_string(d)
|
||||
|
||||
# Create an entry for the task_type.
|
||||
if d_spec.job not in device_dict:
|
||||
@ -361,7 +361,7 @@ def _group_device_list(devices):
|
||||
|
||||
|
||||
def _is_gpu_device(device):
|
||||
return tf_device.DeviceSpec().parse_from_string(device).device_type == "GPU"
|
||||
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
|
||||
|
||||
|
||||
def _infer_num_gpus_per_worker(devices):
|
||||
@ -396,7 +396,7 @@ def _infer_num_gpus_per_worker(devices):
|
||||
raise ValueError("All workers should have the same number of GPUs.")
|
||||
|
||||
for d in device_in_task:
|
||||
d_spec = tf_device.DeviceSpec().parse_from_string(d)
|
||||
d_spec = tf_device.DeviceSpec.from_string(d)
|
||||
if (d_spec.device_type == "GPU" and
|
||||
d_spec.device_index >= num_gpus):
|
||||
raise ValueError("GPU `device_index` on a worker should be "
|
||||
|
@ -86,9 +86,6 @@ class DistributedValuesTest(test.TestCase):
|
||||
self.assertEqual(canonical_cpu, v.devices)
|
||||
v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,))
|
||||
self.assertEqual(canonical_cpu, v.devices)
|
||||
with self.assertRaises(AssertionError):
|
||||
v = values.DistributedValues(
|
||||
values.SingleDeviceMap("/device:cpu:0"), (42,))
|
||||
|
||||
def testIsTensorLike(self):
|
||||
with context.graph_mode(), \
|
||||
|
@ -87,6 +87,7 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:device_spec",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
|
@ -1024,7 +1024,7 @@ class _EagerDeviceContext(object):
|
||||
ctx.ensure_initialized()
|
||||
new_device_spec = pydev.DeviceSpec.from_string(
|
||||
ctx._context_devices[0]) # pylint: disable=protected-access
|
||||
new_device_spec.merge_from(device_spec)
|
||||
new_device_spec = new_device_spec.make_merged_spec(device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string("")
|
||||
new_device_name = new_device_spec.to_string()
|
||||
|
@ -18,275 +18,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import threading
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import device_spec
|
||||
|
||||
@tf_export(v1=["DeviceSpec"])
|
||||
class DeviceSpec(object):
|
||||
"""Represents a (possibly partial) specification for a TensorFlow device.
|
||||
|
||||
`DeviceSpec`s are used throughout TensorFlow to describe where state is stored
|
||||
and computations occur. Using `DeviceSpec` allows you to parse device spec
|
||||
strings to verify their validity, merge them or compose them programmatically.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Place the operations on device "GPU:0" in the "ps" job.
|
||||
device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
|
||||
with tf.device(device_spec):
|
||||
# Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
|
||||
my_var = tf.Variable(..., name="my_variable")
|
||||
squared_var = tf.square(my_var)
|
||||
```
|
||||
|
||||
If a `DeviceSpec` is partially specified, it will be merged with other
|
||||
`DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
|
||||
components defined in inner scopes take precedence over those defined in
|
||||
outer scopes.
|
||||
|
||||
```python
|
||||
with tf.device(DeviceSpec(job="train", )):
|
||||
with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
|
||||
# Nodes created here will be assigned to /job:ps/device:GPU:0.
|
||||
with tf.device(DeviceSpec(device_type="GPU", device_index=1):
|
||||
# Nodes created here will be assigned to /job:train/device:GPU:1.
|
||||
```
|
||||
|
||||
A `DeviceSpec` consists of 5 components -- each of
|
||||
which is optionally specified:
|
||||
|
||||
* Job: The job name.
|
||||
* Replica: The replica index.
|
||||
* Task: The task index.
|
||||
* Device type: The device type string (e.g. "CPU" or "GPU").
|
||||
* Device index: The device index.
|
||||
"""
|
||||
|
||||
def __init__(self, job=None, replica=None, task=None, device_type=None,
|
||||
device_index=None):
|
||||
"""Create a new `DeviceSpec` object.
|
||||
|
||||
Args:
|
||||
job: string. Optional job name.
|
||||
replica: int. Optional replica index.
|
||||
task: int. Optional task index.
|
||||
device_type: Optional device type string (e.g. "CPU" or "GPU")
|
||||
device_index: int. Optional device index. If left
|
||||
unspecified, device represents 'any' device_index.
|
||||
"""
|
||||
self._set_job(job)
|
||||
self._set_replica(replica)
|
||||
self._set_task(task)
|
||||
if device_type == "cpu" or device_type == "gpu":
|
||||
# For backwards compatibility only, we support lowercase variants of
|
||||
# cpu and gpu but turn them into uppercase here.
|
||||
self._set_device_type(device_type.upper())
|
||||
else:
|
||||
self._set_device_type(device_type)
|
||||
self._set_device_index(device_index)
|
||||
self._to_string = self._device_to_string()
|
||||
self._hash = hash(self._to_string)
|
||||
|
||||
def _clear(self):
|
||||
self._job = None
|
||||
self._replica = None
|
||||
self._task = None
|
||||
self._device_type = None
|
||||
self._device_index = None
|
||||
self._to_string = None
|
||||
self._hash = None
|
||||
|
||||
def _sync(self):
|
||||
"""Sync device internal states."""
|
||||
self._to_string = self._device_to_string()
|
||||
self._hash = hash(self._to_string)
|
||||
|
||||
@property
|
||||
def job(self):
|
||||
return self._job
|
||||
|
||||
def _set_job(self, job):
|
||||
if job is not None:
|
||||
self._job = str(job)
|
||||
else:
|
||||
self._job = None
|
||||
|
||||
@job.setter
|
||||
def job(self, job):
|
||||
self._set_job(job)
|
||||
self._sync()
|
||||
|
||||
@property
|
||||
def replica(self):
|
||||
return self._replica
|
||||
|
||||
def _set_replica(self, replica):
|
||||
if replica is not None:
|
||||
self._replica = int(replica)
|
||||
else:
|
||||
self._replica = None
|
||||
|
||||
@replica.setter
|
||||
def replica(self, replica):
|
||||
self._set_replica(replica)
|
||||
self._sync()
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
def _set_task(self, task):
|
||||
if task is not None:
|
||||
self._task = int(task)
|
||||
else:
|
||||
self._task = None
|
||||
|
||||
@task.setter
|
||||
def task(self, task):
|
||||
self._set_task(task)
|
||||
self._sync()
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
return self._device_type
|
||||
|
||||
def _set_device_type(self, device_type):
|
||||
self._device_type = device_type
|
||||
|
||||
@device_type.setter
|
||||
def device_type(self, device_type):
|
||||
self._set_device_type(device_type)
|
||||
self._sync()
|
||||
|
||||
@property
|
||||
def device_index(self):
|
||||
return self._device_index
|
||||
|
||||
def _set_device_index(self, device_index):
|
||||
self._device_index = device_index
|
||||
|
||||
@device_index.setter
|
||||
def device_index(self, device_index):
|
||||
self._set_device_index(device_index)
|
||||
self._sync()
|
||||
|
||||
def parse_from_string(self, spec):
|
||||
"""Parse a `DeviceSpec` name into its components.
|
||||
|
||||
Args:
|
||||
spec: a string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
|
||||
or
|
||||
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
|
||||
as cpu and gpu are mutually exclusive.
|
||||
All entries are optional.
|
||||
|
||||
Returns:
|
||||
The `DeviceSpec`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the spec was not valid.
|
||||
"""
|
||||
self._clear()
|
||||
splits = [x.split(":") for x in spec.split("/")]
|
||||
for y in splits:
|
||||
ly = len(y)
|
||||
if y:
|
||||
if ly == 2 and y[0] == "job":
|
||||
self._set_job(y[1])
|
||||
elif ly == 2 and y[0] == "replica":
|
||||
self._set_replica(y[1])
|
||||
elif ly == 2 and y[0] == "task":
|
||||
self._set_task(y[1])
|
||||
elif ((ly == 1 or ly == 2) and
|
||||
((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
|
||||
if self.device_type is not None:
|
||||
raise ValueError("Cannot specify multiple device types: %s" % spec)
|
||||
self._set_device_type(y[0].upper())
|
||||
if ly == 2 and y[1] != "*":
|
||||
self._set_device_index(int(y[1]))
|
||||
elif ly == 3 and y[0] == "device":
|
||||
if self.device_type is not None:
|
||||
raise ValueError("Cannot specify multiple device types: %s" % spec)
|
||||
self._set_device_type(y[1])
|
||||
if y[2] != "*":
|
||||
self._set_device_index(int(y[2]))
|
||||
elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
|
||||
|
||||
self._sync()
|
||||
|
||||
return self
|
||||
|
||||
def merge_from(self, dev):
|
||||
"""Merge the properties of "dev" into this `DeviceSpec`.
|
||||
|
||||
Args:
|
||||
dev: a `DeviceSpec`.
|
||||
"""
|
||||
if dev.job is not None:
|
||||
self._set_job(dev.job)
|
||||
if dev.replica is not None:
|
||||
self._set_replica(dev.replica)
|
||||
if dev.task is not None:
|
||||
self._set_task(dev.task)
|
||||
if dev.device_type is not None:
|
||||
self._set_device_type(dev.device_type)
|
||||
if dev.device_index is not None:
|
||||
self._set_device_index(dev.device_index)
|
||||
|
||||
self._sync()
|
||||
|
||||
def _device_to_string(self):
|
||||
"""Private method that returns a string representation of `DeviceSpec`."""
|
||||
dev = ""
|
||||
if self.job is not None:
|
||||
dev += "/job:" + self.job
|
||||
if self.replica is not None:
|
||||
dev += "/replica:" + str(self.replica)
|
||||
if self.task is not None:
|
||||
dev += "/task:" + str(self.task)
|
||||
if self.device_type is not None:
|
||||
device_index_string = "*"
|
||||
if self.device_index is not None:
|
||||
device_index_string = str(self.device_index)
|
||||
dev += "/device:%s:%s" % (self.device_type, device_index_string)
|
||||
return dev
|
||||
|
||||
def to_string(self):
|
||||
"""Return a string representation of this `DeviceSpec`.
|
||||
|
||||
Returns:
|
||||
a string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
|
||||
"""
|
||||
return self._to_string
|
||||
|
||||
@staticmethod
|
||||
def from_string(spec):
|
||||
"""Construct a `DeviceSpec` from a string.
|
||||
|
||||
Args:
|
||||
spec: a string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
|
||||
or
|
||||
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
|
||||
as cpu and gpu are mutually exclusive.
|
||||
All entries are optional.
|
||||
|
||||
Returns:
|
||||
A DeviceSpec.
|
||||
"""
|
||||
return DeviceSpec().parse_from_string(spec)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.to_string() == other.to_string()
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
if tf2.enabled():
|
||||
DeviceSpec = device_spec.DeviceSpecV2
|
||||
else:
|
||||
DeviceSpec = device_spec.DeviceSpecV1
|
||||
|
||||
|
||||
def check_valid(spec):
|
||||
@ -302,24 +42,26 @@ def check_valid(spec):
|
||||
DeviceSpec.from_string(spec)
|
||||
|
||||
|
||||
def is_device_spec(obj):
|
||||
"""Abstract away the fact that DeviceSpecV2 is the base class."""
|
||||
return isinstance(obj, device_spec.DeviceSpecV2)
|
||||
|
||||
|
||||
def canonical_name(device):
|
||||
"""Returns a canonical name for the given `DeviceSpec` or device name."""
|
||||
if device is None:
|
||||
return ""
|
||||
if isinstance(device, DeviceSpec):
|
||||
if is_device_spec(device):
|
||||
return device.to_string()
|
||||
else:
|
||||
device = DeviceSpec.from_string(device)
|
||||
return device.to_string()
|
||||
|
||||
|
||||
# Cache from DeviceSpec objects to their corresponding device functions.
|
||||
# This cache is maintained for correctness, not performance: it makes it
|
||||
# possible to compare the device function stacks belonging to different
|
||||
# graphs in a meaningful way.
|
||||
_cached_device_functions = {}
|
||||
_cached_device_specs = {}
|
||||
_cache_lock = threading.Lock()
|
||||
# Performance caches
|
||||
_cached_mergers = {}
|
||||
_cache_lock = threading.RLock()
|
||||
_string_merge_cache = {}
|
||||
|
||||
|
||||
def merge_device(spec):
|
||||
@ -343,29 +85,99 @@ def merge_device(spec):
|
||||
the returned device function's with block.
|
||||
|
||||
Returns:
|
||||
A device function with the above-described behavior.
|
||||
A MergeDevice object with the above-described behavior.
|
||||
|
||||
Raises:
|
||||
ValueError: if the spec was not valid.
|
||||
"""
|
||||
|
||||
if isinstance(spec, MergeDevice):
|
||||
return spec
|
||||
|
||||
with _cache_lock:
|
||||
if not isinstance(spec, DeviceSpec):
|
||||
cached_device_spec = _cached_device_specs.get(spec, None)
|
||||
if cached_device_spec is None:
|
||||
device_spec = DeviceSpec.from_string(spec or "")
|
||||
_cached_device_specs[spec] = device_spec
|
||||
spec = device_spec
|
||||
else:
|
||||
spec = cached_device_spec
|
||||
cached_function = _cached_device_functions.get(spec, None)
|
||||
if cached_function is not None:
|
||||
return cached_function
|
||||
merger = _cached_mergers.get(spec)
|
||||
if merger:
|
||||
return merger
|
||||
|
||||
def _device_function(node_def):
|
||||
current_device = DeviceSpec.from_string(node_def.device or "")
|
||||
copy_spec = copy.copy(spec)
|
||||
copy_spec.merge_from(current_device) # current_device takes precedence.
|
||||
return copy_spec
|
||||
merger = MergeDevice(spec)
|
||||
_cached_mergers[spec] = merger
|
||||
return merger
|
||||
|
||||
_cached_device_functions[spec] = _device_function
|
||||
return _device_function
|
||||
|
||||
class MergeDevice(object):
|
||||
"""Wraps a device specification (DeviceSpec or str) with merge functionality.
|
||||
|
||||
When called, this class will merge a node_def with its own spec. It also
|
||||
exposes a `shortcut_string_merge` method which can significantly improve
|
||||
performance of device placement.
|
||||
"""
|
||||
|
||||
def __init__(self, spec):
|
||||
if isinstance(spec, device_spec.DeviceSpecV2):
|
||||
self._spec = spec
|
||||
elif isinstance(spec, device_spec.DeviceSpecV1):
|
||||
# Capture a snapshot of spec.
|
||||
self._spec = spec.__class__.from_string(spec.to_string())
|
||||
else:
|
||||
self._spec = DeviceSpec.from_string(spec)
|
||||
|
||||
def __call__(self, node_def):
|
||||
# In general a user may create a device function which takes into account
|
||||
# arbitrary properties of an op. (For instance dynamically placing ops based
|
||||
# on type.) So even though the standard DeviceSpec route only uses the
|
||||
# device attribute, we take an entire node_def to maintain a consistent
|
||||
# signature with general device functions.
|
||||
current_device = DeviceSpec.from_string(node_def.device or "")
|
||||
return self._spec.make_merged_spec(current_device)
|
||||
|
||||
def shortcut_string_merge(self, node_def):
|
||||
"""Merge a node def without materializing a full DeviceSpec object.
|
||||
|
||||
Often a device merge is invoked in order to generate a string which can be
|
||||
passed into the c api. In such a case, we can cache the
|
||||
node_def.device -> merge_result_string
|
||||
|
||||
map, and in most cases avoid:
|
||||
- Materializing a copy of self._spec (In the case of DeviceSpecV1)
|
||||
- Materializing a DeviceSpec for node_def.device
|
||||
- A DeviceSpec.merge_from invocation
|
||||
|
||||
In practice the cache hit rate for this function is very high, because the
|
||||
number of invocations when iterating through the device stack is much
|
||||
larger than the number of devices.
|
||||
|
||||
Args:
|
||||
node_def: An Operation (or Operation-like) to merge device constraints
|
||||
with self._spec
|
||||
|
||||
Returns:
|
||||
A string containing the merged device specification.
|
||||
"""
|
||||
device = node_def.device or ""
|
||||
|
||||
merge_key = (self._spec, device)
|
||||
result = _string_merge_cache.get(merge_key)
|
||||
if result is None:
|
||||
# This update is not atomic, however because the merge is stateless
|
||||
# we don't need to lock when updating the cache.
|
||||
result = self.__call__(node_def).to_string()
|
||||
_string_merge_cache[merge_key] = result
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return "{} (spec: {})".format(
|
||||
super(MergeDevice, self).__repr__(), self._spec.to_string())
|
||||
|
||||
@property
|
||||
def is_null_merge(self):
|
||||
"""Indicate whether the wrapped spec is empty.
|
||||
|
||||
In the degenerate case where self._spec is an empty specification, a caller
|
||||
may wish to skip a merge step entirely. (However this class does not have
|
||||
enough information to make that determination.)
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether a device merge will be trivial.
|
||||
"""
|
||||
return not bool(self._spec.to_string())
|
||||
|
428
tensorflow/python/framework/device_spec.py
Normal file
428
tensorflow/python/framework/device_spec.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Class to represent a device."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# == Global Implementation Details =============================================
|
||||
# ==============================================================================
|
||||
_STRING_TO_COMPONENTS_CACHE = {}
|
||||
_COMPONENTS_TO_STRING_CACHE = {}
|
||||
|
||||
|
||||
def _as_str_or_none(inp):
|
||||
return None if inp is None else str(inp)
|
||||
|
||||
|
||||
def _as_int_or_none(inp):
|
||||
return None if inp is None else int(inp)
|
||||
|
||||
|
||||
def _as_device_str_or_none(device_type):
|
||||
# For backwards compatibility only, we support lowercase variants of
|
||||
# cpu and gpu but turn them into uppercase here.
|
||||
if device_type in ("cpu", "gpu"):
|
||||
return device_type.upper()
|
||||
return _as_str_or_none(device_type)
|
||||
|
||||
|
||||
@tf_export("DeviceSpec", v1=[])
|
||||
class DeviceSpecV2(object):
|
||||
"""Represents a (possibly partial) specification for a TensorFlow device.
|
||||
|
||||
`DeviceSpec`s are used throughout TensorFlow to describe where state is stored
|
||||
and computations occur. Using `DeviceSpec` allows you to parse device spec
|
||||
strings to verify their validity, merge them or compose them programmatically.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Place the operations on device "GPU:0" in the "ps" job.
|
||||
device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
|
||||
with tf.device(device_spec):
|
||||
# Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
|
||||
my_var = tf.Variable(..., name="my_variable")
|
||||
squared_var = tf.square(my_var)
|
||||
```
|
||||
|
||||
If a `DeviceSpec` is partially specified, it will be merged with other
|
||||
`DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
|
||||
components defined in inner scopes take precedence over those defined in
|
||||
outer scopes.
|
||||
|
||||
```python
|
||||
with tf.device(DeviceSpec(job="train", )):
|
||||
with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
|
||||
# Nodes created here will be assigned to /job:ps/device:GPU:0.
|
||||
with tf.device(DeviceSpec(device_type="GPU", device_index=1):
|
||||
# Nodes created here will be assigned to /job:train/device:GPU:1.
|
||||
```
|
||||
|
||||
A `DeviceSpec` consists of 5 components -- each of
|
||||
which is optionally specified:
|
||||
|
||||
* Job: The job name.
|
||||
* Replica: The replica index.
|
||||
* Task: The task index.
|
||||
* Device type: The device type string (e.g. "CPU" or "GPU").
|
||||
* Device index: The device index.
|
||||
"""
|
||||
|
||||
__slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
|
||||
"_as_string", "_hash")
|
||||
|
||||
def __init__(self, job=None, replica=None, task=None, device_type=None,
|
||||
device_index=None):
|
||||
"""Create a new `DeviceSpec` object.
|
||||
|
||||
Args:
|
||||
job: string. Optional job name.
|
||||
replica: int. Optional replica index.
|
||||
task: int. Optional task index.
|
||||
device_type: Optional device type string (e.g. "CPU" or "GPU")
|
||||
device_index: int. Optional device index. If left
|
||||
unspecified, device represents 'any' device_index.
|
||||
"""
|
||||
self._job = _as_str_or_none(job)
|
||||
self._replica = _as_int_or_none(replica)
|
||||
self._task = _as_int_or_none(task)
|
||||
self._device_type = _as_device_str_or_none(device_type)
|
||||
self._device_index = _as_int_or_none(device_index)
|
||||
self._as_string = self._components_to_string(
|
||||
job=self._job, replica=self._replica, task=self._task,
|
||||
device_type=self._device_type, device_index=self._device_index)
|
||||
self._hash = hash(self.to_string())
|
||||
|
||||
def to_string(self):
|
||||
"""Return a string representation of this `DeviceSpec`.
|
||||
|
||||
Returns:
|
||||
a string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
|
||||
"""
|
||||
return self._as_string
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, spec):
|
||||
"""Construct a `DeviceSpec` from a string.
|
||||
|
||||
Args:
|
||||
spec: a string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
|
||||
or
|
||||
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
|
||||
as cpu and gpu are mutually exclusive.
|
||||
All entries are optional.
|
||||
|
||||
Returns:
|
||||
A DeviceSpec.
|
||||
"""
|
||||
return cls(*cls._string_to_components(spec))
|
||||
|
||||
def parse_from_string(self, spec):
|
||||
"""Parse a `DeviceSpec` name into its components.
|
||||
|
||||
2.x behavior change:
|
||||
In TensorFlow 1.x, this function mutates its own state and returns itself.
|
||||
In 2.x, DeviceSpecs are immutable, and this function will return a
|
||||
DeviceSpec which contains the spec.
|
||||
|
||||
Recommended:
|
||||
```
|
||||
# my_spec and my_updated_spec are unrelated.
|
||||
my_spec = tf.DeviceSpec.from_string("/CPU:0")
|
||||
my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
|
||||
with tf.device(my_updated_spec):
|
||||
...
|
||||
```
|
||||
|
||||
Will work in 1.x and 2.x (though deprecated in 2.x):
|
||||
```
|
||||
my_spec = tf.DeviceSpec.from_string("/CPU:0")
|
||||
my_updated_spec = my_spec.parse_from_string("/GPU:0")
|
||||
with tf.device(my_updated_spec):
|
||||
...
|
||||
```
|
||||
|
||||
Will NOT work in 2.x:
|
||||
```
|
||||
my_spec = tf.DeviceSpec.from_string("/CPU:0")
|
||||
my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec
|
||||
with tf.device(my_spec):
|
||||
...
|
||||
```
|
||||
|
||||
In general, `DeviceSpec.from_string` should completely replace
|
||||
`DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
|
||||
completely replace setting attributes directly.
|
||||
|
||||
Args:
|
||||
spec: an optional string of the form
|
||||
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
|
||||
or
|
||||
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
|
||||
as cpu and gpu are mutually exclusive.
|
||||
All entries are optional.
|
||||
|
||||
Returns:
|
||||
The `DeviceSpec`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the spec was not valid.
|
||||
"""
|
||||
return self.from_string(spec)
|
||||
|
||||
def make_merged_spec(self, dev):
|
||||
"""Returns a new DeviceSpec which incorporates `dev`.
|
||||
|
||||
When combining specs, `dev` will take precidence over the current spec.
|
||||
So for instance:
|
||||
```
|
||||
first_spec = tf.DeviceSpec(job=0, device_type="CPU")
|
||||
second_spec = tf.DeviceSpec(device_type="GPU")
|
||||
combined_spec = first_spec.make_merged_spec(second_spec)
|
||||
```
|
||||
|
||||
is equivalent to:
|
||||
```
|
||||
combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
|
||||
```
|
||||
|
||||
Args:
|
||||
dev: a `DeviceSpec`
|
||||
|
||||
Returns:
|
||||
A new `DeviceSpec` which combines `self` and `dev`
|
||||
"""
|
||||
return self.__class__(*self._get_combined_properties(dev))
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Convenience method for making a new DeviceSpec by overriding fields.
|
||||
|
||||
For instance:
|
||||
```
|
||||
my_spec = DeviceSpec=(job="my_job", device="CPU")
|
||||
my_updated_spec = my_spec.replace(device="GPU")
|
||||
my_other_spec = my_spec.replace(device=None)
|
||||
```
|
||||
|
||||
Args:
|
||||
**kwargs: This method takes the same args as the DeviceSpec constructor
|
||||
|
||||
Returns:
|
||||
A DeviceSpec with the fields specified in kwargs overridden.
|
||||
"""
|
||||
init_kwargs = dict(
|
||||
job=self.job, replica=self.replica, task=self.task,
|
||||
device_type=self.device_type, device_index=self.device_index)
|
||||
|
||||
# Explicitly provided kwargs take precidence.
|
||||
init_kwargs.update(kwargs)
|
||||
return self.__class__(**init_kwargs)
|
||||
|
||||
@property
|
||||
def job(self):
|
||||
return self._job
|
||||
|
||||
@property
|
||||
def replica(self):
|
||||
return self._replica
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
return self._device_type
|
||||
|
||||
@property
|
||||
def device_index(self):
|
||||
return self._device_index
|
||||
|
||||
def _get_combined_properties(self, dev):
|
||||
"""Combine the current DeviceSpec with another DeviceSpec.
|
||||
|
||||
The combination of DeviceSpecs is will give priority to dev.
|
||||
|
||||
Args:
|
||||
dev: a `DeviceSpec`
|
||||
|
||||
Returns:
|
||||
A tuple of (job, replica, task, device_type, device_index) which
|
||||
represents the combination of self and dev.
|
||||
"""
|
||||
return (
|
||||
dev.job if dev.job is not None else self.job,
|
||||
dev.replica if dev.replica is not None else self.replica,
|
||||
dev.task if dev.task is not None else self.task,
|
||||
dev.device_type if dev.device_type is not None else self.device_type,
|
||||
dev.device_index if dev.device_index is not None else self.device_index,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _string_to_components(spec=None):
|
||||
"""Stateless portion of device spec string parsing.
|
||||
|
||||
Args:
|
||||
spec: An optional string specifying a device specification.
|
||||
|
||||
Returns:
|
||||
The parsed components of `spec`. Note that the result of this function
|
||||
must go through attribute setters of DeviceSpec, and should therefore NOT
|
||||
be used directly.
|
||||
"""
|
||||
cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
raw_spec = spec # keep a copy of the original to update the cache
|
||||
job, replica, task, device_type, device_index = None, None, None, None, None
|
||||
|
||||
spec = spec or ""
|
||||
splits = [x.split(":") for x in spec.split("/")]
|
||||
for y in splits:
|
||||
ly = len(y)
|
||||
if y:
|
||||
# NOTE(taylorrobie): these will go through setters later.
|
||||
if ly == 2 and y[0] == "job":
|
||||
job = y[1]
|
||||
elif ly == 2 and y[0] == "replica":
|
||||
replica = y[1]
|
||||
elif ly == 2 and y[0] == "task":
|
||||
task = y[1]
|
||||
elif ((ly == 1 or ly == 2) and
|
||||
((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
|
||||
if device_type is not None:
|
||||
raise ValueError("Cannot specify multiple device types: %s" % spec)
|
||||
device_type = y[0].upper()
|
||||
if ly == 2 and y[1] != "*":
|
||||
device_index = int(y[1])
|
||||
elif ly == 3 and y[0] == "device":
|
||||
if device_type is not None:
|
||||
raise ValueError("Cannot specify multiple device types: %s" % spec)
|
||||
device_type = y[1]
|
||||
if y[2] != "*":
|
||||
device_index = int(y[2])
|
||||
elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
|
||||
|
||||
output = (job, replica, task, device_type, device_index)
|
||||
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _components_to_string(job, replica, task, device_type, device_index):
|
||||
"""Stateless portion of `to_string` (separated to allow caching)."""
|
||||
key = (job, replica, task, device_type, device_index)
|
||||
cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
output = []
|
||||
if job is not None:
|
||||
output.append("/job:" + job)
|
||||
if replica is not None:
|
||||
output.append("/replica:" + str(replica))
|
||||
if task is not None:
|
||||
output.append("/task:" + str(task))
|
||||
if device_type is not None:
|
||||
device_index_string = "*"
|
||||
if device_index is not None:
|
||||
# Unlike the others, device_index is stored as an int.
|
||||
device_index_string = str(device_index)
|
||||
output.append("/device:%s:%s" % (device_type, device_index_string))
|
||||
|
||||
output = "".join(output)
|
||||
_COMPONENTS_TO_STRING_CACHE[key] = output
|
||||
return output
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.to_string() == other.to_string())
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
|
||||
@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring
|
||||
class DeviceSpecV1(DeviceSpecV2):
|
||||
__doc__ = DeviceSpecV2.__doc__
|
||||
__slots__ = DeviceSpecV2.__slots__
|
||||
|
||||
@DeviceSpecV2.job.setter
|
||||
def job(self, job):
|
||||
self._job = _as_str_or_none(job)
|
||||
self._as_string, self._hash = None, None
|
||||
|
||||
@DeviceSpecV2.replica.setter
|
||||
def replica(self, replica):
|
||||
self._replica = _as_int_or_none(replica)
|
||||
self._as_string, self._hash = None, None
|
||||
|
||||
@DeviceSpecV2.task.setter
|
||||
def task(self, task):
|
||||
self._task = _as_int_or_none(task)
|
||||
self._as_string, self._hash = None, None
|
||||
|
||||
@DeviceSpecV2.device_type.setter
|
||||
def device_type(self, device_type):
|
||||
self._device_type = _as_device_str_or_none(device_type)
|
||||
self._as_string, self._hash = None, None
|
||||
|
||||
@DeviceSpecV2.device_index.setter
|
||||
def device_index(self, device_index):
|
||||
self._device_index = _as_int_or_none(device_index)
|
||||
self._as_string, self._hash = None, None
|
||||
|
||||
def __hash__(self):
|
||||
if self._hash is None:
|
||||
self._hash = hash(self.to_string())
|
||||
return self._hash
|
||||
|
||||
def to_string(self):
|
||||
if self._as_string is None:
|
||||
self._as_string = self._components_to_string(
|
||||
job=self.job, replica=self.replica, task=self.task,
|
||||
device_type=self.device_type, device_index=self.device_index)
|
||||
return self._as_string
|
||||
|
||||
def parse_from_string(self, spec):
|
||||
(self.job, self.replica, self.task, self.device_type, self.device_index
|
||||
) = self._string_to_components(spec)
|
||||
|
||||
return self
|
||||
|
||||
def merge_from(self, dev):
|
||||
"""Merge the properties of "dev" into this `DeviceSpec`.
|
||||
|
||||
Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
|
||||
immutable.
|
||||
|
||||
Args:
|
||||
dev: a `DeviceSpec`.
|
||||
"""
|
||||
(self.job, self.replica, self.task, self.device_type, self.device_index
|
||||
) = self._get_combined_properties(dev)
|
||||
|
||||
# Use parent class docstrings for public methods.
|
||||
to_string.__doc__ = DeviceSpecV2.to_string.__doc__
|
||||
parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__
|
229
tensorflow/python/framework/device_spec_test.py
Normal file
229
tensorflow/python/framework/device_spec_test.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework.device_spec."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import device_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1),
|
||||
("v2", device_spec.DeviceSpecV2))
|
||||
|
||||
|
||||
class DeviceSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def test_empty(self, device_spec_type):
|
||||
d = device_spec_type()
|
||||
self.assertEqual("", d.to_string())
|
||||
d.parse_from_string("")
|
||||
self.assertEqual("", d.to_string())
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def test_constructor(self, device_spec_type):
|
||||
d = device_spec_type(job="j", replica=0, task=1,
|
||||
device_type="CPU", device_index=2)
|
||||
self.assertEqual("j", d.job)
|
||||
self.assertEqual(0, d.replica)
|
||||
self.assertEqual(1, d.task)
|
||||
self.assertEqual("CPU", d.device_type)
|
||||
self.assertEqual(2, d.device_index)
|
||||
self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
|
||||
|
||||
d = device_spec_type(device_type="GPU", device_index=0)
|
||||
self.assertEqual("/device:GPU:0", d.to_string())
|
||||
|
||||
def testto_string_legacy(self):
|
||||
"""DeviceSpecV1 allows direct mutation."""
|
||||
d = device_spec.DeviceSpecV1()
|
||||
d.job = "foo"
|
||||
self.assertEqual("/job:foo", d.to_string())
|
||||
d.task = 3
|
||||
self.assertEqual("/job:foo/task:3", d.to_string())
|
||||
d.device_type = "CPU"
|
||||
d.device_index = 0
|
||||
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
|
||||
d.task = None
|
||||
d.replica = 12
|
||||
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
|
||||
d.device_type = "GPU"
|
||||
d.device_index = 2
|
||||
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
|
||||
d.device_type = "CPU"
|
||||
d.device_index = 1
|
||||
self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())
|
||||
d.device_type = None
|
||||
d.device_index = None
|
||||
self.assertEqual("/job:foo/replica:12", d.to_string())
|
||||
|
||||
# Test wildcard
|
||||
d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3,
|
||||
device_type="GPU")
|
||||
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def test_replace(self, device_spec_type):
|
||||
d = device_spec_type()
|
||||
d = d.replace(job="foo")
|
||||
self.assertEqual("/job:foo", d.to_string())
|
||||
|
||||
d = d.replace(task=3)
|
||||
self.assertEqual("/job:foo/task:3", d.to_string())
|
||||
|
||||
d = d.replace(device_type="CPU", device_index=0)
|
||||
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
|
||||
|
||||
d = d.replace(task=None, replica=12)
|
||||
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
|
||||
|
||||
d = d.replace(device_type="GPU", device_index=2)
|
||||
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
|
||||
|
||||
d = d.replace(device_type="CPU", device_index=1)
|
||||
self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())
|
||||
|
||||
d = d.replace(device_type=None, device_index=None)
|
||||
self.assertEqual("/job:foo/replica:12", d.to_string())
|
||||
|
||||
# Test wildcard
|
||||
d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3,
|
||||
device_type="GPU")
|
||||
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def testto_string(self, device_spec_type):
|
||||
d = device_spec_type(job="foo")
|
||||
self.assertEqual("/job:foo", d.to_string())
|
||||
|
||||
d = device_spec_type(job="foo", task=3)
|
||||
self.assertEqual("/job:foo/task:3", d.to_string())
|
||||
|
||||
d = device_spec_type(job="foo", task=3, device_type="cpu", device_index=0)
|
||||
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
|
||||
|
||||
d = device_spec_type(job="foo", replica=12, device_type="cpu",
|
||||
device_index=0)
|
||||
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
|
||||
|
||||
d = device_spec_type(job="foo", replica=12, device_type="gpu",
|
||||
device_index=2)
|
||||
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
|
||||
|
||||
d = device_spec_type(job="foo", replica=12)
|
||||
self.assertEqual("/job:foo/replica:12", d.to_string())
|
||||
|
||||
# Test wildcard
|
||||
d = device_spec_type(job="foo", replica=12, task=3, device_type="GPU")
|
||||
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
|
||||
|
||||
def test_parse_legacy(self):
|
||||
d = device_spec.DeviceSpecV1()
|
||||
d.parse_from_string("/job:foo/replica:0")
|
||||
self.assertEqual("/job:foo/replica:0", d.to_string())
|
||||
|
||||
d.parse_from_string("/replica:1/task:0/cpu:0")
|
||||
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
|
||||
d.parse_from_string("/replica:1/task:0/device:CPU:0")
|
||||
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
|
||||
d.parse_from_string("/job:muu/device:GPU:2")
|
||||
self.assertEqual("/job:muu/device:GPU:2", d.to_string())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"):
|
||||
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def test_to_from_string(self, device_spec_type):
|
||||
d = device_spec_type.from_string("/job:foo/replica:0")
|
||||
self.assertEqual("/job:foo/replica:0", d.to_string())
|
||||
self.assertEqual(0, d.replica)
|
||||
|
||||
d = device_spec_type.from_string("/replica:1/task:0/cpu:0")
|
||||
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
self.assertAllEqual([1, 0, "CPU", 0],
|
||||
[d.replica, d.task, d.device_type, d.device_index])
|
||||
|
||||
d = device_spec_type.from_string("/replica:1/task:0/device:CPU:0")
|
||||
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
self.assertAllEqual([1, 0, "CPU", 0],
|
||||
[d.replica, d.task, d.device_type, d.device_index])
|
||||
|
||||
d = device_spec_type.from_string("/job:muu/device:GPU:2")
|
||||
self.assertEqual("/job:muu/device:GPU:2", d.to_string())
|
||||
self.assertAllEqual(["muu", "GPU", 2],
|
||||
[d.job, d.device_type, d.device_index])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"):
|
||||
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
|
||||
|
||||
def test_merge_legacy(self):
|
||||
d = device_spec.DeviceSpecV1.from_string("/job:foo/replica:0")
|
||||
self.assertEqual("/job:foo/replica:0", d.to_string())
|
||||
|
||||
d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/device:GPU:2"))
|
||||
self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
|
||||
|
||||
d = device_spec.DeviceSpecV1()
|
||||
d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/cpu:0"))
|
||||
self.assertEqual("/task:1/device:CPU:0", d.to_string())
|
||||
|
||||
d.merge_from(device_spec.DeviceSpecV1.from_string("/job:boo/device:GPU:0"))
|
||||
self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string())
|
||||
|
||||
d.merge_from(device_spec.DeviceSpecV1.from_string("/job:muu/cpu:2"))
|
||||
self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string())
|
||||
d.merge_from(device_spec.DeviceSpecV1.from_string(
|
||||
"/job:muu/device:MyFunnyDevice:2"))
|
||||
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
|
||||
|
||||
def test_merge_removed(self):
|
||||
with self.assertRaises(AttributeError):
|
||||
d = device_spec.DeviceSpecV2()
|
||||
d.merge_from(device_spec.DeviceSpecV2.from_string("/task:1/cpu:0"))
|
||||
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def test_combine(self, device_spec_type):
|
||||
d = device_spec_type.from_string("/job:foo/replica:0")
|
||||
self.assertEqual("/job:foo/replica:0", d.to_string())
|
||||
|
||||
d = d.make_merged_spec(
|
||||
device_spec_type.from_string("/task:1/device:GPU:2"))
|
||||
self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
|
||||
|
||||
d = device_spec_type()
|
||||
d = d.make_merged_spec(device_spec_type.from_string("/task:1/cpu:0"))
|
||||
self.assertEqual("/task:1/device:CPU:0", d.to_string())
|
||||
|
||||
d = d.make_merged_spec(
|
||||
device_spec_type.from_string("/job:boo/device:GPU:0"))
|
||||
self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string())
|
||||
|
||||
d = d.make_merged_spec(device_spec_type.from_string("/job:muu/cpu:2"))
|
||||
self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string())
|
||||
d = d.make_merged_spec(device_spec_type.from_string(
|
||||
"/job:muu/device:MyFunnyDevice:2"))
|
||||
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -18,120 +18,41 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import device_spec
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class DeviceTest(test_util.TensorFlowTestCase):
|
||||
TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1),
|
||||
("v2", device_spec.DeviceSpecV2))
|
||||
|
||||
def testEmpty(self):
|
||||
d = device.DeviceSpec()
|
||||
self.assertEquals("", d.to_string())
|
||||
d.parse_from_string("")
|
||||
self.assertEquals("", d.to_string())
|
||||
|
||||
def testConstructor(self):
|
||||
d = device.DeviceSpec(job="j", replica=0, task=1,
|
||||
device_type="CPU", device_index=2)
|
||||
self.assertEqual("j", d.job)
|
||||
self.assertEqual(0, d.replica)
|
||||
self.assertEqual(1, d.task)
|
||||
self.assertEqual("CPU", d.device_type)
|
||||
self.assertEqual(2, d.device_index)
|
||||
self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
|
||||
class DeviceTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
d = device.DeviceSpec(device_type="GPU", device_index=0)
|
||||
self.assertEquals("/device:GPU:0", d.to_string())
|
||||
|
||||
def testto_string(self):
|
||||
d = device.DeviceSpec()
|
||||
d.job = "foo"
|
||||
self.assertEquals("/job:foo", d.to_string())
|
||||
d.task = 3
|
||||
self.assertEquals("/job:foo/task:3", d.to_string())
|
||||
d.device_type = "CPU"
|
||||
d.device_index = 0
|
||||
self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string())
|
||||
d.task = None
|
||||
d.replica = 12
|
||||
self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string())
|
||||
d.device_type = "GPU"
|
||||
d.device_index = 2
|
||||
self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string())
|
||||
d.device_type = "CPU"
|
||||
d.device_index = 1
|
||||
self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string())
|
||||
d.device_type = None
|
||||
d.device_index = None
|
||||
d.cpu = None
|
||||
self.assertEquals("/job:foo/replica:12", d.to_string())
|
||||
|
||||
# Test wildcard
|
||||
d = device.DeviceSpec(job="foo", replica=12, task=3, device_type="GPU")
|
||||
self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
|
||||
|
||||
def testParse(self):
|
||||
d = device.DeviceSpec()
|
||||
d.parse_from_string("/job:foo/replica:0")
|
||||
self.assertEquals("/job:foo/replica:0", d.to_string())
|
||||
d.parse_from_string("/replica:1/task:0/cpu:0")
|
||||
self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
d.parse_from_string("/replica:1/task:0/device:CPU:0")
|
||||
self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
|
||||
d.parse_from_string("/job:muu/device:GPU:2")
|
||||
self.assertEquals("/job:muu/device:GPU:2", d.to_string())
|
||||
with self.assertRaises(Exception) as e:
|
||||
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
|
||||
self.assertTrue("Cannot specify multiple device" in str(e.exception))
|
||||
|
||||
def testFromString(self):
|
||||
d = device.DeviceSpec.from_string("/job:foo/replica:0")
|
||||
self.assertEquals("/job:foo/replica:0", d.to_string())
|
||||
with self.assertRaises(Exception) as e:
|
||||
d = device.DeviceSpec.from_string("/job:muu/device:GPU:2/cpu:0")
|
||||
self.assertTrue("Cannot specify multiple device" in str(e.exception))
|
||||
|
||||
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/cpu:*")
|
||||
self.assertEquals(None, d.device_index)
|
||||
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/gpu:7")
|
||||
self.assertEquals(7, d.device_index)
|
||||
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/device:GPU:7")
|
||||
self.assertEquals(7, d.device_index)
|
||||
|
||||
def testMerge(self):
|
||||
d = device.DeviceSpec.from_string("/job:foo/replica:0")
|
||||
self.assertEquals("/job:foo/replica:0", d.to_string())
|
||||
d.merge_from(device.DeviceSpec.from_string("/task:1/device:GPU:2"))
|
||||
self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
|
||||
|
||||
d = device.DeviceSpec()
|
||||
d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0"))
|
||||
self.assertEquals("/task:1/device:CPU:0", d.to_string())
|
||||
d.merge_from(device.DeviceSpec.from_string("/job:boo/device:GPU:0"))
|
||||
self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
|
||||
d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2"))
|
||||
self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
|
||||
d.merge_from(device.DeviceSpec.from_string(
|
||||
"/job:muu/device:MyFunnyDevice:2"))
|
||||
self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
|
||||
@parameterized.named_parameters(*TEST_V1_AND_V2)
|
||||
def testMerge(self, DeviceSpec): # pylint: disable=invalid-name
|
||||
d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2")
|
||||
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with ops.device(device.merge_device("/device:GPU:0")):
|
||||
var1 = variables.Variable(1.0)
|
||||
self.assertEquals("/device:GPU:0", var1.device)
|
||||
self.assertEqual("/device:GPU:0", var1.device)
|
||||
with ops.device(device.merge_device("/job:worker")):
|
||||
var2 = variables.Variable(1.0)
|
||||
self.assertEquals("/job:worker/device:GPU:0", var2.device)
|
||||
self.assertEqual("/job:worker/device:GPU:0", var2.device)
|
||||
with ops.device(device.merge_device("/device:CPU:0")):
|
||||
var3 = variables.Variable(1.0)
|
||||
self.assertEquals("/job:worker/device:CPU:0", var3.device)
|
||||
self.assertEqual("/job:worker/device:CPU:0", var3.device)
|
||||
with ops.device(device.merge_device("/job:ps")):
|
||||
var4 = variables.Variable(1.0)
|
||||
self.assertEquals("/job:ps/device:CPU:0", var4.device)
|
||||
self.assertEqual("/job:ps/device:CPU:0", var4.device)
|
||||
|
||||
def testCanonicalName(self):
|
||||
self.assertEqual("/job:foo/replica:0",
|
||||
@ -159,21 +80,17 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
def testCheckValid(self):
|
||||
device.check_valid("/job:foo/replica:0")
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
with self.assertRaisesRegexp(ValueError, "invalid literal for int"):
|
||||
device.check_valid("/job:j/replica:foo")
|
||||
self.assertTrue("invalid literal for int" in str(e.exception))
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
with self.assertRaisesRegexp(ValueError, "invalid literal for int"):
|
||||
device.check_valid("/job:j/task:bar")
|
||||
self.assertTrue("invalid literal for int" in str(e.exception))
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
with self.assertRaisesRegexp(ValueError, "Unknown attribute: 'bar'"):
|
||||
device.check_valid("/bar:muu/baz:2")
|
||||
self.assertTrue("Unknown attribute: 'bar'" in str(e.exception))
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple device"):
|
||||
device.check_valid("/cpu:0/device:GPU:2")
|
||||
self.assertTrue("Cannot specify multiple device" in str(e.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -81,9 +81,15 @@ class _UserDeviceSpec(object):
|
||||
|
||||
def __init__(self, device_name_or_function):
|
||||
self._device_name_or_function = device_name_or_function
|
||||
|
||||
self.display_name = str(self._device_name_or_function)
|
||||
if callable(self._device_name_or_function):
|
||||
self.function = device_name_or_function
|
||||
self.raw_string = None
|
||||
|
||||
if isinstance(device_name_or_function, pydev.MergeDevice):
|
||||
self.is_null_merge = device_name_or_function.is_null_merge
|
||||
|
||||
elif callable(device_name_or_function):
|
||||
self.is_null_merge = False
|
||||
dev_func = self._device_name_or_function
|
||||
func_name = function_utils.get_func_name(dev_func)
|
||||
func_code = function_utils.get_func_code(dev_func)
|
||||
@ -95,13 +101,26 @@ class _UserDeviceSpec(object):
|
||||
lineno = -1
|
||||
self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
|
||||
|
||||
self.raw_string = None
|
||||
elif device_name_or_function is None:
|
||||
# NOTE(taylorrobie): This MUST be False. None signals a break in the
|
||||
# device stack, so `is_null_merge` must be False for such a case to
|
||||
# allow callers to safely skip over null merges without missing a None.
|
||||
self.is_null_merge = False
|
||||
|
||||
self.function = self._device_name_or_function
|
||||
if not (self._device_name_or_function is None or
|
||||
callable(self._device_name_or_function)):
|
||||
self.raw_string = self._device_name_or_function
|
||||
self.function = pydev.merge_device(self._device_name_or_function)
|
||||
else:
|
||||
self.raw_string = device_name_or_function
|
||||
self.function = pydev.merge_device(device_name_or_function)
|
||||
self.is_null_merge = self.function.is_null_merge
|
||||
|
||||
# We perform this check in __init__ because it is of non-trivial cost,
|
||||
# and self.string_merge is typically called many times.
|
||||
self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
|
||||
|
||||
def string_merge(self, node_def):
|
||||
if self.fast_string_merge:
|
||||
return self.function.shortcut_string_merge(node_def)
|
||||
|
||||
return compat.as_str(_device_string(self.function(node_def)))
|
||||
|
||||
|
||||
class NullContextmanager(object):
|
||||
@ -617,7 +636,9 @@ class Tensor(_TensorLike):
|
||||
|
||||
def __eq__(self, other):
|
||||
# Necessary to support Python's collection membership operators
|
||||
return id(self) == id(other)
|
||||
|
||||
# NOTE(taylorrobie): equivalent to: id(self) == id(other)
|
||||
return self is other
|
||||
|
||||
def __copy__(self):
|
||||
# TODO(b/77597810): get rid of Tensor copies.
|
||||
@ -1743,7 +1764,7 @@ IndexedSlicesValue = collections.namedtuple(
|
||||
|
||||
|
||||
def _device_string(dev_spec):
|
||||
if isinstance(dev_spec, pydev.DeviceSpec):
|
||||
if pydev.is_device_spec(dev_spec):
|
||||
return dev_spec.to_string()
|
||||
else:
|
||||
return dev_spec
|
||||
@ -2224,10 +2245,22 @@ class Operation(object):
|
||||
Args:
|
||||
device: string or device.. The device to set.
|
||||
"""
|
||||
self._set_device_from_string(compat.as_str(_device_string(device)))
|
||||
|
||||
def _set_device_from_string(self, device_str):
|
||||
"""Fast path to set device if the type is known to be a string.
|
||||
|
||||
This function is called frequently enough during graph construction that
|
||||
there are non-trivial performance gains if the caller can guarantee that
|
||||
the specified device is already a string.
|
||||
|
||||
Args:
|
||||
device_str: A string specifying where to place this op.
|
||||
"""
|
||||
c_api.SetRequestedDevice(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
self._c_op, # pylint: disable=protected-access
|
||||
compat.as_str(_device_string(device)))
|
||||
device_str)
|
||||
|
||||
def _update_input(self, index, tensor):
|
||||
"""Update the input to this operation at the given index.
|
||||
@ -4493,10 +4526,21 @@ class Graph(object):
|
||||
# We apply here because the result can depend on the Operation's
|
||||
# signature, which is computed in the Operation constructor.
|
||||
# pylint: disable=protected-access
|
||||
prior_device_string = None
|
||||
for device_spec in self._device_function_stack.peek_objs():
|
||||
if device_spec.is_null_merge:
|
||||
continue
|
||||
|
||||
if device_spec.function is None:
|
||||
break
|
||||
op._set_device(device_spec.function(op))
|
||||
|
||||
device_string = device_spec.string_merge(op)
|
||||
|
||||
# Take advantage of the fact that None is a singleton and Python interns
|
||||
# strings, since identity checks are faster than equality checks.
|
||||
if device_string is not prior_device_string:
|
||||
op._set_device_from_string(device_string)
|
||||
prior_device_string = device_string
|
||||
op._device_code_locations = self._snapshot_device_function_stack_metadata()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tfdev
|
||||
from tensorflow.python.framework import dtypes as dtypes_module
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -532,8 +533,13 @@ class _TfDeviceCaptureOp(object):
|
||||
|
||||
def _set_device(self, device):
|
||||
"""This method captures TF's explicit device scope setting."""
|
||||
if tfdev.is_device_spec(device):
|
||||
device = device.to_string()
|
||||
self.device = device
|
||||
|
||||
def _set_device_from_string(self, device_str):
|
||||
self.device = device_str
|
||||
|
||||
|
||||
def _get_current_tf_device():
|
||||
"""Return explicit device of current context, otherwise returns `None`.
|
||||
@ -546,7 +552,7 @@ def _get_current_tf_device():
|
||||
graph = get_graph()
|
||||
op = _TfDeviceCaptureOp()
|
||||
graph._apply_device_functions(op)
|
||||
return op.device
|
||||
return tfdev.DeviceSpec.from_string(op.device)
|
||||
|
||||
|
||||
def _is_current_explicit_device(device_type):
|
||||
|
@ -291,6 +291,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
else:
|
||||
self._device = device
|
||||
|
||||
def _set_device_from_string(self, device_str):
|
||||
self._device = device_str
|
||||
|
||||
if self._outside_compilation_cluster:
|
||||
raise NotImplementedError("Cannot nest outside_compilation clusters")
|
||||
if cluster:
|
||||
|
@ -120,13 +120,13 @@ class _ReplicaDeviceChooser(object):
|
||||
|
||||
current_job, ps_job = current_device.job, ps_device.job
|
||||
if ps_job and (not current_job or current_job == ps_job):
|
||||
ps_device.task = self._ps_strategy(op)
|
||||
ps_device = ps_device.replace(task=self._ps_strategy(op))
|
||||
|
||||
ps_device.merge_from(current_device)
|
||||
ps_device = ps_device.make_merged_spec(current_device)
|
||||
return ps_device.to_string()
|
||||
|
||||
worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
|
||||
worker_device.merge_from(current_device)
|
||||
worker_device = worker_device.make_merged_spec(current_device)
|
||||
return worker_device.to_string()
|
||||
|
||||
|
||||
|
@ -50,8 +50,7 @@ def set_cpu0(device_string):
|
||||
A device string.
|
||||
"""
|
||||
parsed_device = pydev.DeviceSpec.from_string(device_string)
|
||||
parsed_device.device_type = "CPU"
|
||||
parsed_device.device_index = 0
|
||||
parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
|
||||
return parsed_device.to_string()
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.DeviceSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.device.DeviceSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "device_index"
|
||||
@ -28,7 +29,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_string"
|
||||
argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_merged_spec"
|
||||
argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_from"
|
||||
@ -38,6 +43,10 @@ tf_class {
|
||||
name: "parse_from_string"
|
||||
argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "replace"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_string"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
49
tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt
Normal file
49
tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt
Normal file
@ -0,0 +1,49 @@
|
||||
path: "tensorflow.DeviceSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "device_index"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "device_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "job"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "replica"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_string"
|
||||
argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_merged_spec"
|
||||
argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "parse_from_string"
|
||||
argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "replace"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_string"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -12,6 +12,10 @@ tf_module {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DeviceSpec"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "GradientTape"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user