Add additional concat test.
PiperOrigin-RevId: 157844113
This commit is contained in:
parent
f661128dbf
commit
d5421cf58e
@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
Array3D<float> arr0(9, 17, 1);
|
||||
arr0.Fill(1);
|
||||
|
||||
Array3D<float> arr1(9, 17, 256);
|
||||
arr1.Fill(2);
|
||||
|
||||
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
|
||||
for (int64 i = 0; i < expected.n1(); ++i) {
|
||||
for (int64 j = 0; j < expected.n2(); ++j) {
|
||||
int64 kk = 0;
|
||||
for (const Array3D<float>& arr : {arr0, arr1}) {
|
||||
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
|
||||
expected(i, j, kk) = arr(i, j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ComputationDataHandle h0;
|
||||
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
|
||||
&builder, &h0);
|
||||
ComputationDataHandle h1;
|
||||
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
|
||||
&builder, &h1);
|
||||
|
||||
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
|
||||
}
|
||||
|
||||
// Describes a binary rank-2 concatenation test.
|
||||
struct R2BinarySpec {
|
||||
int64 lhs_dim0;
|
||||
|
Loading…
Reference in New Issue
Block a user