Add outside compilation test with TPUPartitionedCall.
PiperOrigin-RevId: 339877462 Change-Id: Ic52ed793da9b3d48c79df173fd000c9930036839
This commit is contained in:
parent
d4e2d789d0
commit
60aa0451fe
@ -34,7 +34,9 @@ from tensorflow.python.eager import remote
|
|||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.lib.io import tf_record
|
from tensorflow.python.lib.io import tf_record
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -47,8 +49,10 @@ from tensorflow.python.ops import string_ops
|
|||||||
from tensorflow.python.ops import summary_ops_v2 as summary
|
from tensorflow.python.ops import summary_ops_v2 as summary
|
||||||
from tensorflow.python.platform import flags
|
from tensorflow.python.platform import flags
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.tpu import functional as tpu_functional
|
||||||
from tensorflow.python.tpu import tpu
|
from tensorflow.python.tpu import tpu
|
||||||
from tensorflow.python.tpu import tpu_strategy_util
|
from tensorflow.python.tpu import tpu_strategy_util
|
||||||
|
from tensorflow.python.tpu.ops import tpu_ops
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
||||||
@ -91,6 +95,36 @@ def _events_from_logdir(test_case, logdir):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_func_wrapper(tf_func):
|
||||||
|
|
||||||
|
def tpu_fn(*args, **kwargs):
|
||||||
|
# tpu.rewrite only accepts list of tensors as input. We need to flatten
|
||||||
|
# keyword arguments to meet this requirement.
|
||||||
|
concrete = tf_func.get_concrete_function(*(list(args) +
|
||||||
|
list(kwargs.values())))
|
||||||
|
return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values()))
|
||||||
|
|
||||||
|
return def_function.function(tpu_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _tpu_partitioned_call_wrapper(tf_func):
|
||||||
|
"""Wrap a tensorflow Function with TPUPartitionedCall."""
|
||||||
|
|
||||||
|
def inner_func(*args, **kwargs):
|
||||||
|
concrete = tf_func.get_concrete_function(*args, **kwargs)
|
||||||
|
# TPUPartitionedCall only accepts list of tensors as input args.
|
||||||
|
# Flatten keyword arguments and do some basic ordering:
|
||||||
|
# Positional args + Flattened keyword args + Captured args.
|
||||||
|
op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs
|
||||||
|
return tpu_functional.TPUPartitionedCall(
|
||||||
|
args=op_args,
|
||||||
|
device_ordinal=tpu_ops.tpu_ordinal_selector(),
|
||||||
|
Tout=[o.type for o in concrete.function_def.signature.output_arg],
|
||||||
|
f=concrete)
|
||||||
|
|
||||||
|
return def_function.function(inner_func)
|
||||||
|
|
||||||
|
|
||||||
class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -653,6 +687,34 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
|
|||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3])
|
strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/167235391): Reenable this test once function calls are handled "
|
||||||
|
"by MLIR bridge."
|
||||||
|
)
|
||||||
|
def testOutsideCompilationWithTPUPartitionedCallOp(self):
|
||||||
|
"""Tests that control flow with TPUPartitionedCall including outside_compilation works."""
|
||||||
|
get_tpu_strategy()
|
||||||
|
|
||||||
|
def host_computation(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
@def_function.function()
|
||||||
|
def train_step(x):
|
||||||
|
x2 = x + 5.0
|
||||||
|
logging_ops.print_v2(x2)
|
||||||
|
x2 = tpu.outside_compilation(host_computation, x2)
|
||||||
|
return x2 + 4.0
|
||||||
|
|
||||||
|
tpu_fn = _rewrite_func_wrapper(train_step)
|
||||||
|
partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn)
|
||||||
|
|
||||||
|
concrete = partitioned_tpu_fn.get_concrete_function(
|
||||||
|
x=tensor_spec.TensorSpec(
|
||||||
|
shape=(1), dtype=dtypes.float32, name="input_tensor"))
|
||||||
|
|
||||||
|
self.assertIsInstance(
|
||||||
|
concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user