From 5cf48458408f51202ff43fe748c6a98385f1be69 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 30 May 2017 12:00:47 -0700 Subject: [PATCH] [XLA] Small test that performs A*B+A and A*B+B. PiperOrigin-RevId: 157492992 --- .../xla/tests/matrix_ops_simple_test.cc | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 8aa40294406..e2dd12bf066 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -158,6 +158,59 @@ TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } +class MatOpsDotAddTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + Array2D lhs({{1.0, 2.0}, {3.0, 4.0}}); + Array2D rhs({{10.0, 11.0}, {12.0, 13.0}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSIGN_OR_ASSERT_OK( + auto lhs_handle, + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSIGN_OR_ASSERT_OK( + auto rhs_handle, + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + expected = Array2D({{35, 39}, {81, 89}}); + } else { + result = builder.Add(result, rhs_arg); + expected = Array2D({{44, 48}, {90, 98}}); + } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); +} + +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, + ::testing::Combine(::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace xla