Merge pull request #32264 from nouiz:utils_pr_hlo_creation
PiperOrigin-RevId: 268091319
This commit is contained in:
commit
c64d415330
@ -185,6 +185,12 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||
broadcast_shape, operand, broadcast_dimensions));
|
||||
}
|
||||
|
||||
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
const Shape& shape) {
|
||||
return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
|
||||
int64 index) {
|
||||
HloComputation* computation = operand->parent();
|
||||
@ -224,6 +230,22 @@ HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
|
||||
return hlo;
|
||||
}
|
||||
|
||||
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
|
||||
PrimitiveType type) {
|
||||
CHECK_NE(hlo->shape().element_type(), type);
|
||||
Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
|
||||
hlo = hlo->parent()->AddInstruction(
|
||||
HloInstruction::CreateBitcastConvert(shape, hlo));
|
||||
CHECK_EQ(hlo->shape().element_type(), type);
|
||||
return hlo;
|
||||
}
|
||||
|
||||
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||
int64 iota_dimension) {
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateIota(shape, iota_dimension));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config) {
|
||||
|
@ -91,6 +91,9 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
|
||||
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::Span<const int64> result_shape_bounds);
|
||||
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
const Shape& shape);
|
||||
|
||||
// Creates a GetTupleElement HLO instruction and adds it to the computation
|
||||
// containing `operand`.
|
||||
@ -107,6 +110,14 @@ StatusOr<HloInstruction*> MakeConcatHlo(
|
||||
// the given primitive type.
|
||||
HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
|
||||
|
||||
// Creates a BitcastConvert HLO instruction.
|
||||
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
|
||||
PrimitiveType type);
|
||||
|
||||
// Creates an Iota HLO instruction.
|
||||
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||
int64 iota_dimension);
|
||||
|
||||
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
|
||||
// and `rhs` (both must be in the same computation).
|
||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||
|
@ -41,6 +41,21 @@ class HloCreationUtilsTest : public HloTestBase {
|
||||
*param = (*entry_computation)->parameter_instruction(0);
|
||||
return module;
|
||||
}
|
||||
|
||||
std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
|
||||
PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
|
||||
absl::Span<const int64> output_shape_dims, HloInstruction** param,
|
||||
HloComputation** entry_computation, PrimitiveType primitive_type_output) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
|
||||
Shape output_shape =
|
||||
ShapeUtil::MakeShape(primitive_type_output, output_shape_dims);
|
||||
auto module = CreateNewVerifiedModule("test");
|
||||
*entry_computation = module->AddEntryComputation(
|
||||
CreateComputationWithSignature({&input_shape}, output_shape, "entry")
|
||||
.ValueOrDie());
|
||||
*param = (*entry_computation)->parameter_instruction(0);
|
||||
return module;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
|
||||
@ -222,5 +237,85 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
|
||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) {
|
||||
HloInstruction* param;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
|
||||
/*output_shape_dims=*/{2, 2},
|
||||
¶m, &entry_computation, F32);
|
||||
auto* input = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})));
|
||||
|
||||
HloInstruction* output = MakeBitcastConvertToHlo(input, F32);
|
||||
entry_computation->set_root_instruction(output);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate(*module,
|
||||
{LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})}));
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) {
|
||||
HloInstruction* param;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
|
||||
/*output_shape_dims=*/{2, 2},
|
||||
¶m, &entry_computation, F32);
|
||||
HloInstruction* output = MakeIotaHlo(module->entry_computation(),
|
||||
ShapeUtil::MakeShape(F32, {2, 2}), 0);
|
||||
entry_computation->set_root_instruction(output);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {1.0f, 1.0f}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) {
|
||||
HloInstruction* param;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
|
||||
/*output_shape_dims=*/{2, 2},
|
||||
¶m, &entry_computation);
|
||||
auto* input = MakeR0ConstantHlo<float>(module->entry_computation(), 0);
|
||||
HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2});
|
||||
entry_computation->set_root_instruction(output);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) {
|
||||
HloInstruction* param;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
|
||||
/*output_shape_dims=*/{2, 2},
|
||||
¶m, &entry_computation);
|
||||
auto* input = MakeR0ConstantHlo<int32>(module->entry_computation(), 0);
|
||||
HloInstruction* output =
|
||||
MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2}));
|
||||
entry_computation->set_root_instruction(output);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user