Remove unused symbols in vars_test.
PiperOrigin-RevId: 324686959 Change-Id: If5a8d2ccf6d4baa4e1f19d83a2d54d359c6e6514
This commit is contained in:
parent
0d85fa03ef
commit
d84acd6e45
@ -26,7 +26,6 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
@ -664,26 +663,6 @@ class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
|
||||
|
||||
|
||||
def _make_replica_local(method, strategy=None):
|
||||
if strategy is None:
|
||||
devices = ("/device:GPU:0", "/device:CPU:0")
|
||||
else:
|
||||
devices = strategy.extended.worker_devices
|
||||
|
||||
v = []
|
||||
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
|
||||
with ops.device(d):
|
||||
v.append(variable_scope.get_variable(
|
||||
name=n, initializer=init, use_resource=True))
|
||||
|
||||
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
|
||||
var_cls = tpu_values.TPUSyncOnReadVariable
|
||||
else:
|
||||
var_cls = values.SyncOnReadVariable
|
||||
replica_local = var_cls(strategy, v, method)
|
||||
return v, replica_local
|
||||
|
||||
|
||||
class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
@ -1258,12 +1237,5 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
|
||||
|
||||
|
||||
def _make_index_slices(vals, indices, dense_shape=None):
|
||||
if dense_shape:
|
||||
dense_shape = array_ops.identity(dense_shape)
|
||||
return indexed_slices.IndexedSlices(
|
||||
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user