Replace keras usages of private function.defun with tf.function
PiperOrigin-RevId: 331804876 Change-Id: If44165155c160ffe35a263f9d5c98f6a73ccb41b
This commit is contained in:
parent
9b6a1850d9
commit
6f0dd6eac4
@ -23,7 +23,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
|
|||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import test_combinations as combinations
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
from tensorflow.python.keras.engine import training as keras_training
|
from tensorflow.python.keras.engine import training as keras_training
|
||||||
from tensorflow.python.keras.layers import core as keras_core
|
from tensorflow.python.keras.layers import core as keras_core
|
||||||
@ -55,13 +55,13 @@ class MiniModel(keras_training.Model):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
],
|
],
|
||||||
mode=["graph", "eager"]))
|
mode=["eager"]))
|
||||||
class MirroredStrategyDefunTest(test.TestCase):
|
class MirroredStrategyDefunTest(test.TestCase):
|
||||||
|
|
||||||
def testTrain(self, distribution):
|
def testTrain(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
mock_model = MiniModel()
|
mock_model = MiniModel()
|
||||||
mock_model.call = function.defun(mock_model.call)
|
mock_model.call = def_function.function(mock_model.call)
|
||||||
|
|
||||||
def loss_fn(ctx):
|
def loss_fn(ctx):
|
||||||
del ctx
|
del ctx
|
||||||
|
|||||||
@ -37,8 +37,8 @@ from tensorflow.python.autograph.core import ag_ctx
|
|||||||
from tensorflow.python.autograph.impl import api as autograph
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import function
|
|
||||||
from tensorflow.python.eager import monitoring
|
from tensorflow.python.eager import monitoring
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -3187,7 +3187,7 @@ class TensorFlowOpLayer(Layer):
|
|||||||
return op.outputs[0]
|
return op.outputs[0]
|
||||||
return op.outputs
|
return op.outputs
|
||||||
|
|
||||||
@function.defun
|
@def_function.function
|
||||||
def _defun_call(self, inputs):
|
def _defun_call(self, inputs):
|
||||||
"""Wraps the op creation method in an Eager function for `run_eagerly`."""
|
"""Wraps the op creation method in an Eager function for `run_eagerly`."""
|
||||||
return self._make_op(inputs)
|
return self._make_op(inputs)
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import numpy as np
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -456,7 +456,7 @@ class TestSequentialEagerIntegration(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None):
|
||||||
super(MySequential, self).__init__(name=name)
|
super(MySequential, self).__init__(name=name)
|
||||||
self.call = function.defun(self.call)
|
self.call = def_function.function(self.call)
|
||||||
|
|
||||||
model = MySequential()
|
model = MySequential()
|
||||||
model.add(keras.layers.Dense(4, activation='relu'))
|
model.add(keras.layers.Dense(4, activation='relu'))
|
||||||
|
|||||||
@ -29,7 +29,6 @@ import six
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
@ -992,7 +991,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
layer = layers_module.Dense(1, kernel_regularizer='l1')
|
layer = layers_module.Dense(1, kernel_regularizer='l1')
|
||||||
layer(array_ops.ones([1, 10]))
|
layer(array_ops.ones([1, 10]))
|
||||||
|
|
||||||
@function.defun
|
@def_function.function
|
||||||
def get_losses():
|
def get_losses():
|
||||||
return layer.losses
|
return layer.losses
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,6 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as eager_function
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
@ -1476,7 +1475,7 @@ class MeanTensorTest(test.TestCase, parameterized.TestCase):
|
|||||||
"""Ensure that variables are created correctly in a tf function."""
|
"""Ensure that variables are created correctly in a tf function."""
|
||||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||||
|
|
||||||
@eager_function.defun
|
@def_function.function
|
||||||
def call_metric(x):
|
def call_metric(x):
|
||||||
return m(x)
|
return m(x)
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -259,18 +259,23 @@ class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
[[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
|
[[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=["eager"]))
|
@combinations.generate(combinations.combine(mode=["eager"]))
|
||||||
def testCapturingInDefunWhileExecutingEagerly(self):
|
def testCapturingInFunctionWhileExecutingEagerly(self):
|
||||||
optimizer = gradient_descent.SGD(1.0)
|
optimizer = gradient_descent.SGD(1.0)
|
||||||
|
|
||||||
|
var_holder = {}
|
||||||
def step():
|
def step():
|
||||||
self.v = variables.Variable(1.0)
|
if not var_holder:
|
||||||
with backprop.GradientTape() as tape:
|
var_holder["var"] = variables.Variable(1.0)
|
||||||
loss = self.v**2
|
else:
|
||||||
grad = tape.gradient(loss, self.v)
|
var_holder["var"].assign(1.0)
|
||||||
optimizer.apply_gradients([(grad, self.v)])
|
|
||||||
return self.v.read_value()
|
|
||||||
|
|
||||||
compiled_step = function.defun(step)
|
with backprop.GradientTape() as tape:
|
||||||
|
loss = var_holder["var"]**2
|
||||||
|
grad = tape.gradient(loss, var_holder["var"])
|
||||||
|
optimizer.apply_gradients([(grad, var_holder["var"])])
|
||||||
|
return var_holder["var"].read_value()
|
||||||
|
|
||||||
|
compiled_step = def_function.function(step)
|
||||||
|
|
||||||
self.assertEqual(float(step()), -1.0)
|
self.assertEqual(float(step()), -1.0)
|
||||||
self.assertEqual(float(compiled_step()), -1.0)
|
self.assertEqual(float(compiled_step()), -1.0)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
@ -41,10 +41,7 @@ class _ModelWithOptimizerUsingDefun(util.Checkpoint):
|
|||||||
self.dense = core.Dense(1)
|
self.dense = core.Dense(1)
|
||||||
self.optimizer = adam.Adam(0.01)
|
self.optimizer = adam.Adam(0.01)
|
||||||
|
|
||||||
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
|
@def_function.function(
|
||||||
# conds to gate variable initialization and so triggers cond reference cycles,
|
|
||||||
# but the thing being wrapped here does not use cond itself.
|
|
||||||
@function.defun(
|
|
||||||
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
|
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
|
||||||
tensor_spec.TensorSpec([None], dtypes.float32)),
|
tensor_spec.TensorSpec([None], dtypes.float32)),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user