Enable half precision convolution for the CPU and GPU backends.
Enhance the CPU IR emitter to support F16 dot operation and convolution operation. Add a CPU runtime implementation for F16 convolution. Enhance the GPU backend to handle F16 convolution thunk. Convert some F32 xla convolution tests to support both F32 and F16 and disable the tests for the CPU backend due to b/72509305. PiperOrigin-RevId: 185862438
This commit is contained in:
parent
c356d28001
commit
b91155edb6
@ -121,6 +121,23 @@ class Array {
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
// Creates a 2D array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array(std::initializer_list<std::initializer_list<T2>> values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size()})) {
|
||||
int64 idx = 0;
|
||||
for (const auto& it1 : values) {
|
||||
for (const auto& it2 : it1) {
|
||||
values_[idx] = static_cast<T>(it2);
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
// Creates a 3D array from the given nested initializer list. The outer
|
||||
// initializer list is the first dimension, and so on.
|
||||
Array(InitializerList3D values)
|
||||
@ -138,6 +155,27 @@ class Array {
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
// Creates a 3D array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
|
||||
values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size(),
|
||||
values.begin()->begin()->size()})) {
|
||||
int64 idx = 0;
|
||||
for (const auto& it1 : values) {
|
||||
for (const auto& it2 : it1) {
|
||||
for (const auto& it3 : it2) {
|
||||
values_[idx] = static_cast<T>(it3);
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
// Creates a 4D array from the given nested initializer list. The outer
|
||||
// initializer list is the first dimension, and so on.
|
||||
Array(InitializerList4D values)
|
||||
@ -158,6 +196,31 @@ class Array {
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
// Creates a 4D array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array(std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
|
||||
values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size(),
|
||||
values.begin()->begin()->size(),
|
||||
values.begin()->begin()->begin()->size()})) {
|
||||
int64 idx = 0;
|
||||
for (const auto& it1 : values) {
|
||||
for (const auto& it2 : it1) {
|
||||
for (const auto& it3 : it2) {
|
||||
for (const auto& it4 : it3) {
|
||||
values_[idx] = static_cast<T>(it4);
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(idx == num_elements());
|
||||
}
|
||||
|
||||
Array(const Array<T>& other)
|
||||
: sizes_(other.sizes_), values_(new T[num_elements()]) {
|
||||
std::copy(&other.values_[0], &other.values_[0] + num_elements(),
|
||||
@ -185,7 +248,7 @@ class Array {
|
||||
// Fills the array with the sequence i*multiplier for i=0,1,...
|
||||
void FillWithMultiples(const T& multiplier) {
|
||||
for (int64 i = 0; i < num_elements(); ++i) {
|
||||
values_[i] = i * multiplier;
|
||||
values_[i] = static_cast<T>(i) * multiplier;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,14 @@ class Array2D : public Array<T> {
|
||||
Array2D(std::initializer_list<std::initializer_list<T>> values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
// Creates an array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array2D(std::initializer_list<std::initializer_list<T2>> values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
Array2D(const Array2D<T>& other) : Array<T>(other) {}
|
||||
|
||||
int64 n1() const { return this->dim(0); }
|
||||
|
@ -63,6 +63,20 @@ TEST(Array2dTest, InitializerListCtor) {
|
||||
EXPECT_EQ(arr(1, 2), 6);
|
||||
}
|
||||
|
||||
TEST(Array2dTest, InitializerListCtorHalf) {
|
||||
Array2D<Eigen::half> arr = {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}};
|
||||
|
||||
EXPECT_EQ(arr.n1(), 2);
|
||||
EXPECT_EQ(arr.n2(), 3);
|
||||
|
||||
EXPECT_EQ(arr(0, 0), static_cast<Eigen::half>(1));
|
||||
EXPECT_EQ(arr(0, 1), static_cast<Eigen::half>(2));
|
||||
EXPECT_EQ(arr(0, 2), static_cast<Eigen::half>(3));
|
||||
EXPECT_EQ(arr(1, 0), static_cast<Eigen::half>(4));
|
||||
EXPECT_EQ(arr(1, 1), static_cast<Eigen::half>(5));
|
||||
EXPECT_EQ(arr(1, 2), static_cast<Eigen::half>(6));
|
||||
}
|
||||
|
||||
TEST(Array2dTest, Accessors) {
|
||||
Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
|
||||
|
||||
|
@ -57,6 +57,16 @@ class Array3D : public Array<T> {
|
||||
values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
// Creates an array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array3D(
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
|
||||
values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
int64 n1() const { return this->dim(0); }
|
||||
int64 n2() const { return this->dim(1); }
|
||||
int64 n3() const { return this->dim(2); }
|
||||
|
@ -69,6 +69,29 @@ TEST(Array3dTest, InitializerListCtor) {
|
||||
EXPECT_EQ(arr(2, 3, 1), 24);
|
||||
}
|
||||
|
||||
TEST(Array3dTest, InitializerListCtorHalf) {
|
||||
Array3D<Eigen::half> arr = {
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}},
|
||||
{{9.0f, 10.0f}, {11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}},
|
||||
{{17.0f, 18.0f}, {19.0f, 20.0f}, {21.0f, 22.0f}, {23.0f, 24.0f}}};
|
||||
|
||||
EXPECT_EQ(arr.n1(), 3);
|
||||
EXPECT_EQ(arr.n2(), 4);
|
||||
EXPECT_EQ(arr.n3(), 2);
|
||||
EXPECT_EQ(arr.num_elements(), 24);
|
||||
|
||||
EXPECT_EQ(arr(0, 0, 0), static_cast<Eigen::half>(1));
|
||||
EXPECT_EQ(arr(0, 0, 1), static_cast<Eigen::half>(2));
|
||||
EXPECT_EQ(arr(0, 1, 0), static_cast<Eigen::half>(3));
|
||||
EXPECT_EQ(arr(0, 3, 1), static_cast<Eigen::half>(8));
|
||||
EXPECT_EQ(arr(1, 0, 0), static_cast<Eigen::half>(9));
|
||||
EXPECT_EQ(arr(1, 1, 1), static_cast<Eigen::half>(12));
|
||||
EXPECT_EQ(arr(2, 0, 0), static_cast<Eigen::half>(17));
|
||||
EXPECT_EQ(arr(2, 1, 1), static_cast<Eigen::half>(20));
|
||||
EXPECT_EQ(arr(2, 2, 0), static_cast<Eigen::half>(21));
|
||||
EXPECT_EQ(arr(2, 3, 1), static_cast<Eigen::half>(24));
|
||||
}
|
||||
|
||||
TEST(Array3dTest, Fill) {
|
||||
Array3D<int> fullof7(2, 3, 4, 7);
|
||||
for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
|
||||
|
@ -82,6 +82,16 @@ class Array4D : public Array<T> {
|
||||
values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
// Creates an array of Eigen::half from the given nested initializer list of
|
||||
// float values.
|
||||
template <typename T2, typename = typename std::enable_if<
|
||||
std::is_same<T, Eigen::half>::value &&
|
||||
std::is_same<T2, float>::value>::type>
|
||||
Array4D(std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<T2>>>>
|
||||
values)
|
||||
: Array<T>(values) {}
|
||||
|
||||
// Numerically-named aliases for the various dimensions. This matches the
|
||||
// dimension names used in array3d.
|
||||
int64 n4() const { return this->dim(3); }
|
||||
|
@ -97,6 +97,36 @@ TEST(Array3dTest, InitializerListCtor) {
|
||||
EXPECT_EQ(arr(2, 3, 1, 0), 24);
|
||||
}
|
||||
|
||||
TEST(Array3dTest, InitializerListCtorHalf) {
|
||||
Array4D<Eigen::half> arr = {
|
||||
{{{1.0f}, {2.0f}}, {{3.0f}, {4.0f}}, {{5.0f}, {6.0f}}, {{7.0f}, {8.0f}}},
|
||||
{{{9.0f}, {10.0f}},
|
||||
{{11.0f}, {12.0f}},
|
||||
{{13.0f}, {14.0f}},
|
||||
{{15.0f}, {16.0f}}},
|
||||
{{{17.0f}, {18.0f}},
|
||||
{{19.0f}, {20.0f}},
|
||||
{{21.0f}, {22.0f}},
|
||||
{{23.0f}, {24.0f}}}};
|
||||
|
||||
EXPECT_EQ(arr.n1(), 3);
|
||||
EXPECT_EQ(arr.n2(), 4);
|
||||
EXPECT_EQ(arr.n3(), 2);
|
||||
EXPECT_EQ(arr.n4(), 1);
|
||||
EXPECT_EQ(arr.num_elements(), 24);
|
||||
|
||||
EXPECT_EQ(arr(0, 0, 0, 0), static_cast<Eigen::half>(1));
|
||||
EXPECT_EQ(arr(0, 0, 1, 0), static_cast<Eigen::half>(2));
|
||||
EXPECT_EQ(arr(0, 1, 0, 0), static_cast<Eigen::half>(3));
|
||||
EXPECT_EQ(arr(0, 3, 1, 0), static_cast<Eigen::half>(8));
|
||||
EXPECT_EQ(arr(1, 0, 0, 0), static_cast<Eigen::half>(9));
|
||||
EXPECT_EQ(arr(1, 1, 1, 0), static_cast<Eigen::half>(12));
|
||||
EXPECT_EQ(arr(2, 0, 0, 0), static_cast<Eigen::half>(17));
|
||||
EXPECT_EQ(arr(2, 1, 1, 0), static_cast<Eigen::half>(20));
|
||||
EXPECT_EQ(arr(2, 2, 0, 0), static_cast<Eigen::half>(21));
|
||||
EXPECT_EQ(arr(2, 3, 1, 0), static_cast<Eigen::half>(24));
|
||||
}
|
||||
|
||||
TEST(Array4dTest, Fill) {
|
||||
Array4D<int> fullof7(2, 3, 4, 5, 7);
|
||||
fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
|
||||
|
@ -60,6 +60,25 @@ TEST(ArrayTest, InitializerListCtor) {
|
||||
EXPECT_EQ(arr(1, 2), 6);
|
||||
}
|
||||
|
||||
TEST(ArrayTest, InitializerListCtorHalf) {
|
||||
Array<Eigen::half> d2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
|
||||
EXPECT_EQ(d2.dim(0), 2);
|
||||
EXPECT_EQ(d2.dim(1), 3);
|
||||
|
||||
Array<Eigen::half> d3({{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}});
|
||||
EXPECT_EQ(d3.dim(0), 3);
|
||||
EXPECT_EQ(d3.dim(1), 2);
|
||||
EXPECT_EQ(d3.dim(2), 1);
|
||||
|
||||
Array<Eigen::half> d4(
|
||||
{{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}},
|
||||
{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}});
|
||||
EXPECT_EQ(d4.dim(0), 2);
|
||||
EXPECT_EQ(d4.dim(1), 3);
|
||||
EXPECT_EQ(d4.dim(2), 2);
|
||||
EXPECT_EQ(d4.dim(3), 1);
|
||||
}
|
||||
|
||||
TEST(ArrayTest, IndexingReadWrite) {
|
||||
Array<int> arr({2, 3});
|
||||
|
||||
|
@ -35,6 +35,8 @@ extern const char* const kEigenMatMulF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulF32";
|
||||
extern const char* const kEigenMatMulF64SymbolName =
|
||||
"__xla_cpu_runtime_EigenMatMulF64";
|
||||
extern const char* const kEigenConvF16SymbolName =
|
||||
"__xla_cpu_runtime_EigenConvF16";
|
||||
extern const char* const kEigenConvF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenConvF32";
|
||||
extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
|
||||
@ -42,6 +44,8 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
|
||||
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
|
||||
extern const char* const kEigenSingleThreadedConvF16SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedConvF16";
|
||||
extern const char* const kEigenSingleThreadedConvF32SymbolName =
|
||||
"__xla_cpu_runtime_EigenSingleThreadedConvF32";
|
||||
extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
|
||||
|
@ -43,10 +43,12 @@ namespace runtime {
|
||||
// because it is a symbol in the cpu_runtime library.
|
||||
extern const char* const kEigenMatMulF32SymbolName;
|
||||
extern const char* const kEigenMatMulF64SymbolName;
|
||||
extern const char* const kEigenConvF16SymbolName;
|
||||
extern const char* const kEigenConvF32SymbolName;
|
||||
extern const char* const kEigenFftSymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
|
||||
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
|
||||
extern const char* const kEigenSingleThreadedConvF16SymbolName;
|
||||
extern const char* const kEigenSingleThreadedConvF32SymbolName;
|
||||
extern const char* const kAcquireInfeedBufferForDequeueSymbolName;
|
||||
extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
|
||||
|
@ -549,7 +549,7 @@ DotOpEmitter::DotOpEmitter(
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
const TargetMachineFeatures& target_machine_features) {
|
||||
PrimitiveType type = target_array.GetShape().element_type();
|
||||
TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
|
||||
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
|
||||
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
|
||||
lhs_array, rhs_array, addend_array,
|
||||
executable_run_options_value, ir_builder,
|
||||
|
@ -801,7 +801,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
|
||||
auto rhs = dot->operand(1);
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F32, F64, C64}));
|
||||
/*supported_types=*/{F16, F32, F64, C64}));
|
||||
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
|
||||
if (dnums.lhs_batch_dimensions_size() > 0 ||
|
||||
dnums.rhs_batch_dimensions_size() > 0) {
|
||||
@ -849,7 +849,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
const auto& window = convolution->window();
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F32, C64}));
|
||||
/*supported_types=*/{F16, F32, C64}));
|
||||
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution->convolution_dimension_numbers();
|
||||
@ -928,25 +928,30 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
int64 rhs_col_dilation =
|
||||
one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
|
||||
|
||||
// Args have been computed, make the call.
|
||||
llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo();
|
||||
PrimitiveType primitive_type = lhs->shape().element_type();
|
||||
llvm::Type* ir_ptr_type = primitive_type == F16
|
||||
? ir_builder_.getHalfTy()->getPointerTo()
|
||||
: ir_builder_.getFloatTy()->getPointerTo();
|
||||
llvm::Type* int64_type = ir_builder_.getInt64Ty();
|
||||
llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
|
||||
llvm::FunctionType* conv_type = llvm::FunctionType::get(
|
||||
ir_builder_.getVoidTy(),
|
||||
{int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
|
||||
int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type},
|
||||
{int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type},
|
||||
/*isVarArg=*/false);
|
||||
bool multi_threaded_eigen =
|
||||
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
|
||||
const char* fn_name =
|
||||
(multi_threaded_eigen
|
||||
? runtime::kEigenConvF32SymbolName
|
||||
: runtime::kEigenSingleThreadedConvF32SymbolName);
|
||||
primitive_type == F16
|
||||
? (multi_threaded_eigen
|
||||
? runtime::kEigenConvF16SymbolName
|
||||
: runtime::kEigenSingleThreadedConvF16SymbolName)
|
||||
: (multi_threaded_eigen
|
||||
? runtime::kEigenConvF32SymbolName
|
||||
: runtime::kEigenSingleThreadedConvF32SymbolName);
|
||||
llvm::Function* conv_func = llvm::cast<llvm::Function>(
|
||||
module_->getOrInsertFunction(fn_name, conv_type));
|
||||
conv_func->setCallingConv(llvm::CallingConv::C);
|
||||
@ -956,9 +961,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
conv_func, {
|
||||
GetExecutableRunOptionsArgument(),
|
||||
ir_builder_.CreateBitCast(
|
||||
GetEmittedValueFor(convolution), float_ptr_type),
|
||||
ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
|
||||
ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
|
||||
GetEmittedValueFor(convolution), ir_ptr_type),
|
||||
ir_builder_.CreateBitCast(lhs_address, ir_ptr_type),
|
||||
ir_builder_.CreateBitCast(rhs_address, ir_ptr_type),
|
||||
ir_builder_.getInt64(input_batch),
|
||||
ir_builder_.getInt64(input_rows),
|
||||
ir_builder_.getInt64(input_cols),
|
||||
|
@ -34,7 +34,26 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32(
|
||||
int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
|
||||
const xla::ExecutableRunOptions* run_options =
|
||||
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
|
||||
tensorflow::xla::EigenConvF32Impl(
|
||||
tensorflow::xla::EigenConvImpl(
|
||||
*run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
|
||||
input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
|
||||
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
|
||||
col_stride, padding_top, padding_bottom, padding_left, padding_right,
|
||||
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16(
|
||||
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
|
||||
Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
|
||||
int64 input_channels, int64 kernel_rows, int64 kernel_cols,
|
||||
int64 kernel_channels, int64 kernel_filters, int64 output_rows,
|
||||
int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
|
||||
int64 padding_bottom, int64 padding_left, int64 padding_right,
|
||||
int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
|
||||
int64 rhs_col_dilation) {
|
||||
const xla::ExecutableRunOptions* run_options =
|
||||
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
|
||||
tensorflow::xla::EigenConvImpl(
|
||||
*run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
|
||||
input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
|
||||
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
|
||||
|
@ -34,6 +34,20 @@ extern void __xla_cpu_runtime_EigenConvF32(
|
||||
tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
|
||||
tensorflow::int64 rhs_col_dilation);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenConvF16(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
|
||||
tensorflow::int64 input_batch, tensorflow::int64 input_rows,
|
||||
tensorflow::int64 input_cols, tensorflow::int64 input_channels,
|
||||
tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
|
||||
tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
|
||||
tensorflow::int64 output_rows, tensorflow::int64 output_cols,
|
||||
tensorflow::int64 row_stride, tensorflow::int64 col_stride,
|
||||
tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
|
||||
tensorflow::int64 padding_left, tensorflow::int64 padding_right,
|
||||
tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
|
||||
tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_
|
||||
|
@ -24,26 +24,27 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
template <typename EigenDevice>
|
||||
void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
|
||||
float* rhs, int64 input_batch, int64 input_rows,
|
||||
int64 input_cols, int64 input_channels, int64 kernel_rows,
|
||||
int64 kernel_cols, int64 kernel_channels,
|
||||
int64 kernel_filters, int64 output_rows,
|
||||
int64 output_cols, int64 row_stride, int64 col_stride,
|
||||
int64 padding_top, int64 padding_bottom,
|
||||
int64 padding_left, int64 padding_right,
|
||||
int64 lhs_row_dilation, int64 lhs_col_dilation,
|
||||
int64 rhs_row_dilation, int64 rhs_col_dilation) {
|
||||
const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
|
||||
template <typename EigenDevice, typename ScalarType>
|
||||
void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
|
||||
ScalarType* rhs, int64 input_batch, int64 input_rows,
|
||||
int64 input_cols, int64 input_channels, int64 kernel_rows,
|
||||
int64 kernel_cols, int64 kernel_channels,
|
||||
int64 kernel_filters, int64 output_rows, int64 output_cols,
|
||||
int64 row_stride, int64 col_stride, int64 padding_top,
|
||||
int64 padding_bottom, int64 padding_left,
|
||||
int64 padding_right, int64 lhs_row_dilation,
|
||||
int64 lhs_col_dilation, int64 rhs_row_dilation,
|
||||
int64 rhs_col_dilation) {
|
||||
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
input(lhs, input_batch, input_rows, input_cols, input_channels);
|
||||
|
||||
const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
|
||||
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor>, Eigen::Aligned>
|
||||
Eigen::TensorMap<Eigen::Tensor<ScalarType, 4, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
output(out, input_batch, output_rows, output_cols, kernel_filters);
|
||||
|
||||
Eigen::array<Eigen::IndexPair<int64>, 1> contract_dims;
|
||||
@ -75,7 +76,7 @@ void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
|
||||
row_stride, rhs_col_dilation, rhs_row_dilation,
|
||||
lhs_col_dilation, lhs_row_dilation,
|
||||
padding_left, padding_right, padding_top,
|
||||
padding_bottom, 0.0f)
|
||||
padding_bottom, static_cast<ScalarType>(0.0f))
|
||||
.reshape(pre_contract_dims)
|
||||
.contract(kernel.reshape(kernel_dims), contract_dims)
|
||||
.reshape(post_contract_dims);
|
||||
|
@ -21,6 +21,24 @@ limitations under the License.
|
||||
|
||||
using tensorflow::int64;
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
__xla_cpu_runtime_EigenSingleThreadedConvF16(
|
||||
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
|
||||
Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
|
||||
int64 input_channels, int64 kernel_rows, int64 kernel_cols,
|
||||
int64 kernel_channels, int64 kernel_filters, int64 output_rows,
|
||||
int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
|
||||
int64 padding_bottom, int64 padding_left, int64 padding_right,
|
||||
int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
|
||||
int64 rhs_col_dilation) {
|
||||
tensorflow::xla::EigenConvImpl(
|
||||
Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
|
||||
input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
|
||||
kernel_filters, output_rows, output_cols, row_stride, col_stride,
|
||||
padding_top, padding_bottom, padding_left, padding_right,
|
||||
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
|
||||
__xla_cpu_runtime_EigenSingleThreadedConvF32(
|
||||
const void* run_options_ptr, float* out, float* lhs, float* rhs,
|
||||
@ -30,7 +48,7 @@ __xla_cpu_runtime_EigenSingleThreadedConvF32(
|
||||
int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
|
||||
int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
|
||||
int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
|
||||
tensorflow::xla::EigenConvF32Impl(
|
||||
tensorflow::xla::EigenConvImpl(
|
||||
Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
|
||||
input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
|
||||
kernel_filters, output_rows, output_cols, row_stride, col_stride,
|
||||
|
@ -20,6 +20,20 @@ limitations under the License.
|
||||
|
||||
extern "C" {
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedConvF16(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
|
||||
tensorflow::int64 input_batch, tensorflow::int64 input_rows,
|
||||
tensorflow::int64 input_cols, tensorflow::int64 input_channels,
|
||||
tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
|
||||
tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
|
||||
tensorflow::int64 output_rows, tensorflow::int64 output_cols,
|
||||
tensorflow::int64 row_stride, tensorflow::int64 col_stride,
|
||||
tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
|
||||
tensorflow::int64 padding_left, tensorflow::int64 padding_right,
|
||||
tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
|
||||
tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedConvF32(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
|
||||
float* lhs, float* rhs, tensorflow::int64 input_batch,
|
||||
|
@ -208,10 +208,12 @@ bool RegisterKnownJITSymbols() {
|
||||
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
||||
|
@ -63,12 +63,12 @@ ConvolutionThunk::ConvolutionThunk(
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream) {
|
||||
se::DeviceMemory<float> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<float> filter_data(
|
||||
buffer_allocations.GetDeviceAddress(filter_buffer_));
|
||||
se::DeviceMemory<float> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
se::DeviceMemoryBase input_data =
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_);
|
||||
se::DeviceMemoryBase filter_data =
|
||||
buffer_allocations.GetDeviceAddress(filter_buffer_);
|
||||
se::DeviceMemoryBase output_data =
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_);
|
||||
se::DeviceMemoryBase scratch =
|
||||
buffer_allocations.GetDeviceAddress(scratch_buffer_);
|
||||
|
||||
@ -80,8 +80,8 @@ Status ConvolutionThunk::ExecuteOnStream(
|
||||
filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
|
||||
stream));
|
||||
|
||||
// Figure out which of output/input/filter is the result produced by this op,
|
||||
// and write the result tuple.
|
||||
// Figure out which of output/input/filter is the result produced by
|
||||
// this op, and write the result tuple.
|
||||
void* result_ptr = [&] {
|
||||
switch (convolution_kind_) {
|
||||
case CudnnConvKind::kForward:
|
||||
|
@ -135,15 +135,6 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
||||
break;
|
||||
}
|
||||
|
||||
// Remove any algorithms with tensor math enabled. These have lower precision
|
||||
// than regular algorithms, and we don't yet have a way to turn this on/off in
|
||||
// XLA.
|
||||
algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(),
|
||||
[&](const AlgorithmDesc& a) {
|
||||
return a.tensor_ops_enabled();
|
||||
}),
|
||||
algorithms.end());
|
||||
|
||||
return algorithms;
|
||||
}
|
||||
|
||||
@ -222,6 +213,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums);
|
||||
se::dnn::ProfileResult best_result;
|
||||
int64 best_result_bytes_used = 0;
|
||||
|
||||
for (const AlgorithmDesc& alg :
|
||||
GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
|
||||
ScratchAllocator scratch_allocator(device_ordinal, allocator);
|
||||
@ -229,14 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
|
||||
<< instr->ToString();
|
||||
|
||||
bool launch_ok =
|
||||
RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
|
||||
se::DeviceMemory<float>(input_buf.ValueOrDie()),
|
||||
se::DeviceMemory<float>(filter_buf.ValueOrDie()),
|
||||
se::DeviceMemory<float>(output_buf.ValueOrDie()),
|
||||
&scratch_allocator, window, dnums,
|
||||
AlgorithmConfig(alg), &stream, &profile_result)
|
||||
.ok();
|
||||
bool launch_ok = RunCudnnConvolution(
|
||||
kind, input_shape, filter_shape, output_shape,
|
||||
input_buf.ValueOrDie(), filter_buf.ValueOrDie(),
|
||||
output_buf.ValueOrDie(), &scratch_allocator, window,
|
||||
dnums, AlgorithmConfig(alg), &stream, &profile_result)
|
||||
.ok();
|
||||
|
||||
if (launch_ok && profile_result.is_valid()) {
|
||||
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
|
||||
|
@ -70,39 +70,11 @@ class ScratchBufAllocator : public se::ScratchAllocator {
|
||||
bool allocated_ = false;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
string CudnnConvKindToString(CudnnConvKind kind) {
|
||||
switch (kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
return "forward";
|
||||
case CudnnConvKind::kBackwardFilter:
|
||||
return "backward_filter";
|
||||
case CudnnConvKind::kBackwardInput:
|
||||
return "backward_input";
|
||||
}
|
||||
}
|
||||
|
||||
Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape,
|
||||
const Shape& filter_shape, const Shape& output_shape,
|
||||
DeviceMemory<float> input_buf,
|
||||
DeviceMemory<float> filter_buf,
|
||||
DeviceMemory<float> output_buf,
|
||||
DeviceMemoryBase scratch_buf, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums,
|
||||
AlgorithmConfig algorithm, Stream* stream,
|
||||
ProfileResult* profile_result /*= nullptr*/) {
|
||||
ScratchBufAllocator scratch_allocator(scratch_buf);
|
||||
return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
|
||||
input_buf, filter_buf, output_buf,
|
||||
&scratch_allocator, window, dnums, algorithm,
|
||||
stream, profile_result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status RunCudnnConvolution(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, DeviceMemory<float> input_buf,
|
||||
DeviceMemory<float> filter_buf, DeviceMemory<float> output_buf,
|
||||
const Shape& output_shape, DeviceMemory<T> input_buf,
|
||||
DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
|
||||
se::ScratchAllocator* scratch_allocator, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
|
||||
Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
|
||||
@ -124,8 +96,16 @@ Status RunCudnnConvolution(
|
||||
// tensorflow/python/ops/nn_ops.py).
|
||||
const int effective_num_dimensions = std::max(2, num_dimensions);
|
||||
|
||||
CHECK_EQ(F32, output_shape.element_type())
|
||||
<< ShapeUtil::HumanString(output_shape);
|
||||
if (std::is_same<T, float>::value) {
|
||||
CHECK_EQ(F32, output_shape.element_type())
|
||||
<< ShapeUtil::HumanString(output_shape);
|
||||
} else if (std::is_same<T, Eigen::half>::value) {
|
||||
CHECK_EQ(F16, output_shape.element_type())
|
||||
<< ShapeUtil::HumanString(output_shape);
|
||||
} else {
|
||||
LOG(FATAL) << ShapeUtil::HumanString(output_shape);
|
||||
}
|
||||
|
||||
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
|
||||
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
|
||||
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size());
|
||||
@ -220,5 +200,63 @@ Status RunCudnnConvolution(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
string CudnnConvKindToString(CudnnConvKind kind) {
|
||||
switch (kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
return "forward";
|
||||
case CudnnConvKind::kBackwardFilter:
|
||||
return "backward_filter";
|
||||
case CudnnConvKind::kBackwardInput:
|
||||
return "backward_input";
|
||||
}
|
||||
}
|
||||
|
||||
Status RunCudnnConvolution(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
|
||||
perftools::gputools::DeviceMemoryBase filter_buf,
|
||||
perftools::gputools::DeviceMemoryBase output_buf,
|
||||
perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums,
|
||||
perftools::gputools::dnn::AlgorithmConfig algorithm,
|
||||
perftools::gputools::Stream* stream,
|
||||
perftools::gputools::dnn::ProfileResult* profile_result) {
|
||||
ScratchBufAllocator scratch_allocator(scratch_buf);
|
||||
return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
|
||||
input_buf, filter_buf, output_buf,
|
||||
&scratch_allocator, window, dnums, algorithm,
|
||||
stream, profile_result);
|
||||
}
|
||||
|
||||
Status RunCudnnConvolution(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
|
||||
perftools::gputools::DeviceMemoryBase filter_buf,
|
||||
perftools::gputools::DeviceMemoryBase output_buf,
|
||||
perftools::gputools::ScratchAllocator* scratch_allocator,
|
||||
const Window& window, const ConvolutionDimensionNumbers& dnums,
|
||||
perftools::gputools::dnn::AlgorithmConfig algorithm,
|
||||
perftools::gputools::Stream* stream,
|
||||
perftools::gputools::dnn::ProfileResult* profile_result) {
|
||||
PrimitiveType output_primitive_type = output_shape.element_type();
|
||||
CHECK(output_primitive_type == F32 || output_primitive_type == F16)
|
||||
<< ShapeUtil::HumanString(output_shape);
|
||||
if (output_primitive_type == F32) {
|
||||
return RunCudnnConvolution(
|
||||
kind, input_shape, filter_shape, output_shape,
|
||||
se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
|
||||
se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
|
||||
algorithm, stream, profile_result);
|
||||
}
|
||||
return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
|
||||
se::DeviceMemory<Eigen::half>(input_buf),
|
||||
se::DeviceMemory<Eigen::half>(filter_buf),
|
||||
se::DeviceMemory<Eigen::half>(output_buf),
|
||||
scratch_allocator, window, dnums, algorithm,
|
||||
stream, profile_result);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -55,7 +55,10 @@ string CudnnConvKindToString(CudnnConvKind kind);
|
||||
// Note that depending on the value of CudnnConvKind, the result of this call
|
||||
// may be written into input_buf, filter_buf, or output_buf!
|
||||
//
|
||||
// At the moment we only support cudnn convolutions over floats.
|
||||
// At the moment we only support cudnn convolutions over float and half, and
|
||||
// convolution with half data type is implemented with cudnn PSEUDO_HALF
|
||||
// configuration, that is, the input values are half and the internal
|
||||
// computation type is float.
|
||||
//
|
||||
// We provide one overload which takes a scratch buffer, and another which takes
|
||||
// an allocator which is responsible for allocating the scratch space. In
|
||||
@ -69,10 +72,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
|
||||
// that size, if you like.
|
||||
Status RunCudnnConvolution(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape,
|
||||
perftools::gputools::DeviceMemory<float> input_buf,
|
||||
perftools::gputools::DeviceMemory<float> filter_buf,
|
||||
perftools::gputools::DeviceMemory<float> output_buf,
|
||||
const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
|
||||
perftools::gputools::DeviceMemoryBase filter_buf,
|
||||
perftools::gputools::DeviceMemoryBase output_buf,
|
||||
perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums,
|
||||
perftools::gputools::dnn::AlgorithmConfig algorithm,
|
||||
@ -81,10 +83,9 @@ Status RunCudnnConvolution(
|
||||
|
||||
Status RunCudnnConvolution(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape,
|
||||
perftools::gputools::DeviceMemory<float> input_buf,
|
||||
perftools::gputools::DeviceMemory<float> filter_buf,
|
||||
perftools::gputools::DeviceMemory<float> output_buf,
|
||||
const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
|
||||
perftools::gputools::DeviceMemoryBase filter_buf,
|
||||
perftools::gputools::DeviceMemoryBase output_buf,
|
||||
perftools::gputools::ScratchAllocator* scratch_allocator,
|
||||
const Window& window, const ConvolutionDimensionNumbers& dnums,
|
||||
perftools::gputools::dnn::AlgorithmConfig algorithm,
|
||||
|
@ -1403,6 +1403,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
|
||||
break;
|
||||
}
|
||||
case F16: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
|
||||
MapImpl<Eigen::half>(map));
|
||||
break;
|
||||
}
|
||||
case F32: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
|
||||
break;
|
||||
@ -2041,9 +2046,7 @@ HloEvaluator::HloEvaluator() {
|
||||
});
|
||||
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
|
||||
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
|
||||
typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
|
||||
});
|
||||
typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
|
||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
|
||||
|
@ -53,157 +53,200 @@ class ConvolutionTest : public ClientLibraryTestBase {
|
||||
#endif
|
||||
};
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
|
||||
const int kInputActivationSizeY = 3;
|
||||
const int kInputActivationSizeX = 3;
|
||||
const int kInputActivationSizeZ = 256;
|
||||
const int kKernelSizeX = 2;
|
||||
const int kKernelSizeY = 2;
|
||||
const int kOutputActivationSizeZ = 256;
|
||||
const int kMiniBatchSize = 4;
|
||||
auto alhs =
|
||||
MakeUnique<Array4D<float>>(kMiniBatchSize, kInputActivationSizeZ,
|
||||
kInputActivationSizeY, kInputActivationSizeX);
|
||||
alhs->FillWithMultiples(1.0f);
|
||||
ASSERT_EQ(3, alhs->width());
|
||||
ASSERT_EQ(3, alhs->height());
|
||||
// TODO(b/72509305): Enable half data type tests for CPU
|
||||
#if (XLA_TEST_BACKEND_GPU)
|
||||
using TestTypes = ::testing::Types<float, Eigen::half>;
|
||||
#else
|
||||
using TestTypes = ::testing::Types<float>;
|
||||
#endif
|
||||
|
||||
auto arhs =
|
||||
MakeUnique<Array4D<float>>(kOutputActivationSizeZ, kInputActivationSizeZ,
|
||||
kKernelSizeY, kKernelSizeX);
|
||||
Array2D<float> rhs_raster({
|
||||
{1.0f, 0.0f}, // row 0
|
||||
{0.0f, 0.0f}, // row 1
|
||||
});
|
||||
arhs->FillWithYX(rhs_raster);
|
||||
ASSERT_EQ(2, arhs->width());
|
||||
ASSERT_EQ(2, arhs->height());
|
||||
template <typename T>
|
||||
Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
|
||||
auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
|
||||
auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
|
||||
|
||||
ComputeAndCompare(&builder, conv, {}, error_spec_);
|
||||
template <>
|
||||
Shape MakeShapeWrapper<float>(tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
return ShapeUtil::MakeShape(F32, dimensions);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<float> input_data(1, 1, 1, 2);
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2},
|
||||
}));
|
||||
Array4D<float> filter_data(1, 1, 1, 2);
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
template <>
|
||||
Shape MakeShapeWrapper<Eigen::half>(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
return ShapeUtil::MakeShape(F16, dimensions);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
const int kInputActivationSizeY = 3;
|
||||
const int kInputActivationSizeX = 3;
|
||||
const int kInputActivationSizeZ = 256;
|
||||
const int kKernelSizeX = 2;
|
||||
const int kKernelSizeY = 2;
|
||||
const int kOutputActivationSizeZ = 256;
|
||||
const int kMiniBatchSize = 4;
|
||||
auto alhs =
|
||||
MakeUnique<Array4D<T>>(kMiniBatchSize, kInputActivationSizeZ,
|
||||
kInputActivationSizeY, kInputActivationSizeX);
|
||||
alhs->FillWithMultiples(static_cast<T>(1.0f));
|
||||
ASSERT_EQ(3, alhs->width());
|
||||
ASSERT_EQ(3, alhs->height());
|
||||
|
||||
auto arhs =
|
||||
MakeUnique<Array4D<T>>(kOutputActivationSizeZ, kInputActivationSizeZ,
|
||||
kKernelSizeY, kKernelSizeX);
|
||||
Array2D<T> rhs_raster({
|
||||
{1.0f, 0.0f}, // row 0
|
||||
{0.0f, 0.0f}, // row 1
|
||||
});
|
||||
arhs->FillWithYX(rhs_raster);
|
||||
ASSERT_EQ(2, arhs->width());
|
||||
ASSERT_EQ(2, arhs->height());
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR4FromArray4D<T>(*alhs);
|
||||
auto rhs = builder.ConstantR4FromArray4D<T>(*arhs);
|
||||
auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
|
||||
|
||||
ComputeAndCompare(&builder, conv, {}, error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
|
||||
XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
|
||||
this->RunTest();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<T> input_data(1, 1, 1, 2);
|
||||
input_data.FillWithYX(Array2D<T>({
|
||||
{1.0f, 2.0f},
|
||||
}));
|
||||
Array4D<T> filter_data(1, 1, 1, 2);
|
||||
filter_data.FillWithYX(Array2D<T>({
|
||||
{5.0f, 6.0f},
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
|
||||
TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
|
||||
|
||||
// Tests valid padding for 2D convolution in raster space.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
template <typename T>
|
||||
class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter_data(1, 1, 2, 2);
|
||||
// clang-format off
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
{7, 8},
|
||||
}));
|
||||
// clang-format on
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
Array4D<T> input_data(1, 1, 4, 4);
|
||||
input_data.FillWithYX(Array2D<T>({
|
||||
{1.0f, 2.0f, 3.0f, 4.0f},
|
||||
{5.0f, 6.0f, 7.0f, 8.0f},
|
||||
{9.0f, 10.0f, 11.0f, 12.0f},
|
||||
{13.0f, 14.0f, 15.0f, 16.0f},
|
||||
}));
|
||||
Array4D<T> filter_data(1, 1, 2, 2);
|
||||
filter_data.FillWithYX(Array2D<T>({
|
||||
{5.0f, 6.0f},
|
||||
{7.0f, 8.0f},
|
||||
}));
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
|
||||
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
|
||||
|
||||
// Tests same padding for 2D convolution in raster space.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
template <typename T>
|
||||
class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter_data(1, 1, 2, 2);
|
||||
// clang-format off
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
{7, 8},
|
||||
}));
|
||||
// clang-format on
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
Array4D<T> input_data(1, 1, 4, 4);
|
||||
input_data.FillWithYX(Array2D<T>({
|
||||
{1.0f, 2.0f, 3.0f, 4.0f},
|
||||
{5.0f, 6.0f, 7.0f, 8.0f},
|
||||
{9.0f, 10.0f, 11.0f, 12.0f},
|
||||
{13.0f, 14.0f, 15.0f, 16.0f},
|
||||
}));
|
||||
Array4D<T> filter_data(1, 1, 2, 2);
|
||||
filter_data.FillWithYX(Array2D<T>({
|
||||
{5.0f, 6.0f},
|
||||
{7.0f, 8.0f},
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
|
||||
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
|
||||
|
||||
// Tests same padding for 2D convolution in raster space with an odd sized
|
||||
// kernel.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
template <typename T>
|
||||
class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 1, 3, 3});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter_data(1, 1, 3, 3);
|
||||
// clang-format off
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{ 5, 6, 7},
|
||||
{ 8, 9, 10},
|
||||
{11, 12, 13},
|
||||
}));
|
||||
// clang-format on
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
Array4D<T> input_data(1, 1, 4, 4);
|
||||
input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
|
||||
{5.0f, 6.0f, 7.0f, 8.0f},
|
||||
{9.0f, 10.0f, 11.0f, 12.0f},
|
||||
{13.0f, 14.0f, 15.0f, 16.0f}}));
|
||||
Array4D<T> filter_data(1, 1, 3, 3);
|
||||
filter_data.FillWithYX(Array2D<T>(
|
||||
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
|
||||
// clang-format on
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{std::move(*Literal::CreateFromArray(input_data)),
|
||||
std::move(*Literal::CreateFromArray(filter_data))},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
|
||||
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -232,36 +275,44 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithRHSDilation) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
builder.ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
builder.ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{570, 670, 770}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -325,36 +376,45 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithPadding) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
builder.ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
|
||||
Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
builder.ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected(
|
||||
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{0, 260, 510, 610, 710, 810, 350, 0}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -389,12 +449,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
}
|
||||
|
||||
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
std::iota(input_elems.begin(), input_elems.end(), 1.0f);
|
||||
iota(input_elems.begin(), input_elems.end(), 1.0f);
|
||||
auto input_r1 = Literal::CreateR1<float>(input_elems);
|
||||
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
|
||||
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
|
||||
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
|
||||
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
@ -412,57 +472,74 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<int64> input_dims = {1, 3, 3, 5};
|
||||
std::vector<int64> filter_dims = {3, 3, 5, 3};
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
|
||||
{
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
|
||||
// Tensorflow dimension numbers for 2D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.add_input_spatial_dimensions(2);
|
||||
dnums.add_output_spatial_dimensions(2);
|
||||
dnums.set_input_feature_dimension(3);
|
||||
dnums.set_output_feature_dimension(3);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.add_kernel_spatial_dimensions(1);
|
||||
dnums.set_kernel_input_feature_dimension(2);
|
||||
dnums.set_kernel_output_feature_dimension(3);
|
||||
|
||||
builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
|
||||
dnums);
|
||||
}
|
||||
|
||||
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
std::iota(input_elems.begin(), input_elems.end(), 1.0f);
|
||||
auto input_r1 = Literal::CreateR1<float>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
|
||||
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = Literal::CreateR1<float>({92115, 93150, 94185});
|
||||
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
// std::iota doesn't work when init_value has a type Eigen::half in some build
|
||||
// servers. The error message is missing the operator ++.
|
||||
template <typename T>
|
||||
void iota_int_init_value(std::vector<T>& values, int init_value) {
|
||||
std::for_each(values.begin(), values.end(),
|
||||
[&](T& value) { value = static_cast<T>(init_value++); });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<int64> input_dims = {1, 3, 3, 5};
|
||||
std::vector<int64> filter_dims = {3, 3, 5, 3};
|
||||
Shape input_shape = MakeShapeWrapper<T>(input_dims);
|
||||
Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
|
||||
{
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
|
||||
// Tensorflow dimension numbers for 2D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.add_input_spatial_dimensions(2);
|
||||
dnums.add_output_spatial_dimensions(2);
|
||||
dnums.set_input_feature_dimension(3);
|
||||
dnums.set_output_feature_dimension(3);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.add_kernel_spatial_dimensions(1);
|
||||
dnums.set_kernel_input_feature_dimension(2);
|
||||
dnums.set_kernel_output_feature_dimension(3);
|
||||
|
||||
builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
|
||||
dnums);
|
||||
}
|
||||
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
iota_int_init_value(input_elems, 1);
|
||||
auto input_r1 = Literal::CreateR1<T>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
iota_int_init_value(filter_elems, 1);
|
||||
auto filter_r1 = Literal::CreateR1<T>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = Literal::CreateR1<T>(
|
||||
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
|
||||
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes);
|
||||
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); }
|
||||
|
||||
// Test fixture to run convolution tests with and without convolution
|
||||
// canonicalization enabled.
|
||||
class ConvolveWithAndWithoutCanonicalization
|
||||
@ -519,67 +596,78 @@ struct Convolve1DTestParam {
|
||||
int64 num_windows;
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTest
|
||||
class Convolve1D1WindowTestBase
|
||||
: public ConvolutionTest,
|
||||
public ::testing::WithParamInterface<Convolve1DTestParam> {};
|
||||
public ::testing::WithParamInterface<Convolve1DTestParam> {
|
||||
protected:
|
||||
template <typename T>
|
||||
void TestImpl() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
int64 input_feature = GetParam().input_feature;
|
||||
int64 output_feature = GetParam().output_feature;
|
||||
int64 batch = GetParam().batch;
|
||||
int64 num_windows = GetParam().num_windows;
|
||||
int64 window_size = GetParam().window_size;
|
||||
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
||||
input_feature};
|
||||
std::vector<int64> filter_dims = {window_size, input_feature,
|
||||
output_feature};
|
||||
Shape input_shape = MakeShapeWrapper<T>(input_dims);
|
||||
Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
|
||||
{
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
int64 input_feature = GetParam().input_feature;
|
||||
int64 output_feature = GetParam().output_feature;
|
||||
int64 batch = GetParam().batch;
|
||||
int64 num_windows = GetParam().num_windows;
|
||||
int64 window_size = GetParam().window_size;
|
||||
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
||||
input_feature};
|
||||
std::vector<int64> filter_dims = {window_size, input_feature, output_feature};
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
|
||||
{
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
// Tensorflow dimension numbers for 1D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.set_input_feature_dimension(2);
|
||||
dnums.set_output_feature_dimension(2);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.set_kernel_input_feature_dimension(1);
|
||||
dnums.set_kernel_output_feature_dimension(2);
|
||||
|
||||
// Tensorflow dimension numbers for 1D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.set_input_feature_dimension(2);
|
||||
dnums.set_output_feature_dimension(2);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.set_kernel_input_feature_dimension(1);
|
||||
dnums.set_kernel_output_feature_dimension(2);
|
||||
builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
|
||||
dnums);
|
||||
}
|
||||
|
||||
builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
|
||||
dnums);
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
|
||||
static_cast<T>(1.0f));
|
||||
auto input_r1 = Literal::CreateR1<T>(input_elems);
|
||||
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
|
||||
static_cast<T>(1.0f));
|
||||
|
||||
auto filter_r1 = Literal::CreateR1<T>(filter_elems);
|
||||
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> expect_elems(batch * output_feature * num_windows,
|
||||
static_cast<T>(window_size * input_feature));
|
||||
auto expected_r1 = Literal::CreateR1<T>(expect_elems);
|
||||
auto expected_r3 =
|
||||
expected_r1->Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, *expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape), 1.0);
|
||||
auto input_r1 = Literal::CreateR1<float>(input_elems);
|
||||
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
|
||||
|
||||
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0);
|
||||
|
||||
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
|
||||
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> expect_elems(batch * output_feature * num_windows,
|
||||
window_size * input_feature);
|
||||
auto expected_r1 = Literal::CreateR1<float>(expect_elems);
|
||||
auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, *expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest,
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
@ -608,6 +696,48 @@ INSTANTIATE_TEST_CASE_P(
|
||||
|
||||
);
|
||||
|
||||
#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
|
||||
class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
|
||||
|
||||
// TODO(b/72509305): Enable half data type tests for CPU.
|
||||
XLA_TEST_P(Convolve1D1WindowTestHalf,
|
||||
DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(Convolve1D1Window))) {
|
||||
TestImpl<Eigen::half>();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 3},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
// TODO(b/72566306): the following three tests fail on CPU
|
||||
// backend due to result miscompare.
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1})
|
||||
|
||||
);
|
||||
#endif
|
||||
|
||||
TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
|
||||
|
@ -161,4 +161,31 @@ string PrependDisabledIfIndicated(const string& test_case_name,
|
||||
|
||||
#define XLA_TEST_P(test_case_name, test_name) \
|
||||
XLA_TEST_P_IMPL_(test_case_name, test_name)
|
||||
|
||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
||||
#define XLA_TYPED_TEST(CaseName, TestName) \
|
||||
template <typename gtest_TypeParam_> \
|
||||
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
|
||||
: public CaseName<gtest_TypeParam_> { \
|
||||
private: \
|
||||
typedef CaseName<gtest_TypeParam_> TestFixture; \
|
||||
typedef gtest_TypeParam_ TypeParam; \
|
||||
virtual void TestBody(); \
|
||||
}; \
|
||||
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
|
||||
::testing::internal::TypeParameterizedTest< \
|
||||
CaseName, \
|
||||
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)>, \
|
||||
GTEST_TYPE_PARAMS_(CaseName)>:: \
|
||||
Register( \
|
||||
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
|
||||
#CaseName, \
|
||||
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
|
||||
0); \
|
||||
template <typename gtest_TypeParam_> \
|
||||
void GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)<gtest_TypeParam_>::TestBody()
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user