These support mixed operand precision (bf16 @ f32 -> f32), (int8 @ int8 -> int32) and thus diverge significantly from the original Xla{Dot,Conv}. PiperOrigin-RevId: 358860920 Change-Id: I316f1b43268b268ce50528f7aa307182bce304f6
517 lines
17 KiB
Python
517 lines
17 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Experimental library that exposes XLA operations directly in TensorFlow.
|
|
|
|
It is sometimes useful to be able to build HLO programs directly from
|
|
TensorFlow. This file provides Tensorflow operators that mirror the semantics of
|
|
HLO operators as closely as possible.
|
|
|
|
Note: There is no promise of backward or forward compatibility for operators
|
|
defined in this module. This is primarily because the underlying HLO operators
|
|
do not promise backward or forward compatibility.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import bitwise_ops
|
|
from tensorflow.python.ops import gen_math_ops
|
|
from tensorflow.python.ops import gen_random_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import special_math_ops
|
|
from tensorflow.python.ops.numpy_ops import np_utils
|
|
|
|
# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
|
|
# ops include:
|
|
# infeed/outfeed (available via tf.contrib.tpu)
|
|
# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
|
|
# conditional
|
|
# gather/scatter
|
|
# collapse
|
|
|
|
# This file reuses builtin names (following XLA's names, so we can call things
|
|
# like xla.max), so we capture the builtin versions here.
|
|
# pylint: disable=redefined-builtin
|
|
_max = max
|
|
_min = min
|
|
_slice = slice # pylint: disable=invalid-name
|
|
|
|
constant = constant_op.constant
|
|
|
|
# Unary operators.
|
|
|
|
# For most arithmetic operators there is a TensorFlow operator
|
|
# that exactly corresponds to each XLA operator. Rather than defining
|
|
# XLA-specific variants, we reuse the corresponding TensorFlow operator.
|
|
# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
|
|
# wrap every HLO operator, because that would allow us to be confident that the
|
|
# semantics match.
|
|
|
|
|
|
def _unary_op(fn):
|
|
"""Wrapper that restricts `fn` to have the correct signature."""
|
|
|
|
def unary_op_wrapper(x, name=None):
|
|
return fn(x, name=name)
|
|
|
|
return unary_op_wrapper
|
|
|
|
|
|
abs = _unary_op(math_ops.abs)
|
|
# TODO(phawkins): implement clz.
|
|
conj = _unary_op(math_ops.conj)
|
|
cos = _unary_op(math_ops.cos)
|
|
ceil = _unary_op(math_ops.ceil)
|
|
digamma = _unary_op(math_ops.digamma)
|
|
erf = _unary_op(math_ops.erf)
|
|
erfc = _unary_op(math_ops.erfc)
|
|
erfinv = _unary_op(math_ops.erfinv)
|
|
ndtri = _unary_op(math_ops.ndtri)
|
|
exp = _unary_op(math_ops.exp)
|
|
expm1 = _unary_op(math_ops.expm1)
|
|
floor = _unary_op(math_ops.floor)
|
|
imag = _unary_op(math_ops.imag)
|
|
is_finite = _unary_op(math_ops.is_finite)
|
|
lgamma = _unary_op(math_ops.lgamma)
|
|
log = _unary_op(math_ops.log)
|
|
log1p = _unary_op(math_ops.log1p)
|
|
logical_not = _unary_op(math_ops.logical_not)
|
|
neg = _unary_op(math_ops.neg)
|
|
real = _unary_op(math_ops.real)
|
|
# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
|
|
# numbers halfway between two integers.
|
|
round = _unary_op(math_ops.round)
|
|
sin = _unary_op(math_ops.sin)
|
|
sign = _unary_op(math_ops.sign)
|
|
tanh = _unary_op(math_ops.tanh)
|
|
|
|
# Bessel
|
|
bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
|
|
bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
|
|
|
|
# Binary operators
|
|
|
|
# The main difference between TensorFlow and XLA binary ops is the broadcasting
|
|
# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
|
|
# requires an explicit specification of which dimensions to broadcast if the
|
|
# arguments have different ranks.
|
|
|
|
|
|
def _broadcasting_binary_op(fn):
|
|
"""Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
|
|
|
|
def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
|
|
"""Inner wrapper function."""
|
|
broadcast_dims = broadcast_dims or []
|
|
broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
|
|
# Rather than relying on having static shape information in the TensorFlow
|
|
# graph, we use an XlaBroadcastHelper op that can compute the correct shapes
|
|
# at JIT compilation time.
|
|
x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
|
|
return fn(x, y, name=name)
|
|
|
|
return broadcasting_binary_op_wrapper
|
|
|
|
|
|
# Map from TF signed types to TF unsigned types.
|
|
_SIGNED_TO_UNSIGNED_TABLE = {
|
|
dtypes.int8: dtypes.uint8,
|
|
dtypes.int16: dtypes.uint16,
|
|
dtypes.int32: dtypes.uint32,
|
|
dtypes.int64: dtypes.uint64,
|
|
}
|
|
|
|
# Map from TF unsigned types to TF signed types.
|
|
_UNSIGNED_TO_SIGNED_TABLE = {
|
|
dtypes.uint8: dtypes.int8,
|
|
dtypes.uint16: dtypes.int16,
|
|
dtypes.uint32: dtypes.int32,
|
|
dtypes.uint64: dtypes.int64,
|
|
}
|
|
|
|
|
|
def _shift_right_logical_helper(x, y, name=None):
|
|
"""Performs an integer right logical shift irrespective of input type."""
|
|
assert y.dtype == x.dtype
|
|
dtype = x.dtype
|
|
signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
|
|
if signed:
|
|
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
|
|
x = math_ops.cast(x, unsigned_dtype)
|
|
y = math_ops.cast(y, unsigned_dtype)
|
|
output = bitwise_ops.right_shift(x, y, name=name)
|
|
if signed:
|
|
output = math_ops.cast(output, dtype)
|
|
return output
|
|
|
|
|
|
def _shift_right_arithmetic_helper(x, y, name=None):
|
|
"""Performs an integer right arithmetic shift irrespective of input type."""
|
|
assert y.dtype == x.dtype
|
|
dtype = x.dtype
|
|
unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
|
|
if unsigned:
|
|
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
|
|
x = math_ops.cast(x, signed_dtype)
|
|
y = math_ops.cast(y, signed_dtype)
|
|
output = bitwise_ops.right_shift(x, y, name=name)
|
|
if unsigned:
|
|
output = math_ops.cast(output, dtype)
|
|
return output
|
|
|
|
|
|
add = _broadcasting_binary_op(math_ops.add)
|
|
sub = _broadcasting_binary_op(math_ops.sub)
|
|
mul = _broadcasting_binary_op(math_ops.mul)
|
|
div = _broadcasting_binary_op(math_ops.div)
|
|
rem = _broadcasting_binary_op(gen_math_ops.mod)
|
|
max = _broadcasting_binary_op(math_ops.maximum)
|
|
min = _broadcasting_binary_op(math_ops.minimum)
|
|
atan2 = _broadcasting_binary_op(math_ops.atan2)
|
|
complex = _broadcasting_binary_op(math_ops.complex)
|
|
logical_and = _broadcasting_binary_op(math_ops.logical_and)
|
|
logical_or = _broadcasting_binary_op(math_ops.logical_or)
|
|
logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
|
|
eq = _broadcasting_binary_op(math_ops.equal)
|
|
ne = _broadcasting_binary_op(math_ops.not_equal)
|
|
ge = _broadcasting_binary_op(math_ops.greater_equal)
|
|
gt = _broadcasting_binary_op(math_ops.greater)
|
|
le = _broadcasting_binary_op(math_ops.less_equal)
|
|
lt = _broadcasting_binary_op(math_ops.less)
|
|
pow = _broadcasting_binary_op(math_ops.pow)
|
|
shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
|
|
shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
|
|
shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
|
|
|
|
igamma = _broadcasting_binary_op(math_ops.igamma)
|
|
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
|
|
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
|
|
igammac = _broadcasting_binary_op(math_ops.igammac)
|
|
polygamma = _broadcasting_binary_op(math_ops.polygamma)
|
|
zeta = _broadcasting_binary_op(math_ops.zeta)
|
|
|
|
|
|
def _binary_op(fn):
|
|
"""Wrapper that restricts `fn` to have the correct signature."""
|
|
|
|
def binary_op_wrapper(x, y, name=None):
|
|
return fn(x, y, name=name)
|
|
|
|
return binary_op_wrapper
|
|
|
|
|
|
transpose = _binary_op(array_ops.transpose)
|
|
rev = _binary_op(array_ops.reverse)
|
|
|
|
bitcast_convert_type = array_ops.bitcast
|
|
|
|
|
|
def broadcast(x, dims, name=None):
|
|
x = ops.convert_to_tensor(x)
|
|
shape = array_ops.concat([constant_op.constant(dims),
|
|
array_ops.shape(x)],
|
|
axis=0)
|
|
return array_ops.broadcast_to(x, shape, name=name)
|
|
|
|
|
|
def clamp(a, x, b, name=None):
|
|
return min(max(a, x, name=name), b, name=name)
|
|
|
|
|
|
concatenate = array_ops.concat
|
|
|
|
|
|
def conv(lhs,
|
|
rhs,
|
|
window_strides,
|
|
padding,
|
|
lhs_dilation,
|
|
rhs_dilation,
|
|
dimension_numbers,
|
|
feature_group_count=1,
|
|
precision_config=None,
|
|
preferred_element_type=None,
|
|
name=None):
|
|
"""Wraps the XLA ConvGeneralDilated operator.
|
|
|
|
ConvGeneralDilated is the most general form of XLA convolution and is
|
|
documented at
|
|
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
|
|
|
|
Args:
|
|
lhs: the input tensor
|
|
rhs: the kernel tensor
|
|
window_strides: the inter-window strides
|
|
padding: the padding to apply at the start and end of each input dimensions
|
|
lhs_dilation: dilation to apply between input elements
|
|
rhs_dilation: dilation to apply between kernel elements
|
|
dimension_numbers: a `ConvolutionDimensionNumbers` proto.
|
|
feature_group_count: number of feature groups for grouped convolution.
|
|
precision_config: a `xla.PrecisionConfig` proto.
|
|
preferred_element_type: the result `dtype`.
|
|
name: an optional name for the operator
|
|
|
|
Returns:
|
|
A tensor representing the output of the convolution.
|
|
"""
|
|
precision_config_proto = ""
|
|
if precision_config:
|
|
precision_config_proto = precision_config.SerializeToString()
|
|
if preferred_element_type is None:
|
|
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
|
return gen_xla_ops.xla_conv_v2(
|
|
lhs,
|
|
rhs,
|
|
window_strides=window_strides,
|
|
padding=padding,
|
|
lhs_dilation=lhs_dilation,
|
|
rhs_dilation=rhs_dilation,
|
|
feature_group_count=feature_group_count,
|
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
|
precision_config=precision_config_proto,
|
|
preferred_element_type=preferred_element_type,
|
|
name=name)
|
|
|
|
|
|
convert_element_type = math_ops.cast
|
|
|
|
|
|
def dot(lhs, rhs, name=None):
|
|
return math_ops.tensordot(lhs, rhs, axes=1, name=name)
|
|
|
|
|
|
def dot_general(lhs,
|
|
rhs,
|
|
dimension_numbers,
|
|
precision_config=None,
|
|
preferred_element_type=None,
|
|
name=None):
|
|
precision_config_proto = ""
|
|
if precision_config:
|
|
precision_config_proto = precision_config.SerializeToString()
|
|
if preferred_element_type is None:
|
|
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
|
return gen_xla_ops.xla_dot_v2(
|
|
lhs,
|
|
rhs,
|
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
|
precision_config=precision_config_proto,
|
|
preferred_element_type=preferred_element_type,
|
|
name=name)
|
|
|
|
|
|
def self_adjoint_eig(a, lower, max_iter, epsilon):
|
|
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
|
|
|
|
|
|
def svd(a, max_iter, epsilon, precision_config=None):
|
|
precision_config_proto = ""
|
|
if precision_config:
|
|
precision_config_proto = precision_config.SerializeToString()
|
|
return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
|
|
|
|
|
|
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
|
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
|
einsum = gen_xla_ops.xla_einsum
|
|
|
|
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
|
|
# the XLA-specific pad operator.
|
|
pad = gen_xla_ops.xla_pad
|
|
|
|
|
|
def random_normal(mu, sigma, dims, name=None):
|
|
mu = ops.convert_to_tensor(mu)
|
|
return random_ops.random_normal(
|
|
dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
|
|
|
|
|
|
def random_uniform(minval, maxval, dims, name=None):
|
|
minval = ops.convert_to_tensor(minval)
|
|
return random_ops.random_uniform(
|
|
dims, minval, maxval, dtype=minval.dtype, name=name)
|
|
|
|
|
|
recv = gen_xla_ops.xla_recv
|
|
reduce = gen_xla_ops.xla_reduce
|
|
variadic_reduce = gen_xla_ops.xla_variadic_reduce
|
|
|
|
ops.no_gradient("XlaVariadicReduce")
|
|
|
|
|
|
def reduce_window(operand,
|
|
init,
|
|
reducer,
|
|
window_dimensions,
|
|
window_strides=None,
|
|
base_dilations=None,
|
|
window_dilations=None,
|
|
padding=None,
|
|
name=None):
|
|
"""Wraps the XLA ReduceWindow operator.
|
|
|
|
ReduceWindow is documented at
|
|
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
|
|
|
|
Args:
|
|
operand: the input tensor
|
|
init: a scalar tensor representing the initial value for the reduction
|
|
reducer: a reduction function that combines a pair of scalars.
|
|
window_dimensions: shape of the window, as a list of integers
|
|
window_strides: inter-window strides, as a list of integers. Optional; if
|
|
omitted, defaults to strides of 1.
|
|
padding: padding to apply to 'operand'. List of (low, high) pairs of
|
|
integers that specify the padding to apply before and after each
|
|
dimension. Optional; if omitted, defaults to no padding.
|
|
name: the operator name, or None.
|
|
|
|
Returns:
|
|
A tensor that represents the output of the reduce_window operator.
|
|
"""
|
|
window_strides = window_strides or [1] * len(window_dimensions)
|
|
base_dilations = base_dilations or [1] * len(window_dimensions)
|
|
window_dilations = window_dilations or [1] * len(window_dimensions)
|
|
padding = padding or [(0, 0)] * len(window_dimensions)
|
|
return gen_xla_ops.xla_reduce_window(
|
|
input=operand,
|
|
init_value=init,
|
|
window_dimensions=window_dimensions,
|
|
window_strides=window_strides,
|
|
base_dilations=base_dilations,
|
|
window_dilations=window_dilations,
|
|
padding=padding,
|
|
computation=reducer,
|
|
name=name)
|
|
|
|
|
|
replica_id = gen_xla_ops.xla_replica_id
|
|
|
|
# Set a static bound for the given input value as a hint to Xla compiler,
|
|
# returns the same value.
|
|
# Usage:
|
|
# def f(t, p):
|
|
# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
|
|
# return t[:p] # xla knows the bound of the slice is 3.
|
|
set_bound = gen_xla_ops.xla_set_bound
|
|
|
|
|
|
# Make a static dimension into a xla bounded dynamic dimension. The current
|
|
# static dimension size will become the bound and the second operand becomes the
|
|
# dynamic size of the dimension.
|
|
#
|
|
# This should mostly be used for testing.
|
|
#
|
|
# def f():
|
|
# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
|
|
# # Tells xla the valid size of the array is 3.
|
|
# dim = 0
|
|
# p = xla_set_dynamic_dimension_size(array, dim, 3)
|
|
# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
|
|
set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
|
|
|
|
|
|
def reshape(x, new_sizes, dimensions=None, name=None):
|
|
if dimensions is not None:
|
|
x = array_ops.transpose(x, dimensions)
|
|
x = array_ops.reshape(x, new_sizes, name=name)
|
|
return x
|
|
|
|
|
|
def select(condition, x, y, name=None):
|
|
return array_ops.where(condition, x, y, name)
|
|
|
|
|
|
select_and_scatter = gen_xla_ops.xla_select_and_scatter
|
|
send = gen_xla_ops.xla_send
|
|
|
|
|
|
def slice(x, start_dims, limit_dims, strides):
|
|
spec = [
|
|
_slice(start, limit, stride)
|
|
for (start, limit, stride) in zip(start_dims, limit_dims, strides)
|
|
]
|
|
return x[tuple(spec)]
|
|
|
|
|
|
sharding = gen_xla_ops.xla_sharding
|
|
|
|
|
|
@ops.RegisterGradient("XlaSharding")
|
|
def _sharding_grad(op, grad):
|
|
sharding_attr = op.get_attr("sharding")
|
|
grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
|
|
# pylint: disable=protected-access
|
|
grad_sharding.op._set_attr("_XlaSharding",
|
|
attr_value_pb2.AttrValue(s=sharding_attr))
|
|
return [grad_sharding]
|
|
|
|
|
|
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
|
spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
|
|
|
|
|
|
@ops.RegisterGradient("XlaSpmdFullToShardShape")
|
|
def _spmd_full_to_shard_shape_grad(op, grad):
|
|
s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
|
|
grad,
|
|
manual_sharding=op.get_attr("manual_sharding"),
|
|
full_shape=op.inputs[0].shape.as_list())
|
|
return [s2f]
|
|
|
|
|
|
@ops.RegisterGradient("XlaSpmdShardToFullShape")
|
|
def _spmd_shard_to_full_shape_grad(op, grad):
|
|
f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
|
|
grad, manual_sharding=op.get_attr("manual_sharding"))
|
|
return [f2s]
|
|
|
|
|
|
sort = gen_xla_ops.xla_sort
|
|
key_value_sort = gen_xla_ops.xla_key_value_sort
|
|
variadic_sort = gen_xla_ops.xla_variadic_sort
|
|
while_loop = gen_xla_ops.xla_while
|
|
dequantize = gen_xla_ops.xla_dequantize
|
|
|
|
|
|
def gather(operand, start_indices, dimension_numbers, slice_sizes,
|
|
indices_are_sorted=False, name=None):
|
|
return gen_xla_ops.xla_gather(
|
|
operand,
|
|
start_indices,
|
|
slice_sizes=slice_sizes,
|
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
|
indices_are_sorted=indices_are_sorted,
|
|
name=name)
|
|
|
|
|
|
def scatter(operand, scatter_indices, updates, update_computation,
|
|
dimension_numbers, indices_are_sorted=False, name=None):
|
|
return gen_xla_ops.xla_scatter(
|
|
operand,
|
|
scatter_indices,
|
|
updates,
|
|
update_computation=update_computation,
|
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
|
indices_are_sorted=indices_are_sorted,
|
|
name=name)
|