Merge JAX and TF bfloat16 numpy extensions.

Some years back, I forked the TF bfloat16 numpy extension to create a JAX version of the same NumPy extension. The TF version has not been actively maintained, whereas the JAX version is substantially more feature-complete (e.g., it implements most of the NumPy ufuncs).

However, having two different NumPy extensions that register the same type causes problems, e.g., if someone loads the (less complete) TF implementation first it takes priority over the (more complete) JAX implementation. Fix this by merging the two implementations and replacing the TF bfloat16 implementation with the JAX version.

The best case would be to go one step further and move the bfloat16 code into its own pip package that can be shared by TF and JAX (and other systems), but we leave this for future work.

A side effect of this change is that calls to numpy.testing.assert_allclose require an explicit cast to a non-bfloat16 type.

PiperOrigin-RevId: 346350783
Change-Id: Ic4d26457f9c9f50ef4c31b4adc3e938101c8e037
This commit is contained in:
Peter Hawkins 2020-12-08 10:04:00 -08:00 committed by TensorFlower Gardener
parent 8206491e82
commit 24ffe9f729
12 changed files with 1458 additions and 2493 deletions

View File

@ -97,13 +97,13 @@ cc_library(
name = "types",
srcs = ["types.cc"],
hdrs = ["types.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":bfloat16",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -113,6 +113,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core:lib",
"//tensorflow/python:bfloat16_lib",
"//third_party/py/numpy:headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
@ -158,42 +159,6 @@ cc_library(
],
)
cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
hdrs = ["bfloat16.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core/platform:bfloat16",
"//tensorflow/core/platform:logging",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/strings",
"@pybind11",
],
)
py_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.py"],
main = "bfloat16_test.py",
python_version = "PY3",
tags = ["no_oss"],
deps = [
":xla_client",
":xla_extension",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
] + xla_py_test_deps(),
)
cc_library(
name = "py_client",
srcs = [
@ -206,6 +171,7 @@ cc_library(
"py_client.h",
"py_executable.h",
],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -232,6 +198,7 @@ cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -263,6 +230,7 @@ cc_library(
name = "jax_jit",
srcs = ["jax_jit.cc"],
hdrs = ["jax_jit.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -292,6 +260,7 @@ cc_library(
name = "ops",
srcs = ["ops.cc"],
hdrs = ["ops.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -356,6 +325,7 @@ cc_library(
name = "outfeed_receiver_py",
srcs = ["outfeed_receiver_py.cc"],
hdrs = ["outfeed_receiver_py.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -435,6 +405,7 @@ cc_library(
name = "xla_compiler",
srcs = ["xla_compiler.cc"],
hdrs = ["xla_compiler.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -481,7 +452,6 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "xla_extension",
deps = [
":bfloat16",
":dlpack",
":jax_jit",
":ops",
@ -534,6 +504,7 @@ pybind_extension(
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
# not require Tensorflow.
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/python:bfloat16_lib",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:platform",
] + select({

File diff suppressed because it is too large Load Diff

View File

@ -1,28 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
xla::StatusOr<pybind11::object> Bfloat16Dtype();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_

View File

@ -1,440 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for the bfloat16 Python type."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import math
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.xla.python import xla_client
bfloat16 = xla_client.bfloat16
def numpy_assert_allclose(a, b, **kwargs):
a = a.astype(np.float32) if a.dtype == bfloat16 else a
b = b.astype(np.float32) if b.dtype == bfloat16 else b
return np.testing.assert_allclose(a, b, **kwargs)
epsilon = float.fromhex("1.0p-7")
# Values that should round trip exactly to float and back.
FLOAT_VALUES = [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
class Bfloat16Test(parameterized.TestCase):
"""Tests the non-numpy Python methods of the bfloat16 type."""
def testRoundTripToFloat(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, float(bfloat16(v)))
def testRoundTripNumpyTypes(self):
for dtype in [np.float16, np.float32, np.float64]:
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
np.testing.assert_equal(
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
def testRoundTripToInt(self):
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
self.assertEqual(v, int(bfloat16(v)))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + dtype.__name__,
"dtype": dtype
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
def testRoundTripToNumpy(self, dtype):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, bfloat16(dtype(v)))
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
if dtype != bfloat16:
np.testing.assert_equal(
np.array(FLOAT_VALUES, dtype),
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
def testStr(self):
self.assertEqual("0", str(bfloat16(0.0)))
self.assertEqual("1", str(bfloat16(1.0)))
self.assertEqual("-3.5", str(bfloat16(-3.5)))
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", str(bfloat16(float("inf"))))
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
self.assertEqual("nan", str(bfloat16(float("nan"))))
def testRepr(self):
self.assertEqual("0", repr(bfloat16(0)))
self.assertEqual("1", repr(bfloat16(1)))
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", repr(bfloat16(float("inf"))))
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
def testHash(self):
self.assertEqual(0, hash(bfloat16(0.0)))
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
# Tests for Python operations
def testNegate(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(-v, float(-bfloat16(v)))
def testAdd(self):
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
# Test type promotion against Numpy scalar values.
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32,
type(bfloat16(3.5) + np.array(2.25, np.float32)))
self.assertEqual(np.float32,
type(np.array(3.5, np.float32) + bfloat16(2.25)))
def testSub(self):
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
np.testing.assert_equal(
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
def testMul(self):
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
def testDiv(self):
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
a = np.isnan(bfloat16(float("nan")))
self.assertTrue(a)
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
a = np.array([bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
b = np.array(
[bfloat16(1.3359375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
UNARY_UFUNCS = [
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
]
BINARY_UFUNCS = [
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
]
BINARY_PREDICATE_UFUNCS = [
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
]
class Bfloat16NumPyTest(parameterized.TestCase):
"""Tests the NumPy integration of the bfloat16 type."""
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)
self.assertEqual("[[1 2 3]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
np.testing.assert_equal(x < y, bx < by)
np.testing.assert_equal(x > y, bx > by)
np.testing.assert_equal(x <= y, bx <= by)
np.testing.assert_equal(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCasts(self):
for dtype in [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]:
x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16)
z = y.astype(dtype)
self.assertTrue(np.all(x == y))
self.assertEqual(bfloat16, y.dtype)
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)
def testConformNumpyComplex(self):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
y_np = x.astype(np.float32)
y_tf = x.astype(bfloat16)
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
z_np = y_np.astype(dtype)
z_tf = y_tf.astype(dtype)
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
def testArange(self):
np.testing.assert_equal(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_PREDICATE_UFUNCS))
def testBinaryPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
np.testing.assert_equal(
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1
vals = rng.randn(*shape)
vals = np.where(posinf_flips, np.inf, vals)
vals = np.where(neginf_flips, -np.inf, vals)
vals = np.where(nan_flips, np.nan, vals)
vals = vals.astype(bfloat16)
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
def testDivmod(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
o1, o2 = np.divmod(x, y)
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
numpy_assert_allclose(o1, e1, rtol=1e-2)
numpy_assert_allclose(o2, e2, rtol=1e-2)
def testModf(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
o1, o2 = np.modf(x)
e1, e2 = np.modf(x.astype(np.float32))
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
def testLdexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randint(-50, 50, (1, 7))
numpy_assert_allclose(
np.ldexp(x, y).astype(np.float32),
np.ldexp(x.astype(np.float32), y),
rtol=1e-2,
atol=1e-6)
def testFrexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
mant1, exp1 = np.frexp(x)
mant2, exp2 = np.frexp(x.astype(np.float32))
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
absltest.main()

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/python/lib/core/bfloat16.h"
namespace xla {
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
case U64:
return py::dtype::of<uint64>();
case BF16: {
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
return py::dtype::from_args(bfloat16);
py::handle bfloat16(tensorflow::Bfloat16Dtype());
return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
}
case F16:
return py::dtype("e"); // PEP 3118 code for "float16
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
// We requested an array of uint16 since NumPy doesn't know how
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
// before handing it back to the caller.
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
py::handle bfloat16(tensorflow::Bfloat16Dtype());
bfloat16.inc_ref();
array = py::reinterpret_steal<py::array>(
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
static_cast<PyTypeObject*>(nullptr)));
}
return array;

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/jax_jit.h"
#include "tensorflow/compiler/xla/python/ops.h"
@ -59,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
throw std::runtime_error("Unable to initialize Numpy API");
}
CHECK(tensorflow::RegisterNumpyBfloat16());
// Types
py::enum_<PrimitiveType>(m, "PrimitiveType")
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
.value("OPAQUE_TYPE", OPAQUE_TYPE)
.value("TOKEN", TOKEN);
m.def("bfloat16_dtype", Bfloat16Dtype);
m.def("bfloat16_dtype",
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
// Must be before PyClient.compile.
BuildXlaCompilerSubmodule(m);

View File

@ -387,16 +387,19 @@ cc_library(
],
)
# bfloat16_lib is shared with JAX, and must not depend on any other parts of
# TensorFlow.
# TODO(phawkins): move bfloat16 into its own pip package.
cc_library(
name = "bfloat16_lib",
srcs = ["lib/core/bfloat16.cc"],
hdrs = ["lib/core/bfloat16.h"],
deps = [
":numpy_lib",
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//third_party/eigen3",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
],
)

View File

@ -2579,6 +2579,12 @@ class TensorFlowTestCase(googletest.TestCase):
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
msgs = [msg]
# np.allclose does not always work for our custom bfloat16 extension type
# when type promotions are involved, so we first cast any bfloat16 arrays
# to float32.
a_dtype = a.dtype
a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Adds more details to np.testing.assert_allclose.
#
@ -2602,7 +2608,7 @@ class TensorFlowTestCase(googletest.TestCase):
msgs.append("not close rhs = {}".format(y))
msgs.append("not close dif = {}".format(np.abs(x - y)))
msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
# TODO(xpan): There seems to be a bug:
# tensorflow/compiler/tests:binary_ops_test pass with float32
# nan even though the equal_nan is False by default internally.

File diff suppressed because it is too large Load Diff

View File

@ -20,11 +20,11 @@ limitations under the License.
namespace tensorflow {
// Register the bfloat16 numpy type.
void RegisterNumpyBfloat16();
// Register the bfloat16 numpy type. Returns true on success.
bool RegisterNumpyBfloat16();
// Returns the PyObject for the bfloat16 type.
PyObject* Bfloat16PyType();
// Returns a pointer to the bfloat16 dtype object.
PyObject* Bfloat16Dtype();
// Returns the id number of the bfloat16 numpy type.
int Bfloat16NumpyType();

View File

@ -12,54 +12,80 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for the bfloat16 Python type."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import math
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
# pylint: disable=unused-import,g-bad-import-order
from tensorflow.python import _pywrap_bfloat16
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
def float_values():
"""Returns values that should round trip exactly to float and back."""
epsilon = float.fromhex("1.0p-7")
return [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
def numpy_assert_allclose(a, b, **kwargs):
a = a.astype(np.float32) if a.dtype == bfloat16 else a
b = b.astype(np.float32) if b.dtype == bfloat16 else b
return np.testing.assert_allclose(a, b, **kwargs)
class Bfloat16Test(test.TestCase):
epsilon = float.fromhex("1.0p-7")
def _assertFloatIdentical(self, v, w):
if math.isnan(v):
self.assertTrue(math.isnan(w))
else:
self.assertEqual(v, w)
# Values that should round trip exactly to float and back.
FLOAT_VALUES = [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
class Bfloat16Test(parameterized.TestCase):
"""Tests the non-numpy Python methods of the bfloat16 type."""
def testRoundTripToFloat(self):
for v in float_values():
self._assertFloatIdentical(v, float(bfloat16(v)))
for v in FLOAT_VALUES:
np.testing.assert_equal(v, float(bfloat16(v)))
def testRoundTripNumpyTypes(self):
for dtype in [np.float16, np.float32, np.float64]:
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
np.testing.assert_equal(
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
def testRoundTripToInt(self):
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
self.assertEqual(v, int(bfloat16(v)))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + dtype.__name__,
"dtype": dtype
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
def testRoundTripToNumpy(self, dtype):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, bfloat16(dtype(v)))
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
if dtype != bfloat16:
np.testing.assert_equal(
np.array(FLOAT_VALUES, dtype),
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
def testStr(self):
self.assertEqual("0", str(bfloat16(0.0)))
self.assertEqual("1", str(bfloat16(1.0)))
@ -70,14 +96,13 @@ class Bfloat16Test(test.TestCase):
self.assertEqual("nan", str(bfloat16(float("nan"))))
def testRepr(self):
self.assertEqual("bfloat16(0)", repr(bfloat16(0)))
self.assertEqual("bfloat16(1)", repr(bfloat16(1)))
self.assertEqual("bfloat16(-3.5)", repr(bfloat16(-3.5)))
self.assertEqual("bfloat16(0.0078125)",
repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("bfloat16(inf)", repr(bfloat16(float("inf"))))
self.assertEqual("bfloat16(-inf)", repr(bfloat16(float("-inf"))))
self.assertEqual("bfloat16(nan)", repr(bfloat16(float("nan"))))
self.assertEqual("0", repr(bfloat16(0)))
self.assertEqual("1", repr(bfloat16(1)))
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", repr(bfloat16(float("inf"))))
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
def testHash(self):
self.assertEqual(0, hash(bfloat16(0.0)))
@ -86,115 +111,166 @@ class Bfloat16Test(test.TestCase):
# Tests for Python operations
def testNegate(self):
for v in float_values():
self._assertFloatIdentical(-v, float(-bfloat16(v)))
for v in FLOAT_VALUES:
np.testing.assert_equal(-v, float(-bfloat16(v)))
def testAdd(self):
self._assertFloatIdentical(0, float(bfloat16(0) + bfloat16(0)))
self._assertFloatIdentical(1, float(bfloat16(1) + bfloat16(0)))
self._assertFloatIdentical(0, float(bfloat16(1) + bfloat16(-1)))
self._assertFloatIdentical(5.5, float(bfloat16(2) + bfloat16(3.5)))
self._assertFloatIdentical(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
self._assertFloatIdentical(float("inf"),
float(bfloat16(float("inf")) + bfloat16(-2.25)))
self._assertFloatIdentical(float("-inf"),
float(bfloat16(float("-inf")) + bfloat16(-2.25)))
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
# Test type promotion against Numpy scalar values.
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32,
type(bfloat16(3.5) + np.array(2.25, np.float32)))
self.assertEqual(np.float32,
type(np.array(3.5, np.float32) + bfloat16(2.25)))
def testSub(self):
self._assertFloatIdentical(0, float(bfloat16(0) - bfloat16(0)))
self._assertFloatIdentical(1, float(bfloat16(1) - bfloat16(0)))
self._assertFloatIdentical(2, float(bfloat16(1) - bfloat16(-1)))
self._assertFloatIdentical(-1.5, float(bfloat16(2) - bfloat16(3.5)))
self._assertFloatIdentical(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
self._assertFloatIdentical(float("-inf"),
float(bfloat16(-2.25) - bfloat16(float("inf"))))
self._assertFloatIdentical(float("inf"),
float(bfloat16(-2.25) - bfloat16(float("-inf"))))
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
np.testing.assert_equal(
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
def testMul(self):
self._assertFloatIdentical(0, float(bfloat16(0) * bfloat16(0)))
self._assertFloatIdentical(0, float(bfloat16(1) * bfloat16(0)))
self._assertFloatIdentical(-1, float(bfloat16(1) * bfloat16(-1)))
self._assertFloatIdentical(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
self._assertFloatIdentical(float("-inf"),
float(bfloat16(float("inf")) * bfloat16(-2.25)))
self._assertFloatIdentical(float("inf"),
float(bfloat16(float("-inf")) * bfloat16(-2.25)))
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
def testDiv(self):
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
self._assertFloatIdentical(float("inf"), float(bfloat16(1) / bfloat16(0)))
self._assertFloatIdentical(-1, float(bfloat16(1) / bfloat16(-1)))
self._assertFloatIdentical(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
self._assertFloatIdentical(float("-inf"),
float(bfloat16(float("inf")) / bfloat16(-2.25)))
self._assertFloatIdentical(float("inf"),
float(bfloat16(float("-inf")) / bfloat16(-2.25)))
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in float_values():
for w in float_values():
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
a = np.isnan(bfloat16(float("nan")))
self.assertTrue(a)
np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
a = np.array(
[bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=dtypes.bfloat16.as_numpy_dtype)
a = np.array([bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
b = np.array(
[bfloat16(1.3359375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=dtypes.bfloat16.as_numpy_dtype)
np.testing.assert_allclose(
dtype=bfloat16)
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
class Bfloat16NumPyTest(test.TestCase):
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
UNARY_UFUNCS = [
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
]
BINARY_UFUNCS = [
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
]
BINARY_PREDICATE_UFUNCS = [
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
]
class Bfloat16NumPyTest(parameterized.TestCase):
"""Tests the NumPy integration of the bfloat16 type."""
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)
self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
self.assertAllEqual(x, x)
self.assertAllClose(x, x)
self.assertEqual("[[1 2 3]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
@ -202,12 +278,12 @@ class Bfloat16NumPyTest(test.TestCase):
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
self.assertAllEqual(x == y, bx == by)
self.assertAllEqual(x != y, bx != by)
self.assertAllEqual(x < y, bx < by)
self.assertAllEqual(x > y, bx > by)
self.assertAllEqual(x <= y, bx <= by)
self.assertAllEqual(x >= y, bx >= by)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
np.testing.assert_equal(x < y, bx < by)
np.testing.assert_equal(x > y, bx > by)
np.testing.assert_equal(x <= y, bx <= by)
np.testing.assert_equal(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
@ -216,8 +292,10 @@ class Bfloat16NumPyTest(test.TestCase):
def testCasts(self):
for dtype in [
np.float16, np.float32, np.float64, np.int32, np.int64,
np.complex64, np.complex128]:
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]:
x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16)
z = y.astype(dtype)
@ -231,44 +309,133 @@ class Bfloat16NumPyTest(test.TestCase):
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
y_np = x.astype(np.float32)
y_tf = x.astype(bfloat16)
self.assertAllClose(y_np, y_tf, atol=2e-2)
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
z_np = y_np.astype(dtype)
z_tf = y_tf.astype(dtype)
self.assertAllClose(z_np, z_tf, atol=2e-2)
def testAdd(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
y = np.array([[4, 5, 6]], dtype=bfloat16)
self.assertAllClose(np.array([[5, 7, 9]]), x + y)
def testLogSumExp(self):
x = np.array([[1, 2, 3]], dtype=np.float32)
y = np.array([[4, 5, 6]], dtype=np.float32)
self.assertAllClose(np.logaddexp(x, y),
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
atol=2e-2)
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
def testArange(self):
self.assertAllEqual(
np.testing.assert_equal(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
self.assertAllEqual(
np.testing.assert_equal(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
self.assertAllEqual(
np.testing.assert_equal(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
self.assertAllEqual(
np.testing.assert_equal(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
def testSort(self):
values_to_sort = np.float32(float_values())
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_PREDICATE_UFUNCS))
def testBinaryPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
np.testing.assert_equal(
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1
vals = rng.randn(*shape)
vals = np.where(posinf_flips, np.inf, vals)
vals = np.where(neginf_flips, -np.inf, vals)
vals = np.where(nan_flips, np.nan, vals)
vals = vals.astype(bfloat16)
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
def testDivmod(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
o1, o2 = np.divmod(x, y)
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
numpy_assert_allclose(o1, e1, rtol=1e-2)
numpy_assert_allclose(o2, e2, rtol=1e-2)
def testModf(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
o1, o2 = np.modf(x)
e1, e2 = np.modf(x.astype(np.float32))
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
def testLdexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randint(-50, 50, (1, 7))
numpy_assert_allclose(
np.ldexp(x, y).astype(np.float32),
np.ldexp(x.astype(np.float32), y),
rtol=1e-2,
atol=1e-6)
def testFrexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
mant1, exp1 = np.frexp(x)
mant2, exp2 = np.frexp(x.astype(np.float32))
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
test.main()
absltest.main()

View File

@ -20,5 +20,5 @@ PYBIND11_MODULE(_pywrap_bfloat16, m) {
tensorflow::RegisterNumpyBfloat16();
m.def("TF_bfloat16_type",
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
[] { return pybind11::handle(tensorflow::Bfloat16Dtype()); });
}