Update tf.function doc: mention that while is captured, and explain how python numerical values are handled

This commit is contained in:
Aurélien Geron 2019-03-25 10:01:20 +08:00 committed by GitHub
parent 13fe6ef76e
commit d10bd87f2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -762,7 +762,8 @@ def function(func=None,
assert (h().numpy() == f(x, y).numpy()).all() assert (h().numpy() == f(x, y).numpy()).all()
# Data-dependent control flow is also captured in the graph. Supported # Data-dependent control flow is also captured in the graph. Supported
# control flow statements include `if`, `for`, `break`, `continue`, `return`. # control flow statements include `if`, `for`, `while`, `break`, `continue`,
# `return`.
@tf.function @tf.function
def g(x): def g(x):
if tf.reduce_sum(x) > 0: if tf.reduce_sum(x) > 0:
@ -784,7 +785,13 @@ def function(func=None,
``` ```
Note that unlike other TensorFlow operations, we don't convert python Note that unlike other TensorFlow operations, we don't convert python
numerical inputs to tensors. numerical inputs to tensors. Moreover, a new graph is generated for each
distinct python numerical value, for example calling `g(2)` and `g(3)` will
generate two new graphs (while only one is generated if you call
`g(tf.constant(2))` and `g(tf.constant(3))`). Therefore, python numerical
inputs should be restricted to arguments that will have few distinct values,
such as hyperparameters like the number of layers in a neural network. This
allows TensorFlow to optimize each variant of the neural network.
_Referencing `tf.Variable`s_ _Referencing `tf.Variable`s_