From a37a3569f4faece52d56fffb9aef3757cf6d03f1 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Fri, 8 May 2020 11:02:28 -0700 Subject: [PATCH] Add a nested tf.function with control flow test. PiperOrigin-RevId: 310589571 Change-Id: Icb71cd7f50d77fe4b67ba21bedf415cdc8ff24bd --- .../custom_training_loop_models_test.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py index 3c748bd7364..48f2af0349a 100644 --- a/tensorflow/python/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/distribute/custom_training_loop_models_test.py @@ -378,6 +378,46 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase): for model_v, model2_v in zip(model.variables, model2.variables): self.assertAllClose(model_v.numpy(), model2_v.numpy()) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def test_nested_tf_functions_with_control_flow(self, distribution): + inputs = np.random.random((10, 3)).astype(np.float32) + targets = np.ones((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat() + dataset = dataset.batch(10, drop_remainder=True) + input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) + + def get_model(): + x = keras.layers.Input(shape=(3,), name="input") + y = keras.layers.Dense(4, name="dense")(x) + model = keras.Model(x, y) + return model + + with distribution.scope(): + model = get_model() + optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01) + + @def_function.function + def train_step(iterator): + + def step_fn(inputs): + images, targets = inputs + with backprop.GradientTape() as tape: + outputs = model(images) + loss = math_ops.reduce_sum(outputs - targets) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + + distribution.run(step_fn, args=(next(iterator),)) + + @def_function.function + def train_steps(iterator): + for _ in math_ops.range(10): + train_step(iterator) + + train_steps(input_iterator) + @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies,