Do not cache external values in TPUReplicateContext.AddValue if there is no outer context since that is no-op. Doing so leads to name collisions in the present of nested graphs since op names are not necessarily unique across graphs.
PiperOrigin-RevId: 330734233 Change-Id: I51402fb233d371634cecabf46c4cb7eef2398d5c
This commit is contained in:
parent
7d0994a2ba
commit
bf53969dfa
@ -604,6 +604,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
|
||||
def AddValue(self, val):
|
||||
"""Add `val` to the current context and its outer context recursively."""
|
||||
if not self._outer_context:
|
||||
return val
|
||||
|
||||
if val.name in self._values:
|
||||
# Use the real value if it comes from outer context.
|
||||
result = self._external_values.get(val.name)
|
||||
|
@ -19,11 +19,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -40,49 +40,77 @@ from tensorflow.python.tpu import training_loop
|
||||
|
||||
class TPUContextTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testIsInContext(self):
|
||||
"""Test that control_flow_util can check that we're in a TPU context."""
|
||||
z1 = array_ops.identity(1)
|
||||
pivot = control_flow_ops.no_op()
|
||||
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
|
||||
context.Enter()
|
||||
z2 = array_ops.identity(1)
|
||||
context.Exit()
|
||||
self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
|
||||
self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
|
||||
with ops.Graph().as_default():
|
||||
z1 = array_ops.identity(1)
|
||||
pivot = control_flow_ops.no_op()
|
||||
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
|
||||
context.Enter()
|
||||
z2 = array_ops.identity(1)
|
||||
context.Exit()
|
||||
self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
|
||||
self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
|
||||
|
||||
def testHandlesNameCollision(self):
|
||||
"""Test AddValue handles name collisions for ops from different graphs."""
|
||||
with ops.Graph().as_default():
|
||||
z = array_ops.zeros([2, 3], name="a")
|
||||
assert z.name == "a:0", "Expected: a:0, Found: %s" % z.name
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
pivot = control_flow_ops.no_op()
|
||||
context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
|
||||
context.Enter()
|
||||
array_ops.identity(z) # Capture z.
|
||||
z1 = array_ops.zeros([3, 2], name="a")
|
||||
assert z1.name == "a:0", "Expected: a:0, Found: %s" % z1.name
|
||||
z2 = array_ops.zeros([3, 2], name="a")
|
||||
# Prior to fixing b/166794533 this would fail with a shape mismatch
|
||||
# because context.AddValue would have cached `z` by its name which
|
||||
# collides with z1's name.
|
||||
result = z1 + z2
|
||||
context.Exit()
|
||||
return result
|
||||
|
||||
f.get_concrete_function()
|
||||
|
||||
|
||||
class TPULayerRewriteTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testUsingInfeedQueueWithRegularizer(self):
|
||||
"""Test that Layer regularizers can reference data created in loops."""
|
||||
|
||||
def make_regularizer(scale):
|
||||
return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs))
|
||||
with ops.Graph().as_default():
|
||||
|
||||
def training_step(inputs, scale):
|
||||
outputs = convolutional.conv2d(
|
||||
inputs,
|
||||
filters=16,
|
||||
kernel_size=(3, 3),
|
||||
data_format="channels_first",
|
||||
kernel_regularizer=make_regularizer(scale))
|
||||
loss = math_ops.reduce_mean(math_ops.square(outputs))
|
||||
return loss.op
|
||||
def make_regularizer(scale):
|
||||
def regularizer(inputs):
|
||||
return scale * math_ops.reduce_sum(math_ops.square(inputs))
|
||||
return regularizer
|
||||
|
||||
inputs = array_ops.zeros(shape=(128, 32, 32, 16))
|
||||
scale = array_ops.ones(shape=())
|
||||
infeed = tpu_feed.InfeedQueue(
|
||||
tuple_types=[dtypes.float32, dtypes.float32],
|
||||
tuple_shapes=[inputs.shape, scale.shape])
|
||||
def training_step(inputs, scale):
|
||||
outputs = convolutional.conv2d(
|
||||
inputs,
|
||||
filters=16,
|
||||
kernel_size=(3, 3),
|
||||
data_format="channels_first",
|
||||
kernel_regularizer=make_regularizer(scale))
|
||||
loss = math_ops.reduce_mean(math_ops.square(outputs))
|
||||
return loss.op
|
||||
|
||||
def loop():
|
||||
return training_loop.repeat(5, training_step, infeed_queue=infeed)
|
||||
inputs = array_ops.zeros(shape=(128, 32, 32, 16))
|
||||
scale = array_ops.ones(shape=())
|
||||
infeed = tpu_feed.InfeedQueue(
|
||||
tuple_types=[dtypes.float32, dtypes.float32],
|
||||
tuple_shapes=[inputs.shape, scale.shape])
|
||||
|
||||
def loop():
|
||||
return training_loop.repeat(5, training_step, infeed_queue=infeed)
|
||||
|
||||
# This should not throw an error.
|
||||
tpu.rewrite(loop)
|
||||
|
||||
# This should not throw an error.
|
||||
tpu.rewrite(loop)
|
||||
|
||||
class TPUGraphPruneTest(test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user