From 6f0dd6eac487a70b908e2e509a43e17fb1a3cba2 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan <kaftan@google.com> Date: Tue, 15 Sep 2020 10:50:06 -0700 Subject: [PATCH] Replace keras usages of private `function.defun` with `tf.function` PiperOrigin-RevId: 331804876 Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b --- .../distribute/mirrored_strategy_test.py | 6 ++--- tensorflow/python/keras/engine/base_layer.py | 4 ++-- .../python/keras/engine/sequential_test.py | 4 ++-- .../python/keras/engine/training_test.py | 3 +-- tensorflow/python/keras/metrics_test.py | 3 +-- .../optimizer_v2/gradient_descent_test.py | 23 +++++++++++-------- .../python/keras/tests/saved_model_test.py | 7 ++---- 7 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/keras/distribute/mirrored_strategy_test.py b/tensorflow/python/keras/distribute/mirrored_strategy_test.py index 1303952bf78..fc800d4b210 100644 --- a/tensorflow/python/keras/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/keras/distribute/mirrored_strategy_test.py @@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import function +from tensorflow.python.eager import def_function from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core @@ -55,13 +55,13 @@ class MiniModel(keras_training.Model): distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, ], - mode=["graph", "eager"])) + mode=["eager"])) class MirroredStrategyDefunTest(test.TestCase): def testTrain(self, distribution): with distribution.scope(): mock_model = MiniModel() - mock_model.call = function.defun(mock_model.call) + mock_model.call = def_function.function(mock_model.call) def loss_fn(ctx): del ctx diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 0efcf47bc09..bd71bd4b7b2 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import execute -from tensorflow.python.eager import function from tensorflow.python.eager import monitoring from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer): return op.outputs[0] return op.outputs - @function.defun + @def_function.function def _defun_call(self, inputs): """Wraps the op creation method in an Eager function for `run_eagerly`.""" return self._make_op(inputs) diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 1c8510ff3c9..6a9a3bf9bcc 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -24,7 +24,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context -from tensorflow.python.eager import function +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase): def __init__(self, name=None): super(MySequential, self).__init__(name=name) - self.call = function.defun(self.call) + self.call = def_function.function(self.call) model = MySequential() model.add(keras.layers.Dense(4, activation='relu')) diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 3ce9a1ac01c..1f8f8cb1b52 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -29,7 +29,6 @@ import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.eager import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util @@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase): layer = layers_module.Dense(1, kernel_regularizer='l1') layer(array_ops.ones([1, 10])) - @function.defun + @def_function.function def get_losses(): return layer.losses diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index a4f61082c2d..b297063e0d3 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.eager import def_function -from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase): """Ensure that variables are created correctly in a tf function.""" m = metrics.MeanTensor(dtype=dtypes.float64) - @eager_function.defun + @def_function.function def call_metric(x): return m(x) diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py index 15a501f5259..165102bede5 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py @@ -23,7 +23,7 @@ import numpy as np from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import function +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase): [[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1)) @combinations.generate(combinations.combine(mode=["eager"])) - def testCapturingInDefunWhileExecutingEagerly(self): + def testCapturingInFunctionWhileExecutingEagerly(self): optimizer = gradient_descent.SGD(1.0) + var_holder = {} def step(): - self.v = variables.Variable(1.0) - with backprop.GradientTape() as tape: - loss = self.v**2 - grad = tape.gradient(loss, self.v) - optimizer.apply_gradients([(grad, self.v)]) - return self.v.read_value() + if not var_holder: + var_holder["var"] = variables.Variable(1.0) + else: + var_holder["var"].assign(1.0) - compiled_step = function.defun(step) + with backprop.GradientTape() as tape: + loss = var_holder["var"]**2 + grad = tape.gradient(loss, var_holder["var"]) + optimizer.apply_gradients([(grad, var_holder["var"])]) + return var_holder["var"].read_value() + + compiled_step = def_function.function(step) self.assertEqual(float(step()), -1.0) self.assertEqual(float(compiled_step()), -1.0) diff --git a/tensorflow/python/keras/tests/saved_model_test.py b/tensorflow/python/keras/tests/saved_model_test.py index cd6363b8855..9264a60eb55 100644 --- a/tensorflow/python/keras/tests/saved_model_test.py +++ b/tensorflow/python/keras/tests/saved_model_test.py @@ -22,7 +22,7 @@ import os import sys from tensorflow.python.eager import backprop -from tensorflow.python.eager import function +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_spec @@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint): self.dense = core.Dense(1) self.optimizer = adam.Adam(0.01) - # Using defun due to control flow v2 cycles, b/121159261. def_function uses - # conds to gate variable initialization and so triggers cond reference cycles, - # but the thing being wrapped here does not use cond itself. - @function.defun( + @def_function.function( input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), tensor_spec.TensorSpec([None], dtypes.float32)), )