Make async_scope and async_clear_error experimental APIs.

PiperOrigin-RevId: 297523880
Change-Id: I4b76fdcb5d06a0e79317ad29ed314bf963a0d6a4
This commit is contained in:
Haoyu Zhang 2020-02-26 23:23:11 -08:00 committed by TensorFlower Gardener
parent 70cdf91366
commit a0434c8fa6
4 changed files with 88 additions and 25 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import contextlib
import copy
import os
import random
import threading
@ -2142,6 +2143,51 @@ def check_alive(worker_name):
return context().check_alive(worker_name)
@tf_export("experimental.async_scope")
@tf_contextlib.contextmanager
def async_scope():
"""Context manager for grouping async operations.
Ops/function calls inside the scope can return before finishing the actual
execution. When exiting the async scope, a synchronization barrier will be
automatically added to ensure the completion of all async op and function
execution, potentially raising exceptions if async execution results in
an error state.
Users may write the following code to asynchronuously invoke `train_step_fn`
and log the `loss` metric for every `num_steps` steps in a training loop.
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
throw OutOfRangeError when running out of data. In the case:
```
try:
with tf.experimental.async_scope():
for _ in range(num_steps):
# Step function updates the metric `loss` internally
train_step_fn()
except tf.errors.OutOfRangeError:
tf.experimental.async_clear_error()
logging.info('loss =', loss.numpy())
```
Yields:
Context manager for grouping async operations.
"""
# TODO(haoyuzhang): replace env var once we have a config method to turn on
# and off async streaming RPC
remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
old_policy = os.environ.get(remote_async_env_var)
try:
os.environ[remote_async_env_var] = str(True)
yield
finally:
context().sync_executors()
if old_policy is None:
del os.environ[remote_async_env_var]
else:
os.environ[remote_async_env_var] = old_policy
def async_wait():
"""Sync all async operations and raise any errors during execution.
@ -2150,34 +2196,11 @@ def async_wait():
all async op and function execution. It only returns when all pending nodes
are finished, potentially raising exceptions if async execution results in
an error state.
Users may write the following code to asynchronuously invoke `train_step_fn`
and log the `loss` metric for every `num_steps` steps in a training loop.
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
throw OutOfRangeError when running out of data. In the case:
- If the exception is thrown during the loop of scheduling function steps,
the next call to function triggers an exception. In the except block,
we clear the error and break from the loop;
- If all `train_step_fn`s are scheduled before throwing an exception, we
block at the last iteration to wait for the scheduled functions to finish
excution and throw the OutOfRangeError.
```
for i in range(num_steps):
try:
# Step function updates the metric `loss` internally
train_step_fn()
if i == num_steps - 1:
context.async_wait()
except tf.errors.OutOfRangeError:
context.async_clear_error()
break
logging.info('loss =', loss.numpy())
```
"""
context().sync_executors()
@tf_export("experimental.async_clear_error")
def async_clear_error():
"""Clear pending operations and error statuses in async execution.
@ -2193,7 +2216,7 @@ def async_clear_error():
# Step function updates the metric `loss` internally
train_step_fn()
except tf.errors.OutOfRangeError:
context.async_clear_error()
tf.experimental.async_clear_error()
break
logging.info('loss =', loss.numpy())
```

View File

@ -219,6 +219,30 @@ class RemoteAsyncTest(test.TestCase):
self.assertAllEqual(v.numpy(), 4.0)
def test_out_of_range_with_async_scope(self):
with ops.device('/job:worker/task:0'):
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
dataset = dataset.batch(1, drop_remainder=False)
iterator = iter(dataset)
v = variables.Variable(1.0)
@def_function.function
def train_step(iterator):
i = next(iterator)
v.assign_add(math_ops.reduce_mean(i))
num_steps = 3
try:
with context.async_scope():
for _ in range(num_steps):
with ops.device('/job:worker/task:0'):
train_step(iterator)
except errors.OutOfRangeError:
context.async_clear_error()
self.assertAllEqual(v.numpy(), 4.0)
class MultiWorkersTest(test.TestCase, parameterized.TestCase):

View File

@ -1,5 +1,13 @@
path: "tensorflow.experimental"
tf_module {
member_method {
name: "async_clear_error"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "async_scope"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "function_executor_type"
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"

View File

@ -4,6 +4,14 @@ tf_module {
name: "tensorrt"
mtype: "<type \'module\'>"
}
member_method {
name: "async_clear_error"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "async_scope"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "function_executor_type"
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"