Workaround for bizarre issue related to aliasing tensors in cond_v2's transform
PiperOrigin-RevId: 224228895
This commit is contained in:
parent
3f43965a44
commit
5abe466090
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -412,11 +414,19 @@ class BatchNormalizationV2(Layer):
|
|||||||
def _assign_moving_average(self, variable, value, momentum):
|
def _assign_moving_average(self, variable, value, momentum):
|
||||||
with ops.name_scope(None, 'AssignMovingAvg',
|
with ops.name_scope(None, 'AssignMovingAvg',
|
||||||
[variable, value, momentum]) as scope:
|
[variable, value, momentum]) as scope:
|
||||||
with ops.colocate_with(variable):
|
# TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
|
||||||
|
# because of a bug which leads cond_v2 to skip rewriting them creating
|
||||||
|
# conflicts.
|
||||||
|
if tf2.enabled():
|
||||||
|
cm = contextlib.contextmanager(lambda: (yield))
|
||||||
|
else:
|
||||||
|
cm = ops.colocate_with(variable)
|
||||||
|
with cm:
|
||||||
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
|
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
|
||||||
if decay.dtype != variable.dtype.base_dtype:
|
if decay.dtype != variable.dtype.base_dtype:
|
||||||
decay = math_ops.cast(decay, variable.dtype.base_dtype)
|
decay = math_ops.cast(decay, variable.dtype.base_dtype)
|
||||||
update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
|
update_delta = (
|
||||||
|
variable - math_ops.cast(value, variable.dtype)) * decay
|
||||||
return state_ops.assign_sub(variable, update_delta, name=scope)
|
return state_ops.assign_sub(variable, update_delta, name=scope)
|
||||||
|
|
||||||
def _fused_batch_norm(self, inputs, training):
|
def _fused_batch_norm(self, inputs, training):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user