De-dup tpu/xla.py compiler/xla/xla.py

PiperOrigin-RevId: 241598282
This commit is contained in:
Yanan Cao 2019-04-02 14:00:16 -07:00 committed by TensorFlower Gardener
parent 6d6e671b67
commit ebf22aecde
4 changed files with 3 additions and 109 deletions

View File

@ -137,7 +137,6 @@ py_library(
"tpu_strategy_util.py",
"tpu_system_metadata.py",
"training_loop.py",
"xla.py",
],
srcs_version = "PY2AND3",
deps = [
@ -164,6 +163,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/compiler/xla",
"//tensorflow/python/ops/losses",
"//tensorflow/python/tpu/profiler",
],

View File

@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.python.compat import compat as api_compat
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu import xla
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest

View File

@ -19,12 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu import xla
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):

View File

@ -1,106 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""XLA utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.util import tf_inspect
def is_flat(outputs):
"""Checks if outputs is a flat structure.
Following structures and values are considered flat:
1) None
2) A single object
3) A list or tuple of Tensors/Operations
The only structures that this function understands are sequences and
dictionaries. E.g. this means that if outputs contains a single
user-defined Object, it is considered to be flat. Errors are raised later on
if that Object cannot be converted to a Tensor.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
A boolean indicates whether outputs is flat.
"""
# If outputs is a list or tuple, check if it has any nested structure. If
# there is, then outputs is non-flat.
if isinstance(outputs, collections.Sequence):
for o in outputs:
if isinstance(o, collections.Sequence) or isinstance(o, dict):
return False
# If outputs is a dict, it is non-flat.
if isinstance(outputs, dict):
return False
# Getting here means either outputs itself is a single non-structured value
# or it is a flat list of single non-structured values.
return True
def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to an XLA function.
Args:
func: the Python function that will be called to generate the body of an XLA
computation graph.
input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
Returns:
None if function can be called with the supplied number of
arguments, or an error string if it cannot.
"""
def format_error(complaint, quantity):
return '%s %d argument%s' % (complaint, quantity, ''
if quantity == 1 else 's')
num_args_supplied = input_arity
if infeed_queue is not None:
num_args_supplied += infeed_queue.number_of_tuple_elements
arg_spec = tf_inspect.getargspec(func)
num_func_args = len(arg_spec.args)
if arg_spec.defaults is None:
num_func_defaults = 0
else:
num_func_defaults = len(arg_spec.defaults)
min_func_args = num_func_args - num_func_defaults
if num_args_supplied < min_func_args:
# The required number of arguments is not enough to call the function.
if num_func_defaults == 0 and arg_spec.varargs is None:
return format_error('exactly', num_func_args)
else:
return format_error('at least', min_func_args)
if arg_spec.varargs is None and num_args_supplied > num_func_args:
# The required number of arguments is too many to call the function.
if num_func_defaults == 0:
return format_error('exactly', num_func_args)
else:
return format_error('at most', num_func_args)
# Reaching here means either
# 1) There are varargs, func can accept any number of arguments greater than
# the minimum.
# 2) Number of supplied arguments falls in range of acceptable argument count
# of func.
return None