Introduce PackedDistributedVariable which packs multiple variables distributed across devices.
Introduce PackedVarAndDevice which represents a packed variable in a given device. PiperOrigin-RevId: 315769635 Change-Id: Ia63b72610afeb7139bd8370bc47067a1fb165307
This commit is contained in:
parent
cc8505eb36
commit
f5547e8125
@ -675,12 +675,25 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "packed_distributed_variable",
|
||||
srcs = ["packed_distributed_variable.py"],
|
||||
deps = [
|
||||
":device_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "values",
|
||||
srcs = ["values.py"],
|
||||
deps = [
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":packed_distributed_variable",
|
||||
":reduce_util",
|
||||
":values_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1710,6 +1723,24 @@ py_library(
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "packed_distributed_variable_test",
|
||||
srcs = ["packed_distributed_variable_test.py"],
|
||||
tags = [
|
||||
"nomac", #TODO(b/145922293): It would cause a Python segfault on macos
|
||||
],
|
||||
deps = [
|
||||
":device_util",
|
||||
":packed_distributed_variable",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_process_runner_test",
|
||||
srcs = ["multi_process_runner_test.py"],
|
||||
|
338
tensorflow/python/distribute/packed_distributed_variable.py
Normal file
338
tensorflow/python/distribute/packed_distributed_variable.py
Normal file
@ -0,0 +1,338 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A variable which packs a list of variables distributed across devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
|
||||
class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
||||
"""A variable which packs multiple variables distributed across devices.
|
||||
|
||||
It's only supported when eager execution is enabled.
|
||||
For op-by-op execution, use an unpacked handle on the current device; for
|
||||
function execution, use the packed handle to reduce the overhead of function
|
||||
calls.
|
||||
"""
|
||||
|
||||
def __init__(self, distributed_variables=None, name=None, **unused_kwargs):
|
||||
"""Packs a list of variables which are distributed across devices.
|
||||
|
||||
Args:
|
||||
distributed_variables: A list of distributed Variables to pack.
|
||||
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||||
uniquified automatically.
|
||||
"""
|
||||
if not context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"PackedDistributedVariable should be created in eager mode.")
|
||||
if not distributed_variables:
|
||||
raise ValueError("Expect a non-empty list of variables to pack.")
|
||||
for i, var in enumerate(distributed_variables):
|
||||
if not resource_variable_ops.is_resource_variable(var):
|
||||
raise ValueError("Expect a list of ResourceVariables to pack, "
|
||||
"but the %d-th variable is %s" % (i, type(var)))
|
||||
|
||||
self._distributed_variables = distributed_variables
|
||||
self._devices = [v.device for v in distributed_variables]
|
||||
with ops.init_scope():
|
||||
with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
|
||||
handle = ops.pack_eager_tensors(
|
||||
[var.handle for var in distributed_variables])
|
||||
handle_name = ops.name_from_scope_name(name)
|
||||
unique_id = "%s_%d" % (handle_name, ops.uid())
|
||||
super(PackedDistributedVariable, self).__init__(
|
||||
trainable=distributed_variables[0].trainable,
|
||||
shape=distributed_variables[0].shape,
|
||||
dtype=distributed_variables[0].dtype,
|
||||
handle=handle,
|
||||
synchronization=distributed_variables[0].synchronization,
|
||||
constraint=distributed_variables[0].constraint,
|
||||
aggregation=distributed_variables[0].aggregation,
|
||||
distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access
|
||||
name=name,
|
||||
unique_id=unique_id,
|
||||
handle_name=handle_name,
|
||||
graph_element=None,
|
||||
initial_value=None,
|
||||
initializer_op=None,
|
||||
is_initialized_op=None,
|
||||
cached_value=None,
|
||||
caching_device=None,
|
||||
is_distributed_variables=True)
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
return self._devices
|
||||
|
||||
def get_var_on_device(self, device):
|
||||
for i, d in enumerate(self._devices):
|
||||
if d == device:
|
||||
return self._distributed_variables[i]
|
||||
raise ValueError("Device %s is not found" % device)
|
||||
|
||||
def get_var_on_current_device(self):
|
||||
current_device = device_util.canonicalize(device_util.current())
|
||||
return self.get_var_on_device(current_device)
|
||||
|
||||
def initial_value(self, device):
|
||||
"""Returns the Tensor used as the initial value for the variable."""
|
||||
return self.get_var_on_device(device).initial_value
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
def _read_variable_op(self):
|
||||
if context.executing_eagerly():
|
||||
return self.get_var_on_current_device().value()
|
||||
else:
|
||||
return super(PackedDistributedVariable, self)._read_variable_op()
|
||||
|
||||
def value(self):
|
||||
return self._read_variable_op()
|
||||
|
||||
def is_initialized(self, name=None):
|
||||
if context.executing_eagerly():
|
||||
result = self._distributed_variables[0].is_initialized()
|
||||
for v in self._distributed_variables[1:-1]:
|
||||
result = math_ops.logical_and(result, v.is_initialized())
|
||||
result = math_ops.logical_and(
|
||||
result, self._distributed_variables[-1].is_initialized(), name=name)
|
||||
else:
|
||||
with ops.device(self._devices[0]):
|
||||
result = super(PackedDistributedVariable, self).is_initialized(name)
|
||||
for d in self._devices[1:-1]:
|
||||
with ops.device(d):
|
||||
initialized = super(PackedDistributedVariable,
|
||||
self).is_initialized(name)
|
||||
result = math_ops.logical_and(result, initialized)
|
||||
with ops.device(self._devices[-1]):
|
||||
initialized = super(PackedDistributedVariable,
|
||||
self).is_initialized(name)
|
||||
result = math_ops.logical_and(result, initialized, name=name)
|
||||
return result
|
||||
|
||||
def _update(self, update_fn, value, **kwargs):
|
||||
if context.executing_eagerly():
|
||||
return update_fn(self.get_var_on_current_device(), value, **kwargs)
|
||||
else:
|
||||
return update_fn(super(PackedDistributedVariable, self), value, **kwargs)
|
||||
|
||||
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_sub_fn,
|
||||
value=delta,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_add_fn,
|
||||
value=delta,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign(self, value, use_locking=None, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
if context.executing_eagerly():
|
||||
return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
as_ref=as_ref)
|
||||
else:
|
||||
return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
as_ref=as_ref)
|
||||
|
||||
|
||||
class PackedVarAndDevice(object):
|
||||
"""Holds a packed distributed variable and a device."""
|
||||
|
||||
def __init__(self, var, device):
|
||||
self._var = var
|
||||
self._device = device
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._var, name)
|
||||
|
||||
def var(self):
|
||||
return self._var
|
||||
|
||||
def value(self):
|
||||
with ops.device(self._device):
|
||||
return self._var.value()
|
||||
|
||||
def read_value(self):
|
||||
with ops.device(self._device):
|
||||
return self._var.read_value()
|
||||
|
||||
@property
|
||||
def initial_value(self):
|
||||
return self._var.initial_value(self._device)
|
||||
|
||||
def initialized_value(self):
|
||||
with ops.device(self._device):
|
||||
return self._var.initialized_value()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._var.handle
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
with ops.device(self._device):
|
||||
return self._var.op
|
||||
|
||||
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
|
||||
with ops.device(self._device):
|
||||
return self._var.assign_sub(delta, use_locking, name, read_value)
|
||||
|
||||
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
|
||||
with ops.device(self._device):
|
||||
return self._var.assign_add(delta, use_locking, name, read_value)
|
||||
|
||||
def assign(self, value, use_locking=None, name=None, read_value=True):
|
||||
with ops.device(self._device):
|
||||
return self._var.assign(value, use_locking, name, read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_sub(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_add(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_mul(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_div(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_min(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_max(sparse_delta, use_locking, name)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
with ops.device(self._device):
|
||||
return self._var.scatter_update(sparse_delta, use_locking, name)
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
with ops.device(self._device):
|
||||
return self._var._dense_var_to_tensor( # pylint: disable=protected-access
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
as_ref=as_ref)
|
||||
|
||||
def _as_graph_element(self):
|
||||
return self._var._as_graph_element() # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _tensor_conversion_packed_var_and_device(var,
|
||||
dtype=None,
|
||||
name=None,
|
||||
as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
PackedVarAndDevice, _tensor_conversion_packed_var_and_device)
|
107
tensorflow/python/distribute/packed_distributed_variable_test.py
Normal file
107
tensorflow/python/distribute/packed_distributed_variable_test.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import packed_distributed_variable
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PackedDistributedVariableTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PackedDistributedVariableTest, self).setUp()
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
# Set 2 virtual CPUs
|
||||
config.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
])
|
||||
|
||||
def testPackedVariable(self):
|
||||
with ops.device('/cpu:0'):
|
||||
v0 = resource_variable_ops.ResourceVariable(1.0, name='var0')
|
||||
with ops.device('/cpu:1'):
|
||||
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
|
||||
|
||||
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
|
||||
self.assertTrue(packed_var.handle.is_packed)
|
||||
self.assertTrue(packed_var.is_initialized)
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
self.assertAllEqual(packed_var.get_var_on_current_device(), v0)
|
||||
val0 = packed_var.assign(2.0).assign_add(1.0)
|
||||
self.assertAllEqual(val0, 3.0)
|
||||
|
||||
with ops.device('/cpu:1'):
|
||||
self.assertAllEqual(packed_var.get_var_on_current_device(), v1)
|
||||
val0 = packed_var.assign(2.0).assign_add(1.0)
|
||||
self.assertAllEqual(val0, 3.0)
|
||||
|
||||
@def_function.function
|
||||
def update_var():
|
||||
with ops.device('/cpu:0'):
|
||||
packed_var.assign_add(3.0).assign_sub(1.0)
|
||||
read0 = packed_var.value()
|
||||
with ops.device('/cpu:1'):
|
||||
packed_var.assign_sub(4.0).assign_sub(2.0)
|
||||
read1 = packed_var.value()
|
||||
|
||||
return read0, read1
|
||||
|
||||
self.assertAllEqual(update_var(), (5.0, -3.0))
|
||||
|
||||
def testPackedVarAndDevice(self):
|
||||
device0 = device_util.canonicalize('/cpu:0')
|
||||
device1 = device_util.canonicalize('/cpu:1')
|
||||
|
||||
with ops.device(device0):
|
||||
v0 = resource_variable_ops.ResourceVariable(1.0)
|
||||
with ops.device(device1):
|
||||
v1 = resource_variable_ops.ResourceVariable(2.0)
|
||||
|
||||
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
|
||||
|
||||
packed_var0 = packed_distributed_variable.PackedVarAndDevice(
|
||||
packed_var, device0)
|
||||
self.assertTrue(packed_var0.handle.is_packed)
|
||||
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
|
||||
|
||||
packed_var1 = packed_distributed_variable.PackedVarAndDevice(
|
||||
packed_var, device1)
|
||||
self.assertAllEqual(packed_var1.assign(3.0), 3.0)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
|
||||
var0.assign_add(3.0)
|
||||
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)
|
||||
return var0.value(), math_ops.add(var1, 2.0)
|
||||
|
||||
self.assertAllEqual(func(), (4.0, 5.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -419,6 +420,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
self._aggregation = aggregation
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
self._common_name = self._primary.name.split(":")[0]
|
||||
|
||||
# Packed variable is used to reduce the overhead of function execution.
|
||||
# For a DistributedVariable, only one variable handle is captured into a
|
||||
# function graph. It's only supported in eager mode.
|
||||
if ops.executing_eagerly_outside_functions() and getattr(
|
||||
strategy, "_enable_packed_variable_in_eager_mode", False):
|
||||
name = "%s/packed/" % self._common_name
|
||||
self._packed_var = packed.PackedDistributedVariable(values, name=name)
|
||||
else:
|
||||
self._packed_var = None
|
||||
|
||||
# tf.keras keeps track of variables initialized using this attribute. When
|
||||
# tf.keras gets the default session, it initializes all uninitialized vars.
|
||||
# We need to make _keras_initialized a member of DistributedVariable because
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
@ -40,6 +41,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
@ -703,6 +705,27 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(distribution.run(assign)))
|
||||
|
||||
def testPackedVariable(self, distribution, synchronization, aggregation):
|
||||
with distribution.scope():
|
||||
v0 = variables_lib.Variable(
|
||||
0., synchronization=synchronization, aggregation=aggregation)
|
||||
self.assertEqual(v0._packed_var, None)
|
||||
|
||||
device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type
|
||||
for d in v0._devices:
|
||||
if device.DeviceSpec.from_string(d).device_type != device_type:
|
||||
self.skipTest("Packing variables on devices of different types "
|
||||
"is not supported yet.")
|
||||
|
||||
distribution._enable_packed_variable_in_eager_mode = True
|
||||
with distribution.scope():
|
||||
v1 = variables_lib.Variable(
|
||||
0., synchronization=synchronization, aggregation=aggregation)
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
|
||||
else:
|
||||
self.assertEqual(v1._packed_var, None)
|
||||
|
||||
|
||||
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -281,6 +281,9 @@ class EagerResourceDeleter(object):
|
||||
# Resources follow object-identity when executing eagerly, so it is safe to
|
||||
# delete the resource we have a handle to.
|
||||
try:
|
||||
# A packed EagerTensor doesn't own any resource.
|
||||
if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
|
||||
return
|
||||
# This resource was created in eager mode. However, this destructor may be
|
||||
# running in graph mode (especially during unit tests). To clean up
|
||||
# successfully, we switch back into eager mode temporarily.
|
||||
|
Loading…
Reference in New Issue
Block a user