From 8573a5abc86dad26c8b1965605847070b6c4997f Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 26 Mar 2019 15:15:34 -0700 Subject: [PATCH] Automated rollback of commit 302acb768a8f00605f84430fabd28b6daa2c4c77 PiperOrigin-RevId: 240438115 --- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/contrib/compiler/BUILD | 38 +- tensorflow/contrib/compiler/jit.py | 99 +-- tensorflow/contrib/compiler/xla.py | 548 +--------------- tensorflow/contrib/compiler/xla_test.py | 227 ------- tensorflow/python/__init__.py | 4 + tensorflow/python/compiler/BUILD | 4 +- tensorflow/python/compiler/xla/BUILD | 85 +++ tensorflow/python/compiler/xla/__init__.py | 24 + tensorflow/python/compiler/xla/jit.py | 120 ++++ .../compiler/xla}/jit_test.py | 7 +- tensorflow/python/compiler/xla/xla.py | 604 ++++++++++++++++++ tensorflow/python/compiler/xla/xla_test.py | 267 ++++++++ .../tools/api/generator/api_init_files_v1.bzl | 2 + .../tools/api/golden/v1/tensorflow.pbtxt | 4 + .../v1/tensorflow.xla.experimental.pbtxt | 11 + .../tools/api/golden/v1/tensorflow.xla.pbtxt | 7 + .../tools/compatibility/tf_upgrade_v2.py | 4 + .../tools/compatibility/tf_upgrade_v2_test.py | 11 + tensorflow/tools/pip_package/BUILD | 1 + 20 files changed, 1160 insertions(+), 908 deletions(-) create mode 100644 tensorflow/python/compiler/xla/BUILD create mode 100644 tensorflow/python/compiler/xla/__init__.py create mode 100644 tensorflow/python/compiler/xla/jit.py rename tensorflow/{contrib/compiler => python/compiler/xla}/jit_test.py (98%) create mode 100644 tensorflow/python/compiler/xla/xla.py create mode 100644 tensorflow/python/compiler/xla/xla_test.py create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.xla.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e1df032ba93..dda726198e6 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -10,6 +10,7 @@ package_group( "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", "//tensorflow/contrib/compiler/...", + "//tensorflow/python/compiler/...", ], ) diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 839682afdc6..773560fcd0b 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -24,49 +24,20 @@ py_library( srcs_version = "PY2AND3", deps = [ ":xla", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/compiler/xla:compiler_py", ], ) -cuda_py_test( - name = "jit_test", - size = "small", - srcs = ["jit_test.py"], - additional_deps = [ - ":compiler_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - xla_enabled = True, -) - py_library( name = "xla", srcs = ["xla.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/compiler/jit:xla_ops_py", - "//tensorflow/compiler/jit/ops:xla_ops_grad", "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:summary_op_util", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/estimator:estimator_py", ], ) @@ -79,17 +50,12 @@ cuda_py_test( "@absl_py//absl/testing:parameterized", "//tensorflow/compiler/tests:xla_test", "//tensorflow/contrib/tpu:tpu_estimator", - "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", ], tags = [ diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py index c516ab658d7..70898aeb974 100644 --- a/tensorflow/contrib/compiler/jit.py +++ b/tensorflow/contrib/compiler/jit.py @@ -18,101 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib +from tensorflow.python.compiler.xla import jit -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.framework import ops - - -_XLA_SCOPE_KEY = ("__xla_scope",) - - -class _XlaScope(object): - """Keeps track of previous XLA scope calls, and depth of current call.""" - - def __init__(self, count, depth): - self.count = count - self.depth = depth - - -@contextlib.contextmanager -def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): - """Enable or disable JIT compilation of operators within the scope. - - NOTE: This is an experimental feature. - - The compilation is a hint and only supported on a best-effort basis. - - Example usage: - with tf.contrib.compiler.experimental_jit_scope(): - c = tf.matmul(a, b) # compiled - with tf.contrib.compiler.experimental_jit_scope(compile_ops=False): - d = tf.matmul(a, c) # not compiled - with tf.contrib.compiler.experimental_jit_scope( - compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): - e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. - - Example of separate_compiled_gradients: - # In the example below, the computations for f, g and h will all be compiled - # in separate scopes. - with tf.contrib.compiler.experimental_jit_scope( - separate_compiled_gradients=True): - f = tf.matmul(a, b) - g = tf.gradients([f], [a, b], name='mygrads1') - h = tf.gradients([f], [a, b], name='mygrads2') - - Args: - compile_ops: Whether to enable or disable compilation in the scope. - Either a Python bool, or a callable that accepts the parameter - `node_def` and returns a python bool. - separate_compiled_gradients: If true put each gradient subgraph into a - separate compilation scope. This gives fine-grained control over which - portions of the graph will be compiled as a single unit. Compiling - gradients separately may yield better performance for some graphs. - The scope is named based on the scope of the forward computation as well - as the name of the gradients. As a result, the gradients will be compiled - in a scope that is separate from both the forward computation, and from - other gradients. - Yields: - The current scope, enabling or disabling compilation. - - """ - if callable(compile_ops): - def xla_compile(node_def): - return attr_value_pb2.AttrValue(b=compile_ops(node_def)) - else: - xla_compile = attr_value_pb2.AttrValue(b=compile_ops) - - attrs = { - "_XlaCompile": - xla_compile, - "_XlaSeparateCompiledGradients": - attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) - } - - # Find the singleton counter for the current scoped graph. If it - # doesn't exist, create one. - xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) - if not xla_scope_counter: - xla_scope_counter = _XlaScope(0, 0) - ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) - else: - xla_scope_counter = xla_scope_counter[0] - - if xla_scope_counter.depth == 0: - # If we're at the root xla scope, we can increase the counter so - # future calls to jit_scope use a different scope value. - # If we're already within a scope, we'll be fusing using the scope - # controlled by the parent. - attrs["_XlaScope"] = attr_value_pb2.AttrValue( - s=("jit_scope_%d" % xla_scope_counter.count).encode()) - xla_scope_counter.count += 1 - - xla_scope_counter.depth += 1 - - # pylint: disable=protected-access - with ops.get_default_graph()._attr_scope(attrs): - yield - # pylint: enable=protected-access - - xla_scope_counter.depth -= 1 +experimental_jit_scope = jit.experimental_jit_scope diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 2ccb27da12f..b8ccbf18e68 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -18,511 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import contextlib -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.compiler.jit.ops import xla_ops -from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.distribute import summary_op_util +from tensorflow.python.compiler.xla import xla from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat from tensorflow.python.util import function_utils -from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect -_XLA_COMPILE_ATTR = '_xla_compile_id' -_MAX_WARNING_LINES = 5 - -# Operations that indicate some error in the users graph. For example, XLA -# computation should not have any Placeholder op. -_BLACKLISTED_OPS = set([ - 'Placeholder', -]) - -# XLA doesn't currently support reading of intermediate tensors, thus some ops -# are not supported. -_UNSUPPORTED_OPS = set([ - 'AudioSummary', - 'AudioSummaryV2', - 'HistogramSummary', - 'ImageSummary', - 'MergeSummary', - 'Print', - 'ScalarSummary', - 'TensorSummary', - 'TensorSummaryV2', -]) - - -def compile(computation, inputs=None): # pylint: disable=redefined-builtin - """Builds an operator that compiles and runs `computation` with XLA. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. - - `computation` may return a list of operations and tensors. Tensors must - come before operations in the returned list. The return value of - `compile` is a list of tensors corresponding to the tensors from the - output of `computation`. - - All `Operation`s returned from `computation` will be executed when - evaluating any of the returned output tensors. - inputs: A list of inputs or `None` (equivalent to an empty list). Each input - can be a nested structure containing values that are convertible to - tensors. Note that passing an N-dimension list of compatible values will - result in a N-dimention list of scalar tensors rather than a single Rank-N - tensors. If you need different behavior, convert part of inputs to tensors - with `tf.convert_to_tensor`. - - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - """ - # pylint: disable=protected-access - return _compile_internal(computation, inputs) - - -class XLACompileContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside an XLA computation cluster. - - THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. - - The primary role of `XLACompileContext` is to mark operators inside a - xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is - a unique name. - - `ControlFlowContext` is used to perform the annotation since it integrates - with Tensorflow constructs like ResourceVariables. For example, if a - `ResourceVariable` is constructed inside a xla.compile() block, the - `ResourceVariable` implementation can use - `with ops.control_dependencies(None)` to build the variable's definition - outside the compiled computation. - """ - - def __init__(self, name, pivot): - """Builds a new XLACompileContext. - - Args: - name: a unique name for the context, used to populate the - `_xla_compile_id` attribute. - pivot: a pivot node. Nodes in the XLACompileContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ - super(XLACompileContext, self).__init__() - self._name = name - self._name_as_bytes = compat.as_bytes(name) - self._unsupported_ops = [] - self._pivot = pivot - - def report_unsupported_operations(self): - if self._unsupported_ops: - op_str = '\n'.join([ - ' %s (%s)' % (op.type, op.name) - for op in self._unsupported_ops[:_MAX_WARNING_LINES] - ]) - logging.warning('%d unsupported operations found: \n%s', - len(self._unsupported_ops), op_str) - if len(self._unsupported_ops) > _MAX_WARNING_LINES: - logging.warning('... and %d more', - len(self._unsupported_ops) - _MAX_WARNING_LINES) - - def _RemoveExternalControlEdges(self, op): - """Remove any external control dependency on this op.""" - internal_control_inputs = [] - external_control_inputs = [] - for x in op.control_inputs: - # pylint: disable=protected-access - is_internal_op = False - ctxt = x._get_control_flow_context() - while ctxt is not None: - if ctxt == self: - is_internal_op = True - break - ctxt = ctxt._outer_context - if is_internal_op: - internal_control_inputs.append(x) - else: - external_control_inputs.append(x) - # pylint: enable=protected-access - # pylint: disable=protected-access - op._remove_all_control_inputs() - op._add_control_inputs(internal_control_inputs) - # pylint: enable=protected-access - return internal_control_inputs, external_control_inputs - - def AddOp(self, op): - """Create op in XLACompileContext and notifies outer context recursively.""" - # pylint: disable=protected-access - if op.type in _BLACKLISTED_OPS: - logging.error( - 'Operation of type %s (%s) is not supported in XLA. Execution will ' - 'fail if this op is used in the graph. ', op.type, op.name) - - # TODO(ycao): Automatically disable summaries instead of reporting them. - if op.type in _UNSUPPORTED_OPS: - self._unsupported_ops.append(op) - - if any(x.dtype._is_ref_dtype for x in op.inputs): - raise NotImplementedError( - 'Non-resource Variables are not supported inside XLA computations ' - '(operator name: %s)' % op.name) - - if _XLA_COMPILE_ATTR in op.node_def.attr: - raise ValueError('XLA compiled computations cannot be nested, (operator ' - 'name: %s)' % op.name) - - op._set_attr( - _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) - - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) - - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. An example is when one of op's inputs is - # generated in a different While control flow context. - (internal_control_inputs, - external_control_inputs) = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not internal_control_inputs: - # pylint: disable=protected-access - op._add_control_input(self._pivot) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_control_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() - # pylint: disable=protected-access - op._add_control_inputs(external_control_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - """Add `val` to the current context and its outer context recursively.""" - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - - result = val - self._values.add(val.name) - if self._outer_context: - result = self._outer_context.AddValue(val) - self._values.add(result.name) - - self._external_values[val.name] = result - - return result - - def AddInnerOp(self, op): - self.AddOp(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - @property - def grad_state(self): - # Define the gradient loop state associated with the XLACompileContext to - # be None as the XLACompileContext does not get nested nor does the - # grad_state outside the XLACompileContext affect the graph inside so the - # grad_state should be as if this is the top-level gradient state. - return None - - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - -def _compile_internal(computation, inputs=None): - """Builds graph operators that compiles and symbolically executes computation. - - Args: - computation: A Python function that builds the computation to compile and - execute. - inputs: A list of inputs or `None` (equivalent to an empty list). Each input - can be a nested structure containing values that are convertible to - tensors. Note that passing an N-dimension list of compatible values will - result in a N-dimension list of scalar tensors rather than a single Rank-N - tensors. If you need different behavior, convert part of inputs to tensors - with `tf.convert_to_tensor`. - - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: 1) None output 2) Single - value output 3) Operation-only outputs - Raises: - ValueError: If any element in computation outputs is neither an operations - or a value that can be converted to tensor. - ValueError: If computation outputs is non-flat and contains any Operations. - TypeError: If `inputs` is not a list or tuple. - """ - if inputs is None: - inputs = [] - - if not isinstance(inputs, collections.Sequence): - raise TypeError('inputs must be a list') - - # Flatten inputs. - flat_inputs = nest.flatten(inputs) - # Converts inputs to Tensors. - flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] - - cluster_name = ops.get_default_graph().unique_name('cluster') - pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') - context = XLACompileContext(name=cluster_name, pivot=pivot) - try: - context.Enter() - - # Add identity ops so even unused inputs are 'consumed' by the - # computation. - flat_inputs = [ - array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(flat_inputs) - ] - - # Re-pack flat_inputs in same structure as 'inputs'. - computation_inputs = nest.pack_sequence_as( - structure=inputs, flat_sequence=flat_inputs) - - # Only resource variables work inside an XLA computation, so turn on - # resource variables for the computation. - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - vscope.set_use_resource(True) - - with _disable_summary_context(): - outputs = computation(*computation_inputs) - - # Restore variable scope after computation. - vscope.set_use_resource(saved_use_resource) - - outputs_is_flat = is_flat(outputs) - if outputs_is_flat: - output_tensors, control_deps = _postprocess_flat_outputs(outputs) - else: - output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - - context.ExitResult(output_tensors) - finally: - context.report_unsupported_operations() - context.Exit() - - # When XLA computation returns only operations and no tensors, a NoOp - # dependent on the operations in outputs is returned. Otherwise final - # outputs would be empty and there is no way to trigger returned - # operations. - if not output_tensors: - return control_flow_ops.group(control_deps, name='output_0') - - output_tensors = [ - xla_ops.xla_cluster_output(o, name='output{}'.format(i)) - for i, o in enumerate(output_tensors) - ] - - with ops.control_dependencies(control_deps): - # Wraps the outputs in identity operators that carries control - # dependencies. - output_tensors = [ - array_ops.identity(o, name='output_%d' % i) - for i, o in enumerate(output_tensors) - ] - - # If `computation` returned non-flat output structure, pack output tensors - # back into same structure. - if not outputs_is_flat: - output_tensors = nest.pack_sequence_as( - structure=outputs, flat_sequence=output_tensors) - - return output_tensors - - -def is_flat(outputs): - """Checks if outputs is a flat structure. - - Following structures and values are considered flat: - 1) None - 2) A single object - 3) A list or tuple of Tensors/Operations - - The only structures that this function understands are sequences and - dictionaries. E.g. this means that if outputs contains a single - user-defined Object, it is considered to be flat. Errors are raised later on - if that Object cannot be converted to a Tensor. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - A boolean indicates whether outputs is flat. - """ - # If outputs is a list or tuple, check if it has any nested structure. If - # there is, then outputs is non-flat. - if isinstance(outputs, collections.Sequence): - for o in outputs: - if isinstance(o, collections.Sequence) or isinstance(o, dict): - return False - - # If outputs is a dict, it is non-flat. - if isinstance(outputs, dict): - return False - - # Getting here means either outputs itself is a single non-structured value - # or it is a flat list of single non-structured values. - return True - - -def _postprocess_flat_outputs(outputs): - """Validates flat outputs and adds back device assignments. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - Tensors and Operations extracted from outputs. - """ - # Following code segment is to preserve legacy behavior. Previously we only - # supported flat outputs and thus for consistency it was nice to convert even - # single element into a tuple. But now that we support arbitrary output - # structure, this is no longer necessary. - # TODO(b/121383831): Migrate all legacy use cases and delete this special - # case. - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that return value of this function always contains - # at least one op that can trigger XlaLaunch node. - outputs += (control_flow_ops.no_op(),) - try: - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be Operations' - ' or convertible to Tensors. Got error: "%s"' % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - 'XLA computation function must return zero or more Tensor values ' - 'followed by zero or more Operations.') - - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else ''): - new_output_tensors.append(array_ops.identity(t)) - - return new_output_tensors, output_operations - - -def _postprocess_non_flat_outputs(outputs): - """Validates non-flat outputs and adds back device assignments. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - Tensors extracted from outputs and an empty list because Operations are not - allowed in non-flat outputs.. - """ - # Convert all non-Operation outputs to Tensors. - new_output_tensors = [] - for o in nest.flatten(outputs): - if isinstance(o, ops.Operation): - raise ValueError( - 'xla.compile does not support Operation as return value in non-flat ' - 'output structure. You can set returned Operations as control ' - 'dependencies of returned Tensors so Operations are triggered when ' - 'Tensors are evaluated. Operation found: "%s"' % o.name) - - try: - o = ops.convert_to_tensor(o) - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be ' - 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) - - # Makes sure even pass-through inputs/outputs are touched in compile - # context by creating an Identity node inside compile context. - with ops.device(o.device if o.device else ''): - new_output_tensors.append(array_ops.identity(o)) - - return new_output_tensors, [] - - -@contextlib.contextmanager -def _disable_summary_context(): - """Enters a context where all summary ops are skipped. - - Summaries are not yet supported in xla.compile(). So we provide this context - manager that can skip creating summary ops. This is a temporary workaround due - to XLA not supporting summary ops. - - Yields: - None. - """ - original_skip_summary_func = summary_op_util.skip_summary - summary_op_util.skip_summary = lambda: True - - try: - yield - finally: - summary_op_util.skip_summary = original_skip_summary_func +compile = xla.compile # pylint: disable=redefined-builtin +check_function_argument_count = xla.check_function_argument_count class _CapturedObject(object): """A placeholder to capture an object.""" @@ -788,51 +294,3 @@ def estimator_model_fn(target_model_fn=None): return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) return decorated(target_model_fn) if target_model_fn else decorated - - -def check_function_argument_count(func, input_arity, infeed_queue): - """Validate the number of input arguments to an XLA function. - - Args: - func: the Python function that will be called to generate the body of an XLA - computation graph. - input_arity: the number of explicit arguments supplied by the caller. - infeed_queue: if not None, the infeed queue that will supply - additional arguments to the function. - - Returns: - None if function can be called with the supplied number of - arguments, or an error string if it cannot. - """ - def format_error(complaint, quantity): - return '%s %d argument%s' % (complaint, quantity, '' - if quantity == 1 else 's') - - num_args_supplied = input_arity - if infeed_queue is not None: - num_args_supplied += infeed_queue.number_of_tuple_elements - arg_spec = tf_inspect.getargspec(func) - num_func_args = len(arg_spec.args) - if arg_spec.defaults is None: - num_func_defaults = 0 - else: - num_func_defaults = len(arg_spec.defaults) - min_func_args = num_func_args - num_func_defaults - if num_args_supplied < min_func_args: - # The required number of arguments is not enough to call the function. - if num_func_defaults == 0 and arg_spec.varargs is None: - return format_error('exactly', num_func_args) - else: - return format_error('at least', min_func_args) - if arg_spec.varargs is None and num_args_supplied > num_func_args: - # The required number of arguments is too many to call the function. - if num_func_defaults == 0: - return format_error('exactly', num_func_args) - else: - return format_error('at most', num_func_args) - # Reaching here means either - # 1) There are varargs, func can accept any number of arguments greater than - # the minimum. - # 2) Number of supplied arguments falls in range of acceptable argument count - # of func. - return None diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index c4384dcde75..0df7c3706aa 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -23,20 +23,13 @@ from absl.testing import parameterized from tensorflow.contrib.compiler import xla from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.training.python.training import hparam from tensorflow.python import summary from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import training @@ -48,226 +41,6 @@ _EXPECTED_FEATURE = 2 _EXPECTED_LABEL = 3 -class XLACompileContextTest(test.TestCase): - - def create_test_xla_compile_context(self): - computation_name = ops.get_default_graph().unique_name('computation') - pivot = control_flow_ops.no_op(name=computation_name + '/pivot') - return xla.XLACompileContext(name=computation_name, pivot=pivot) - - def test_report_unsupported_operations(self): - """Tests that unsupported operations are detected.""" - context = self.create_test_xla_compile_context() - context.Enter() - dummy_tensor = constant_op.constant(1.1) - audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5) - histogram_summary = summary.histogram('histogram_summary', dummy_tensor) - image_summary = summary.image('image_summary', dummy_tensor) - scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) - summary.merge( - [ - audio_summary, histogram_summary, image_summary, scalar_summary, - tensor_summary - ], - name='merge_summary') - logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op') - context.Exit() - - unsupported_ops_names = [op.name for op in context._unsupported_ops] - self.assertEqual(unsupported_ops_names, [ - u'audio_summary', u'histogram_summary', u'image_summary', - u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary', - u'print_op' - ]) - - def test_resource_variable(self): - """Tests that resource variable usage is allowed.""" - a = variable_scope.get_variable( - name='variable_a', shape=(1), use_resource=True) - - context = self.create_test_xla_compile_context() - context.Enter() - state_ops.assign(a, a + 1) - context.Exit() - - def test_non_resource_variable_error(self): - """Tests that non-resource variable usage is disallowed.""" - a = variable_scope.get_variable( - name='variable_a', shape=(1), use_resource=False) - - context = self.create_test_xla_compile_context() - context.Enter() - with self.assertRaisesRegexp( - NotImplementedError, 'Non-resource Variables are not supported inside ' - r'XLA computations \(operator name: Assign\)'): - state_ops.assign(a, a + 1) - context.Exit() - - def test_nested_xla_compile_error(self): - """Tests that nested XLA computation leads to fatal error.""" - context1 = self.create_test_xla_compile_context() - context1.Enter() - - context2 = self.create_test_xla_compile_context() - context2.Enter() - with self.assertRaisesRegexp(ValueError, - 'XLA compiled computations cannot be nested'): - constant_op.constant(1) - context2.Exit() - context1.Exit() - - def test_xla_compile_attr(self): - """Tests that ops are tagged with XLA compile ID attribute.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertIn('_xla_compile_id', op.op.node_def.attr) - - def test_op_without_input(self): - """Tests that ops without inputs depend on pivot correctly.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - - self.assertIn(context._pivot, op.op.control_inputs) - - def test_external_control_edges(self): - """Tests that external control edges are handled correctly.""" - i = constant_op.constant(1) - op1 = constant_op.constant(1) - - with ops.control_dependencies([op1]): - op2 = constant_op.constant(1) - self.assertIn(op1.op, op2.op.control_inputs) - - def while_body(i): - del i # unused - context = self.create_test_xla_compile_context() - context.Enter() - with ops.control_dependencies([op1]): - op3 = constant_op.constant(1) - context.Exit() - self.assertNotIn(op1.op, op3.op.control_inputs) - return op3 - - control_flow_ops.while_loop( - cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i]) - - def test_op_output_marked_as_seen(self): - """Tests that any op output is marked as seen in context.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - - self.assertIn(op.name, context._values) - - def testOpIsInContext(self): - """Tests that XLACompileContext is recognized as an XLA context.""" - op1 = constant_op.constant(1) - context = self.create_test_xla_compile_context() - context.Enter() - op2 = constant_op.constant(2) - context.Exit() - self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) - self.assertTrue(control_flow_util.IsInXLAContext(op2.op)) - - def testOpPreventFeeding(self): - """Tests that ops created inside XLACompileContext can not be fed.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertFalse(op.graph.is_feedable(op.op)) - - def testOpPreventFetching(self): - """Tests that ops created inside XLACompileContext can not be fetched.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertFalse(op.graph.is_fetchable(op.op)) - - -class CheckFunctionArgumentCountTest(test.TestCase): - - def testSimple(self): - """Tests that arg checker works for functions with no varargs or defaults. - """ - - def func(x, y, z): - return x + y + z - - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual('exactly 3 arguments', - xla.check_function_argument_count(func, 2, None)) - queue = tpu_feed.InfeedQueue(2) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual('exactly 3 arguments', - xla.check_function_argument_count(func, 2, queue)) - - def testDefaultArgs(self): - """Tests that arg checker works for a function with no varargs.""" - - def func(x, y, z=17): - return x + y + z - - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - self.assertEqual('at most 3 arguments', - xla.check_function_argument_count(func, 4, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - self.assertEqual('at most 3 arguments', - xla.check_function_argument_count(func, 4, queue)) - - def testVarArgs(self): - """Tests that arg checker works for a function with varargs.""" - - def func(x, y, *z): - return x + y + len(z) - - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - - def testVarArgsAndDefaults(self): - """Tests that arg checker works for a function with varargs and defaults.""" - - def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg - return x + y + z + len(q) - - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - - def _test_train_model_fn(features, labels, mode, params): """A dummy model_fn for testing purpose.""" del features, labels, params diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 0016b5beaa5..7d620fadd70 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -144,6 +144,10 @@ from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +# XLA JIT compiler APIs. +from tensorflow.python.compiler.xla import jit +from tensorflow.python.compiler.xla import xla + # Required due to `rnn` and `rnn_cell` not being imported in `nn` directly # (due to a circular dependency issue: rnn depends on layers). nn.dynamic_rnn = rnn.dynamic_rnn diff --git a/tensorflow/python/compiler/BUILD b/tensorflow/python/compiler/BUILD index 07209a9eca9..ccb95139f90 100644 --- a/tensorflow/python/compiler/BUILD +++ b/tensorflow/python/compiler/BUILD @@ -15,5 +15,7 @@ py_library( srcs_version = "PY2AND3", deps = if_not_windows([ "//tensorflow/python/compiler/tensorrt:init_py", - ]), + ]) + [ + "//tensorflow/python/compiler/xla:compiler_py", + ], ) diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD new file mode 100644 index 00000000000..90f178e0bb2 --- /dev/null +++ b/tensorflow/python/compiler/xla/BUILD @@ -0,0 +1,85 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "compiler_py", + srcs = [ + "__init__.py", + "jit.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":xla", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +cuda_py_test( + name = "jit_test", + size = "small", + srcs = ["jit_test.py"], + additional_deps = [ + ":compiler_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + xla_enabled = True, +) + +py_library( + name = "xla", + srcs = ["xla.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/compiler/jit:xla_ops_py", + "//tensorflow/compiler/jit/ops:xla_ops_grad", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:summary_op_util", + ], +) + +cuda_py_test( + name = "xla_test", + srcs = ["xla_test.py"], + additional_deps = [ + ":xla", + "@absl_py//absl/testing:parameterized", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:variable_scope", + ], + tags = [ + "no_mac", + "no_windows", + ], + xla_enabled = True, +) diff --git a/tensorflow/python/compiler/xla/__init__.py b/tensorflow/python/compiler/xla/__init__.py new file mode 100644 index 00000000000..eb395365828 --- /dev/null +++ b/tensorflow/python/compiler/xla/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module for controlling the Tensorflow/XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.python.compiler.xla import jit +from tensorflow.python.compiler.xla import xla +# pylint: enable=unused-import diff --git a/tensorflow/python/compiler/xla/jit.py b/tensorflow/python/compiler/xla/jit.py new file mode 100644 index 00000000000..24635ece389 --- /dev/null +++ b/tensorflow/python/compiler/xla/jit.py @@ -0,0 +1,120 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for controlling the Tensorflow/XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.util.tf_export import tf_export + + +_XLA_SCOPE_KEY = ("__xla_scope",) + + +class _XlaScope(object): + """Keeps track of previous XLA scope calls, and depth of current call.""" + + def __init__(self, count, depth): + self.count = count + self.depth = depth + + +@contextlib.contextmanager +@tf_export(v1=["xla.experimental.jit_scope"]) +def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): + """Enable or disable JIT compilation of operators within the scope. + + NOTE: This is an experimental feature. + + The compilation is a hint and only supported on a best-effort basis. + + Example usage: + with tf.xla.experimental.jit_scope(): + c = tf.matmul(a, b) # compiled + with tf.xla.experimental.jit_scope(compile_ops=False): + d = tf.matmul(a, c) # not compiled + with tf.xla.experimental.jit_scope( + compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): + e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. + + Example of separate_compiled_gradients: + # In the example below, the computations for f, g and h will all be compiled + # in separate scopes. + with tf.xla.experimental.jit_scope( + separate_compiled_gradients=True): + f = tf.matmul(a, b) + g = tf.gradients([f], [a, b], name='mygrads1') + h = tf.gradients([f], [a, b], name='mygrads2') + + Args: + compile_ops: Whether to enable or disable compilation in the scope. + Either a Python bool, or a callable that accepts the parameter + `node_def` and returns a python bool. + separate_compiled_gradients: If true put each gradient subgraph into a + separate compilation scope. This gives fine-grained control over which + portions of the graph will be compiled as a single unit. Compiling + gradients separately may yield better performance for some graphs. + The scope is named based on the scope of the forward computation as well + as the name of the gradients. As a result, the gradients will be compiled + in a scope that is separate from both the forward computation, and from + other gradients. + Yields: + The current scope, enabling or disabling compilation. + + """ + if callable(compile_ops): + def xla_compile(node_def): + return attr_value_pb2.AttrValue(b=compile_ops(node_def)) + else: + xla_compile = attr_value_pb2.AttrValue(b=compile_ops) + + attrs = { + "_XlaCompile": + xla_compile, + "_XlaSeparateCompiledGradients": + attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) + } + + # Find the singleton counter for the current scoped graph. If it + # doesn't exist, create one. + xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) + if not xla_scope_counter: + xla_scope_counter = _XlaScope(0, 0) + ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) + else: + xla_scope_counter = xla_scope_counter[0] + + if xla_scope_counter.depth == 0: + # If we're at the root xla scope, we can increase the counter so + # future calls to jit_scope use a different scope value. + # If we're already within a scope, we'll be fusing using the scope + # controlled by the parent. + attrs["_XlaScope"] = attr_value_pb2.AttrValue( + s=("jit_scope_%d" % xla_scope_counter.count).encode()) + xla_scope_counter.count += 1 + + xla_scope_counter.depth += 1 + + # pylint: disable=protected-access + with ops.get_default_graph()._attr_scope(attrs): + yield + # pylint: enable=protected-access + + xla_scope_counter.depth -= 1 diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/python/compiler/xla/jit_test.py similarity index 98% rename from tensorflow/contrib/compiler/jit_test.py rename to tensorflow/python/compiler/xla/jit_test.py index 3e631b59094..aeeeb3acf0b 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/python/compiler/xla/jit_test.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for contrib.compiler.jit.""" +"""Tests for python.compiler.xla.jit.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.compiler import jit +from tensorflow.python.compiler.xla import jit from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -44,6 +45,7 @@ def enable_jit_nonstateful(node_def): raise ValueError("Unregistered op being created: %s" % node_def) +@test_util.run_v1_only("b/128927195") class JITTest(test.TestCase): def compute(self, use_jit, compute_fn): @@ -170,6 +172,7 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) +@test_util.run_v1_only("b/128927195") class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): diff --git a/tensorflow/python/compiler/xla/xla.py b/tensorflow/python/compiler/xla/xla.py new file mode 100644 index 00000000000..0d21ed72d0d --- /dev/null +++ b/tensorflow/python/compiler/xla/xla.py @@ -0,0 +1,604 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""xla is an experimental library that provides XLA support APIs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.jit.ops import xla_ops +from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.distribute import summary_op_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export + +_XLA_COMPILE_ATTR = '_xla_compile_id' +_MAX_WARNING_LINES = 5 + +# Operations that indicate some error in the users graph. For example, XLA +# computation should not have any Placeholder op. +_BLACKLISTED_OPS = set([ + 'Placeholder', +]) + +# XLA doesn't currently support reading of intermediate tensors, thus some ops +# are not supported. +_UNSUPPORTED_OPS = set([ + 'AudioSummary', + 'AudioSummaryV2', + 'HistogramSummary', + 'ImageSummary', + 'MergeSummary', + 'Print', + 'ScalarSummary', + 'TensorSummary', + 'TensorSummaryV2', +]) + + +@tf_export(v1=['xla.experimental.compile']) +def compile(computation, inputs=None): # pylint: disable=redefined-builtin + """Builds an operator that compiles and runs `computation` with XLA. + + Args: + computation: A Python function that builds a computation to apply to the + input. If the function takes n inputs, 'inputs' should be a list of n + tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `compile` is a list of tensors corresponding to the tensors from the + output of `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimension list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. + + Returns: + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. + """ + # pylint: disable=protected-access + return _compile_internal(computation, inputs) + + +class XLACompileContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside an XLA computation cluster. + + THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. + + The primary role of `XLACompileContext` is to mark operators inside a + xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is + a unique name. + + `ControlFlowContext` is used to perform the annotation since it integrates + with Tensorflow constructs like ResourceVariables. For example, if a + `ResourceVariable` is constructed inside a xla.compile() block, the + `ResourceVariable` implementation can use + `with ops.control_dependencies(None)` to build the variable's definition + outside the compiled computation. + """ + + def __init__(self, name, pivot): + """Builds a new XLACompileContext. + + Args: + name: a unique name for the context, used to populate the + `_xla_compile_id` attribute. + pivot: a pivot node. Nodes in the XLACompileContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ + super(XLACompileContext, self).__init__() + self._name = name + self._name_as_bytes = compat.as_bytes(name) + self._unsupported_ops = [] + self._pivot = pivot + + def report_unsupported_operations(self): + if self._unsupported_ops: + op_str = '\n'.join([ + ' %s (%s)' % (op.type, op.name) + for op in self._unsupported_ops[:_MAX_WARNING_LINES] + ]) + logging.warning('%d unsupported operations found: \n%s', + len(self._unsupported_ops), op_str) + if len(self._unsupported_ops) > _MAX_WARNING_LINES: + logging.warning('... and %d more', + len(self._unsupported_ops) - _MAX_WARNING_LINES) + + def _RemoveExternalControlEdges(self, op): + """Remove any external control dependency on this op.""" + internal_control_inputs = [] + external_control_inputs = [] + for x in op.control_inputs: + # pylint: disable=protected-access + is_internal_op = False + ctxt = x._get_control_flow_context() + while ctxt is not None: + if ctxt == self: + is_internal_op = True + break + ctxt = ctxt._outer_context + if is_internal_op: + internal_control_inputs.append(x) + else: + external_control_inputs.append(x) + # pylint: enable=protected-access + # pylint: disable=protected-access + op._remove_all_control_inputs() + op._add_control_inputs(internal_control_inputs) + # pylint: enable=protected-access + return internal_control_inputs, external_control_inputs + + def AddOp(self, op): + """Create op in XLACompileContext and notifies outer context recursively.""" + # pylint: disable=protected-access + if op.type in _BLACKLISTED_OPS: + logging.error( + 'Operation of type %s (%s) is not supported in XLA. Execution will ' + 'fail if this op is used in the graph. ', op.type, op.name) + + # TODO(ycao): Automatically disable summaries instead of reporting them. + if op.type in _UNSUPPORTED_OPS: + self._unsupported_ops.append(op) + + if any(x.dtype._is_ref_dtype for x in op.inputs): + raise NotImplementedError( + 'Non-resource Variables are not supported inside XLA computations ' + '(operator name: %s)' % op.name) + + if _XLA_COMPILE_ATTR in op.node_def.attr: + raise ValueError('XLA compiled computations cannot be nested, (operator ' + 'name: %s)' % op.name) + + op._set_attr( + _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) + + op.graph.prevent_feeding(op) + op.graph.prevent_fetching(op) + + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. An example is when one of op's inputs is + # generated in a different While control flow context. + (internal_control_inputs, + external_control_inputs) = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not internal_control_inputs: + # pylint: disable=protected-access + op._add_control_input(self._pivot) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_control_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_control_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + """Add `val` to the current context and its outer context recursively.""" + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + + result = val + self._values.add(val.name) + if self._outer_context: + result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + + return result + + def AddInnerOp(self, op): + self.AddOp(op) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + @property + def grad_state(self): + # Define the gradient loop state associated with the XLACompileContext to + # be None as the XLACompileContext does not get nested nor does the + # grad_state outside the XLACompileContext affect the graph inside so the + # grad_state should be as if this is the top-level gradient state. + return None + + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False + + +def _compile_internal(computation, inputs=None): + """Builds graph operators that compiles and symbolically executes computation. + + Args: + computation: A Python function that builds the computation to compile and + execute. + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimension list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. + + Returns: + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: 1) None output 2) Single + value output 3) Operation-only outputs + Raises: + ValueError: If any element in computation outputs is neither an operations + or a value that can be converted to tensor. + ValueError: If computation outputs is non-flat and contains any Operations. + TypeError: If `inputs` is not a list or tuple. + """ + if inputs is None: + inputs = [] + + if not isinstance(inputs, collections.Sequence): + raise TypeError('inputs must be a list') + + # Flatten inputs. + flat_inputs = nest.flatten(inputs) + # Converts inputs to Tensors. + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] + + cluster_name = ops.get_default_graph().unique_name('cluster') + pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') + context = XLACompileContext(name=cluster_name, pivot=pivot) + try: + context.Enter() + + # Add identity ops so even unused inputs are 'consumed' by the + # computation. + flat_inputs = [ + array_ops.identity(x, name='input_{}'.format(i)) + for i, x in enumerate(flat_inputs) + ] + + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + + # Only resource variables work inside an XLA computation, so turn on + # resource variables for the computation. + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + vscope.set_use_resource(True) + + with _disable_summary_context(): + outputs = computation(*computation_inputs) + + # Restore variable scope after computation. + vscope.set_use_resource(saved_use_resource) + + outputs_is_flat = is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) + + context.ExitResult(output_tensors) + finally: + context.report_unsupported_operations() + context.Exit() + + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + if not output_tensors: + return control_flow_ops.group(control_deps, name='output_0') + + output_tensors = [ + xla_ops.xla_cluster_output(o, name='output{}'.format(i)) + for i, o in enumerate(output_tensors) + ] + + with ops.control_dependencies(control_deps): + # Wraps the outputs in identity operators that carries control + # dependencies. + output_tensors = [ + array_ops.identity(o, name='output_%d' % i) + for i, o in enumerate(output_tensors) + ] + + # If `computation` returned non-flat output structure, pack output tensors + # back into same structure. + if not outputs_is_flat: + output_tensors = nest.pack_sequence_as( + structure=outputs, flat_sequence=output_tensors) + + return output_tensors + + +def is_flat(outputs): + """Checks if outputs is a flat structure. + + Following structures and values are considered flat: + 1) None + 2) A single object + 3) A list or tuple of Tensors/Operations + + The only structures that this function understands are sequences and + dictionaries. E.g. this means that if outputs contains a single + user-defined Object, it is considered to be flat. Errors are raised later on + if that Object cannot be converted to a Tensor. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + A boolean indicates whether outputs is flat. + """ + # If outputs is a list or tuple, check if it has any nested structure. If + # there is, then outputs is non-flat. + if isinstance(outputs, collections.Sequence): + for o in outputs: + if isinstance(o, collections.Sequence) or isinstance(o, dict): + return False + + # If outputs is a dict, it is non-flat. + if isinstance(outputs, dict): + return False + + # Getting here means either outputs itself is a single non-structured value + # or it is a flat list of single non-structured values. + return True + + +def _postprocess_flat_outputs(outputs): + """Validates flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + return new_output_tensors, output_operations + + +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + # Convert all non-Operation outputs to Tensors. + new_output_tensors = [] + for o in nest.flatten(outputs): + if isinstance(o, ops.Operation): + raise ValueError( + 'xla.compile does not support Operation as return value in non-flat ' + 'output structure. You can set returned Operations as control ' + 'dependencies of returned Tensors so Operations are triggered when ' + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be ' + 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) + + # Makes sure even pass-through inputs/outputs are touched in compile + # context by creating an Identity node inside compile context. + with ops.device(o.device if o.device else ''): + new_output_tensors.append(array_ops.identity(o)) + + return new_output_tensors, [] + + +@contextlib.contextmanager +def _disable_summary_context(): + """Enters a context where all summary ops are skipped. + + Summaries are not yet supported in xla.compile(). So we provide this context + manager that can skip creating summary ops. This is a temporary workaround due + to XLA not supporting summary ops. + + Yields: + None. + """ + original_skip_summary_func = summary_op_util.skip_summary + summary_op_util.skip_summary = lambda: True + + try: + yield + finally: + summary_op_util.skip_summary = original_skip_summary_func + + +class _CapturedObject(object): + """A placeholder to capture an object.""" + + def __init__(self): + self._object = None + + def capture(self, o): + if self._object: + raise RuntimeError( + 'InternalError: _CapturedObject can capture only once. Please file ' + 'bug.') + + self._object = o + + def get(self): + return self._object + + +def _get_scaffold(captured_scaffold_fn): + """Retrieves the Scaffold from `captured_scaffold_fn`.""" + scaffold_fn = captured_scaffold_fn.get() + + if not scaffold_fn: + return None + + scaffold = scaffold_fn() + if scaffold is None: + raise ValueError( + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') + + return scaffold + + +def check_function_argument_count(func, input_arity, infeed_queue): + """Validate the number of input arguments to an XLA function. + + Args: + func: the Python function that will be called to generate the body of an XLA + computation graph. + input_arity: the number of explicit arguments supplied by the caller. + infeed_queue: if not None, the infeed queue that will supply + additional arguments to the function. + + Returns: + None if function can be called with the supplied number of + arguments, or an error string if it cannot. + """ + def format_error(complaint, quantity): + return '%s %d argument%s' % (complaint, quantity, '' + if quantity == 1 else 's') + + num_args_supplied = input_arity + if infeed_queue is not None: + num_args_supplied += infeed_queue.number_of_tuple_elements + arg_spec = tf_inspect.getargspec(func) + num_func_args = len(arg_spec.args) + if arg_spec.defaults is None: + num_func_defaults = 0 + else: + num_func_defaults = len(arg_spec.defaults) + min_func_args = num_func_args - num_func_defaults + if num_args_supplied < min_func_args: + # The required number of arguments is not enough to call the function. + if num_func_defaults == 0 and arg_spec.varargs is None: + return format_error('exactly', num_func_args) + else: + return format_error('at least', min_func_args) + if arg_spec.varargs is None and num_args_supplied > num_func_args: + # The required number of arguments is too many to call the function. + if num_func_defaults == 0: + return format_error('exactly', num_func_args) + else: + return format_error('at most', num_func_args) + # Reaching here means either + # 1) There are varargs, func can accept any number of arguments greater than + # the minimum. + # 2) Number of supplied arguments falls in range of acceptable argument count + # of func. + return None diff --git a/tensorflow/python/compiler/xla/xla_test.py b/tensorflow/python/compiler/xla/xla_test.py new file mode 100644 index 00000000000..ab5fd429a75 --- /dev/null +++ b/tensorflow/python/compiler/xla/xla_test.py @@ -0,0 +1,267 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for python.compiler.xla.xla.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.python import summary +from tensorflow.python.compiler.xla import xla +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +_TRAIN = model_fn_lib.ModeKeys.TRAIN +_EVAL = model_fn_lib.ModeKeys.EVAL +_EXPECTED_LOSS = 1 +_EXPECTED_FEATURE = 2 +_EXPECTED_LABEL = 3 + + +@test_util.run_v1_only('b/128927195') +class XLACompileContextTest(test.TestCase): + + def create_test_xla_compile_context(self): + computation_name = ops.get_default_graph().unique_name('computation') + pivot = control_flow_ops.no_op(name=computation_name + '/pivot') + return xla.XLACompileContext(name=computation_name, pivot=pivot) + + def test_report_unsupported_operations(self): + """Tests that unsupported operations are detected.""" + context = self.create_test_xla_compile_context() + context.Enter() + dummy_tensor = constant_op.constant(1.1) + audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5) + histogram_summary = summary.histogram('histogram_summary', dummy_tensor) + image_summary = summary.image('image_summary', dummy_tensor) + scalar_summary = summary.scalar('scalar_summary', dummy_tensor) + tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) + summary.merge( + [ + audio_summary, histogram_summary, image_summary, scalar_summary, + tensor_summary + ], + name='merge_summary') + logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op') + context.Exit() + + unsupported_ops_names = [op.name for op in context._unsupported_ops] + self.assertEqual(unsupported_ops_names, [ + u'audio_summary', u'histogram_summary', u'image_summary', + u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary', + u'print_op' + ]) + + def test_resource_variable(self): + """Tests that resource variable usage is allowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=True) + + context = self.create_test_xla_compile_context() + context.Enter() + state_ops.assign(a, a + 1) + context.Exit() + + def test_non_resource_variable_error(self): + """Tests that non-resource variable usage is disallowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=False) + + context = self.create_test_xla_compile_context() + context.Enter() + with self.assertRaisesRegexp( + NotImplementedError, 'Non-resource Variables are not supported inside ' + r'XLA computations \(operator name: Assign\)'): + state_ops.assign(a, a + 1) + context.Exit() + + def test_nested_xla_compile_error(self): + """Tests that nested XLA computation leads to fatal error.""" + context1 = self.create_test_xla_compile_context() + context1.Enter() + + context2 = self.create_test_xla_compile_context() + context2.Enter() + with self.assertRaisesRegexp(ValueError, + 'XLA compiled computations cannot be nested'): + constant_op.constant(1) + context2.Exit() + context1.Exit() + + def test_xla_compile_attr(self): + """Tests that ops are tagged with XLA compile ID attribute.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertIn('_xla_compile_id', op.op.node_def.attr) + + def test_op_without_input(self): + """Tests that ops without inputs depend on pivot correctly.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(context._pivot, op.op.control_inputs) + + def test_external_control_edges(self): + """Tests that external control edges are handled correctly.""" + i = constant_op.constant(1) + op1 = constant_op.constant(1) + + with ops.control_dependencies([op1]): + op2 = constant_op.constant(1) + self.assertIn(op1.op, op2.op.control_inputs) + + def while_body(i): + del i # unused + context = self.create_test_xla_compile_context() + context.Enter() + with ops.control_dependencies([op1]): + op3 = constant_op.constant(1) + context.Exit() + self.assertNotIn(op1.op, op3.op.control_inputs) + return op3 + + control_flow_ops.while_loop( + cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i]) + + def test_op_output_marked_as_seen(self): + """Tests that any op output is marked as seen in context.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(op.name, context._values) + + def testOpIsInContext(self): + """Tests that XLACompileContext is recognized as an XLA context.""" + op1 = constant_op.constant(1) + context = self.create_test_xla_compile_context() + context.Enter() + op2 = constant_op.constant(2) + context.Exit() + self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) + self.assertTrue(control_flow_util.IsInXLAContext(op2.op)) + + def testOpPreventFeeding(self): + """Tests that ops created inside XLACompileContext can not be fed.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_feedable(op.op)) + + def testOpPreventFetching(self): + """Tests that ops created inside XLACompileContext can not be fetched.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_fetchable(op.op)) + + +@test_util.run_v1_only('b/128927195') +class CheckFunctionArgumentCountTest(test.TestCase): + + def testSimple(self): + """Tests that arg checker works for functions with no varargs or defaults. + """ + + def func(x, y, z): + return x + y + z + + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual('exactly 3 arguments', + xla.check_function_argument_count(func, 2, None)) + queue = tpu_feed.InfeedQueue(2) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual('exactly 3 arguments', + xla.check_function_argument_count(func, 2, queue)) + + def testDefaultArgs(self): + """Tests that arg checker works for a function with no varargs.""" + + def func(x, y, z=17): + return x + y + z + + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + self.assertEqual('at most 3 arguments', + xla.check_function_argument_count(func, 4, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + self.assertEqual('at most 3 arguments', + xla.check_function_argument_count(func, 4, queue)) + + def testVarArgs(self): + """Tests that arg checker works for a function with varargs.""" + + def func(x, y, *z): + return x + y + len(z) + + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + + def testVarArgsAndDefaults(self): + """Tests that arg checker works for a function with varargs and defaults.""" + + def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg + return x + y + z + len(q) + + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 66703c0f99e..9be2b2daf97 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -83,6 +83,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "train/queue_runner/__init__.py", "user_ops/__init__.py", "version/__init__.py", + "xla/__init__.py", + "xla/experimental/__init__.py", # END GENERATED FILES ] diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 462ae112882..c034a59ca2e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -668,6 +668,10 @@ tf_module { name: "version" mtype: "" } + member { + name: "xla" + mtype: "" + } member { name: "zeros_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.xla.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.xla.experimental.pbtxt new file mode 100644 index 00000000000..7b2eda66b5f --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.xla.experimental.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.xla.experimental" +tf_module { + member_method { + name: "compile" + argspec: "args=[\'computation\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "jit_scope" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt new file mode 100644 index 00000000000..9d4d777207e --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.xla.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.xla" +tf_module { + member { + name: "experimental" + mtype: "" + } +} diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 0306cf67dd8..3eddf8d4dd8 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -883,6 +883,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.nn.conv2d_transpose", "tf.test.compute_gradient": "tf.compat.v1.test.compute_gradient", + "tf.xla.experimental.compile": + "tf.compat.v1.xla.experimental.compile", + "tf.xla.experimental.jit_scope": + "tf.compat.v1.xla.experimental.jit_scope", } # pylint: enable=line-too-long diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 9bb16849756..23bca995413 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -1547,6 +1547,17 @@ def _log_prob(self, x): _, _, _, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) + def testXlaExperimental(self): + text = "tf.xla.experimental.jit_scope(0)" + expected_text = "tf.compat.v1.xla.experimental.jit_scope(0)" + _, _, _, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + + text = "tf.xla.experimental.compile(0)" + expected_text = "tf.compat.v1.xla.experimental.compile(0)" + _, _, _, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + class TestUpgradeFiles(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index f556b53dfb0..cd39cb9b1e3 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -65,6 +65,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/pyct/testing:test_modules", "//tensorflow/python/autograph/pyct/common_transformers:common_transformers", + "//tensorflow/python/compiler:compiler", "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata",