Merge pull request #32264 from nouiz:utils_pr_hlo_creation

PiperOrigin-RevId: 268091319
This commit is contained in:
TensorFlower Gardener 2019-09-09 16:28:58 -07:00
commit c64d415330
3 changed files with 128 additions and 0 deletions

View File

@ -185,6 +185,12 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
broadcast_shape, operand, broadcast_dimensions));
}
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
absl::Span<const int64> broadcast_dimensions,
const Shape& shape) {
return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
}
StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
int64 index) {
HloComputation* computation = operand->parent();
@ -224,6 +230,22 @@ HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
return hlo;
}
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
PrimitiveType type) {
CHECK_NE(hlo->shape().element_type(), type);
Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
hlo = hlo->parent()->AddInstruction(
HloInstruction::CreateBitcastConvert(shape, hlo));
CHECK_EQ(hlo->shape().element_type(), type);
return hlo;
}
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
int64 iota_dimension) {
return computation->AddInstruction(
HloInstruction::CreateIota(shape, iota_dimension));
}
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config) {

View File

@ -91,6 +91,9 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
absl::Span<const int64> broadcast_dimensions,
absl::Span<const int64> result_shape_bounds);
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
absl::Span<const int64> broadcast_dimensions,
const Shape& shape);
// Creates a GetTupleElement HLO instruction and adds it to the computation
// containing `operand`.
@ -107,6 +110,14 @@ StatusOr<HloInstruction*> MakeConcatHlo(
// the given primitive type.
HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
// Creates a BitcastConvert HLO instruction.
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
PrimitiveType type);
// Creates an Iota HLO instruction.
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
int64 iota_dimension);
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,

View File

@ -41,6 +41,21 @@ class HloCreationUtilsTest : public HloTestBase {
*param = (*entry_computation)->parameter_instruction(0);
return module;
}
std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
absl::Span<const int64> output_shape_dims, HloInstruction** param,
HloComputation** entry_computation, PrimitiveType primitive_type_output) {
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
Shape output_shape =
ShapeUtil::MakeShape(primitive_type_output, output_shape_dims);
auto module = CreateNewVerifiedModule("test");
*entry_computation = module->AddEntryComputation(
CreateComputationWithSignature({&input_shape}, output_shape, "entry")
.ValueOrDie());
*param = (*entry_computation)->parameter_instruction(0);
return module;
}
};
TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
@ -222,5 +237,85 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) {
HloInstruction* param;
HloComputation* entry_computation;
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
/*output_shape_dims=*/{2, 2},
&param, &entry_computation, F32);
auto* input = module->entry_computation()->AddInstruction(
HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})));
HloInstruction* output = MakeBitcastConvertToHlo(input, F32);
entry_computation->set_root_instruction(output);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate(*module,
{LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})}));
CHECK_EQ(result_literal,
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) {
HloInstruction* param;
HloComputation* entry_computation;
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
/*output_shape_dims=*/{2, 2},
&param, &entry_computation, F32);
HloInstruction* output = MakeIotaHlo(module->entry_computation(),
ShapeUtil::MakeShape(F32, {2, 2}), 0);
entry_computation->set_root_instruction(output);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
CHECK_EQ(result_literal,
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {1.0f, 1.0f}}));
}
TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) {
HloInstruction* param;
HloComputation* entry_computation;
auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
/*output_shape_dims=*/{2, 2},
&param, &entry_computation);
auto* input = MakeR0ConstantHlo<float>(module->entry_computation(), 0);
HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2});
entry_computation->set_root_instruction(output);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
CHECK_EQ(result_literal,
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) {
HloInstruction* param;
HloComputation* entry_computation;
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
/*output_shape_dims=*/{2, 2},
&param, &entry_computation);
auto* input = MakeR0ConstantHlo<int32>(module->entry_computation(), 0);
HloInstruction* output =
MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2}));
entry_computation->set_root_instruction(output);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
} // namespace
} // namespace xla