Return a meaningful error for dynamic shape inputs with outside compilation head extraction in TPUs.
PiperOrigin-RevId: 311490072 Change-Id: Idc7bf1764aba1fcbfcf830e36a5b575b387923d7
This commit is contained in:
parent
da78c46560
commit
ca18db7f3f
@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.eager import remote
|
from tensorflow.python.eager import remote
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
# for non-local TPU.
|
# for non-local TPU.
|
||||||
if FLAGS.tpu:
|
if FLAGS.tpu:
|
||||||
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
||||||
|
|
||||||
|
# Disable automatic outside compilation.
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy()
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
|
|
||||||
good_run()
|
good_run()
|
||||||
|
|
||||||
|
def test_dynamic_shape_with_outside_compilation_failure(self):
|
||||||
|
# Enable automatic outside compilation.
|
||||||
|
config.set_soft_device_placement(True)
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
|
||||||
|
2, drop_remainder=False)
|
||||||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
iterator = iter(dataset)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_fn(iterator):
|
||||||
|
|
||||||
|
def step_fn(inputs):
|
||||||
|
_, inputs = inputs
|
||||||
|
return math_ops.reduce_sum(inputs)
|
||||||
|
|
||||||
|
return strategy.experimental_local_results(
|
||||||
|
strategy.run(step_fn, args=(next(iterator),)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
||||||
|
logging.info(train_fn(iterator))
|
||||||
|
|
||||||
def test_computation_on_subset_cores(self):
|
def test_computation_on_subset_cores(self):
|
||||||
resolver = get_tpu_cluster_resolver()
|
resolver = get_tpu_cluster_resolver()
|
||||||
remote.connect_to_cluster(resolver)
|
remote.connect_to_cluster(resolver)
|
||||||
|
Loading…
Reference in New Issue
Block a user