Automated rollback of commit 302acb768a8f00605f84430fabd28b6daa2c4c77

PiperOrigin-RevId: 240438115
This commit is contained in:
Yanan Cao 2019-03-26 15:15:34 -07:00 committed by TensorFlower Gardener
parent 6f9ed358d9
commit 8573a5abc8
20 changed files with 1160 additions and 908 deletions

View File

@ -10,6 +10,7 @@ package_group(
"//tensorflow/compiler/tests/...",
"//tensorflow/compiler/tf2xla/...",
"//tensorflow/contrib/compiler/...",
"//tensorflow/python/compiler/...",
],
)

View File

@ -24,49 +24,20 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":xla",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/compiler/xla:compiler_py",
],
)
cuda_py_test(
name = "jit_test",
size = "small",
srcs = ["jit_test.py"],
additional_deps = [
":compiler_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
xla_enabled = True,
)
py_library(
name = "xla",
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/jit:xla_ops_py",
"//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/compiler/xla:compiler_py",
"//tensorflow/python/estimator:estimator_py",
],
)
@ -79,17 +50,12 @@ cuda_py_test(
"@absl_py//absl/testing:parameterized",
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
],
tags = [

View File

@ -18,101 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.compiler.xla import jit
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
_XLA_SCOPE_KEY = ("__xla_scope",)
class _XlaScope(object):
"""Keeps track of previous XLA scope calls, and depth of current call."""
def __init__(self, count, depth):
self.count = count
self.depth = depth
@contextlib.contextmanager
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
"""Enable or disable JIT compilation of operators within the scope.
NOTE: This is an experimental feature.
The compilation is a hint and only supported on a best-effort basis.
Example usage:
with tf.contrib.compiler.experimental_jit_scope():
c = tf.matmul(a, b) # compiled
with tf.contrib.compiler.experimental_jit_scope(compile_ops=False):
d = tf.matmul(a, c) # not compiled
with tf.contrib.compiler.experimental_jit_scope(
compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
Example of separate_compiled_gradients:
# In the example below, the computations for f, g and h will all be compiled
# in separate scopes.
with tf.contrib.compiler.experimental_jit_scope(
separate_compiled_gradients=True):
f = tf.matmul(a, b)
g = tf.gradients([f], [a, b], name='mygrads1')
h = tf.gradients([f], [a, b], name='mygrads2')
Args:
compile_ops: Whether to enable or disable compilation in the scope.
Either a Python bool, or a callable that accepts the parameter
`node_def` and returns a python bool.
separate_compiled_gradients: If true put each gradient subgraph into a
separate compilation scope. This gives fine-grained control over which
portions of the graph will be compiled as a single unit. Compiling
gradients separately may yield better performance for some graphs.
The scope is named based on the scope of the forward computation as well
as the name of the gradients. As a result, the gradients will be compiled
in a scope that is separate from both the forward computation, and from
other gradients.
Yields:
The current scope, enabling or disabling compilation.
"""
if callable(compile_ops):
def xla_compile(node_def):
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
else:
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
attrs = {
"_XlaCompile":
xla_compile,
"_XlaSeparateCompiledGradients":
attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
}
# Find the singleton counter for the current scoped graph. If it
# doesn't exist, create one.
xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
if not xla_scope_counter:
xla_scope_counter = _XlaScope(0, 0)
ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
else:
xla_scope_counter = xla_scope_counter[0]
if xla_scope_counter.depth == 0:
# If we're at the root xla scope, we can increase the counter so
# future calls to jit_scope use a different scope value.
# If we're already within a scope, we'll be fusing using the scope
# controlled by the parent.
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
s=("jit_scope_%d" % xla_scope_counter.count).encode())
xla_scope_counter.count += 1
xla_scope_counter.depth += 1
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope(attrs):
yield
# pylint: enable=protected-access
xla_scope_counter.depth -= 1
experimental_jit_scope = jit.experimental_jit_scope

View File

