Do not cache external values in TPUReplicateContext.AddValue if there is no outer context since that is no-op. Doing so leads to name collisions in the present of nested graphs since op names are not necessarily unique across graphs.

PiperOrigin-RevId: 330734233
Change-Id: I51402fb233d371634cecabf46c4cb7eef2398d5c
This commit is contained in:
Saurabh Saxena 2020-09-09 09:09:11 -07:00 committed by TensorFlower Gardener
parent 7d0994a2ba
commit bf53969dfa
2 changed files with 62 additions and 31 deletions

View File

@ -604,6 +604,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if not self._outer_context:
return val
if val.name in self._values:
# Use the real value if it comes from outer context.
result = self._external_values.get(val.name)

View File

@ -19,11 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -40,49 +40,77 @@ from tensorflow.python.tpu import training_loop
class TPUContextTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def testIsInContext(self):
"""Test that control_flow_util can check that we're in a TPU context."""
z1 = array_ops.identity(1)
pivot = control_flow_ops.no_op()
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
context.Enter()
z2 = array_ops.identity(1)
context.Exit()
self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
with ops.Graph().as_default():
z1 = array_ops.identity(1)
pivot = control_flow_ops.no_op()
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
context.Enter()
z2 = array_ops.identity(1)
context.Exit()
self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
def testHandlesNameCollision(self):
"""Test AddValue handles name collisions for ops from different graphs."""
with ops.Graph().as_default():
z = array_ops.zeros([2, 3], name="a")
assert z.name == "a:0", "Expected: a:0, Found: %s" % z.name
@def_function.function
def f():
pivot = control_flow_ops.no_op()
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
context.Enter()
array_ops.identity(z) # Capture z.
z1 = array_ops.zeros([3, 2], name="a")
assert z1.name == "a:0", "Expected: a:0, Found: %s" % z1.name
z2 = array_ops.zeros([3, 2], name="a")
# Prior to fixing b/166794533 this would fail with a shape mismatch
# because context.AddValue would have cached `z` by its name which
# collides with z1's name.
result = z1 + z2
context.Exit()
return result
f.get_concrete_function()
class TPULayerRewriteTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def testUsingInfeedQueueWithRegularizer(self):
"""Test that Layer regularizers can reference data created in loops."""
def make_regularizer(scale):
return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs))
with ops.Graph().as_default():
def training_step(inputs, scale):
outputs = convolutional.conv2d(
inputs,
filters=16,
kernel_size=(3, 3),
data_format="channels_first",
kernel_regularizer=make_regularizer(scale))
loss = math_ops.reduce_mean(math_ops.square(outputs))
return loss.op
def make_regularizer(scale):
def regularizer(inputs):
return scale * math_ops.reduce_sum(math_ops.square(inputs))
return regularizer
inputs = array_ops.zeros(shape=(128, 32, 32, 16))
scale = array_ops.ones(shape=())
infeed = tpu_feed.InfeedQueue(
tuple_types=[dtypes.float32, dtypes.float32],
tuple_shapes=[inputs.shape, scale.shape])
def training_step(inputs, scale):
outputs = convolutional.conv2d(
inputs,
filters=16,
kernel_size=(3, 3),
data_format="channels_first",
kernel_regularizer=make_regularizer(scale))
loss = math_ops.reduce_mean(math_ops.square(outputs))
return loss.op
def loop():
return training_loop.repeat(5, training_step, infeed_queue=infeed)
inputs = array_ops.zeros(shape=(128, 32, 32, 16))
scale = array_ops.ones(shape=())
infeed = tpu_feed.InfeedQueue(
tuple_types=[dtypes.float32, dtypes.float32],
tuple_shapes=[inputs.shape, scale.shape])
def loop():
return training_loop.repeat(5, training_step, infeed_queue=infeed)
# This should not throw an error.
tpu.rewrite(loop)
# This should not throw an error.
tpu.rewrite(loop)
class TPUGraphPruneTest(test.TestCase):