[XLA] Make ShapeUtil::ParseShapeString more complete.
Handle tuples, nested tuples, more element types. PiperOrigin-RevId: 165826211
This commit is contained in:
parent
c572ca84f4
commit
8903e5fc72
@ -425,13 +425,42 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(const string& s) {
|
||||
namespace {
|
||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||
// front of s and passes that view recursively as required.
|
||||
StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
tensorflow::str_util::RemoveLeadingWhitespace(s);
|
||||
|
||||
if (s->Consume("(")) { // Tuple.
|
||||
std::vector<Shape> shapes;
|
||||
bool must_end = false;
|
||||
while (true) {
|
||||
if (s->Consume(")")) {
|
||||
break;
|
||||
} else if (must_end) {
|
||||
return InvalidArgument("Expected end of tuple; got: \"%s\"",
|
||||
s->ToString().c_str());
|
||||
}
|
||||
shapes.emplace_back();
|
||||
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
|
||||
tensorflow::str_util::RemoveLeadingWhitespace(s);
|
||||
must_end = !s->Consume(",");
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
|
||||
string element_type_string;
|
||||
string dimensions_string;
|
||||
string layout_string;
|
||||
if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?",
|
||||
&element_type_string, &dimensions_string,
|
||||
&layout_string)) {
|
||||
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
|
||||
// we convert in to the RE2-consumable type and then consume the corresponding
|
||||
// amount from our StringPiece type.
|
||||
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
|
||||
if (RE2::Consume(&s_consumable,
|
||||
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?",
|
||||
&element_type_string, &dimensions_string, &layout_string)) {
|
||||
size_t consumed = s->size() - s_consumable.size();
|
||||
s->remove_prefix(consumed);
|
||||
auto comma_list_to_int64s =
|
||||
[&s](const string& input) -> StatusOr<std::vector<int64>> {
|
||||
std::vector<int64> results;
|
||||
@ -439,39 +468,58 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
int64 element;
|
||||
if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
|
||||
return InvalidArgument(
|
||||
"invalid value in parsed shape string: \"%s\" in \"%s\"",
|
||||
piece.c_str(), s.c_str());
|
||||
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
|
||||
piece.c_str(), s->ToString().c_str());
|
||||
}
|
||||
results.push_back(element);
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
// Extract the dimensions.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
|
||||
comma_list_to_int64s(dimensions_string));
|
||||
PrimitiveType primitive_type;
|
||||
if (element_type_string == "f32") {
|
||||
primitive_type = F32;
|
||||
} else if (element_type_string == "s32") {
|
||||
primitive_type = S32;
|
||||
} else if (element_type_string == "u32") {
|
||||
primitive_type = U32;
|
||||
} else {
|
||||
LOG(FATAL) << "unhandled element type string: " << element_type_string;
|
||||
|
||||
// Extract the primitive element type.
|
||||
PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID;
|
||||
for (PrimitiveType i =
|
||||
static_cast<PrimitiveType>(PRIMITIVE_TYPE_INVALID + 1);
|
||||
i < TUPLE; i = static_cast<PrimitiveType>(i + 1)) {
|
||||
if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) ==
|
||||
element_type_string) {
|
||||
primitive_type = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (primitive_type == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".",
|
||||
element_type_string.c_str());
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (layout_string.empty()) {
|
||||
result = MakeShape(primitive_type, dimensions);
|
||||
// Create a shape without a layout set.
|
||||
result = ShapeUtil::MakeShape(primitive_type, dimensions);
|
||||
} else {
|
||||
// Extract the layout minor-to-major and set it.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
|
||||
comma_list_to_int64s(layout_string));
|
||||
TF_RET_CHECK(dimensions.size() == min2maj.size());
|
||||
result = MakeShapeWithLayout(primitive_type, dimensions, min2maj);
|
||||
result =
|
||||
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
|
||||
}
|
||||
TF_DCHECK_OK(ValidateShape(result));
|
||||
return result;
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str());
|
||||
return InvalidArgument("Invalid shape string to parse: \"%s\"",
|
||||
s->ToString().c_str());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
|
||||
tensorflow::StringPiece s) {
|
||||
return ParseShapeStringInternal(&s);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
|
||||
|
@ -125,7 +125,7 @@ class ShapeUtil {
|
||||
|
||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||
// object.
|
||||
static StatusOr<Shape> ParseShapeString(const string& s);
|
||||
static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
|
||||
|
||||
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
|
||||
// not check element type.
|
||||
|
@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) {
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
|
||||
string shape_string = "(f32[1572864],s8[5120,1024])";
|
||||
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
|
||||
Shape expected =
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
|
||||
ShapeUtil::MakeShape(S8, {5120, 1024})});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
|
||||
string shape_string = "(f32[1],(f32[2]), f32[3])";
|
||||
Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
|
||||
Shape expected = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
|
||||
ShapeUtil::MakeShape(F32, {3}),
|
||||
});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
|
||||
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user