Add a section on control flow.

PiperOrigin-RevId: 260924357
This commit is contained in:
Dan Moldovan 2019-07-31 07:42:52 -07:00 committed by TensorFlower Gardener
parent 926adde044
commit 0fd17699e8
3 changed files with 680 additions and 8 deletions

View File

@ -0,0 +1,517 @@
# AutoGraph reference
[Index](index.md)
## Control flow
AutoGraph rewrites all control flow statements with specialized AutoGraph
function calls. These function calls are capable of executing the corresponding
control flow statement using Python semantics for effects outside the Python
interpreter itself (see the [Introduction](intro.md)).
### Dispatch rules
Key Point: Only statements that are conditioned on, or iterate over, a
TensorFlow object such as `tf.Tensor`, are converted into TensorFlow ops.
As described in the [Introduction](intro.md), AutoGraph aims to preserve the
semantics of valid Python code. If a control flow statement runs in graph
execution without raising an error, then AutoGraph will also execute it as
normal Python control flow. Statements which would normally raise an error, for
example because a `tf.Tensor` cannot be used as a `bool` in an `if` statement,
are converted to TensorFlow control flow ops.
#### Analogy with compile-time constants and code optimization
From the perspective of a TensorFlow graph, non-Tensor values, for example an
integer or a NumPy array, are _constants_: they do not change value while the
graph executes.
For example, in the graph below, the condition is always `True` (it is
invariant):
```
x = 1
y = tf.cond(x > 0, lambda: 3 * x, lambda 5 * x)
```
That is equivalent to the code below:
```
x = 1
y = 3 * x
```
In the example above, we've optimized away the conditional on a constant
condition. The AutoGraph dispatch rules have the same effect: anything that is
not a TensorFlow object is a compile-time constant for TensorFlow, and can be
optimized away. For this reason, you can usually mix Python and TensorFlow
computation and it will transparently have the expected result even
when only some computations are executed in the graph.
<!-- TODO(mdan): This is actually a limitation (a very subtle one) -->
Caution: The assumption of invariant code made above is not true if the
TensorFlow graph had callbacks into the Python code. If you modify data
from within a `tf.py_function`, then the code outside a `tf.py_function`
will have unpredictable behavior if it depends on the same data.
For example, the `tf.cond` that runs as part of the `if` statement below will
miss the update made by `f`:
```
n = [10]
def f():
n[0] = 20
return 0
tf.py_function(f, (), (tf.int32,))
if tf.equal(n[0], 10):
tf.print('n is 10')
```
```
n is 10
```
### Compound symbols
AutoGraph usually handles basic symbols:
```
if a < 0:
a = -a
```
```
a = tf.cond(a < 0, lambda: -a, lambda: a)
```
But it can also handle complex symbols in many cases. For example, if we treat
`a.b` as a symbol in the code below, then we can use it as if it were a basic
symbol name:
```
if a.b < 0
a.b = -a.b
```
```
a.b = tf.cond(a.b < 0, lambda: -a.b, lambda: a.b)
```
This is useful in methods, which can operate on properties of `self`, as well as
working directly on more complex object structures or collections.
Caution: There are certain [limitations](limitations.md) around using Python
collections and object mutation. When in doubt, place the values you work
with into local variables and operate on those.
### Effects of the tracing process
#### All Python code paths are executed during tracing
When constructing a graph, TensorFlow _traces_ the code. The tracing of control
flow requires visiting _every possible code path_ (usually once).
Note: In rare cases, the runtime may decide to trace some code paths several
times. For example, the condition of a `while` statement may be executed twice,
first with a temporary graph, to determine whether it evaluates to a
`tf.Tensor`, then if it is a `tf.Tensor`, it's executed a second time in the
proper graph.
In other words, when tracing executes both branches of an if statement.
Similarly, the body of loops is executed once (even if the loop would otherwise
not iterate at all).
This explains why inserting `print` statements in an `if` statement produces
this output:
```
print('before if')
if tf.constant(True):
print('true branch')
else:
print('false branch')
print('after if')
```
```
before if
true branch
false branch
after if
```
Note: Control flow that is not executed as a TensorFlow graph is not traced. Its
body will execute as expected.
Example of code that runs as regular Python code:
```
print('before if')
if True: # Condition not a Tensor, running normally
print('true branch')
else:
print('false branch')
print('after if')
```
```
before if
true branch
after if
```
#### Python values modified in TensorFlow control flow become Tensors
If a symbol is modified in a TensorFlow control flow statement, then it becomes
a `tf.Tensor`, even if it started off as a Python promitive value.
For example, the conditional below will run as a `tf.cond` (its condition is a
`tf.Tensor`), which in turn will cause `i` to become a `tf.Tensor`.
```
i = 0
if tf.greater(i, 0):
i = 1
# i is not a Tensor
```
### `if` statements
`if` statements whose condition is a `tf.Tensor` are executed as TensorFlow
conditionals by converting them to `tf.cond`:
```
if tf.random.uniform(()) > 0.5:
x = 1
else:
x = 2
```
`if` statements whose condition is not a `tf.Tensor` are executed as normal
Python:
```
if np.random.uniform() > 0.5:
x = 1
else:
x = 2
```
`if` statements executed as TensorFlow conditionals are subject to restrictions
(see [limitations](limitations.md)). All symbols affected by the statement and
used thereafter must be:
* of a data type understood by TensorFlow
* defined in both branches
* of consistent dtypes in both branches, for TensorFlow entities
* of consistent structure in both branches, for static collections (such as
lists or tuples)
### `while` statements
`while` statements whose condition is a `tf.Tensor` are executed as TensorFlow
loops by converting them to `tf.while_loop`:
```
x = 0
while tf.random.uniform(()) > 0.5:
x = x + 1
```
`while` statements whose condition is not a `tf.Tensor` are executed as normal
Python:
```
x = 0
while np.random.uniform() > 0.5:
x = x + 1
```
`while` statements executed as TensorFlow loops are subject to restrictions
(see [limitations](limitations.md)). All symbols affected by the statement and
used thereafter must be:
* of a data type understood by TensorFlow
* defined before the loop
* of consistent dtype at the beginning and the end of the loop,
for TensorFlow entities
* either of consistent shape at the beginning and the end of the loop,
for TensorFlow entities, or declared in `shape_invariants`
* of consistent structure at the beginning and the end of the loop, for
static collections (such as lists or tuples)
Caution: A `while` loop whose condition is a Python scalar will execute as
normal Python. If you intended to run the loop as a TensorFlow loop, the loop
will replicate its body in the graph (it is unrolled). To avoid that, make sure
its condition is converted to a `tf.Tensor`, using for instance `tf.constant`.
For example, the following loop is unrolled, even though the list contains
`tf.Tensor` values, because the type of `l` is a Python `list`:
```
l = [tf.constant(1), tf.constant(2), tf.constant(3)]
for i in l:
tf.print(i) # This is unrolled - three `tf.print`s are built in the graph.
```
If you wish for the loop to run as a TensorFlow loop, stack the loop:
```
l = [tf.constant(1), tf.constant(2), tf.constant(3)]
for i in tf.stack(l):
tf.print(i) # This runs as a TensorFlow loop.
```
<!-- TODO(mdan): List this under limitations -->
Caution: A loop in which the type of the condition condition changes across
iterations, in a way that would influence the way the loop is executed, is not
allowed in AutoGraph.
For example, the loop below will generate an error. After the first iteration,
`i` becomes a tf.Tensor, because
```
i = 0
while i < 10: # `i < 10` is a Python bool - run as normal while loop
i = tf.constant(1) # Error -- `i < 10` would now be a `tf.Tensor`
```
### `for` statements
`for` statements that iterate over a `tf.Tensor` are executed as TensorFlow
loops by converting them to a `tf.while_loop` which iterates over the first
dimension (equivalent to NumPy):
```
for i in tf.constant(((1, 2), (3, 4))):
tf.print('iteration:', i)
```
```
iteration: [1, 2]
iteration: [3, 4]
```
Note: If possible, AutoGraph will also set the `maximum_iteration` parameter
of the `tf.while_loop`.
`for` statements that iterate over a the output of a `tf.range` are executed as
TensorFlow loops by converting them to a `tf.while_loop` which uses the
arguments passed to the `tf.range`:
```
for i in tf.range(3):
tf.print('iteration:', i)
```
`for` statements that iterate over a `tf.data.Dataset` and which do not contain
`break` or `return` statements are executed as TensorFlow loops by converting
them to `tf.data.Dataset.reduce` ops:
```
for i in tf.data.Dataset.range(3):
tf.print('iteration:', i)
```
`for` statements that iterate over a _distributed_ `tf.data.Dataset` and which
do not contain `break` or `return` statements are executed as TensorFlow loops
by converting them to the datasets' `reduce` ops:
```
for i in tf.distribute.OneDeviceStrategy('cpu').experimental_distribute_dataset(
tf.data.Dataset.range(3)):
tf.print('iteration:', i)
```
`for` statements that iterate over a `tf.data.Dataset` and which contain
`break` or `return` statements are executed as TensorFlow loops by converting
them to a combination of `tf.data.Dataset.scan`, `tf.data.Dataset.take_while`
and `tf.data.Dataset.reduce` ops:
```
for i in tf.data.Dataset.range(3):
tf.print('iteration:', i)
break
```
```
iteration: 1
```
`for` statements that iterate over a `tf.data.Dataset` _iterator_ are executed
as TensorFlow loops by converting them to a combination of `tf.while_loop`,
and `tf.cond` ops:
```
for i in iter(tf.data.Dataset.range(3)):
tf.print('iteration:', i)
```
`for` statements that iterate over a type different from any of the above are
executed as normal Python:
```
for i in [1, 2, 3]:
print('iteration:', i)
```
Caution: A `for` loop over a `list` or `tuple` of `tf.Tensor` is considered to
iterate over a Python `list` (or respectively `tuple`), therefore will be
executed as normal Python. If you intended to run it as a TensorFlow loop,
use `tf.stack` or `tf.concat`.
Caution: A `for` loop over a Python `range` will be executed as normal Python.
If you intended to run it as a TensorFlow loop, `tf.range`.
Note: AutoGraph may output a warning when it believes that you are unrolling
a loop inefficiently. However, the warning thresholds are very conservative.
### `break` statements
Code blocks in which `break` statements are used are rewritten with equivalent
code that uses extra control booleans and conditionals. The control booleans are
used directly in `while` loops. In the case of `for` loops, the AutoGraph
corresponding operator accepts an `extra_test` argument which is similar to
the conditional of a while loop, and which contains the control boolean.
For example, the `while` loop below is rewritten as (showing the output of the
`break` transformation only):
```
while i < 10:
if i > 3:
break
i += 1
```
```
break_ = False
while i < 10 and not break_:
if i > 3:
break_ = True
continue # The continue statement is also rewritten in a subsequent pass
i += 1
```
Another example shows how the control boolean is used in the overload of a `for`
loop (showing portions of the final output):
```
for i in range(10):
if i > 3:
break
```
```
break_ = False
...
def extra_test(break_):
return ag__.not_(break_)
# break_ becomes a loop variable.
break_, = ag__.for_stmt(range(10), extra_test, ..., (break_,))
```
### `continue` statements
Code blocks in which `continue` statements are used are rewritten with
equivalent code that uses extra control booleans and conditionals, similar to
how `break` is handled.
For example, the `for` loop below is rewritten as (showing the output of the
`continue` transformation only):
```
for i in range(10):
if i > 3:
continue
```
```
for i in range(10):
continue_ = False
if i > 3:
continue_ = True
if not continue_:
i += 1
```
Notice that unlike `break`, `continue` statements are local to the loop and do
not influence the number of iterations.
### `return` statements
`return` statements are also rewritten using control symbols, in a manner
similar to how `break` is converted. In the case of `return` statements, an
additional symbol keeps track of the return value.
Depending on the structure of the code, the return value might be undefined
in parts of the code (for example on code paths in which no return statement
has executed). AutoGraph keeps track of this by using a special value.
This special value is converted to `None` (the default return value) upon
exiting the function.
Caution: TensorFlow control flow doe not support undefined values, and an
undefined return value is no exception. Therefore, AutoGraph will raise an
error for TensorFlow control flow in which the return value is not known for
all code paths.
For example, the following code raises an error because the return value would
be undefined when the random number would be less than 0.5:
```
if tf.random.uniform(()) > 0.5:
return 1
```
```
ValueError: A value must also be returned from the else branch.
```
An example of rewriting a `while` (showing the output of the `return`
transformation only):
```
def f():
while i < 10:
if i > 3:
return 1
i += 1
```
```
def f():
do_return = False
retval_ = ag__.UndefinedReturnValue()
while i < 10 and not do_return:
if i > 3:
do_return = True
retval_ = 1
if not do_return:
i += 1
return ag__.retval(retval_) # Transforms any UndefinedReturnValue to None
```
Note: AutoGraph performs an additional code normalization in which an `if`
statement with no `else` branch contains a `return` statement it is rewritten as
an `if-else` statement in which the code that follows the statement is moved
under the `else` branch.
Example (showing the normalization only):
```
def f():
if i > 3:
return 1
i += 1
```
```
def f():
if i > 3:
return 1
else:
i += 1
```