@ -18,511 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.compiler.xla import xla
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
# Operations that indicate some error in the users graph. For example, XLA
# computation should not have any Placeholder op.
_BLACKLISTED_OPS = set([
'Placeholder',
])
# XLA doesn't currently support reading of intermediate tensors, thus some ops
# are not supported.
_UNSUPPORTED_OPS = set([
'AudioSummary',
'AudioSummaryV2',
'HistogramSummary',
'ImageSummary',
'MergeSummary',
'Print',
'ScalarSummary',
'TensorSummary',
'TensorSummaryV2',
])
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
"""Builds an operator that compiles and runs `computation` with XLA.
Args:
computation: A Python function that builds a computation to apply to the
input. If the function takes n inputs, 'inputs' should be a list of n
tensors.
`computation` may return a list of operations and tensors. Tensors must
come before operations in the returned list. The return value of
`compile` is a list of tensors corresponding to the tensors from the
output of `computation`.
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors. Note that passing an N-dimension list of compatible values will
result in a N-dimention list of scalar tensors rather than a single Rank-N
tensors. If you need different behavior, convert part of inputs to tensors
with `tf.convert_to_tensor`.
Returns:
Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include:
1) None output: a NoOp would be returned which control-depends on
computation.
2) Single value output: A tuple containing the value would be returned.
3) Operation-only outputs: a NoOp would be returned which
control-depends on computation.
TODO(b/121383831): Investigate into removing these special cases.
"""
# pylint: disable=protected-access
return _compile_internal(computation, inputs)
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
The primary role of `XLACompileContext` is to mark operators inside a
xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
a unique name.
`ControlFlowContext` is used to perform the annotation since it integrates
with Tensorflow constructs like ResourceVariables. For example, if a
`ResourceVariable` is constructed inside a xla.compile() block, the
`ResourceVariable` implementation can use
`with ops.control_dependencies(None)` to build the variable's definition
outside the compiled computation.
"""
def __init__(self, name, pivot):
"""Builds a new XLACompileContext.
Args:
name: a unique name for the context, used to populate the
`_xla_compile_id` attribute.
pivot: a pivot node. Nodes in the XLACompileContext that do not have any
inputs will have a control dependency on the pivot node. This ensures
that nodes are correctly included in any enclosing control flow
contexts.
"""
super(XLACompileContext, self).__init__()
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._unsupported_ops = []
self._pivot = pivot
def report_unsupported_operations(self):
if self._unsupported_ops:
op_str = '\n'.join([
' %s (%s)' % (op.type, op.name)
for op in self._unsupported_ops[:_MAX_WARNING_LINES]
])
logging.warning('%d unsupported operations found: \n%s',
len(self._unsupported_ops), op_str)
if len(self._unsupported_ops) > _MAX_WARNING_LINES:
logging.warning('... and %d more',
len(self._unsupported_ops) - _MAX_WARNING_LINES)
def _RemoveExternalControlEdges(self, op):
"""Remove any external control dependency on this op."""
internal_control_inputs = []
external_control_inputs = []
for x in op.control_inputs:
# pylint: disable=protected-access
is_internal_op = False
ctxt = x._get_control_flow_context()
while ctxt is not None:
if ctxt == self:
is_internal_op = True
break
ctxt = ctxt._outer_context
if is_internal_op:
internal_control_inputs.append(x)
else:
external_control_inputs.append(x)
# pylint: enable=protected-access
# pylint: disable=protected-access
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
# pylint: enable=protected-access
return internal_control_inputs, external_control_inputs
def AddOp(self, op):
"""Create op in XLACompileContext and notifies outer context recursively."""
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
logging.error(
'Operation of type %s (%s) is not supported in XLA. Execution will '
'fail if this op is used in the graph. ', op.type, op.name)
# TODO(ycao): Automatically disable summaries instead of reporting them.
if op.type in _UNSUPPORTED_OPS:
self._unsupported_ops.append(op)
if any(x.dtype._is_ref_dtype for x in op.inputs):
raise NotImplementedError(
'Non-resource Variables are not supported inside XLA computations '
'(operator name: %s)' % op.name)
if _XLA_COMPILE_ATTR in op.node_def.attr:
raise ValueError('XLA compiled computations cannot be nested, (operator '
'name: %s)' % op.name)
op._set_attr(
_XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
# Remove any control edges from outer control flow contexts. These may cause
# mismatched frame errors. An example is when one of op's inputs is
# generated in a different While control flow context.
(internal_control_inputs,
external_control_inputs) = self._RemoveExternalControlEdges(op)
if not op.inputs:
# Add a control edge from the control pivot to this op.
if not internal_control_inputs:
# pylint: disable=protected-access
op._add_control_input(self._pivot)
# pylint: enable=protected-access
else:
for index in xrange(len(op.inputs)):
x = op.inputs[index]
real_x = self.AddValue(x)
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
if external_control_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(phawkins): fix that.
with ops.control_dependencies(None):
self.Enter()
external_control_inputs = [
array_ops.identity(x.outputs[0]).op
for x in external_control_inputs
if x.outputs
]
self.Exit()
# pylint: disable=protected-access
op._add_control_inputs(external_control_inputs)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
output_names = [x.name for x in op.outputs]
context = self
while context is not None:
# pylint: disable=protected-access
context._values.update(output_names)
context = context._outer_context
# pylint: enable=protected-access
if self._outer_context:
self._outer_context.AddInnerOp(op)
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context.
result = self._external_values.get(val.name)
return val if result is None else result
result = val
self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
self._values.add(result.name)
self._external_values[val.name] = result
return result
def AddInnerOp(self, op):
self.AddOp(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
@property
def grad_state(self):
# Define the gradient loop state associated with the XLACompileContext to
# be None as the XLACompileContext does not get nested nor does the
# grad_state outside the XLACompileContext affect the graph inside so the
# grad_state should be as if this is the top-level gradient state.
return None
@property
def back_prop(self):
"""Forwards to the enclosing while context, if any."""
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
def _compile_internal(computation, inputs=None):
"""Builds graph operators that compiles and symbolically executes computation.
Args:
computation: A Python function that builds the computation to compile and
execute.
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors. Note that passing an N-dimension list of compatible values will
result in a N-dimension list of scalar tensors rather than a single Rank-N
tensors. If you need different behavior, convert part of inputs to tensors
with `tf.convert_to_tensor`.
Returns:
Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include: 1) None output 2) Single
value output 3) Operation-only outputs
Raises:
ValueError: If any element in computation outputs is neither an operations
or a value that can be converted to tensor.
ValueError: If computation outputs is non-flat and contains any Operations.
TypeError: If `inputs` is not a list or tuple.
"""
if inputs is None:
inputs = []
if not isinstance(inputs, collections.Sequence):
raise TypeError('inputs must be a list')
# Flatten inputs.
flat_inputs = nest.flatten(inputs)
# Converts inputs to Tensors.
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
cluster_name = ops.get_default_graph().unique_name('cluster')
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
context = XLACompileContext(name=cluster_name, pivot=pivot)
try:
context.Enter()
# Add identity ops so even unused inputs are 'consumed' by the
# computation.
flat_inputs = [
array_ops.identity(x, name='input_{}'.format(i))
for i, x in enumerate(flat_inputs)
]
# Re-pack flat_inputs in same structure as 'inputs'.
computation_inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
# Only resource variables work inside an XLA computation, so turn on
# resource variables for the computation.
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
vscope.set_use_resource(True)
with _disable_summary_context():
outputs = computation(*computation_inputs)
# Restore variable scope after computation.
vscope.set_use_resource(saved_use_resource)
outputs_is_flat = is_flat(outputs)
if outputs_is_flat:
output_tensors, control_deps = _postprocess_flat_outputs(outputs)
else:
output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
# When XLA computation returns only operations and no tensors, a NoOp
# dependent on the operations in outputs is returned. Otherwise final
# outputs would be empty and there is no way to trigger returned
# operations.
if not output_tensors:
return control_flow_ops.group(control_deps, name='output_0')
output_tensors = [
xla_ops.xla_cluster_output(o, name='output{}'.format(i))
for i, o in enumerate(output_tensors)
]
with ops.control_dependencies(control_deps):
# Wraps the outputs in identity operators that carries control
# dependencies.
output_tensors = [
array_ops.identity(o, name='output_%d' % i)
for i, o in enumerate(output_tensors)
]
# If `computation` returned non-flat output structure, pack output tensors
# back into same structure.
if not outputs_is_flat:
output_tensors = nest.pack_sequence_as(
structure=outputs, flat_sequence=output_tensors)
return output_tensors
def is_flat(outputs):
"""Checks if outputs is a flat structure.
Following structures and values are considered flat:
1) None
2) A single object
3) A list or tuple of Tensors/Operations
The only structures that this function understands are sequences and
dictionaries. E.g. this means that if outputs contains a single
user-defined Object, it is considered to be flat. Errors are raised later on
if that Object cannot be converted to a Tensor.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
A boolean indicates whether outputs is flat.
"""
# If outputs is a list or tuple, check if it has any nested structure. If
# there is, then outputs is non-flat.
if isinstance(outputs, collections.Sequence):
for o in outputs:
if isinstance(o, collections.Sequence) or isinstance(o, dict):
return False
# If outputs is a dict, it is non-flat.
if isinstance(outputs, dict):
return False
# Getting here means either outputs itself is a single non-structured value
# or it is a flat list of single non-structured values.
return True
def _postprocess_flat_outputs(outputs):
"""Validates flat outputs and adds back device assignments.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
Tensors and Operations extracted from outputs.
"""
# Following code segment is to preserve legacy behavior. Previously we only
# supported flat outputs and thus for consistency it was nice to convert even
# single element into a tuple. But now that we support arbitrary output
# structure, this is no longer necessary.
# TODO(b/121383831): Migrate all legacy use cases and delete this special
# case.
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
outputs = tuple()
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
# Append `no_op` here so that return value of this function always contains
# at least one op that can trigger XlaLaunch node.
outputs += (control_flow_ops.no_op(),)
try:
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
except Exception as e:
raise ValueError(
'XLA computation function return values must all either be Operations'
' or convertible to Tensors. Got error: "%s"' % str(e))
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
'XLA computation function must return zero or more Tensor values '
'followed by zero or more Operations.')
new_output_tensors = []
for t in output_tensors:
with ops.device(t.device if t.device else ''):
new_output_tensors.append(array_ops.identity(t))
return new_output_tensors, output_operations
def _postprocess_non_flat_outputs(outputs):
"""Validates non-flat outputs and adds back device assignments.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
Tensors extracted from outputs and an empty list because Operations are not
allowed in non-flat outputs..
"""
# Convert all non-Operation outputs to Tensors.
new_output_tensors = []
for o in nest.flatten(outputs):
if isinstance(o, ops.Operation):
raise ValueError(
'xla.compile does not support Operation as return value in non-flat '
'output structure. You can set returned Operations as control '
'dependencies of returned Tensors so Operations are triggered when '
'Tensors are evaluated. Operation found: "%s"' % o.name)
try:
o = ops.convert_to_tensor(o)
except Exception as e:
raise ValueError(
'XLA computation function return values must all either be '
'Operations or convertible to Tensors. Got error: "%s"' % str(e))
# Makes sure even pass-through inputs/outputs are touched in compile
# context by creating an Identity node inside compile context.
with ops.device(o.device if o.device else ''):
new_output_tensors.append(array_ops.identity(o))
return new_output_tensors, []
@contextlib.contextmanager
def _disable_summary_context():
"""Enters a context where all summary ops are skipped.
Summaries are not yet supported in xla.compile(). So we provide this context
manager that can skip creating summary ops. This is a temporary workaround due
to XLA not supporting summary ops.
Yields:
None.
"""
original_skip_summary_func = summary_op_util.skip_summary
summary_op_util.skip_summary = lambda: True
try:
yield
finally:
summary_op_util.skip_summary = original_skip_summary_func
compile = xla.compile # pylint: disable=redefined-builtin
check_function_argument_count = xla.check_function_argument_count
class _CapturedObject(object):
"""A placeholder to capture an object."""
@ -788,51 +294,3 @@ def estimator_model_fn(target_model_fn=None):
return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
return decorated(target_model_fn) if target_model_fn else decorated
def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to an XLA function.
Args:
func: the Python function that will be called to generate the body of an XLA
computation graph.
input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
Returns:
None if function can be called with the supplied number of
arguments, or an error string if it cannot.
"""
def format_error(complaint, quantity):
return '%s %d argument%s' % (complaint, quantity, ''
if quantity == 1 else 's')
num_args_supplied = input_arity
if infeed_queue is not None:
num_args_supplied += infeed_queue.number_of_tuple_elements
arg_spec = tf_inspect.getargspec(func)
num_func_args = len(arg_spec.args)
if arg_spec.defaults is None:
num_func_defaults = 0
else:
num_func_defaults = len(arg_spec.defaults)
min_func_args = num_func_args - num_func_defaults
if num_args_supplied < min_func_args:
# The required number of arguments is not enough to call the function.
if num_func_defaults == 0 and arg_spec.varargs is None:
return format_error('exactly', num_func_args)
else:
return format_error('at least', min_func_args)
if arg_spec.varargs is None and num_args_supplied > num_func_args:
# The required number of arguments is too many to call the function.
if num_func_defaults == 0:
return format_error('exactly', num_func_args)
else:
return format_error('at most', num_func_args)
# Reaching here means either
# 1) There are varargs, func can accept any number of arguments greater than
# the minimum.
# 2) Number of supplied arguments falls in range of acceptable argument count
# of func.
return None

