Fix error which results in 2 different strategy instances being used(one is used for creating variables under a given scope and the second is used to call run
. Add a second regex assertion since we will only have one kind of tf.distribute variables (DistributedVariable) going forward.
PiperOrigin-RevId: 322671804 Change-Id: Ie9e992669e44486f3217845221577b0535516184
This commit is contained in:
parent
39e13608cb
commit
ff95948ceb
@ -432,13 +432,21 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
)
|
||||
|
||||
def test_repr_distributed(self):
|
||||
with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
|
||||
strategy = mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2'])
|
||||
with strategy.scope():
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
self.assertRegex(
|
||||
repr(x).replace('\n', ' '),
|
||||
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
|
||||
'inner_variable=MirroredVariable.*>')
|
||||
use_policy = getattr(strategy.extended, '_use_policy', False)
|
||||
if use_policy:
|
||||
self.assertRegex(
|
||||
repr(x).replace('\n', ' '),
|
||||
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
|
||||
'inner_variable=DistributedVariable.*>')
|
||||
else:
|
||||
self.assertRegex(
|
||||
repr(x).replace('\n', ' '),
|
||||
'<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
|
||||
'inner_variable=MirroredVariable.*>')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('v1', gradient_descent_v1.GradientDescentOptimizer),
|
||||
|
@ -114,7 +114,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
with lsgt.LossScaleGradientTape(loss_scale) as g:
|
||||
y = x * x
|
||||
return g.gradient(y, x, output_gradients=constant_op.constant(2.0))
|
||||
dy_dx_list = self._run_with_strategy(run_fn, strategy_fn(), use_tf_function)
|
||||
dy_dx_list = self._run_with_strategy(run_fn, strategy, use_tf_function)
|
||||
self.assertEqual(loss_scale(), 32)
|
||||
for dy_dx in dy_dx_list:
|
||||
self.assertEqual(dy_dx, 12.0)
|
||||
@ -236,7 +236,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
dy_dx = g.gradient(y, x)
|
||||
return dz_dx, dy_dx
|
||||
|
||||
dz_dx_list, dy_dx_list = self._run_with_strategy(run_fn, strategy_fn(),
|
||||
dz_dx_list, dy_dx_list = self._run_with_strategy(run_fn, strategy,
|
||||
use_tf_function)
|
||||
for dz_dx in dz_dx_list:
|
||||
self.assertEqual(dz_dx, 108.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user