From 5abe466090e4e115389a017634880907f5a0c172 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Wed, 5 Dec 2018 15:08:01 -0800 Subject: [PATCH] Workaround for bizarre issue related to aliasing tensors in cond_v2's transform PiperOrigin-RevId: 224228895 --- tensorflow/python/keras/layers/normalization.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index d9584976555..37894a3d3d1 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib + from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -412,11 +414,19 @@ class BatchNormalizationV2(Layer): def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: - with ops.colocate_with(variable): + # TODO(apassos,srbs,skyewm): the colocation constraints here are disabled + # because of a bug which leads cond_v2 to skip rewriting them creating + # conflicts. + if tf2.enabled(): + cm = contextlib.contextmanager(lambda: (yield)) + else: + cm = ops.colocate_with(variable) + with cm: decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay + update_delta = ( + variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training):