From 604988b5d4e8cec6564db6502e6e40eefac8fc67 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 3 Jul 2019 18:00:44 -0700 Subject: [PATCH] Add operator overloads to AutoCastVariable. The code was copied from DistributionStrategy at https://github.com/tensorflow/tensorflow/blob/81acfa851ecf413df02c6bdf4795630524f2f859/tensorflow/python/distribute/values.py#L401 with slight modifications. PiperOrigin-RevId: 256469842 --- .../experimental/autocast_variable.py | 57 ++++++++++++++++- .../experimental/autocast_variable_test.py | 62 ++++++++++++------- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py index 53ea015222c..e968594ef08 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py @@ -148,8 +148,63 @@ class AutoCastVariable(trackable.Trackable): """Pass resource_variable_ops.is_resource_variable check.""" pass - # TODO(reedwm): Define operator overloads. + # Operator overloads: + # Note we only overload operators that support floating-point types, as + # non-float variables cannot be wrapped with an AutoCastVariable. + # pylint: disable=multiple-statements + def __add__(self, o): return self.value() + o + def __radd__(self, o): return o + self.value() + def __sub__(self, o): return self.value() - o + def __rsub__(self, o): return o - self.value() + def __mul__(self, o): return self.value() * o + def __rmul__(self, o): return o * self.value() + def __truediv__(self, o): return self.value() / o + def __rtruediv__(self, o): return o / self.value() + def __floordiv__(self, o): return self.value() // o + + def __rfloordiv__(self, o): return o // self.value() + def __mod__(self, o): return self.value() % o + def __rmod__(self, o): return o % self.value() + def __lt__(self, o): return self.value() < o + def __le__(self, o): return self.value() <= o + def __gt__(self, o): return self.value() > o + def __ge__(self, o): return self.value() >= o + def __getitem__(self, o): return self.value()[o] + def __pow__(self, o, modulo=None): return pow(self.value(), o, modulo) + def __rpow__(self, o): return pow(o, self.value()) + def __neg__(self): return -self.value() + def __abs__(self): return abs(self.value()) + + def __div__(self, o): + try: + return self.value().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.value().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.value().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.value().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # pylint: enable=multiple-statements ops.register_tensor_conversion_function( AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index ed31d202d39..850691c86b4 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -97,30 +97,46 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters(*TESTCASES) def test_operator_overloads(self, distribute): with get_distribute_scope(distribute): - x = get_var(1., dtypes.float32) - x = get_autocast_var(x, distribute) - self.evaluate(x.initializer) + for read_dtype in (dtypes.float32, dtypes.float16): + x = get_var(7., dtypes.float32) + x = get_autocast_var(x, distribute) + x._read_dtype = read_dtype + self.evaluate(x.initializer) + self.assertAlmostEqual(8, self.evaluate(x + 1)) + self.assertAlmostEqual(10, self.evaluate(3 + x)) + self.assertAlmostEqual(14, self.evaluate(x + x)) + self.assertAlmostEqual(5, self.evaluate(x - 2)) + self.assertAlmostEqual(6, self.evaluate(13 - x)) + self.assertAlmostEqual(0, self.evaluate(x - x)) + self.assertAlmostEqual(14, self.evaluate(x * 2)) + self.assertAlmostEqual(21, self.evaluate(3 * x)) + self.assertAlmostEqual(49, self.evaluate(x * x)) + self.assertAlmostEqual(3.5, self.evaluate(x / 2)) + self.assertAlmostEqual(1.5, self.evaluate(10.5 / x)) + self.assertAlmostEqual(3, self.evaluate(x // 2)) + self.assertAlmostEqual(2, self.evaluate(15 // x)) + if read_dtype == dtypes.float32: + # The "mod" operator does not support float16 + self.assertAlmostEqual(1, self.evaluate(x % 2)) + self.assertAlmostEqual(2, self.evaluate(16 % x)) + self.assertTrue(self.evaluate(x < 12)) + self.assertTrue(self.evaluate(x <= 12)) + self.assertFalse(self.evaluate(x > 12)) + self.assertFalse(self.evaluate(x >= 12)) + self.assertFalse(self.evaluate(12 < x)) + self.assertFalse(self.evaluate(12 <= x)) + self.assertTrue(self.evaluate(12 > x)) + self.assertTrue(self.evaluate(12 >= x)) + self.assertAlmostEqual(343, self.evaluate(pow(x, 3)), places=4) + self.assertAlmostEqual(128, self.evaluate(pow(2, x)), places=4) + self.assertAlmostEqual(-7, self.evaluate(-x)) + self.assertAlmostEqual(7, self.evaluate(abs(x))) - v1 = constant_op.constant(2., dtype=dtypes.float32) - v2 = constant_op.constant(2., dtype=dtypes.float16) - - # Because autocast variables do not yet define operator overloads, the - # operator is defined by the non-variable tensor - - # Test variable as the LHS. Currently, this is not supported with - # distributed autocast variables - if not distribute: - self.assertEqual(self.evaluate(x + v1), 3.) - - x._read_dtype = dtypes.float16 - self.assertEqual(self.evaluate(x + v2), 3.) - - # Test variable as the RHS - x._read_dtype = dtypes.float32 - self.assertEqual(self.evaluate(v1 + x), 3.) - - x._read_dtype = dtypes.float16 - self.assertEqual(self.evaluate(v2 + x), 3.) + x = get_var([7, 8, 9], dtypes.float32) + x = get_autocast_var(x, distribute) + x._read_dtype = read_dtype + self.evaluate(x.initializer) + self.assertEqual(self.evaluate(x[1]), 8) @parameterized.named_parameters(*TESTCASES) def test_assign(self, distribute):