This takes the easy but not too satisfying "wrap it in a tf.function" approach, similar to what tf.vectorized_map does. This means we'll re-trace the cond's branches every time tf.cond runs. If this ends up being a performance bottleneck there are a few things we can do. One is to check if the condition parallel tensor is actually going to take different branches on different devices (and do the eager thing if not). Another is to tweak the calling code (e.g. BN) to wrap the cond itself in a tf.function; there we'll be able to cache the trace. We could also implement cond in the parallel device, with null optionals if a device isn't taking a branch; that seems pretty complicated. PiperOrigin-RevId: 345741960 Change-Id: Iaa543e03a2dab96dc0fa0cd453f48718f42d31a8 |
||
---|---|---|
.. | ||
BUILD | ||
parallel_device_test.py | ||
parallel_device.py | ||
pywrap_parallel_device.cc | ||
saving.py |