Return a meaningful error for dynamic shape inputs with outside compilation head extraction in TPUs.

PiperOrigin-RevId: 311495416
Change-Id: I42b12ac545224c32e770d963a5f3f333ba280531
This commit is contained in:
A. Unique TensorFlower 2020-05-14 02:32:13 -07:00 committed by TensorFlower Gardener
parent 5767af0cd2
commit 85bf5f7c20
1 changed files with 0 additions and 26 deletions

View File

@ -28,7 +28,6 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.eager import remote from tensorflow.python.eager import remote
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -141,9 +140,6 @@ class TPUStrategyTest(test.TestCase):
# for non-local TPU. # for non-local TPU.
if FLAGS.tpu: if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981") self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy() strategy = get_tpu_strategy()
@def_function.function @def_function.function
@ -168,28 +164,6 @@ class TPUStrategyTest(test.TestCase):
good_run() good_run()
def test_dynamic_shape_with_outside_compilation_failure(self):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self): def test_computation_on_subset_cores(self):
resolver = get_tpu_cluster_resolver() resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver) remote.connect_to_cluster(resolver)