Add operator overloads to AutoCastVariable.
The code was copied from DistributionStrategy at 81acfa851e/tensorflow/python/distribute/values.py (L401)
with slight modifications.
PiperOrigin-RevId: 256469842
This commit is contained in:
parent
f3fec15c7c
commit
604988b5d4
@ -148,8 +148,63 @@ class AutoCastVariable(trackable.Trackable):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
|
||||
# TODO(reedwm): Define operator overloads.
|
||||
# Operator overloads:
|
||||
# Note we only overload operators that support floating-point types, as
|
||||
# non-float variables cannot be wrapped with an AutoCastVariable.
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o): return self.value() + o
|
||||
def __radd__(self, o): return o + self.value()
|
||||
def __sub__(self, o): return self.value() - o
|
||||
def __rsub__(self, o): return o - self.value()
|
||||
def __mul__(self, o): return self.value() * o
|
||||
def __rmul__(self, o): return o * self.value()
|
||||
def __truediv__(self, o): return self.value() / o
|
||||
def __rtruediv__(self, o): return o / self.value()
|
||||
def __floordiv__(self, o): return self.value() // o
|
||||
|
||||
def __rfloordiv__(self, o): return o // self.value()
|
||||
def __mod__(self, o): return self.value() % o
|
||||
def __rmod__(self, o): return o % self.value()
|
||||
def __lt__(self, o): return self.value() < o
|
||||
def __le__(self, o): return self.value() <= o
|
||||
def __gt__(self, o): return self.value() > o
|
||||
def __ge__(self, o): return self.value() >= o
|
||||
def __getitem__(self, o): return self.value()[o]
|
||||
def __pow__(self, o, modulo=None): return pow(self.value(), o, modulo)
|
||||
def __rpow__(self, o): return pow(o, self.value())
|
||||
def __neg__(self): return -self.value()
|
||||
def __abs__(self): return abs(self.value())
|
||||
|
||||
def __div__(self, o):
|
||||
try:
|
||||
return self.value().__div__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rdiv__(self, o):
|
||||
try:
|
||||
return self.value().__rdiv__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, o):
|
||||
try:
|
||||
return self.value().__matmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, o):
|
||||
try:
|
||||
return self.value().__rmatmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
# pylint: enable=multiple-statements
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
|
||||
|
@ -97,30 +97,46 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_operator_overloads(self, distribute):
|
||||
with get_distribute_scope(distribute):
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = get_autocast_var(x, distribute)
|
||||
self.evaluate(x.initializer)
|
||||
for read_dtype in (dtypes.float32, dtypes.float16):
|
||||
x = get_var(7., dtypes.float32)
|
||||
x = get_autocast_var(x, distribute)
|
||||
x._read_dtype = read_dtype
|
||||
self.evaluate(x.initializer)
|
||||
self.assertAlmostEqual(8, self.evaluate(x + 1))
|
||||
self.assertAlmostEqual(10, self.evaluate(3 + x))
|
||||
self.assertAlmostEqual(14, self.evaluate(x + x))
|
||||
self.assertAlmostEqual(5, self.evaluate(x - 2))
|
||||
self.assertAlmostEqual(6, self.evaluate(13 - x))
|
||||
self.assertAlmostEqual(0, self.evaluate(x - x))
|
||||
self.assertAlmostEqual(14, self.evaluate(x * 2))
|
||||
self.assertAlmostEqual(21, self.evaluate(3 * x))
|
||||
self.assertAlmostEqual(49, self.evaluate(x * x))
|
||||
self.assertAlmostEqual(3.5, self.evaluate(x / 2))
|
||||
self.assertAlmostEqual(1.5, self.evaluate(10.5 / x))
|
||||
self.assertAlmostEqual(3, self.evaluate(x // 2))
|
||||
self.assertAlmostEqual(2, self.evaluate(15 // x))
|
||||
if read_dtype == dtypes.float32:
|
||||
# The "mod" operator does not support float16
|
||||
self.assertAlmostEqual(1, self.evaluate(x % 2))
|
||||
self.assertAlmostEqual(2, self.evaluate(16 % x))
|
||||
self.assertTrue(self.evaluate(x < 12))
|
||||
self.assertTrue(self.evaluate(x <= 12))
|
||||
self.assertFalse(self.evaluate(x > 12))
|
||||
self.assertFalse(self.evaluate(x >= 12))
|
||||
self.assertFalse(self.evaluate(12 < x))
|
||||
self.assertFalse(self.evaluate(12 <= x))
|
||||
self.assertTrue(self.evaluate(12 > x))
|
||||
self.assertTrue(self.evaluate(12 >= x))
|
||||
self.assertAlmostEqual(343, self.evaluate(pow(x, 3)), places=4)
|
||||
self.assertAlmostEqual(128, self.evaluate(pow(2, x)), places=4)
|
||||
self.assertAlmostEqual(-7, self.evaluate(-x))
|
||||
self.assertAlmostEqual(7, self.evaluate(abs(x)))
|
||||
|
||||
v1 = constant_op.constant(2., dtype=dtypes.float32)
|
||||
v2 = constant_op.constant(2., dtype=dtypes.float16)
|
||||
|
||||
# Because autocast variables do not yet define operator overloads, the
|
||||
# operator is defined by the non-variable tensor
|
||||
|
||||
# Test variable as the LHS. Currently, this is not supported with
|
||||
# distributed autocast variables
|
||||
if not distribute:
|
||||
self.assertEqual(self.evaluate(x + v1), 3.)
|
||||
|
||||
x._read_dtype = dtypes.float16
|
||||
self.assertEqual(self.evaluate(x + v2), 3.)
|
||||
|
||||
# Test variable as the RHS
|
||||
x._read_dtype = dtypes.float32
|
||||
self.assertEqual(self.evaluate(v1 + x), 3.)
|
||||
|
||||
x._read_dtype = dtypes.float16
|
||||
self.assertEqual(self.evaluate(v2 + x), 3.)
|
||||
x = get_var([7, 8, 9], dtypes.float32)
|
||||
x = get_autocast_var(x, distribute)
|
||||
x._read_dtype = read_dtype
|
||||
self.evaluate(x.initializer)
|
||||
self.assertEqual(self.evaluate(x[1]), 8)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_assign(self, distribute):
|
||||
|
Loading…
Reference in New Issue
Block a user