Write the example in a more TF-friendly fashion.

PiperOrigin-RevId: 356499216
Change-Id: I70710b504c1dc2eef41c2efd64b8e05bed1f52a2
This commit is contained in:
Dan Moldovan 2021-02-09 07:42:11 -08:00 committed by TensorFlower Gardener
parent 03773a81bd
commit 385b8f50bf

View File

@ -303,25 +303,39 @@ while x > 0:
c.y += 1 # Okay -- c.y can now be properly tracked!
```
Another possibility is to rely on immutable objects. This may lead to many
temporary objects when executing eagerly, but their number is greatly reduced
in `@tf.function`:
Another possibility is to rely on immutable objects with value semantics. This
may lead to many temporary objects when executing eagerly, but their number is
greatly reduced in `@tf.function`:
```
class MyClass(object):
class MyClass(collections.namedtuple('MyClass', ('y',))):
def change(self):
self.y += 1
return self
new_y = self.y + 1
return MyClass(new_y)
c = MyClass()
while x > 0:
c = c.change() # Okay -- c is now a loop var.
```
It is also recommended to use a functional programming style with such immutable
objects - that is, all arguments are inputs, all changes are return values:
```
def use_my_class(c: MyClass) -> MyClass:
new_c = c.change()
return new_c
```
Don't worry about creating a few extra objects - they are only used at trace
time, and don't exist at graph execution.
Note: TensorFlow control flow does not currently support arbitrary Python
objects, but it does support basic collection objects such as `list`, `dict`,
`tuple`, `namedtuple` and their subclasses. Design your objects as subclasses
of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple).
of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple),
or other types that [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest/map_structure)
recognizes.
#### Variables closed over by lambda functions