Add outside compilation test with TPUPartitionedCall.

PiperOrigin-RevId: 339877462
Change-Id: Ic52ed793da9b3d48c79df173fd000c9930036839
This commit is contained in:
Ken Franko 2020-10-30 08:45:34 -07:00 committed by TensorFlower Gardener
parent d4e2d789d0
commit 60aa0451fe

View File

@ -34,7 +34,9 @@ from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import array_ops
@ -47,8 +49,10 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import summary_ops_v2 as summary
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.tpu import functional as tpu_functional
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu.ops import tpu_ops
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@ -91,6 +95,36 @@ def _events_from_logdir(test_case, logdir):
return result
def _rewrite_func_wrapper(tf_func):
def tpu_fn(*args, **kwargs):
# tpu.rewrite only accepts list of tensors as input. We need to flatten
# keyword arguments to meet this requirement.
concrete = tf_func.get_concrete_function(*(list(args) +
list(kwargs.values())))
return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values()))
return def_function.function(tpu_fn)
def _tpu_partitioned_call_wrapper(tf_func):
"""Wrap a tensorflow Function with TPUPartitionedCall."""
def inner_func(*args, **kwargs):
concrete = tf_func.get_concrete_function(*args, **kwargs)
# TPUPartitionedCall only accepts list of tensors as input args.
# Flatten keyword arguments and do some basic ordering:
# Positional args + Flattened keyword args + Captured args.
op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs
return tpu_functional.TPUPartitionedCall(
args=op_args,
device_ordinal=tpu_ops.tpu_ordinal_selector(),
Tout=[o.type for o in concrete.function_def.signature.output_arg],
f=concrete)
return def_function.function(inner_func)
class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
@ -653,6 +687,34 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
self.assertAllEqual(
strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3])
@test_util.disable_mlir_bridge(
"TODO(b/167235391): Reenable this test once function calls are handled "
"by MLIR bridge."
)
def testOutsideCompilationWithTPUPartitionedCallOp(self):
"""Tests that control flow with TPUPartitionedCall including outside_compilation works."""
get_tpu_strategy()
def host_computation(x):
return x + 1
@def_function.function()
def train_step(x):
x2 = x + 5.0
logging_ops.print_v2(x2)
x2 = tpu.outside_compilation(host_computation, x2)
return x2 + 4.0
tpu_fn = _rewrite_func_wrapper(train_step)
partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn)
concrete = partitioned_tpu_fn.get_concrete_function(
x=tensor_spec.TensorSpec(
shape=(1), dtype=dtypes.float32, name="input_tensor"))
self.assertIsInstance(
concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor)
if __name__ == "__main__":
test.main()