[XLA] Make overly-specific ShapeUtil predicate a little more general.
PiperOrigin-RevId: 216263039
This commit is contained in:
parent
eb0f862ba6
commit
cb057ea640
@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) {
|
||||
auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
|
||||
|
||||
EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
|
||||
EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape()));
|
||||
EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
|
||||
EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
|
||||
EXPECT_EQ(0, parameter->operand_count());
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ namespace hlo_query {
|
||||
|
||||
bool IsConstantR0F32(HloInstruction* instruction, float* out) {
|
||||
if (instruction->opcode() == HloOpcode::kConstant &&
|
||||
ShapeUtil::IsScalarF32(instruction->shape())) {
|
||||
ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
|
||||
*out = instruction->literal().Get<float>({});
|
||||
return true;
|
||||
}
|
||||
|
@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
|
||||
return shape.element_type() == F32 && Rank(shape) == 0;
|
||||
/* static */ bool ShapeUtil::IsScalarWithElementType(
|
||||
const Shape& shape, PrimitiveType element_type) {
|
||||
return IsScalar(shape) && shape.element_type() == element_type;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -312,7 +312,10 @@ class ShapeUtil {
|
||||
static bool IsEffectiveScalar(const Shape& shape) {
|
||||
return IsArray(shape) && TrueRank(shape) == 0;
|
||||
}
|
||||
static bool IsScalarF32(const Shape& shape);
|
||||
|
||||
// Returns whether "shape" is a scalar (array) with the given element_type.
|
||||
static bool IsScalarWithElementType(const Shape& shape,
|
||||
PrimitiveType element_type);
|
||||
|
||||
// Extracts the size of the shape's dimension at dimension number
|
||||
// GetDimensionNumber(dimension_number).
|
||||
|
Loading…
x
Reference in New Issue
Block a user