[TF2XLA] Inject XLA context in Function._defun_with_scope

This covers more codepaths, and subsumes the need for the two previously
existing injections.

PiperOrigin-RevId: 326748241
Change-Id: I7d660282fee4127afe180bbd83d44f0a41f273d6
This commit is contained in:
George Karpenkov 2020-08-14 16:09:44 -07:00 committed by TensorFlower Gardener
parent ba58b8cafa
commit 108b0edc94
3 changed files with 29 additions and 38 deletions

View File

@ -593,6 +593,8 @@ class Function(object):
"""Creates a defun wrapped inside a variable creator scope."""
weak_wrapped_fn = None
compile_with_xla = self._experimental_compile
def wrapped_fn(*args, **kwds):
"""Wraps `self._python_function` in a variable creator scope."""
# We register a variable creator with reduced priority. If an outer
@ -607,10 +609,22 @@ class Function(object):
# and so variable initializers can't depend on function arguments. This is
# better than the alternative, tracing the initialization graph but giving
# the user a variable type they didn't want.
with ops.get_default_graph()._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
default_graph = ops.get_default_graph()
with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
# __wrapped__ allows AutoGraph to swap in a converted function. We give
# the function a weak reference to itself to avoid a reference cycle.
return weak_wrapped_fn().__wrapped__(*args, **kwds)
if compile_with_xla and \
not control_flow_util.GraphOrParentsInXlaContext(default_graph):
xla_context = control_flow_ops.XLAControlFlowContext()
try:
xla_context.Enter()
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
finally:
xla_context.Exit()
else:
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
return out
weak_wrapped_fn = weakref.ref(wrapped_fn)
return self._defun(tf_decorator.make_decorator(
@ -769,23 +783,8 @@ class Function(object):
tracing_count = self._get_tracing_count()
with trace.Trace(self._name) as tm:
if self._experimental_compile and (
not control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())):
# V2 control flow relies on XLAControlFlowContext to generate a
# XLA-compatible function graph. If the function is already called
# inside an XLA context, we don't create nested XLA context.
compiler = "xla"
xla_context = control_flow_ops.XLAControlFlowContext()
try:
xla_context.Enter()
result = self._call(*args, **kwds)
finally:
xla_context.Exit()
else:
compiler = "nonXla"
result = self._call(*args, **kwds)
compiler = "xla" if self._experimental_compile else "nonXla"
new_tracing_count = self._get_tracing_count()
without_tracing = (tracing_count == new_tracing_count)
execution_mode = "notTraced" if without_tracing else "traced"

View File

@ -218,6 +218,9 @@ class DefFunctionTest(xla_test.XLATestCase):
y = f(x)
return y, tape.gradient(y, x)
# Test that XLA context gets correctly propagated.
g._get_concrete_function_garbage_collected(2.0)(2.0)
self.assertAllClose(40.0, f(2.0))
self.assertAllClose([40.0, 28.0], g(2.0))
self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))

View File

@ -57,7 +57,6 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import functional_ops
@ -1940,24 +1939,14 @@ class ConcreteFunction(object):
possible_gradient_type,
executing_eagerly)
forward_function, args_with_tangents = forward_backward.forward()
compiled_with_xla = self._attrs.get("_XlaMustCompile", False) and \
not control_flow_util.GraphOrParentsInXlaContext(default_graph)
xla_context = control_flow_ops.XLAControlFlowContext()
try:
if compiled_with_xla:
xla_context.Enter()
if executing_eagerly:
flat_outputs = forward_function.call(
ctx, args_with_tangents,
cancellation_manager=cancellation_manager)
ctx, args_with_tangents, cancellation_manager=cancellation_manager)
else:
with default_graph._override_gradient_function( # pylint: disable=protected-access
{"PartitionedCall": self._get_gradient_function(),
"StatefulPartitionedCall": self._get_gradient_function()}):
flat_outputs = forward_function.call(ctx, args_with_tangents)
finally:
if compiled_with_xla:
xla_context.Exit()
forward_backward.record(flat_outputs)
return self._build_call_outputs(flat_outputs)