[TF:XLA] Add a XlaSort operator that directly wraps the Sort HLO.

Merge XLA-specific operator registrations into a single file rather than having many tiny files.

In passing, register a fill function for bfloat16 numpy type; needed for the np.arange() call in the sort unit test.

PiperOrigin-RevId: 201005718
This commit is contained in:
Peter Hawkins 2018-06-18 09:16:09 -07:00 committed by TensorFlower Gardener
parent 3db3e50bb0
commit 8ecf506fb8
13 changed files with 316 additions and 235 deletions

View File

@ -839,6 +839,18 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "sort_ops_test",
size = "small",
srcs = ["sort_ops_test.py"],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
],
)
tf_xla_py_test(
name = "xla_device_test",
size = "small",

View File

@ -0,0 +1,57 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XlaSort."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.test_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
for arg in args
]
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
output = op(*placeholders)
result = session.run(output, feeds)
self.assertAllClose(result, expected, rtol=1e-3)
def testSort(self):
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=np.arange(101, dtype=dtype))
if __name__ == "__main__":
test.main()

View File

@ -79,6 +79,7 @@ tf_kernel_library(
"shape_util.cc",
"slice_op.cc",
"softmax_op.cc",
"sort_ops.cc",
"spacetobatch_op.cc",
"spacetodepth_op.cc",
"split_op.cc",

View File

@ -0,0 +1,36 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
class XlaSortOp : public XlaOpKernel {
public:
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
xla::XlaBuilder* const b = context->builder();
context->SetOutput(0, b->Sort(context->Input(0)));
}
};
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
} // namespace
} // namespace tensorflow

View File

@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
cc_library(
name = "xla_ops",
srcs = [
"dynamic_slice_ops.cc",
"functional_ops.cc",
"reduce_window_op.cc",
"sendrecv_ops.cc",
],
srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
],

View File

@ -1,49 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
.Input("indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA DynamicUpdateSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
.
XlaDynamicUpdateSlice generates a result which is the value of the `input`
operand, with a slice update overwritten at `indices`. The shape of `update`
determines the shape of the sub-array of the result which is updated. The shape
of indices must be rank == 1, with dimension size equal to the rank of `input`.
Handling of out-of-bounds slice indices is implementation-defined.
input: A `Tensor` of type T.
indices: A vector of indices into `input`. Must have length equal to the rank of
`input`.
update: A `Tensor` of type T. Same rank as `input`.
output: A `Tensor` of type T.
)doc");
} // namespace tensorflow

View File

@ -1,74 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
// TODO(b/37549631) setting the If Op to always be stateful is too
// conservative.
REGISTER_OP("XlaIf")
.Input("cond: Tcond")
.Input("inputs: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("then_branch: func")
.Attr("else_branch: func")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = cond ? then_branch(inputs) : else_branch(inputs).
cond: A boolean scalar.
inputs: A list of input tensors.
output: A list of tensors returned by either then_branch(inputs) or
else_branch(inputs). The input shapes of the then_branch and
else_branch must match.
then_branch: A function takes 'inputs' and returns a list of tensors,
whose types are the same as what else_branch returns.
else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
} // namespace tensorflow

View File

@ -1,45 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
.Attr("T: numbertype")
.Attr("computation: func")
.Attr("window_dimensions: list(int)")
.Attr("window_strides: list(int)")
.Attr("padding_low: list(int)")
.Attr("padding_high: list(int)")
.Output("output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding_low: the padding to apply at the start of each input dimensions
padding_high: the padding to apply at the end of each input dimension.
)doc");
} // namespace tensorflow

View File

@ -1,61 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("XlaSend")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
documented at
https://www.tensorflow.org/performance/xla/operation_semantics#send .
tensor: The tensor to send.
tensor_name: A string key that identifies the channel.
)doc");
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return Status::OK();
})
.Doc(R"doc(
Receives the named tensor from another XLA computation. Wraps the XLA Recv
operator documented at
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
tensor: The tensor to receive.
dtype: The type of the tensor.
tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,182 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
.Input("indices: Tindices")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA DynamicUpdateSlice operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
.
XlaDynamicUpdateSlice generates a result which is the value of the `input`
operand, with a slice update overwritten at `indices`. The shape of `update`
determines the shape of the sub-array of the result which is updated. The shape
of indices must be rank == 1, with dimension size equal to the rank of `input`.
Handling of out-of-bounds slice indices is implementation-defined.
input: A `Tensor` of type T.
indices: A vector of indices into `input`. Must have length equal to the rank of
`input`.
update: A `Tensor` of type T. Same rank as `input`.
output: A `Tensor` of type T.
)doc");
// TODO(b/37549631) setting the If Op to always be stateful is too
// conservative.
REGISTER_OP("XlaIf")
.Input("cond: Tcond")
.Input("inputs: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("then_branch: func")
.Attr("else_branch: func")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = cond ? then_branch(inputs) : else_branch(inputs).
cond: A boolean scalar.
inputs: A list of input tensors.
output: A list of tensors returned by either then_branch(inputs) or
else_branch(inputs). The input shapes of the then_branch and
else_branch must match.
then_branch: A function takes 'inputs' and returns a list of tensors,
whose types are the same as what else_branch returns.
else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
TensorShape shape_attr;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
c->set_output(0, s);
return Status::OK();
})
.Doc(R"doc(
Receives the named tensor from another XLA computation. Wraps the XLA Recv
operator documented at
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
tensor: The tensor to receive.
dtype: The type of the tensor.
tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
.Attr("T: numbertype")
.Attr("computation: func")
.Attr("window_dimensions: list(int)")
.Attr("window_strides: list(int)")
.Attr("padding_low: list(int)")
.Attr("padding_high: list(int)")
.Output("output: T")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
input: the input tensor
init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
padding_low: the padding to apply at the start of each input dimensions
padding_high: the padding to apply at the end of each input dimension.
)doc");
REGISTER_OP("XlaSend")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
documented at
https://www.tensorflow.org/performance/xla/operation_semantics#send .
tensor: The tensor to send.
tensor_name: A string key that identifies the channel.
)doc");
REGISTER_OP("XlaSort")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
} // namespace tensorflow

View File

@ -77,4 +77,6 @@ def reduce_window(operand,
recv = gen_xla_ops.xla_recv
send = gen_xla_ops.xla_send
sort = gen_xla_ops.xla_sort
while_loop = gen_xla_ops.xla_while

View File

@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
return x != static_cast<bfloat16>(0);
}
int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
const float start(buffer[0]);
const float delta = static_cast<float>(buffer[1]) - start;
for (npy_intp i = 2; i < length; ++i) {
buffer[i] = static_cast<bfloat16>(start + i * delta);
}
return 0;
}
// NumPy casts
// Performs a NumPy array cast from type 'From' to 'To'.
@ -548,6 +558,7 @@ bool Initialize() {
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);

View File

@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase):
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
atol=2e-2)
def testArange(self):
self.assertAllEqual(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
self.assertAllEqual(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
self.assertAllEqual(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
self.assertAllEqual(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
if __name__ == "__main__":
test.main()