View File

@ -8,16 +8,17 @@ graph.
* [Introduction](intro.md)
* [Interacting with the generated code](generated_code.md)
* [Debugging AutoGraph code](debugging.md)
* Control Flow (coming soon)
* [Control flow](control_flow.md)
* Functions calls (coming soon)
* Exception handling (coming soon)
* Conversion mechanics (coming soon)
* Collections (coming soon)
* Exceptions (coming soon)
* Builtin Functions (coming soon)
* Datasets (coming soon)
* [Limitations](limitations.md)
* Common errors (coming soon)
For more information on AutoGraph, see the following articles:
* [AutoGraph tutorial](https://www.tensorflow.org/alpha/guide/autograph)
* [AutoGraph tutorial](https://www.tensorflow.org/alpha/beta/autograph)
* [Eager tutorial](https://www.tensorflow.org/alpha/guide/eager)
* [TensorFlow 2.0 Alpha](https://www.tensorflow.org/alpha)
* [AutoGraph blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)

View File

@ -10,6 +10,12 @@ However, when applied to TensorFlow control flow (for example, an if statement
with a `tf.Tensor` condition), there are certain limitations. This section
describes these limitations and practices that will allow you to avoid them.
Key Term: Python variables refer to Python symbols (or symbols for short) and
should not be confused with TensorFlow variables.
Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
### Indirect modifications and hidden side effects in TensorFlow control flow
<!-- TODO(mdan) Refine this paragraph well - it's important -->
@ -22,12 +28,10 @@ flow statements into equivalent TensorFlow ops. This process requires "wiring"
variables in the Python code whose values are affected these statements control
flow into the respective ops.
Note: Python variables should not be confused with TensorFlow variables.
The examples below use a `while` loop, but the same notions extend to all
control flow: `if` and `for` statements.
In the example below, `x` needs to become a _loop variable_ of the
In the example below, `x` needs to become a loop variable of the
corresponding `tf.while_loop':
```
@ -255,6 +259,156 @@ for i in tf.range(10):
d = {key: value + i for key, value in d.items()} # Okay
```
### Shape and dtype consistency in TensorFlow control flow
Unlike Python, TensorFlow has limited support for dynamic typing. This means
that tensors must maintain consistent shapes and dtypes across control flow
paths.
Note: In general, these restrictions do not apply in control flow in Eager
execution, because Eager execution uses Python control flow, rather than
TensorFlow control flow ops.
#### Consistency of dtype
The dtypes across all code paths must be consistent in conditionals and loops.
For example, if a `tf.cond` (and correspondingly, an AutoGraph `if`) sets a
tensor value conditionally, then that tensor must have the same shape and dtype
in both branches of the conditional.
Example of illegal dtype change in a conditional:
```
x = tf.cond(
tf.random.uniform(()) > 0.5,
lambda: tf.constant(1, dtype=tf.int32),
lambda: tf.constant(1, dtype=tf.float32)) # Error -- inconsistent dtypes: int32, float32
```
The same restriction in AutoGraph code:
```
if tf.random.uniform(()) > 0.5:
x = tf.constant(1, dtype=tf.int32)
else:
x = tf.constant(1, dtype=tf.float32) # Error -- inconsistent dtypes: int32, float32
```
Example of illegal dtype change in a loop:
```
# This won't work - "x" changes dtype inside the loop.
x = tf.while_loop(
lambda _: tf.random.uniform(()) > 0.5,
lambda x: tf.constant(1, dtype=tf.float32),
loop_vars=(tf.constant(1, dtype=tf.int32),)) # Error -- inconsistent dtypes: int32, float32
```
The same restriction in AutoGraph code:
```
x = tf.constant(0, dtype=tf.int32)
while tf.random.uniform(()) > 0.5:
x = tf.constant(0, dtype=tf.float32) # Error -- inconsistent dtypes: int32, float32
```
#### Consistency of shape
The shapes across all code paths must be consistent in loops only. When tensors
do need to change shape across iterations, use `shape_invariants`.
Note: Shapes are allowed to be inconsistent in conditionals. The result will be
a partially dynamic shape.
In a `tf.while_loop` (and correspondingly, an AutoGraph `while` or `for` loop)
all loop variables must maintain consistent shape and dtype across iterations.
That is, every loop variable must have the same shape at the end of the loop
body as the shape that it had at the beginning of the loop body.
Example of illegal shape change in a loop:
```
def loop_body(x): # x.shape is ()
return tf.constant((1, 2, 3)) # Error -- inconsistent shapes: (), (3,)
x = tf.while_loop(
lambda _: tf.random.uniform(()) > 0.5,
loop_body,
loop_vars=(tf.constant(1,))
```
The same restriction in AutoGraph code:
```
x = tf.constant(0, dtype=tf.int32)
while tf.random.uniform(()) > 0.5:
x = tf.constant(0, dtype=tf.float32) # Error -- inconsistent shapes: (), (3,)
```
### Undefined and None values in TensorFlow
TensorFlow does not support undefined and `None` values. All tensors must have
a value.
Example:
```
x = tf.cond(
tf.random.uniform(()) > 0.5,
lambda: tf.constant(1),
lambda: None) # Error -- a Tensor cannot be None
```
The same restriction carries over in AutoGraph, but only if the symbol is used
after the conditional (otherwise AutoGraph avoids making it a return value
of the `tf.cond`):
```
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
x = None
tf.print(x) # Error -- x may be None here
```
A related but less obvious restriction in AutoGraph forbids symbols to be
defined in only one branch of TensorFlow control flow, if the symbol is
used afterwards:
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
pass
tf.print(x) # Error -- x may be undefined here
```
Similarly, variables defined in a loop may not be used outside the loop, again
if the symbol is used afterwards:
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
tf.print(x) # Error -- x may be undefined here
```
Avoid these limitations by defining a default value before the control flow
statement:
```
x = tf.constant()
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
tf.print(x) # Okay -- x is either 0 or 1
```
Note: `None` values and undefined symbols are allowed in Eager control flow,
because Eager execution uses Python control flow, rather than TensorFlow
control flow ops.
### Access to source code
Key point: AutoGraph can only handle functions whose source code can be