Workaround for bizarre issue related to aliasing tensors in cond_v2's transform
PiperOrigin-RevId: 224228895
This commit is contained in:
parent
3f43965a44
commit
5abe466090
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -412,11 +414,19 @@ class BatchNormalizationV2(Layer):
|
||||
def _assign_moving_average(self, variable, value, momentum):
|
||||
with ops.name_scope(None, 'AssignMovingAvg',
|
||||
[variable, value, momentum]) as scope:
|
||||
with ops.colocate_with(variable):
|
||||
# TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
|
||||
# because of a bug which leads cond_v2 to skip rewriting them creating
|
||||
# conflicts.
|
||||
if tf2.enabled():
|
||||
cm = contextlib.contextmanager(lambda: (yield))
|
||||
else:
|
||||
cm = ops.colocate_with(variable)
|
||||
with cm:
|
||||
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
|
||||
if decay.dtype != variable.dtype.base_dtype:
|
||||
decay = math_ops.cast(decay, variable.dtype.base_dtype)
|
||||
update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
|
||||
update_delta = (
|
||||
variable - math_ops.cast(value, variable.dtype)) * decay
|
||||
return state_ops.assign_sub(variable, update_delta, name=scope)
|
||||
|
||||
def _fused_batch_norm(self, inputs, training):
|
||||
|
Loading…
Reference in New Issue
Block a user