Update loss_scaling_gradient_tape_test to not rely on Keras.
PiperOrigin-RevId: 316000157 Change-Id: Ibcb3571498c26535c0505da328ab7ab0e5d8619f
This commit is contained in:
parent
bf81b60578
commit
8b997d655d
@ -27,7 +27,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_combinations
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -133,28 +132,24 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
with strategy.scope():
|
||||
x1 = variables.Variable(1.0) # Distributed variable
|
||||
x2 = variables.Variable([1.0, 2.0]) # Distributed non-scalar variable
|
||||
# Distributed AutoCastVariable
|
||||
x3 = autocast_variable.create_autocast_variable(variables.Variable(2.0))
|
||||
x4 = variables.Variable(2.0) # Non-distributed variable
|
||||
x5 = constant_op.constant(2.0) # Tensor
|
||||
x3 = variables.Variable(2.0) # Non-distributed variable
|
||||
x4 = constant_op.constant(2.0) # Tensor
|
||||
def run_fn():
|
||||
with lsgt.LossScaleGradientTape(loss_scale) as g:
|
||||
g.watch(x5)
|
||||
y = x1 * x2 * x3 * x4 * x5
|
||||
return g.gradient(y, [x1, x2, x3, x4, x5])
|
||||
x1g, x2g, x3g, x4g, x5g = self._run_with_strategy(run_fn, strategy,
|
||||
use_tf_function)
|
||||
g.watch(x4)
|
||||
y = x1 * x2 * x3 * x4
|
||||
return g.gradient(y, [x1, x2, x3, x4])
|
||||
x1g, x2g, x3g, x4g = self._run_with_strategy(run_fn, strategy,
|
||||
use_tf_function)
|
||||
self.assertEqual(loss_scale(), 32)
|
||||
for dy_dx1 in x1g:
|
||||
self.assertEqual(dy_dx1, 24.0)
|
||||
self.assertEqual(dy_dx1, 12.0)
|
||||
for dy_dx2 in x2g:
|
||||
self.assertAllEqual(dy_dx2, [8.0, 8.0])
|
||||
self.assertAllEqual(dy_dx2, [4.0, 4.0])
|
||||
for dy_dx3 in x3g:
|
||||
self.assertEqual(dy_dx3, 12.0)
|
||||
self.assertEqual(dy_dx3, 6.0)
|
||||
for dy_dx4 in x4g:
|
||||
self.assertEqual(dy_dx4, 12.0)
|
||||
for dy_dx5 in x5g:
|
||||
self.assertEqual(dy_dx5, 12.0)
|
||||
self.assertEqual(dy_dx4, 6.0)
|
||||
|
||||
@test_combinations.generate(test_combinations.combine(
|
||||
loss_scale=[loss_scale_module.FixedLossScale,
|
||||
|
Loading…
Reference in New Issue
Block a user