Rename import to avoid conflict with parameters.
PiperOrigin-RevId: 318113928 Change-Id: I06da2eff0eacc8e5bedfd20cf7334c17bcab008f
This commit is contained in:
parent
830b3c0eff
commit
dc30240f53
@ -36,7 +36,7 @@ from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute import values as values_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -69,7 +69,7 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
def testGetEager(self):
|
||||
one = constant_op.constant(1)
|
||||
two = constant_op.constant(2)
|
||||
v = values.DistributedValues((one, two))
|
||||
v = values_lib.DistributedValues((one, two))
|
||||
self.assertEqual(one, v._get())
|
||||
with distribute_lib.ReplicaContext(None, 1):
|
||||
self.assertEqual(two, v._get())
|
||||
@ -78,7 +78,7 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
with context.graph_mode(), ops.Graph().as_default():
|
||||
one = constant_op.constant(1)
|
||||
two = constant_op.constant(2)
|
||||
v = values.DistributedValues((one, two))
|
||||
v = values_lib.DistributedValues((one, two))
|
||||
self.assertEqual(one, v._get())
|
||||
with distribute_lib.ReplicaContext(None, 1):
|
||||
self.assertEqual(two, v._get())
|
||||
@ -291,14 +291,14 @@ class DistributedDelegateTest(test.TestCase):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
v = values.DistributedDelegate((Foo(7), Foo(8)))
|
||||
v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
|
||||
self.assertEqual(7, v.x)
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = v.y
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testOperatorOverride(self):
|
||||
v = values.DistributedDelegate((7, 8))
|
||||
v = values_lib.DistributedDelegate((7, 8))
|
||||
# v should act like int(7).
|
||||
self.assertEqual(8, v + 1)
|
||||
self.assertEqual(10, 3 + v)
|
||||
@ -348,7 +348,7 @@ class DistributedDelegateTest(test.TestCase):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
v = values.DistributedDelegate((Foo(7), Foo(8)))
|
||||
v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
|
||||
v_shallow_copy = copy.copy(v)
|
||||
self.assertEqual(v.x, v_shallow_copy.x)
|
||||
v_deep_copy = copy.deepcopy(v)
|
||||
@ -369,7 +369,7 @@ def _make_mirrored_val(init_val=5.0):
|
||||
for d, _ in zip(devices, ["v", "v/replica"]):
|
||||
with ops.device(d):
|
||||
v.append(constant_op.constant(init_val))
|
||||
return values.Mirrored(v)
|
||||
return values_lib.Mirrored(v)
|
||||
|
||||
|
||||
def _make_mirrored():
|
||||
@ -379,7 +379,7 @@ def _make_mirrored():
|
||||
with ops.device(d):
|
||||
v.append(variable_scope.get_variable(
|
||||
name=n, initializer=init, use_resource=True))
|
||||
mirrored = values.MirroredVariable(
|
||||
mirrored = values_lib.MirroredVariable(
|
||||
None, v, variable_scope.VariableAggregation.SUM)
|
||||
return mirrored
|
||||
|
||||
@ -396,7 +396,7 @@ def mirrored_and_tpu_strategy_combinations():
|
||||
|
||||
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _is_per_replica(self, result, expected, klass=values.PerReplica):
|
||||
def _is_per_replica(self, result, expected, klass=values_lib.PerReplica):
|
||||
self.assertIsInstance(result, klass)
|
||||
for i, exp in enumerate(expected):
|
||||
self.assertEqual(exp, result.values[i])
|
||||
@ -443,21 +443,21 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
# Normally a mirrored value would be the same across devices, but
|
||||
# for a test it is convenient to be able to tell the values apart.
|
||||
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")),
|
||||
values.Mirrored)
|
||||
values_lib.Mirrored)
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 3)
|
||||
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
|
||||
self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)
|
||||
self._is_per_replica(result[0], ["a1", "a2"], values_lib.Mirrored)
|
||||
self._is_per_replica(result[2], ["h1", "h2"], values_lib.Mirrored)
|
||||
|
||||
self.assertIsInstance(result[1], list)
|
||||
self.assertLen(result[1], 3)
|
||||
self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
|
||||
self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)
|
||||
self._is_per_replica(result[1][0], ["b1", "b2"], values_lib.Mirrored)
|
||||
self._is_per_replica(result[1][2], ["g1", "g2"], values_lib.Mirrored)
|
||||
|
||||
self.assertIsInstance(result[1][1], dict)
|
||||
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
|
||||
self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
|
||||
self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
|
||||
self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values_lib.Mirrored)
|
||||
self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values_lib.Mirrored)
|
||||
|
||||
# Also test that we can undo the merge using select_replica()
|
||||
self.assertEqual(_nested_value("1"),
|
||||
@ -474,8 +474,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
result = distribute_utils.regroup([("1", "2"), ("3", "4")])
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 2)
|
||||
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
|
||||
self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
|
||||
self._is_per_replica(result[0], ("1", "3"), values_lib.PerReplica)
|
||||
self._is_per_replica(result[1], ("2", "4"), values_lib.PerReplica)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -785,7 +785,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
def testVariableOnAnotherDevice(self):
|
||||
v = variable_scope.get_variable(
|
||||
name="v", initializer=[1.], use_resource=True)
|
||||
mirrored = values.MirroredVariable(
|
||||
mirrored = values_lib.MirroredVariable(
|
||||
None, (v,), variable_scope.VariableAggregation.MEAN)
|
||||
|
||||
self.assertEqual(v.name, mirrored.name)
|
||||
@ -942,7 +942,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device("/device:GPU:0"):
|
||||
v = variable_scope.get_variable(
|
||||
name="v", initializer=1., use_resource=True)
|
||||
mirrored = values.MirroredVariable(
|
||||
mirrored = values_lib.MirroredVariable(
|
||||
distribution, (v,), variable_scope.VariableAggregation.MEAN)
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
sess.run({"complicated": mirrored})
|
||||
@ -1451,7 +1451,7 @@ def _make_replica_local(method, strategy=None):
|
||||
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
|
||||
var_cls = tpu_values.TPUSyncOnReadVariable
|
||||
else:
|
||||
var_cls = values.SyncOnReadVariable
|
||||
var_cls = values_lib.SyncOnReadVariable
|
||||
replica_local = var_cls(strategy, v, method)
|
||||
return v, replica_local
|
||||
|
||||
@ -1483,7 +1483,7 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
|
||||
|
||||
v = variable_scope.get_variable(
|
||||
name="v", initializer=[1.], use_resource=True)
|
||||
replica_local = values.SyncOnReadVariable(
|
||||
replica_local = values_lib.SyncOnReadVariable(
|
||||
None, (v,), variable_scope.VariableAggregation.MEAN)
|
||||
self.assertEqual(2., self.evaluate(add1(replica_local)))
|
||||
|
||||
@ -2035,7 +2035,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2053,7 +2053,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2071,7 +2071,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2089,7 +2089,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2107,7 +2107,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2125,7 +2125,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2143,7 +2143,7 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=aggregation)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
delta = values.PerReplica([
|
||||
delta = values_lib.PerReplica([
|
||||
indexed_slices.IndexedSlices(
|
||||
values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
|
||||
indexed_slices.IndexedSlices(
|
||||
@ -2175,7 +2175,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testTypeSpec(self):
|
||||
vals = (constant_op.constant(1.),)
|
||||
per_replica = values.PerReplica(vals)
|
||||
per_replica = values_lib.PerReplica(vals)
|
||||
|
||||
spec = per_replica._type_spec
|
||||
self.assertEqual(spec._value_specs,
|
||||
@ -2183,7 +2183,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testTypeSpecRoundTrip(self):
|
||||
vals = (constant_op.constant(1.),)
|
||||
per_replica = values.PerReplica(vals)
|
||||
per_replica = values_lib.PerReplica(vals)
|
||||
|
||||
spec = per_replica._type_spec
|
||||
tensor_list = spec._to_components(per_replica)
|
||||
@ -2193,7 +2193,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testTypeSpecNest(self):
|
||||
vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
|
||||
per_replica = values.PerReplica(vals)
|
||||
per_replica = values_lib.PerReplica(vals)
|
||||
|
||||
# Note: nest.map_structure exercises nest.flatten and
|
||||
# nest.pack_sequence_as.
|
||||
@ -2206,7 +2206,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testIsGraphTensor(self):
|
||||
per_replica = values.PerReplica((constant_op.constant(1.),))
|
||||
per_replica = values_lib.PerReplica((constant_op.constant(1.),))
|
||||
for t in nest.flatten(per_replica, expand_composites=True):
|
||||
self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
|
||||
|
||||
@ -2218,7 +2218,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
traces.append(None) # Only happens on trace.
|
||||
return x
|
||||
|
||||
per_replica = values.PerReplica((constant_op.constant(1.),))
|
||||
per_replica = values_lib.PerReplica((constant_op.constant(1.),))
|
||||
|
||||
# Trace once.
|
||||
f(per_replica)
|
||||
@ -2232,13 +2232,13 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
per_replica = per_replica_spec._from_components(vals)
|
||||
|
||||
output = f(per_replica)
|
||||
self.assertIsInstance(output, values.PerReplica)
|
||||
self.assertIsInstance(output, values_lib.PerReplica)
|
||||
self.assertAllEqual(output._values, per_replica._values)
|
||||
self.assertEmpty(traces) # Make sure we're not re-tracing `f`.
|
||||
|
||||
def testFunctionCanReturnPerReplica(self):
|
||||
f = def_function.function(lambda x: x)
|
||||
x = values.PerReplica((constant_op.constant(1.),))
|
||||
x = values_lib.PerReplica((constant_op.constant(1.),))
|
||||
y = f(x)
|
||||
self.assertIsNot(x, y)
|
||||
nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
|
||||
@ -2246,8 +2246,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCondWithTensorValues(self):
|
||||
per_replica_1 = values.PerReplica((constant_op.constant("a"),))
|
||||
per_replica_2 = values.PerReplica((constant_op.constant(["b", "c"]),))
|
||||
per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),))
|
||||
per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),))
|
||||
condition = array_ops.placeholder_with_default(True, [])
|
||||
|
||||
result = control_flow_ops.cond(
|
||||
@ -2258,8 +2258,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCondWithValuesConvertibleToTensor(self):
|
||||
per_replica_1 = values.PerReplica(("a",))
|
||||
per_replica_2 = values.PerReplica(("b",))
|
||||
per_replica_1 = values_lib.PerReplica(("a",))
|
||||
per_replica_2 = values_lib.PerReplica(("b",))
|
||||
condition = array_ops.placeholder_with_default(True, [])
|
||||
|
||||
result = control_flow_ops.cond(
|
||||
@ -2270,8 +2270,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testCondWithValuesNotConvertibleToTensor(self):
|
||||
per_replica_1 = values.PerReplica(({"a"},))
|
||||
per_replica_2 = values.PerReplica(({"b", "c"},))
|
||||
per_replica_1 = values_lib.PerReplica(({"a"},))
|
||||
per_replica_2 = values_lib.PerReplica(({"b", "c"},))
|
||||
condition = array_ops.placeholder(dtypes.bool, [])
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
|
||||
@ -2279,11 +2279,11 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
condition, lambda: per_replica_1, lambda: per_replica_2)
|
||||
|
||||
|
||||
def _make_index_slices(vals, indices, dense_shape=None):
|
||||
def _make_index_slices(values, indices, dense_shape=None):
|
||||
if dense_shape:
|
||||
dense_shape = array_ops.identity(dense_shape)
|
||||
return indexed_slices.IndexedSlices(
|
||||
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
|
||||
array_ops.identity(values), array_ops.identity(indices), dense_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user