diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index db27dd79062..ff82714ce3e 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -137,7 +137,6 @@ py_library( "tpu_strategy_util.py", "tpu_system_metadata.py", "training_loop.py", - "xla.py", ], srcs_version = "PY2AND3", deps = [ @@ -164,6 +163,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/compiler/xla", "//tensorflow/python/ops/losses", "//tensorflow/python/tpu/profiler", ], diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 55273a5203e..d5d7ea266da 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding from tensorflow.python.compat import compat as api_compat +from tensorflow.python.compiler.xla import xla from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tpu_function -from tensorflow.python.tpu import xla from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.util import compat from tensorflow.python.util import nest diff --git a/tensorflow/python/tpu/training_loop.py b/tensorflow/python/tpu/training_loop.py index cffeb7e915a..06c84e56416 100644 --- a/tensorflow/python/tpu/training_loop.py +++ b/tensorflow/python/tpu/training_loop.py @@ -19,12 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.compiler.xla import xla from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.tpu import tensor_tracer from tensorflow.python.tpu import tpu_function -from tensorflow.python.tpu import xla def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): diff --git a/tensorflow/python/tpu/xla.py b/tensorflow/python/tpu/xla.py deleted file mode 100644 index 58476fae3d1..00000000000 --- a/tensorflow/python/tpu/xla.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""XLA utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.python.util import tf_inspect - - -def is_flat(outputs): - """Checks if outputs is a flat structure. - - Following structures and values are considered flat: - 1) None - 2) A single object - 3) A list or tuple of Tensors/Operations - - The only structures that this function understands are sequences and - dictionaries. E.g. this means that if outputs contains a single - user-defined Object, it is considered to be flat. Errors are raised later on - if that Object cannot be converted to a Tensor. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - A boolean indicates whether outputs is flat. - """ - # If outputs is a list or tuple, check if it has any nested structure. If - # there is, then outputs is non-flat. - if isinstance(outputs, collections.Sequence): - for o in outputs: - if isinstance(o, collections.Sequence) or isinstance(o, dict): - return False - - # If outputs is a dict, it is non-flat. - if isinstance(outputs, dict): - return False - - # Getting here means either outputs itself is a single non-structured value - # or it is a flat list of single non-structured values. - return True - - -def check_function_argument_count(func, input_arity, infeed_queue): - """Validate the number of input arguments to an XLA function. - - Args: - func: the Python function that will be called to generate the body of an XLA - computation graph. - input_arity: the number of explicit arguments supplied by the caller. - infeed_queue: if not None, the infeed queue that will supply - additional arguments to the function. - - Returns: - None if function can be called with the supplied number of - arguments, or an error string if it cannot. - """ - def format_error(complaint, quantity): - return '%s %d argument%s' % (complaint, quantity, '' - if quantity == 1 else 's') - - num_args_supplied = input_arity - if infeed_queue is not None: - num_args_supplied += infeed_queue.number_of_tuple_elements - arg_spec = tf_inspect.getargspec(func) - num_func_args = len(arg_spec.args) - if arg_spec.defaults is None: - num_func_defaults = 0 - else: - num_func_defaults = len(arg_spec.defaults) - min_func_args = num_func_args - num_func_defaults - if num_args_supplied < min_func_args: - # The required number of arguments is not enough to call the function. - if num_func_defaults == 0 and arg_spec.varargs is None: - return format_error('exactly', num_func_args) - else: - return format_error('at least', min_func_args) - if arg_spec.varargs is None and num_args_supplied > num_func_args: - # The required number of arguments is too many to call the function. - if num_func_defaults == 0: - return format_error('exactly', num_func_args) - else: - return format_error('at most', num_func_args) - # Reaching here means either - # 1) There are varargs, func can accept any number of arguments greater than - # the minimum. - # 2) Number of supplied arguments falls in range of acceptable argument count - # of func. - return None