Make async_scope and async_clear_error experimental APIs.
PiperOrigin-RevId: 297523880 Change-Id: I4b76fdcb5d06a0e79317ad29ed314bf963a0d6a4
This commit is contained in:
parent
70cdf91366
commit
a0434c8fa6
tensorflow
python/eager
tools/api/golden
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
|
||||
@ -2142,6 +2143,51 @@ def check_alive(worker_name):
|
||||
return context().check_alive(worker_name)
|
||||
|
||||
|
||||
@tf_export("experimental.async_scope")
|
||||
@tf_contextlib.contextmanager
|
||||
def async_scope():
|
||||
"""Context manager for grouping async operations.
|
||||
|
||||
Ops/function calls inside the scope can return before finishing the actual
|
||||
execution. When exiting the async scope, a synchronization barrier will be
|
||||
automatically added to ensure the completion of all async op and function
|
||||
execution, potentially raising exceptions if async execution results in
|
||||
an error state.
|
||||
|
||||
Users may write the following code to asynchronuously invoke `train_step_fn`
|
||||
and log the `loss` metric for every `num_steps` steps in a training loop.
|
||||
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
|
||||
throw OutOfRangeError when running out of data. In the case:
|
||||
|
||||
```
|
||||
try:
|
||||
with tf.experimental.async_scope():
|
||||
for _ in range(num_steps):
|
||||
# Step function updates the metric `loss` internally
|
||||
train_step_fn()
|
||||
except tf.errors.OutOfRangeError:
|
||||
tf.experimental.async_clear_error()
|
||||
logging.info('loss =', loss.numpy())
|
||||
```
|
||||
|
||||
Yields:
|
||||
Context manager for grouping async operations.
|
||||
"""
|
||||
# TODO(haoyuzhang): replace env var once we have a config method to turn on
|
||||
# and off async streaming RPC
|
||||
remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
|
||||
old_policy = os.environ.get(remote_async_env_var)
|
||||
try:
|
||||
os.environ[remote_async_env_var] = str(True)
|
||||
yield
|
||||
finally:
|
||||
context().sync_executors()
|
||||
if old_policy is None:
|
||||
del os.environ[remote_async_env_var]
|
||||
else:
|
||||
os.environ[remote_async_env_var] = old_policy
|
||||
|
||||
|
||||
def async_wait():
|
||||
"""Sync all async operations and raise any errors during execution.
|
||||
|
||||
@ -2150,34 +2196,11 @@ def async_wait():
|
||||
all async op and function execution. It only returns when all pending nodes
|
||||
are finished, potentially raising exceptions if async execution results in
|
||||
an error state.
|
||||
|
||||
Users may write the following code to asynchronuously invoke `train_step_fn`
|
||||
and log the `loss` metric for every `num_steps` steps in a training loop.
|
||||
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
|
||||
throw OutOfRangeError when running out of data. In the case:
|
||||
- If the exception is thrown during the loop of scheduling function steps,
|
||||
the next call to function triggers an exception. In the except block,
|
||||
we clear the error and break from the loop;
|
||||
- If all `train_step_fn`s are scheduled before throwing an exception, we
|
||||
block at the last iteration to wait for the scheduled functions to finish
|
||||
excution and throw the OutOfRangeError.
|
||||
|
||||
```
|
||||
for i in range(num_steps):
|
||||
try:
|
||||
# Step function updates the metric `loss` internally
|
||||
train_step_fn()
|
||||
if i == num_steps - 1:
|
||||
context.async_wait()
|
||||
except tf.errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
break
|
||||
logging.info('loss =', loss.numpy())
|
||||
```
|
||||
"""
|
||||
context().sync_executors()
|
||||
|
||||
|
||||
@tf_export("experimental.async_clear_error")
|
||||
def async_clear_error():
|
||||
"""Clear pending operations and error statuses in async execution.
|
||||
|
||||
@ -2193,7 +2216,7 @@ def async_clear_error():
|
||||
# Step function updates the metric `loss` internally
|
||||
train_step_fn()
|
||||
except tf.errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
tf.experimental.async_clear_error()
|
||||
break
|
||||
logging.info('loss =', loss.numpy())
|
||||
```
|
||||
|
@ -219,6 +219,30 @@ class RemoteAsyncTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(v.numpy(), 4.0)
|
||||
|
||||
def test_out_of_range_with_async_scope(self):
|
||||
|
||||
with ops.device('/job:worker/task:0'):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
||||
dataset = dataset.batch(1, drop_remainder=False)
|
||||
iterator = iter(dataset)
|
||||
v = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
i = next(iterator)
|
||||
v.assign_add(math_ops.reduce_mean(i))
|
||||
|
||||
num_steps = 3
|
||||
try:
|
||||
with context.async_scope():
|
||||
for _ in range(num_steps):
|
||||
with ops.device('/job:worker/task:0'):
|
||||
train_step(iterator)
|
||||
except errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
|
||||
self.assertAllEqual(v.numpy(), 4.0)
|
||||
|
||||
|
||||
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -1,5 +1,13 @@
|
||||
path: "tensorflow.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "async_clear_error"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "async_scope"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "function_executor_type"
|
||||
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -4,6 +4,14 @@ tf_module {
|
||||
name: "tensorrt"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "async_clear_error"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "async_scope"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "function_executor_type"
|
||||
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user