De-dup tpu/xla.py compiler/xla/xla.py
PiperOrigin-RevId: 241598282
This commit is contained in:
parent
6d6e671b67
commit
ebf22aecde
@ -137,7 +137,6 @@ py_library(
|
|||||||
"tpu_strategy_util.py",
|
"tpu_strategy_util.py",
|
||||||
"tpu_system_metadata.py",
|
"tpu_system_metadata.py",
|
||||||
"training_loop.py",
|
"training_loop.py",
|
||||||
"xla.py",
|
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
@ -164,6 +163,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python/compiler/xla",
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
"//tensorflow/python/tpu/profiler",
|
"//tensorflow/python/tpu/profiler",
|
||||||
],
|
],
|
||||||
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||||
from tensorflow.python.compat import compat as api_compat
|
from tensorflow.python.compat import compat as api_compat
|
||||||
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import tpu_function
|
from tensorflow.python.tpu import tpu_function
|
||||||
from tensorflow.python.tpu import xla
|
|
||||||
from tensorflow.python.tpu.ops import tpu_ops
|
from tensorflow.python.tpu.ops import tpu_ops
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
@ -19,12 +19,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.tpu import tensor_tracer
|
from tensorflow.python.tpu import tensor_tracer
|
||||||
from tensorflow.python.tpu import tpu_function
|
from tensorflow.python.tpu import tpu_function
|
||||||
from tensorflow.python.tpu import xla
|
|
||||||
|
|
||||||
|
|
||||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# =============================================================================
|
|
||||||
"""XLA utility functions."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import collections
|
|
||||||
|
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
|
|
||||||
|
|
||||||
def is_flat(outputs):
|
|
||||||
"""Checks if outputs is a flat structure.
|
|
||||||
|
|
||||||
Following structures and values are considered flat:
|
|
||||||
1) None
|
|
||||||
2) A single object
|
|
||||||
3) A list or tuple of Tensors/Operations
|
|
||||||
|
|
||||||
The only structures that this function understands are sequences and
|
|
||||||
dictionaries. E.g. this means that if outputs contains a single
|
|
||||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
|
||||||
if that Object cannot be converted to a Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs: Output from `computation` inside `xla.compile`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A boolean indicates whether outputs is flat.
|
|
||||||
"""
|
|
||||||
# If outputs is a list or tuple, check if it has any nested structure. If
|
|
||||||
# there is, then outputs is non-flat.
|
|
||||||
if isinstance(outputs, collections.Sequence):
|
|
||||||
for o in outputs:
|
|
||||||
if isinstance(o, collections.Sequence) or isinstance(o, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# If outputs is a dict, it is non-flat.
|
|
||||||
if isinstance(outputs, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Getting here means either outputs itself is a single non-structured value
|
|
||||||
# or it is a flat list of single non-structured values.
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
|
||||||
"""Validate the number of input arguments to an XLA function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: the Python function that will be called to generate the body of an XLA
|
|
||||||
computation graph.
|
|
||||||
input_arity: the number of explicit arguments supplied by the caller.
|
|
||||||
infeed_queue: if not None, the infeed queue that will supply
|
|
||||||
additional arguments to the function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None if function can be called with the supplied number of
|
|
||||||
arguments, or an error string if it cannot.
|
|
||||||
"""
|
|
||||||
def format_error(complaint, quantity):
|
|
||||||
return '%s %d argument%s' % (complaint, quantity, ''
|
|
||||||
if quantity == 1 else 's')
|
|
||||||
|
|
||||||
num_args_supplied = input_arity
|
|
||||||
if infeed_queue is not None:
|
|
||||||
num_args_supplied += infeed_queue.number_of_tuple_elements
|
|
||||||
arg_spec = tf_inspect.getargspec(func)
|
|
||||||
num_func_args = len(arg_spec.args)
|
|
||||||
if arg_spec.defaults is None:
|
|
||||||
num_func_defaults = 0
|
|
||||||
else:
|
|
||||||
num_func_defaults = len(arg_spec.defaults)
|
|
||||||
min_func_args = num_func_args - num_func_defaults
|
|
||||||
if num_args_supplied < min_func_args:
|
|
||||||
# The required number of arguments is not enough to call the function.
|
|
||||||
if num_func_defaults == 0 and arg_spec.varargs is None:
|
|
||||||
return format_error('exactly', num_func_args)
|
|
||||||
else:
|
|
||||||
return format_error('at least', min_func_args)
|
|
||||||
if arg_spec.varargs is None and num_args_supplied > num_func_args:
|
|
||||||
# The required number of arguments is too many to call the function.
|
|
||||||
if num_func_defaults == 0:
|
|
||||||
return format_error('exactly', num_func_args)
|
|
||||||
else:
|
|
||||||
return format_error('at most', num_func_args)
|
|
||||||
# Reaching here means either
|
|
||||||
# 1) There are varargs, func can accept any number of arguments greater than
|
|
||||||
# the minimum.
|
|
||||||
# 2) Number of supplied arguments falls in range of acceptable argument count
|
|
||||||
# of func.
|
|
||||||
return None
|
|
Loading…
Reference in New Issue
Block a user