Support multiple index-space constraints in synthetic input generator.
PiperOrigin-RevId: 208868489
This commit is contained in:
parent
bc646fd576
commit
75399bba46
@ -2076,6 +2076,8 @@ tf_cc_test(
|
||||
xla_test(
|
||||
name = "test_utils_test",
|
||||
srcs = ["test_utils_test.cc"],
|
||||
# There is nothing backend specific in this test, so just pick an arbitrary backend.
|
||||
backends = ["cpu"],
|
||||
deps = [
|
||||
":local_client_test_base",
|
||||
":test_utils",
|
||||
|
@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
|
||||
|
||||
// Generate random values that are constrained to the input_shape minus the
|
||||
// output_shape so as not to produce wrapping slices, for instance.
|
||||
std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
|
||||
const Shape& input_shape, const Shape& slice_shape,
|
||||
std::minstd_rand0* engine) {
|
||||
const int64 rank = ShapeUtil::Rank(input_shape);
|
||||
std::vector<int32> start_indices(rank);
|
||||
std::unique_ptr<Literal> MakeRandomIndex(
|
||||
tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
|
||||
std::vector<int32> start_indices(index_space.size());
|
||||
if (engine != nullptr) {
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
|
||||
ShapeUtil::GetDimension(slice_shape, i);
|
||||
std::uniform_int_distribution<int32> generator(0, upper_bound);
|
||||
for (int i = 0; i < index_space.size(); ++i) {
|
||||
std::uniform_int_distribution<int32> generator(0, index_space[i]);
|
||||
start_indices[i] = generator(*engine);
|
||||
}
|
||||
}
|
||||
@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
|
||||
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
|
||||
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
|
||||
const HloInstruction& param, std::minstd_rand0* engine) {
|
||||
HloInstruction* needs_index = nullptr;
|
||||
HloInstruction* needs_constant = nullptr;
|
||||
std::vector<int64> index_space;
|
||||
bool needs_constant = false;
|
||||
ConstantType constant_type = ConstantType::kUnknown;
|
||||
for (HloInstruction* use : constrained_uses) {
|
||||
switch (use->opcode()) {
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
if (needs_index != nullptr) {
|
||||
auto needs_index_shape = needs_index->shape();
|
||||
auto use_shape = use->shape();
|
||||
if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
|
||||
needs_index_shape = needs_index->operand(0)->shape();
|
||||
case HloOpcode::kDynamicUpdateSlice: {
|
||||
const Shape& indexed_shape = use->operand(0)->shape();
|
||||
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
|
||||
? use->shape()
|
||||
: use->operand(1)->shape();
|
||||
const int64 rank = ShapeUtil::Rank(indexed_shape);
|
||||
if (!index_space.empty()) {
|
||||
TF_RET_CHECK(rank == index_space.size());
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
index_space[i] = std::min(
|
||||
index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
|
||||
ShapeUtil::GetDimension(slice_shape, i));
|
||||
}
|
||||
if (use->opcode() == HloOpcode::kDynamicSlice) {
|
||||
use_shape = use->operand(0)->shape();
|
||||
}
|
||||
if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
|
||||
return Unimplemented(
|
||||
"Conflicting operand generation slice index constraints\n");
|
||||
} else {
|
||||
index_space.resize(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
|
||||
ShapeUtil::GetDimension(slice_shape, i);
|
||||
}
|
||||
}
|
||||
needs_index = use;
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
needs_constant = use;
|
||||
needs_constant = true;
|
||||
constant_type = GetInitValue(*use->to_apply());
|
||||
break;
|
||||
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
needs_constant = use;
|
||||
needs_constant = true;
|
||||
constant_type = GetInitValue(*use->scatter());
|
||||
break;
|
||||
|
||||
@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
|
||||
use->ToString().c_str());
|
||||
}
|
||||
}
|
||||
if (needs_index != nullptr && needs_constant != nullptr) {
|
||||
if (!index_space.empty() && needs_constant) {
|
||||
return Unimplemented(
|
||||
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
|
||||
"constant: %s\n",
|
||||
needs_index->ToString().c_str(), needs_constant->ToString().c_str());
|
||||
"Conflicting operand generation constraints. Dynamically indexes a "
|
||||
"shape and is the init value of a reduction.");
|
||||
}
|
||||
if (needs_index != nullptr) {
|
||||
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
|
||||
needs_index->shape(), engine);
|
||||
} else if (needs_constant != nullptr) {
|
||||
if (!index_space.empty()) {
|
||||
return MakeRandomIndex(index_space, engine);
|
||||
} else if (needs_constant) {
|
||||
switch (constant_type) {
|
||||
case ConstantType::kZero:
|
||||
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
|
||||
@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
|
||||
std::vector<std::unique_ptr<Literal>> arguments(params.size());
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
|
||||
*dataflow, *params[i], engine.get()));
|
||||
arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
|
||||
.ValueOrDie();
|
||||
}
|
||||
return std::move(arguments);
|
||||
}
|
||||
|
@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
|
||||
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
||||
auto module = ParseHloString(
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
index_param = s32[3]{0} parameter(0)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
|
||||
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
|
||||
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
|
||||
MakeFakeArguments(module.get()));
|
||||
ASSERT_EQ(args.size(), 3);
|
||||
const Literal& index_arg = *args[0];
|
||||
|
||||
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({1}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({1}), 2);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({2}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({2}), 3);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||
auto module = ParseHloString(
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
index_param = s32[3]{0} parameter(0)
|
||||
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
|
||||
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
|
||||
update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
|
||||
update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
|
||||
|
||||
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
|
||||
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
|
||||
MakeFakeArguments(module.get()));
|
||||
ASSERT_EQ(args.size(), 5);
|
||||
const Literal& index_arg = *args[0];
|
||||
|
||||
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({1}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({1}), 2);
|
||||
|
||||
EXPECT_GE(index_arg.Get<int32>({2}), 0);
|
||||
EXPECT_LE(index_arg.Get<int32>({2}), 3);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user