diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 98fab319d6f..af760b54167 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -839,6 +839,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sort_ops_test", + size = "small", + srcs = ["sort_ops_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py new file mode 100644 index 00000000000..5ff40edaa5d --- /dev/null +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XlaSort.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaSortOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + self.assertAllClose(result, expected, rtol=1e-3) + + def testSort(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=dtype) + np.random.shuffle(x) + self._assertOpOutputMatchesExpected( + xla.sort, [x], expected=np.arange(101, dtype=dtype)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index edd2ab6301e..e86b333e4bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -79,6 +79,7 @@ tf_kernel_library( "shape_util.cc", "slice_op.cc", "softmax_op.cc", + "sort_ops.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", "split_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc new file mode 100644 index 00000000000..204ae845821 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaSortOp : public XlaOpKernel { + public: + explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* const b = context->builder(); + context->SetOutput(0, b->Sort(context->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index bb9168fa358..ace6fd1d8ee 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( name = "xla_ops", - srcs = [ - "dynamic_slice_ops.cc", - "functional_ops.cc", - "reduce_window_op.cc", - "sendrecv_ops.cc", - ], + srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", ], diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc deleted file mode 100644 index d6c0edbb889..00000000000 --- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("XlaDynamicUpdateSlice") - .Input("input: T") - .Input("update: T") - .Input("indices: Tindices") - .Output("output: T") - .Attr("T: type") - .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Wraps the XLA DynamicUpdateSlice operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice -. - -XlaDynamicUpdateSlice generates a result which is the value of the `input` -operand, with a slice update overwritten at `indices`. The shape of `update` -determines the shape of the sub-array of the result which is updated. The shape -of indices must be rank == 1, with dimension size equal to the rank of `input`. - -Handling of out-of-bounds slice indices is implementation-defined. - -input: A `Tensor` of type T. -indices: A vector of indices into `input`. Must have length equal to the rank of - `input`. -update: A `Tensor` of type T. Same rank as `input`. -output: A `Tensor` of type T. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc deleted file mode 100644 index 4a669f8e6ea..00000000000 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. -REGISTER_OP("XlaWhile") - .Input("input: T") - .Output("output: T") - .Attr("T: list(type) >= 0") - .Attr("cond: func") - .Attr("body: func") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = input; While (Cond(output)) { output = Body(output) } - -input: A list of input tensors whose types are T. -output: A list of output tensors whose types are T. -cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. -body: A function that takes a list of tensors and returns another - list of tensors. Both lists have the same types as specified by T. -)doc"); - -// TODO(b/37549631) setting the If Op to always be stateful is too -// conservative. -REGISTER_OP("XlaIf") - .Input("cond: Tcond") - .Input("inputs: Tin") - .Output("output: Tout") - .Attr("Tcond: type") - .Attr("then_branch: func") - .Attr("else_branch: func") - .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type) >= 0") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = cond ? then_branch(inputs) : else_branch(inputs). - -cond: A boolean scalar. -inputs: A list of input tensors. -output: A list of tensors returned by either then_branch(inputs) or - else_branch(inputs). The input shapes of the then_branch and - else_branch must match. -then_branch: A function takes 'inputs' and returns a list of tensors, - whose types are the same as what else_branch returns. -else_branch: A function takes 'inputs' and returns a list of tensors. - whose types are the same as what then_branch returns. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc deleted file mode 100644 index d9af982adc0..00000000000 --- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaReduceWindow") - .Input("input: T") - .Input("init_value: T") - .Attr("T: numbertype") - .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") - .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Wraps the XLA ReduceWindow operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . - -input: the input tensor -init_value: a scalar representing the initial value for the reduction -computation: a reducer function to apply -window_dimensions: the shape of the window -window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc deleted file mode 100644 index 7ec7b50e905..00000000000 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaSend") - .Input("tensor: T") - .Attr("T: type") - .Attr("tensor_name: string") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Sends the named tensor to another XLA computation. Wraps the XLA Send operator -documented at - https://www.tensorflow.org/performance/xla/operation_semantics#send . - -tensor: The tensor to send. -tensor_name: A string key that identifies the channel. -)doc"); - -REGISTER_OP("XlaRecv") - .Output("tensor: dtype") - .Attr("dtype: type") - .Attr("tensor_name: string") - .Attr("shape: shape") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - TensorShape shape_attr; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); - shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); - c->set_output(0, s); - return Status::OK(); - }) - .Doc(R"doc( -Receives the named tensor from another XLA computation. Wraps the XLA Recv -operator documented at - https://www.tensorflow.org/performance/xla/operation_semantics#recv . - -tensor: The tensor to receive. -dtype: The type of the tensor. -tensor_name: A string key that identifies the channel. -shape: The shape of the tensor. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc new file mode 100644 index 00000000000..a59c77f5c3a --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + +tensor: The tensor to receive. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +REGISTER_OP("XlaSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + +tensor: The tensor to send. +tensor_name: A string key that identifies the channel. +)doc"); + +REGISTER_OP("XlaSort") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. + +input: A `Tensor` of type T. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index e5ce65bec95..2fc47dffb8f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -77,4 +77,6 @@ def reduce_window(operand, recv = gen_xla_ops.xla_recv send = gen_xla_ops.xla_send +sort = gen_xla_ops.xla_sort + while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 77fa2c1f66d..fde3a837702 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) { return x != static_cast(0); } +int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { + bfloat16* const buffer = reinterpret_cast(buffer_raw); + const float start(buffer[0]); + const float delta = static_cast(buffer[1]) - start; + for (npy_intp i = 2; i < length; ++i) { + buffer[i] = static_cast(start + i * delta); + } + return 0; +} + // NumPy casts // Performs a NumPy array cast from type 'From' to 'To'. @@ -548,6 +558,7 @@ bool Initialize() { NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; + NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 09d4b01fa43..bc928cd9e5e 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase): np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)), atol=2e-2) + def testArange(self): + self.assertAllEqual( + np.arange(100, dtype=np.float32).astype(bfloat16), + np.arange(100, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), + np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), + np.arange(-0., -7., -0.25, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), + np.arange(-16384., 16384., 64., dtype=bfloat16)) + if __name__ == "__main__": test.main()