Return a meaningful error for dynamic shape inputs with outside compilation head extraction in TPUs.

PiperOrigin-RevId: 311490072
Change-Id: Idc7bf1764aba1fcbfcf830e36a5b575b387923d7
This commit is contained in:
Ruoxin Sang 2020-05-14 01:38:34 -07:00 committed by TensorFlower Gardener
parent da78c46560
commit ca18db7f3f

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.eager import remote from tensorflow.python.eager import remote
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase):
# for non-local TPU. # for non-local TPU.
if FLAGS.tpu: if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981") self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy() strategy = get_tpu_strategy()
@def_function.function @def_function.function
@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase):
good_run() good_run()
def test_dynamic_shape_with_outside_compilation_failure(self):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self): def test_computation_on_subset_cores(self):
resolver = get_tpu_cluster_resolver() resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver) remote.connect_to_cluster(resolver)