Remove unused symbols in vars_test.

PiperOrigin-RevId: 324686959
Change-Id: If5a8d2ccf6d4baa4e1f19d83a2d54d359c6e6514
This commit is contained in:
Anjali Sridhar 2020-08-03 14:43:36 -07:00 committed by TensorFlower Gardener
parent 0d85fa03ef
commit d84acd6e45

View File

@ -26,7 +26,6 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
@ -664,26 +663,6 @@ class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
def _make_replica_local(method, strategy=None):
if strategy is None:
devices = ("/device:GPU:0", "/device:CPU:0")
else:
devices = strategy.extended.worker_devices
v = []
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
with ops.device(d):
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
var_cls = tpu_values.TPUSyncOnReadVariable
else:
var_cls = values.SyncOnReadVariable
replica_local = var_cls(strategy, v, method)
return v, replica_local
class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations())
@ -1258,12 +1237,5 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
def _make_index_slices(vals, indices, dense_shape=None):
if dense_shape:
dense_shape = array_ops.identity(dense_shape)
return indexed_slices.IndexedSlices(
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
if __name__ == "__main__":
test.main()