View File

@ -23,20 +23,13 @@ from absl.testing import parameterized
from tensorflow.contrib.compiler import xla
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.training.python.training import hparam
from tensorflow.python import summary
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import training
@ -48,226 +41,6 @@ _EXPECTED_FEATURE = 2
_EXPECTED_LABEL = 3
class XLACompileContextTest(test.TestCase):
def create_test_xla_compile_context(self):
computation_name = ops.get_default_graph().unique_name('computation')
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
return xla.XLACompileContext(name=computation_name, pivot=pivot)
def test_report_unsupported_operations(self):
"""Tests that unsupported operations are detected."""
context = self.create_test_xla_compile_context()
context.Enter()
dummy_tensor = constant_op.constant(1.1)
audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
image_summary = summary.image('image_summary', dummy_tensor)
scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor)
summary.merge(
[
audio_summary, histogram_summary, image_summary, scalar_summary,
tensor_summary
],
name='merge_summary')
logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
context.Exit()
unsupported_ops_names = [op.name for op in context._unsupported_ops]
self.assertEqual(unsupported_ops_names, [
u'audio_summary', u'histogram_summary', u'image_summary',
u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
u'print_op'
])
def test_resource_variable(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=True)
context = self.create_test_xla_compile_context()
context.Enter()
state_ops.assign(a, a + 1)
context.Exit()
def test_non_resource_variable_error(self):
"""Tests that non-resource variable usage is disallowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=False)
context = self.create_test_xla_compile_context()
context.Enter()
with self.assertRaisesRegexp(
NotImplementedError, 'Non-resource Variables are not supported inside '
r'XLA computations \(operator name: Assign\)'):
state_ops.assign(a, a + 1)
context.Exit()
def test_nested_xla_compile_error(self):
"""Tests that nested XLA computation leads to fatal error."""
context1 = self.create_test_xla_compile_context()
context1.Enter()
context2 = self.create_test_xla_compile_context()
context2.Enter()
with self.assertRaisesRegexp(ValueError,
'XLA compiled computations cannot be nested'):
constant_op.constant(1)
context2.Exit()
context1.Exit()
def test_xla_compile_attr(self):
"""Tests that ops are tagged with XLA compile ID attribute."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn('_xla_compile_id', op.op.node_def.attr)
def test_op_without_input(self):
"""Tests that ops without inputs depend on pivot correctly."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(context._pivot, op.op.control_inputs)
def test_external_control_edges(self):
"""Tests that external control edges are handled correctly."""
i = constant_op.constant(1)
op1 = constant_op.constant(1)
with ops.control_dependencies([op1]):
op2 = constant_op.constant(1)
self.assertIn(op1.op, op2.op.control_inputs)
def while_body(i):
del i # unused
context = self.create_test_xla_compile_context()
context.Enter()
with ops.control_dependencies([op1]):
op3 = constant_op.constant(1)
context.Exit()
self.assertNotIn(op1.op, op3.op.control_inputs)
return op3
control_flow_ops.while_loop(
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
def test_op_output_marked_as_seen(self):
"""Tests that any op output is marked as seen in context."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(op.name, context._values)
def testOpIsInContext(self):
"""Tests that XLACompileContext is recognized as an XLA context."""
op1 = constant_op.constant(1)
context = self.create_test_xla_compile_context()
context.Enter()
op2 = constant_op.constant(2)
context.Exit()
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
def testOpPreventFeeding(self):
"""Tests that ops created inside XLACompileContext can not be fed."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_feedable(op.op))
def testOpPreventFetching(self):
"""Tests that ops created inside XLACompileContext can not be fetched."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_fetchable(op.op))
class CheckFunctionArgumentCountTest(test.TestCase):
def testSimple(self):
"""Tests that arg checker works for functions with no varargs or defaults.
"""
def func(x, y, z):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, None))
queue = tpu_feed.InfeedQueue(2)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, queue))
def testDefaultArgs(self):
"""Tests that arg checker works for a function with no varargs."""
def func(x, y, z=17):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, queue))
def testVarArgs(self):
"""Tests that arg checker works for a function with varargs."""
def func(x, y, *z):
return x + y + len(z)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
def testVarArgsAndDefaults(self):
"""Tests that arg checker works for a function with varargs and defaults."""
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
return x + y + z + len(q)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
def _test_train_model_fn(features, labels, mode, params):
"""A dummy model_fn for testing purpose."""
del features, labels, params

