Support multiple index-space constraints in synthetic input generator.

PiperOrigin-RevId: 208868489
This commit is contained in:
Mark Heffernan 2018-08-15 13:04:39 -07:00 committed by TensorFlower Gardener
parent bc646fd576
commit 75399bba46
3 changed files with 92 additions and 36 deletions

View File

@ -2076,6 +2076,8 @@ tf_cc_test(
xla_test(
name = "test_utils_test",
srcs = ["test_utils_test.cc"],
# There is nothing backend specific in this test, so just pick an arbitrary backend.
backends = ["cpu"],
deps = [
":local_client_test_base",
":test_utils",

View File

@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
const Shape& input_shape, const Shape& slice_shape,
std::minstd_rand0* engine) {
const int64 rank = ShapeUtil::Rank(input_shape);
std::vector<int32> start_indices(rank);
std::unique_ptr<Literal> MakeRandomIndex(
tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < rank; ++i) {
const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
std::uniform_int_distribution<int32> generator(0, upper_bound);
for (int i = 0; i < index_space.size(); ++i) {
std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
HloInstruction* needs_constant = nullptr;
std::vector<int64> index_space;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
if (needs_index != nullptr) {
auto needs_index_shape = needs_index->shape();
auto use_shape = use->shape();
if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
needs_index_shape = needs_index->operand(0)->shape();
case HloOpcode::kDynamicUpdateSlice: {
const Shape& indexed_shape = use->operand(0)->shape();
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
? use->shape()
: use->operand(1)->shape();
const int64 rank = ShapeUtil::Rank(indexed_shape);
if (!index_space.empty()) {
TF_RET_CHECK(rank == index_space.size());
for (int64 i = 0; i < rank; ++i) {
index_space[i] = std::min(
index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i));
}
if (use->opcode() == HloOpcode::kDynamicSlice) {
use_shape = use->operand(0)->shape();
}
if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
return Unimplemented(
"Conflicting operand generation slice index constraints\n");
} else {
index_space.resize(rank);
for (int64 i = 0; i < rank; ++i) {
index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
}
}
needs_index = use;
break;
}
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;
needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
needs_constant = use;
needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
if (needs_index != nullptr && needs_constant != nullptr) {
if (!index_space.empty() && needs_constant) {
return Unimplemented(
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
"constant: %s\n",
needs_index->ToString().c_str(), needs_constant->ToString().c_str());
"Conflicting operand generation constraints. Dynamically indexes a "
"shape and is the init value of a reduction.");
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
needs_index->shape(), engine);
} else if (needs_constant != nullptr) {
if (!index_space.empty()) {
return MakeRandomIndex(index_space, engine);
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
*dataflow, *params[i], engine.get()));
arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
.ValueOrDie();
}
return std::move(arguments);
}

View File

@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
auto module = ParseHloString(
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
const Literal& index_arg = *args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
auto module = ParseHloString(
R"(HloModule index_space_module
ENTRY IndexSpace {
index_param = s32[3]{0} parameter(0)
array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
const Literal& index_arg = *args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
EXPECT_GE(index_arg.Get<int32>({1}), 0);
EXPECT_LE(index_arg.Get<int32>({1}), 2);
EXPECT_GE(index_arg.Get<int32>({2}), 0);
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
} // namespace
} // namespace xla