Replace keras usages of private function.defun
with tf.function
PiperOrigin-RevId: 331804876 Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b
This commit is contained in:
parent
9b6a1850d9
commit
6f0dd6eac4
@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
@ -55,13 +55,13 @@ class MiniModel(keras_training.Model):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
mode=["eager"]))
|
||||
class MirroredStrategyDefunTest(test.TestCase):
|
||||
|
||||
def testTrain(self, distribution):
|
||||
with distribution.scope():
|
||||
mock_model = MiniModel()
|
||||
mock_model.call = function.defun(mock_model.call)
|
||||
mock_model.call = def_function.function(mock_model.call)
|
||||
|
||||
def loss_fn(ctx):
|
||||
del ctx
|
||||
|
@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer):
|
||||
return op.outputs[0]
|
||||
return op.outputs
|
||||
|
||||
@function.defun
|
||||
@def_function.function
|
||||
def _defun_call(self, inputs):
|
||||
"""Wraps the op creation method in an Eager function for `run_eagerly`."""
|
||||
return self._make_op(inputs)
|
||||
|
@ -24,7 +24,7 @@ import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(MySequential, self).__init__(name=name)
|
||||
self.call = function.defun(self.call)
|
||||
self.call = def_function.function(self.call)
|
||||
|
||||
model = MySequential()
|
||||
model.add(keras.layers.Dense(4, activation='relu'))
|
||||
|
@ -29,7 +29,6 @@ import six
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
layer = layers_module.Dense(1, kernel_regularizer='l1')
|
||||
layer(array_ops.ones([1, 10]))
|
||||
|
||||
@function.defun
|
||||
@def_function.function
|
||||
def get_losses():
|
||||
return layer.losses
|
||||
|
||||
|
@ -26,7 +26,6 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase):
|
||||
"""Ensure that variables are created correctly in a tf function."""
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
|
||||
@eager_function.defun
|
||||
@def_function.function
|
||||
def call_metric(x):
|
||||
return m(x)
|
||||
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
[[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=["eager"]))
|
||||
def testCapturingInDefunWhileExecutingEagerly(self):
|
||||
def testCapturingInFunctionWhileExecutingEagerly(self):
|
||||
optimizer = gradient_descent.SGD(1.0)
|
||||
|
||||
var_holder = {}
|
||||
def step():
|
||||
self.v = variables.Variable(1.0)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = self.v**2
|
||||
grad = tape.gradient(loss, self.v)
|
||||
optimizer.apply_gradients([(grad, self.v)])
|
||||
return self.v.read_value()
|
||||
if not var_holder:
|
||||
var_holder["var"] = variables.Variable(1.0)
|
||||
else:
|
||||
var_holder["var"].assign(1.0)
|
||||
|
||||
compiled_step = function.defun(step)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = var_holder["var"]**2
|
||||
grad = tape.gradient(loss, var_holder["var"])
|
||||
optimizer.apply_gradients([(grad, var_holder["var"])])
|
||||
return var_holder["var"].read_value()
|
||||
|
||||
compiled_step = def_function.function(step)
|
||||
|
||||
self.assertEqual(float(step()), -1.0)
|
||||
self.assertEqual(float(compiled_step()), -1.0)
|
||||
|
@ -22,7 +22,7 @@ import os
|
||||
import sys
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint):
|
||||
self.dense = core.Dense(1)
|
||||
self.optimizer = adam.Adam(0.01)
|
||||
|
||||
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
|
||||
# conds to gate variable initialization and so triggers cond reference cycles,
|
||||
# but the thing being wrapped here does not use cond itself.
|
||||
@function.defun(
|
||||
@def_function.function(
|
||||
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
|
||||
tensor_spec.TensorSpec([None], dtypes.float32)),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user