internal change only.

PiperOrigin-RevId: 212754752
This commit is contained in:
A. Unique TensorFlower 2018-09-12 21:22:34 -07:00 committed by TensorFlower Gardener
parent f4d8442e13
commit 725dfe9cd0
4 changed files with 35 additions and 2 deletions

View File

@ -123,8 +123,8 @@ class NodeFilter {
// We arbitrarily set this as the boundary between "large" and "small"
// instructions.
bool IsSmall(const HloInstruction* instr) {
if (ShapeUtil::IsOpaque(instr->shape()) ||
ShapeUtil::IsToken(instr->shape())) {
if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
return true;
}
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;

View File

@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return count;
}
/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
PrimitiveType primitive_type) {
if (shape.element_type() == primitive_type) {
return true;
}
for (const Shape& element_shape : shape.tuple_shapes()) {
if (HasPrimitiveType(element_shape, primitive_type)) {
return true;
}
}
return false;
}
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}

View File

@ -180,6 +180,10 @@ class ShapeUtil {
// As ElementsIn(), but recurses through tuples.
static int64 ElementsInRecursive(const Shape& shape);
// Returns true if shape has the primitive type, recurses through tuples.
static bool HasPrimitiveType(const Shape& shape,
PrimitiveType primitive_type);
// Returns true if 'shape' is an array with zero elements.
static bool IsZeroElementArray(const Shape& shape);

View File

@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
TEST(ShapeUtilTest, HasPrimitiveType) {
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
S32));
EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
S16));
}
TEST(ShapeUtilTest, IsZeroElementArray) {
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));