internal change only.
PiperOrigin-RevId: 212754752
This commit is contained in:
parent
f4d8442e13
commit
725dfe9cd0
@ -123,8 +123,8 @@ class NodeFilter {
|
||||
// We arbitrarily set this as the boundary between "large" and "small"
|
||||
// instructions.
|
||||
bool IsSmall(const HloInstruction* instr) {
|
||||
if (ShapeUtil::IsOpaque(instr->shape()) ||
|
||||
ShapeUtil::IsToken(instr->shape())) {
|
||||
if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
|
||||
ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
|
||||
return true;
|
||||
}
|
||||
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
|
||||
|
@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return count;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
|
||||
PrimitiveType primitive_type) {
|
||||
if (shape.element_type() == primitive_type) {
|
||||
return true;
|
||||
}
|
||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||
if (HasPrimitiveType(element_shape, primitive_type)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
|
||||
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
|
||||
}
|
||||
|
@ -180,6 +180,10 @@ class ShapeUtil {
|
||||
// As ElementsIn(), but recurses through tuples.
|
||||
static int64 ElementsInRecursive(const Shape& shape);
|
||||
|
||||
// Returns true if shape has the primitive type, recurses through tuples.
|
||||
static bool HasPrimitiveType(const Shape& shape,
|
||||
PrimitiveType primitive_type);
|
||||
|
||||
// Returns true if 'shape' is an array with zero elements.
|
||||
static bool IsZeroElementArray(const Shape& shape);
|
||||
|
||||
|
@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
|
||||
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, HasPrimitiveType) {
|
||||
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
|
||||
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
|
||||
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
|
||||
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
|
||||
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
|
||||
S32));
|
||||
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(S32, {}),
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
|
||||
S16));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IsZeroElementArray) {
|
||||
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
|
||||
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
|
||||
|
Loading…
Reference in New Issue
Block a user