[TF2XLA] [NFC] Test that aliased updates do not actually increase memory usage

PiperOrigin-RevId: 324705992
Change-Id: I1c7c19867c3b7086dab61427da7e7b78547e4c59
This commit is contained in:
George Karpenkov 2020-08-03 16:19:08 -07:00 committed by TensorFlower Gardener
parent 3cf7683cfe
commit 474d3df724

View File

@ -461,26 +461,23 @@ class DefFunctionTest(xla_test.XLATestCase):
def testUpdateVariable(self):
with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable(3.1)
on_gpu = 'gpu' in self.device.lower()
v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(constant_op.constant(0.7), constant_op.constant(0.6))
self.assertAllClose(v, 3.52)
arg1 = random_ops.random_normal([2])
arg2 = random_ops.random_normal([2])
def testUpdateVariableVector(self):
with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable([3.1, 3.1])
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(
constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6]))
self.assertAllClose(v, [3.52, 3.52])
initial_usage = context.context().get_total_memory_usage(
v.device) if on_gpu else 0
update_var(arg1, arg2)
final_usage = context.context().get_total_memory_usage(
v.device) if on_gpu else 0
self.assertEqual(initial_usage, final_usage)
@test_util.disable_mlir_bridge('TODO(b/162381930): MLIR bridge renames '
' functions')
@ -524,11 +521,19 @@ class DefFunctionTest(xla_test.XLATestCase):
def f(a, b):
return (a, b)
a = constant_op.constant([0.7])
b = constant_op.constant([0.6])
a = random_ops.random_normal([10, 10])
b = random_ops.random_normal([10, 10])
on_gpu = 'gpu' in self.device.lower()
initial_usage = context.context().get_total_memory_usage(
b.backing_device) if on_gpu else 0
f(a, b)
final_usage = context.context().get_total_memory_usage(
b.backing_device) if on_gpu else 0
self.assertEqual(initial_usage, final_usage)
if __name__ == '__main__':
ops.enable_eager_execution()