Merge JAX and TF bfloat16 numpy extensions.
Some years back, I forked the TF bfloat16 numpy extension to create a JAX version of the same NumPy extension. The TF version has not been actively maintained, whereas the JAX version is substantially more feature-complete (e.g., it implements most of the NumPy ufuncs). However, having two different NumPy extensions that register the same type causes problems, e.g., if someone loads the (less complete) TF implementation first it takes priority over the (more complete) JAX implementation. Fix this by merging the two implementations and replacing the TF bfloat16 implementation with the JAX version. The best case would be to go one step further and move the bfloat16 code into its own pip package that can be shared by TF and JAX (and other systems), but we leave this for future work. A side effect of this change is that calls to numpy.testing.assert_allclose require an explicit cast to a non-bfloat16 type. PiperOrigin-RevId: 346350783 Change-Id: Ic4d26457f9c9f50ef4c31b4adc3e938101c8e037
This commit is contained in:
parent
8206491e82
commit
24ffe9f729
@ -97,13 +97,13 @@ cc_library(
|
||||
name = "types",
|
||||
srcs = ["types.cc"],
|
||||
hdrs = ["types.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":bfloat16",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -113,6 +113,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:bfloat16_lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@ -158,42 +159,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16",
|
||||
srcs = ["bfloat16.cc"],
|
||||
hdrs = ["bfloat16.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core/platform:bfloat16",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bfloat16_test",
|
||||
srcs = ["bfloat16_test.py"],
|
||||
main = "bfloat16_test.py",
|
||||
python_version = "PY3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
":xla_extension",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
] + xla_py_test_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_client",
|
||||
srcs = [
|
||||
@ -206,6 +171,7 @@ cc_library(
|
||||
"py_client.h",
|
||||
"py_executable.h",
|
||||
],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -232,6 +198,7 @@ cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
hdrs = ["dlpack.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -263,6 +230,7 @@ cc_library(
|
||||
name = "jax_jit",
|
||||
srcs = ["jax_jit.cc"],
|
||||
hdrs = ["jax_jit.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -292,6 +260,7 @@ cc_library(
|
||||
name = "ops",
|
||||
srcs = ["ops.cc"],
|
||||
hdrs = ["ops.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -356,6 +325,7 @@ cc_library(
|
||||
name = "outfeed_receiver_py",
|
||||
srcs = ["outfeed_receiver_py.cc"],
|
||||
hdrs = ["outfeed_receiver_py.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -435,6 +405,7 @@ cc_library(
|
||||
name = "xla_compiler",
|
||||
srcs = ["xla_compiler.cc"],
|
||||
hdrs = ["xla_compiler.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -481,7 +452,6 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "xla_extension",
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":dlpack",
|
||||
":jax_jit",
|
||||
":ops",
|
||||
@ -534,6 +504,7 @@ pybind_extension(
|
||||
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
|
||||
# not require Tensorflow.
|
||||
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
|
||||
"//tensorflow/python:bfloat16_lib",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
] + select({
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,28 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
xla::StatusOr<pybind11::object> Bfloat16Dtype();
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
@ -1,440 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for the bfloat16 Python type."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.python import xla_client
|
||||
|
||||
bfloat16 = xla_client.bfloat16
|
||||
|
||||
|
||||
def numpy_assert_allclose(a, b, **kwargs):
|
||||
a = a.astype(np.float32) if a.dtype == bfloat16 else a
|
||||
b = b.astype(np.float32) if b.dtype == bfloat16 else b
|
||||
return np.testing.assert_allclose(a, b, **kwargs)
|
||||
|
||||
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
|
||||
# Values that should round trip exactly to float and back.
|
||||
FLOAT_VALUES = [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"),
|
||||
float("-inf"),
|
||||
float("nan")
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16Test(parameterized.TestCase):
|
||||
"""Tests the non-numpy Python methods of the bfloat16 type."""
|
||||
|
||||
def testRoundTripToFloat(self):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, float(bfloat16(v)))
|
||||
|
||||
def testRoundTripNumpyTypes(self):
|
||||
for dtype in [np.float16, np.float32, np.float64]:
|
||||
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
|
||||
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
|
||||
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
|
||||
np.testing.assert_equal(
|
||||
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
|
||||
|
||||
def testRoundTripToInt(self):
|
||||
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
|
||||
self.assertEqual(v, int(bfloat16(v)))
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + dtype.__name__,
|
||||
"dtype": dtype
|
||||
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
|
||||
def testRoundTripToNumpy(self, dtype):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, bfloat16(dtype(v)))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
|
||||
if dtype != bfloat16:
|
||||
np.testing.assert_equal(
|
||||
np.array(FLOAT_VALUES, dtype),
|
||||
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
|
||||
|
||||
def testStr(self):
|
||||
self.assertEqual("0", str(bfloat16(0.0)))
|
||||
self.assertEqual("1", str(bfloat16(1.0)))
|
||||
self.assertEqual("-3.5", str(bfloat16(-3.5)))
|
||||
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("inf", str(bfloat16(float("inf"))))
|
||||
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
|
||||
self.assertEqual("nan", str(bfloat16(float("nan"))))
|
||||
|
||||
def testRepr(self):
|
||||
self.assertEqual("0", repr(bfloat16(0)))
|
||||
self.assertEqual("1", repr(bfloat16(1)))
|
||||
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
|
||||
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("inf", repr(bfloat16(float("inf"))))
|
||||
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
|
||||
self.assertEqual("nan", repr(bfloat16(float("nan"))))
|
||||
|
||||
def testHash(self):
|
||||
self.assertEqual(0, hash(bfloat16(0.0)))
|
||||
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
|
||||
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
|
||||
|
||||
# Tests for Python operations
|
||||
def testNegate(self):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(-v, float(-bfloat16(v)))
|
||||
|
||||
def testAdd(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
|
||||
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
|
||||
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
|
||||
|
||||
# Test type promotion against Numpy scalar values.
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
|
||||
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
|
||||
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32,
|
||||
type(bfloat16(3.5) + np.array(2.25, np.float32)))
|
||||
self.assertEqual(np.float32,
|
||||
type(np.array(3.5, np.float32) + bfloat16(2.25)))
|
||||
|
||||
def testSub(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
|
||||
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
|
||||
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
|
||||
|
||||
def testMul(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
|
||||
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
|
||||
|
||||
def testDiv(self):
|
||||
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
|
||||
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
|
||||
|
||||
def testLess(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
|
||||
|
||||
def testLessEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
|
||||
|
||||
def testGreater(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
|
||||
|
||||
def testGreaterEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
|
||||
|
||||
def testEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
|
||||
|
||||
def testNotEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
|
||||
|
||||
def testNan(self):
|
||||
a = np.isnan(bfloat16(float("nan")))
|
||||
self.assertTrue(a)
|
||||
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
|
||||
|
||||
a = np.array([bfloat16(1.34375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=bfloat16)
|
||||
b = np.array(
|
||||
[bfloat16(1.3359375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=bfloat16)
|
||||
numpy_assert_allclose(
|
||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(FLOAT_VALUES)
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
|
||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
||||
|
||||
UNARY_UFUNCS = [
|
||||
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
|
||||
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
|
||||
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
|
||||
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
|
||||
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
|
||||
]
|
||||
|
||||
BINARY_UFUNCS = [
|
||||
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
|
||||
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
|
||||
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
|
||||
]
|
||||
|
||||
BINARY_PREDICATE_UFUNCS = [
|
||||
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
|
||||
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
"""Tests the NumPy integration of the bfloat16 type."""
|
||||
|
||||
def testDtype(self):
|
||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
||||
|
||||
def testDeepCopyDoesNotAlterHash(self):
|
||||
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
||||
# value of the type descriptor is not initialized correctly, a deep copy
|
||||
# can change the type hash.
|
||||
dtype = np.dtype(bfloat16)
|
||||
h = hash(dtype)
|
||||
_ = copy.deepcopy(dtype)
|
||||
self.assertEqual(h, hash(dtype))
|
||||
|
||||
def testArray(self):
|
||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||
self.assertEqual(bfloat16, x.dtype)
|
||||
self.assertEqual("[[1 2 3]]", str(x))
|
||||
np.testing.assert_equal(x, x)
|
||||
numpy_assert_allclose(x, x)
|
||||
self.assertTrue((x == x).all())
|
||||
|
||||
def testComparisons(self):
|
||||
x = np.array([401408, 7, -32], dtype=np.float32)
|
||||
bx = x.astype(bfloat16)
|
||||
y = np.array([82432, 7, 0], dtype=np.float32)
|
||||
by = y.astype(bfloat16)
|
||||
np.testing.assert_equal(x == y, bx == by)
|
||||
np.testing.assert_equal(x != y, bx != by)
|
||||
np.testing.assert_equal(x < y, bx < by)
|
||||
np.testing.assert_equal(x > y, bx > by)
|
||||
np.testing.assert_equal(x <= y, bx <= by)
|
||||
np.testing.assert_equal(x >= y, bx >= by)
|
||||
|
||||
def testEqual2(self):
|
||||
a = np.array([401408], bfloat16)
|
||||
b = np.array([82432], bfloat16)
|
||||
self.assertFalse(a.__eq__(b))
|
||||
|
||||
def testCasts(self):
|
||||
for dtype in [
|
||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
||||
]:
|
||||
x = np.array([[1, 2, 3]], dtype=dtype)
|
||||
y = x.astype(bfloat16)
|
||||
z = y.astype(dtype)
|
||||
self.assertTrue(np.all(x == y))
|
||||
self.assertEqual(bfloat16, y.dtype)
|
||||
self.assertTrue(np.all(x == z))
|
||||
self.assertEqual(dtype, z.dtype)
|
||||
|
||||
def testConformNumpyComplex(self):
|
||||
for dtype in [np.complex64, np.complex128]:
|
||||
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
|
||||
y_np = x.astype(np.float32)
|
||||
y_tf = x.astype(bfloat16)
|
||||
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
|
||||
|
||||
z_np = y_np.astype(dtype)
|
||||
z_tf = y_tf.astype(dtype)
|
||||
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
|
||||
|
||||
def testArange(self):
|
||||
np.testing.assert_equal(
|
||||
np.arange(100, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(100, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-0., -7., -0.25, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in UNARY_UFUNCS))
|
||||
def testUnaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_UFUNCS))
|
||||
def testBinaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x, y).astype(np.float32),
|
||||
op(x.astype(np.float32), y.astype(np.float32)),
|
||||
rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_PREDICATE_UFUNCS))
|
||||
def testBinaryPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
np.testing.assert_equal(
|
||||
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
||||
def testPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
shape = (3, 7, 10)
|
||||
posinf_flips = rng.rand(*shape) < 0.1
|
||||
neginf_flips = rng.rand(*shape) < 0.1
|
||||
nan_flips = rng.rand(*shape) < 0.1
|
||||
vals = rng.randn(*shape)
|
||||
vals = np.where(posinf_flips, np.inf, vals)
|
||||
vals = np.where(neginf_flips, -np.inf, vals)
|
||||
vals = np.where(nan_flips, np.nan, vals)
|
||||
vals = vals.astype(bfloat16)
|
||||
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
|
||||
|
||||
def testDivmod(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
o1, o2 = np.divmod(x, y)
|
||||
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
|
||||
numpy_assert_allclose(o1, e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2, e2, rtol=1e-2)
|
||||
|
||||
def testModf(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
o1, o2 = np.modf(x)
|
||||
e1, e2 = np.modf(x.astype(np.float32))
|
||||
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
|
||||
|
||||
def testLdexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randint(-50, 50, (1, 7))
|
||||
numpy_assert_allclose(
|
||||
np.ldexp(x, y).astype(np.float32),
|
||||
np.ldexp(x.astype(np.float32), y),
|
||||
rtol=1e-2,
|
||||
atol=1e-6)
|
||||
|
||||
def testFrexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
mant1, exp1 = np.frexp(x)
|
||||
mant2, exp2 = np.frexp(x.astype(np.float32))
|
||||
np.testing.assert_equal(exp1, exp2)
|
||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||
|
||||
def testNextAfter(self):
|
||||
one = np.array(1., dtype=bfloat16)
|
||||
two = np.array(2., dtype=bfloat16)
|
||||
zero = np.array(0., dtype=bfloat16)
|
||||
nan = np.array(np.nan, dtype=bfloat16)
|
||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||
smallest_denormal = float.fromhex("1.0p-133")
|
||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||
np.testing.assert_equal(
|
||||
np.nextafter(
|
||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||
np.nextafter(
|
||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
|
||||
case U64:
|
||||
return py::dtype::of<uint64>();
|
||||
case BF16: {
|
||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
||||
return py::dtype::from_args(bfloat16);
|
||||
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||
return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
|
||||
}
|
||||
case F16:
|
||||
return py::dtype("e"); // PEP 3118 code for "float16
|
||||
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
|
||||
// We requested an array of uint16 since NumPy doesn't know how
|
||||
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
|
||||
// before handing it back to the caller.
|
||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
||||
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||
bfloat16.inc_ref();
|
||||
array = py::reinterpret_steal<py::array>(
|
||||
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
|
||||
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
|
||||
reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
|
||||
static_cast<PyTypeObject*>(nullptr)));
|
||||
}
|
||||
return array;
|
||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||
#include "tensorflow/compiler/xla/python/ops.h"
|
||||
@ -59,6 +58,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
throw std::runtime_error("Unable to initialize Numpy API");
|
||||
}
|
||||
|
||||
CHECK(tensorflow::RegisterNumpyBfloat16());
|
||||
|
||||
// Types
|
||||
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
||||
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
|
||||
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("OPAQUE_TYPE", OPAQUE_TYPE)
|
||||
.value("TOKEN", TOKEN);
|
||||
|
||||
m.def("bfloat16_dtype", Bfloat16Dtype);
|
||||
m.def("bfloat16_dtype",
|
||||
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
|
||||
|
||||
// Must be before PyClient.compile.
|
||||
BuildXlaCompilerSubmodule(m);
|
||||
|
@ -387,16 +387,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# bfloat16_lib is shared with JAX, and must not depend on any other parts of
|
||||
# TensorFlow.
|
||||
# TODO(phawkins): move bfloat16 into its own pip package.
|
||||
cc_library(
|
||||
name = "bfloat16_lib",
|
||||
srcs = ["lib/core/bfloat16.cc"],
|
||||
hdrs = ["lib/core/bfloat16.h"],
|
||||
deps = [
|
||||
":numpy_lib",
|
||||
":safe_ptr",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//third_party/eigen3",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -2579,6 +2579,12 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
|
||||
|
||||
msgs = [msg]
|
||||
# np.allclose does not always work for our custom bfloat16 extension type
|
||||
# when type promotions are involved, so we first cast any bfloat16 arrays
|
||||
# to float32.
|
||||
a_dtype = a.dtype
|
||||
a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
|
||||
b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
|
||||
if not np.allclose(a, b, rtol=rtol, atol=atol):
|
||||
# Adds more details to np.testing.assert_allclose.
|
||||
#
|
||||
@ -2602,7 +2608,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
msgs.append("not close rhs = {}".format(y))
|
||||
msgs.append("not close dif = {}".format(np.abs(x - y)))
|
||||
msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
|
||||
msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
|
||||
msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
|
||||
# TODO(xpan): There seems to be a bug:
|
||||
# tensorflow/compiler/tests:binary_ops_test pass with float32
|
||||
# nan even though the equal_nan is False by default internally.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -20,11 +20,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Register the bfloat16 numpy type.
|
||||
void RegisterNumpyBfloat16();
|
||||
// Register the bfloat16 numpy type. Returns true on success.
|
||||
bool RegisterNumpyBfloat16();
|
||||
|
||||
// Returns the PyObject for the bfloat16 type.
|
||||
PyObject* Bfloat16PyType();
|
||||
// Returns a pointer to the bfloat16 dtype object.
|
||||
PyObject* Bfloat16Dtype();
|
||||
|
||||
// Returns the id number of the bfloat16 numpy type.
|
||||
int Bfloat16NumpyType();
|
||||
|
@ -12,54 +12,80 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Test cases for the bfloat16 Python type."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from tensorflow.python import _pywrap_bfloat16
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||
|
||||
|
||||
def float_values():
|
||||
"""Returns values that should round trip exactly to float and back."""
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
return [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"),
|
||||
float("-inf"),
|
||||
float("nan")
|
||||
]
|
||||
def numpy_assert_allclose(a, b, **kwargs):
|
||||
a = a.astype(np.float32) if a.dtype == bfloat16 else a
|
||||
b = b.astype(np.float32) if b.dtype == bfloat16 else b
|
||||
return np.testing.assert_allclose(a, b, **kwargs)
|
||||
|
||||
|
||||
class Bfloat16Test(test.TestCase):
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
|
||||
def _assertFloatIdentical(self, v, w):
|
||||
if math.isnan(v):
|
||||
self.assertTrue(math.isnan(w))
|
||||
else:
|
||||
self.assertEqual(v, w)
|
||||
# Values that should round trip exactly to float and back.
|
||||
FLOAT_VALUES = [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"),
|
||||
float("-inf"),
|
||||
float("nan")
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16Test(parameterized.TestCase):
|
||||
"""Tests the non-numpy Python methods of the bfloat16 type."""
|
||||
|
||||
def testRoundTripToFloat(self):
|
||||
for v in float_values():
|
||||
self._assertFloatIdentical(v, float(bfloat16(v)))
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, float(bfloat16(v)))
|
||||
|
||||
def testRoundTripNumpyTypes(self):
|
||||
for dtype in [np.float16, np.float32, np.float64]:
|
||||
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
|
||||
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
|
||||
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
|
||||
np.testing.assert_equal(
|
||||
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
|
||||
|
||||
def testRoundTripToInt(self):
|
||||
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
|
||||
self.assertEqual(v, int(bfloat16(v)))
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + dtype.__name__,
|
||||
"dtype": dtype
|
||||
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
|
||||
def testRoundTripToNumpy(self, dtype):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, bfloat16(dtype(v)))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
|
||||
if dtype != bfloat16:
|
||||
np.testing.assert_equal(
|
||||
np.array(FLOAT_VALUES, dtype),
|
||||
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
|
||||
|
||||
def testStr(self):
|
||||
self.assertEqual("0", str(bfloat16(0.0)))
|
||||
self.assertEqual("1", str(bfloat16(1.0)))
|
||||
@ -70,14 +96,13 @@ class Bfloat16Test(test.TestCase):
|
||||
self.assertEqual("nan", str(bfloat16(float("nan"))))
|
||||
|
||||
def testRepr(self):
|
||||
self.assertEqual("bfloat16(0)", repr(bfloat16(0)))
|
||||
self.assertEqual("bfloat16(1)", repr(bfloat16(1)))
|
||||
self.assertEqual("bfloat16(-3.5)", repr(bfloat16(-3.5)))
|
||||
self.assertEqual("bfloat16(0.0078125)",
|
||||
repr(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("bfloat16(inf)", repr(bfloat16(float("inf"))))
|
||||
self.assertEqual("bfloat16(-inf)", repr(bfloat16(float("-inf"))))
|
||||
self.assertEqual("bfloat16(nan)", repr(bfloat16(float("nan"))))
|
||||
self.assertEqual("0", repr(bfloat16(0)))
|
||||
self.assertEqual("1", repr(bfloat16(1)))
|
||||
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
|
||||
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("inf", repr(bfloat16(float("inf"))))
|
||||
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
|
||||
self.assertEqual("nan", repr(bfloat16(float("nan"))))
|
||||
|
||||
def testHash(self):
|
||||
self.assertEqual(0, hash(bfloat16(0.0)))
|
||||
@ -86,115 +111,166 @@ class Bfloat16Test(test.TestCase):
|
||||
|
||||
# Tests for Python operations
|
||||
def testNegate(self):
|
||||
for v in float_values():
|
||||
self._assertFloatIdentical(-v, float(-bfloat16(v)))
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(-v, float(-bfloat16(v)))
|
||||
|
||||
def testAdd(self):
|
||||
self._assertFloatIdentical(0, float(bfloat16(0) + bfloat16(0)))
|
||||
self._assertFloatIdentical(1, float(bfloat16(1) + bfloat16(0)))
|
||||
self._assertFloatIdentical(0, float(bfloat16(1) + bfloat16(-1)))
|
||||
self._assertFloatIdentical(5.5, float(bfloat16(2) + bfloat16(3.5)))
|
||||
self._assertFloatIdentical(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("inf"),
|
||||
float(bfloat16(float("inf")) + bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("-inf"),
|
||||
float(bfloat16(float("-inf")) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
|
||||
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
|
||||
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
|
||||
|
||||
# Test type promotion against Numpy scalar values.
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
|
||||
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
|
||||
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32,
|
||||
type(bfloat16(3.5) + np.array(2.25, np.float32)))
|
||||
self.assertEqual(np.float32,
|
||||
type(np.array(3.5, np.float32) + bfloat16(2.25)))
|
||||
|
||||
def testSub(self):
|
||||
self._assertFloatIdentical(0, float(bfloat16(0) - bfloat16(0)))
|
||||
self._assertFloatIdentical(1, float(bfloat16(1) - bfloat16(0)))
|
||||
self._assertFloatIdentical(2, float(bfloat16(1) - bfloat16(-1)))
|
||||
self._assertFloatIdentical(-1.5, float(bfloat16(2) - bfloat16(3.5)))
|
||||
self._assertFloatIdentical(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("-inf"),
|
||||
float(bfloat16(-2.25) - bfloat16(float("inf"))))
|
||||
self._assertFloatIdentical(float("inf"),
|
||||
float(bfloat16(-2.25) - bfloat16(float("-inf"))))
|
||||
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
|
||||
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
|
||||
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
|
||||
|
||||
def testMul(self):
|
||||
self._assertFloatIdentical(0, float(bfloat16(0) * bfloat16(0)))
|
||||
self._assertFloatIdentical(0, float(bfloat16(1) * bfloat16(0)))
|
||||
self._assertFloatIdentical(-1, float(bfloat16(1) * bfloat16(-1)))
|
||||
self._assertFloatIdentical(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("-inf"),
|
||||
float(bfloat16(float("inf")) * bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("inf"),
|
||||
float(bfloat16(float("-inf")) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
|
||||
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
|
||||
|
||||
def testDiv(self):
|
||||
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
|
||||
self._assertFloatIdentical(float("inf"), float(bfloat16(1) / bfloat16(0)))
|
||||
self._assertFloatIdentical(-1, float(bfloat16(1) / bfloat16(-1)))
|
||||
self._assertFloatIdentical(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
|
||||
self._assertFloatIdentical(float("-inf"),
|
||||
float(bfloat16(float("inf")) / bfloat16(-2.25)))
|
||||
self._assertFloatIdentical(float("inf"),
|
||||
float(bfloat16(float("-inf")) / bfloat16(-2.25)))
|
||||
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
|
||||
|
||||
def testLess(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
|
||||
|
||||
def testLessEqual(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
|
||||
|
||||
def testGreater(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
|
||||
|
||||
def testGreaterEqual(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
|
||||
|
||||
def testEqual(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
|
||||
|
||||
def testNotEqual(self):
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
|
||||
|
||||
def testNan(self):
|
||||
a = np.isnan(bfloat16(float("nan")))
|
||||
self.assertTrue(a)
|
||||
np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
|
||||
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
|
||||
|
||||
a = np.array(
|
||||
[bfloat16(1.34375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=dtypes.bfloat16.as_numpy_dtype)
|
||||
a = np.array([bfloat16(1.34375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=bfloat16)
|
||||
b = np.array(
|
||||
[bfloat16(1.3359375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=dtypes.bfloat16.as_numpy_dtype)
|
||||
np.testing.assert_allclose(
|
||||
dtype=bfloat16)
|
||||
numpy_assert_allclose(
|
||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(FLOAT_VALUES)
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
class Bfloat16NumPyTest(test.TestCase):
|
||||
|
||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
||||
|
||||
UNARY_UFUNCS = [
|
||||
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
|
||||
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
|
||||
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
|
||||
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
|
||||
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
|
||||
]
|
||||
|
||||
BINARY_UFUNCS = [
|
||||
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
|
||||
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
|
||||
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
|
||||
]
|
||||
|
||||
BINARY_PREDICATE_UFUNCS = [
|
||||
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
|
||||
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
"""Tests the NumPy integration of the bfloat16 type."""
|
||||
|
||||
def testDtype(self):
|
||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
||||
|
||||
def testDeepCopyDoesNotAlterHash(self):
|
||||
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
||||
# value of the type descriptor is not initialized correctly, a deep copy
|
||||
# can change the type hash.
|
||||
dtype = np.dtype(bfloat16)
|
||||
h = hash(dtype)
|
||||
_ = copy.deepcopy(dtype)
|
||||
self.assertEqual(h, hash(dtype))
|
||||
|
||||
def testArray(self):
|
||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||
self.assertEqual(bfloat16, x.dtype)
|
||||
self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
|
||||
self.assertAllEqual(x, x)
|
||||
self.assertAllClose(x, x)
|
||||
self.assertEqual("[[1 2 3]]", str(x))
|
||||
np.testing.assert_equal(x, x)
|
||||
numpy_assert_allclose(x, x)
|
||||
self.assertTrue((x == x).all())
|
||||
|
||||
def testComparisons(self):
|
||||
@ -202,12 +278,12 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
bx = x.astype(bfloat16)
|
||||
y = np.array([82432, 7, 0], dtype=np.float32)
|
||||
by = y.astype(bfloat16)
|
||||
self.assertAllEqual(x == y, bx == by)
|
||||
self.assertAllEqual(x != y, bx != by)
|
||||
self.assertAllEqual(x < y, bx < by)
|
||||
self.assertAllEqual(x > y, bx > by)
|
||||
self.assertAllEqual(x <= y, bx <= by)
|
||||
self.assertAllEqual(x >= y, bx >= by)
|
||||
np.testing.assert_equal(x == y, bx == by)
|
||||
np.testing.assert_equal(x != y, bx != by)
|
||||
np.testing.assert_equal(x < y, bx < by)
|
||||
np.testing.assert_equal(x > y, bx > by)
|
||||
np.testing.assert_equal(x <= y, bx <= by)
|
||||
np.testing.assert_equal(x >= y, bx >= by)
|
||||
|
||||
def testEqual2(self):
|
||||
a = np.array([401408], bfloat16)
|
||||
@ -216,8 +292,10 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
|
||||
def testCasts(self):
|
||||
for dtype in [
|
||||
np.float16, np.float32, np.float64, np.int32, np.int64,
|
||||
np.complex64, np.complex128]:
|
||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
||||
]:
|
||||
x = np.array([[1, 2, 3]], dtype=dtype)
|
||||
y = x.astype(bfloat16)
|
||||
z = y.astype(dtype)
|
||||
@ -231,44 +309,133 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
|
||||
y_np = x.astype(np.float32)
|
||||
y_tf = x.astype(bfloat16)
|
||||
self.assertAllClose(y_np, y_tf, atol=2e-2)
|
||||
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
|
||||
|
||||
z_np = y_np.astype(dtype)
|
||||
z_tf = y_tf.astype(dtype)
|
||||
self.assertAllClose(z_np, z_tf, atol=2e-2)
|
||||
|
||||
def testAdd(self):
|
||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||
y = np.array([[4, 5, 6]], dtype=bfloat16)
|
||||
self.assertAllClose(np.array([[5, 7, 9]]), x + y)
|
||||
|
||||
def testLogSumExp(self):
|
||||
x = np.array([[1, 2, 3]], dtype=np.float32)
|
||||
y = np.array([[4, 5, 6]], dtype=np.float32)
|
||||
self.assertAllClose(np.logaddexp(x, y),
|
||||
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
|
||||
atol=2e-2)
|
||||
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
|
||||
|
||||
def testArange(self):
|
||||
self.assertAllEqual(
|
||||
np.testing.assert_equal(
|
||||
np.arange(100, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(100, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.testing.assert_equal(
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.testing.assert_equal(
|
||||
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-0., -7., -0.25, dtype=bfloat16))
|
||||
self.assertAllEqual(
|
||||
np.testing.assert_equal(
|
||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(float_values())
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in UNARY_UFUNCS))
|
||||
def testUnaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_UFUNCS))
|
||||
def testBinaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x, y).astype(np.float32),
|
||||
op(x.astype(np.float32), y.astype(np.float32)),
|
||||
rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_PREDICATE_UFUNCS))
|
||||
def testBinaryPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
np.testing.assert_equal(
|
||||
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
||||
def testPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
shape = (3, 7, 10)
|
||||
posinf_flips = rng.rand(*shape) < 0.1
|
||||
neginf_flips = rng.rand(*shape) < 0.1
|
||||
nan_flips = rng.rand(*shape) < 0.1
|
||||
vals = rng.randn(*shape)
|
||||
vals = np.where(posinf_flips, np.inf, vals)
|
||||
vals = np.where(neginf_flips, -np.inf, vals)
|
||||
vals = np.where(nan_flips, np.nan, vals)
|
||||
vals = vals.astype(bfloat16)
|
||||
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
|
||||
|
||||
def testDivmod(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
o1, o2 = np.divmod(x, y)
|
||||
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
|
||||
numpy_assert_allclose(o1, e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2, e2, rtol=1e-2)
|
||||
|
||||
def testModf(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
o1, o2 = np.modf(x)
|
||||
e1, e2 = np.modf(x.astype(np.float32))
|
||||
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
|
||||
|
||||
def testLdexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randint(-50, 50, (1, 7))
|
||||
numpy_assert_allclose(
|
||||
np.ldexp(x, y).astype(np.float32),
|
||||
np.ldexp(x.astype(np.float32), y),
|
||||
rtol=1e-2,
|
||||
atol=1e-6)
|
||||
|
||||
def testFrexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
mant1, exp1 = np.frexp(x)
|
||||
mant2, exp2 = np.frexp(x.astype(np.float32))
|
||||
np.testing.assert_equal(exp1, exp2)
|
||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||
|
||||
def testNextAfter(self):
|
||||
one = np.array(1., dtype=bfloat16)
|
||||
two = np.array(2., dtype=bfloat16)
|
||||
zero = np.array(0., dtype=bfloat16)
|
||||
nan = np.array(np.nan, dtype=bfloat16)
|
||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||
smallest_denormal = float.fromhex("1.0p-133")
|
||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||
np.testing.assert_equal(
|
||||
np.nextafter(
|
||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||
np.nextafter(
|
||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
absltest.main()
|
||||
|
@ -20,5 +20,5 @@ PYBIND11_MODULE(_pywrap_bfloat16, m) {
|
||||
tensorflow::RegisterNumpyBfloat16();
|
||||
|
||||
m.def("TF_bfloat16_type",
|
||||
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
|
||||
[] { return pybind11::handle(tensorflow::Bfloat16Dtype()); });
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user