Add outside compilation test with TPUPartitionedCall.
PiperOrigin-RevId: 339877462 Change-Id: Ic52ed793da9b3d48c79df173fd000c9930036839
This commit is contained in:
parent
d4e2d789d0
commit
60aa0451fe
@ -34,7 +34,9 @@ from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -47,8 +49,10 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.tpu import functional as tpu_functional
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
||||
@ -91,6 +95,36 @@ def _events_from_logdir(test_case, logdir):
|
||||
return result
|
||||
|
||||
|
||||
def _rewrite_func_wrapper(tf_func):
|
||||
|
||||
def tpu_fn(*args, **kwargs):
|
||||
# tpu.rewrite only accepts list of tensors as input. We need to flatten
|
||||
# keyword arguments to meet this requirement.
|
||||
concrete = tf_func.get_concrete_function(*(list(args) +
|
||||
list(kwargs.values())))
|
||||
return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values()))
|
||||
|
||||
return def_function.function(tpu_fn)
|
||||
|
||||
|
||||
def _tpu_partitioned_call_wrapper(tf_func):
|
||||
"""Wrap a tensorflow Function with TPUPartitionedCall."""
|
||||
|
||||
def inner_func(*args, **kwargs):
|
||||
concrete = tf_func.get_concrete_function(*args, **kwargs)
|
||||
# TPUPartitionedCall only accepts list of tensors as input args.
|
||||
# Flatten keyword arguments and do some basic ordering:
|
||||
# Positional args + Flattened keyword args + Captured args.
|
||||
op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs
|
||||
return tpu_functional.TPUPartitionedCall(
|
||||
args=op_args,
|
||||
device_ordinal=tpu_ops.tpu_ordinal_selector(),
|
||||
Tout=[o.type for o in concrete.function_def.signature.output_arg],
|
||||
f=concrete)
|
||||
|
||||
return def_function.function(inner_func)
|
||||
|
||||
|
||||
class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -653,6 +687,34 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
|
||||
self.assertAllEqual(
|
||||
strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3])
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/167235391): Reenable this test once function calls are handled "
|
||||
"by MLIR bridge."
|
||||
)
|
||||
def testOutsideCompilationWithTPUPartitionedCallOp(self):
|
||||
"""Tests that control flow with TPUPartitionedCall including outside_compilation works."""
|
||||
get_tpu_strategy()
|
||||
|
||||
def host_computation(x):
|
||||
return x + 1
|
||||
|
||||
@def_function.function()
|
||||
def train_step(x):
|
||||
x2 = x + 5.0
|
||||
logging_ops.print_v2(x2)
|
||||
x2 = tpu.outside_compilation(host_computation, x2)
|
||||
return x2 + 4.0
|
||||
|
||||
tpu_fn = _rewrite_func_wrapper(train_step)
|
||||
partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn)
|
||||
|
||||
concrete = partitioned_tpu_fn.get_concrete_function(
|
||||
x=tensor_spec.TensorSpec(
|
||||
shape=(1), dtype=dtypes.float32, name="input_tensor"))
|
||||
|
||||
self.assertIsInstance(
|
||||
concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user