From 8903e5fc72aba856c5567d09b41340a5e32d4f8f Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Sat, 19 Aug 2017 16:33:35 -0700 Subject: [PATCH] [XLA] Make ShapeUtil::ParseShapeString more complete. Handle tuples, nested tuples, more element types. PiperOrigin-RevId: 165826211 --- tensorflow/compiler/xla/shape_util.cc | 88 +++++++++++++++++----- tensorflow/compiler/xla/shape_util.h | 2 +- tensorflow/compiler/xla/shape_util_test.cc | 24 ++++++ 3 files changed, 93 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b84494d34a3..b71b3a9e131 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -425,13 +425,42 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { HumanString(program_shape.result())); } -/* static */ StatusOr ShapeUtil::ParseShapeString(const string& s) { +namespace { +// Parses shapes with simple recursive descent structure -- consumes from the +// front of s and passes that view recursively as required. +StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { + tensorflow::str_util::RemoveLeadingWhitespace(s); + + if (s->Consume("(")) { // Tuple. + std::vector shapes; + bool must_end = false; + while (true) { + if (s->Consume(")")) { + break; + } else if (must_end) { + return InvalidArgument("Expected end of tuple; got: \"%s\"", + s->ToString().c_str()); + } + shapes.emplace_back(); + TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); + tensorflow::str_util::RemoveLeadingWhitespace(s); + must_end = !s->Consume(","); + } + return ShapeUtil::MakeTupleShape(shapes); + } + string element_type_string; string dimensions_string; string layout_string; - if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?", - &element_type_string, &dimensions_string, - &layout_string)) { + // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so + // we convert in to the RE2-consumable type and then consume the corresponding + // amount from our StringPiece type. + tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); + if (RE2::Consume(&s_consumable, + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?", + &element_type_string, &dimensions_string, &layout_string)) { + size_t consumed = s->size() - s_consumable.size(); + s->remove_prefix(consumed); auto comma_list_to_int64s = [&s](const string& input) -> StatusOr> { std::vector results; @@ -439,39 +468,58 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { int64 element; if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) { return InvalidArgument( - "invalid value in parsed shape string: \"%s\" in \"%s\"", - piece.c_str(), s.c_str()); + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", + piece.c_str(), s->ToString().c_str()); } results.push_back(element); } return results; }; + + // Extract the dimensions. TF_ASSIGN_OR_RETURN(std::vector dimensions, comma_list_to_int64s(dimensions_string)); - PrimitiveType primitive_type; - if (element_type_string == "f32") { - primitive_type = F32; - } else if (element_type_string == "s32") { - primitive_type = S32; - } else if (element_type_string == "u32") { - primitive_type = U32; - } else { - LOG(FATAL) << "unhandled element type string: " << element_type_string; + + // Extract the primitive element type. + PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; + for (PrimitiveType i = + static_cast(PRIMITIVE_TYPE_INVALID + 1); + i < TUPLE; i = static_cast(i + 1)) { + if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == + element_type_string) { + primitive_type = i; + break; + } } + if (primitive_type == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("Invalid element type string: \"%s\".", + element_type_string.c_str()); + } + Shape result; if (layout_string.empty()) { - result = MakeShape(primitive_type, dimensions); + // Create a shape without a layout set. + result = ShapeUtil::MakeShape(primitive_type, dimensions); } else { + // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = MakeShapeWithLayout(primitive_type, dimensions, min2maj); + result = + ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); } - TF_DCHECK_OK(ValidateShape(result)); - return result; + TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + return std::move(result); } - return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", + s->ToString().c_str()); +} +} // namespace + +/* static */ StatusOr ShapeUtil::ParseShapeString( + tensorflow::StringPiece s) { + return ParseShapeStringInternal(&s); } /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index f0058a6ed39..e3473138376 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -125,7 +125,7 @@ class ShapeUtil { // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. - static StatusOr ParseShapeString(const string& s); + static StatusOr ParseShapeString(tensorflow::StringPiece s); // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 69ef6175ccd..9635e5ad2eb 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2]), f32[3])"; + Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});