Fix error which results in 2 different strategy instances being used(one is used for creating variables under a given scope and the second is used to call run. Add a second regex assertion since we will only have one kind of tf.distribute variables (DistributedVariable) going forward.

PiperOrigin-RevId: 322671804
Change-Id: Ie9e992669e44486f3217845221577b0535516184
This commit is contained in:
Anjali Sridhar 2020-07-22 15:45:59 -07:00 committed by TensorFlower Gardener
parent 39e13608cb
commit ff95948ceb
2 changed files with 15 additions and 7 deletions

View File

@ -432,13 +432,21 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
)
def test_repr_distributed(self):
with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
strategy = mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2'])
with strategy.scope():
x = get_var(1., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)
self.assertRegex(
repr(x).replace('\n', ' '),
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
'inner_variable=MirroredVariable.*>')
use_policy = getattr(strategy.extended, '_use_policy', False)
if use_policy:
self.assertRegex(
repr(x).replace('\n', ' '),
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
'inner_variable=DistributedVariable.*>')
else:
self.assertRegex(
repr(x).replace('\n', ' '),
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
'inner_variable=MirroredVariable.*>')
@parameterized.named_parameters(
('v1', gradient_descent_v1.GradientDescentOptimizer),

View File

@ -114,7 +114,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
with lsgt.LossScaleGradientTape(loss_scale) as g:
y = x * x
return g.gradient(y, x, output_gradients=constant_op.constant(2.0))
dy_dx_list = self._run_with_strategy(run_fn, strategy_fn(), use_tf_function)
dy_dx_list = self._run_with_strategy(run_fn, strategy, use_tf_function)
self.assertEqual(loss_scale(), 32)
for dy_dx in dy_dx_list:
self.assertEqual(dy_dx, 12.0)
@ -236,7 +236,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
dy_dx = g.gradient(y, x)
return dz_dx, dy_dx
dz_dx_list, dy_dx_list = self._run_with_strategy(run_fn, strategy_fn(),
dz_dx_list, dy_dx_list = self._run_with_strategy(run_fn, strategy,
use_tf_function)
for dz_dx in dz_dx_list:
self.assertEqual(dz_dx, 108.0)