Add a section on control flow.
PiperOrigin-RevId: 260924357
This commit is contained in:
parent
926adde044
commit
0fd17699e8
517
tensorflow/python/autograph/g3doc/reference/control_flow.md
Normal file
517
tensorflow/python/autograph/g3doc/reference/control_flow.md
Normal file
@ -0,0 +1,517 @@
|
||||
# AutoGraph reference
|
||||
|
||||
[Index](index.md)
|
||||
|
||||
## Control flow
|
||||
|
||||
AutoGraph rewrites all control flow statements with specialized AutoGraph
|
||||
function calls. These function calls are capable of executing the corresponding
|
||||
control flow statement using Python semantics for effects outside the Python
|
||||
interpreter itself (see the [Introduction](intro.md)).
|
||||
|
||||
### Dispatch rules
|
||||
|
||||
Key Point: Only statements that are conditioned on, or iterate over, a
|
||||
TensorFlow object such as `tf.Tensor`, are converted into TensorFlow ops.
|
||||
|
||||
As described in the [Introduction](intro.md), AutoGraph aims to preserve the
|
||||
semantics of valid Python code. If a control flow statement runs in graph
|
||||
execution without raising an error, then AutoGraph will also execute it as
|
||||
normal Python control flow. Statements which would normally raise an error, for
|
||||
example because a `tf.Tensor` cannot be used as a `bool` in an `if` statement,
|
||||
are converted to TensorFlow control flow ops.
|
||||
|
||||
#### Analogy with compile-time constants and code optimization
|
||||
|
||||
From the perspective of a TensorFlow graph, non-Tensor values, for example an
|
||||
integer or a NumPy array, are _constants_: they do not change value while the
|
||||
graph executes.
|
||||
|
||||
For example, in the graph below, the condition is always `True` (it is
|
||||
invariant):
|
||||
|
||||
```
|
||||
x = 1
|
||||
y = tf.cond(x > 0, lambda: 3 * x, lambda 5 * x)
|
||||
```
|
||||
|
||||
That is equivalent to the code below:
|
||||
|
||||
```
|
||||
x = 1
|
||||
y = 3 * x
|
||||
```
|
||||
|
||||
In the example above, we've optimized away the conditional on a constant
|
||||
condition. The AutoGraph dispatch rules have the same effect: anything that is
|
||||
not a TensorFlow object is a compile-time constant for TensorFlow, and can be
|
||||
optimized away. For this reason, you can usually mix Python and TensorFlow
|
||||
computation and it will transparently have the expected result even
|
||||
when only some computations are executed in the graph.
|
||||
|
||||
<!-- TODO(mdan): This is actually a limitation (a very subtle one) -->
|
||||
Caution: The assumption of invariant code made above is not true if the
|
||||
TensorFlow graph had callbacks into the Python code. If you modify data
|
||||
from within a `tf.py_function`, then the code outside a `tf.py_function`
|
||||
will have unpredictable behavior if it depends on the same data.
|
||||
|
||||
For example, the `tf.cond` that runs as part of the `if` statement below will
|
||||
miss the update made by `f`:
|
||||
|
||||
```
|
||||
n = [10]
|
||||
def f():
|
||||
n[0] = 20
|
||||
return 0
|
||||
tf.py_function(f, (), (tf.int32,))
|
||||
if tf.equal(n[0], 10):
|
||||
tf.print('n is 10')
|
||||
```
|
||||
|
||||
```
|
||||
n is 10
|
||||
```
|
||||
|
||||
### Compound symbols
|
||||
|
||||
AutoGraph usually handles basic symbols:
|
||||
|
||||
```
|
||||
if a < 0:
|
||||
a = -a
|
||||
```
|
||||
|
||||
```
|
||||
a = tf.cond(a < 0, lambda: -a, lambda: a)
|
||||
```
|
||||
|
||||
But it can also handle complex symbols in many cases. For example, if we treat
|
||||
`a.b` as a symbol in the code below, then we can use it as if it were a basic
|
||||
symbol name:
|
||||
|
||||
```
|
||||
if a.b < 0
|
||||
a.b = -a.b
|
||||
```
|
||||
|
||||
```
|
||||
a.b = tf.cond(a.b < 0, lambda: -a.b, lambda: a.b)
|
||||
```
|
||||
|
||||
This is useful in methods, which can operate on properties of `self`, as well as
|
||||
working directly on more complex object structures or collections.
|
||||
|
||||
Caution: There are certain [limitations](limitations.md) around using Python
|
||||
collections and object mutation. When in doubt, place the values you work
|
||||
with into local variables and operate on those.
|
||||
|
||||
### Effects of the tracing process
|
||||
|
||||
#### All Python code paths are executed during tracing
|
||||
|
||||
When constructing a graph, TensorFlow _traces_ the code. The tracing of control
|
||||
flow requires visiting _every possible code path_ (usually once).
|
||||
|
||||
Note: In rare cases, the runtime may decide to trace some code paths several
|
||||
times. For example, the condition of a `while` statement may be executed twice,
|
||||
first with a temporary graph, to determine whether it evaluates to a
|
||||
`tf.Tensor`, then if it is a `tf.Tensor`, it's executed a second time in the
|
||||
proper graph.
|
||||
|
||||
In other words, when tracing executes both branches of an if statement.
|
||||
Similarly, the body of loops is executed once (even if the loop would otherwise
|
||||
not iterate at all).
|
||||
|
||||
This explains why inserting `print` statements in an `if` statement produces
|
||||
this output:
|
||||
|
||||
```
|
||||
print('before if')
|
||||
if tf.constant(True):
|
||||
print('true branch')
|
||||
else:
|
||||
print('false branch')
|
||||
print('after if')
|
||||
```
|
||||
|
||||
```
|
||||
before if
|
||||
true branch
|
||||
false branch
|
||||
after if
|
||||
```
|
||||
|
||||
Note: Control flow that is not executed as a TensorFlow graph is not traced. Its
|
||||
body will execute as expected.
|
||||
|
||||
Example of code that runs as regular Python code:
|
||||
|
||||
```
|
||||
print('before if')
|
||||
if True: # Condition not a Tensor, running normally
|
||||
print('true branch')
|
||||
else:
|
||||
print('false branch')
|
||||
print('after if')
|
||||
```
|
||||
|
||||
```
|
||||
before if
|
||||
true branch
|
||||
after if
|
||||
```
|
||||
|
||||
#### Python values modified in TensorFlow control flow become Tensors
|
||||
|
||||
If a symbol is modified in a TensorFlow control flow statement, then it becomes
|
||||
a `tf.Tensor`, even if it started off as a Python promitive value.
|
||||
|
||||
For example, the conditional below will run as a `tf.cond` (its condition is a
|
||||
`tf.Tensor`), which in turn will cause `i` to become a `tf.Tensor`.
|
||||
|
||||
```
|
||||
i = 0
|
||||
if tf.greater(i, 0):
|
||||
i = 1
|
||||
# i is not a Tensor
|
||||
```
|
||||
|
||||
### `if` statements
|
||||
|
||||
`if` statements whose condition is a `tf.Tensor` are executed as TensorFlow
|
||||
conditionals by converting them to `tf.cond`:
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
```
|
||||
|
||||
`if` statements whose condition is not a `tf.Tensor` are executed as normal
|
||||
Python:
|
||||
|
||||
```
|
||||
if np.random.uniform() > 0.5:
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
```
|
||||
|
||||
`if` statements executed as TensorFlow conditionals are subject to restrictions
|
||||
(see [limitations](limitations.md)). All symbols affected by the statement and
|
||||
used thereafter must be:
|
||||
|
||||
* of a data type understood by TensorFlow
|
||||
* defined in both branches
|
||||
* of consistent dtypes in both branches, for TensorFlow entities
|
||||
* of consistent structure in both branches, for static collections (such as
|
||||
lists or tuples)
|
||||
|
||||
### `while` statements
|
||||
|
||||
`while` statements whose condition is a `tf.Tensor` are executed as TensorFlow
|
||||
loops by converting them to `tf.while_loop`:
|
||||
|
||||
```
|
||||
x = 0
|
||||
while tf.random.uniform(()) > 0.5:
|
||||
x = x + 1
|
||||
```
|
||||
|
||||
`while` statements whose condition is not a `tf.Tensor` are executed as normal
|
||||
Python:
|
||||
|
||||
```
|
||||
x = 0
|
||||
while np.random.uniform() > 0.5:
|
||||
x = x + 1
|
||||
```
|
||||
|
||||
`while` statements executed as TensorFlow loops are subject to restrictions
|
||||
(see [limitations](limitations.md)). All symbols affected by the statement and
|
||||
used thereafter must be:
|
||||
|
||||
* of a data type understood by TensorFlow
|
||||
* defined before the loop
|
||||
* of consistent dtype at the beginning and the end of the loop,
|
||||
for TensorFlow entities
|
||||
* either of consistent shape at the beginning and the end of the loop,
|
||||
for TensorFlow entities, or declared in `shape_invariants`
|
||||
* of consistent structure at the beginning and the end of the loop, for
|
||||
static collections (such as lists or tuples)
|
||||
|
||||
Caution: A `while` loop whose condition is a Python scalar will execute as
|
||||
normal Python. If you intended to run the loop as a TensorFlow loop, the loop
|
||||
will replicate its body in the graph (it is unrolled). To avoid that, make sure
|
||||
its condition is converted to a `tf.Tensor`, using for instance `tf.constant`.
|
||||
|
||||
For example, the following loop is unrolled, even though the list contains
|
||||
`tf.Tensor` values, because the type of `l` is a Python `list`:
|
||||
|
||||
```
|
||||
l = [tf.constant(1), tf.constant(2), tf.constant(3)]
|
||||
for i in l:
|
||||
tf.print(i) # This is unrolled - three `tf.print`s are built in the graph.
|
||||
```
|
||||
|
||||
If you wish for the loop to run as a TensorFlow loop, stack the loop:
|
||||
|
||||
```
|
||||
l = [tf.constant(1), tf.constant(2), tf.constant(3)]
|
||||
for i in tf.stack(l):
|
||||
tf.print(i) # This runs as a TensorFlow loop.
|
||||
```
|
||||
|
||||
<!-- TODO(mdan): List this under limitations -->
|
||||
Caution: A loop in which the type of the condition condition changes across
|
||||
iterations, in a way that would influence the way the loop is executed, is not
|
||||
allowed in AutoGraph.
|
||||
|
||||
For example, the loop below will generate an error. After the first iteration,
|
||||
`i` becomes a tf.Tensor, because
|
||||
|
||||
```
|
||||
i = 0
|
||||
while i < 10: # `i < 10` is a Python bool - run as normal while loop
|
||||
i = tf.constant(1) # Error -- `i < 10` would now be a `tf.Tensor`
|
||||
```
|
||||
|
||||
### `for` statements
|
||||
|
||||
`for` statements that iterate over a `tf.Tensor` are executed as TensorFlow
|
||||
loops by converting them to a `tf.while_loop` which iterates over the first
|
||||
dimension (equivalent to NumPy):
|
||||
|
||||
```
|
||||
for i in tf.constant(((1, 2), (3, 4))):
|
||||
tf.print('iteration:', i)
|
||||
```
|
||||
|
||||
```
|
||||
iteration: [1, 2]
|
||||
iteration: [3, 4]
|
||||
```
|
||||
|
||||
Note: If possible, AutoGraph will also set the `maximum_iteration` parameter
|
||||
of the `tf.while_loop`.
|
||||
|
||||
`for` statements that iterate over a the output of a `tf.range` are executed as
|
||||
TensorFlow loops by converting them to a `tf.while_loop` which uses the
|
||||
arguments passed to the `tf.range`:
|
||||
|
||||
```
|
||||
for i in tf.range(3):
|
||||
tf.print('iteration:', i)
|
||||
```
|
||||
|
||||
`for` statements that iterate over a `tf.data.Dataset` and which do not contain
|
||||
`break` or `return` statements are executed as TensorFlow loops by converting
|
||||
them to `tf.data.Dataset.reduce` ops:
|
||||
|
||||
```
|
||||
for i in tf.data.Dataset.range(3):
|
||||
tf.print('iteration:', i)
|
||||
```
|
||||
|
||||
`for` statements that iterate over a _distributed_ `tf.data.Dataset` and which
|
||||
do not contain `break` or `return` statements are executed as TensorFlow loops
|
||||
by converting them to the datasets' `reduce` ops:
|
||||
|
||||
```
|
||||
for i in tf.distribute.OneDeviceStrategy('cpu').experimental_distribute_dataset(
|
||||
tf.data.Dataset.range(3)):
|
||||
tf.print('iteration:', i)
|
||||
```
|
||||
|
||||
`for` statements that iterate over a `tf.data.Dataset` and which contain
|
||||
`break` or `return` statements are executed as TensorFlow loops by converting
|
||||
them to a combination of `tf.data.Dataset.scan`, `tf.data.Dataset.take_while`
|
||||
and `tf.data.Dataset.reduce` ops:
|
||||
|
||||
```
|
||||
for i in tf.data.Dataset.range(3):
|
||||
tf.print('iteration:', i)
|
||||
break
|
||||
```
|
||||
|
||||
```
|
||||
iteration: 1
|
||||
```
|
||||
|
||||
`for` statements that iterate over a `tf.data.Dataset` _iterator_ are executed
|
||||
as TensorFlow loops by converting them to a combination of `tf.while_loop`,
|
||||
and `tf.cond` ops:
|
||||
|
||||
```
|
||||
for i in iter(tf.data.Dataset.range(3)):
|
||||
tf.print('iteration:', i)
|
||||
```
|
||||
|
||||
`for` statements that iterate over a type different from any of the above are
|
||||
executed as normal Python:
|
||||
|
||||
```
|
||||
for i in [1, 2, 3]:
|
||||
print('iteration:', i)
|
||||
```
|
||||
|
||||
Caution: A `for` loop over a `list` or `tuple` of `tf.Tensor` is considered to
|
||||
iterate over a Python `list` (or respectively `tuple`), therefore will be
|
||||
executed as normal Python. If you intended to run it as a TensorFlow loop,
|
||||
use `tf.stack` or `tf.concat`.
|
||||
|
||||
Caution: A `for` loop over a Python `range` will be executed as normal Python.
|
||||
If you intended to run it as a TensorFlow loop, `tf.range`.
|
||||
|
||||
Note: AutoGraph may output a warning when it believes that you are unrolling
|
||||
a loop inefficiently. However, the warning thresholds are very conservative.
|
||||
|
||||
### `break` statements
|
||||
|
||||
Code blocks in which `break` statements are used are rewritten with equivalent
|
||||
code that uses extra control booleans and conditionals. The control booleans are
|
||||
used directly in `while` loops. In the case of `for` loops, the AutoGraph
|
||||
corresponding operator accepts an `extra_test` argument which is similar to
|
||||
the conditional of a while loop, and which contains the control boolean.
|
||||
|
||||
For example, the `while` loop below is rewritten as (showing the output of the
|
||||
`break` transformation only):
|
||||
|
||||
```
|
||||
while i < 10:
|
||||
if i > 3:
|
||||
break
|
||||
i += 1
|
||||
```
|
||||
|
||||
```
|
||||
break_ = False
|
||||
while i < 10 and not break_:
|
||||
if i > 3:
|
||||
break_ = True
|
||||
continue # The continue statement is also rewritten in a subsequent pass
|
||||
i += 1
|
||||
```
|
||||
|
||||
Another example shows how the control boolean is used in the overload of a `for`
|
||||
loop (showing portions of the final output):
|
||||
|
||||
```
|
||||
for i in range(10):
|
||||
if i > 3:
|
||||
break
|
||||
```
|
||||
|
||||
```
|
||||
break_ = False
|
||||
...
|
||||
def extra_test(break_):
|
||||
return ag__.not_(break_)
|
||||
# break_ becomes a loop variable.
|
||||
break_, = ag__.for_stmt(range(10), extra_test, ..., (break_,))
|
||||
```
|
||||
|
||||
### `continue` statements
|
||||
|
||||
Code blocks in which `continue` statements are used are rewritten with
|
||||
equivalent code that uses extra control booleans and conditionals, similar to
|
||||
how `break` is handled.
|
||||
|
||||
For example, the `for` loop below is rewritten as (showing the output of the
|
||||
`continue` transformation only):
|
||||
|
||||
```
|
||||
for i in range(10):
|
||||
if i > 3:
|
||||
continue
|
||||
```
|
||||
|
||||
```
|
||||
for i in range(10):
|
||||
continue_ = False
|
||||
if i > 3:
|
||||
continue_ = True
|
||||
if not continue_:
|
||||
i += 1
|
||||
```
|
||||
|
||||
Notice that unlike `break`, `continue` statements are local to the loop and do
|
||||
not influence the number of iterations.
|
||||
|
||||
### `return` statements
|
||||
|
||||
`return` statements are also rewritten using control symbols, in a manner
|
||||
similar to how `break` is converted. In the case of `return` statements, an
|
||||
additional symbol keeps track of the return value.
|
||||
|
||||
Depending on the structure of the code, the return value might be undefined
|
||||
in parts of the code (for example on code paths in which no return statement
|
||||
has executed). AutoGraph keeps track of this by using a special value.
|
||||
This special value is converted to `None` (the default return value) upon
|
||||
exiting the function.
|
||||
|
||||
Caution: TensorFlow control flow doe not support undefined values, and an
|
||||
undefined return value is no exception. Therefore, AutoGraph will raise an
|
||||
error for TensorFlow control flow in which the return value is not known for
|
||||
all code paths.
|
||||
|
||||
For example, the following code raises an error because the return value would
|
||||
be undefined when the random number would be less than 0.5:
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
return 1
|
||||
```
|
||||
|
||||
```
|
||||
ValueError: A value must also be returned from the else branch.
|
||||
```
|
||||
|
||||
An example of rewriting a `while` (showing the output of the `return`
|
||||
transformation only):
|
||||
|
||||
```
|
||||
def f():
|
||||
while i < 10:
|
||||
if i > 3:
|
||||
return 1
|
||||
i += 1
|
||||
```
|
||||
|
||||
```
|
||||
def f():
|
||||
do_return = False
|
||||
retval_ = ag__.UndefinedReturnValue()
|
||||
while i < 10 and not do_return:
|
||||
if i > 3:
|
||||
do_return = True
|
||||
retval_ = 1
|
||||
if not do_return:
|
||||
i += 1
|
||||
return ag__.retval(retval_) # Transforms any UndefinedReturnValue to None
|
||||
```
|
||||
|
||||
Note: AutoGraph performs an additional code normalization in which an `if`
|
||||
statement with no `else` branch contains a `return` statement it is rewritten as
|
||||
an `if-else` statement in which the code that follows the statement is moved
|
||||
under the `else` branch.
|
||||
|
||||
Example (showing the normalization only):
|
||||
|
||||
```
|
||||
def f():
|
||||
if i > 3:
|
||||
return 1
|
||||
i += 1
|
||||
```
|
||||
|
||||
```
|
||||
def f():
|
||||
if i > 3:
|
||||
return 1
|
||||
else:
|
||||
i += 1
|
||||
```
|
||||
|
||||
|
@ -8,16 +8,17 @@ graph.
|
||||
* [Introduction](intro.md)
|
||||
* [Interacting with the generated code](generated_code.md)
|
||||
* [Debugging AutoGraph code](debugging.md)
|
||||
* Control Flow (coming soon)
|
||||
* [Control flow](control_flow.md)
|
||||
* Functions calls (coming soon)
|
||||
* Exception handling (coming soon)
|
||||
* Conversion mechanics (coming soon)
|
||||
* Collections (coming soon)
|
||||
* Exceptions (coming soon)
|
||||
* Builtin Functions (coming soon)
|
||||
* Datasets (coming soon)
|
||||
* [Limitations](limitations.md)
|
||||
* Common errors (coming soon)
|
||||
|
||||
For more information on AutoGraph, see the following articles:
|
||||
|
||||
* [AutoGraph tutorial](https://www.tensorflow.org/alpha/guide/autograph)
|
||||
* [AutoGraph tutorial](https://www.tensorflow.org/alpha/beta/autograph)
|
||||
* [Eager tutorial](https://www.tensorflow.org/alpha/guide/eager)
|
||||
* [TensorFlow 2.0 Alpha](https://www.tensorflow.org/alpha)
|
||||
* [AutoGraph blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
|
||||
|
@ -10,6 +10,12 @@ However, when applied to TensorFlow control flow (for example, an if statement
|
||||
with a `tf.Tensor` condition), there are certain limitations. This section
|
||||
describes these limitations and practices that will allow you to avoid them.
|
||||
|
||||
Key Term: Python variables refer to Python symbols (or symbols for short) and
|
||||
should not be confused with TensorFlow variables.
|
||||
|
||||
Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
|
||||
value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
|
||||
|
||||
### Indirect modifications and hidden side effects in TensorFlow control flow
|
||||
|
||||
<!-- TODO(mdan) Refine this paragraph well - it's important -->
|
||||
@ -22,12 +28,10 @@ flow statements into equivalent TensorFlow ops. This process requires "wiring"
|
||||
variables in the Python code whose values are affected these statements control
|
||||
flow into the respective ops.
|
||||
|
||||
Note: Python variables should not be confused with TensorFlow variables.
|
||||
|
||||
The examples below use a `while` loop, but the same notions extend to all
|
||||
control flow: `if` and `for` statements.
|
||||
|
||||
In the example below, `x` needs to become a _loop variable_ of the
|
||||
In the example below, `x` needs to become a loop variable of the
|
||||
corresponding `tf.while_loop':
|
||||
|
||||
```
|
||||
@ -255,6 +259,156 @@ for i in tf.range(10):
|
||||
d = {key: value + i for key, value in d.items()} # Okay
|
||||
```
|
||||
|
||||
### Shape and dtype consistency in TensorFlow control flow
|
||||
|
||||
Unlike Python, TensorFlow has limited support for dynamic typing. This means
|
||||
that tensors must maintain consistent shapes and dtypes across control flow
|
||||
paths.
|
||||
|
||||
Note: In general, these restrictions do not apply in control flow in Eager
|
||||
execution, because Eager execution uses Python control flow, rather than
|
||||
TensorFlow control flow ops.
|
||||
|
||||
#### Consistency of dtype
|
||||
|
||||
The dtypes across all code paths must be consistent in conditionals and loops.
|
||||
|
||||
For example, if a `tf.cond` (and correspondingly, an AutoGraph `if`) sets a
|
||||
tensor value conditionally, then that tensor must have the same shape and dtype
|
||||
in both branches of the conditional.
|
||||
|
||||
Example of illegal dtype change in a conditional:
|
||||
|
||||
```
|
||||
x = tf.cond(
|
||||
tf.random.uniform(()) > 0.5,
|
||||
lambda: tf.constant(1, dtype=tf.int32),
|
||||
lambda: tf.constant(1, dtype=tf.float32)) # Error -- inconsistent dtypes: int32, float32
|
||||
```
|
||||
|
||||
The same restriction in AutoGraph code:
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1, dtype=tf.int32)
|
||||
else:
|
||||
x = tf.constant(1, dtype=tf.float32) # Error -- inconsistent dtypes: int32, float32
|
||||
```
|
||||
|
||||
Example of illegal dtype change in a loop:
|
||||
|
||||
```
|
||||
# This won't work - "x" changes dtype inside the loop.
|
||||
x = tf.while_loop(
|
||||
lambda _: tf.random.uniform(()) > 0.5,
|
||||
lambda x: tf.constant(1, dtype=tf.float32),
|
||||
loop_vars=(tf.constant(1, dtype=tf.int32),)) # Error -- inconsistent dtypes: int32, float32
|
||||
```
|
||||
|
||||
The same restriction in AutoGraph code:
|
||||
|
||||
```
|
||||
x = tf.constant(0, dtype=tf.int32)
|
||||
while tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(0, dtype=tf.float32) # Error -- inconsistent dtypes: int32, float32
|
||||
```
|
||||
|
||||
#### Consistency of shape
|
||||
|
||||
The shapes across all code paths must be consistent in loops only. When tensors
|
||||
do need to change shape across iterations, use `shape_invariants`.
|
||||
|
||||
Note: Shapes are allowed to be inconsistent in conditionals. The result will be
|
||||
a partially dynamic shape.
|
||||
|
||||
In a `tf.while_loop` (and correspondingly, an AutoGraph `while` or `for` loop)
|
||||
all loop variables must maintain consistent shape and dtype across iterations.
|
||||
That is, every loop variable must have the same shape at the end of the loop
|
||||
body as the shape that it had at the beginning of the loop body.
|
||||
|
||||
Example of illegal shape change in a loop:
|
||||
|
||||
```
|
||||
def loop_body(x): # x.shape is ()
|
||||
return tf.constant((1, 2, 3)) # Error -- inconsistent shapes: (), (3,)
|
||||
|
||||
x = tf.while_loop(
|
||||
lambda _: tf.random.uniform(()) > 0.5,
|
||||
loop_body,
|
||||
loop_vars=(tf.constant(1,))
|
||||
```
|
||||
|
||||
The same restriction in AutoGraph code:
|
||||
|
||||
```
|
||||
x = tf.constant(0, dtype=tf.int32)
|
||||
while tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(0, dtype=tf.float32) # Error -- inconsistent shapes: (), (3,)
|
||||
```
|
||||
|
||||
### Undefined and None values in TensorFlow
|
||||
|
||||
TensorFlow does not support undefined and `None` values. All tensors must have
|
||||
a value.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
x = tf.cond(
|
||||
tf.random.uniform(()) > 0.5,
|
||||
lambda: tf.constant(1),
|
||||
lambda: None) # Error -- a Tensor cannot be None
|
||||
```
|
||||
|
||||
The same restriction carries over in AutoGraph, but only if the symbol is used
|
||||
after the conditional (otherwise AutoGraph avoids making it a return value
|
||||
of the `tf.cond`):
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
x = None
|
||||
tf.print(x) # Error -- x may be None here
|
||||
```
|
||||
|
||||
A related but less obvious restriction in AutoGraph forbids symbols to be
|
||||
defined in only one branch of TensorFlow control flow, if the symbol is
|
||||
used afterwards:
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
pass
|
||||
tf.print(x) # Error -- x may be undefined here
|
||||
```
|
||||
|
||||
Similarly, variables defined in a loop may not be used outside the loop, again
|
||||
if the symbol is used afterwards:
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
tf.print(x) # Error -- x may be undefined here
|
||||
```
|
||||
|
||||
Avoid these limitations by defining a default value before the control flow
|
||||
statement:
|
||||
|
||||
```
|
||||
x = tf.constant()
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
tf.print(x) # Okay -- x is either 0 or 1
|
||||
```
|
||||
|
||||
Note: `None` values and undefined symbols are allowed in Eager control flow,
|
||||
because Eager execution uses Python control flow, rather than TensorFlow
|
||||
control flow ops.
|
||||
|
||||
### Access to source code
|
||||
|
||||
Key point: AutoGraph can only handle functions whose source code can be
|
||||
|
Loading…
Reference in New Issue
Block a user