[TF2XLA] [NFC] Test that aliased updates do not actually increase memory usage
PiperOrigin-RevId: 324705992 Change-Id: I1c7c19867c3b7086dab61427da7e7b78547e4c59
This commit is contained in:
parent
3cf7683cfe
commit
474d3df724
@ -461,26 +461,23 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def testUpdateVariable(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
v = variables.Variable(3.1)
|
||||
|
||||
on_gpu = 'gpu' in self.device.lower()
|
||||
v = variables.Variable([3.1, 3.2])
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def update_var(a, b):
|
||||
v.assign_add(a * b)
|
||||
|
||||
update_var(constant_op.constant(0.7), constant_op.constant(0.6))
|
||||
self.assertAllClose(v, 3.52)
|
||||
arg1 = random_ops.random_normal([2])
|
||||
arg2 = random_ops.random_normal([2])
|
||||
|
||||
def testUpdateVariableVector(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
v = variables.Variable([3.1, 3.1])
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def update_var(a, b):
|
||||
v.assign_add(a * b)
|
||||
|
||||
update_var(
|
||||
constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6]))
|
||||
self.assertAllClose(v, [3.52, 3.52])
|
||||
initial_usage = context.context().get_total_memory_usage(
|
||||
v.device) if on_gpu else 0
|
||||
update_var(arg1, arg2)
|
||||
final_usage = context.context().get_total_memory_usage(
|
||||
v.device) if on_gpu else 0
|
||||
self.assertEqual(initial_usage, final_usage)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/162381930): MLIR bridge renames '
|
||||
' functions')
|
||||
@ -524,11 +521,19 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
def f(a, b):
|
||||
return (a, b)
|
||||
|
||||
a = constant_op.constant([0.7])
|
||||
b = constant_op.constant([0.6])
|
||||
a = random_ops.random_normal([10, 10])
|
||||
b = random_ops.random_normal([10, 10])
|
||||
|
||||
on_gpu = 'gpu' in self.device.lower()
|
||||
initial_usage = context.context().get_total_memory_usage(
|
||||
b.backing_device) if on_gpu else 0
|
||||
|
||||
f(a, b)
|
||||
|
||||
final_usage = context.context().get_total_memory_usage(
|
||||
b.backing_device) if on_gpu else 0
|
||||
self.assertEqual(initial_usage, final_usage)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user