From 725dfe9cd0eef3f4b858eaeda38728813c99a210 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Sep 2018 21:22:34 -0700 Subject: [PATCH] internal change only. PiperOrigin-RevId: 212754752 --- .../compiler/xla/service/hlo_graph_dumper.cc | 4 ++-- tensorflow/compiler/xla/shape_util.cc | 13 +++++++++++++ tensorflow/compiler/xla/shape_util.h | 4 ++++ tensorflow/compiler/xla/shape_util_test.cc | 16 ++++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 4826bff19e8..287ba84b3b2 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -123,8 +123,8 @@ class NodeFilter { // We arbitrarily set this as the boundary between "large" and "small" // instructions. bool IsSmall(const HloInstruction* instr) { - if (ShapeUtil::IsOpaque(instr->shape()) || - ShapeUtil::IsToken(instr->shape())) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { return true; } return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 9772c06bce3..96c80fd577e 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return count; } +/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type) { + if (shape.element_type() == primitive_type) { + return true; + } + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } + } + return false; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8234fcdd3f5..623ae39de81 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -180,6 +180,10 @@ class ShapeUtil { // As ElementsIn(), but recurses through tuples. static int64 ElementsInRecursive(const Shape& shape); + // Returns true if shape has the primitive type, recurses through tuples. + static bool HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 6ca4085aaf3..c622ecdca1f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } +TEST(ShapeUtilTest, HasPrimitiveType) { + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}), + S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}), + S16)); +} + TEST(ShapeUtilTest, IsZeroElementArray) { EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));