Update loss_scaling_gradient_tape_test to not rely on Keras.

PiperOrigin-RevId: 316000157
Change-Id: Ibcb3571498c26535c0505da328ab7ab0e5d8619f
This commit is contained in:
Scott Zhu 2020-06-11 16:22:46 -07:00 committed by TensorFlower Gardener
parent bf81b60578
commit 8b997d655d

View File

@ -27,7 +27,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_combinations
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -133,28 +132,24 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
with strategy.scope():
x1 = variables.Variable(1.0) # Distributed variable
x2 = variables.Variable([1.0, 2.0]) # Distributed non-scalar variable
# Distributed AutoCastVariable
x3 = autocast_variable.create_autocast_variable(variables.Variable(2.0))
x4 = variables.Variable(2.0) # Non-distributed variable
x5 = constant_op.constant(2.0) # Tensor
x3 = variables.Variable(2.0) # Non-distributed variable
x4 = constant_op.constant(2.0) # Tensor
def run_fn():
with lsgt.LossScaleGradientTape(loss_scale) as g:
g.watch(x5)
y = x1 * x2 * x3 * x4 * x5
return g.gradient(y, [x1, x2, x3, x4, x5])
x1g, x2g, x3g, x4g, x5g = self._run_with_strategy(run_fn, strategy,
use_tf_function)
g.watch(x4)
y = x1 * x2 * x3 * x4
return g.gradient(y, [x1, x2, x3, x4])
x1g, x2g, x3g, x4g = self._run_with_strategy(run_fn, strategy,
use_tf_function)
self.assertEqual(loss_scale(), 32)
for dy_dx1 in x1g:
self.assertEqual(dy_dx1, 24.0)
self.assertEqual(dy_dx1, 12.0)
for dy_dx2 in x2g:
self.assertAllEqual(dy_dx2, [8.0, 8.0])
self.assertAllEqual(dy_dx2, [4.0, 4.0])
for dy_dx3 in x3g:
self.assertEqual(dy_dx3, 12.0)
self.assertEqual(dy_dx3, 6.0)
for dy_dx4 in x4g:
self.assertEqual(dy_dx4, 12.0)
for dy_dx5 in x5g:
self.assertEqual(dy_dx5, 12.0)
self.assertEqual(dy_dx4, 6.0)
@test_combinations.generate(test_combinations.combine(
loss_scale=[loss_scale_module.FixedLossScale,