internal change only.
PiperOrigin-RevId: 212754752
This commit is contained in:
parent
f4d8442e13
commit
725dfe9cd0
@ -123,8 +123,8 @@ class NodeFilter {
|
|||||||
// We arbitrarily set this as the boundary between "large" and "small"
|
// We arbitrarily set this as the boundary between "large" and "small"
|
||||||
// instructions.
|
// instructions.
|
||||||
bool IsSmall(const HloInstruction* instr) {
|
bool IsSmall(const HloInstruction* instr) {
|
||||||
if (ShapeUtil::IsOpaque(instr->shape()) ||
|
if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
|
||||||
ShapeUtil::IsToken(instr->shape())) {
|
ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
|
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
|
||||||
|
@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
|
||||||
|
PrimitiveType primitive_type) {
|
||||||
|
if (shape.element_type() == primitive_type) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||||
|
if (HasPrimitiveType(element_shape, primitive_type)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
|
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
|
||||||
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
|
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
|
||||||
}
|
}
|
||||||
|
@ -180,6 +180,10 @@ class ShapeUtil {
|
|||||||
// As ElementsIn(), but recurses through tuples.
|
// As ElementsIn(), but recurses through tuples.
|
||||||
static int64 ElementsInRecursive(const Shape& shape);
|
static int64 ElementsInRecursive(const Shape& shape);
|
||||||
|
|
||||||
|
// Returns true if shape has the primitive type, recurses through tuples.
|
||||||
|
static bool HasPrimitiveType(const Shape& shape,
|
||||||
|
PrimitiveType primitive_type);
|
||||||
|
|
||||||
// Returns true if 'shape' is an array with zero elements.
|
// Returns true if 'shape' is an array with zero elements.
|
||||||
static bool IsZeroElementArray(const Shape& shape);
|
static bool IsZeroElementArray(const Shape& shape);
|
||||||
|
|
||||||
|
@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
|
|||||||
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
|
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ShapeUtilTest, HasPrimitiveType) {
|
||||||
|
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
|
||||||
|
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
|
||||||
|
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
|
||||||
|
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
|
||||||
|
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
|
||||||
|
ShapeUtil::MakeTupleShape(
|
||||||
|
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
|
||||||
|
S32));
|
||||||
|
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
|
||||||
|
ShapeUtil::MakeTupleShape(
|
||||||
|
{ShapeUtil::MakeShape(S32, {}),
|
||||||
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
|
||||||
|
S16));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ShapeUtilTest, IsZeroElementArray) {
|
TEST(ShapeUtilTest, IsZeroElementArray) {
|
||||||
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
|
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
|
||||||
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
|
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
|
||||||
|
Loading…
Reference in New Issue
Block a user