View File

@ -144,6 +144,10 @@ from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
# XLA JIT compiler APIs.
from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla
# Required due to `rnn` and `rnn_cell` not being imported in `nn` directly
# (due to a circular dependency issue: rnn depends on layers).
nn.dynamic_rnn = rnn.dynamic_rnn

View File

@ -15,5 +15,7 @@ py_library(
srcs_version = "PY2AND3",
deps = if_not_windows([
"//tensorflow/python/compiler/tensorrt:init_py",
]),
]) + [
"//tensorflow/python/compiler/xla:compiler_py",
],
)

View File

@ -0,0 +1,85 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "compiler_py",
srcs = [
"__init__.py",
"jit.py",
],
srcs_version = "PY2AND3",
deps = [
":xla",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
cuda_py_test(
name = "jit_test",
size = "small",
srcs = ["jit_test.py"],
additional_deps = [
":compiler_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
xla_enabled = True,
)
py_library(
name = "xla",
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/jit:xla_ops_py",
"//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util",
],
)
cuda_py_test(
name = "xla_test",
srcs = ["xla_test.py"],
additional_deps = [
":xla",
"@absl_py//absl/testing:parameterized",
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:variable_scope",
],
tags = [
"no_mac",
"no_windows",
],
xla_enabled = True,
)

View File

@ -0,0 +1,24 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A module for controlling the Tensorflow/XLA JIT compiler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla
# pylint: enable=unused-import

View File

@ -0,0 +1,120 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for controlling the Tensorflow/XLA JIT compiler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.util.tf_export import tf_export
_XLA_SCOPE_KEY = ("__xla_scope",)
class _XlaScope(object):
"""Keeps track of previous XLA scope calls, and depth of current call."""
def __init__(self, count, depth):
self.count = count
self.depth = depth
@contextlib.contextmanager
@tf_export(v1=["xla.experimental.jit_scope"])
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
"""Enable or disable JIT compilation of operators within the scope.
NOTE: This is an experimental feature.
The compilation is a hint and only supported on a best-effort basis.
Example usage:
with tf.xla.experimental.jit_scope():
c = tf.matmul(a, b) # compiled
with tf.xla.experimental.jit_scope(compile_ops=False):
d = tf.matmul(a, c) # not compiled
with tf.xla.experimental.jit_scope(
compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
Example of separate_compiled_gradients:
# In the example below, the computations for f, g and h will all be compiled
# in separate scopes.
with tf.xla.experimental.jit_scope(
separate_compiled_gradients=True):
f = tf.matmul(a, b)
g = tf.gradients([f], [a, b], name='mygrads1')
h = tf.gradients([f], [a, b], name='mygrads2')
Args:
compile_ops: Whether to enable or disable compilation in the scope.
Either a Python bool, or a callable that accepts the parameter
`node_def` and returns a python bool.
separate_compiled_gradients: If true put each gradient subgraph into a
separate compilation scope. This gives fine-grained control over which
portions of the graph will be compiled as a single unit. Compiling
gradients separately may yield better performance for some graphs.
The scope is named based on the scope of the forward computation as well
as the name of the gradients. As a result, the gradients will be compiled
in a scope that is separate from both the forward computation, and from
other gradients.
Yields:
The current scope, enabling or disabling compilation.
"""
if callable(compile_ops):
def xla_compile(node_def):
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
else:
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
attrs = {
"_XlaCompile":
xla_compile,
"_XlaSeparateCompiledGradients":
attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
}
# Find the singleton counter for the current scoped graph. If it
# doesn't exist, create one.
xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
if not xla_scope_counter:
xla_scope_counter = _XlaScope(0, 0)
ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
else:
xla_scope_counter = xla_scope_counter[0]
if xla_scope_counter.depth == 0:
# If we're at the root xla scope, we can increase the counter so
# future calls to jit_scope use a different scope value.
# If we're already within a scope, we'll be fusing using the scope
# controlled by the parent.
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
s=("jit_scope_%d" % xla_scope_counter.count).encode())
xla_scope_counter.count += 1
xla_scope_counter.depth += 1
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope(attrs):
yield
# pylint: enable=protected-access
xla_scope_counter.depth -= 1

