[XLA] Make ShapeUtil::ParseShapeString more complete.

Handle tuples, nested tuples, more element types.

PiperOrigin-RevId: 165826211
This commit is contained in:
Chris Leary 2017-08-19 16:33:35 -07:00 committed by TensorFlower Gardener
parent c572ca84f4
commit 8903e5fc72
3 changed files with 93 additions and 21 deletions

View File

@ -425,13 +425,42 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
HumanString(program_shape.result()));
}
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(const string& s) {
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
tensorflow::str_util::RemoveLeadingWhitespace(s);
if (s->Consume("(")) { // Tuple.
std::vector<Shape> shapes;
bool must_end = false;
while (true) {
if (s->Consume(")")) {
break;
} else if (must_end) {
return InvalidArgument("Expected end of tuple; got: \"%s\"",
s->ToString().c_str());
}
shapes.emplace_back();
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
tensorflow::str_util::RemoveLeadingWhitespace(s);
must_end = !s->Consume(",");
}
return ShapeUtil::MakeTupleShape(shapes);
}
string element_type_string;
string dimensions_string;
string layout_string;
if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?",
&element_type_string, &dimensions_string,
&layout_string)) {
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
// amount from our StringPiece type.
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
if (RE2::Consume(&s_consumable,
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?",
&element_type_string, &dimensions_string, &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
auto comma_list_to_int64s =
[&s](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
@ -439,39 +468,58 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
int64 element;
if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
return InvalidArgument(
"invalid value in parsed shape string: \"%s\" in \"%s\"",
piece.c_str(), s.c_str());
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
piece.c_str(), s->ToString().c_str());
}
results.push_back(element);
}
return results;
};
// Extract the dimensions.
TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
comma_list_to_int64s(dimensions_string));
PrimitiveType primitive_type;
if (element_type_string == "f32") {
primitive_type = F32;
} else if (element_type_string == "s32") {
primitive_type = S32;
} else if (element_type_string == "u32") {
primitive_type = U32;
} else {
LOG(FATAL) << "unhandled element type string: " << element_type_string;
// Extract the primitive element type.
PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID;
for (PrimitiveType i =
static_cast<PrimitiveType>(PRIMITIVE_TYPE_INVALID + 1);
i < TUPLE; i = static_cast<PrimitiveType>(i + 1)) {
if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) ==
element_type_string) {
primitive_type = i;
break;
}
}
if (primitive_type == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str());
}
Shape result;
if (layout_string.empty()) {
result = MakeShape(primitive_type, dimensions);
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string));
TF_RET_CHECK(dimensions.size() == min2maj.size());
result = MakeShapeWithLayout(primitive_type, dimensions, min2maj);
result =
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
}
TF_DCHECK_OK(ValidateShape(result));
return result;
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
return std::move(result);
}
return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str());
return InvalidArgument("Invalid shape string to parse: \"%s\"",
s->ToString().c_str());
}
} // namespace
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
tensorflow::StringPiece s) {
return ParseShapeStringInternal(&s);
}
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,

View File

@ -125,7 +125,7 @@ class ShapeUtil {
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
static StatusOr<Shape> ParseShapeString(const string& s);
static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
// not check element type.

View File

@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
string shape_string = "(f32[1572864],s8[5120,1024])";
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
Shape expected =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
ShapeUtil::MakeShape(S8, {5120, 1024})});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2]), f32[3])";
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});