STT-tensorflow/tensorflow/python/distribute/parallel_device
Allen Lavoie 6b4ba7fb16 Parallel device: make tf.cond work executing eagerly
This takes the easy but not too satisfying "wrap it in a tf.function" approach, similar to what tf.vectorized_map does. This means we'll re-trace the cond's branches every time tf.cond runs.

If this ends up being a performance bottleneck there are a few things we can do. One is to check if the condition parallel tensor is actually going to take different branches on different devices (and do the eager thing if not). Another is to tweak the calling code (e.g. BN) to wrap the cond itself in a tf.function; there we'll be able to cache the trace. We could also implement cond in the parallel device, with null optionals if a device isn't taking a branch; that seems pretty complicated.

PiperOrigin-RevId: 345741960
Change-Id: Iaa543e03a2dab96dc0fa0cd453f48718f42d31a8
2020-12-04 13:19:43 -08:00
..
BUILD Parallel device: fix variable initialization in tf.function 2020-09-28 09:16:33 -07:00
parallel_device_test.py Parallel device: make tf.cond work executing eagerly 2020-12-04 13:19:43 -08:00
parallel_device.py Parallel device: fix variable initialization in tf.function 2020-09-28 09:16:33 -07:00
pywrap_parallel_device.cc Factor out C++ types from the parallel device 2020-06-04 14:49:16 -07:00
saving.py Parallel device: fix variable initialization in tf.function 2020-09-28 09:16:33 -07:00