View File

@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.compiler.jit."""
"""Tests for python.compiler.xla.jit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.compiler import jit
from tensorflow.python.compiler.xla import jit
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@ -44,6 +45,7 @@ def enable_jit_nonstateful(node_def):
raise ValueError("Unregistered op being created: %s" % node_def)
@test_util.run_v1_only("b/128927195")
class JITTest(test.TestCase):
def compute(self, use_jit, compute_fn):
@ -170,6 +172,7 @@ class JITTest(test.TestCase):
self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
@test_util.run_v1_only("b/128927195")
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):

View File

@ -0,0 +1,604 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""xla is an experimental library that provides XLA support APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
# Operations that indicate some error in the users graph. For example, XLA
# computation should not have any Placeholder op.
_BLACKLISTED_OPS = set([
'Placeholder',
])
# XLA doesn't currently support reading of intermediate tensors, thus some ops
# are not supported.
_UNSUPPORTED_OPS = set([
'AudioSummary',
'AudioSummaryV2',
'HistogramSummary',
'ImageSummary',
'MergeSummary',
'Print',
'ScalarSummary',
'TensorSummary',
'TensorSummaryV2',
])
@tf_export(v1=['xla.experimental.compile'])
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
"""Builds an operator that compiles and runs `computation` with XLA.
Args:
computation: A Python function that builds a computation to apply to the
input. If the function takes n inputs, 'inputs' should be a list of n
tensors.
`computation` may return a list of operations and tensors. Tensors must
come before operations in the returned list. The return value of
`compile` is a list of tensors corresponding to the tensors from the
output of `computation`.
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors. Note that passing an N-dimension list of compatible values will
result in a N-dimension list of scalar tensors rather than a single Rank-N
tensors. If you need different behavior, convert part of inputs to tensors
with `tf.convert_to_tensor`.
Returns:
Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include:
1) None output: a NoOp would be returned which control-depends on
computation.
2) Single value output: A tuple containing the value would be returned.
3) Operation-only outputs: a NoOp would be returned which
control-depends on computation.
TODO(b/121383831): Investigate into removing these special cases.
"""
# pylint: disable=protected-access
return _compile_internal(computation, inputs)
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
The primary role of `XLACompileContext` is to mark operators inside a
xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
a unique name.
`ControlFlowContext` is used to perform the annotation since it integrates
with Tensorflow constructs like ResourceVariables. For example, if a
`ResourceVariable` is constructed inside a xla.compile() block, the
`ResourceVariable` implementation can use
`with ops.control_dependencies(None)` to build the variable's definition
outside the compiled computation.
"""
def __init__(self, name, pivot):
"""Builds a new XLACompileContext.
Args:
name: a unique name for the context, used to populate the
`_xla_compile_id` attribute.
pivot: a pivot node. Nodes in the XLACompileContext that do not have any
inputs will have a control dependency on the pivot node. This ensures
that nodes are correctly included in any enclosing control flow
contexts.
"""
super(XLACompileContext, self).__init__()
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._unsupported_ops = []
self._pivot = pivot
def report_unsupported_operations(self):
if self._unsupported_ops:
op_str = '\n'.join([
' %s (%s)' % (op.type, op.name)
for op in self._unsupported_ops[:_MAX_WARNING_LINES]
])
logging.warning('%d unsupported operations found: \n%s',
len(self._unsupported_ops), op_str)
if len(self._unsupported_ops) > _MAX_WARNING_LINES:
logging.warning('... and %d more',
len(self._unsupported_ops) - _MAX_WARNING_LINES)
def _RemoveExternalControlEdges(self, op):
"""Remove any external control dependency on this op."""
internal_control_inputs = []
external_control_inputs = []
for x in op.control_inputs:
# pylint: disable=protected-access
is_internal_op = False
ctxt = x._get_control_flow_context()
while ctxt is not None:
if ctxt == self:
is_internal_op = True
break
ctxt = ctxt._outer_context
if is_internal_op:
internal_control_inputs.append(x)
else:
external_control_inputs.append(x)
# pylint: enable=protected-access
# pylint: disable=protected-access
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
# pylint: enable=protected-access
return internal_control_inputs, external_control_inputs
def AddOp(self, op):
"""Create op in XLACompileContext and notifies outer context recursively."""
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
logging.error(
'Operation of type %s (%s) is not supported in XLA. Execution will '
'fail if this op is used in the graph. ', op.type, op.name)
# TODO(ycao): Automatically disable summaries instead of reporting them.
if op.type in _UNSUPPORTED_OPS:
self._unsupported_ops.append(op)
if any(x.dtype._is_ref_dtype for x in op.inputs):
raise NotImplementedError(
'Non-resource Variables are not supported inside XLA computations '
'(operator name: %s)' % op.name)
if _XLA_COMPILE_ATTR in op.node_def.attr:
raise ValueError('XLA compiled computations cannot be nested, (operator '
'name: %s)' % op.name)
op._set_attr(
_XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
# Remove any control edges from outer control flow contexts. These may cause
# mismatched frame errors. An example is when one of op's inputs is
# generated in a different While control flow context.
(internal_control_inputs,
external_control_inputs) = self._RemoveExternalControlEdges(op)
if not op.inputs:
# Add a control edge from the control pivot to this op.
if not internal_control_inputs:
# pylint: disable=protected-access
op._add_control_input(self._pivot)
# pylint: enable=protected-access
else:
for index in xrange(len(op.inputs)):
x = op.inputs[index]
real_x = self.AddValue(x)
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
if external_control_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(phawkins): fix that.
with ops.control_dependencies(None):
self.Enter()
external_control_inputs = [
array_ops.identity(x.outputs[0]).op
for x in external_control_inputs
if x.outputs
]
self.Exit()
# pylint: disable=protected-access
op._add_control_inputs(external_control_inputs)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
output_names = [x.name for x in op.outputs]
context = self
while context is not None:
# pylint: disable=protected-access
context._values.update(output_names)
context = context._outer_context
# pylint: enable=protected-access
if self._outer_context:
self._outer_context.AddInnerOp(op)
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context.
result = self._external_values.get(val.name)
return val if result is None else result
result = val
self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
self._values.add(result.name)
self._external_values[val.name] = result
return result
def AddInnerOp(self, op):
self.AddOp(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
@property
def grad_state(self):
# Define the gradient loop state associated with the XLACompileContext to
# be None as the XLACompileContext does not get nested nor does the
# grad_state outside the XLACompileContext affect the graph inside so the
# grad_state should be as if this is the top-level gradient state.
return None
@property
def back_prop(self):
"""Forwards to the enclosing while context, if any."""
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
def _compile_internal(computation, inputs=None):
"""Builds graph operators that compiles and symbolically executes computation.
Args:
computation: A Python function that builds the computation to compile and
execute.
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors. Note that passing an N-dimension list of compatible values will
result in a N-dimension list of scalar tensors rather than a single Rank-N
tensors. If you need different behavior, convert part of inputs to tensors
with `tf.convert_to_tensor`.
Returns:
Same data structure as if computation(*inputs) is called directly with some
exceptions for correctness. Exceptions include: 1) None output 2) Single
value output 3) Operation-only outputs
Raises:
ValueError: If any element in computation outputs is neither an operations
or a value that can be converted to tensor.
ValueError: If computation outputs is non-flat and contains any Operations.
TypeError: If `inputs` is not a list or tuple.
"""
if inputs is None:
inputs = []
if not isinstance(inputs, collections.Sequence):
raise TypeError('inputs must be a list')
# Flatten inputs.
flat_inputs = nest.flatten(inputs)
# Converts inputs to Tensors.
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
cluster_name = ops.get_default_graph().unique_name('cluster')
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
context = XLACompileContext(name=cluster_name, pivot=pivot)
try:
context.Enter()
# Add identity ops so even unused inputs are 'consumed' by the
# computation.
flat_inputs = [
array_ops.identity(x, name='input_{}'.format(i))
for i, x in enumerate(flat_inputs)
]
# Re-pack flat_inputs in same structure as 'inputs'.
computation_inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
# Only resource variables work inside an XLA computation, so turn on
# resource variables for the computation.
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
vscope.set_use_resource(True)
with _disable_summary_context():
outputs = computation(*computation_inputs)
# Restore variable scope after computation.
vscope.set_use_resource(saved_use_resource)
outputs_is_flat = is_flat(outputs)
if outputs_is_flat:
output_tensors, control_deps = _postprocess_flat_outputs(outputs)
else:
output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
# When XLA computation returns only operations and no tensors, a NoOp
# dependent on the operations in outputs is returned. Otherwise final
# outputs would be empty and there is no way to trigger returned
# operations.
if not output_tensors:
return control_flow_ops.group(control_deps, name='output_0')
output_tensors = [
xla_ops.xla_cluster_output(o, name='output{}'.format(i))
for i, o in enumerate(output_tensors)
]
with ops.control_dependencies(control_deps):
# Wraps the outputs in identity operators that carries control
# dependencies.
output_tensors = [
array_ops.identity(o, name='output_%d' % i)
for i, o in enumerate(output_tensors)
]
# If `computation` returned non-flat output structure, pack output tensors
# back into same structure.
if not outputs_is_flat:
output_tensors = nest.pack_sequence_as(
structure=outputs, flat_sequence=output_tensors)
return output_tensors
def is_flat(outputs):
"""Checks if outputs is a flat structure.
Following structures and values are considered flat:
1) None
2) A single object
3) A list or tuple of Tensors/Operations
The only structures that this function understands are sequences and
dictionaries. E.g. this means that if outputs contains a single
user-defined Object, it is considered to be flat. Errors are raised later on
if that Object cannot be converted to a Tensor.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
A boolean indicates whether outputs is flat.
"""
# If outputs is a list or tuple, check if it has any nested structure. If
# there is, then outputs is non-flat.
if isinstance(outputs, collections.Sequence):
for o in outputs:
if isinstance(o, collections.Sequence) or isinstance(o, dict):
return False
# If outputs is a dict, it is non-flat.
if isinstance(outputs, dict):
return False
# Getting here means either outputs itself is a single non-structured value
# or it is a flat list of single non-structured values.
return True
def _postprocess_flat_outputs(outputs):
"""Validates flat outputs and adds back device assignments.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
Tensors and Operations extracted from outputs.
"""
# Following code segment is to preserve legacy behavior. Previously we only
# supported flat outputs and thus for consistency it was nice to convert even
# single element into a tuple. But now that we support arbitrary output
# structure, this is no longer necessary.
# TODO(b/121383831): Migrate all legacy use cases and delete this special
# case.
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
outputs = tuple()
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
# Append `no_op` here so that return value of this function always contains
# at least one op that can trigger XlaLaunch node.
outputs += (control_flow_ops.no_op(),)
try:
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
except Exception as e:
raise ValueError(
'XLA computation function return values must all either be Operations'
' or convertible to Tensors. Got error: "%s"' % str(e))
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
'XLA computation function must return zero or more Tensor values '
'followed by zero or more Operations.')
new_output_tensors = []
for t in output_tensors:
with ops.device(t.device if t.device else ''):
new_output_tensors.append(array_ops.identity(t))
return new_output_tensors, output_operations
def _postprocess_non_flat_outputs(outputs):
"""Validates non-flat outputs and adds back device assignments.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
Tensors extracted from outputs and an empty list because Operations are not
allowed in non-flat outputs..
"""
# Convert all non-Operation outputs to Tensors.
new_output_tensors = []
for o in nest.flatten(outputs):
if isinstance(o, ops.Operation):
raise ValueError(
'xla.compile does not support Operation as return value in non-flat '
'output structure. You can set returned Operations as control '
'dependencies of returned Tensors so Operations are triggered when '
'Tensors are evaluated. Operation found: "%s"' % o.name)
try:
o = ops.convert_to_tensor(o)
except Exception as e:
raise ValueError(
'XLA computation function return values must all either be '
'Operations or convertible to Tensors. Got error: "%s"' % str(e))
# Makes sure even pass-through inputs/outputs are touched in compile
# context by creating an Identity node inside compile context.
with ops.device(o.device if o.device else ''):
new_output_tensors.append(array_ops.identity(o))
return new_output_tensors, []
@contextlib.contextmanager
def _disable_summary_context():
"""Enters a context where all summary ops are skipped.
Summaries are not yet supported in xla.compile(). So we provide this context
manager that can skip creating summary ops. This is a temporary workaround due
to XLA not supporting summary ops.
Yields:
None.
"""
original_skip_summary_func = summary_op_util.skip_summary
summary_op_util.skip_summary = lambda: True
try:
yield
finally:
summary_op_util.skip_summary = original_skip_summary_func
class _CapturedObject(object):
"""A placeholder to capture an object."""
def __init__(self):
self._object = None
def capture(self, o):
if self._object:
raise RuntimeError(
'InternalError: _CapturedObject can capture only once. Please file '
'bug.')
self._object = o
def get(self):
return self._object
def _get_scaffold(captured_scaffold_fn):
"""Retrieves the Scaffold from `captured_scaffold_fn`."""
scaffold_fn = captured_scaffold_fn.get()
if not scaffold_fn:
return None
scaffold = scaffold_fn()
if scaffold is None:
raise ValueError(
'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
return scaffold
def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to an XLA function.
Args:
func: the Python function that will be called to generate the body of an XLA
computation graph.
input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
Returns:
None if function can be called with the supplied number of
arguments, or an error string if it cannot.
"""
def format_error(complaint, quantity):
return '%s %d argument%s' % (complaint, quantity, ''
if quantity == 1 else 's')
num_args_supplied = input_arity
if infeed_queue is not None:
num_args_supplied += infeed_queue.number_of_tuple_elements
arg_spec = tf_inspect.getargspec(func)
num_func_args = len(arg_spec.args)
if arg_spec.defaults is None:
num_func_defaults = 0
else:
num_func_defaults = len(arg_spec.defaults)
min_func_args = num_func_args - num_func_defaults
if num_args_supplied < min_func_args:
# The required number of arguments is not enough to call the function.
if num_func_defaults == 0 and arg_spec.varargs is None:
return format_error('exactly', num_func_args)
else:
return format_error('at least', min_func_args)
if arg_spec.varargs is None and num_args_supplied > num_func_args:
# The required number of arguments is too many to call the function.
if num_func_defaults == 0:
return format_error('exactly', num_func_args)
else:
return format_error('at most', num_func_args)
# Reaching here means either
# 1) There are varargs, func can accept any number of arguments greater than
# the minimum.
# 2) Number of supplied arguments falls in range of acceptable argument count
# of func.
return None

