[TF:XLA] Add a XlaSort operator that directly wraps the Sort HLO.
Merge XLA-specific operator registrations into a single file rather than having many tiny files. In passing, register a fill function for bfloat16 numpy type; needed for the np.arange() call in the sort unit test. PiperOrigin-RevId: 201005718
This commit is contained in:
parent
3db3e50bb0
commit
8ecf506fb8
@ -839,6 +839,18 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "small",
|
||||
srcs = ["sort_ops_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "xla_device_test",
|
||||
size = "small",
|
||||
|
57
tensorflow/compiler/tests/sort_ops_test.py
Normal file
57
tensorflow/compiler/tests/sort_ops_test.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for XlaSort."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaSortOpTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self, op, args, expected):
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
placeholders = [
|
||||
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
|
||||
for arg in args
|
||||
]
|
||||
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
|
||||
output = op(*placeholders)
|
||||
result = session.run(output, feeds)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testSort(self):
|
||||
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||
if self.device in ["XLA_CPU", "XLA_GPU"]:
|
||||
return
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
x = np.arange(101, dtype=dtype)
|
||||
np.random.shuffle(x)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.sort, [x], expected=np.arange(101, dtype=dtype))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -79,6 +79,7 @@ tf_kernel_library(
|
||||
"shape_util.cc",
|
||||
"slice_op.cc",
|
||||
"softmax_op.cc",
|
||||
"sort_ops.cc",
|
||||
"spacetobatch_op.cc",
|
||||
"spacetodepth_op.cc",
|
||||
"split_op.cc",
|
||||
|
36
tensorflow/compiler/tf2xla/kernels/sort_ops.cc
Normal file
36
tensorflow/compiler/tf2xla/kernels/sort_ops.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaSortOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
xla::XlaBuilder* const b = context->builder();
|
||||
context->SetOutput(0, b->Sort(context->Input(0)));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
srcs = [
|
||||
"dynamic_slice_ops.cc",
|
||||
"functional_ops.cc",
|
||||
"reduce_window_op.cc",
|
||||
"sendrecv_ops.cc",
|
||||
],
|
||||
srcs = ["xla_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("XlaDynamicUpdateSlice")
|
||||
.Input("input: T")
|
||||
.Input("update: T")
|
||||
.Input("indices: Tindices")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA DynamicUpdateSlice operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
|
||||
.
|
||||
|
||||
XlaDynamicUpdateSlice generates a result which is the value of the `input`
|
||||
operand, with a slice update overwritten at `indices`. The shape of `update`
|
||||
determines the shape of the sub-array of the result which is updated. The shape
|
||||
of indices must be rank == 1, with dimension size equal to the rank of `input`.
|
||||
|
||||
Handling of out-of-bounds slice indices is implementation-defined.
|
||||
|
||||
input: A `Tensor` of type T.
|
||||
indices: A vector of indices into `input`. Must have length equal to the rank of
|
||||
`input`.
|
||||
update: A `Tensor` of type T. Same rank as `input`.
|
||||
output: A `Tensor` of type T.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,74 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(b/37549631) setting the While Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaWhile")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A function takes 'input' and returns a tensor. If the tensor is
|
||||
a scalar of non-boolean, the scalar is converted to a boolean
|
||||
according to the following rule: if the scalar is a numerical
|
||||
value, non-zero means True and zero means False; if the scalar is
|
||||
a string, non-empty means True and empty means False. If the
|
||||
tensor is not a scalar, non-emptiness means True and False
|
||||
otherwise.
|
||||
body: A function that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types as specified by T.
|
||||
)doc");
|
||||
|
||||
// TODO(b/37549631) setting the If Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaIf")
|
||||
.Input("cond: Tcond")
|
||||
.Input("inputs: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("Tcond: type")
|
||||
.Attr("then_branch: func")
|
||||
.Attr("else_branch: func")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >= 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
output = cond ? then_branch(inputs) : else_branch(inputs).
|
||||
|
||||
cond: A boolean scalar.
|
||||
inputs: A list of input tensors.
|
||||
output: A list of tensors returned by either then_branch(inputs) or
|
||||
else_branch(inputs). The input shapes of the then_branch and
|
||||
else_branch must match.
|
||||
then_branch: A function takes 'inputs' and returns a list of tensors,
|
||||
whose types are the same as what else_branch returns.
|
||||
else_branch: A function takes 'inputs' and returns a list of tensors.
|
||||
whose types are the same as what then_branch returns.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,45 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("XlaReduceWindow")
|
||||
.Input("input: T")
|
||||
.Input("init_value: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("computation: func")
|
||||
.Attr("window_dimensions: list(int)")
|
||||
.Attr("window_strides: list(int)")
|
||||
.Attr("padding_low: list(int)")
|
||||
.Attr("padding_high: list(int)")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA ReduceWindow operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
|
||||
|
||||
input: the input tensor
|
||||
init_value: a scalar representing the initial value for the reduction
|
||||
computation: a reducer function to apply
|
||||
window_dimensions: the shape of the window
|
||||
window_strides: the inter-window strides
|
||||
padding_low: the padding to apply at the start of each input dimensions
|
||||
padding_high: the padding to apply at the end of each input dimension.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,61 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("XlaSend")
|
||||
.Input("tensor: T")
|
||||
.Attr("T: type")
|
||||
.Attr("tensor_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
|
||||
documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#send .
|
||||
|
||||
tensor: The tensor to send.
|
||||
tensor_name: A string key that identifies the channel.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaRecv")
|
||||
.Output("tensor: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("tensor_name: string")
|
||||
.Attr("shape: shape")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
TensorShape shape_attr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor from another XLA computation. Wraps the XLA Recv
|
||||
operator documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
|
||||
|
||||
tensor: The tensor to receive.
|
||||
dtype: The type of the tensor.
|
||||
tensor_name: A string key that identifies the channel.
|
||||
shape: The shape of the tensor.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
182
tensorflow/compiler/tf2xla/ops/xla_ops.cc
Normal file
182
tensorflow/compiler/tf2xla/ops/xla_ops.cc
Normal file
@ -0,0 +1,182 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("XlaDynamicUpdateSlice")
|
||||
.Input("input: T")
|
||||
.Input("update: T")
|
||||
.Input("indices: Tindices")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA DynamicUpdateSlice operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
|
||||
.
|
||||
|
||||
XlaDynamicUpdateSlice generates a result which is the value of the `input`
|
||||
operand, with a slice update overwritten at `indices`. The shape of `update`
|
||||
determines the shape of the sub-array of the result which is updated. The shape
|
||||
of indices must be rank == 1, with dimension size equal to the rank of `input`.
|
||||
|
||||
Handling of out-of-bounds slice indices is implementation-defined.
|
||||
|
||||
input: A `Tensor` of type T.
|
||||
indices: A vector of indices into `input`. Must have length equal to the rank of
|
||||
`input`.
|
||||
update: A `Tensor` of type T. Same rank as `input`.
|
||||
output: A `Tensor` of type T.
|
||||
)doc");
|
||||
|
||||
// TODO(b/37549631) setting the If Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaIf")
|
||||
.Input("cond: Tcond")
|
||||
.Input("inputs: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("Tcond: type")
|
||||
.Attr("then_branch: func")
|
||||
.Attr("else_branch: func")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >= 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
output = cond ? then_branch(inputs) : else_branch(inputs).
|
||||
|
||||
cond: A boolean scalar.
|
||||
inputs: A list of input tensors.
|
||||
output: A list of tensors returned by either then_branch(inputs) or
|
||||
else_branch(inputs). The input shapes of the then_branch and
|
||||
else_branch must match.
|
||||
then_branch: A function takes 'inputs' and returns a list of tensors,
|
||||
whose types are the same as what else_branch returns.
|
||||
else_branch: A function takes 'inputs' and returns a list of tensors.
|
||||
whose types are the same as what then_branch returns.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaRecv")
|
||||
.Output("tensor: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("tensor_name: string")
|
||||
.Attr("shape: shape")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
TensorShape shape_attr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor from another XLA computation. Wraps the XLA Recv
|
||||
operator documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#recv .
|
||||
|
||||
tensor: The tensor to receive.
|
||||
dtype: The type of the tensor.
|
||||
tensor_name: A string key that identifies the channel.
|
||||
shape: The shape of the tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaReduceWindow")
|
||||
.Input("input: T")
|
||||
.Input("init_value: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("computation: func")
|
||||
.Attr("window_dimensions: list(int)")
|
||||
.Attr("window_strides: list(int)")
|
||||
.Attr("padding_low: list(int)")
|
||||
.Attr("padding_high: list(int)")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA ReduceWindow operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
|
||||
|
||||
input: the input tensor
|
||||
init_value: a scalar representing the initial value for the reduction
|
||||
computation: a reducer function to apply
|
||||
window_dimensions: the shape of the window
|
||||
window_strides: the inter-window strides
|
||||
padding_low: the padding to apply at the start of each input dimensions
|
||||
padding_high: the padding to apply at the end of each input dimension.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSend")
|
||||
.Input("tensor: T")
|
||||
.Attr("T: type")
|
||||
.Attr("tensor_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor to another XLA computation. Wraps the XLA Send operator
|
||||
documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#send .
|
||||
|
||||
tensor: The tensor to send.
|
||||
tensor_name: A string key that identifies the channel.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSort")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA Sort operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#sort
|
||||
.
|
||||
|
||||
Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
|
||||
|
||||
input: A `Tensor` of type T.
|
||||
output: A `Tensor` of type T.
|
||||
)doc");
|
||||
|
||||
// TODO(b/37549631) setting the While Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaWhile")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A function takes 'input' and returns a tensor. If the tensor is
|
||||
a scalar of non-boolean, the scalar is converted to a boolean
|
||||
according to the following rule: if the scalar is a numerical
|
||||
value, non-zero means True and zero means False; if the scalar is
|
||||
a string, non-empty means True and empty means False. If the
|
||||
tensor is not a scalar, non-emptiness means True and False
|
||||
otherwise.
|
||||
body: A function that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types as specified by T.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -77,4 +77,6 @@ def reduce_window(operand,
|
||||
recv = gen_xla_ops.xla_recv
|
||||
send = gen_xla_ops.xla_send
|
||||
|
||||
sort = gen_xla_ops.xla_sort
|
||||
|
||||
while_loop = gen_xla_ops.xla_while
|
||||
|
@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
|
||||
return x != static_cast<bfloat16>(0);
|
||||
}
|
||||
|
||||
int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
|
||||
bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
|
||||
const float start(buffer[0]);
|
||||
const float delta = static_cast<float>(buffer[1]) - start;
|
||||
for (npy_intp i = 2; i < length; ++i) {
|
||||
buffer[i] = static_cast<bfloat16>(start + i * delta);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NumPy casts
|
||||
|
||||
// Performs a NumPy array cast from type 'From' to 'To'.
|
||||
@ -548,6 +558,7 @@ bool Initialize() {
|
||||
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
|
||||
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
|
||||
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
|
||||
NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
|
||||
|
||||
Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
|
||||
npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
|
||||
|
@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
|
||||
atol=2e-2)
|
||||
|
||||
def testArange(self):
|
||||
self.assertAllEqual(
|
||||
np.arange(100, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(100, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-0., -7., -0.25, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user