STT-tensorflow/tensorflow/compiler/tf2xla/ops/xla_ops.cc
A. Unique TensorFlower 9baeb353e1 [TF:XLA] Implement XlaSvdOp
PiperOrigin-RevId: 236936547
2019-03-05 15:59:36 -08:00

511 lines
17 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
// Helper shape function for operators that return an output with the same rank
// as their first input.
Status UnchangedRank(shape_inference::InferenceContext* c) {
if (c->RankKnown(c->input(0))) {
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
} else {
c->set_output(0, c->input(0));
}
return Status::OK();
}
REGISTER_OP("XlaBroadcastHelper")
.Input("lhs: T")
.Input("rhs: T")
.Input("broadcast_dims: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Output("lhs_output: T")
.Output("rhs_output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Helper operator for performing XLA-style broadcasts
Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
for binary operators.
lhs: the LHS input tensor
rhs: the RHS input tensor
broadcast_dims: an XLA-style broadcast dimension specification
lhs_output: the broadcasted LHS tensor
rhs_output: the broadcasted RHS tensor
)doc");
REGISTER_OP("XlaSelfAdjointEig")
.Input("a: T")
.Attr("lower: bool")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Output("w: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
i=0...N-1.
a: the input tensor.
lower: a boolean specifies whether the calculation is done with the lower
triangular part or the upper triangular part.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximatly logN sweeps are needed in practice (Ref: Golub &
van Loan "Matrix Computation").
epsilon: the tolerance ratio.
w: The eigenvalues in ascending order, each repeated according to its
multiplicity.
v: The column v[..., :, i] is the normalized eigenvector corresponding to the
eigenvalue w[..., i].
)doc");
REGISTER_OP("XlaSvd")
.Input("a: T")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Attr("precision_config: string")
.Output("s: T")
.Output("u: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
a: the input tensor.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximatly log(min (M, N)) sweeps are needed in practice
(Ref: Golub & van Loan "Matrix Computation").
epsilon: the tolerance ratio.
precision_config: a serialized xla::PrecisionConfig proto.
s: Singular values. The values are sorted in reverse order of magnitude, so
s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
u: Left singular vectors.
v: Right singular vectors.
)doc");
REGISTER_OP("XlaConv")
.Input("lhs: T")
.Input("rhs: T")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("lhs_dilation: Tindices")
.Input("rhs_dilation: Tindices")
.Input("feature_group_count: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ConvGeneralDilated operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
.
lhs: the input tensor
rhs: the kernel tensor
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
feature_group_count: number of feature groups for grouped convolution.
dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDot")
.Input("lhs: T")
.Input("rhs: T")
.Attr("T: numbertype")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Output("output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Wraps the XLA ConvGeneralDilated operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
.
lhs: the LHS tensor
rhs: the RHS tensor
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDynamicSlice")
.Input("input: T")
.Input("start_indices: Tindices")
.Input("size_indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Wraps the XLA DynamicSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
.
DynamicSlice extracts a sub-array from the input array at dynamic
start_indices. The size of the slice in each dimension is passed in
size_indices, which specify the end point of exclusive slice intervals in each
dimension -- [start, start + size). The shape of start_indices must have rank 1,
with dimension size equal to the rank of operand.
input: A `Tensor` of type T.
start_indices: Rank 1 tensor of N integers containing the starting indices of
the slice for each dimension. Value must be greater than or equal to zero.
start_indices: List of N integers containing the slice size for each
dimension. Each value must be strictly greater than zero, and start + size
must be less than or equal to the size of the dimension to avoid
implementation defined behavior.
)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
.Input("indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA DynamicUpdateSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
.
XlaDynamicUpdateSlice generates a result which is the value of the `input`
operand, with a slice update overwritten at `indices`. The shape of `update`
determines the shape of the sub-array of the result which is updated. The shape
of indices must be rank == 1, with dimension size equal to the rank of `input`.
Handling of out-of-bounds slice indices is implementation-defined.
input: A `Tensor` of type T.
indices: A vector of indices into `input`. Must have length equal to the rank of
`input`.
update: A `Tensor` of type T. Same rank as `input`.
output: A `Tensor` of type T.
)doc");
// TODO(b/37549631) setting the If Op to always be stateful is too
// conservative.
REGISTER_OP("XlaIf")
.Input("cond: Tcond")
.Input("inputs: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("then_branch: func")
.Attr("else_branch: func")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = cond ? then_branch(inputs) : else_branch(inputs).
cond: A boolean scalar.
inputs: A list of input tensors.
output: A list of tensors returned by either then_branch(inputs) or
else_branch(inputs). The input shapes of the then_branch and
else_branch must match.
then_branch: A function takes 'inputs' and returns a list of tensors,
whose types are the same as what else_branch returns.
else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
REGISTER_OP("XlaPad")
.Input("input: T")
.Input("padding_value: T")
.Input("padding_low: Tindices")
.Input("padding_high: Tindices")
.Input("padding_interior: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA Pad operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#pad
.
input: A `Tensor` of type T.
padding_value: A scalar `Tensor` of type T.
padding_low: the padding to apply at the start of each input dimensions
padding_high: the padding to apply at the end of each input dimension.
padding_interior: the padding to apply between each input element.
output: A `Tensor` of type T.
)doc");
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return Status::OK();
})
.Doc(R"doc(
Receives the named tensor from another XLA computation. Wraps the XLA Recv
operator documented at
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
tensor: The tensor to receive.
dtype: The type of the tensor.
tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
REGISTER_OP("XlaReduce")
.Input("input: T")
.Input("init_value: T")
.Attr("T: numbertype")
.Attr("dimensions_to_reduce: list(int)")
.Attr("reducer: func")
.Output("output: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
if (c->RankKnown(c->input(0))) {
int rank = c->Rank(c->input(0));
std::vector<int64> dimensions_to_reduce;
TF_RETURN_IF_ERROR(
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
std::set<int64> dims_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
auto dim_in_range = [rank](int64 dim) {
return dim >= 0 && dim < rank;
};
if (rank < dimensions_to_reduce.size() ||
dims_set.size() != dimensions_to_reduce.size() ||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
return errors::InvalidArgument(
"Invalid dimensions_to_reduce argument to XlaReduce");
}
c->set_output(
0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
} else {
c->set_output(0, c->input(0));
}
return Status::OK();
})
.Doc(R"doc(
Wraps the XLA Reduce operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
.Input("base_dilations: Tindices")
.Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("computation: func")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
)doc");
REGISTER_OP("XlaSelectAndScatter")
.Input("operand: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("source: T")
.Input("init_value: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("select: func")
.Attr("scatter: func")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA SelectAndScatter operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
.
operand: the input tensor
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
source: a tensor of values to scatter
init_value: a scalar representing the initial value for the output tensor
select: a selection function to apply
scatter: a scatter function to apply
)doc");
REGISTER_OP("XlaSend")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
documented at
https://www.tensorflow.org/performance/xla/operation_semantics#send .
tensor: The tensor to send.
tensor_name: A string key that identifies the channel.
)doc");
REGISTER_OP("XlaSort")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts a tensor. Currently only sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
REGISTER_OP("XlaKeyValueSort")
.Input("keys: K")
.Input("values: V")
.Output("sorted_keys: K")
.Output("sorted_values: V")
.Attr("K: realnumbertype")
.Attr("V: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output(1, c->input(1));
return Status::OK();
})
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts a tensor. Currently only sorts in ascending order are supported.
keys: A `Tensor` of type K.
values: A `Tensor` of type V.
sorted_keys: A `Tensor` of type K.
sorted_values: A `Tensor` of type V.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
REGISTER_OP("XlaDequantize")
.Input("input: uint32")
.Output("output: bfloat16")
.Attr("min_range: float")
.Attr("max_range: float")
.Attr("mode: string")
.Attr("transpose_output: bool")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Takes the packed uint32 input and unpacks the input to uint8 to do
Dequantization on deivce.
input: Input tensors whose types is uint32, shape is [d0, ..., dn].
output: Output tensors whose types is bloat16. If transpose_output is true,
output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
is false, output shape is [d0,..., dn * 4].
min_range: The minimum scalar value possibly produced for the input.
max_range: The maximum scalar value possibly produced for the input.
mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
transpose_output: Boolean to determine if output is transposed. transpose_output
is faster when input is large and rank of input is higher than 1.
)doc");
} // namespace
} // namespace tensorflow