View File

@ -0,0 +1,267 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for python.compiler.xla.xla."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.python import summary
from tensorflow.python.compiler.xla import xla
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
_TRAIN = model_fn_lib.ModeKeys.TRAIN
_EVAL = model_fn_lib.ModeKeys.EVAL
_EXPECTED_LOSS = 1
_EXPECTED_FEATURE = 2
_EXPECTED_LABEL = 3
@test_util.run_v1_only('b/128927195')
class XLACompileContextTest(test.TestCase):
def create_test_xla_compile_context(self):
computation_name = ops.get_default_graph().unique_name('computation')
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
return xla.XLACompileContext(name=computation_name, pivot=pivot)
def test_report_unsupported_operations(self):
"""Tests that unsupported operations are detected."""
context = self.create_test_xla_compile_context()
context.Enter()
dummy_tensor = constant_op.constant(1.1)
audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
image_summary = summary.image('image_summary', dummy_tensor)
scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor)
summary.merge(
[
audio_summary, histogram_summary, image_summary, scalar_summary,
tensor_summary
],
name='merge_summary')
logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
context.Exit()
unsupported_ops_names = [op.name for op in context._unsupported_ops]
self.assertEqual(unsupported_ops_names, [
u'audio_summary', u'histogram_summary', u'image_summary',
u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
u'print_op'
])
def test_resource_variable(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=True)
context = self.create_test_xla_compile_context()
context.Enter()
state_ops.assign(a, a + 1)
context.Exit()
def test_non_resource_variable_error(self):
"""Tests that non-resource variable usage is disallowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=False)
context = self.create_test_xla_compile_context()
context.Enter()
with self.assertRaisesRegexp(
NotImplementedError, 'Non-resource Variables are not supported inside '
r'XLA computations \(operator name: Assign\)'):
state_ops.assign(a, a + 1)
context.Exit()
def test_nested_xla_compile_error(self):
"""Tests that nested XLA computation leads to fatal error."""
context1 = self.create_test_xla_compile_context()
context1.Enter()
context2 = self.create_test_xla_compile_context()
context2.Enter()
with self.assertRaisesRegexp(ValueError,
'XLA compiled computations cannot be nested'):
constant_op.constant(1)
context2.Exit()
context1.Exit()
def test_xla_compile_attr(self):
"""Tests that ops are tagged with XLA compile ID attribute."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn('_xla_compile_id', op.op.node_def.attr)
def test_op_without_input(self):
"""Tests that ops without inputs depend on pivot correctly."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(context._pivot, op.op.control_inputs)
def test_external_control_edges(self):
"""Tests that external control edges are handled correctly."""
i = constant_op.constant(1)
op1 = constant_op.constant(1)
with ops.control_dependencies([op1]):
op2 = constant_op.constant(1)
self.assertIn(op1.op, op2.op.control_inputs)
def while_body(i):
del i # unused
context = self.create_test_xla_compile_context()
context.Enter()
with ops.control_dependencies([op1]):
op3 = constant_op.constant(1)
context.Exit()
self.assertNotIn(op1.op, op3.op.control_inputs)
return op3
control_flow_ops.while_loop(
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
def test_op_output_marked_as_seen(self):
"""Tests that any op output is marked as seen in context."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(op.name, context._values)
def testOpIsInContext(self):
"""Tests that XLACompileContext is recognized as an XLA context."""
op1 = constant_op.constant(1)
context = self.create_test_xla_compile_context()
context.Enter()
op2 = constant_op.constant(2)
context.Exit()
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
def testOpPreventFeeding(self):
"""Tests that ops created inside XLACompileContext can not be fed."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_feedable(op.op))
def testOpPreventFetching(self):
"""Tests that ops created inside XLACompileContext can not be fetched."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_fetchable(op.op))
@test_util.run_v1_only('b/128927195')
class CheckFunctionArgumentCountTest(test.TestCase):
def testSimple(self):
"""Tests that arg checker works for functions with no varargs or defaults.
"""
def func(x, y, z):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, None))
queue = tpu_feed.InfeedQueue(2)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, queue))
def testDefaultArgs(self):
"""Tests that arg checker works for a function with no varargs."""
def func(x, y, z=17):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, queue))
def testVarArgs(self):
"""Tests that arg checker works for a function with varargs."""
def func(x, y, *z):
return x + y + len(z)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
def testVarArgsAndDefaults(self):
"""Tests that arg checker works for a function with varargs and defaults."""
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
return x + y + z + len(q)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
if __name__ == '__main__':
test.main()

