STT-tensorflow/tensorflow/compiler/xla/tests/collective_ops_test.cc
George Karpenkov 134dcd1302 [XLA:CPU] CollectivePermute support on CPU
Dummy implementation, main thread performs all the work.

PiperOrigin-RevId: 305184516
Change-Id: Ib4af0b7fda920fe08b551cb0782884ba92947ba7
2020-04-06 22:07:56 -07:00

645 lines
23 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_replace.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
// Tests cross-GPU operations.
//
// This test requires at least four GPUs. For instructions on running this
// within Google, see go/multi-gpu-unit-test.
namespace xla {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
class CollectiveOpsTest : public HloTestBase {
protected:
std::unique_ptr<HloModule> MakeCrsModule(
const Shape& shape, std::vector<std::vector<int64>> replica_groups,
const HloModuleConfig& config, std::string op = "add",
std::string datatype = "f32") {
std::string hlo_template = R"(
HloModule test
apply_op {
x = DATATYPE[] parameter(0)
y = DATATYPE[] parameter(1)
ROOT apply_op = DATATYPE[] OP(x, y)
}
ENTRY test_computation {
p = SHAPE parameter(0)
p2 = SHAPE bitcast(p)
crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
copy = SHAPE copy(crs)
ROOT out = SHAPE bitcast(copy)
}
)";
std::vector<string> replica_group_strs;
for (const auto& g : replica_groups) {
replica_group_strs.push_back(
absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
}
std::string shape_str = shape.ToString(/*print_layout=*/false);
if (shape_str == "f32[1]") {
// Exercise the scalar codepath.
hlo_template = absl::StrReplaceAll(
hlo_template,
{{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
{"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
{"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
}
std::string parameterized_hlo = absl::StrReplaceAll(
hlo_template,
{{"SHAPE", shape_str},
{"REPLICA_GROUPS",
absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
{"OP", op},
{"DATATYPE", datatype}});
return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
}
template <typename LiteralType>
void TestTwoReplicasOneOperand(std::string op, Literal input_value,
Literal expected_value) {
const int kNumReplicas = 2;
std::string dtype = primitive_util::LowercasePrimitiveTypeName(
primitive_util::NativeToPrimitiveType<LiteralType>());
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
auto module = MakeCrsModule(
/*shape_str=*/input_value.shape(),
/*replica_groups=*/{}, config,
/*op=*/op, /*datatype=*/dtype);
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_value},
/*num_replicas=*/kNumReplicas,
/*use_threads=*/true));
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
}
}
template <typename LiteralType>
void TestAllOps() {
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
auto to_literal = [&](absl::Span<const LiteralType> values) {
return LiteralUtil::CreateR1<LiteralType>(values);
};
Literal input_value = to_literal({cast(1), cast(2), cast(3)});
TestTwoReplicasOneOperand<LiteralType>(
"add",
/*input_value=*/input_value.Clone(),
/*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
TestTwoReplicasOneOperand<LiteralType>(
"multiply",
/*input_value=*/input_value.Clone(),
/*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
TestTwoReplicasOneOperand<LiteralType>(
"maximum",
/*input_value=*/input_value.Clone(),
/*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
TestTwoReplicasOneOperand<LiteralType>(
"minimum",
/*input_value=*/input_value.Clone(),
/*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
}
};
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
std::vector<std::vector<int64>> PowerSetOfIota(int64 n) {
std::vector<std::vector<int64>> power_set;
for (int64 i = 1; i < (1 << n); ++i) {
power_set.emplace_back();
for (int64 j = 0; j < n; ++j) {
if (i & (1 << j)) {
power_set.back().push_back(j);
}
}
}
return power_set;
}
// Makes a DeviceAssignment assigning replica-id i to devices[i].
DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
DeviceAssignment assn(/*replica_count=*/devices.size(),
/*computation_count=*/1);
for (int64 i = 0; i < devices.size(); ++i) {
assn(i, 0) = devices[i];
}
return assn;
}
// Shorter alias for this function.
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
}
template <typename T>
static Eigen::half ToHalf(T value) {
return static_cast<Eigen::half>(value);
}
XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
TestTwoReplicasOneOperand<float>(
"add",
/*input_value=*/LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
/*expected_value=*/LiteralUtil::CreateR2<float>({{2, 4}, {6, 8}}));
}
XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
TestTwoReplicasOneOperand<float>(
"add",
/*input_value=*/LiteralUtil::CreateR1<float>({1}),
/*expected_value=*/LiteralUtil::CreateR1<float>({2}));
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
TestAllOps<int8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
TestAllOps<uint8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
TestAllOps<uint32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
TestAllOps<int32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
TestAllOps<int64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
TestAllOps<uint64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
TestAllOps<float>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
TestAllOps<double>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
TestAllOps<Eigen::half>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
// Test with equal elements.
TestTwoReplicasOneOperand<bool>(
"and",
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
// Test with {true, false}.
const char* hlo_module = R"(
HloModule test
apply_op {
x = pred[] parameter(0)
y = pred[] parameter(1)
ROOT apply_op = pred[] and(x, y)
}
ENTRY test_computation {
id = u32[] replica-id()
c = u32[] constant(0)
p = pred[] compare(id, c), direction=EQ
p2 = pred[1] bitcast(p)
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
copy = pred[1] copy(crs)
ROOT out = pred[1] bitcast(copy)
}
)";
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {},
/*num_replicas=*/2,
/*use_threads=*/true));
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
results[replica_idx]));
}
}
XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
// Test with equal elements.
TestTwoReplicasOneOperand<bool>(
"or",
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
// Test with {true, false}.
const char* hlo_module = R"(
HloModule test
apply_op {
x = pred[] parameter(0)
y = pred[] parameter(1)
ROOT apply_op = pred[] or(x, y)
}
ENTRY test_computation {
id = u32[] replica-id()
c = u32[] constant(0)
p = pred[] compare(id, c), direction=EQ
p2 = pred[1] bitcast(p)
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
copy = pred[1] copy(crs)
ROOT out = pred[1] bitcast(copy)
}
)";
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {},
/*num_replicas=*/2,
/*use_threads=*/true));
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
results[replica_idx]));
}
}
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
// devices in sequence.
XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
const int64 kNumDevices = 4;
const int64 kNumElems = 1024;
for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
absl::StrJoin(devices, ", ")));
DeviceAssignment device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(device_assn);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
auto module = MakeCrsModule(input_literal.shape(),
/*replica_groups=*/{}, config);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_literal},
/*num_replicas=*/devices.size(), &device_assn,
/*run_hlo_passes=*/true, /*use_threads=*/true));
}
}
// Check that the NCCL data structures in our all-reduce implementation are
// cached as we expect.
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
const int64 kNumElems = 1024;
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
// Initially no NCCL channels should be open.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
struct ExecutableInfo {
std::unique_ptr<Executable> executable;
DeviceAssignment device_assn;
HloRunner::ReplicatedExecuteOptions opts;
};
std::vector<ExecutableInfo> executables;
for (const auto& devices :
std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
executables.emplace_back();
auto& e = executables.back();
e.device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(e.device_assn);
auto module = MakeCrsModule(input_literal.shape(),
/*replica_groups=*/{}, config);
e.executable =
test_runner_
.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
.ValueOrDie();
e.opts.num_replicas = devices.size();
e.opts.use_threads = true;
e.opts.arguments.push_back(&input_literal);
}
auto run_executable = [&](int64 i) {
auto& e = executables[i];
TF_ASSERT_OK(
test_runner_
.ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
.status());
};
// Compiling executables above shouldn't cause us to open any channels.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Run the executables and check that channels are opened as we expect.
run_executable(0);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1));
run_executable(2);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
// Tear down the executables and check that channels are closed as we expect.
// Note that after we tear down an executable *all* the nccl channels may go
// away, so we rerun all of the executables that haven't been torn down.
executables[2].executable.reset();
run_executable(0);
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
executables[0].executable.reset();
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2));
executables[1].executable.reset();
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
}
// Runs the same executable many times concurrently. The all-reduces should not
// conflict with one another.
XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
const int64 kNumElems = 1024;
const int64 kNumThreads = 200;
const int64 kRunsPerThread = 10;
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto executable =
test_runner_
.CreateExecutable(MakeCrsModule(input_literal.shape(),
/*replica_groups=*/{}, config),
/*run_hlo_passes=*/true)
.ValueOrDie();
std::vector<int64> devices = {0, 1};
auto device_assn = MakeDeviceAssn(devices);
HloRunner::ReplicatedExecuteOptions opts;
opts.num_replicas = devices.size();
opts.use_threads = true;
opts.arguments.push_back(&input_literal);
tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
kNumThreads);
for (int64 i = 0; i < kNumThreads * kRunsPerThread; ++i) {
pool.Schedule([&] {
TF_ASSERT_OK(
test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
.status());
done.DecrementCount();
});
}
done.Wait();
}
// Runs the same executable many times concurrently. The all-reduces should not
// conflict with one another.
XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
std::string hlo_string = R"(
HloModule test
apply_op {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT apply_op = f32[] add(x, y)
}
ENTRY test_computation {
p0 = f32[5] parameter(0)
p1 = f32[5] parameter(1)
crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
}
)";
static constexpr int kNumReplicas = 2;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
/*num_replicas=*/kNumReplicas,
/*use_threads=*/true));
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
auto rs = results[replica_idx].DecomposeTuple();
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
ErrorSpec{1e-5, 1e-5}));
}
}
// Runs an all-reduce with three partitions:
// {0}, {1,2}, {3}
// meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
// 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
// Test a prime number so it's not all powers of 2.
const int64 kNumElems = 137;
auto config = GetModuleConfigForTest();
config.set_replica_count(4);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
auto module = MakeCrsModule(
/*shape_str=*/input_literal.shape(),
/*replica_groups=*/{{0}, {1, 2}, {3}}, config);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
/*use_threads=*/true));
ASSERT_EQ(results.size(), 4);
std::vector<float> input_vec_doubled;
for (float n : input_vec) {
input_vec_doubled.push_back(n * 2);
}
auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
}
XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
id = u32[] replica-id()
ROOT out = u32[] copy(id)
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (uint32 i = 0; i < kNumReplicas; ++i) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
}
}
XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
replica = u32[] replica-id()
ten = u32[] constant(10)
sum = u32[] add(replica, ten)
p = u32[2] broadcast(sum), dimensions={}
ROOT permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
results[1]));
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
results[2]));
// Nothing writes to replica 3, so it is memzero'ed.
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
results[3]));
}
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
std::string hlo_string = R"(
HloModule test
apply_op {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT apply_op = f32[] add(x, y)
}
ENTRY test_computation {
p0 = f32[5] parameter(0)
p1 = f32[7] parameter(1)
ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
}
)";
static constexpr int kNumReplicas = 2;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
std::vector<float> input1_vec = {
7., 3., 4., 1., 2., 3., 4.,
};
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
/*num_replicas=*/kNumReplicas,
/*use_threads=*/true));
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
auto rs = results[replica_idx].DecomposeTuple();
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
ErrorSpec{1e-5, 1e-5}));
}
}
} // namespace
} // namespace xla