diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index f7ecb294c44..54c2598324c 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -18,13 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np + from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import def_function from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import flags from tensorflow.python.tpu import tpu @@ -52,7 +57,7 @@ def get_tpu_strategy(): return tpu_lib.TPUStrategy(resolver) -class TpuOutsideCompilationTest(test.TestCase): +class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): def testResourceVariableAssignOnHost(self): strategy = get_tpu_strategy() @@ -79,6 +84,26 @@ class TpuOutsideCompilationTest(test.TestCase): self.assertAllEqual(4.0 * strategy.num_replicas_in_sync, v2.numpy()) self.assertAllEqual(5.0, v.numpy()) + def testHostNoInput(self): + strategy = get_tpu_strategy() + + def outside_fn(): + logging_ops.print_v2("Outside compiled") + + @def_function.function + def train_step(): + + def tpu_fn(x): + x2 = x + 5.0 + tpu.outside_compilation(outside_fn) + return x2 + 5.0 + + return strategy.run(tpu_fn, args=(25.0,)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(35., shape=(strategy.num_replicas_in_sync))) + def testHostInputOnly(self): strategy = get_tpu_strategy() @@ -120,13 +145,71 @@ class TpuOutsideCompilationTest(test.TestCase): strategy.experimental_local_results(train_step()), constant_op.constant(36., shape=(strategy.num_replicas_in_sync))) - def testOutsideCompilationControlFlowIf(self): + def testHostMultipleInputs(self): + strategy = get_tpu_strategy() + val0 = np.arange(6).reshape((2, 3)).astype(np.float32) + val1 = np.arange(6).reshape((3, 2)).astype(np.float32) + + def outside_fn(arg0, arg1): + tmp = array_ops.reshape(arg1, array_ops.shape(arg0)) + ret0 = arg0 + tmp + ret1 = math_ops.matmul(arg0, arg1) + ret2 = array_ops.concat([arg0, tmp], 0) + return ret0, ret1, ret2 + + @def_function.function + def train_step(): + + def tpu_fn(x, y): + a = x + 7.0 + b = y * 2.0 + c, d, e = tpu.outside_compilation(outside_fn, a, b) + return (math_ops.reduce_max(c) + math_ops.reduce_min(d) + + math_ops.reduce_sum(e)) + + return strategy.run(tpu_fn, args=(val0, val1)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(213., shape=(strategy.num_replicas_in_sync))) + + def testMultipleClusters(self): + strategy = get_tpu_strategy() + + def outside_fn1(x): + logging_ops.print_v2("Outside compiled", x) + return x + 6.0 + + def outside_fn2(x): + logging_ops.print_v2("Outside compiled", x) + return x - 18.0 + + @def_function.function + def train_step(): + + def tpu_fn(x): + x2 = x + 5.0 + output1 = tpu.outside_compilation(outside_fn1, x2) + x3 = output1 + 3.0 + output2 = tpu.outside_compilation(outside_fn2, x3) + return output2 + + return strategy.run(tpu_fn, args=(25.0,)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(21., shape=(strategy.num_replicas_in_sync))) + + @parameterized.parameters((True), (False)) + def testOutsideCompilationControlFlowIf(self, take_true_branch): strategy = get_tpu_strategy() def outside_fn(x): logging_ops.print_v2("Outside compiled", x) return x + 6.0 + input_value = 51.0 if take_true_branch else 25.0 + @def_function.function def train_step(): @@ -137,11 +220,15 @@ class TpuOutsideCompilationTest(test.TestCase): else: return x2 - return strategy.run(tpu_fn, args=(25.0,)) + return strategy.run(tpu_fn, args=(input_value,)) + output_value = 36.0 + if take_true_branch: + output_value = 56.0 self.assertAllEqual( strategy.experimental_local_results(train_step()), - constant_op.constant(36., shape=(strategy.num_replicas_in_sync))) + constant_op.constant( + output_value, shape=(strategy.num_replicas_in_sync))) def testOutsideCompilationControlFlowWhile(self): strategy = get_tpu_strategy()