Workaround for bizarre issue related to aliasing tensors in cond_v2's transform

PiperOrigin-RevId: 224228895
This commit is contained in:
Alexandre Passos 2018-12-05 15:08:01 -08:00 committed by TensorFlower Gardener
parent 3f43965a44
commit 5abe466090

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@ -412,11 +414,19 @@ class BatchNormalizationV2(Layer):
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope:
with ops.colocate_with(variable):
# TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
# because of a bug which leads cond_v2 to skip rewriting them creating
# conflicts.
if tf2.enabled():
cm = contextlib.contextmanager(lambda: (yield))
else:
cm = ops.colocate_with(variable)
with cm:
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
update_delta = (
variable - math_ops.cast(value, variable.dtype)) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):