Automated rollback of commit 302acb768a8f00605f84430fabd28b6daa2c4c77
PiperOrigin-RevId: 240438115
This commit is contained in:
parent
6f9ed358d9
commit
8573a5abc8
@ -10,6 +10,7 @@ package_group(
|
||||
"//tensorflow/compiler/tests/...",
|
||||
"//tensorflow/compiler/tf2xla/...",
|
||||
"//tensorflow/contrib/compiler/...",
|
||||
"//tensorflow/python/compiler/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,49 +24,20 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":xla",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/compiler/xla:compiler_py",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "jit_test",
|
||||
size = "small",
|
||||
srcs = ["jit_test.py"],
|
||||
additional_deps = [
|
||||
":compiler_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
xla_enabled = True,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "xla",
|
||||
srcs = ["xla.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_ops_py",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops_grad",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
"//tensorflow/python/compiler/xla:compiler_py",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
@ -79,17 +50,12 @@ cuda_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/contrib/tpu:tpu_estimator",
|
||||
"//tensorflow/contrib/tpu:tpu_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
tags = [
|
||||
|
@ -18,101 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
_XLA_SCOPE_KEY = ("__xla_scope",)
|
||||
|
||||
|
||||
class _XlaScope(object):
|
||||
"""Keeps track of previous XLA scope calls, and depth of current call."""
|
||||
|
||||
def __init__(self, count, depth):
|
||||
self.count = count
|
||||
self.depth = depth
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
|
||||
"""Enable or disable JIT compilation of operators within the scope.
|
||||
|
||||
NOTE: This is an experimental feature.
|
||||
|
||||
The compilation is a hint and only supported on a best-effort basis.
|
||||
|
||||
Example usage:
|
||||
with tf.contrib.compiler.experimental_jit_scope():
|
||||
c = tf.matmul(a, b) # compiled
|
||||
with tf.contrib.compiler.experimental_jit_scope(compile_ops=False):
|
||||
d = tf.matmul(a, c) # not compiled
|
||||
with tf.contrib.compiler.experimental_jit_scope(
|
||||
compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
|
||||
e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
|
||||
|
||||
Example of separate_compiled_gradients:
|
||||
# In the example below, the computations for f, g and h will all be compiled
|
||||
# in separate scopes.
|
||||
with tf.contrib.compiler.experimental_jit_scope(
|
||||
separate_compiled_gradients=True):
|
||||
f = tf.matmul(a, b)
|
||||
g = tf.gradients([f], [a, b], name='mygrads1')
|
||||
h = tf.gradients([f], [a, b], name='mygrads2')
|
||||
|
||||
Args:
|
||||
compile_ops: Whether to enable or disable compilation in the scope.
|
||||
Either a Python bool, or a callable that accepts the parameter
|
||||
`node_def` and returns a python bool.
|
||||
separate_compiled_gradients: If true put each gradient subgraph into a
|
||||
separate compilation scope. This gives fine-grained control over which
|
||||
portions of the graph will be compiled as a single unit. Compiling
|
||||
gradients separately may yield better performance for some graphs.
|
||||
The scope is named based on the scope of the forward computation as well
|
||||
as the name of the gradients. As a result, the gradients will be compiled
|
||||
in a scope that is separate from both the forward computation, and from
|
||||
other gradients.
|
||||
Yields:
|
||||
The current scope, enabling or disabling compilation.
|
||||
|
||||
"""
|
||||
if callable(compile_ops):
|
||||
def xla_compile(node_def):
|
||||
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
|
||||
else:
|
||||
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
|
||||
|
||||
attrs = {
|
||||
"_XlaCompile":
|
||||
xla_compile,
|
||||
"_XlaSeparateCompiledGradients":
|
||||
attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
|
||||
}
|
||||
|
||||
# Find the singleton counter for the current scoped graph. If it
|
||||
# doesn't exist, create one.
|
||||
xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
|
||||
if not xla_scope_counter:
|
||||
xla_scope_counter = _XlaScope(0, 0)
|
||||
ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
|
||||
else:
|
||||
xla_scope_counter = xla_scope_counter[0]
|
||||
|
||||
if xla_scope_counter.depth == 0:
|
||||
# If we're at the root xla scope, we can increase the counter so
|
||||
# future calls to jit_scope use a different scope value.
|
||||
# If we're already within a scope, we'll be fusing using the scope
|
||||
# controlled by the parent.
|
||||
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
|
||||
s=("jit_scope_%d" % xla_scope_counter.count).encode())
|
||||
xla_scope_counter.count += 1
|
||||
|
||||
xla_scope_counter.depth += 1
|
||||
|
||||
# pylint: disable=protected-access
|
||||
with ops.get_default_graph()._attr_scope(attrs):
|
||||
yield
|
||||
# pylint: enable=protected-access
|
||||
|
||||
xla_scope_counter.depth -= 1
|
||||
experimental_jit_scope = jit.experimental_jit_scope
|
||||
|
@ -18,511 +18,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.jit.ops import xla_ops
|
||||
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute import summary_op_util
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
_XLA_COMPILE_ATTR = '_xla_compile_id'
|
||||
_MAX_WARNING_LINES = 5
|
||||
|
||||
# Operations that indicate some error in the users graph. For example, XLA
|
||||
# computation should not have any Placeholder op.
|
||||
_BLACKLISTED_OPS = set([
|
||||
'Placeholder',
|
||||
])
|
||||
|
||||
# XLA doesn't currently support reading of intermediate tensors, thus some ops
|
||||
# are not supported.
|
||||
_UNSUPPORTED_OPS = set([
|
||||
'AudioSummary',
|
||||
'AudioSummaryV2',
|
||||
'HistogramSummary',
|
||||
'ImageSummary',
|
||||
'MergeSummary',
|
||||
'Print',
|
||||
'ScalarSummary',
|
||||
'TensorSummary',
|
||||
'TensorSummaryV2',
|
||||
])
|
||||
|
||||
|
||||
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
"""Builds an operator that compiles and runs `computation` with XLA.
|
||||
|
||||
Args:
|
||||
computation: A Python function that builds a computation to apply to the
|
||||
input. If the function takes n inputs, 'inputs' should be a list of n
|
||||
tensors.
|
||||
|
||||
`computation` may return a list of operations and tensors. Tensors must
|
||||
come before operations in the returned list. The return value of
|
||||
`compile` is a list of tensors corresponding to the tensors from the
|
||||
output of `computation`.
|
||||
|
||||
All `Operation`s returned from `computation` will be executed when
|
||||
evaluating any of the returned output tensors.
|
||||
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
|
||||
can be a nested structure containing values that are convertible to
|
||||
tensors. Note that passing an N-dimension list of compatible values will
|
||||
result in a N-dimention list of scalar tensors rather than a single Rank-N
|
||||
tensors. If you need different behavior, convert part of inputs to tensors
|
||||
with `tf.convert_to_tensor`.
|
||||
|
||||
Returns:
|
||||
Same data structure as if computation(*inputs) is called directly with some
|
||||
exceptions for correctness. Exceptions include:
|
||||
1) None output: a NoOp would be returned which control-depends on
|
||||
computation.
|
||||
2) Single value output: A tuple containing the value would be returned.
|
||||
3) Operation-only outputs: a NoOp would be returned which
|
||||
control-depends on computation.
|
||||
TODO(b/121383831): Investigate into removing these special cases.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
return _compile_internal(computation, inputs)
|
||||
|
||||
|
||||
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
|
||||
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
|
||||
|
||||
THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
|
||||
|
||||
The primary role of `XLACompileContext` is to mark operators inside a
|
||||
xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
|
||||
a unique name.
|
||||
|
||||
`ControlFlowContext` is used to perform the annotation since it integrates
|
||||
with Tensorflow constructs like ResourceVariables. For example, if a
|
||||
`ResourceVariable` is constructed inside a xla.compile() block, the
|
||||
`ResourceVariable` implementation can use
|
||||
`with ops.control_dependencies(None)` to build the variable's definition
|
||||
outside the compiled computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, pivot):
|
||||
"""Builds a new XLACompileContext.
|
||||
|
||||
Args:
|
||||
name: a unique name for the context, used to populate the
|
||||
`_xla_compile_id` attribute.
|
||||
pivot: a pivot node. Nodes in the XLACompileContext that do not have any
|
||||
inputs will have a control dependency on the pivot node. This ensures
|
||||
that nodes are correctly included in any enclosing control flow
|
||||
contexts.
|
||||
"""
|
||||
super(XLACompileContext, self).__init__()
|
||||
self._name = name
|
||||
self._name_as_bytes = compat.as_bytes(name)
|
||||
self._unsupported_ops = []
|
||||
self._pivot = pivot
|
||||
|
||||
def report_unsupported_operations(self):
|
||||
if self._unsupported_ops:
|
||||
op_str = '\n'.join([
|
||||
' %s (%s)' % (op.type, op.name)
|
||||
for op in self._unsupported_ops[:_MAX_WARNING_LINES]
|
||||
])
|
||||
logging.warning('%d unsupported operations found: \n%s',
|
||||
len(self._unsupported_ops), op_str)
|
||||
if len(self._unsupported_ops) > _MAX_WARNING_LINES:
|
||||
logging.warning('... and %d more',
|
||||
len(self._unsupported_ops) - _MAX_WARNING_LINES)
|
||||
|
||||
def _RemoveExternalControlEdges(self, op):
|
||||
"""Remove any external control dependency on this op."""
|
||||
internal_control_inputs = []
|
||||
external_control_inputs = []
|
||||
for x in op.control_inputs:
|
||||
# pylint: disable=protected-access
|
||||
is_internal_op = False
|
||||
ctxt = x._get_control_flow_context()
|
||||
while ctxt is not None:
|
||||
if ctxt == self:
|
||||
is_internal_op = True
|
||||
break
|
||||
ctxt = ctxt._outer_context
|
||||
if is_internal_op:
|
||||
internal_control_inputs.append(x)
|
||||
else:
|
||||
external_control_inputs.append(x)
|
||||
# pylint: enable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
op._remove_all_control_inputs()
|
||||
op._add_control_inputs(internal_control_inputs)
|
||||
# pylint: enable=protected-access
|
||||
return internal_control_inputs, external_control_inputs
|
||||
|
||||
def AddOp(self, op):
|
||||
"""Create op in XLACompileContext and notifies outer context recursively."""
|
||||
# pylint: disable=protected-access
|
||||
if op.type in _BLACKLISTED_OPS:
|
||||
logging.error(
|
||||
'Operation of type %s (%s) is not supported in XLA. Execution will '
|
||||
'fail if this op is used in the graph. ', op.type, op.name)
|
||||
|
||||
# TODO(ycao): Automatically disable summaries instead of reporting them.
|
||||
if op.type in _UNSUPPORTED_OPS:
|
||||
self._unsupported_ops.append(op)
|
||||
|
||||
if any(x.dtype._is_ref_dtype for x in op.inputs):
|
||||
raise NotImplementedError(
|
||||
'Non-resource Variables are not supported inside XLA computations '
|
||||
'(operator name: %s)' % op.name)
|
||||
|
||||
if _XLA_COMPILE_ATTR in op.node_def.attr:
|
||||
raise ValueError('XLA compiled computations cannot be nested, (operator '
|
||||
'name: %s)' % op.name)
|
||||
|
||||
op._set_attr(
|
||||
_XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
|
||||
|
||||
op.graph.prevent_feeding(op)
|
||||
op.graph.prevent_fetching(op)
|
||||
|
||||
# Remove any control edges from outer control flow contexts. These may cause
|
||||
# mismatched frame errors. An example is when one of op's inputs is
|
||||
# generated in a different While control flow context.
|
||||
(internal_control_inputs,
|
||||
external_control_inputs) = self._RemoveExternalControlEdges(op)
|
||||
|
||||
if not op.inputs:
|
||||
# Add a control edge from the control pivot to this op.
|
||||
if not internal_control_inputs:
|
||||
# pylint: disable=protected-access
|
||||
op._add_control_input(self._pivot)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
for index in xrange(len(op.inputs)):
|
||||
x = op.inputs[index]
|
||||
real_x = self.AddValue(x)
|
||||
if real_x != x:
|
||||
op._update_input(index, real_x) # pylint: disable=protected-access
|
||||
|
||||
if external_control_inputs:
|
||||
# Use an identity to pull control inputs as data inputs. Note that we
|
||||
# ignore ops which don't have outputs. TODO(phawkins): fix that.
|
||||
with ops.control_dependencies(None):
|
||||
self.Enter()
|
||||
external_control_inputs = [
|
||||
array_ops.identity(x.outputs[0]).op
|
||||
for x in external_control_inputs
|
||||
if x.outputs
|
||||
]
|
||||
self.Exit()
|
||||
# pylint: disable=protected-access
|
||||
op._add_control_inputs(external_control_inputs)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Mark op's outputs as seen by this context and any outer contexts.
|
||||
output_names = [x.name for x in op.outputs]
|
||||
context = self
|
||||
while context is not None:
|
||||
# pylint: disable=protected-access
|
||||
context._values.update(output_names)
|
||||
context = context._outer_context
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
def AddValue(self, val):
|
||||
"""Add `val` to the current context and its outer context recursively."""
|
||||
if val.name in self._values:
|
||||
# Use the real value if it comes from outer context.
|
||||
result = self._external_values.get(val.name)
|
||||
return val if result is None else result
|
||||
|
||||
result = val
|
||||
self._values.add(val.name)
|
||||
if self._outer_context:
|
||||
result = self._outer_context.AddValue(val)
|
||||
self._values.add(result.name)
|
||||
|
||||
self._external_values[val.name] = result
|
||||
|
||||
return result
|
||||
|
||||
def AddInnerOp(self, op):
|
||||
self.AddOp(op)
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
@property
|
||||
def grad_state(self):
|
||||
# Define the gradient loop state associated with the XLACompileContext to
|
||||
# be None as the XLACompileContext does not get nested nor does the
|
||||
# grad_state outside the XLACompileContext affect the graph inside so the
|
||||
# grad_state should be as if this is the top-level gradient state.
|
||||
return None
|
||||
|
||||
@property
|
||||
def back_prop(self):
|
||||
"""Forwards to the enclosing while context, if any."""
|
||||
if self.GetWhileContext():
|
||||
return self.GetWhileContext().back_prop
|
||||
return False
|
||||
|
||||
|
||||
def _compile_internal(computation, inputs=None):
|
||||
"""Builds graph operators that compiles and symbolically executes computation.
|
||||
|
||||
Args:
|
||||
computation: A Python function that builds the computation to compile and
|
||||
execute.
|
||||
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
|
||||
can be a nested structure containing values that are convertible to
|
||||
tensors. Note that passing an N-dimension list of compatible values will
|
||||
result in a N-dimension list of scalar tensors rather than a single Rank-N
|
||||
tensors. If you need different behavior, convert part of inputs to tensors
|
||||
with `tf.convert_to_tensor`.
|
||||
|
||||
Returns:
|
||||
Same data structure as if computation(*inputs) is called directly with some
|
||||
exceptions for correctness. Exceptions include: 1) None output 2) Single
|
||||
value output 3) Operation-only outputs
|
||||
Raises:
|
||||
ValueError: If any element in computation outputs is neither an operations
|
||||
or a value that can be converted to tensor.
|
||||
ValueError: If computation outputs is non-flat and contains any Operations.
|
||||
TypeError: If `inputs` is not a list or tuple.
|
||||
"""
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
|
||||
if not isinstance(inputs, collections.Sequence):
|
||||
raise TypeError('inputs must be a list')
|
||||
|
||||
# Flatten inputs.
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
# Converts inputs to Tensors.
|
||||
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
|
||||
|
||||
cluster_name = ops.get_default_graph().unique_name('cluster')
|
||||
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
|
||||
context = XLACompileContext(name=cluster_name, pivot=pivot)
|
||||
try:
|
||||
context.Enter()
|
||||
|
||||
# Add identity ops so even unused inputs are 'consumed' by the
|
||||
# computation.
|
||||
flat_inputs = [
|
||||
array_ops.identity(x, name='input_{}'.format(i))
|
||||
for i, x in enumerate(flat_inputs)
|
||||
]
|
||||
|
||||
# Re-pack flat_inputs in same structure as 'inputs'.
|
||||
computation_inputs = nest.pack_sequence_as(
|
||||
structure=inputs, flat_sequence=flat_inputs)
|
||||
|
||||
# Only resource variables work inside an XLA computation, so turn on
|
||||
# resource variables for the computation.
|
||||
vscope = variable_scope.get_variable_scope()
|
||||
saved_use_resource = vscope.use_resource
|
||||
vscope.set_use_resource(True)
|
||||
|
||||
with _disable_summary_context():
|
||||
outputs = computation(*computation_inputs)
|
||||
|
||||
# Restore variable scope after computation.
|
||||
vscope.set_use_resource(saved_use_resource)
|
||||
|
||||
outputs_is_flat = is_flat(outputs)
|
||||
if outputs_is_flat:
|
||||
output_tensors, control_deps = _postprocess_flat_outputs(outputs)
|
||||
else:
|
||||
output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
|
||||
|
||||
context.ExitResult(output_tensors)
|
||||
finally:
|
||||
context.report_unsupported_operations()
|
||||
context.Exit()
|
||||
|
||||
# When XLA computation returns only operations and no tensors, a NoOp
|
||||
# dependent on the operations in outputs is returned. Otherwise final
|
||||
# outputs would be empty and there is no way to trigger returned
|
||||
# operations.
|
||||
if not output_tensors:
|
||||
return control_flow_ops.group(control_deps, name='output_0')
|
||||
|
||||
output_tensors = [
|
||||
xla_ops.xla_cluster_output(o, name='output{}'.format(i))
|
||||
for i, o in enumerate(output_tensors)
|
||||
]
|
||||
|
||||
with ops.control_dependencies(control_deps):
|
||||
# Wraps the outputs in identity operators that carries control
|
||||
# dependencies.
|
||||
output_tensors = [
|
||||
array_ops.identity(o, name='output_%d' % i)
|
||||
for i, o in enumerate(output_tensors)
|
||||
]
|
||||
|
||||
# If `computation` returned non-flat output structure, pack output tensors
|
||||
# back into same structure.
|
||||
if not outputs_is_flat:
|
||||
output_tensors = nest.pack_sequence_as(
|
||||
structure=outputs, flat_sequence=output_tensors)
|
||||
|
||||
return output_tensors
|
||||
|
||||
|
||||
def is_flat(outputs):
|
||||
"""Checks if outputs is a flat structure.
|
||||
|
||||
Following structures and values are considered flat:
|
||||
1) None
|
||||
2) A single object
|
||||
3) A list or tuple of Tensors/Operations
|
||||
|
||||
The only structures that this function understands are sequences and
|
||||
dictionaries. E.g. this means that if outputs contains a single
|
||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
||||
if that Object cannot be converted to a Tensor.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
A boolean indicates whether outputs is flat.
|
||||
"""
|
||||
# If outputs is a list or tuple, check if it has any nested structure. If
|
||||
# there is, then outputs is non-flat.
|
||||
if isinstance(outputs, collections.Sequence):
|
||||
for o in outputs:
|
||||
if isinstance(o, collections.Sequence) or isinstance(o, dict):
|
||||
return False
|
||||
|
||||
# If outputs is a dict, it is non-flat.
|
||||
if isinstance(outputs, dict):
|
||||
return False
|
||||
|
||||
# Getting here means either outputs itself is a single non-structured value
|
||||
# or it is a flat list of single non-structured values.
|
||||
return True
|
||||
|
||||
|
||||
def _postprocess_flat_outputs(outputs):
|
||||
"""Validates flat outputs and adds back device assignments.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
Tensors and Operations extracted from outputs.
|
||||
"""
|
||||
# Following code segment is to preserve legacy behavior. Previously we only
|
||||
# supported flat outputs and thus for consistency it was nice to convert even
|
||||
# single element into a tuple. But now that we support arbitrary output
|
||||
# structure, this is no longer necessary.
|
||||
# TODO(b/121383831): Migrate all legacy use cases and delete this special
|
||||
# case.
|
||||
# If the computation returns `None`, make it an empty tuple.
|
||||
if outputs is None:
|
||||
outputs = tuple()
|
||||
# If the computation only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, collections.Sequence):
|
||||
outputs = (outputs,)
|
||||
|
||||
# Append `no_op` here so that return value of this function always contains
|
||||
# at least one op that can trigger XlaLaunch node.
|
||||
outputs += (control_flow_ops.no_op(),)
|
||||
try:
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
'XLA computation function return values must all either be Operations'
|
||||
' or convertible to Tensors. Got error: "%s"' % str(e))
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
'XLA computation function must return zero or more Tensor values '
|
||||
'followed by zero or more Operations.')
|
||||
|
||||
new_output_tensors = []
|
||||
for t in output_tensors:
|
||||
with ops.device(t.device if t.device else ''):
|
||||
new_output_tensors.append(array_ops.identity(t))
|
||||
|
||||
return new_output_tensors, output_operations
|
||||
|
||||
|
||||
def _postprocess_non_flat_outputs(outputs):
|
||||
"""Validates non-flat outputs and adds back device assignments.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
Tensors extracted from outputs and an empty list because Operations are not
|
||||
allowed in non-flat outputs..
|
||||
"""
|
||||
# Convert all non-Operation outputs to Tensors.
|
||||
new_output_tensors = []
|
||||
for o in nest.flatten(outputs):
|
||||
if isinstance(o, ops.Operation):
|
||||
raise ValueError(
|
||||
'xla.compile does not support Operation as return value in non-flat '
|
||||
'output structure. You can set returned Operations as control '
|
||||
'dependencies of returned Tensors so Operations are triggered when '
|
||||
'Tensors are evaluated. Operation found: "%s"' % o.name)
|
||||
|
||||
try:
|
||||
o = ops.convert_to_tensor(o)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
'XLA computation function return values must all either be '
|
||||
'Operations or convertible to Tensors. Got error: "%s"' % str(e))
|
||||
|
||||
# Makes sure even pass-through inputs/outputs are touched in compile
|
||||
# context by creating an Identity node inside compile context.
|
||||
with ops.device(o.device if o.device else ''):
|
||||
new_output_tensors.append(array_ops.identity(o))
|
||||
|
||||
return new_output_tensors, []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _disable_summary_context():
|
||||
"""Enters a context where all summary ops are skipped.
|
||||
|
||||
Summaries are not yet supported in xla.compile(). So we provide this context
|
||||
manager that can skip creating summary ops. This is a temporary workaround due
|
||||
to XLA not supporting summary ops.
|
||||
|
||||
Yields:
|
||||
None.
|
||||
"""
|
||||
original_skip_summary_func = summary_op_util.skip_summary
|
||||
summary_op_util.skip_summary = lambda: True
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
summary_op_util.skip_summary = original_skip_summary_func
|
||||
|
||||
compile = xla.compile # pylint: disable=redefined-builtin
|
||||
check_function_argument_count = xla.check_function_argument_count
|
||||
|
||||
class _CapturedObject(object):
|
||||
"""A placeholder to capture an object."""
|
||||
@ -788,51 +294,3 @@ def estimator_model_fn(target_model_fn=None):
|
||||
return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
|
||||
|
||||
return decorated(target_model_fn) if target_model_fn else decorated
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to an XLA function.
|
||||
|
||||
Args:
|
||||
func: the Python function that will be called to generate the body of an XLA
|
||||
computation graph.
|
||||
input_arity: the number of explicit arguments supplied by the caller.
|
||||
infeed_queue: if not None, the infeed queue that will supply
|
||||
additional arguments to the function.
|
||||
|
||||
Returns:
|
||||
None if function can be called with the supplied number of
|
||||
arguments, or an error string if it cannot.
|
||||
"""
|
||||
def format_error(complaint, quantity):
|
||||
return '%s %d argument%s' % (complaint, quantity, ''
|
||||
if quantity == 1 else 's')
|
||||
|
||||
num_args_supplied = input_arity
|
||||
if infeed_queue is not None:
|
||||
num_args_supplied += infeed_queue.number_of_tuple_elements
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
num_func_args = len(arg_spec.args)
|
||||
if arg_spec.defaults is None:
|
||||
num_func_defaults = 0
|
||||
else:
|
||||
num_func_defaults = len(arg_spec.defaults)
|
||||
min_func_args = num_func_args - num_func_defaults
|
||||
if num_args_supplied < min_func_args:
|
||||
# The required number of arguments is not enough to call the function.
|
||||
if num_func_defaults == 0 and arg_spec.varargs is None:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at least', min_func_args)
|
||||
if arg_spec.varargs is None and num_args_supplied > num_func_args:
|
||||
# The required number of arguments is too many to call the function.
|
||||
if num_func_defaults == 0:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at most', num_func_args)
|
||||
# Reaching here means either
|
||||
# 1) There are varargs, func can accept any number of arguments greater than
|
||||
# the minimum.
|
||||
# 2) Number of supplied arguments falls in range of acceptable argument count
|
||||
# of func.
|
||||
return None
|
||||
|
@ -23,20 +23,13 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.contrib.compiler import xla
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.contrib.training.python.training import hparam
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training
|
||||
|
||||
@ -48,226 +41,6 @@ _EXPECTED_FEATURE = 2
|
||||
_EXPECTED_LABEL = 3
|
||||
|
||||
|
||||
class XLACompileContextTest(test.TestCase):
|
||||
|
||||
def create_test_xla_compile_context(self):
|
||||
computation_name = ops.get_default_graph().unique_name('computation')
|
||||
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
|
||||
return xla.XLACompileContext(name=computation_name, pivot=pivot)
|
||||
|
||||
def test_report_unsupported_operations(self):
|
||||
"""Tests that unsupported operations are detected."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
dummy_tensor = constant_op.constant(1.1)
|
||||
audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
|
||||
histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
|
||||
image_summary = summary.image('image_summary', dummy_tensor)
|
||||
scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
|
||||
tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor)
|
||||
summary.merge(
|
||||
[
|
||||
audio_summary, histogram_summary, image_summary, scalar_summary,
|
||||
tensor_summary
|
||||
],
|
||||
name='merge_summary')
|
||||
logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
|
||||
context.Exit()
|
||||
|
||||
unsupported_ops_names = [op.name for op in context._unsupported_ops]
|
||||
self.assertEqual(unsupported_ops_names, [
|
||||
u'audio_summary', u'histogram_summary', u'image_summary',
|
||||
u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
|
||||
u'print_op'
|
||||
])
|
||||
|
||||
def test_resource_variable(self):
|
||||
"""Tests that resource variable usage is allowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', shape=(1), use_resource=True)
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
state_ops.assign(a, a + 1)
|
||||
context.Exit()
|
||||
|
||||
def test_non_resource_variable_error(self):
|
||||
"""Tests that non-resource variable usage is disallowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', shape=(1), use_resource=False)
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, 'Non-resource Variables are not supported inside '
|
||||
r'XLA computations \(operator name: Assign\)'):
|
||||
state_ops.assign(a, a + 1)
|
||||
context.Exit()
|
||||
|
||||
def test_nested_xla_compile_error(self):
|
||||
"""Tests that nested XLA computation leads to fatal error."""
|
||||
context1 = self.create_test_xla_compile_context()
|
||||
context1.Enter()
|
||||
|
||||
context2 = self.create_test_xla_compile_context()
|
||||
context2.Enter()
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'XLA compiled computations cannot be nested'):
|
||||
constant_op.constant(1)
|
||||
context2.Exit()
|
||||
context1.Exit()
|
||||
|
||||
def test_xla_compile_attr(self):
|
||||
"""Tests that ops are tagged with XLA compile ID attribute."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertIn('_xla_compile_id', op.op.node_def.attr)
|
||||
|
||||
def test_op_without_input(self):
|
||||
"""Tests that ops without inputs depend on pivot correctly."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
|
||||
self.assertIn(context._pivot, op.op.control_inputs)
|
||||
|
||||
def test_external_control_edges(self):
|
||||
"""Tests that external control edges are handled correctly."""
|
||||
i = constant_op.constant(1)
|
||||
op1 = constant_op.constant(1)
|
||||
|
||||
with ops.control_dependencies([op1]):
|
||||
op2 = constant_op.constant(1)
|
||||
self.assertIn(op1.op, op2.op.control_inputs)
|
||||
|
||||
def while_body(i):
|
||||
del i # unused
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
with ops.control_dependencies([op1]):
|
||||
op3 = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertNotIn(op1.op, op3.op.control_inputs)
|
||||
return op3
|
||||
|
||||
control_flow_ops.while_loop(
|
||||
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
|
||||
|
||||
def test_op_output_marked_as_seen(self):
|
||||
"""Tests that any op output is marked as seen in context."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
|
||||
self.assertIn(op.name, context._values)
|
||||
|
||||
def testOpIsInContext(self):
|
||||
"""Tests that XLACompileContext is recognized as an XLA context."""
|
||||
op1 = constant_op.constant(1)
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op2 = constant_op.constant(2)
|
||||
context.Exit()
|
||||
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
|
||||
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
|
||||
|
||||
def testOpPreventFeeding(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fed."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertFalse(op.graph.is_feedable(op.op))
|
||||
|
||||
def testOpPreventFetching(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fetched."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertFalse(op.graph.is_fetchable(op.op))
|
||||
|
||||
|
||||
class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
"""Tests that arg checker works for functions with no varargs or defaults.
|
||||
"""
|
||||
|
||||
def func(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual('exactly 3 arguments',
|
||||
xla.check_function_argument_count(func, 2, None))
|
||||
queue = tpu_feed.InfeedQueue(2)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual('exactly 3 arguments',
|
||||
xla.check_function_argument_count(func, 2, queue))
|
||||
|
||||
def testDefaultArgs(self):
|
||||
"""Tests that arg checker works for a function with no varargs."""
|
||||
|
||||
def func(x, y, z=17):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
self.assertEqual('at most 3 arguments',
|
||||
xla.check_function_argument_count(func, 4, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
self.assertEqual('at most 3 arguments',
|
||||
xla.check_function_argument_count(func, 4, queue))
|
||||
|
||||
def testVarArgs(self):
|
||||
"""Tests that arg checker works for a function with varargs."""
|
||||
|
||||
def func(x, y, *z):
|
||||
return x + y + len(z)
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
|
||||
def testVarArgsAndDefaults(self):
|
||||
"""Tests that arg checker works for a function with varargs and defaults."""
|
||||
|
||||
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
|
||||
return x + y + z + len(q)
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
|
||||
|
||||
def _test_train_model_fn(features, labels, mode, params):
|
||||
"""A dummy model_fn for testing purpose."""
|
||||
del features, labels, params
|
||||
|
@ -144,6 +144,10 @@ from tensorflow.python.framework.ops import enable_eager_execution
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
|
||||
# XLA JIT compiler APIs.
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
|
||||
# Required due to `rnn` and `rnn_cell` not being imported in `nn` directly
|
||||
# (due to a circular dependency issue: rnn depends on layers).
|
||||
nn.dynamic_rnn = rnn.dynamic_rnn
|
||||
|
@ -15,5 +15,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = if_not_windows([
|
||||
"//tensorflow/python/compiler/tensorrt:init_py",
|
||||
]),
|
||||
]) + [
|
||||
"//tensorflow/python/compiler/xla:compiler_py",
|
||||
],
|
||||
)
|
||||
|
85
tensorflow/python/compiler/xla/BUILD
Normal file
85
tensorflow/python/compiler/xla/BUILD
Normal file
@ -0,0 +1,85 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "compiler_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"jit.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":xla",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "jit_test",
|
||||
size = "small",
|
||||
srcs = ["jit_test.py"],
|
||||
additional_deps = [
|
||||
":compiler_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
xla_enabled = True,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "xla",
|
||||
srcs = ["xla.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_ops_py",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops_grad",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "xla_test",
|
||||
srcs = ["xla_test.py"],
|
||||
additional_deps = [
|
||||
":xla",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/contrib/tpu:tpu_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
tags = [
|
||||
"no_mac",
|
||||
"no_windows",
|
||||
],
|
||||
xla_enabled = True,
|
||||
)
|
24
tensorflow/python/compiler/xla/__init__.py
Normal file
24
tensorflow/python/compiler/xla/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A module for controlling the Tensorflow/XLA JIT compiler."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
# pylint: enable=unused-import
|
120
tensorflow/python/compiler/xla/jit.py
Normal file
120
tensorflow/python/compiler/xla/jit.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library for controlling the Tensorflow/XLA JIT compiler."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_XLA_SCOPE_KEY = ("__xla_scope",)
|
||||
|
||||
|
||||
class _XlaScope(object):
|
||||
"""Keeps track of previous XLA scope calls, and depth of current call."""
|
||||
|
||||
def __init__(self, count, depth):
|
||||
self.count = count
|
||||
self.depth = depth
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@tf_export(v1=["xla.experimental.jit_scope"])
|
||||
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
|
||||
"""Enable or disable JIT compilation of operators within the scope.
|
||||
|
||||
NOTE: This is an experimental feature.
|
||||
|
||||
The compilation is a hint and only supported on a best-effort basis.
|
||||
|
||||
Example usage:
|
||||
with tf.xla.experimental.jit_scope():
|
||||
c = tf.matmul(a, b) # compiled
|
||||
with tf.xla.experimental.jit_scope(compile_ops=False):
|
||||
d = tf.matmul(a, c) # not compiled
|
||||
with tf.xla.experimental.jit_scope(
|
||||
compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
|
||||
e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
|
||||
|
||||
Example of separate_compiled_gradients:
|
||||
# In the example below, the computations for f, g and h will all be compiled
|
||||
# in separate scopes.
|
||||
with tf.xla.experimental.jit_scope(
|
||||
separate_compiled_gradients=True):
|
||||
f = tf.matmul(a, b)
|
||||
g = tf.gradients([f], [a, b], name='mygrads1')
|
||||
h = tf.gradients([f], [a, b], name='mygrads2')
|
||||
|
||||
Args:
|
||||
compile_ops: Whether to enable or disable compilation in the scope.
|
||||
Either a Python bool, or a callable that accepts the parameter
|
||||
`node_def` and returns a python bool.
|
||||
separate_compiled_gradients: If true put each gradient subgraph into a
|
||||
separate compilation scope. This gives fine-grained control over which
|
||||
portions of the graph will be compiled as a single unit. Compiling
|
||||
gradients separately may yield better performance for some graphs.
|
||||
The scope is named based on the scope of the forward computation as well
|
||||
as the name of the gradients. As a result, the gradients will be compiled
|
||||
in a scope that is separate from both the forward computation, and from
|
||||
other gradients.
|
||||
Yields:
|
||||
The current scope, enabling or disabling compilation.
|
||||
|
||||
"""
|
||||
if callable(compile_ops):
|
||||
def xla_compile(node_def):
|
||||
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
|
||||
else:
|
||||
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
|
||||
|
||||
attrs = {
|
||||
"_XlaCompile":
|
||||
xla_compile,
|
||||
"_XlaSeparateCompiledGradients":
|
||||
attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
|
||||
}
|
||||
|
||||
# Find the singleton counter for the current scoped graph. If it
|
||||
# doesn't exist, create one.
|
||||
xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
|
||||
if not xla_scope_counter:
|
||||
xla_scope_counter = _XlaScope(0, 0)
|
||||
ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
|
||||
else:
|
||||
xla_scope_counter = xla_scope_counter[0]
|
||||
|
||||
if xla_scope_counter.depth == 0:
|
||||
# If we're at the root xla scope, we can increase the counter so
|
||||
# future calls to jit_scope use a different scope value.
|
||||
# If we're already within a scope, we'll be fusing using the scope
|
||||
# controlled by the parent.
|
||||
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
|
||||
s=("jit_scope_%d" % xla_scope_counter.count).encode())
|
||||
xla_scope_counter.count += 1
|
||||
|
||||
xla_scope_counter.depth += 1
|
||||
|
||||
# pylint: disable=protected-access
|
||||
with ops.get_default_graph()._attr_scope(attrs):
|
||||
yield
|
||||
# pylint: enable=protected-access
|
||||
|
||||
xla_scope_counter.depth -= 1
|
@ -12,18 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.compiler.jit."""
|
||||
"""Tests for python.compiler.xla.jit."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.compiler import jit
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -44,6 +45,7 @@ def enable_jit_nonstateful(node_def):
|
||||
raise ValueError("Unregistered op being created: %s" % node_def)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/128927195")
|
||||
class JITTest(test.TestCase):
|
||||
|
||||
def compute(self, use_jit, compute_fn):
|
||||
@ -170,6 +172,7 @@ class JITTest(test.TestCase):
|
||||
self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/128927195")
|
||||
class CompilationEnabledInGradientTest(test.TestCase):
|
||||
|
||||
def testCompilationInGradient(self):
|
604
tensorflow/python/compiler/xla/xla.py
Normal file
604
tensorflow/python/compiler/xla/xla.py
Normal file
@ -0,0 +1,604 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""xla is an experimental library that provides XLA support APIs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.jit.ops import xla_ops
|
||||
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute import summary_op_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_XLA_COMPILE_ATTR = '_xla_compile_id'
|
||||
_MAX_WARNING_LINES = 5
|
||||
|
||||
# Operations that indicate some error in the users graph. For example, XLA
|
||||
# computation should not have any Placeholder op.
|
||||
_BLACKLISTED_OPS = set([
|
||||
'Placeholder',
|
||||
])
|
||||
|
||||
# XLA doesn't currently support reading of intermediate tensors, thus some ops
|
||||
# are not supported.
|
||||
_UNSUPPORTED_OPS = set([
|
||||
'AudioSummary',
|
||||
'AudioSummaryV2',
|
||||
'HistogramSummary',
|
||||
'ImageSummary',
|
||||
'MergeSummary',
|
||||
'Print',
|
||||
'ScalarSummary',
|
||||
'TensorSummary',
|
||||
'TensorSummaryV2',
|
||||
])
|
||||
|
||||
|
||||
@tf_export(v1=['xla.experimental.compile'])
|
||||
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
"""Builds an operator that compiles and runs `computation` with XLA.
|
||||
|
||||
Args:
|
||||
computation: A Python function that builds a computation to apply to the
|
||||
input. If the function takes n inputs, 'inputs' should be a list of n
|
||||
tensors.
|
||||
|
||||
`computation` may return a list of operations and tensors. Tensors must
|
||||
come before operations in the returned list. The return value of
|
||||
`compile` is a list of tensors corresponding to the tensors from the
|
||||
output of `computation`.
|
||||
|
||||
All `Operation`s returned from `computation` will be executed when
|
||||
evaluating any of the returned output tensors.
|
||||
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
|
||||
can be a nested structure containing values that are convertible to
|
||||
tensors. Note that passing an N-dimension list of compatible values will
|
||||
result in a N-dimension list of scalar tensors rather than a single Rank-N
|
||||
tensors. If you need different behavior, convert part of inputs to tensors
|
||||
with `tf.convert_to_tensor`.
|
||||
|
||||
Returns:
|
||||
Same data structure as if computation(*inputs) is called directly with some
|
||||
exceptions for correctness. Exceptions include:
|
||||
1) None output: a NoOp would be returned which control-depends on
|
||||
computation.
|
||||
2) Single value output: A tuple containing the value would be returned.
|
||||
3) Operation-only outputs: a NoOp would be returned which
|
||||
control-depends on computation.
|
||||
TODO(b/121383831): Investigate into removing these special cases.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
return _compile_internal(computation, inputs)
|
||||
|
||||
|
||||
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
|
||||
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
|
||||
|
||||
THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
|
||||
|
||||
The primary role of `XLACompileContext` is to mark operators inside a
|
||||
xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
|
||||
a unique name.
|
||||
|
||||
`ControlFlowContext` is used to perform the annotation since it integrates
|
||||
with Tensorflow constructs like ResourceVariables. For example, if a
|
||||
`ResourceVariable` is constructed inside a xla.compile() block, the
|
||||
`ResourceVariable` implementation can use
|
||||
`with ops.control_dependencies(None)` to build the variable's definition
|
||||
outside the compiled computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, pivot):
|
||||
"""Builds a new XLACompileContext.
|
||||
|
||||
Args:
|
||||
name: a unique name for the context, used to populate the
|
||||
`_xla_compile_id` attribute.
|
||||
pivot: a pivot node. Nodes in the XLACompileContext that do not have any
|
||||
inputs will have a control dependency on the pivot node. This ensures
|
||||
that nodes are correctly included in any enclosing control flow
|
||||
contexts.
|
||||
"""
|
||||
super(XLACompileContext, self).__init__()
|
||||
self._name = name
|
||||
self._name_as_bytes = compat.as_bytes(name)
|
||||
self._unsupported_ops = []
|
||||
self._pivot = pivot
|
||||
|
||||
def report_unsupported_operations(self):
|
||||
if self._unsupported_ops:
|
||||
op_str = '\n'.join([
|
||||
' %s (%s)' % (op.type, op.name)
|
||||
for op in self._unsupported_ops[:_MAX_WARNING_LINES]
|
||||
])
|
||||
logging.warning('%d unsupported operations found: \n%s',
|
||||
len(self._unsupported_ops), op_str)
|
||||
if len(self._unsupported_ops) > _MAX_WARNING_LINES:
|
||||
logging.warning('... and %d more',
|
||||
len(self._unsupported_ops) - _MAX_WARNING_LINES)
|
||||
|
||||
def _RemoveExternalControlEdges(self, op):
|
||||
"""Remove any external control dependency on this op."""
|
||||
internal_control_inputs = []
|
||||
external_control_inputs = []
|
||||
for x in op.control_inputs:
|
||||
# pylint: disable=protected-access
|
||||
is_internal_op = False
|
||||
ctxt = x._get_control_flow_context()
|
||||
while ctxt is not None:
|
||||
if ctxt == self:
|
||||
is_internal_op = True
|
||||
break
|
||||
ctxt = ctxt._outer_context
|
||||
if is_internal_op:
|
||||
internal_control_inputs.append(x)
|
||||
else:
|
||||
external_control_inputs.append(x)
|
||||
# pylint: enable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
op._remove_all_control_inputs()
|
||||
op._add_control_inputs(internal_control_inputs)
|
||||
# pylint: enable=protected-access
|
||||
return internal_control_inputs, external_control_inputs
|
||||
|
||||
def AddOp(self, op):
|
||||
"""Create op in XLACompileContext and notifies outer context recursively."""
|
||||
# pylint: disable=protected-access
|
||||
if op.type in _BLACKLISTED_OPS:
|
||||
logging.error(
|
||||
'Operation of type %s (%s) is not supported in XLA. Execution will '
|
||||
'fail if this op is used in the graph. ', op.type, op.name)
|
||||
|
||||
# TODO(ycao): Automatically disable summaries instead of reporting them.
|
||||
if op.type in _UNSUPPORTED_OPS:
|
||||
self._unsupported_ops.append(op)
|
||||
|
||||
if any(x.dtype._is_ref_dtype for x in op.inputs):
|
||||
raise NotImplementedError(
|
||||
'Non-resource Variables are not supported inside XLA computations '
|
||||
'(operator name: %s)' % op.name)
|
||||
|
||||
if _XLA_COMPILE_ATTR in op.node_def.attr:
|
||||
raise ValueError('XLA compiled computations cannot be nested, (operator '
|
||||
'name: %s)' % op.name)
|
||||
|
||||
op._set_attr(
|
||||
_XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
|
||||
|
||||
op.graph.prevent_feeding(op)
|
||||
op.graph.prevent_fetching(op)
|
||||
|
||||
# Remove any control edges from outer control flow contexts. These may cause
|
||||
# mismatched frame errors. An example is when one of op's inputs is
|
||||
# generated in a different While control flow context.
|
||||
(internal_control_inputs,
|
||||
external_control_inputs) = self._RemoveExternalControlEdges(op)
|
||||
|
||||
if not op.inputs:
|
||||
# Add a control edge from the control pivot to this op.
|
||||
if not internal_control_inputs:
|
||||
# pylint: disable=protected-access
|
||||
op._add_control_input(self._pivot)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
for index in xrange(len(op.inputs)):
|
||||
x = op.inputs[index]
|
||||
real_x = self.AddValue(x)
|
||||
if real_x != x:
|
||||
op._update_input(index, real_x) # pylint: disable=protected-access
|
||||
|
||||
if external_control_inputs:
|
||||
# Use an identity to pull control inputs as data inputs. Note that we
|
||||
# ignore ops which don't have outputs. TODO(phawkins): fix that.
|
||||
with ops.control_dependencies(None):
|
||||
self.Enter()
|
||||
external_control_inputs = [
|
||||
array_ops.identity(x.outputs[0]).op
|
||||
for x in external_control_inputs
|
||||
if x.outputs
|
||||
]
|
||||
self.Exit()
|
||||
# pylint: disable=protected-access
|
||||
op._add_control_inputs(external_control_inputs)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Mark op's outputs as seen by this context and any outer contexts.
|
||||
output_names = [x.name for x in op.outputs]
|
||||
context = self
|
||||
while context is not None:
|
||||
# pylint: disable=protected-access
|
||||
context._values.update(output_names)
|
||||
context = context._outer_context
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
def AddValue(self, val):
|
||||
"""Add `val` to the current context and its outer context recursively."""
|
||||
if val.name in self._values:
|
||||
# Use the real value if it comes from outer context.
|
||||
result = self._external_values.get(val.name)
|
||||
return val if result is None else result
|
||||
|
||||
result = val
|
||||
self._values.add(val.name)
|
||||
if self._outer_context:
|
||||
result = self._outer_context.AddValue(val)
|
||||
self._values.add(result.name)
|
||||
|
||||
self._external_values[val.name] = result
|
||||
|
||||
return result
|
||||
|
||||
def AddInnerOp(self, op):
|
||||
self.AddOp(op)
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
@property
|
||||
def grad_state(self):
|
||||
# Define the gradient loop state associated with the XLACompileContext to
|
||||
# be None as the XLACompileContext does not get nested nor does the
|
||||
# grad_state outside the XLACompileContext affect the graph inside so the
|
||||
# grad_state should be as if this is the top-level gradient state.
|
||||
return None
|
||||
|
||||
@property
|
||||
def back_prop(self):
|
||||
"""Forwards to the enclosing while context, if any."""
|
||||
if self.GetWhileContext():
|
||||
return self.GetWhileContext().back_prop
|
||||
return False
|
||||
|
||||
|
||||
def _compile_internal(computation, inputs=None):
|
||||
"""Builds graph operators that compiles and symbolically executes computation.
|
||||
|
||||
Args:
|
||||
computation: A Python function that builds the computation to compile and
|
||||
execute.
|
||||
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
|
||||
can be a nested structure containing values that are convertible to
|
||||
tensors. Note that passing an N-dimension list of compatible values will
|
||||
result in a N-dimension list of scalar tensors rather than a single Rank-N
|
||||
tensors. If you need different behavior, convert part of inputs to tensors
|
||||
with `tf.convert_to_tensor`.
|
||||
|
||||
Returns:
|
||||
Same data structure as if computation(*inputs) is called directly with some
|
||||
exceptions for correctness. Exceptions include: 1) None output 2) Single
|
||||
value output 3) Operation-only outputs
|
||||
Raises:
|
||||
ValueError: If any element in computation outputs is neither an operations
|
||||
or a value that can be converted to tensor.
|
||||
ValueError: If computation outputs is non-flat and contains any Operations.
|
||||
TypeError: If `inputs` is not a list or tuple.
|
||||
"""
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
|
||||
if not isinstance(inputs, collections.Sequence):
|
||||
raise TypeError('inputs must be a list')
|
||||
|
||||
# Flatten inputs.
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
# Converts inputs to Tensors.
|
||||
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
|
||||
|
||||
cluster_name = ops.get_default_graph().unique_name('cluster')
|
||||
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
|
||||
context = XLACompileContext(name=cluster_name, pivot=pivot)
|
||||
try:
|
||||
context.Enter()
|
||||
|
||||
# Add identity ops so even unused inputs are 'consumed' by the
|
||||
# computation.
|
||||
flat_inputs = [
|
||||
array_ops.identity(x, name='input_{}'.format(i))
|
||||
for i, x in enumerate(flat_inputs)
|
||||
]
|
||||
|
||||
# Re-pack flat_inputs in same structure as 'inputs'.
|
||||
computation_inputs = nest.pack_sequence_as(
|
||||
structure=inputs, flat_sequence=flat_inputs)
|
||||
|
||||
# Only resource variables work inside an XLA computation, so turn on
|
||||
# resource variables for the computation.
|
||||
vscope = variable_scope.get_variable_scope()
|
||||
saved_use_resource = vscope.use_resource
|
||||
vscope.set_use_resource(True)
|
||||
|
||||
with _disable_summary_context():
|
||||
outputs = computation(*computation_inputs)
|
||||
|
||||
# Restore variable scope after computation.
|
||||
vscope.set_use_resource(saved_use_resource)
|
||||
|
||||
outputs_is_flat = is_flat(outputs)
|
||||
if outputs_is_flat:
|
||||
output_tensors, control_deps = _postprocess_flat_outputs(outputs)
|
||||
else:
|
||||
output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
|
||||
|
||||
context.ExitResult(output_tensors)
|
||||
finally:
|
||||
context.report_unsupported_operations()
|
||||
context.Exit()
|
||||
|
||||
# When XLA computation returns only operations and no tensors, a NoOp
|
||||
# dependent on the operations in outputs is returned. Otherwise final
|
||||
# outputs would be empty and there is no way to trigger returned
|
||||
# operations.
|
||||
if not output_tensors:
|
||||
return control_flow_ops.group(control_deps, name='output_0')
|
||||
|
||||
output_tensors = [
|
||||
xla_ops.xla_cluster_output(o, name='output{}'.format(i))
|
||||
for i, o in enumerate(output_tensors)
|
||||
]
|
||||
|
||||
with ops.control_dependencies(control_deps):
|
||||
# Wraps the outputs in identity operators that carries control
|
||||
# dependencies.
|
||||
output_tensors = [
|
||||
array_ops.identity(o, name='output_%d' % i)
|
||||
for i, o in enumerate(output_tensors)
|
||||
]
|
||||
|
||||
# If `computation` returned non-flat output structure, pack output tensors
|
||||
# back into same structure.
|
||||
if not outputs_is_flat:
|
||||
output_tensors = nest.pack_sequence_as(
|
||||
structure=outputs, flat_sequence=output_tensors)
|
||||
|
||||
return output_tensors
|
||||
|
||||
|
||||
def is_flat(outputs):
|
||||
"""Checks if outputs is a flat structure.
|
||||
|
||||
Following structures and values are considered flat:
|
||||
1) None
|
||||
2) A single object
|
||||
3) A list or tuple of Tensors/Operations
|
||||
|
||||
The only structures that this function understands are sequences and
|
||||
dictionaries. E.g. this means that if outputs contains a single
|
||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
||||
if that Object cannot be converted to a Tensor.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
A boolean indicates whether outputs is flat.
|
||||
"""
|
||||
# If outputs is a list or tuple, check if it has any nested structure. If
|
||||
# there is, then outputs is non-flat.
|
||||
if isinstance(outputs, collections.Sequence):
|
||||
for o in outputs:
|
||||
if isinstance(o, collections.Sequence) or isinstance(o, dict):
|
||||
return False
|
||||
|
||||
# If outputs is a dict, it is non-flat.
|
||||
if isinstance(outputs, dict):
|
||||
return False
|
||||
|
||||
# Getting here means either outputs itself is a single non-structured value
|
||||
# or it is a flat list of single non-structured values.
|
||||
return True
|
||||
|
||||
|
||||
def _postprocess_flat_outputs(outputs):
|
||||
"""Validates flat outputs and adds back device assignments.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
Tensors and Operations extracted from outputs.
|
||||
"""
|
||||
# Following code segment is to preserve legacy behavior. Previously we only
|
||||
# supported flat outputs and thus for consistency it was nice to convert even
|
||||
# single element into a tuple. But now that we support arbitrary output
|
||||
# structure, this is no longer necessary.
|
||||
# TODO(b/121383831): Migrate all legacy use cases and delete this special
|
||||
# case.
|
||||
# If the computation returns `None`, make it an empty tuple.
|
||||
if outputs is None:
|
||||
outputs = tuple()
|
||||
# If the computation only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, collections.Sequence):
|
||||
outputs = (outputs,)
|
||||
|
||||
# Append `no_op` here so that return value of this function always contains
|
||||
# at least one op that can trigger XlaLaunch node.
|
||||
outputs += (control_flow_ops.no_op(),)
|
||||
try:
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
'XLA computation function return values must all either be Operations'
|
||||
' or convertible to Tensors. Got error: "%s"' % str(e))
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
'XLA computation function must return zero or more Tensor values '
|
||||
'followed by zero or more Operations.')
|
||||
|
||||
new_output_tensors = []
|
||||
for t in output_tensors:
|
||||
with ops.device(t.device if t.device else ''):
|
||||
new_output_tensors.append(array_ops.identity(t))
|
||||
|
||||
return new_output_tensors, output_operations
|
||||
|
||||
|
||||
def _postprocess_non_flat_outputs(outputs):
|
||||
"""Validates non-flat outputs and adds back device assignments.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
Tensors extracted from outputs and an empty list because Operations are not
|
||||
allowed in non-flat outputs..
|
||||
"""
|
||||
# Convert all non-Operation outputs to Tensors.
|
||||
new_output_tensors = []
|
||||
for o in nest.flatten(outputs):
|
||||
if isinstance(o, ops.Operation):
|
||||
raise ValueError(
|
||||
'xla.compile does not support Operation as return value in non-flat '
|
||||
'output structure. You can set returned Operations as control '
|
||||
'dependencies of returned Tensors so Operations are triggered when '
|
||||
'Tensors are evaluated. Operation found: "%s"' % o.name)
|
||||
|
||||
try:
|
||||
o = ops.convert_to_tensor(o)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
'XLA computation function return values must all either be '
|
||||
'Operations or convertible to Tensors. Got error: "%s"' % str(e))
|
||||
|
||||
# Makes sure even pass-through inputs/outputs are touched in compile
|
||||
# context by creating an Identity node inside compile context.
|
||||
with ops.device(o.device if o.device else ''):
|
||||
new_output_tensors.append(array_ops.identity(o))
|
||||
|
||||
return new_output_tensors, []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _disable_summary_context():
|
||||
"""Enters a context where all summary ops are skipped.
|
||||
|
||||
Summaries are not yet supported in xla.compile(). So we provide this context
|
||||
manager that can skip creating summary ops. This is a temporary workaround due
|
||||
to XLA not supporting summary ops.
|
||||
|
||||
Yields:
|
||||
None.
|
||||
"""
|
||||
original_skip_summary_func = summary_op_util.skip_summary
|
||||
summary_op_util.skip_summary = lambda: True
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
summary_op_util.skip_summary = original_skip_summary_func
|
||||
|
||||
|
||||
class _CapturedObject(object):
|
||||
"""A placeholder to capture an object."""
|
||||
|
||||
def __init__(self):
|
||||
self._object = None
|
||||
|
||||
def capture(self, o):
|
||||
if self._object:
|
||||
raise RuntimeError(
|
||||
'InternalError: _CapturedObject can capture only once. Please file '
|
||||
'bug.')
|
||||
|
||||
self._object = o
|
||||
|
||||
def get(self):
|
||||
return self._object
|
||||
|
||||
|
||||
def _get_scaffold(captured_scaffold_fn):
|
||||
"""Retrieves the Scaffold from `captured_scaffold_fn`."""
|
||||
scaffold_fn = captured_scaffold_fn.get()
|
||||
|
||||
if not scaffold_fn:
|
||||
return None
|
||||
|
||||
scaffold = scaffold_fn()
|
||||
if scaffold is None:
|
||||
raise ValueError(
|
||||
'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
|
||||
|
||||
return scaffold
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to an XLA function.
|
||||
|
||||
Args:
|
||||
func: the Python function that will be called to generate the body of an XLA
|
||||
computation graph.
|
||||
input_arity: the number of explicit arguments supplied by the caller.
|
||||
infeed_queue: if not None, the infeed queue that will supply
|
||||
additional arguments to the function.
|
||||
|
||||
Returns:
|
||||
None if function can be called with the supplied number of
|
||||
arguments, or an error string if it cannot.
|
||||
"""
|
||||
def format_error(complaint, quantity):
|
||||
return '%s %d argument%s' % (complaint, quantity, ''
|
||||
if quantity == 1 else 's')
|
||||
|
||||
num_args_supplied = input_arity
|
||||
if infeed_queue is not None:
|
||||
num_args_supplied += infeed_queue.number_of_tuple_elements
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
num_func_args = len(arg_spec.args)
|
||||
if arg_spec.defaults is None:
|
||||
num_func_defaults = 0
|
||||
else:
|
||||
num_func_defaults = len(arg_spec.defaults)
|
||||
min_func_args = num_func_args - num_func_defaults
|
||||
if num_args_supplied < min_func_args:
|
||||
# The required number of arguments is not enough to call the function.
|
||||
if num_func_defaults == 0 and arg_spec.varargs is None:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at least', min_func_args)
|
||||
if arg_spec.varargs is None and num_args_supplied > num_func_args:
|
||||
# The required number of arguments is too many to call the function.
|
||||
if num_func_defaults == 0:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at most', num_func_args)
|
||||
# Reaching here means either
|
||||
# 1) There are varargs, func can accept any number of arguments greater than
|
||||
# the minimum.
|
||||
# 2) Number of supplied arguments falls in range of acceptable argument count
|
||||
# of func.
|
||||
return None
|
267
tensorflow/python/compiler/xla/xla_test.py
Normal file
267
tensorflow/python/compiler/xla/xla_test.py
Normal file
@ -0,0 +1,267 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for python.compiler.xla.xla."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
_TRAIN = model_fn_lib.ModeKeys.TRAIN
|
||||
_EVAL = model_fn_lib.ModeKeys.EVAL
|
||||
_EXPECTED_LOSS = 1
|
||||
_EXPECTED_FEATURE = 2
|
||||
_EXPECTED_LABEL = 3
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/128927195')
|
||||
class XLACompileContextTest(test.TestCase):
|
||||
|
||||
def create_test_xla_compile_context(self):
|
||||
computation_name = ops.get_default_graph().unique_name('computation')
|
||||
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
|
||||
return xla.XLACompileContext(name=computation_name, pivot=pivot)
|
||||
|
||||
def test_report_unsupported_operations(self):
|
||||
"""Tests that unsupported operations are detected."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
dummy_tensor = constant_op.constant(1.1)
|
||||
audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
|
||||
histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
|
||||
image_summary = summary.image('image_summary', dummy_tensor)
|
||||
scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
|
||||
tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor)
|
||||
summary.merge(
|
||||
[
|
||||
audio_summary, histogram_summary, image_summary, scalar_summary,
|
||||
tensor_summary
|
||||
],
|
||||
name='merge_summary')
|
||||
logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
|
||||
context.Exit()
|
||||
|
||||
unsupported_ops_names = [op.name for op in context._unsupported_ops]
|
||||
self.assertEqual(unsupported_ops_names, [
|
||||
u'audio_summary', u'histogram_summary', u'image_summary',
|
||||
u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
|
||||
u'print_op'
|
||||
])
|
||||
|
||||
def test_resource_variable(self):
|
||||
"""Tests that resource variable usage is allowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', shape=(1), use_resource=True)
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
state_ops.assign(a, a + 1)
|
||||
context.Exit()
|
||||
|
||||
def test_non_resource_variable_error(self):
|
||||
"""Tests that non-resource variable usage is disallowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', shape=(1), use_resource=False)
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, 'Non-resource Variables are not supported inside '
|
||||
r'XLA computations \(operator name: Assign\)'):
|
||||
state_ops.assign(a, a + 1)
|
||||
context.Exit()
|
||||
|
||||
def test_nested_xla_compile_error(self):
|
||||
"""Tests that nested XLA computation leads to fatal error."""
|
||||
context1 = self.create_test_xla_compile_context()
|
||||
context1.Enter()
|
||||
|
||||
context2 = self.create_test_xla_compile_context()
|
||||
context2.Enter()
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'XLA compiled computations cannot be nested'):
|
||||
constant_op.constant(1)
|
||||
context2.Exit()
|
||||
context1.Exit()
|
||||
|
||||
def test_xla_compile_attr(self):
|
||||
"""Tests that ops are tagged with XLA compile ID attribute."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertIn('_xla_compile_id', op.op.node_def.attr)
|
||||
|
||||
def test_op_without_input(self):
|
||||
"""Tests that ops without inputs depend on pivot correctly."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
|
||||
self.assertIn(context._pivot, op.op.control_inputs)
|
||||
|
||||
def test_external_control_edges(self):
|
||||
"""Tests that external control edges are handled correctly."""
|
||||
i = constant_op.constant(1)
|
||||
op1 = constant_op.constant(1)
|
||||
|
||||
with ops.control_dependencies([op1]):
|
||||
op2 = constant_op.constant(1)
|
||||
self.assertIn(op1.op, op2.op.control_inputs)
|
||||
|
||||
def while_body(i):
|
||||
del i # unused
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
with ops.control_dependencies([op1]):
|
||||
op3 = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertNotIn(op1.op, op3.op.control_inputs)
|
||||
return op3
|
||||
|
||||
control_flow_ops.while_loop(
|
||||
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
|
||||
|
||||
def test_op_output_marked_as_seen(self):
|
||||
"""Tests that any op output is marked as seen in context."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
|
||||
self.assertIn(op.name, context._values)
|
||||
|
||||
def testOpIsInContext(self):
|
||||
"""Tests that XLACompileContext is recognized as an XLA context."""
|
||||
op1 = constant_op.constant(1)
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op2 = constant_op.constant(2)
|
||||
context.Exit()
|
||||
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
|
||||
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
|
||||
|
||||
def testOpPreventFeeding(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fed."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertFalse(op.graph.is_feedable(op.op))
|
||||
|
||||
def testOpPreventFetching(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fetched."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
op = constant_op.constant(1)
|
||||
context.Exit()
|
||||
self.assertFalse(op.graph.is_fetchable(op.op))
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/128927195')
|
||||
class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
"""Tests that arg checker works for functions with no varargs or defaults.
|
||||
"""
|
||||
|
||||
def func(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual('exactly 3 arguments',
|
||||
xla.check_function_argument_count(func, 2, None))
|
||||
queue = tpu_feed.InfeedQueue(2)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual('exactly 3 arguments',
|
||||
xla.check_function_argument_count(func, 2, queue))
|
||||
|
||||
def testDefaultArgs(self):
|
||||
"""Tests that arg checker works for a function with no varargs."""
|
||||
|
||||
def func(x, y, z=17):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
self.assertEqual('at most 3 arguments',
|
||||
xla.check_function_argument_count(func, 4, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
self.assertEqual('at most 3 arguments',
|
||||
xla.check_function_argument_count(func, 4, queue))
|
||||
|
||||
def testVarArgs(self):
|
||||
"""Tests that arg checker works for a function with varargs."""
|
||||
|
||||
def func(x, y, *z):
|
||||
return x + y + len(z)
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
|
||||
def testVarArgsAndDefaults(self):
|
||||
"""Tests that arg checker works for a function with varargs and defaults."""
|
||||
|
||||
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
|
||||
return x + y + z + len(q)
|
||||
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -83,6 +83,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
||||
"train/queue_runner/__init__.py",
|
||||
"user_ops/__init__.py",
|
||||
"version/__init__.py",
|
||||
"xla/__init__.py",
|
||||
"xla/experimental/__init__.py",
|
||||
# END GENERATED FILES
|
||||
]
|
||||
|
||||
|
@ -668,6 +668,10 @@ tf_module {
|
||||
name: "version"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "xla"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "zeros_initializer"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,11 @@
|
||||
path: "tensorflow.xla.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "compile"
|
||||
argspec: "args=[\'computation\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "jit_scope"
|
||||
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
|
||||
}
|
||||
}
|
7
tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt
Normal file
7
tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt
Normal file
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.xla"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -883,6 +883,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.nn.conv2d_transpose",
|
||||
"tf.test.compute_gradient":
|
||||
"tf.compat.v1.test.compute_gradient",
|
||||
"tf.xla.experimental.compile":
|
||||
"tf.compat.v1.xla.experimental.compile",
|
||||
"tf.xla.experimental.jit_scope":
|
||||
"tf.compat.v1.xla.experimental.jit_scope",
|
||||
}
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
@ -1547,6 +1547,17 @@ def _log_prob(self, x):
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testXlaExperimental(self):
|
||||
text = "tf.xla.experimental.jit_scope(0)"
|
||||
expected_text = "tf.compat.v1.xla.experimental.jit_scope(0)"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
text = "tf.xla.experimental.compile(0)"
|
||||
expected_text = "tf.compat.v1.xla.experimental.compile(0)"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -65,6 +65,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
"//tensorflow/python/autograph/pyct/testing:test_modules",
|
||||
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
|
||||
"//tensorflow/python/compiler:compiler",
|
||||
"//tensorflow/python:cond_v2",
|
||||
"//tensorflow/python:distributed_framework_test_lib",
|
||||
"//tensorflow/python:meta_graph_testdata",
|
||||
|
Loading…
x
Reference in New Issue
Block a user