From 97d5010a8fab6c9125516f2eb404a0c4543c955a Mon Sep 17 00:00:00 2001
From: Scott Zhu <scottzhu@google.com>
Date: Wed, 8 Apr 2020 10:27:04 -0700
Subject: [PATCH] Move keras related eager memory test to keras

PiperOrigin-RevId: 305504849
Change-Id: If783553d443ad4c900eebbb72970ea582cf801a5
---
 tensorflow/python/eager/memory_tests/BUILD    |  1 -
 .../python/eager/memory_tests/memory_test.py  | 42 ----------
 tensorflow/python/keras/tests/BUILD           | 25 ++++++
 tensorflow/python/keras/tests/memory_test.py  | 80 +++++++++++++++++++
 4 files changed, 105 insertions(+), 43 deletions(-)
 create mode 100644 tensorflow/python/keras/tests/memory_test.py

diff --git a/tensorflow/python/eager/memory_tests/BUILD b/tensorflow/python/eager/memory_tests/BUILD
index c9694c64694..419de91b42a 100644
--- a/tensorflow/python/eager/memory_tests/BUILD
+++ b/tensorflow/python/eager/memory_tests/BUILD
@@ -34,7 +34,6 @@ cuda_py_test(
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:test",
-        "//tensorflow/python/keras",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/eager/memory_tests/memory_test.py b/tensorflow/python/eager/memory_tests/memory_test.py
index ba94621f67b..ba831b5ba8c 100644
--- a/tensorflow/python/eager/memory_tests/memory_test.py
+++ b/tensorflow/python/eager/memory_tests/memory_test.py
@@ -24,7 +24,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import keras
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
@@ -38,17 +37,6 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.variables import Variable
 
 
-class SingleLayerNet(keras.Model):
-  """Simple keras model used to ensure that there are no leaks."""
-
-  def __init__(self):
-    super(SingleLayerNet, self).__init__()
-    self.fc1 = keras.layers.Dense(5)
-
-  def call(self, x):
-    return self.fc1(x)
-
-
 class MemoryTest(test.TestCase):
 
   def testMemoryLeakAnonymousVariable(self):
@@ -61,36 +49,6 @@ class MemoryTest(test.TestCase):
 
     memory_test_util.assert_no_leak(f, num_iters=10000)
 
-  def testMemoryLeakInSimpleModelForwardOnly(self):
-    if not memory_test_util.memory_profiler_is_available():
-      self.skipTest("memory_profiler required to run this test")
-
-    inputs = array_ops.zeros([32, 100], dtypes.float32)
-    net = SingleLayerNet()
-
-    def f():
-      with backprop.GradientTape():
-        net(inputs)
-
-    memory_test_util.assert_no_leak(f)
-
-  def testMemoryLeakInSimpleModelForwardAndBackward(self):
-    if not memory_test_util.memory_profiler_is_available():
-      self.skipTest("memory_profiler required to run this test")
-
-    inputs = array_ops.zeros([32, 100], dtypes.float32)
-    net = SingleLayerNet()
-
-    def f():
-      with backprop.GradientTape() as tape:
-        result = net(inputs)
-
-      tape.gradient(result, net.variables)
-
-      del tape
-
-    memory_test_util.assert_no_leak(f)
-
   def testMemoryLeakInFunction(self):
     if not memory_test_util.memory_profiler_is_available():
       self.skipTest("memory_profiler required to run this test")
diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD
index bcbb7a375d0..18f9575cecc 100644
--- a/tensorflow/python/keras/tests/BUILD
+++ b/tensorflow/python/keras/tests/BUILD
@@ -2,6 +2,7 @@
 #   Contains Keras test utils and integration tests.
 
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 package(
     default_visibility = [
@@ -128,6 +129,30 @@ tf_py_test(
     ],
 )
 
+cuda_py_test(
+    name = "memory_test",
+    size = "medium",
+    srcs = ["memory_test.py"],
+    tags = [
+        "manual",
+        "no_oss",
+        "notap",  #TODO(b/140640597): this test is flaky at the moment
+        "optonly",  # The test is too slow in non-opt mode
+    ],
+    # TODO(b/140065350): Re-enable
+    xla_enable_strict_auto_jit = False,
+    deps = [
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python/eager:backprop",
+        "//tensorflow/python/eager:test",
+        "//tensorflow/python/eager/memory_tests:memory_test_util",
+        "//tensorflow/python/keras",
+        "@six_archive//:six",
+    ],
+)
+
 tf_py_test(
     name = "temporal_sample_weights_correctness_test",
     srcs = ["temporal_sample_weights_correctness_test.py"],
diff --git a/tensorflow/python/keras/tests/memory_test.py b/tensorflow/python/keras/tests/memory_test.py
new file mode 100644
index 00000000000..753820d3295
--- /dev/null
+++ b/tensorflow/python/keras/tests/memory_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for memory leaks in eager execution.
+
+It is possible that this test suite will eventually become flaky due to taking
+too long to run (since the tests iterate many times), but for now they are
+helpful for finding memory leaks since not all PyObject leaks are found by
+introspection (test_util decorators). Please be careful adding new tests here.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import test
+from tensorflow.python.eager.memory_tests import memory_test_util
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+
+
+class SingleLayerNet(keras.Model):
+  """Simple keras model used to ensure that there are no leaks."""
+
+  def __init__(self):
+    super(SingleLayerNet, self).__init__()
+    self.fc1 = keras.layers.Dense(5)
+
+  def call(self, x):
+    return self.fc1(x)
+
+
+class MemoryTest(test.TestCase):
+
+  def testMemoryLeakInSimpleModelForwardOnly(self):
+    if not memory_test_util.memory_profiler_is_available():
+      self.skipTest("memory_profiler required to run this test")
+
+    inputs = array_ops.zeros([32, 100], dtypes.float32)
+    net = SingleLayerNet()
+
+    def f():
+      with backprop.GradientTape():
+        net(inputs)
+
+    memory_test_util.assert_no_leak(f)
+
+  def testMemoryLeakInSimpleModelForwardAndBackward(self):
+    if not memory_test_util.memory_profiler_is_available():
+      self.skipTest("memory_profiler required to run this test")
+
+    inputs = array_ops.zeros([32, 100], dtypes.float32)
+    net = SingleLayerNet()
+
+    def f():
+      with backprop.GradientTape() as tape:
+        result = net(inputs)
+
+      tape.gradient(result, net.variables)
+
+      del tape
+
+    memory_test_util.assert_no_leak(f)
+
+
+if __name__ == "__main__":
+  test.main()