Merge pull request #32264 from nouiz:utils_pr_hlo_creation
PiperOrigin-RevId: 268091319
This commit is contained in:
commit
c64d415330
@ -185,6 +185,12 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
|||||||
broadcast_shape, operand, broadcast_dimensions));
|
broadcast_shape, operand, broadcast_dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
const Shape& shape) {
|
||||||
|
return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
|
StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
|
||||||
int64 index) {
|
int64 index) {
|
||||||
HloComputation* computation = operand->parent();
|
HloComputation* computation = operand->parent();
|
||||||
@ -224,6 +230,22 @@ HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
|
|||||||
return hlo;
|
return hlo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
|
||||||
|
PrimitiveType type) {
|
||||||
|
CHECK_NE(hlo->shape().element_type(), type);
|
||||||
|
Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
|
||||||
|
hlo = hlo->parent()->AddInstruction(
|
||||||
|
HloInstruction::CreateBitcastConvert(shape, hlo));
|
||||||
|
CHECK_EQ(hlo->shape().element_type(), type);
|
||||||
|
return hlo;
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||||
|
int64 iota_dimension) {
|
||||||
|
return computation->AddInstruction(
|
||||||
|
HloInstruction::CreateIota(shape, iota_dimension));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||||
const DotDimensionNumbers& dim_numbers,
|
const DotDimensionNumbers& dim_numbers,
|
||||||
const PrecisionConfig& precision_config) {
|
const PrecisionConfig& precision_config) {
|
||||||
|
@ -91,6 +91,9 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
|
|||||||
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
absl::Span<const int64> result_shape_bounds);
|
absl::Span<const int64> result_shape_bounds);
|
||||||
|
HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
const Shape& shape);
|
||||||
|
|
||||||
// Creates a GetTupleElement HLO instruction and adds it to the computation
|
// Creates a GetTupleElement HLO instruction and adds it to the computation
|
||||||
// containing `operand`.
|
// containing `operand`.
|
||||||
@ -107,6 +110,14 @@ StatusOr<HloInstruction*> MakeConcatHlo(
|
|||||||
// the given primitive type.
|
// the given primitive type.
|
||||||
HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
|
HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
|
||||||
|
|
||||||
|
// Creates a BitcastConvert HLO instruction.
|
||||||
|
HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
|
||||||
|
PrimitiveType type);
|
||||||
|
|
||||||
|
// Creates an Iota HLO instruction.
|
||||||
|
HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||||
|
int64 iota_dimension);
|
||||||
|
|
||||||
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
|
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
|
||||||
// and `rhs` (both must be in the same computation).
|
// and `rhs` (both must be in the same computation).
|
||||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||||
|
@ -41,6 +41,21 @@ class HloCreationUtilsTest : public HloTestBase {
|
|||||||
*param = (*entry_computation)->parameter_instruction(0);
|
*param = (*entry_computation)->parameter_instruction(0);
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
|
||||||
|
PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
|
||||||
|
absl::Span<const int64> output_shape_dims, HloInstruction** param,
|
||||||
|
HloComputation** entry_computation, PrimitiveType primitive_type_output) {
|
||||||
|
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
|
||||||
|
Shape output_shape =
|
||||||
|
ShapeUtil::MakeShape(primitive_type_output, output_shape_dims);
|
||||||
|
auto module = CreateNewVerifiedModule("test");
|
||||||
|
*entry_computation = module->AddEntryComputation(
|
||||||
|
CreateComputationWithSignature({&input_shape}, output_shape, "entry")
|
||||||
|
.ValueOrDie());
|
||||||
|
*param = (*entry_computation)->parameter_instruction(0);
|
||||||
|
return module;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
|
TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
|
||||||
@ -222,5 +237,85 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
|
|||||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) {
|
||||||
|
HloInstruction* param;
|
||||||
|
HloComputation* entry_computation;
|
||||||
|
|
||||||
|
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2},
|
||||||
|
/*output_shape_dims=*/{2, 2},
|
||||||
|
¶m, &entry_computation, F32);
|
||||||
|
auto* input = module->entry_computation()->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})));
|
||||||
|
|
||||||
|
HloInstruction* output = MakeBitcastConvertToHlo(input, F32);
|
||||||
|
entry_computation->set_root_instruction(output);
|
||||||
|
|
||||||
|
HloEvaluator evaluator;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal result_literal,
|
||||||
|
evaluator.Evaluate(*module,
|
||||||
|
{LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})}));
|
||||||
|
CHECK_EQ(result_literal,
|
||||||
|
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) {
|
||||||
|
HloInstruction* param;
|
||||||
|
HloComputation* entry_computation;
|
||||||
|
|
||||||
|
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
|
||||||
|
/*output_shape_dims=*/{2, 2},
|
||||||
|
¶m, &entry_computation, F32);
|
||||||
|
HloInstruction* output = MakeIotaHlo(module->entry_computation(),
|
||||||
|
ShapeUtil::MakeShape(F32, {2, 2}), 0);
|
||||||
|
entry_computation->set_root_instruction(output);
|
||||||
|
|
||||||
|
HloEvaluator evaluator;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal result_literal,
|
||||||
|
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
|
||||||
|
CHECK_EQ(result_literal,
|
||||||
|
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {1.0f, 1.0f}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) {
|
||||||
|
HloInstruction* param;
|
||||||
|
HloComputation* entry_computation;
|
||||||
|
|
||||||
|
auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
|
||||||
|
/*output_shape_dims=*/{2, 2},
|
||||||
|
¶m, &entry_computation);
|
||||||
|
auto* input = MakeR0ConstantHlo<float>(module->entry_computation(), 0);
|
||||||
|
HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2});
|
||||||
|
entry_computation->set_root_instruction(output);
|
||||||
|
|
||||||
|
HloEvaluator evaluator;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal result_literal,
|
||||||
|
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
|
||||||
|
CHECK_EQ(result_literal,
|
||||||
|
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) {
|
||||||
|
HloInstruction* param;
|
||||||
|
HloComputation* entry_computation;
|
||||||
|
|
||||||
|
auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
|
||||||
|
/*output_shape_dims=*/{2, 2},
|
||||||
|
¶m, &entry_computation);
|
||||||
|
auto* input = MakeR0ConstantHlo<int32>(module->entry_computation(), 0);
|
||||||
|
HloInstruction* output =
|
||||||
|
MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2}));
|
||||||
|
entry_computation->set_root_instruction(output);
|
||||||
|
|
||||||
|
HloEvaluator evaluator;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal result_literal,
|
||||||
|
evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0.0)}));
|
||||||
|
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user