[TF2XLA] Inject XLA context in Function._defun_with_scope
This covers more codepaths, and subsumes the need for the two previously existing injections. PiperOrigin-RevId: 326748241 Change-Id: I7d660282fee4127afe180bbd83d44f0a41f273d6
This commit is contained in:
parent
ba58b8cafa
commit
108b0edc94
@ -593,6 +593,8 @@ class Function(object):
|
||||
"""Creates a defun wrapped inside a variable creator scope."""
|
||||
|
||||
weak_wrapped_fn = None
|
||||
compile_with_xla = self._experimental_compile
|
||||
|
||||
def wrapped_fn(*args, **kwds):
|
||||
"""Wraps `self._python_function` in a variable creator scope."""
|
||||
# We register a variable creator with reduced priority. If an outer
|
||||
@ -607,10 +609,22 @@ class Function(object):
|
||||
# and so variable initializers can't depend on function arguments. This is
|
||||
# better than the alternative, tracing the initialization graph but giving
|
||||
# the user a variable type they didn't want.
|
||||
with ops.get_default_graph()._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
|
||||
default_graph = ops.get_default_graph()
|
||||
with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
|
||||
# __wrapped__ allows AutoGraph to swap in a converted function. We give
|
||||
# the function a weak reference to itself to avoid a reference cycle.
|
||||
return weak_wrapped_fn().__wrapped__(*args, **kwds)
|
||||
if compile_with_xla and \
|
||||
not control_flow_util.GraphOrParentsInXlaContext(default_graph):
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
try:
|
||||
xla_context.Enter()
|
||||
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
|
||||
finally:
|
||||
xla_context.Exit()
|
||||
else:
|
||||
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
|
||||
return out
|
||||
|
||||
weak_wrapped_fn = weakref.ref(wrapped_fn)
|
||||
|
||||
return self._defun(tf_decorator.make_decorator(
|
||||
@ -769,23 +783,8 @@ class Function(object):
|
||||
|
||||
tracing_count = self._get_tracing_count()
|
||||
with trace.Trace(self._name) as tm:
|
||||
if self._experimental_compile and (
|
||||
not control_flow_util.GraphOrParentsInXlaContext(
|
||||
ops.get_default_graph())):
|
||||
# V2 control flow relies on XLAControlFlowContext to generate a
|
||||
# XLA-compatible function graph. If the function is already called
|
||||
# inside an XLA context, we don't create nested XLA context.
|
||||
compiler = "xla"
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
try:
|
||||
xla_context.Enter()
|
||||
result = self._call(*args, **kwds)
|
||||
finally:
|
||||
xla_context.Exit()
|
||||
else:
|
||||
compiler = "nonXla"
|
||||
result = self._call(*args, **kwds)
|
||||
|
||||
result = self._call(*args, **kwds)
|
||||
compiler = "xla" if self._experimental_compile else "nonXla"
|
||||
new_tracing_count = self._get_tracing_count()
|
||||
without_tracing = (tracing_count == new_tracing_count)
|
||||
execution_mode = "notTraced" if without_tracing else "traced"
|
||||
|
@ -218,6 +218,9 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
y = f(x)
|
||||
return y, tape.gradient(y, x)
|
||||
|
||||
# Test that XLA context gets correctly propagated.
|
||||
g._get_concrete_function_garbage_collected(2.0)(2.0)
|
||||
|
||||
self.assertAllClose(40.0, f(2.0))
|
||||
self.assertAllClose([40.0, 28.0], g(2.0))
|
||||
self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))
|
||||
|
@ -57,7 +57,6 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import default_gradient
|
||||
from tensorflow.python.ops import functional_ops
|
||||
@ -1940,24 +1939,14 @@ class ConcreteFunction(object):
|
||||
possible_gradient_type,
|
||||
executing_eagerly)
|
||||
forward_function, args_with_tangents = forward_backward.forward()
|
||||
compiled_with_xla = self._attrs.get("_XlaMustCompile", False) and \
|
||||
not control_flow_util.GraphOrParentsInXlaContext(default_graph)
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
try:
|
||||
if compiled_with_xla:
|
||||
xla_context.Enter()
|
||||
if executing_eagerly:
|
||||
flat_outputs = forward_function.call(
|
||||
ctx, args_with_tangents,
|
||||
cancellation_manager=cancellation_manager)
|
||||
else:
|
||||
with default_graph._override_gradient_function( # pylint: disable=protected-access
|
||||
{"PartitionedCall": self._get_gradient_function(),
|
||||
"StatefulPartitionedCall": self._get_gradient_function()}):
|
||||
flat_outputs = forward_function.call(ctx, args_with_tangents)
|
||||
finally:
|
||||
if compiled_with_xla:
|
||||
xla_context.Exit()
|
||||
if executing_eagerly:
|
||||
flat_outputs = forward_function.call(
|
||||
ctx, args_with_tangents, cancellation_manager=cancellation_manager)
|
||||
else:
|
||||
with default_graph._override_gradient_function( # pylint: disable=protected-access
|
||||
{"PartitionedCall": self._get_gradient_function(),
|
||||
"StatefulPartitionedCall": self._get_gradient_function()}):
|
||||
flat_outputs = forward_function.call(ctx, args_with_tangents)
|
||||
forward_backward.record(flat_outputs)
|
||||
return self._build_call_outputs(flat_outputs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user