View File

@ -83,6 +83,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"train/queue_runner/__init__.py",
"user_ops/__init__.py",
"version/__init__.py",
"xla/__init__.py",
"xla/experimental/__init__.py",
# END GENERATED FILES
]

View File

@ -668,6 +668,10 @@ tf_module {
name: "version"
mtype: "<type \'module\'>"
}
member {
name: "xla"
mtype: "<type \'module\'>"
}
member {
name: "zeros_initializer"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,11 @@
path: "tensorflow.xla.experimental"
tf_module {
member_method {
name: "compile"
argspec: "args=[\'computation\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "jit_scope"
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.xla"
tf_module {
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -883,6 +883,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.conv2d_transpose",
"tf.test.compute_gradient":
"tf.compat.v1.test.compute_gradient",
"tf.xla.experimental.compile":
"tf.compat.v1.xla.experimental.compile",
"tf.xla.experimental.jit_scope":
"tf.compat.v1.xla.experimental.jit_scope",
}
# pylint: enable=line-too-long

View File

@ -1547,6 +1547,17 @@ def _log_prob(self, x):
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testXlaExperimental(self):
text = "tf.xla.experimental.jit_scope(0)"
expected_text = "tf.compat.v1.xla.experimental.jit_scope(0)"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
text = "tf.xla.experimental.compile(0)"
expected_text = "tf.compat.v1.xla.experimental.compile(0)"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
class TestUpgradeFiles(test_util.TensorFlowTestCase):

View File

@ -65,6 +65,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/pyct/testing:test_modules",
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python/compiler:compiler",
"//tensorflow/python:cond_v2",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:meta_graph_testdata",