[XLA] Small test that performs A*B+A and A*B+B.
PiperOrigin-RevId: 157492992
This commit is contained in:
parent
367ec84f8c
commit
5cf4845840
@ -158,6 +158,59 @@ TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); }
|
|||||||
|
|
||||||
TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); }
|
TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); }
|
||||||
|
|
||||||
|
class MatOpsDotAddTest
|
||||||
|
: public ClientLibraryTestBase,
|
||||||
|
public ::testing::WithParamInterface<std::tuple<bool, bool>> {};
|
||||||
|
|
||||||
|
TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) {
|
||||||
|
bool row_major = std::get<0>(GetParam());
|
||||||
|
bool add_lhs = std::get<1>(GetParam());
|
||||||
|
Array2D<float> lhs({{1.0, 2.0}, {3.0, 4.0}});
|
||||||
|
Array2D<float> rhs({{10.0, 11.0}, {12.0, 13.0}});
|
||||||
|
|
||||||
|
auto minor_to_major = [](bool row_major) -> std::vector<int64> {
|
||||||
|
return {row_major ? 1 : 0, row_major ? 0 : 1};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto prim_type = primitive_util::NativeToPrimitiveType<float>();
|
||||||
|
Shape lhs_shape =
|
||||||
|
ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()});
|
||||||
|
Shape rhs_shape =
|
||||||
|
ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()});
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(
|
||||||
|
auto lhs_handle,
|
||||||
|
client_->TransferToServer(
|
||||||
|
*LiteralUtil::CreateR2FromArray2DWithLayout<float>(
|
||||||
|
lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(
|
||||||
|
auto rhs_handle,
|
||||||
|
client_->TransferToServer(
|
||||||
|
*LiteralUtil::CreateR2FromArray2DWithLayout<float>(
|
||||||
|
rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
|
||||||
|
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs");
|
||||||
|
auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs");
|
||||||
|
auto result = builder.Dot(lhs_arg, rhs_arg);
|
||||||
|
Array2D<float> expected;
|
||||||
|
if (add_lhs) {
|
||||||
|
result = builder.Add(result, lhs_arg);
|
||||||
|
expected = Array2D<float>({{35, 39}, {81, 89}});
|
||||||
|
} else {
|
||||||
|
result = builder.Add(result, rhs_arg);
|
||||||
|
expected = Array2D<float>({{44, 48}, {90, 98}});
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected,
|
||||||
|
{lhs_handle.get(), rhs_handle.get()},
|
||||||
|
ErrorSpec(1e-6));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest,
|
||||||
|
::testing::Combine(::testing::Bool(),
|
||||||
|
::testing::Bool()));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user