Implement several optimizations to reduce graph construction time.

In approximately decreasing order of significance:

1) Cache various to_string, from_string, and string to string functionality in device.py.

2) Optimize DeviceSpec.to_string to reduce unnecessary string copies.

3) _Skip no-op device assignments when creating ops. (When possible.)

4) Remove hash caching in DeviceSpec (since it can now be computed much more cheaply) which allows less aggressive locking.

5) Misc finesse around how high traffic functions (millions of calls).

PiperOrigin-RevId: 242996847
This commit is contained in:
Taylor Robie 2019-04-10 20:53:31 -07:00 committed by TensorFlower Gardener
parent c39738021b
commit 8d24f6ae5c
18 changed files with 955 additions and 426 deletions

View File

@ -628,6 +628,7 @@ py_library(
deps = [
":constant_op",
":device",
":device_spec",
":dtypes",
":framework_ops",
":function",
@ -649,6 +650,7 @@ py_library(
deps = [
":constant_op",
":device",
":device_spec",
":dtypes",
":framework_ops",
":op_def_library",
@ -747,6 +749,15 @@ py_library(
],
)
py_library(
name = "device_spec",
srcs = ["framework/device_spec.py"],
srcs_version = "PY2AND3",
deps = [
":util",
],
)
py_library(
name = "device",
srcs = ["framework/device.py"],
@ -1668,6 +1679,19 @@ tf_py_test(
main = "framework/sparse_tensor_test.py",
)
tf_py_test(
name = "framework_device_spec_test",
size = "small",
srcs = ["framework/device_spec_test.py"],
additional_deps = [
":framework_for_generated_wrappers",
":framework_test_lib",
":platform_test",
"//tensorflow/core:protos_all_py",
],
main = "framework/device_spec_test.py",
)
tf_py_test(
name = "framework_device_test",
size = "small",
@ -3924,6 +3948,7 @@ py_library(
":control_flow_ops",
":data_flow_ops",
":device",
":device_spec",
":distribute",
":errors",
":framework",

View File

@ -51,10 +51,13 @@ def canonicalize(d, default=None):
result = tf_device.DeviceSpec(
replica=0, task=0, device_type="CPU", device_index=0)
if ops.executing_eagerly_outside_functions():
result.job = "localhost"
result = result.replace(job="localhost")
if default:
result.merge_from(tf_device.DeviceSpec.from_string(default))
result.merge_from(d)
result = result.make_merged_spec(
tf_device.DeviceSpec.from_string(default))
# Apply `d` last, so that it takes precidence over the defaults.
result = result.make_merged_spec(d)
return result.to_string()
@ -83,6 +86,9 @@ class _FakeOperation(object):
def _set_device(self, device):
self.device = ops._device_string(device) # pylint: disable=protected-access
def _set_device_from_string(self, device_str):
self.device = device_str
def current():
"""Return a string (not canonicalized) for the current device."""

View File

@ -68,7 +68,7 @@ def _enter_graph(g, eager, creator_stack=None):
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
return cpu_device.to_string()
@ -297,7 +297,7 @@ def _is_device_list_local(devices):
"""
all_local = None
for d in devices:
d_spec = tf_device.DeviceSpec().parse_from_string(d)
d_spec = tf_device.DeviceSpec.from_string(d)
is_local = d_spec.job in (None, "localhost")
if all_local is None: # Determine all_local from first device.
@ -345,7 +345,7 @@ def _group_device_list(devices):
device_dict = {}
for d in devices:
d_spec = tf_device.DeviceSpec().parse_from_string(d)
d_spec = tf_device.DeviceSpec.from_string(d)
# Create an entry for the task_type.
if d_spec.job not in device_dict:
@ -361,7 +361,7 @@ def _group_device_list(devices):
def _is_gpu_device(device):
return tf_device.DeviceSpec().parse_from_string(device).device_type == "GPU"
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
def _infer_num_gpus_per_worker(devices):
@ -396,7 +396,7 @@ def _infer_num_gpus_per_worker(devices):
raise ValueError("All workers should have the same number of GPUs.")
for d in device_in_task:
d_spec = tf_device.DeviceSpec().parse_from_string(d)
d_spec = tf_device.DeviceSpec.from_string(d)
if (d_spec.device_type == "GPU" and
d_spec.device_index >= num_gpus):
raise ValueError("GPU `device_index` on a worker should be "

View File

@ -86,9 +86,6 @@ class DistributedValuesTest(test.TestCase):
self.assertEqual(canonical_cpu, v.devices)
v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,))
self.assertEqual(canonical_cpu, v.devices)
with self.assertRaises(AssertionError):
v = values.DistributedValues(
values.SingleDeviceMap("/device:cpu:0"), (42,))
def testIsTensorLike(self):
with context.graph_mode(), \

View File

@ -87,6 +87,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:device",
"//tensorflow/python:device_spec",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",

View File

@ -1024,7 +1024,7 @@ class _EagerDeviceContext(object):
ctx.ensure_initialized()
new_device_spec = pydev.DeviceSpec.from_string(
ctx._context_devices[0]) # pylint: disable=protected-access
new_device_spec.merge_from(device_spec)
new_device_spec = new_device_spec.make_merged_spec(device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string("")
new_device_name = new_device_spec.to_string()

View File

@ -18,275 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import threading
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python import tf2
from tensorflow.python.framework import device_spec
@tf_export(v1=["DeviceSpec"])
class DeviceSpec(object):
"""Represents a (possibly partial) specification for a TensorFlow device.
`DeviceSpec`s are used throughout TensorFlow to describe where state is stored
and computations occur. Using `DeviceSpec` allows you to parse device spec
strings to verify their validity, merge them or compose them programmatically.
Example:
```python
# Place the operations on device "GPU:0" in the "ps" job.
device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
with tf.device(device_spec):
# Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
my_var = tf.Variable(..., name="my_variable")
squared_var = tf.square(my_var)
```
If a `DeviceSpec` is partially specified, it will be merged with other
`DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
components defined in inner scopes take precedence over those defined in
outer scopes.
```python
with tf.device(DeviceSpec(job="train", )):
with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
# Nodes created here will be assigned to /job:ps/device:GPU:0.
with tf.device(DeviceSpec(device_type="GPU", device_index=1):
# Nodes created here will be assigned to /job:train/device:GPU:1.
```
A `DeviceSpec` consists of 5 components -- each of
which is optionally specified:
* Job: The job name.
* Replica: The replica index.
* Task: The task index.
* Device type: The device type string (e.g. "CPU" or "GPU").
* Device index: The device index.
"""
def __init__(self, job=None, replica=None, task=None, device_type=None,
device_index=None):
"""Create a new `DeviceSpec` object.
Args:
job: string. Optional job name.
replica: int. Optional replica index.
task: int. Optional task index.
device_type: Optional device type string (e.g. "CPU" or "GPU")
device_index: int. Optional device index. If left
unspecified, device represents 'any' device_index.
"""
self._set_job(job)
self._set_replica(replica)
self._set_task(task)
if device_type == "cpu" or device_type == "gpu":
# For backwards compatibility only, we support lowercase variants of
# cpu and gpu but turn them into uppercase here.
self._set_device_type(device_type.upper())
else:
self._set_device_type(device_type)
self._set_device_index(device_index)
self._to_string = self._device_to_string()
self._hash = hash(self._to_string)
def _clear(self):
self._job = None
self._replica = None
self._task = None
self._device_type = None
self._device_index = None
self._to_string = None
self._hash = None
def _sync(self):
"""Sync device internal states."""
self._to_string = self._device_to_string()
self._hash = hash(self._to_string)
@property
def job(self):
return self._job
def _set_job(self, job):
if job is not None:
self._job = str(job)
else:
self._job = None
@job.setter
def job(self, job):
self._set_job(job)
self._sync()
@property
def replica(self):
return self._replica
def _set_replica(self, replica):
if replica is not None:
self._replica = int(replica)
else:
self._replica = None
@replica.setter
def replica(self, replica):
self._set_replica(replica)
self._sync()
@property
def task(self):
return self._task
def _set_task(self, task):
if task is not None:
self._task = int(task)
else:
self._task = None
@task.setter
def task(self, task):
self._set_task(task)
self._sync()
@property
def device_type(self):
return self._device_type
def _set_device_type(self, device_type):
self._device_type = device_type
@device_type.setter
def device_type(self, device_type):
self._set_device_type(device_type)
self._sync()
@property
def device_index(self):
return self._device_index
def _set_device_index(self, device_index):
self._device_index = device_index
@device_index.setter
def device_index(self, device_index):
self._set_device_index(device_index)
self._sync()
def parse_from_string(self, spec):
"""Parse a `DeviceSpec` name into its components.
Args:
spec: a string of the form
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
or
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
as cpu and gpu are mutually exclusive.
All entries are optional.
Returns:
The `DeviceSpec`.
Raises:
ValueError: if the spec was not valid.
"""
self._clear()
splits = [x.split(":") for x in spec.split("/")]
for y in splits:
ly = len(y)
if y:
if ly == 2 and y[0] == "job":
self._set_job(y[1])
elif ly == 2 and y[0] == "replica":
self._set_replica(y[1])
elif ly == 2 and y[0] == "task":
self._set_task(y[1])
elif ((ly == 1 or ly == 2) and
((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
if self.device_type is not None:
raise ValueError("Cannot specify multiple device types: %s" % spec)
self._set_device_type(y[0].upper())
if ly == 2 and y[1] != "*":
self._set_device_index(int(y[1]))
elif ly == 3 and y[0] == "device":
if self.device_type is not None:
raise ValueError("Cannot specify multiple device types: %s" % spec)
self._set_device_type(y[1])
if y[2] != "*":
self._set_device_index(int(y[2]))
elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
self._sync()
return self
def merge_from(self, dev):
"""Merge the properties of "dev" into this `DeviceSpec`.
Args:
dev: a `DeviceSpec`.
"""
if dev.job is not None:
self._set_job(dev.job)
if dev.replica is not None:
self._set_replica(dev.replica)
if dev.task is not None:
self._set_task(dev.task)
if dev.device_type is not None:
self._set_device_type(dev.device_type)
if dev.device_index is not None:
self._set_device_index(dev.device_index)
self._sync()
def _device_to_string(self):
"""Private method that returns a string representation of `DeviceSpec`."""
dev = ""
if self.job is not None:
dev += "/job:" + self.job
if self.replica is not None:
dev += "/replica:" + str(self.replica)
if self.task is not None:
dev += "/task:" + str(self.task)
if self.device_type is not None:
device_index_string = "*"
if self.device_index is not None:
device_index_string = str(self.device_index)
dev += "/device:%s:%s" % (self.device_type, device_index_string)
return dev
def to_string(self):
"""Return a string representation of this `DeviceSpec`.
Returns:
a string of the form
/job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
"""
return self._to_string
@staticmethod
def from_string(spec):
"""Construct a `DeviceSpec` from a string.
Args:
spec: a string of the form
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
or
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
as cpu and gpu are mutually exclusive.
All entries are optional.
Returns:
A DeviceSpec.
"""
return DeviceSpec().parse_from_string(spec)
def __eq__(self, other):
return self.to_string() == other.to_string()
def __hash__(self):
return self._hash
if tf2.enabled():
DeviceSpec = device_spec.DeviceSpecV2
else:
DeviceSpec = device_spec.DeviceSpecV1
def check_valid(spec):
@ -302,24 +42,26 @@ def check_valid(spec):
DeviceSpec.from_string(spec)
def is_device_spec(obj):
"""Abstract away the fact that DeviceSpecV2 is the base class."""
return isinstance(obj, device_spec.DeviceSpecV2)
def canonical_name(device):
"""Returns a canonical name for the given `DeviceSpec` or device name."""
if device is None:
return ""
if isinstance(device, DeviceSpec):
if is_device_spec(device):
return device.to_string()
else:
device = DeviceSpec.from_string(device)
return device.to_string()
# Cache from DeviceSpec objects to their corresponding device functions.
# This cache is maintained for correctness, not performance: it makes it
# possible to compare the device function stacks belonging to different
# graphs in a meaningful way.
_cached_device_functions = {}
_cached_device_specs = {}
_cache_lock = threading.Lock()
# Performance caches
_cached_mergers = {}
_cache_lock = threading.RLock()
_string_merge_cache = {}
def merge_device(spec):
@ -343,29 +85,99 @@ def merge_device(spec):
the returned device function's with block.
Returns:
A device function with the above-described behavior.
A MergeDevice object with the above-described behavior.
Raises:
ValueError: if the spec was not valid.
"""
if isinstance(spec, MergeDevice):
return spec
with _cache_lock:
if not isinstance(spec, DeviceSpec):
cached_device_spec = _cached_device_specs.get(spec, None)
if cached_device_spec is None:
device_spec = DeviceSpec.from_string(spec or "")
_cached_device_specs[spec] = device_spec
spec = device_spec
else:
spec = cached_device_spec
cached_function = _cached_device_functions.get(spec, None)
if cached_function is not None:
return cached_function
merger = _cached_mergers.get(spec)
if merger:
return merger
def _device_function(node_def):
current_device = DeviceSpec.from_string(node_def.device or "")
copy_spec = copy.copy(spec)
copy_spec.merge_from(current_device) # current_device takes precedence.
return copy_spec
merger = MergeDevice(spec)
_cached_mergers[spec] = merger
return merger
_cached_device_functions[spec] = _device_function
return _device_function
class MergeDevice(object):
"""Wraps a device specification (DeviceSpec or str) with merge functionality.
When called, this class will merge a node_def with its own spec. It also
exposes a `shortcut_string_merge` method which can significantly improve
performance of device placement.
"""
def __init__(self, spec):
if isinstance(spec, device_spec.DeviceSpecV2):
self._spec = spec
elif isinstance(spec, device_spec.DeviceSpecV1):
# Capture a snapshot of spec.
self._spec = spec.__class__.from_string(spec.to_string())
else:
self._spec = DeviceSpec.from_string(spec)
def __call__(self, node_def):
# In general a user may create a device function which takes into account
# arbitrary properties of an op. (For instance dynamically placing ops based
# on type.) So even though the standard DeviceSpec route only uses the
# device attribute, we take an entire node_def to maintain a consistent
# signature with general device functions.
current_device = DeviceSpec.from_string(node_def.device or "")
return self._spec.make_merged_spec(current_device)
def shortcut_string_merge(self, node_def):
"""Merge a node def without materializing a full DeviceSpec object.
Often a device merge is invoked in order to generate a string which can be
passed into the c api. In such a case, we can cache the
node_def.device -> merge_result_string
map, and in most cases avoid:
- Materializing a copy of self._spec (In the case of DeviceSpecV1)
- Materializing a DeviceSpec for node_def.device
- A DeviceSpec.merge_from invocation
In practice the cache hit rate for this function is very high, because the
number of invocations when iterating through the device stack is much
larger than the number of devices.
Args:
node_def: An Operation (or Operation-like) to merge device constraints
with self._spec
Returns:
A string containing the merged device specification.
"""
device = node_def.device or ""
merge_key = (self._spec, device)
result = _string_merge_cache.get(merge_key)
if result is None:
# This update is not atomic, however because the merge is stateless
# we don't need to lock when updating the cache.
result = self.__call__(node_def).to_string()
_string_merge_cache[merge_key] = result
return result
def __repr__(self):
return "{} (spec: {})".format(
super(MergeDevice, self).__repr__(), self._spec.to_string())
@property
def is_null_merge(self):
"""Indicate whether the wrapped spec is empty.
In the degenerate case where self._spec is an empty specification, a caller
may wish to skip a merge step entirely. (However this class does not have
enough information to make that determination.)
Returns:
A boolean indicating whether a device merge will be trivial.
"""
return not bool(self._spec.to_string())

View File

@ -0,0 +1,428 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to represent a device."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
# ==============================================================================
# == Global Implementation Details =============================================
# ==============================================================================
_STRING_TO_COMPONENTS_CACHE = {}
_COMPONENTS_TO_STRING_CACHE = {}
def _as_str_or_none(inp):
return None if inp is None else str(inp)
def _as_int_or_none(inp):
return None if inp is None else int(inp)
def _as_device_str_or_none(device_type):
# For backwards compatibility only, we support lowercase variants of
# cpu and gpu but turn them into uppercase here.
if device_type in ("cpu", "gpu"):
return device_type.upper()
return _as_str_or_none(device_type)
@tf_export("DeviceSpec", v1=[])
class DeviceSpecV2(object):
"""Represents a (possibly partial) specification for a TensorFlow device.
`DeviceSpec`s are used throughout TensorFlow to describe where state is stored
and computations occur. Using `DeviceSpec` allows you to parse device spec
strings to verify their validity, merge them or compose them programmatically.
Example:
```python
# Place the operations on device "GPU:0" in the "ps" job.
device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
with tf.device(device_spec):
# Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
my_var = tf.Variable(..., name="my_variable")
squared_var = tf.square(my_var)
```
If a `DeviceSpec` is partially specified, it will be merged with other
`DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
components defined in inner scopes take precedence over those defined in
outer scopes.
```python
with tf.device(DeviceSpec(job="train", )):
with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
# Nodes created here will be assigned to /job:ps/device:GPU:0.
with tf.device(DeviceSpec(device_type="GPU", device_index=1):
# Nodes created here will be assigned to /job:train/device:GPU:1.
```
A `DeviceSpec` consists of 5 components -- each of
which is optionally specified:
* Job: The job name.
* Replica: The replica index.
* Task: The task index.
* Device type: The device type string (e.g. "CPU" or "GPU").
* Device index: The device index.
"""
__slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
"_as_string", "_hash")
def __init__(self, job=None, replica=None, task=None, device_type=None,
device_index=None):
"""Create a new `DeviceSpec` object.
Args:
job: string. Optional job name.
replica: int. Optional replica index.
task: int. Optional task index.
device_type: Optional device type string (e.g. "CPU" or "GPU")
device_index: int. Optional device index. If left
unspecified, device represents 'any' device_index.
"""
self._job = _as_str_or_none(job)
self._replica = _as_int_or_none(replica)
self._task = _as_int_or_none(task)
self._device_type = _as_device_str_or_none(device_type)
self._device_index = _as_int_or_none(device_index)
self._as_string = self._components_to_string(
job=self._job, replica=self._replica, task=self._task,
device_type=self._device_type, device_index=self._device_index)
self._hash = hash(self.to_string())
def to_string(self):
"""Return a string representation of this `DeviceSpec`.
Returns:
a string of the form
/job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
"""
return self._as_string
@classmethod
def from_string(cls, spec):
"""Construct a `DeviceSpec` from a string.
Args:
spec: a string of the form
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
or
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
as cpu and gpu are mutually exclusive.
All entries are optional.
Returns:
A DeviceSpec.
"""
return cls(*cls._string_to_components(spec))
def parse_from_string(self, spec):
"""Parse a `DeviceSpec` name into its components.
2.x behavior change:
In TensorFlow 1.x, this function mutates its own state and returns itself.
In 2.x, DeviceSpecs are immutable, and this function will return a
DeviceSpec which contains the spec.
Recommended:
```
# my_spec and my_updated_spec are unrelated.
my_spec = tf.DeviceSpec.from_string("/CPU:0")
my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
with tf.device(my_updated_spec):
...
```
Will work in 1.x and 2.x (though deprecated in 2.x):
```
my_spec = tf.DeviceSpec.from_string("/CPU:0")
my_updated_spec = my_spec.parse_from_string("/GPU:0")
with tf.device(my_updated_spec):
...
```
Will NOT work in 2.x:
```
my_spec = tf.DeviceSpec.from_string("/CPU:0")
my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec
with tf.device(my_spec):
...
```
In general, `DeviceSpec.from_string` should completely replace
`DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
completely replace setting attributes directly.
Args:
spec: an optional string of the form
/job:<name>/replica:<id>/task:<id>/device:CPU:<id>
or
/job:<name>/replica:<id>/task:<id>/device:GPU:<id>
as cpu and gpu are mutually exclusive.
All entries are optional.
Returns:
The `DeviceSpec`.
Raises:
ValueError: if the spec was not valid.
"""
return self.from_string(spec)
def make_merged_spec(self, dev):
"""Returns a new DeviceSpec which incorporates `dev`.
When combining specs, `dev` will take precidence over the current spec.
So for instance:
```
first_spec = tf.DeviceSpec(job=0, device_type="CPU")
second_spec = tf.DeviceSpec(device_type="GPU")
combined_spec = first_spec.make_merged_spec(second_spec)
```
is equivalent to:
```
combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
```
Args:
dev: a `DeviceSpec`
Returns:
A new `DeviceSpec` which combines `self` and `dev`
"""
return self.__class__(*self._get_combined_properties(dev))
def replace(self, **kwargs):
"""Convenience method for making a new DeviceSpec by overriding fields.
For instance:
```
my_spec = DeviceSpec=(job="my_job", device="CPU")
my_updated_spec = my_spec.replace(device="GPU")
my_other_spec = my_spec.replace(device=None)
```
Args:
**kwargs: This method takes the same args as the DeviceSpec constructor
Returns:
A DeviceSpec with the fields specified in kwargs overridden.
"""
init_kwargs = dict(
job=self.job, replica=self.replica, task=self.task,
device_type=self.device_type, device_index=self.device_index)
# Explicitly provided kwargs take precidence.
init_kwargs.update(kwargs)
return self.__class__(**init_kwargs)
@property
def job(self):
return self._job
@property
def replica(self):
return self._replica
@property
def task(self):
return self._task
@property
def device_type(self):
return self._device_type
@property
def device_index(self):
return self._device_index
def _get_combined_properties(self, dev):
"""Combine the current DeviceSpec with another DeviceSpec.
The combination of DeviceSpecs is will give priority to dev.
Args:
dev: a `DeviceSpec`
Returns:
A tuple of (job, replica, task, device_type, device_index) which
represents the combination of self and dev.
"""
return (
dev.job if dev.job is not None else self.job,
dev.replica if dev.replica is not None else self.replica,
dev.task if dev.task is not None else self.task,
dev.device_type if dev.device_type is not None else self.device_type,
dev.device_index if dev.device_index is not None else self.device_index,
)
@staticmethod
def _string_to_components(spec=None):
"""Stateless portion of device spec string parsing.
Args:
spec: An optional string specifying a device specification.
Returns:
The parsed components of `spec`. Note that the result of this function
must go through attribute setters of DeviceSpec, and should therefore NOT
be used directly.
"""
cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
if cached_result is not None:
return cached_result
raw_spec = spec # keep a copy of the original to update the cache
job, replica, task, device_type, device_index = None, None, None, None, None
spec = spec or ""
splits = [x.split(":") for x in spec.split("/")]
for y in splits:
ly = len(y)
if y:
# NOTE(taylorrobie): these will go through setters later.
if ly == 2 and y[0] == "job":
job = y[1]
elif ly == 2 and y[0] == "replica":
replica = y[1]
elif ly == 2 and y[0] == "task":
task = y[1]
elif ((ly == 1 or ly == 2) and
((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
if device_type is not None:
raise ValueError("Cannot specify multiple device types: %s" % spec)
device_type = y[0].upper()
if ly == 2 and y[1] != "*":
device_index = int(y[1])
elif ly == 3 and y[0] == "device":
if device_type is not None:
raise ValueError("Cannot specify multiple device types: %s" % spec)
device_type = y[1]
if y[2] != "*":
device_index = int(y[2])
elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
output = (job, replica, task, device_type, device_index)
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output
return output
@staticmethod
def _components_to_string(job, replica, task, device_type, device_index):
"""Stateless portion of `to_string` (separated to allow caching)."""
key = (job, replica, task, device_type, device_index)
cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
if cached_result is not None:
return cached_result
output = []
if job is not None:
output.append("/job:" + job)
if replica is not None:
output.append("/replica:" + str(replica))
if task is not None:
output.append("/task:" + str(task))
if device_type is not None:
device_index_string = "*"
if device_index is not None:
# Unlike the others, device_index is stored as an int.
device_index_string = str(device_index)
output.append("/device:%s:%s" % (device_type, device_index_string))
output = "".join(output)
_COMPONENTS_TO_STRING_CACHE[key] = output
return output
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.to_string() == other.to_string())
def __hash__(self):
return self._hash
@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring
class DeviceSpecV1(DeviceSpecV2):
__doc__ = DeviceSpecV2.__doc__
__slots__ = DeviceSpecV2.__slots__
@DeviceSpecV2.job.setter
def job(self, job):
self._job = _as_str_or_none(job)
self._as_string, self._hash = None, None
@DeviceSpecV2.replica.setter
def replica(self, replica):
self._replica = _as_int_or_none(replica)
self._as_string, self._hash = None, None
@DeviceSpecV2.task.setter
def task(self, task):
self._task = _as_int_or_none(task)
self._as_string, self._hash = None, None
@DeviceSpecV2.device_type.setter
def device_type(self, device_type):
self._device_type = _as_device_str_or_none(device_type)
self._as_string, self._hash = None, None
@DeviceSpecV2.device_index.setter
def device_index(self, device_index):
self._device_index = _as_int_or_none(device_index)
self._as_string, self._hash = None, None
def __hash__(self):
if self._hash is None:
self._hash = hash(self.to_string())
return self._hash
def to_string(self):
if self._as_string is None:
self._as_string = self._components_to_string(
job=self.job, replica=self.replica, task=self.task,
device_type=self.device_type, device_index=self.device_index)
return self._as_string
def parse_from_string(self, spec):
(self.job, self.replica, self.task, self.device_type, self.device_index
) = self._string_to_components(spec)
return self
def merge_from(self, dev):
"""Merge the properties of "dev" into this `DeviceSpec`.
Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
immutable.
Args:
dev: a `DeviceSpec`.
"""
(self.job, self.replica, self.task, self.device_type, self.device_index
) = self._get_combined_properties(dev)
# Use parent class docstrings for public methods.
to_string.__doc__ = DeviceSpecV2.to_string.__doc__
parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__

View File

@ -0,0 +1,229 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.device_spec."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import device_spec
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1),
("v2", device_spec.DeviceSpecV2))
class DeviceSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters(*TEST_V1_AND_V2)
def test_empty(self, device_spec_type):
d = device_spec_type()
self.assertEqual("", d.to_string())
d.parse_from_string("")
self.assertEqual("", d.to_string())
@parameterized.named_parameters(*TEST_V1_AND_V2)
def test_constructor(self, device_spec_type):
d = device_spec_type(job="j", replica=0, task=1,
device_type="CPU", device_index=2)
self.assertEqual("j", d.job)
self.assertEqual(0, d.replica)
self.assertEqual(1, d.task)
self.assertEqual("CPU", d.device_type)
self.assertEqual(2, d.device_index)
self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
d = device_spec_type(device_type="GPU", device_index=0)
self.assertEqual("/device:GPU:0", d.to_string())
def testto_string_legacy(self):
"""DeviceSpecV1 allows direct mutation."""
d = device_spec.DeviceSpecV1()
d.job = "foo"
self.assertEqual("/job:foo", d.to_string())
d.task = 3
self.assertEqual("/job:foo/task:3", d.to_string())
d.device_type = "CPU"
d.device_index = 0
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
d.task = None
d.replica = 12
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
d.device_type = "GPU"
d.device_index = 2
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
d.device_type = "CPU"
d.device_index = 1
self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())
d.device_type = None
d.device_index = None
self.assertEqual("/job:foo/replica:12", d.to_string())
# Test wildcard
d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3,
device_type="GPU")
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
@parameterized.named_parameters(*TEST_V1_AND_V2)
def test_replace(self, device_spec_type):
d = device_spec_type()
d = d.replace(job="foo")
self.assertEqual("/job:foo", d.to_string())
d = d.replace(task=3)
self.assertEqual("/job:foo/task:3", d.to_string())
d = d.replace(device_type="CPU", device_index=0)
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
d = d.replace(task=None, replica=12)
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
d = d.replace(device_type="GPU", device_index=2)
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
d = d.replace(device_type="CPU", device_index=1)
self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())
d = d.replace(device_type=None, device_index=None)
self.assertEqual("/job:foo/replica:12", d.to_string())
# Test wildcard
d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3,
device_type="GPU")
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
@parameterized.named_parameters(*TEST_V1_AND_V2)
def testto_string(self, device_spec_type):
d = device_spec_type(job="foo")
self.assertEqual("/job:foo", d.to_string())
d = device_spec_type(job="foo", task=3)
self.assertEqual("/job:foo/task:3", d.to_string())
d = device_spec_type(job="foo", task=3, device_type="cpu", device_index=0)
self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
d = device_spec_type(job="foo", replica=12, device_type="cpu",
device_index=0)
self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
d = device_spec_type(job="foo", replica=12, device_type="gpu",
device_index=2)
self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
d = device_spec_type(job="foo", replica=12)
self.assertEqual("/job:foo/replica:12", d.to_string())
# Test wildcard
d = device_spec_type(job="foo", replica=12, task=3, device_type="GPU")
self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def test_parse_legacy(self):
d = device_spec.DeviceSpecV1()
d.parse_from_string("/job:foo/replica:0")
self.assertEqual("/job:foo/replica:0", d.to_string())
d.parse_from_string("/replica:1/task:0/cpu:0")
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
d.parse_from_string("/replica:1/task:0/device:CPU:0")
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
d.parse_from_string("/job:muu/device:GPU:2")
self.assertEqual("/job:muu/device:GPU:2", d.to_string())
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"):
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
@parameterized.named_parameters(*TEST_V1_AND_V2)
def test_to_from_string(self, device_spec_type):
d = device_spec_type.from_string("/job:foo/replica:0")
self.assertEqual("/job:foo/replica:0", d.to_string())
self.assertEqual(0, d.replica)
d = device_spec_type.from_string("/replica:1/task:0/cpu:0")
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
self.assertAllEqual([1, 0, "CPU", 0],
[d.replica, d.task, d.device_type, d.device_index])
d = device_spec_type.from_string("/replica:1/task:0/device:CPU:0")
self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())
self.assertAllEqual([1, 0, "CPU", 0],
[d.replica, d.task, d.device_type, d.device_index])
d = device_spec_type.from_string("/job:muu/device:GPU:2")
self.assertEqual("/job:muu/device:GPU:2", d.to_string())
self.assertAllEqual(["muu", "GPU", 2],
[d.job, d.device_type, d.device_index])
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"):
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
def test_merge_legacy(self):
d = device_spec.DeviceSpecV1.from_string("/job:foo/replica:0")
self.assertEqual("/job:foo/replica:0", d.to_string())
d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/device:GPU:2"))
self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
d = device_spec.DeviceSpecV1()
d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/cpu:0"))
self.assertEqual("/task:1/device:CPU:0", d.to_string())
d.merge_from(device_spec.DeviceSpecV1.from_string("/job:boo/device:GPU:0"))
self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string())
d.merge_from(device_spec.DeviceSpecV1.from_string("/job:muu/cpu:2"))
self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string())
d.merge_from(device_spec.DeviceSpecV1.from_string(
"/job:muu/device:MyFunnyDevice:2"))
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
def test_merge_removed(self):
with self.assertRaises(AttributeError):
d = device_spec.DeviceSpecV2()
d.merge_from(device_spec.DeviceSpecV2.from_string("/task:1/cpu:0"))
@parameterized.named_parameters(*TEST_V1_AND_V2)
def test_combine(self, device_spec_type):
d = device_spec_type.from_string("/job:foo/replica:0")
self.assertEqual("/job:foo/replica:0", d.to_string())
d = d.make_merged_spec(
device_spec_type.from_string("/task:1/device:GPU:2"))
self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
d = device_spec_type()
d = d.make_merged_spec(device_spec_type.from_string("/task:1/cpu:0"))
self.assertEqual("/task:1/device:CPU:0", d.to_string())
d = d.make_merged_spec(
device_spec_type.from_string("/job:boo/device:GPU:0"))
self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string())
d = d.make_merged_spec(device_spec_type.from_string("/job:muu/cpu:2"))
self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string())
d = d.make_merged_spec(device_spec_type.from_string(
"/job:muu/device:MyFunnyDevice:2"))
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
if __name__ == "__main__":
googletest.main()

View File

@ -18,120 +18,41 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import device
from tensorflow.python.framework import device_spec
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class DeviceTest(test_util.TensorFlowTestCase):
TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1),
("v2", device_spec.DeviceSpecV2))
def testEmpty(self):
d = device.DeviceSpec()
self.assertEquals("", d.to_string())
d.parse_from_string("")
self.assertEquals("", d.to_string())
def testConstructor(self):
d = device.DeviceSpec(job="j", replica=0, task=1,
device_type="CPU", device_index=2)
self.assertEqual("j", d.job)
self.assertEqual(0, d.replica)
self.assertEqual(1, d.task)
self.assertEqual("CPU", d.device_type)
self.assertEqual(2, d.device_index)
self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
class DeviceTest(test_util.TensorFlowTestCase, parameterized.TestCase):
d = device.DeviceSpec(device_type="GPU", device_index=0)
self.assertEquals("/device:GPU:0", d.to_string())
def testto_string(self):
d = device.DeviceSpec()
d.job = "foo"
self.assertEquals("/job:foo", d.to_string())
d.task = 3
self.assertEquals("/job:foo/task:3", d.to_string())
d.device_type = "CPU"
d.device_index = 0
self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string())
d.task = None
d.replica = 12
self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string())
d.device_type = "GPU"
d.device_index = 2
self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string())
d.device_type = "CPU"
d.device_index = 1
self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string())
d.device_type = None
d.device_index = None
d.cpu = None
self.assertEquals("/job:foo/replica:12", d.to_string())
# Test wildcard
d = device.DeviceSpec(job="foo", replica=12, task=3, device_type="GPU")
self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def testParse(self):
d = device.DeviceSpec()
d.parse_from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
d.parse_from_string("/replica:1/task:0/cpu:0")
self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
d.parse_from_string("/replica:1/task:0/device:CPU:0")
self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
d.parse_from_string("/job:muu/device:GPU:2")
self.assertEquals("/job:muu/device:GPU:2", d.to_string())
with self.assertRaises(Exception) as e:
d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
self.assertTrue("Cannot specify multiple device" in str(e.exception))
def testFromString(self):
d = device.DeviceSpec.from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
with self.assertRaises(Exception) as e:
d = device.DeviceSpec.from_string("/job:muu/device:GPU:2/cpu:0")
self.assertTrue("Cannot specify multiple device" in str(e.exception))
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/cpu:*")
self.assertEquals(None, d.device_index)
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/gpu:7")
self.assertEquals(7, d.device_index)
d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/device:GPU:7")
self.assertEquals(7, d.device_index)
def testMerge(self):
d = device.DeviceSpec.from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
d.merge_from(device.DeviceSpec.from_string("/task:1/device:GPU:2"))
self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
d = device.DeviceSpec()
d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0"))
self.assertEquals("/task:1/device:CPU:0", d.to_string())
d.merge_from(device.DeviceSpec.from_string("/job:boo/device:GPU:0"))
self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2"))
self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
d.merge_from(device.DeviceSpec.from_string(
"/job:muu/device:MyFunnyDevice:2"))
self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
@parameterized.named_parameters(*TEST_V1_AND_V2)
def testMerge(self, DeviceSpec): # pylint: disable=invalid-name
d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2")
self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
if not context.executing_eagerly():
with ops.device(device.merge_device("/device:GPU:0")):
var1 = variables.Variable(1.0)
self.assertEquals("/device:GPU:0", var1.device)
self.assertEqual("/device:GPU:0", var1.device)
with ops.device(device.merge_device("/job:worker")):
var2 = variables.Variable(1.0)
self.assertEquals("/job:worker/device:GPU:0", var2.device)
self.assertEqual("/job:worker/device:GPU:0", var2.device)
with ops.device(device.merge_device("/device:CPU:0")):
var3 = variables.Variable(1.0)
self.assertEquals("/job:worker/device:CPU:0", var3.device)
self.assertEqual("/job:worker/device:CPU:0", var3.device)
with ops.device(device.merge_device("/job:ps")):
var4 = variables.Variable(1.0)
self.assertEquals("/job:ps/device:CPU:0", var4.device)
self.assertEqual("/job:ps/device:CPU:0", var4.device)
def testCanonicalName(self):
self.assertEqual("/job:foo/replica:0",
@ -159,21 +80,17 @@ class DeviceTest(test_util.TensorFlowTestCase):
def testCheckValid(self):
device.check_valid("/job:foo/replica:0")
with self.assertRaises(Exception) as e:
with self.assertRaisesRegexp(ValueError, "invalid literal for int"):
device.check_valid("/job:j/replica:foo")
self.assertTrue("invalid literal for int" in str(e.exception))
with self.assertRaises(Exception) as e:
with self.assertRaisesRegexp(ValueError, "invalid literal for int"):
device.check_valid("/job:j/task:bar")
self.assertTrue("invalid literal for int" in str(e.exception))
with self.assertRaises(Exception) as e:
with self.assertRaisesRegexp(ValueError, "Unknown attribute: 'bar'"):
device.check_valid("/bar:muu/baz:2")
self.assertTrue("Unknown attribute: 'bar'" in str(e.exception))
with self.assertRaises(Exception) as e:
with self.assertRaisesRegexp(ValueError, "Cannot specify multiple device"):
device.check_valid("/cpu:0/device:GPU:2")
self.assertTrue("Cannot specify multiple device" in str(e.exception))
if __name__ == "__main__":

View File

@ -81,9 +81,15 @@ class _UserDeviceSpec(object):
def __init__(self, device_name_or_function):
self._device_name_or_function = device_name_or_function
self.display_name = str(self._device_name_or_function)
if callable(self._device_name_or_function):
self.function = device_name_or_function
self.raw_string = None
if isinstance(device_name_or_function, pydev.MergeDevice):
self.is_null_merge = device_name_or_function.is_null_merge
elif callable(device_name_or_function):
self.is_null_merge = False
dev_func = self._device_name_or_function
func_name = function_utils.get_func_name(dev_func)
func_code = function_utils.get_func_code(dev_func)
@ -95,13 +101,26 @@ class _UserDeviceSpec(object):
lineno = -1
self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
self.raw_string = None
elif device_name_or_function is None:
# NOTE(taylorrobie): This MUST be False. None signals a break in the
# device stack, so `is_null_merge` must be False for such a case to
# allow callers to safely skip over null merges without missing a None.
self.is_null_merge = False
self.function = self._device_name_or_function
if not (self._device_name_or_function is None or
callable(self._device_name_or_function)):
self.raw_string = self._device_name_or_function
self.function = pydev.merge_device(self._device_name_or_function)
else:
self.raw_string = device_name_or_function
self.function = pydev.merge_device(device_name_or_function)
self.is_null_merge = self.function.is_null_merge
# We perform this check in __init__ because it is of non-trivial cost,
# and self.string_merge is typically called many times.
self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
def string_merge(self, node_def):
if self.fast_string_merge:
return self.function.shortcut_string_merge(node_def)
return compat.as_str(_device_string(self.function(node_def)))
class NullContextmanager(object):
@ -617,7 +636,9 @@ class Tensor(_TensorLike):
def __eq__(self, other):
# Necessary to support Python's collection membership operators
return id(self) == id(other)
# NOTE(taylorrobie): equivalent to: id(self) == id(other)
return self is other
def __copy__(self):
# TODO(b/77597810): get rid of Tensor copies.
@ -1743,7 +1764,7 @@ IndexedSlicesValue = collections.namedtuple(
def _device_string(dev_spec):
if isinstance(dev_spec, pydev.DeviceSpec):
if pydev.is_device_spec(dev_spec):
return dev_spec.to_string()
else:
return dev_spec
@ -2224,10 +2245,22 @@ class Operation(object):
Args:
device: string or device.. The device to set.
"""
self._set_device_from_string(compat.as_str(_device_string(device)))
def _set_device_from_string(self, device_str):
"""Fast path to set device if the type is known to be a string.
This function is called frequently enough during graph construction that
there are non-trivial performance gains if the caller can guarantee that
the specified device is already a string.
Args:
device_str: A string specifying where to place this op.
"""
c_api.SetRequestedDevice(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
compat.as_str(_device_string(device)))
device_str)
def _update_input(self, index, tensor):
"""Update the input to this operation at the given index.
@ -4493,10 +4526,21 @@ class Graph(object):
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
# pylint: disable=protected-access
prior_device_string = None
for device_spec in self._device_function_stack.peek_objs():
if device_spec.is_null_merge:
continue
if device_spec.function is None:
break
op._set_device(device_spec.function(op))
device_string = device_spec.string_merge(op)
# Take advantage of the fact that None is a singleton and Python interns
# strings, since identity checks are faster than equality checks.
if device_string is not prior_device_string:
op._set_device_from_string(device_string)
prior_device_string = device_string
op._device_code_locations = self._snapshot_device_function_stack_metadata()
# pylint: enable=protected-access

View File

@ -40,6 +40,7 @@ from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tfdev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -532,8 +533,13 @@ class _TfDeviceCaptureOp(object):
def _set_device(self, device):
"""This method captures TF's explicit device scope setting."""
if tfdev.is_device_spec(device):
device = device.to_string()
self.device = device
def _set_device_from_string(self, device_str):
self.device = device_str
def _get_current_tf_device():
"""Return explicit device of current context, otherwise returns `None`.
@ -546,7 +552,7 @@ def _get_current_tf_device():
graph = get_graph()
op = _TfDeviceCaptureOp()
graph._apply_device_functions(op)
return op.device
return tfdev.DeviceSpec.from_string(op.device)
def _is_current_explicit_device(device_type):

View File

@ -291,6 +291,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
else:
self._device = device
def _set_device_from_string(self, device_str):
self._device = device_str
if self._outside_compilation_cluster:
raise NotImplementedError("Cannot nest outside_compilation clusters")
if cluster:

View File

@ -120,13 +120,13 @@ class _ReplicaDeviceChooser(object):
current_job, ps_job = current_device.job, ps_device.job
if ps_job and (not current_job or current_job == ps_job):
ps_device.task = self._ps_strategy(op)
ps_device = ps_device.replace(task=self._ps_strategy(op))
ps_device.merge_from(current_device)
ps_device = ps_device.make_merged_spec(current_device)
return ps_device.to_string()
worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
worker_device.merge_from(current_device)
worker_device = worker_device.make_merged_spec(current_device)
return worker_device.to_string()

View File

@ -50,8 +50,7 @@ def set_cpu0(device_string):
A device string.
"""
parsed_device = pydev.DeviceSpec.from_string(device_string)
parsed_device.device_type = "CPU"
parsed_device.device_index = 0
parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
return parsed_device.to_string()

View File

@ -1,6 +1,7 @@
path: "tensorflow.DeviceSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.device.DeviceSpec\'>"
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV1\'>"
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV2\'>"
is_instance: "<type \'object\'>"
member {
name: "device_index"
@ -28,7 +29,11 @@ tf_class {
}
member_method {
name: "from_string"
argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_merged_spec"
argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_from"
@ -38,6 +43,10 @@ tf_class {
name: "parse_from_string"
argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "replace"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "to_string"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -0,0 +1,49 @@
path: "tensorflow.DeviceSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.device_spec.DeviceSpecV2\'>"
is_instance: "<type \'object\'>"
member {
name: "device_index"
mtype: "<type \'property\'>"
}
member {
name: "device_type"
mtype: "<type \'property\'>"
}
member {
name: "job"
mtype: "<type \'property\'>"
}
member {
name: "replica"
mtype: "<type \'property\'>"
}
member {
name: "task"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_string"
argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_merged_spec"
argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "parse_from_string"
argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "replace"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "to_string"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -12,6 +12,10 @@ tf_module {
name: "DType"
mtype: "<type \'type\'>"
}
member {
name: "DeviceSpec"
mtype: "<type \'type\'>"
}
member {
name: "GradientTape"
mtype: "<type \'type\'>"