Update tf.function doc: mention that while
is captured, and explain how python numerical values are handled
This commit is contained in:
parent
13fe6ef76e
commit
d10bd87f2c
@ -762,7 +762,8 @@ def function(func=None,
|
|||||||
assert (h().numpy() == f(x, y).numpy()).all()
|
assert (h().numpy() == f(x, y).numpy()).all()
|
||||||
|
|
||||||
# Data-dependent control flow is also captured in the graph. Supported
|
# Data-dependent control flow is also captured in the graph. Supported
|
||||||
# control flow statements include `if`, `for`, `break`, `continue`, `return`.
|
# control flow statements include `if`, `for`, `while`, `break`, `continue`,
|
||||||
|
# `return`.
|
||||||
@tf.function
|
@tf.function
|
||||||
def g(x):
|
def g(x):
|
||||||
if tf.reduce_sum(x) > 0:
|
if tf.reduce_sum(x) > 0:
|
||||||
@ -784,7 +785,13 @@ def function(func=None,
|
|||||||
```
|
```
|
||||||
|
|
||||||
Note that unlike other TensorFlow operations, we don't convert python
|
Note that unlike other TensorFlow operations, we don't convert python
|
||||||
numerical inputs to tensors.
|
numerical inputs to tensors. Moreover, a new graph is generated for each
|
||||||
|
distinct python numerical value, for example calling `g(2)` and `g(3)` will
|
||||||
|
generate two new graphs (while only one is generated if you call
|
||||||
|
`g(tf.constant(2))` and `g(tf.constant(3))`). Therefore, python numerical
|
||||||
|
inputs should be restricted to arguments that will have few distinct values,
|
||||||
|
such as hyperparameters like the number of layers in a neural network. This
|
||||||
|
allows TensorFlow to optimize each variant of the neural network.
|
||||||
|
|
||||||
_Referencing `tf.Variable`s_
|
_Referencing `tf.Variable`s_
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user