De-dup tpu/xla.py compiler/xla/xla.py
PiperOrigin-RevId: 241598282
This commit is contained in:
parent
6d6e671b67
commit
ebf22aecde
tensorflow/python/tpu
@ -137,7 +137,6 @@ py_library(
|
||||
"tpu_strategy_util.py",
|
||||
"tpu_system_metadata.py",
|
||||
"training_loop.py",
|
||||
"xla.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -164,6 +163,7 @@ py_library(
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/compiler/xla",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/tpu/profiler",
|
||||
],
|
||||
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||
from tensorflow.python.compat import compat as api_compat
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu import xla
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
@ -19,12 +19,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.tpu import tensor_tracer
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu import xla
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""XLA utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def is_flat(outputs):
|
||||
"""Checks if outputs is a flat structure.
|
||||
|
||||
Following structures and values are considered flat:
|
||||
1) None
|
||||
2) A single object
|
||||
3) A list or tuple of Tensors/Operations
|
||||
|
||||
The only structures that this function understands are sequences and
|
||||
dictionaries. E.g. this means that if outputs contains a single
|
||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
||||
if that Object cannot be converted to a Tensor.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
A boolean indicates whether outputs is flat.
|
||||
"""
|
||||
# If outputs is a list or tuple, check if it has any nested structure. If
|
||||
# there is, then outputs is non-flat.
|
||||
if isinstance(outputs, collections.Sequence):
|
||||
for o in outputs:
|
||||
if isinstance(o, collections.Sequence) or isinstance(o, dict):
|
||||
return False
|
||||
|
||||
# If outputs is a dict, it is non-flat.
|
||||
if isinstance(outputs, dict):
|
||||
return False
|
||||
|
||||
# Getting here means either outputs itself is a single non-structured value
|
||||
# or it is a flat list of single non-structured values.
|
||||
return True
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to an XLA function.
|
||||
|
||||
Args:
|
||||
func: the Python function that will be called to generate the body of an XLA
|
||||
computation graph.
|
||||
input_arity: the number of explicit arguments supplied by the caller.
|
||||
infeed_queue: if not None, the infeed queue that will supply
|
||||
additional arguments to the function.
|
||||
|
||||
Returns:
|
||||
None if function can be called with the supplied number of
|
||||
arguments, or an error string if it cannot.
|
||||
"""
|
||||
def format_error(complaint, quantity):
|
||||
return '%s %d argument%s' % (complaint, quantity, ''
|
||||
if quantity == 1 else 's')
|
||||
|
||||
num_args_supplied = input_arity
|
||||
if infeed_queue is not None:
|
||||
num_args_supplied += infeed_queue.number_of_tuple_elements
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
num_func_args = len(arg_spec.args)
|
||||
if arg_spec.defaults is None:
|
||||
num_func_defaults = 0
|
||||
else:
|
||||
num_func_defaults = len(arg_spec.defaults)
|
||||
min_func_args = num_func_args - num_func_defaults
|
||||
if num_args_supplied < min_func_args:
|
||||
# The required number of arguments is not enough to call the function.
|
||||
if num_func_defaults == 0 and arg_spec.varargs is None:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at least', min_func_args)
|
||||
if arg_spec.varargs is None and num_args_supplied > num_func_args:
|
||||
# The required number of arguments is too many to call the function.
|
||||
if num_func_defaults == 0:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at most', num_func_args)
|
||||
# Reaching here means either
|
||||
# 1) There are varargs, func can accept any number of arguments greater than
|
||||
# the minimum.
|
||||
# 2) Number of supplied arguments falls in range of acceptable argument count
|
||||
# of func.
|
||||
return None
|
Loading…
Reference in New Issue
Block a user