Add operator overloads to AutoCastVariable.

The code was copied from DistributionStrategy at 81acfa851e/tensorflow/python/distribute/values.py (L401) with slight modifications.

PiperOrigin-RevId: 256469842
This commit is contained in:
Reed Wanderman-Milne 2019-07-03 18:00:44 -07:00 committed by TensorFlower Gardener
parent f3fec15c7c
commit 604988b5d4
2 changed files with 95 additions and 24 deletions

View File

@ -148,8 +148,63 @@ class AutoCastVariable(trackable.Trackable):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
# TODO(reedwm): Define operator overloads.
# Operator overloads:
# Note we only overload operators that support floating-point types, as
# non-float variables cannot be wrapped with an AutoCastVariable.
# pylint: disable=multiple-statements
def __add__(self, o): return self.value() + o
def __radd__(self, o): return o + self.value()
def __sub__(self, o): return self.value() - o
def __rsub__(self, o): return o - self.value()
def __mul__(self, o): return self.value() * o
def __rmul__(self, o): return o * self.value()
def __truediv__(self, o): return self.value() / o
def __rtruediv__(self, o): return o / self.value()
def __floordiv__(self, o): return self.value() // o
def __rfloordiv__(self, o): return o // self.value()
def __mod__(self, o): return self.value() % o
def __rmod__(self, o): return o % self.value()
def __lt__(self, o): return self.value() < o
def __le__(self, o): return self.value() <= o
def __gt__(self, o): return self.value() > o
def __ge__(self, o): return self.value() >= o
def __getitem__(self, o): return self.value()[o]
def __pow__(self, o, modulo=None): return pow(self.value(), o, modulo)
def __rpow__(self, o): return pow(o, self.value())
def __neg__(self): return -self.value()
def __abs__(self): return abs(self.value())
def __div__(self, o):
try:
return self.value().__div__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rdiv__(self, o):
try:
return self.value().__rdiv__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __matmul__(self, o):
try:
return self.value().__matmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rmatmul__(self, o):
try:
return self.value().__rmatmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
# pylint: enable=multiple-statements
ops.register_tensor_conversion_function(
AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access

View File

@ -97,30 +97,46 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(*TESTCASES)
def test_operator_overloads(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
for read_dtype in (dtypes.float32, dtypes.float16):
x = get_var(7., dtypes.float32)
x = get_autocast_var(x, distribute)
x._read_dtype = read_dtype
self.evaluate(x.initializer)
self.assertAlmostEqual(8, self.evaluate(x + 1))
self.assertAlmostEqual(10, self.evaluate(3 + x))
self.assertAlmostEqual(14, self.evaluate(x + x))
self.assertAlmostEqual(5, self.evaluate(x - 2))
self.assertAlmostEqual(6, self.evaluate(13 - x))
self.assertAlmostEqual(0, self.evaluate(x - x))
self.assertAlmostEqual(14, self.evaluate(x * 2))
self.assertAlmostEqual(21, self.evaluate(3 * x))
self.assertAlmostEqual(49, self.evaluate(x * x))
self.assertAlmostEqual(3.5, self.evaluate(x / 2))
self.assertAlmostEqual(1.5, self.evaluate(10.5 / x))
self.assertAlmostEqual(3, self.evaluate(x // 2))
self.assertAlmostEqual(2, self.evaluate(15 // x))
if read_dtype == dtypes.float32:
# The "mod" operator does not support float16
self.assertAlmostEqual(1, self.evaluate(x % 2))
self.assertAlmostEqual(2, self.evaluate(16 % x))
self.assertTrue(self.evaluate(x < 12))
self.assertTrue(self.evaluate(x <= 12))
self.assertFalse(self.evaluate(x > 12))
self.assertFalse(self.evaluate(x >= 12))
self.assertFalse(self.evaluate(12 < x))
self.assertFalse(self.evaluate(12 <= x))
self.assertTrue(self.evaluate(12 > x))
self.assertTrue(self.evaluate(12 >= x))
self.assertAlmostEqual(343, self.evaluate(pow(x, 3)), places=4)
self.assertAlmostEqual(128, self.evaluate(pow(2, x)), places=4)
self.assertAlmostEqual(-7, self.evaluate(-x))
self.assertAlmostEqual(7, self.evaluate(abs(x)))
v1 = constant_op.constant(2., dtype=dtypes.float32)
v2 = constant_op.constant(2., dtype=dtypes.float16)
# Because autocast variables do not yet define operator overloads, the
# operator is defined by the non-variable tensor
# Test variable as the LHS. Currently, this is not supported with
# distributed autocast variables
if not distribute:
self.assertEqual(self.evaluate(x + v1), 3.)
x._read_dtype = dtypes.float16
self.assertEqual(self.evaluate(x + v2), 3.)
# Test variable as the RHS
x._read_dtype = dtypes.float32
self.assertEqual(self.evaluate(v1 + x), 3.)
x._read_dtype = dtypes.float16
self.assertEqual(self.evaluate(v2 + x), 3.)
x = get_var([7, 8, 9], dtypes.float32)
x = get_autocast_var(x, distribute)
x._read_dtype = read_dtype
self.evaluate(x.initializer)
self.assertEqual(self.evaluate(x[1]), 8)
@parameterized.named_parameters(*TESTCASES)
def test_assign(self, distribute):