From 6f0dd6eac487a70b908e2e509a43e17fb1a3cba2 Mon Sep 17 00:00:00 2001
From: Tomer Kaftan <kaftan@google.com>
Date: Tue, 15 Sep 2020 10:50:06 -0700
Subject: [PATCH] Replace keras usages of private `function.defun` with
 `tf.function`

PiperOrigin-RevId: 331804876
Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b
---
 .../distribute/mirrored_strategy_test.py      |  6 ++---
 tensorflow/python/keras/engine/base_layer.py  |  4 ++--
 .../python/keras/engine/sequential_test.py    |  4 ++--
 .../python/keras/engine/training_test.py      |  3 +--
 tensorflow/python/keras/metrics_test.py       |  3 +--
 .../optimizer_v2/gradient_descent_test.py     | 23 +++++++++++--------
 .../python/keras/tests/saved_model_test.py    |  7 ++----
 7 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/tensorflow/python/keras/distribute/mirrored_strategy_test.py b/tensorflow/python/keras/distribute/mirrored_strategy_test.py
index 1303952bf78..fc800d4b210 100644
--- a/tensorflow/python/keras/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/keras/distribute/mirrored_strategy_test.py
@@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import test_combinations as combinations
 from tensorflow.python.keras.engine import training as keras_training
 from tensorflow.python.keras.layers import core as keras_core
@@ -55,13 +55,13 @@ class MiniModel(keras_training.Model):
         distribution=[
             strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
         ],
-        mode=["graph", "eager"]))
+        mode=["eager"]))
 class MirroredStrategyDefunTest(test.TestCase):
 
   def testTrain(self, distribution):
     with distribution.scope():
       mock_model = MiniModel()
-      mock_model.call = function.defun(mock_model.call)
+      mock_model.call = def_function.function(mock_model.call)
 
       def loss_fn(ctx):
         del ctx
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 0efcf47bc09..bd71bd4b7b2 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx
 from tensorflow.python.autograph.impl import api as autograph
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import execute
-from tensorflow.python.eager import function
 from tensorflow.python.eager import monitoring
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer):
         return op.outputs[0]
       return op.outputs
 
-  @function.defun
+  @def_function.function
   def _defun_call(self, inputs):
     """Wraps the op creation method in an Eager function for `run_eagerly`."""
     return self._make_op(inputs)
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 1c8510ff3c9..6a9a3bf9bcc 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -24,7 +24,7 @@ import numpy as np
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
@@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
 
       def __init__(self, name=None):
         super(MySequential, self).__init__(name=name)
-        self.call = function.defun(self.call)
+        self.call = def_function.function(self.call)
 
     model = MySequential()
     model.add(keras.layers.Dense(4, activation='relu'))
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 3ce9a1ac01c..1f8f8cb1b52 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -29,7 +29,6 @@ import six
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
-from tensorflow.python.eager import function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util as tf_test_util
@@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase):
     layer = layers_module.Dense(1, kernel_regularizer='l1')
     layer(array_ops.ones([1, 10]))
 
-    @function.defun
+    @def_function.function
     def get_losses():
       return layer.losses
 
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index a4f61082c2d..b297063e0d3 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -26,7 +26,6 @@ from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.eager import def_function
-from tensorflow.python.eager import function as eager_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase):
     """Ensure that variables are created correctly in a tf function."""
     m = metrics.MeanTensor(dtype=dtypes.float64)
 
-    @eager_function.defun
+    @def_function.function
     def call_metric(x):
       return m(x)
 
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
index 15a501f5259..165102bede5 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
@@ -23,7 +23,7 @@ import numpy as np
 
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
             [[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
 
   @combinations.generate(combinations.combine(mode=["eager"]))
-  def testCapturingInDefunWhileExecutingEagerly(self):
+  def testCapturingInFunctionWhileExecutingEagerly(self):
     optimizer = gradient_descent.SGD(1.0)
 
+    var_holder = {}
     def step():
-      self.v = variables.Variable(1.0)
-      with backprop.GradientTape() as tape:
-        loss = self.v**2
-      grad = tape.gradient(loss, self.v)
-      optimizer.apply_gradients([(grad, self.v)])
-      return self.v.read_value()
+      if not var_holder:
+        var_holder["var"] = variables.Variable(1.0)
+      else:
+        var_holder["var"].assign(1.0)
 
-    compiled_step = function.defun(step)
+      with backprop.GradientTape() as tape:
+        loss = var_holder["var"]**2
+      grad = tape.gradient(loss, var_holder["var"])
+      optimizer.apply_gradients([(grad, var_holder["var"])])
+      return var_holder["var"].read_value()
+
+    compiled_step = def_function.function(step)
 
     self.assertEqual(float(step()), -1.0)
     self.assertEqual(float(compiled_step()), -1.0)
diff --git a/tensorflow/python/keras/tests/saved_model_test.py b/tensorflow/python/keras/tests/saved_model_test.py
index cd6363b8855..9264a60eb55 100644
--- a/tensorflow/python/keras/tests/saved_model_test.py
+++ b/tensorflow/python/keras/tests/saved_model_test.py
@@ -22,7 +22,7 @@ import os
 import sys
 
 from tensorflow.python.eager import backprop
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_spec
@@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint):
     self.dense = core.Dense(1)
     self.optimizer = adam.Adam(0.01)
 
-  # Using defun due to control flow v2 cycles, b/121159261. def_function uses
-  # conds to gate variable initialization and so triggers cond reference cycles,
-  # but the thing being wrapped here does not use cond itself.
-  @function.defun(
+  @def_function.function(
       input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
                        tensor_spec.TensorSpec([None], dtypes.float32)),
   )