Justin Lebar aaade421b3 Switch tensorflow::bit_cast -> absl::bit_cast.
PiperOrigin-RevId: 219228749
2018-10-29 18:37:33 -07:00

549 lines
18 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstdint>
#include <limits>
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
class ConvertTest : public ClientLibraryTestBase {
public:
explicit ConvertTest(se::Platform* platform = nullptr)
: ClientLibraryTestBase(platform) {
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
mutable_debug_options()->add_xla_disable_hlo_passes("inline");
}
};
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {42, 64});
ConvertElementType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1S32ToR1U32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {42, 64});
ConvertElementType(a, U32);
std::vector<uint32> expected = {42, 64};
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1S32ToR1PRED) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {42, 0, -64});
ConvertElementType(a, PRED);
std::array<bool, 3> expected = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1U32ToR1U32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint32>(&builder, {42, 64});
ConvertElementType(a, U32);
std::vector<uint32> expected = {42, 64};
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1U32ToR1S32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint32>(&builder, {42, 64});
ConvertElementType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1U32ToR1PRED) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint32>(&builder, {42, 0, 64});
ConvertElementType(a, PRED);
std::array<bool, 3> expected = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1F32ToR1PRED) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f});
ConvertElementType(a, PRED);
std::array<bool, 3> expected = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {42, 64});
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<bool>(&builder, {true, false, true});
ConvertElementType(a, S32);
std::vector<int32> expected = {1, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1PREDToR1U32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<bool>(&builder, {true, false, true});
ConvertElementType(a, U32);
std::vector<uint32> expected = {1, 0, 1};
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<bool>(&builder, {true, false, true});
ConvertElementType(a, F32);
std::vector<float> expected = {1., 0., 1.};
ComputeAndCompareR1<float>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {});
ConvertElementType(a, F32);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.6, 64.4});
ConvertElementType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
XlaBuilder builder(TestName());
std::vector<int64> arg{
-9223371216516022272,
-2,
-1,
-0x7FFFFFFF,
-0x80000000,
0,
1,
2,
1073742145,
1073742656,
0x7FFFFFFF,
0x80000000,
826720496944058148,
4296062029846194332,
0x0007FB72E4000000LL,
0x0007FB72E4000001LL,
0x0007FB72E6000000LL,
0x0007FB72E7000000LL,
0x0007FB72E7FFFFFFLL,
0x0007FB72E8000000LL,
0x0007FB72E8000001LL,
0x0007FB72EA000000LL,
0x0007FB72EB000000LL,
0x0007FB72EBFFFFFFLL,
0x0007FB72EC000000LL,
0x7FFFFF0000000000LL,
0x7FFFFF8000000000LL,
0x7FFFFFFFFFFFFF00,
static_cast<int64>(0xFFFFFFFFFFFFFFFF),
static_cast<int64>(0x0000f234e67e0001LL),
static_cast<int64>(0x8000000000000000),
static_cast<int64>(0x8000000000000000LL),
static_cast<int64>(0x8000000000000001LL),
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
std::vector<float> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<float>(arg[i]);
}
ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
std::vector<float> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<float>(arg[i]);
}
ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
std::vector<uint32> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<uint32>(arg[i]);
}
ComputeAndCompareR1<uint32>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<int64>(arg[i]);
}
ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<int64>(arg[i]);
}
ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
XlaBuilder builder(TestName());
// Test cases from compiler_rt library.
std::vector<float> arg{0.0f,
0.5f,
0.99f,
1.0f,
1.5f,
1.99f,
2.0f,
2.01f,
2147483648.f,
-0.5f,
-0.99f,
-1.0f,
-1.5f,
-1.99f,
-2.0f,
-2.01f,
9223371487098961920.f,
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
expected[i] = static_cast<int64>(arg[i]);
}
ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint8_t>(&builder, {32, 64});
ConvertElementType(a, F32);
std::vector<float> expected = {32.0, 64.0};
ComputeAndCompareR1<float>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint8_t>(&builder, {32, 64});
ConvertElementType(a, S32);
std::vector<int32_t> expected = {32, 64};
ComputeAndCompareR1<int32_t>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<uint8_t>(&builder, {32, 64});
ConvertElementType(a, U32);
std::vector<uint32_t> expected = {32, 64};
ComputeAndCompareR1<uint32_t>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {32.0f, 64.0f});
ConvertElementType(a, F64);
std::vector<double> expected = {32.0, 64.0};
ComputeAndCompareR1<double>(&builder, expected, {});
}
XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<double>(&builder, {32.0, 64.0});
ConvertElementType(a, F32);
std::vector<float> expected = {32.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertS32Extremes) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
std::numeric_limits<int32>::max()});
ConvertElementType(a, F32);
std::vector<float> expected = {
static_cast<float>(std::numeric_limits<int32>::min()),
static_cast<float>(std::numeric_limits<int32>::max())};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(ConvertTest, ConvertMapToS32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
ConvertElementType(param, S32);
auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertMapToF32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
ConvertElementType(param, F32);
auto a = ConstantR1<int32>(&builder, {42, 64});
Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
// Regression test for b/31758660. When ReshapeMover transforms
// input -> reshape -> convert
// to
// input -> convert -> reshape
// the new convert should have the same element type as the old convert.
TEST_F(ConvertTest, ConvertReshape) {
XlaBuilder builder(TestName());
auto input = ConstantR1<int32>(&builder, {42});
auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
ConvertElementType(reshape, F32);
ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
}
std::vector<float> GetInterestingF16ConversionTestCases() {
float infinity = std::numeric_limits<float>::infinity();
float half_min_positive_normal = absl::bit_cast<float, uint32>(0x38800000);
float half_max_subnormal = absl::bit_cast<float, uint32>(0x387fc000);
float half_min_positive_subnormal = absl::bit_cast<float, uint32>(0x33800000);
float half_max = 65504.0f;
std::vector<float> test_cases(
{-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f,
-half_min_positive_subnormal, -half_max_subnormal,
-half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal,
half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max,
(half_max * 2 + 1), infinity});
return test_cases;
}
XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
std::vector<float> test_cases = GetInterestingF16ConversionTestCases();
std::vector<half> input;
absl::c_transform(test_cases, std::back_inserter(input),
[](float f) { return Eigen::half(f); });
std::vector<float> expected_output;
absl::c_transform(input, std::back_inserter(expected_output),
[](Eigen::half h) { return static_cast<float>(h); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
Parameter(&builder, 0,
ShapeUtil::MakeShape(F16, {static_cast<int64>(input.size())}),
"param"),
F32);
ComputeAndCompareR1<float>(&builder, expected_output, {dot_lhs_handle.get()});
}
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
std::vector<float> input = GetInterestingF16ConversionTestCases();
std::vector<half> expected_output;
absl::c_transform(input, std::back_inserter(expected_output),
[](float f) { return Eigen::half(f); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
Parameter(&builder, 0,
ShapeUtil::MakeShape(F32, {static_cast<int64>(input.size())}),
"param"),
F16);
ComputeAndCompareR1<half>(&builder, expected_output, {dot_lhs_handle.get()});
}
XLA_TEST_F(ConvertTest, ConvertC64ToC64) {
XlaBuilder builder(TestName());
std::vector<complex64> x = {{42.0f, 64.0f}};
ConvertElementType(ConstantR1<complex64>(&builder, x), C64);
ComputeAndCompareR1<complex64>(&builder, x, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConvertTest, ConvertS64S64) {
XlaBuilder builder(TestName());
std::vector<int64> x = {{-42, 64}};
ConvertElementType(ConstantR1<int64>(&builder, x), S64);
ComputeAndCompareR1<int64>(&builder, x, {});
}
XLA_TEST_F(ConvertTest, ConvertU64U64) {
XlaBuilder builder(TestName());
std::vector<uint64> x = {{42, 64}};
ConvertElementType(ConstantR1<uint64>(&builder, x), U64);
ComputeAndCompareR1<uint64>(&builder, x, {});
}
XLA_TEST_F(ConvertTest, ConvertU64S64) {
XlaBuilder builder(TestName());
std::vector<uint64> unsigned_x = {{42, UINT64_MAX}};
ConvertElementType(ConstantR1<uint64>(&builder, unsigned_x), S64);
std::vector<int64> signed_x = {{42, -1}};
ComputeAndCompareR1<int64>(&builder, signed_x, {});
}
XLA_TEST_F(ConvertTest, ConvertS64U64) {
XlaBuilder builder(TestName());
std::vector<int64> signed_x = {{42, -1, INT64_MIN}};
ConvertElementType(ConstantR1<int64>(&builder, signed_x), U64);
std::vector<uint64> unsigned_x = {
{42, UINT64_MAX, tensorflow::MathUtil::IPow<uint64>(2, 63)}};
ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
}
XLA_TEST_F(ConvertTest, ConvertBF16F32) {
XlaBuilder builder(TestName());
std::vector<bfloat16> all_bfloats(1 << 16);
for (int i = 0; i < all_bfloats.size(); ++i) {
all_bfloats[i].value = i;
}
std::vector<uint32> expected(all_bfloats.size());
for (int i = 0; i < expected.size(); ++i) {
expected[i] = (1U << 16) * i;
}
// Exhaustively test all bf16 to f32 conversions.
xla::XlaOp all_bfloats_bf16 = ConstantR1<bfloat16>(&builder, all_bfloats);
xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32);
BitcastConvertType(all_bfloats_f32, U32);
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
} // namespace
} // namespace xla