Replace keras usages of private function.defun with tf.function

PiperOrigin-RevId: 331804876
Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b
This commit is contained in:
Tomer Kaftan 2020-09-15 10:50:06 -07:00 committed by TensorFlower Gardener
parent 9b6a1850d9
commit 6f0dd6eac4
7 changed files with 25 additions and 25 deletions

View File

@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core
@ -55,13 +55,13 @@ class MiniModel(keras_training.Model):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
mode=["eager"]))
class MirroredStrategyDefunTest(test.TestCase):
def testTrain(self, distribution):
with distribution.scope():
mock_model = MiniModel()
mock_model.call = function.defun(mock_model.call)
mock_model.call = def_function.function(mock_model.call)
def loss_fn(ctx):
del ctx

View File

@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute
from tensorflow.python.eager import function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer):
return op.outputs[0]
return op.outputs
@function.defun
@def_function.function
def _defun_call(self, inputs):
"""Wraps the op creation method in an Eager function for `run_eagerly`."""
return self._make_op(inputs)

View File

@ -24,7 +24,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
def __init__(self, name=None):
super(MySequential, self).__init__(name=name)
self.call = function.defun(self.call)
self.call = def_function.function(self.call)
model = MySequential()
model.add(keras.layers.Dense(4, activation='relu'))

View File

@ -29,7 +29,6 @@ import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase):
layer = layers_module.Dense(1, kernel_regularizer='l1')
layer(array_ops.ones([1, 10]))
@function.defun
@def_function.function
def get_losses():
return layer.losses

View File

@ -26,7 +26,6 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase):
"""Ensure that variables are created correctly in a tf function."""
m = metrics.MeanTensor(dtype=dtypes.float64)
@eager_function.defun
@def_function.function
def call_metric(x):
return m(x)

View File

@ -23,7 +23,7 @@ import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
[[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
@combinations.generate(combinations.combine(mode=["eager"]))
def testCapturingInDefunWhileExecutingEagerly(self):
def testCapturingInFunctionWhileExecutingEagerly(self):
optimizer = gradient_descent.SGD(1.0)
var_holder = {}
def step():
self.v = variables.Variable(1.0)
with backprop.GradientTape() as tape:
loss = self.v**2
grad = tape.gradient(loss, self.v)
optimizer.apply_gradients([(grad, self.v)])
return self.v.read_value()
if not var_holder:
var_holder["var"] = variables.Variable(1.0)
else:
var_holder["var"].assign(1.0)
compiled_step = function.defun(step)
with backprop.GradientTape() as tape:
loss = var_holder["var"]**2
grad = tape.gradient(loss, var_holder["var"])
optimizer.apply_gradients([(grad, var_holder["var"])])
return var_holder["var"].read_value()
compiled_step = def_function.function(step)
self.assertEqual(float(step()), -1.0)
self.assertEqual(float(compiled_step()), -1.0)

View File

@ -22,7 +22,7 @@ import os
import sys
from tensorflow.python.eager import backprop
from tensorflow.python.eager import function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
# conds to gate variable initialization and so triggers cond reference cycles,
# but the thing being wrapped here does not use cond itself.
@function.defun(
@def_function.function(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)),
)