Replace keras usages of private function.defun with tf.function

PiperOrigin-RevId: 331804876
Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b
This commit is contained in:
Tomer Kaftan 2020-09-15 10:50:06 -07:00 committed by TensorFlower Gardener
parent 9b6a1850d9
commit 6f0dd6eac4
7 changed files with 25 additions and 25 deletions

View File

@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.keras.layers import core as keras_core
@ -55,13 +55,13 @@ class MiniModel(keras_training.Model):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
], ],
mode=["graph", "eager"])) mode=["eager"]))
class MirroredStrategyDefunTest(test.TestCase): class MirroredStrategyDefunTest(test.TestCase):
def testTrain(self, distribution): def testTrain(self, distribution):
with distribution.scope(): with distribution.scope():
mock_model = MiniModel() mock_model = MiniModel()
mock_model.call = function.defun(mock_model.call) mock_model.call = def_function.function(mock_model.call)
def loss_fn(ctx): def loss_fn(ctx):
del ctx del ctx

View File

@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import function
from tensorflow.python.eager import monitoring from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer):
return op.outputs[0] return op.outputs[0]
return op.outputs return op.outputs
@function.defun @def_function.function
def _defun_call(self, inputs): def _defun_call(self, inputs):
"""Wraps the op creation method in an Eager function for `run_eagerly`.""" """Wraps the op creation method in an Eager function for `run_eagerly`."""
return self._make_op(inputs) return self._make_op(inputs)

View File

@ -24,7 +24,7 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
def __init__(self, name=None): def __init__(self, name=None):
super(MySequential, self).__init__(name=name) super(MySequential, self).__init__(name=name)
self.call = function.defun(self.call) self.call = def_function.function(self.call)
model = MySequential() model = MySequential()
model.add(keras.layers.Dense(4, activation='relu')) model.add(keras.layers.Dense(4, activation='relu'))

View File

@ -29,7 +29,6 @@ import six
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase):
layer = layers_module.Dense(1, kernel_regularizer='l1') layer = layers_module.Dense(1, kernel_regularizer='l1')
layer(array_ops.ones([1, 10])) layer(array_ops.ones([1, 10]))
@function.defun @def_function.function
def get_losses(): def get_losses():
return layer.losses return layer.losses

View File

@ -26,7 +26,6 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase):
"""Ensure that variables are created correctly in a tf function.""" """Ensure that variables are created correctly in a tf function."""
m = metrics.MeanTensor(dtype=dtypes.float64) m = metrics.MeanTensor(dtype=dtypes.float64)
@eager_function.defun @def_function.function
def call_metric(x): def call_metric(x):
return m(x) return m(x)

View File

@ -23,7 +23,7 @@ import numpy as np
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
[[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1)) [[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
@combinations.generate(combinations.combine(mode=["eager"])) @combinations.generate(combinations.combine(mode=["eager"]))
def testCapturingInDefunWhileExecutingEagerly(self): def testCapturingInFunctionWhileExecutingEagerly(self):
optimizer = gradient_descent.SGD(1.0) optimizer = gradient_descent.SGD(1.0)
var_holder = {}
def step(): def step():
self.v = variables.Variable(1.0) if not var_holder:
with backprop.GradientTape() as tape: var_holder["var"] = variables.Variable(1.0)
loss = self.v**2 else:
grad = tape.gradient(loss, self.v) var_holder["var"].assign(1.0)
optimizer.apply_gradients([(grad, self.v)])
return self.v.read_value()
compiled_step = function.defun(step) with backprop.GradientTape() as tape:
loss = var_holder["var"]**2
grad = tape.gradient(loss, var_holder["var"])
optimizer.apply_gradients([(grad, var_holder["var"])])
return var_holder["var"].read_value()
compiled_step = def_function.function(step)
self.assertEqual(float(step()), -1.0) self.assertEqual(float(step()), -1.0)
self.assertEqual(float(compiled_step()), -1.0) self.assertEqual(float(compiled_step()), -1.0)

View File

@ -22,7 +22,7 @@ import os
import sys import sys
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint):
self.dense = core.Dense(1) self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01) self.optimizer = adam.Adam(0.01)
# Using defun due to control flow v2 cycles, b/121159261. def_function uses @def_function.function(
# conds to gate variable initialization and so triggers cond reference cycles,
# but the thing being wrapped here does not use cond itself.
@function.defun(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)), tensor_spec.TensorSpec([None], dtypes.float32)),
) )