diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6bd45e6ee2b..7c4f33c3265 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -628,6 +628,7 @@ py_library( deps = [ ":constant_op", ":device", + ":device_spec", ":dtypes", ":framework_ops", ":function", @@ -649,6 +650,7 @@ py_library( deps = [ ":constant_op", ":device", + ":device_spec", ":dtypes", ":framework_ops", ":op_def_library", @@ -747,6 +749,15 @@ py_library( ], ) +py_library( + name = "device_spec", + srcs = ["framework/device_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":util", + ], +) + py_library( name = "device", srcs = ["framework/device.py"], @@ -1668,6 +1679,19 @@ tf_py_test( main = "framework/sparse_tensor_test.py", ) +tf_py_test( + name = "framework_device_spec_test", + size = "small", + srcs = ["framework/device_spec_test.py"], + additional_deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":platform_test", + "//tensorflow/core:protos_all_py", + ], + main = "framework/device_spec_test.py", +) + tf_py_test( name = "framework_device_test", size = "small", @@ -3924,6 +3948,7 @@ py_library( ":control_flow_ops", ":data_flow_ops", ":device", + ":device_spec", ":distribute", ":errors", ":framework", diff --git a/tensorflow/python/distribute/device_util.py b/tensorflow/python/distribute/device_util.py index c3d858690e0..94b9708a64e 100644 --- a/tensorflow/python/distribute/device_util.py +++ b/tensorflow/python/distribute/device_util.py @@ -51,10 +51,13 @@ def canonicalize(d, default=None): result = tf_device.DeviceSpec( replica=0, task=0, device_type="CPU", device_index=0) if ops.executing_eagerly_outside_functions(): - result.job = "localhost" + result = result.replace(job="localhost") if default: - result.merge_from(tf_device.DeviceSpec.from_string(default)) - result.merge_from(d) + result = result.make_merged_spec( + tf_device.DeviceSpec.from_string(default)) + + # Apply `d` last, so that it takes precidence over the defaults. + result = result.make_merged_spec(d) return result.to_string() @@ -83,6 +86,9 @@ class _FakeOperation(object): def _set_device(self, device): self.device = ops._device_string(device) # pylint: disable=protected-access + def _set_device_from_string(self, device_str): + self.device = device_str + def current(): """Return a string (not canonicalized) for the current device.""" diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index e0694411a17..e38e722cb16 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -68,7 +68,7 @@ def _enter_graph(g, eager, creator_stack=None): def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) + cpu_device = cpu_device.replace(device_type="CPU", device_index=0) return cpu_device.to_string() @@ -297,7 +297,7 @@ def _is_device_list_local(devices): """ all_local = None for d in devices: - d_spec = tf_device.DeviceSpec().parse_from_string(d) + d_spec = tf_device.DeviceSpec.from_string(d) is_local = d_spec.job in (None, "localhost") if all_local is None: # Determine all_local from first device. @@ -345,7 +345,7 @@ def _group_device_list(devices): device_dict = {} for d in devices: - d_spec = tf_device.DeviceSpec().parse_from_string(d) + d_spec = tf_device.DeviceSpec.from_string(d) # Create an entry for the task_type. if d_spec.job not in device_dict: @@ -361,7 +361,7 @@ def _group_device_list(devices): def _is_gpu_device(device): - return tf_device.DeviceSpec().parse_from_string(device).device_type == "GPU" + return tf_device.DeviceSpec.from_string(device).device_type == "GPU" def _infer_num_gpus_per_worker(devices): @@ -396,7 +396,7 @@ def _infer_num_gpus_per_worker(devices): raise ValueError("All workers should have the same number of GPUs.") for d in device_in_task: - d_spec = tf_device.DeviceSpec().parse_from_string(d) + d_spec = tf_device.DeviceSpec.from_string(d) if (d_spec.device_type == "GPU" and d_spec.device_index >= num_gpus): raise ValueError("GPU `device_index` on a worker should be " diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index a855ec6adaf..2eec7bea0c4 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -86,9 +86,6 @@ class DistributedValuesTest(test.TestCase): self.assertEqual(canonical_cpu, v.devices) v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,)) self.assertEqual(canonical_cpu, v.devices) - with self.assertRaises(AssertionError): - v = values.DistributedValues( - values.SingleDeviceMap("/device:cpu:0"), (42,)) def testIsTensorLike(self): with context.graph_mode(), \ diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 47a8bef3b94..4739470bb77 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -87,6 +87,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:device", + "//tensorflow/python:device_spec", "//tensorflow/python:errors", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 2863ac853f7..a2e5c62e767 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1024,7 +1024,7 @@ class _EagerDeviceContext(object): ctx.ensure_initialized() new_device_spec = pydev.DeviceSpec.from_string( ctx._context_devices[0]) # pylint: disable=protected-access - new_device_spec.merge_from(device_spec) + new_device_spec = new_device_spec.make_merged_spec(device_spec) else: new_device_spec = pydev.DeviceSpec.from_string("") new_device_name = new_device_spec.to_string() diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py index 7261f7a3526..2ffe97360f3 100644 --- a/tensorflow/python/framework/device.py +++ b/tensorflow/python/framework/device.py @@ -18,275 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy import threading -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python import tf2 +from tensorflow.python.framework import device_spec -@tf_export(v1=["DeviceSpec"]) -class DeviceSpec(object): - """Represents a (possibly partial) specification for a TensorFlow device. - - `DeviceSpec`s are used throughout TensorFlow to describe where state is stored - and computations occur. Using `DeviceSpec` allows you to parse device spec - strings to verify their validity, merge them or compose them programmatically. - - Example: - - ```python - # Place the operations on device "GPU:0" in the "ps" job. - device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) - with tf.device(device_spec): - # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. - my_var = tf.Variable(..., name="my_variable") - squared_var = tf.square(my_var) - ``` - - If a `DeviceSpec` is partially specified, it will be merged with other - `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` - components defined in inner scopes take precedence over those defined in - outer scopes. - - ```python - with tf.device(DeviceSpec(job="train", )): - with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): - # Nodes created here will be assigned to /job:ps/device:GPU:0. - with tf.device(DeviceSpec(device_type="GPU", device_index=1): - # Nodes created here will be assigned to /job:train/device:GPU:1. - ``` - - A `DeviceSpec` consists of 5 components -- each of - which is optionally specified: - - * Job: The job name. - * Replica: The replica index. - * Task: The task index. - * Device type: The device type string (e.g. "CPU" or "GPU"). - * Device index: The device index. - """ - - def __init__(self, job=None, replica=None, task=None, device_type=None, - device_index=None): - """Create a new `DeviceSpec` object. - - Args: - job: string. Optional job name. - replica: int. Optional replica index. - task: int. Optional task index. - device_type: Optional device type string (e.g. "CPU" or "GPU") - device_index: int. Optional device index. If left - unspecified, device represents 'any' device_index. - """ - self._set_job(job) - self._set_replica(replica) - self._set_task(task) - if device_type == "cpu" or device_type == "gpu": - # For backwards compatibility only, we support lowercase variants of - # cpu and gpu but turn them into uppercase here. - self._set_device_type(device_type.upper()) - else: - self._set_device_type(device_type) - self._set_device_index(device_index) - self._to_string = self._device_to_string() - self._hash = hash(self._to_string) - - def _clear(self): - self._job = None - self._replica = None - self._task = None - self._device_type = None - self._device_index = None - self._to_string = None - self._hash = None - - def _sync(self): - """Sync device internal states.""" - self._to_string = self._device_to_string() - self._hash = hash(self._to_string) - - @property - def job(self): - return self._job - - def _set_job(self, job): - if job is not None: - self._job = str(job) - else: - self._job = None - - @job.setter - def job(self, job): - self._set_job(job) - self._sync() - - @property - def replica(self): - return self._replica - - def _set_replica(self, replica): - if replica is not None: - self._replica = int(replica) - else: - self._replica = None - - @replica.setter - def replica(self, replica): - self._set_replica(replica) - self._sync() - - @property - def task(self): - return self._task - - def _set_task(self, task): - if task is not None: - self._task = int(task) - else: - self._task = None - - @task.setter - def task(self, task): - self._set_task(task) - self._sync() - - @property - def device_type(self): - return self._device_type - - def _set_device_type(self, device_type): - self._device_type = device_type - - @device_type.setter - def device_type(self, device_type): - self._set_device_type(device_type) - self._sync() - - @property - def device_index(self): - return self._device_index - - def _set_device_index(self, device_index): - self._device_index = device_index - - @device_index.setter - def device_index(self, device_index): - self._set_device_index(device_index) - self._sync() - - def parse_from_string(self, spec): - """Parse a `DeviceSpec` name into its components. - - Args: - spec: a string of the form - /job:/replica:/task:/device:CPU: - or - /job:/replica:/task:/device:GPU: - as cpu and gpu are mutually exclusive. - All entries are optional. - - Returns: - The `DeviceSpec`. - - Raises: - ValueError: if the spec was not valid. - """ - self._clear() - splits = [x.split(":") for x in spec.split("/")] - for y in splits: - ly = len(y) - if y: - if ly == 2 and y[0] == "job": - self._set_job(y[1]) - elif ly == 2 and y[0] == "replica": - self._set_replica(y[1]) - elif ly == 2 and y[0] == "task": - self._set_task(y[1]) - elif ((ly == 1 or ly == 2) and - ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): - if self.device_type is not None: - raise ValueError("Cannot specify multiple device types: %s" % spec) - self._set_device_type(y[0].upper()) - if ly == 2 and y[1] != "*": - self._set_device_index(int(y[1])) - elif ly == 3 and y[0] == "device": - if self.device_type is not None: - raise ValueError("Cannot specify multiple device types: %s" % spec) - self._set_device_type(y[1]) - if y[2] != "*": - self._set_device_index(int(y[2])) - elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison - raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) - - self._sync() - - return self - - def merge_from(self, dev): - """Merge the properties of "dev" into this `DeviceSpec`. - - Args: - dev: a `DeviceSpec`. - """ - if dev.job is not None: - self._set_job(dev.job) - if dev.replica is not None: - self._set_replica(dev.replica) - if dev.task is not None: - self._set_task(dev.task) - if dev.device_type is not None: - self._set_device_type(dev.device_type) - if dev.device_index is not None: - self._set_device_index(dev.device_index) - - self._sync() - - def _device_to_string(self): - """Private method that returns a string representation of `DeviceSpec`.""" - dev = "" - if self.job is not None: - dev += "/job:" + self.job - if self.replica is not None: - dev += "/replica:" + str(self.replica) - if self.task is not None: - dev += "/task:" + str(self.task) - if self.device_type is not None: - device_index_string = "*" - if self.device_index is not None: - device_index_string = str(self.device_index) - dev += "/device:%s:%s" % (self.device_type, device_index_string) - return dev - - def to_string(self): - """Return a string representation of this `DeviceSpec`. - - Returns: - a string of the form - /job:/replica:/task:/device::. - """ - return self._to_string - - @staticmethod - def from_string(spec): - """Construct a `DeviceSpec` from a string. - - Args: - spec: a string of the form - /job:/replica:/task:/device:CPU: - or - /job:/replica:/task:/device:GPU: - as cpu and gpu are mutually exclusive. - All entries are optional. - - Returns: - A DeviceSpec. - """ - return DeviceSpec().parse_from_string(spec) - - def __eq__(self, other): - return self.to_string() == other.to_string() - - def __hash__(self): - return self._hash +if tf2.enabled(): + DeviceSpec = device_spec.DeviceSpecV2 +else: + DeviceSpec = device_spec.DeviceSpecV1 def check_valid(spec): @@ -302,24 +42,26 @@ def check_valid(spec): DeviceSpec.from_string(spec) +def is_device_spec(obj): + """Abstract away the fact that DeviceSpecV2 is the base class.""" + return isinstance(obj, device_spec.DeviceSpecV2) + + def canonical_name(device): """Returns a canonical name for the given `DeviceSpec` or device name.""" if device is None: return "" - if isinstance(device, DeviceSpec): + if is_device_spec(device): return device.to_string() else: device = DeviceSpec.from_string(device) return device.to_string() -# Cache from DeviceSpec objects to their corresponding device functions. -# This cache is maintained for correctness, not performance: it makes it -# possible to compare the device function stacks belonging to different -# graphs in a meaningful way. -_cached_device_functions = {} -_cached_device_specs = {} -_cache_lock = threading.Lock() +# Performance caches +_cached_mergers = {} +_cache_lock = threading.RLock() +_string_merge_cache = {} def merge_device(spec): @@ -343,29 +85,99 @@ def merge_device(spec): the returned device function's with block. Returns: - A device function with the above-described behavior. + A MergeDevice object with the above-described behavior. Raises: ValueError: if the spec was not valid. """ + + if isinstance(spec, MergeDevice): + return spec + with _cache_lock: - if not isinstance(spec, DeviceSpec): - cached_device_spec = _cached_device_specs.get(spec, None) - if cached_device_spec is None: - device_spec = DeviceSpec.from_string(spec or "") - _cached_device_specs[spec] = device_spec - spec = device_spec - else: - spec = cached_device_spec - cached_function = _cached_device_functions.get(spec, None) - if cached_function is not None: - return cached_function + merger = _cached_mergers.get(spec) + if merger: + return merger - def _device_function(node_def): - current_device = DeviceSpec.from_string(node_def.device or "") - copy_spec = copy.copy(spec) - copy_spec.merge_from(current_device) # current_device takes precedence. - return copy_spec + merger = MergeDevice(spec) + _cached_mergers[spec] = merger + return merger - _cached_device_functions[spec] = _device_function - return _device_function + +class MergeDevice(object): + """Wraps a device specification (DeviceSpec or str) with merge functionality. + + When called, this class will merge a node_def with its own spec. It also + exposes a `shortcut_string_merge` method which can significantly improve + performance of device placement. + """ + + def __init__(self, spec): + if isinstance(spec, device_spec.DeviceSpecV2): + self._spec = spec + elif isinstance(spec, device_spec.DeviceSpecV1): + # Capture a snapshot of spec. + self._spec = spec.__class__.from_string(spec.to_string()) + else: + self._spec = DeviceSpec.from_string(spec) + + def __call__(self, node_def): + # In general a user may create a device function which takes into account + # arbitrary properties of an op. (For instance dynamically placing ops based + # on type.) So even though the standard DeviceSpec route only uses the + # device attribute, we take an entire node_def to maintain a consistent + # signature with general device functions. + current_device = DeviceSpec.from_string(node_def.device or "") + return self._spec.make_merged_spec(current_device) + + def shortcut_string_merge(self, node_def): + """Merge a node def without materializing a full DeviceSpec object. + + Often a device merge is invoked in order to generate a string which can be + passed into the c api. In such a case, we can cache the + node_def.device -> merge_result_string + + map, and in most cases avoid: + - Materializing a copy of self._spec (In the case of DeviceSpecV1) + - Materializing a DeviceSpec for node_def.device + - A DeviceSpec.merge_from invocation + + In practice the cache hit rate for this function is very high, because the + number of invocations when iterating through the device stack is much + larger than the number of devices. + + Args: + node_def: An Operation (or Operation-like) to merge device constraints + with self._spec + + Returns: + A string containing the merged device specification. + """ + device = node_def.device or "" + + merge_key = (self._spec, device) + result = _string_merge_cache.get(merge_key) + if result is None: + # This update is not atomic, however because the merge is stateless + # we don't need to lock when updating the cache. + result = self.__call__(node_def).to_string() + _string_merge_cache[merge_key] = result + + return result + + def __repr__(self): + return "{} (spec: {})".format( + super(MergeDevice, self).__repr__(), self._spec.to_string()) + + @property + def is_null_merge(self): + """Indicate whether the wrapped spec is empty. + + In the degenerate case where self._spec is an empty specification, a caller + may wish to skip a merge step entirely. (However this class does not have + enough information to make that determination.) + + Returns: + A boolean indicating whether a device merge will be trivial. + """ + return not bool(self._spec.to_string()) diff --git a/tensorflow/python/framework/device_spec.py b/tensorflow/python/framework/device_spec.py new file mode 100644 index 00000000000..83e517c2ae4 --- /dev/null +++ b/tensorflow/python/framework/device_spec.py @@ -0,0 +1,428 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class to represent a device.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + + +# ============================================================================== +# == Global Implementation Details ============================================= +# ============================================================================== +_STRING_TO_COMPONENTS_CACHE = {} +_COMPONENTS_TO_STRING_CACHE = {} + + +def _as_str_or_none(inp): + return None if inp is None else str(inp) + + +def _as_int_or_none(inp): + return None if inp is None else int(inp) + + +def _as_device_str_or_none(device_type): + # For backwards compatibility only, we support lowercase variants of + # cpu and gpu but turn them into uppercase here. + if device_type in ("cpu", "gpu"): + return device_type.upper() + return _as_str_or_none(device_type) + + +@tf_export("DeviceSpec", v1=[]) +class DeviceSpecV2(object): + """Represents a (possibly partial) specification for a TensorFlow device. + + `DeviceSpec`s are used throughout TensorFlow to describe where state is stored + and computations occur. Using `DeviceSpec` allows you to parse device spec + strings to verify their validity, merge them or compose them programmatically. + + Example: + + ```python + # Place the operations on device "GPU:0" in the "ps" job. + device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) + with tf.device(device_spec): + # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. + my_var = tf.Variable(..., name="my_variable") + squared_var = tf.square(my_var) + ``` + + If a `DeviceSpec` is partially specified, it will be merged with other + `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` + components defined in inner scopes take precedence over those defined in + outer scopes. + + ```python + with tf.device(DeviceSpec(job="train", )): + with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): + # Nodes created here will be assigned to /job:ps/device:GPU:0. + with tf.device(DeviceSpec(device_type="GPU", device_index=1): + # Nodes created here will be assigned to /job:train/device:GPU:1. + ``` + + A `DeviceSpec` consists of 5 components -- each of + which is optionally specified: + + * Job: The job name. + * Replica: The replica index. + * Task: The task index. + * Device type: The device type string (e.g. "CPU" or "GPU"). + * Device index: The device index. + """ + + __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index", + "_as_string", "_hash") + + def __init__(self, job=None, replica=None, task=None, device_type=None, + device_index=None): + """Create a new `DeviceSpec` object. + + Args: + job: string. Optional job name. + replica: int. Optional replica index. + task: int. Optional task index. + device_type: Optional device type string (e.g. "CPU" or "GPU") + device_index: int. Optional device index. If left + unspecified, device represents 'any' device_index. + """ + self._job = _as_str_or_none(job) + self._replica = _as_int_or_none(replica) + self._task = _as_int_or_none(task) + self._device_type = _as_device_str_or_none(device_type) + self._device_index = _as_int_or_none(device_index) + self._as_string = self._components_to_string( + job=self._job, replica=self._replica, task=self._task, + device_type=self._device_type, device_index=self._device_index) + self._hash = hash(self.to_string()) + + def to_string(self): + """Return a string representation of this `DeviceSpec`. + + Returns: + a string of the form + /job:/replica:/task:/device::. + """ + return self._as_string + + @classmethod + def from_string(cls, spec): + """Construct a `DeviceSpec` from a string. + + Args: + spec: a string of the form + /job:/replica:/task:/device:CPU: + or + /job:/replica:/task:/device:GPU: + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + A DeviceSpec. + """ + return cls(*cls._string_to_components(spec)) + + def parse_from_string(self, spec): + """Parse a `DeviceSpec` name into its components. + + 2.x behavior change: + In TensorFlow 1.x, this function mutates its own state and returns itself. + In 2.x, DeviceSpecs are immutable, and this function will return a + DeviceSpec which contains the spec. + + Recommended: + ``` + # my_spec and my_updated_spec are unrelated. + my_spec = tf.DeviceSpec.from_string("/CPU:0") + my_updated_spec = tf.DeviceSpec.from_string("/GPU:0") + with tf.device(my_updated_spec): + ... + ``` + + Will work in 1.x and 2.x (though deprecated in 2.x): + ``` + my_spec = tf.DeviceSpec.from_string("/CPU:0") + my_updated_spec = my_spec.parse_from_string("/GPU:0") + with tf.device(my_updated_spec): + ... + ``` + + Will NOT work in 2.x: + ``` + my_spec = tf.DeviceSpec.from_string("/CPU:0") + my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec + with tf.device(my_spec): + ... + ``` + + In general, `DeviceSpec.from_string` should completely replace + `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should + completely replace setting attributes directly. + + Args: + spec: an optional string of the form + /job:/replica:/task:/device:CPU: + or + /job:/replica:/task:/device:GPU: + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + The `DeviceSpec`. + + Raises: + ValueError: if the spec was not valid. + """ + return self.from_string(spec) + + def make_merged_spec(self, dev): + """Returns a new DeviceSpec which incorporates `dev`. + + When combining specs, `dev` will take precidence over the current spec. + So for instance: + ``` + first_spec = tf.DeviceSpec(job=0, device_type="CPU") + second_spec = tf.DeviceSpec(device_type="GPU") + combined_spec = first_spec.make_merged_spec(second_spec) + ``` + + is equivalent to: + ``` + combined_spec = tf.DeviceSpec(job=0, device_type="GPU") + ``` + + Args: + dev: a `DeviceSpec` + + Returns: + A new `DeviceSpec` which combines `self` and `dev` + """ + return self.__class__(*self._get_combined_properties(dev)) + + def replace(self, **kwargs): + """Convenience method for making a new DeviceSpec by overriding fields. + + For instance: + ``` + my_spec = DeviceSpec=(job="my_job", device="CPU") + my_updated_spec = my_spec.replace(device="GPU") + my_other_spec = my_spec.replace(device=None) + ``` + + Args: + **kwargs: This method takes the same args as the DeviceSpec constructor + + Returns: + A DeviceSpec with the fields specified in kwargs overridden. + """ + init_kwargs = dict( + job=self.job, replica=self.replica, task=self.task, + device_type=self.device_type, device_index=self.device_index) + + # Explicitly provided kwargs take precidence. + init_kwargs.update(kwargs) + return self.__class__(**init_kwargs) + + @property + def job(self): + return self._job + + @property + def replica(self): + return self._replica + + @property + def task(self): + return self._task + + @property + def device_type(self): + return self._device_type + + @property + def device_index(self): + return self._device_index + + def _get_combined_properties(self, dev): + """Combine the current DeviceSpec with another DeviceSpec. + + The combination of DeviceSpecs is will give priority to dev. + + Args: + dev: a `DeviceSpec` + + Returns: + A tuple of (job, replica, task, device_type, device_index) which + represents the combination of self and dev. + """ + return ( + dev.job if dev.job is not None else self.job, + dev.replica if dev.replica is not None else self.replica, + dev.task if dev.task is not None else self.task, + dev.device_type if dev.device_type is not None else self.device_type, + dev.device_index if dev.device_index is not None else self.device_index, + ) + + @staticmethod + def _string_to_components(spec=None): + """Stateless portion of device spec string parsing. + + Args: + spec: An optional string specifying a device specification. + + Returns: + The parsed components of `spec`. Note that the result of this function + must go through attribute setters of DeviceSpec, and should therefore NOT + be used directly. + """ + cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec) + if cached_result is not None: + return cached_result + + raw_spec = spec # keep a copy of the original to update the cache + job, replica, task, device_type, device_index = None, None, None, None, None + + spec = spec or "" + splits = [x.split(":") for x in spec.split("/")] + for y in splits: + ly = len(y) + if y: + # NOTE(taylorrobie): these will go through setters later. + if ly == 2 and y[0] == "job": + job = y[1] + elif ly == 2 and y[0] == "replica": + replica = y[1] + elif ly == 2 and y[0] == "task": + task = y[1] + elif ((ly == 1 or ly == 2) and + ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): + if device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + device_type = y[0].upper() + if ly == 2 and y[1] != "*": + device_index = int(y[1]) + elif ly == 3 and y[0] == "device": + if device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + device_type = y[1] + if y[2] != "*": + device_index = int(y[2]) + elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison + raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) + + output = (job, replica, task, device_type, device_index) + _STRING_TO_COMPONENTS_CACHE[raw_spec] = output + return output + + @staticmethod + def _components_to_string(job, replica, task, device_type, device_index): + """Stateless portion of `to_string` (separated to allow caching).""" + key = (job, replica, task, device_type, device_index) + cached_result = _COMPONENTS_TO_STRING_CACHE.get(key) + if cached_result is not None: + return cached_result + + output = [] + if job is not None: + output.append("/job:" + job) + if replica is not None: + output.append("/replica:" + str(replica)) + if task is not None: + output.append("/task:" + str(task)) + if device_type is not None: + device_index_string = "*" + if device_index is not None: + # Unlike the others, device_index is stored as an int. + device_index_string = str(device_index) + output.append("/device:%s:%s" % (device_type, device_index_string)) + + output = "".join(output) + _COMPONENTS_TO_STRING_CACHE[key] = output + return output + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.to_string() == other.to_string()) + + def __hash__(self): + return self._hash + + +@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring +class DeviceSpecV1(DeviceSpecV2): + __doc__ = DeviceSpecV2.__doc__ + __slots__ = DeviceSpecV2.__slots__ + + @DeviceSpecV2.job.setter + def job(self, job): + self._job = _as_str_or_none(job) + self._as_string, self._hash = None, None + + @DeviceSpecV2.replica.setter + def replica(self, replica): + self._replica = _as_int_or_none(replica) + self._as_string, self._hash = None, None + + @DeviceSpecV2.task.setter + def task(self, task): + self._task = _as_int_or_none(task) + self._as_string, self._hash = None, None + + @DeviceSpecV2.device_type.setter + def device_type(self, device_type): + self._device_type = _as_device_str_or_none(device_type) + self._as_string, self._hash = None, None + + @DeviceSpecV2.device_index.setter + def device_index(self, device_index): + self._device_index = _as_int_or_none(device_index) + self._as_string, self._hash = None, None + + def __hash__(self): + if self._hash is None: + self._hash = hash(self.to_string()) + return self._hash + + def to_string(self): + if self._as_string is None: + self._as_string = self._components_to_string( + job=self.job, replica=self.replica, task=self.task, + device_type=self.device_type, device_index=self.device_index) + return self._as_string + + def parse_from_string(self, spec): + (self.job, self.replica, self.task, self.device_type, self.device_index + ) = self._string_to_components(spec) + + return self + + def merge_from(self, dev): + """Merge the properties of "dev" into this `DeviceSpec`. + + Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become + immutable. + + Args: + dev: a `DeviceSpec`. + """ + (self.job, self.replica, self.task, self.device_type, self.device_index + ) = self._get_combined_properties(dev) + + # Use parent class docstrings for public methods. + to_string.__doc__ = DeviceSpecV2.to_string.__doc__ + parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__ diff --git a/tensorflow/python/framework/device_spec_test.py b/tensorflow/python/framework/device_spec_test.py new file mode 100644 index 00000000000..850b9a561ae --- /dev/null +++ b/tensorflow/python/framework/device_spec_test.py @@ -0,0 +1,229 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.device_spec.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.framework import device_spec +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1), + ("v2", device_spec.DeviceSpecV2)) + + +class DeviceSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def test_empty(self, device_spec_type): + d = device_spec_type() + self.assertEqual("", d.to_string()) + d.parse_from_string("") + self.assertEqual("", d.to_string()) + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def test_constructor(self, device_spec_type): + d = device_spec_type(job="j", replica=0, task=1, + device_type="CPU", device_index=2) + self.assertEqual("j", d.job) + self.assertEqual(0, d.replica) + self.assertEqual(1, d.task) + self.assertEqual("CPU", d.device_type) + self.assertEqual(2, d.device_index) + self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string()) + + d = device_spec_type(device_type="GPU", device_index=0) + self.assertEqual("/device:GPU:0", d.to_string()) + + def testto_string_legacy(self): + """DeviceSpecV1 allows direct mutation.""" + d = device_spec.DeviceSpecV1() + d.job = "foo" + self.assertEqual("/job:foo", d.to_string()) + d.task = 3 + self.assertEqual("/job:foo/task:3", d.to_string()) + d.device_type = "CPU" + d.device_index = 0 + self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string()) + d.task = None + d.replica = 12 + self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string()) + d.device_type = "GPU" + d.device_index = 2 + self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string()) + d.device_type = "CPU" + d.device_index = 1 + self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string()) + d.device_type = None + d.device_index = None + self.assertEqual("/job:foo/replica:12", d.to_string()) + + # Test wildcard + d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3, + device_type="GPU") + self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def test_replace(self, device_spec_type): + d = device_spec_type() + d = d.replace(job="foo") + self.assertEqual("/job:foo", d.to_string()) + + d = d.replace(task=3) + self.assertEqual("/job:foo/task:3", d.to_string()) + + d = d.replace(device_type="CPU", device_index=0) + self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string()) + + d = d.replace(task=None, replica=12) + self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string()) + + d = d.replace(device_type="GPU", device_index=2) + self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string()) + + d = d.replace(device_type="CPU", device_index=1) + self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string()) + + d = d.replace(device_type=None, device_index=None) + self.assertEqual("/job:foo/replica:12", d.to_string()) + + # Test wildcard + d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3, + device_type="GPU") + self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def testto_string(self, device_spec_type): + d = device_spec_type(job="foo") + self.assertEqual("/job:foo", d.to_string()) + + d = device_spec_type(job="foo", task=3) + self.assertEqual("/job:foo/task:3", d.to_string()) + + d = device_spec_type(job="foo", task=3, device_type="cpu", device_index=0) + self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string()) + + d = device_spec_type(job="foo", replica=12, device_type="cpu", + device_index=0) + self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string()) + + d = device_spec_type(job="foo", replica=12, device_type="gpu", + device_index=2) + self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string()) + + d = device_spec_type(job="foo", replica=12) + self.assertEqual("/job:foo/replica:12", d.to_string()) + + # Test wildcard + d = device_spec_type(job="foo", replica=12, task=3, device_type="GPU") + self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) + + def test_parse_legacy(self): + d = device_spec.DeviceSpecV1() + d.parse_from_string("/job:foo/replica:0") + self.assertEqual("/job:foo/replica:0", d.to_string()) + + d.parse_from_string("/replica:1/task:0/cpu:0") + self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) + + d.parse_from_string("/replica:1/task:0/device:CPU:0") + self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) + + d.parse_from_string("/job:muu/device:GPU:2") + self.assertEqual("/job:muu/device:GPU:2", d.to_string()) + + with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"): + d.parse_from_string("/job:muu/device:GPU:2/cpu:0") + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def test_to_from_string(self, device_spec_type): + d = device_spec_type.from_string("/job:foo/replica:0") + self.assertEqual("/job:foo/replica:0", d.to_string()) + self.assertEqual(0, d.replica) + + d = device_spec_type.from_string("/replica:1/task:0/cpu:0") + self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) + self.assertAllEqual([1, 0, "CPU", 0], + [d.replica, d.task, d.device_type, d.device_index]) + + d = device_spec_type.from_string("/replica:1/task:0/device:CPU:0") + self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) + self.assertAllEqual([1, 0, "CPU", 0], + [d.replica, d.task, d.device_type, d.device_index]) + + d = device_spec_type.from_string("/job:muu/device:GPU:2") + self.assertEqual("/job:muu/device:GPU:2", d.to_string()) + self.assertAllEqual(["muu", "GPU", 2], + [d.job, d.device_type, d.device_index]) + + with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"): + d.parse_from_string("/job:muu/device:GPU:2/cpu:0") + + def test_merge_legacy(self): + d = device_spec.DeviceSpecV1.from_string("/job:foo/replica:0") + self.assertEqual("/job:foo/replica:0", d.to_string()) + + d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/device:GPU:2")) + self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) + + d = device_spec.DeviceSpecV1() + d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/cpu:0")) + self.assertEqual("/task:1/device:CPU:0", d.to_string()) + + d.merge_from(device_spec.DeviceSpecV1.from_string("/job:boo/device:GPU:0")) + self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string()) + + d.merge_from(device_spec.DeviceSpecV1.from_string("/job:muu/cpu:2")) + self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string()) + d.merge_from(device_spec.DeviceSpecV1.from_string( + "/job:muu/device:MyFunnyDevice:2")) + self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) + + def test_merge_removed(self): + with self.assertRaises(AttributeError): + d = device_spec.DeviceSpecV2() + d.merge_from(device_spec.DeviceSpecV2.from_string("/task:1/cpu:0")) + + @parameterized.named_parameters(*TEST_V1_AND_V2) + def test_combine(self, device_spec_type): + d = device_spec_type.from_string("/job:foo/replica:0") + self.assertEqual("/job:foo/replica:0", d.to_string()) + + d = d.make_merged_spec( + device_spec_type.from_string("/task:1/device:GPU:2")) + self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) + + d = device_spec_type() + d = d.make_merged_spec(device_spec_type.from_string("/task:1/cpu:0")) + self.assertEqual("/task:1/device:CPU:0", d.to_string()) + + d = d.make_merged_spec( + device_spec_type.from_string("/job:boo/device:GPU:0")) + self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string()) + + d = d.make_merged_spec(device_spec_type.from_string("/job:muu/cpu:2")) + self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string()) + d = d.make_merged_spec(device_spec_type.from_string( + "/job:muu/device:MyFunnyDevice:2")) + self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py index cd4b4ea51e6..2b34c1ec7fd 100644 --- a/tensorflow/python/framework/device_test.py +++ b/tensorflow/python/framework/device_test.py @@ -18,120 +18,41 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.framework import device +from tensorflow.python.framework import device_spec from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class DeviceTest(test_util.TensorFlowTestCase): +TEST_V1_AND_V2 = (("v1", device_spec.DeviceSpecV1), + ("v2", device_spec.DeviceSpecV2)) - def testEmpty(self): - d = device.DeviceSpec() - self.assertEquals("", d.to_string()) - d.parse_from_string("") - self.assertEquals("", d.to_string()) - def testConstructor(self): - d = device.DeviceSpec(job="j", replica=0, task=1, - device_type="CPU", device_index=2) - self.assertEqual("j", d.job) - self.assertEqual(0, d.replica) - self.assertEqual(1, d.task) - self.assertEqual("CPU", d.device_type) - self.assertEqual(2, d.device_index) - self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string()) +class DeviceTest(test_util.TensorFlowTestCase, parameterized.TestCase): - d = device.DeviceSpec(device_type="GPU", device_index=0) - self.assertEquals("/device:GPU:0", d.to_string()) - - def testto_string(self): - d = device.DeviceSpec() - d.job = "foo" - self.assertEquals("/job:foo", d.to_string()) - d.task = 3 - self.assertEquals("/job:foo/task:3", d.to_string()) - d.device_type = "CPU" - d.device_index = 0 - self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string()) - d.task = None - d.replica = 12 - self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string()) - d.device_type = "GPU" - d.device_index = 2 - self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string()) - d.device_type = "CPU" - d.device_index = 1 - self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string()) - d.device_type = None - d.device_index = None - d.cpu = None - self.assertEquals("/job:foo/replica:12", d.to_string()) - - # Test wildcard - d = device.DeviceSpec(job="foo", replica=12, task=3, device_type="GPU") - self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) - - def testParse(self): - d = device.DeviceSpec() - d.parse_from_string("/job:foo/replica:0") - self.assertEquals("/job:foo/replica:0", d.to_string()) - d.parse_from_string("/replica:1/task:0/cpu:0") - self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) - d.parse_from_string("/replica:1/task:0/device:CPU:0") - self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) - d.parse_from_string("/job:muu/device:GPU:2") - self.assertEquals("/job:muu/device:GPU:2", d.to_string()) - with self.assertRaises(Exception) as e: - d.parse_from_string("/job:muu/device:GPU:2/cpu:0") - self.assertTrue("Cannot specify multiple device" in str(e.exception)) - - def testFromString(self): - d = device.DeviceSpec.from_string("/job:foo/replica:0") - self.assertEquals("/job:foo/replica:0", d.to_string()) - with self.assertRaises(Exception) as e: - d = device.DeviceSpec.from_string("/job:muu/device:GPU:2/cpu:0") - self.assertTrue("Cannot specify multiple device" in str(e.exception)) - - d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/cpu:*") - self.assertEquals(None, d.device_index) - d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/gpu:7") - self.assertEquals(7, d.device_index) - d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/device:GPU:7") - self.assertEquals(7, d.device_index) - - def testMerge(self): - d = device.DeviceSpec.from_string("/job:foo/replica:0") - self.assertEquals("/job:foo/replica:0", d.to_string()) - d.merge_from(device.DeviceSpec.from_string("/task:1/device:GPU:2")) - self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) - - d = device.DeviceSpec() - d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0")) - self.assertEquals("/task:1/device:CPU:0", d.to_string()) - d.merge_from(device.DeviceSpec.from_string("/job:boo/device:GPU:0")) - self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string()) - d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2")) - self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string()) - d.merge_from(device.DeviceSpec.from_string( - "/job:muu/device:MyFunnyDevice:2")) - self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) + @parameterized.named_parameters(*TEST_V1_AND_V2) + def testMerge(self, DeviceSpec): # pylint: disable=invalid-name + d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2") + self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) if not context.executing_eagerly(): with ops.device(device.merge_device("/device:GPU:0")): var1 = variables.Variable(1.0) - self.assertEquals("/device:GPU:0", var1.device) + self.assertEqual("/device:GPU:0", var1.device) with ops.device(device.merge_device("/job:worker")): var2 = variables.Variable(1.0) - self.assertEquals("/job:worker/device:GPU:0", var2.device) + self.assertEqual("/job:worker/device:GPU:0", var2.device) with ops.device(device.merge_device("/device:CPU:0")): var3 = variables.Variable(1.0) - self.assertEquals("/job:worker/device:CPU:0", var3.device) + self.assertEqual("/job:worker/device:CPU:0", var3.device) with ops.device(device.merge_device("/job:ps")): var4 = variables.Variable(1.0) - self.assertEquals("/job:ps/device:CPU:0", var4.device) + self.assertEqual("/job:ps/device:CPU:0", var4.device) def testCanonicalName(self): self.assertEqual("/job:foo/replica:0", @@ -159,21 +80,17 @@ class DeviceTest(test_util.TensorFlowTestCase): def testCheckValid(self): device.check_valid("/job:foo/replica:0") - with self.assertRaises(Exception) as e: + with self.assertRaisesRegexp(ValueError, "invalid literal for int"): device.check_valid("/job:j/replica:foo") - self.assertTrue("invalid literal for int" in str(e.exception)) - with self.assertRaises(Exception) as e: + with self.assertRaisesRegexp(ValueError, "invalid literal for int"): device.check_valid("/job:j/task:bar") - self.assertTrue("invalid literal for int" in str(e.exception)) - with self.assertRaises(Exception) as e: + with self.assertRaisesRegexp(ValueError, "Unknown attribute: 'bar'"): device.check_valid("/bar:muu/baz:2") - self.assertTrue("Unknown attribute: 'bar'" in str(e.exception)) - with self.assertRaises(Exception) as e: + with self.assertRaisesRegexp(ValueError, "Cannot specify multiple device"): device.check_valid("/cpu:0/device:GPU:2") - self.assertTrue("Cannot specify multiple device" in str(e.exception)) if __name__ == "__main__": diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 8b4086f573f..d3fc9528dad 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -81,9 +81,15 @@ class _UserDeviceSpec(object): def __init__(self, device_name_or_function): self._device_name_or_function = device_name_or_function - self.display_name = str(self._device_name_or_function) - if callable(self._device_name_or_function): + self.function = device_name_or_function + self.raw_string = None + + if isinstance(device_name_or_function, pydev.MergeDevice): + self.is_null_merge = device_name_or_function.is_null_merge + + elif callable(device_name_or_function): + self.is_null_merge = False dev_func = self._device_name_or_function func_name = function_utils.get_func_name(dev_func) func_code = function_utils.get_func_code(dev_func) @@ -95,13 +101,26 @@ class _UserDeviceSpec(object): lineno = -1 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) - self.raw_string = None + elif device_name_or_function is None: + # NOTE(taylorrobie): This MUST be False. None signals a break in the + # device stack, so `is_null_merge` must be False for such a case to + # allow callers to safely skip over null merges without missing a None. + self.is_null_merge = False - self.function = self._device_name_or_function - if not (self._device_name_or_function is None or - callable(self._device_name_or_function)): - self.raw_string = self._device_name_or_function - self.function = pydev.merge_device(self._device_name_or_function) + else: + self.raw_string = device_name_or_function + self.function = pydev.merge_device(device_name_or_function) + self.is_null_merge = self.function.is_null_merge + + # We perform this check in __init__ because it is of non-trivial cost, + # and self.string_merge is typically called many times. + self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) + + def string_merge(self, node_def): + if self.fast_string_merge: + return self.function.shortcut_string_merge(node_def) + + return compat.as_str(_device_string(self.function(node_def))) class NullContextmanager(object): @@ -617,7 +636,9 @@ class Tensor(_TensorLike): def __eq__(self, other): # Necessary to support Python's collection membership operators - return id(self) == id(other) + + # NOTE(taylorrobie): equivalent to: id(self) == id(other) + return self is other def __copy__(self): # TODO(b/77597810): get rid of Tensor copies. @@ -1743,7 +1764,7 @@ IndexedSlicesValue = collections.namedtuple( def _device_string(dev_spec): - if isinstance(dev_spec, pydev.DeviceSpec): + if pydev.is_device_spec(dev_spec): return dev_spec.to_string() else: return dev_spec @@ -2224,10 +2245,22 @@ class Operation(object): Args: device: string or device.. The device to set. """ + self._set_device_from_string(compat.as_str(_device_string(device))) + + def _set_device_from_string(self, device_str): + """Fast path to set device if the type is known to be a string. + + This function is called frequently enough during graph construction that + there are non-trivial performance gains if the caller can guarantee that + the specified device is already a string. + + Args: + device_str: A string specifying where to place this op. + """ c_api.SetRequestedDevice( self._graph._c_graph, # pylint: disable=protected-access self._c_op, # pylint: disable=protected-access - compat.as_str(_device_string(device))) + device_str) def _update_input(self, index, tensor): """Update the input to this operation at the given index. @@ -4493,10 +4526,21 @@ class Graph(object): # We apply here because the result can depend on the Operation's # signature, which is computed in the Operation constructor. # pylint: disable=protected-access + prior_device_string = None for device_spec in self._device_function_stack.peek_objs(): + if device_spec.is_null_merge: + continue + if device_spec.function is None: break - op._set_device(device_spec.function(op)) + + device_string = device_spec.string_merge(op) + + # Take advantage of the fact that None is a singleton and Python interns + # strings, since identity checks are faster than equality checks. + if device_string is not prior_device_string: + op._set_device_from_string(device_string) + prior_device_string = device_string op._device_code_locations = self._snapshot_device_function_stack_metadata() # pylint: enable=protected-access diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 5e2e3d4f324..6b9070da740 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -40,6 +40,7 @@ from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tfdev from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops @@ -532,8 +533,13 @@ class _TfDeviceCaptureOp(object): def _set_device(self, device): """This method captures TF's explicit device scope setting.""" + if tfdev.is_device_spec(device): + device = device.to_string() self.device = device + def _set_device_from_string(self, device_str): + self.device = device_str + def _get_current_tf_device(): """Return explicit device of current context, otherwise returns `None`. @@ -546,7 +552,7 @@ def _get_current_tf_device(): graph = get_graph() op = _TfDeviceCaptureOp() graph._apply_device_functions(op) - return op.device + return tfdev.DeviceSpec.from_string(op.device) def _is_current_explicit_device(device_type): diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index d5d7ea266da..208eb728eec 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -291,6 +291,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): else: self._device = device + def _set_device_from_string(self, device_str): + self._device = device_str + if self._outside_compilation_cluster: raise NotImplementedError("Cannot nest outside_compilation clusters") if cluster: diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 5874a1ff415..1f946796d27 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -120,13 +120,13 @@ class _ReplicaDeviceChooser(object): current_job, ps_job = current_device.job, ps_device.job if ps_job and (not current_job or current_job == ps_job): - ps_device.task = self._ps_strategy(op) + ps_device = ps_device.replace(task=self._ps_strategy(op)) - ps_device.merge_from(current_device) + ps_device = ps_device.make_merged_spec(current_device) return ps_device.to_string() worker_device = pydev.DeviceSpec.from_string(self._worker_device or "") - worker_device.merge_from(current_device) + worker_device = worker_device.make_merged_spec(current_device) return worker_device.to_string() diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index 9c6be690a86..9be3e2ce3de 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -50,8 +50,7 @@ def set_cpu0(device_string): A device string. """ parsed_device = pydev.DeviceSpec.from_string(device_string) - parsed_device.device_type = "CPU" - parsed_device.device_index = 0 + parsed_device = parsed_device.replace(device_type="CPU", device_index=0) return parsed_device.to_string() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt index ab442d123c4..f8d3cb4e2e9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.DeviceSpec" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "device_index" @@ -28,7 +29,11 @@ tf_class { } member_method { name: "from_string" - argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_merged_spec" + argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None" } member_method { name: "merge_from" @@ -38,6 +43,10 @@ tf_class { name: "parse_from_string" argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "replace" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } member_method { name: "to_string" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt new file mode 100644 index 00000000000..49ae20a33b2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.DeviceSpec" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "device_index" + mtype: "" + } + member { + name: "device_type" + mtype: "" + } + member { + name: "job" + mtype: "" + } + member { + name: "replica" + mtype: "" + } + member { + name: "task" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_string" + argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_merged_spec" + argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "parse_from_string" + argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "replace" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_string" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 6e7cc51a285..f8d86223b12 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "DType" mtype: "" } + member { + name: "DeviceSpec" + mtype: "" + } member { name: "GradientTape" mtype: ""