Add TOKEN primitive type.

The token type will be threaded through side-effecting ops to order them. Subsequent cls will add new opcodes and change side effecting operations to support this ordering.

This CL also does some cleanup in shape_util and layout_util where we have assumed that shapes are either arrays or tuples.

PiperOrigin-RevId: 199215963
This commit is contained in:
Mark Heffernan 2018-06-04 16:41:46 -07:00 committed by TensorFlower Gardener
parent cf01d118ef
commit 14d4d1634d
6 changed files with 305 additions and 150 deletions

View File

@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} // namespace } // namespace
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
// Opaque and token types have empty layouts.
return Layout();
}
// A Layout proto corresponds to a single array, not a tuple. // A Layout proto corresponds to a single array, not a tuple.
DCHECK(!ShapeUtil::IsTuple(shape)); CHECK(ShapeUtil::IsArray(shape));
return CreateDefaultLayoutForRank(shape.dimensions_size()); return CreateDefaultLayoutForRank(shape.dimensions_size());
} }
@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
SetToDefaultLayout(&element_shape); SetToDefaultLayout(&element_shape);
} }
shape->clear_layout(); shape->clear_layout();
} else if (ShapeUtil::IsOpaque(*shape)) { } else if (ShapeUtil::IsArray(*shape)) {
shape->clear_layout();
} else {
shape->mutable_layout()->set_format(DENSE); shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>* tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0); minor_to_major->Resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major); SetDefaultLayoutToContainer(minor_to_major);
} else {
// Opaque, token types etc. have no layout.
shape->clear_layout();
} }
} }
@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
} }
return Status::OK(); return Status::OK();
} else if (ShapeUtil::IsOpaque(shape)) { } else if (ShapeUtil::IsArray(shape)) {
if (shape.has_layout()) {
return InvalidArgument("opaque should not have a layout field");
}
return Status::OK();
} else {
// Array shape.
if (!shape.has_layout()) { if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout", return InvalidArgument("shape %s does not have a layout",
ShapeUtil::HumanString(shape).c_str()); ShapeUtil::HumanString(shape).c_str());
} }
return ValidateLayoutForShape(shape.layout(), shape); return ValidateLayoutForShape(shape.layout(), shape);
} else {
// Token, opaque, etc. shape.
if (shape.has_layout()) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
}
return Status::OK();
} }
} }
@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return InvalidArgument("a single Layout is not valid for tuple shapes"); return InvalidArgument("a single Layout is not valid for tuple shapes");
} }
if (ShapeUtil::IsOpaque(shape)) { if (!ShapeUtil::IsArray(shape)) {
return Status::OK(); return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
} }
if (layout.format() == INVALID_FORMAT) { if (layout.format() == INVALID_FORMAT) {
@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} }
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) { /* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) { shape.layout().padded_dimensions_size() == 0) {
return false; return false;
} }
@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
// Tuple shape: all subshapes must have a layout. // Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); }); [](const Shape& s) { return HasLayout(s); });
} else if (ShapeUtil::IsOpaque(shape)) { } else if (!ShapeUtil::IsArray(shape)) {
// Opaque, token types etc. ignore layout.
return true; return true;
} }
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) { const Shape& rhs) {
if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
return false;
}
if (ShapeUtil::IsTuple(lhs)) { if (ShapeUtil::IsTuple(lhs)) {
if (ShapeUtil::TupleElementCount(lhs) != if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) { ShapeUtil::TupleElementCount(rhs)) {
return false; return false;
} }
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
} }
} }
return true; return true;
} else { } else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
LayoutUtil::Equal(lhs.layout(), rhs.layout()); LayoutUtil::Equal(lhs.layout(), rhs.layout());
} else {
// Layouts of non-array and non-tuple shapes is ignored.
return true;
} }
} }

View File

@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
"elements, but shape is rank")); "elements, but shape is rank"));
} }
TEST_F(LayoutUtilTest, CopyTokenLayout) {
Shape src = ShapeUtil::MakeTokenShape();
Shape dst = ShapeUtil::MakeTokenShape();
// Layouts are trivially the same for token types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
Shape src = ShapeUtil::MakeOpaqueShape();
Shape dst = ShapeUtil::MakeOpaqueShape();
// Layouts are trivially the same for opaque types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
Shape dst = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, ClearLayoutTuple) { TEST_F(LayoutUtilTest, ClearLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape( Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}), {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
} }
TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
// Opaque and token types trivially have layouts.
for (Shape shape :
{ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
LayoutUtil::ClearLayout(&shape);
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
}
}
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape( Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/gtl/iterator_range.h"
@ -42,17 +41,18 @@ limitations under the License.
namespace xla { namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
string ShapeIndex::ToString() const { string ShapeIndex::ToString() const {
return tensorflow::strings::StrCat( return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
"{", tensorflow::str_util::Join(indices_, ","), "}");
} }
string ShapeIndexView::ToString() const { string ShapeIndexView::ToString() const {
return tensorflow::strings::StrCat( return StrCat("{",
"{", tensorflow::str_util::Join(
tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), tensorflow::gtl::make_range(begin_, end_), ","),
","), "}");
"}");
} }
bool ShapeIndexView::operator==(const ShapeIndexView& other) const { bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
@ -84,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
namespace { namespace {
// Returns whether the given primitive type corresponds to an array shape.
bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
primitive_type != OPAQUE && primitive_type != TOKEN;
}
// Recursive helper for comparing the equality of two shapes. Returns true if // Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also // the shapes are the same. If compare_layouts is true, then layouts must also
// match. // match.
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { if (!ShapeUtil::SameElementType(lhs, rhs)) {
return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && VLOG(3) << "CompareShapes: lhs element type != rhs element type";
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), return false;
}
if (ShapeUtil::IsTuple(lhs)) {
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) { [=](const Shape& l, const Shape& r) {
return CompareShapes(l, r, compare_layouts); return CompareShapes(l, r, compare_layouts);
}); });
} else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { } else if (!ShapeUtil::IsArray(lhs)) {
return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); // Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
return true;
} }
if (compare_layouts) { if (compare_layouts) {
@ -125,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false; return false;
} }
if (!ShapeUtil::SameElementType(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
return true; return true;
} }
@ -171,8 +179,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
} }
/* static */ int64 ShapeUtil::Rank(const Shape& shape) { /* static */ int64 ShapeUtil::Rank(const Shape& shape) {
CHECK(!ShapeUtil::IsTuple(shape)) CHECK(ShapeUtil::IsArray(shape))
<< "Tuples do not have a rank, shape: " << shape; << "Non-arrays do not have a rank, shape: " << shape;
return shape.dimensions_size(); return shape.dimensions_size();
} }
@ -199,8 +207,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShape( /* static */ Shape ShapeUtil::MakeShape(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) { PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
DCHECK_NE(TUPLE, element_type); CHECK(IsArrayPrimitiveType(element_type));
DCHECK_NE(OPAQUE, element_type);
Shape result; Shape result;
PopulateShape(element_type, dimensions, &result); PopulateShape(element_type, dimensions, &result);
return result; return result;
@ -223,8 +230,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) { int64 max_sparse_elements) {
DCHECK_NE(TUPLE, element_type); CHECK(IsArrayPrimitiveType(element_type));
DCHECK_NE(OPAQUE, element_type);
Shape shape = ShapeUtil::MakeShape(element_type, dimensions); Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
@ -271,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return result; return result;
} }
/* static */ Shape ShapeUtil::MakeTokenShape() {
Shape result;
result.set_element_type(TOKEN);
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
return result;
}
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
Shape* tuple_shape) { Shape* tuple_shape) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
@ -294,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
} }
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { if (!IsArray(shape)) {
return false; return false;
} }
return primitive_util::BitWidth(shape.element_type()) == bits; return primitive_util::BitWidth(shape.element_type()) == bits;
@ -320,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
case C64: case C64:
case TUPLE: case TUPLE:
case OPAQUE: case OPAQUE:
case TOKEN:
return false; return false;
default: default:
@ -335,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return primitive_util::IsFloatingPointType(shape.element_type()); return primitive_util::IsFloatingPointType(shape.element_type());
} }
/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
return IsArrayPrimitiveType(shape.element_type());
}
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
shape.tuple_shapes().end(), IsTuple); shape.tuple_shapes().end(), IsTuple);
@ -388,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
} }
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
CHECK_EQ(shape.dimensions_size(), Rank(shape)); CHECK_EQ(shape.dimensions_size(), Rank(shape));
return std::accumulate<decltype(shape.dimensions().begin()), int64>( return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL, shape.dimensions().begin(), shape.dimensions().end(), 1LL,
@ -403,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.element_type() == F32 && Rank(shape) == 0; return shape.element_type() == F32 && Rank(shape) == 0;
} }
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
} else {
return tensorflow::strings::StrCat(
tensorflow::str_util::Lowercase(
PrimitiveType_Name(shape.element_type())),
"[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
}
}
namespace { namespace {
@ -470,48 +471,56 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
} // namespace } // namespace
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
}
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
tensorflow::str_util::Join(shape.dimensions(), ","), "]");
}
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
if (IsTuple(shape)) { if (IsTuple(shape)) {
string text = "("; string text = "(";
const char* prefix = ""; const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) { for (const Shape& elem_shape : shape.tuple_shapes()) {
tensorflow::strings::StrAppend(&text, prefix, StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
HumanStringWithLayout(elem_shape));
prefix = ", "; prefix = ", ";
} }
text += ")"; text += ")";
return text; return text;
} else {
string result = tensorflow::strings::StrCat(
LowercasePrimitiveTypeName(shape.element_type()), "[");
for (int i = 0; i < shape.dimensions().size(); i++) {
tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
shape.dimensions(i));
}
result += "]";
if (!IsScalar(shape) && !IsOpaque(shape)) {
if (LayoutUtil::HasLayout(shape)) {
tensorflow::strings::StrAppend(&result,
LayoutUtil::HumanString(shape.layout()));
}
}
return result;
} }
string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
for (int i = 0; i < shape.dimensions().size(); i++) {
StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
}
result += "]";
if (!IsScalar(shape) && IsArray(shape)) {
if (LayoutUtil::HasLayout(shape)) {
StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
}
}
return result;
} }
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
std::vector<string> parameters; std::vector<string> parameters;
for (auto& shape : program_shape.parameters()) { for (auto& shape : program_shape.parameters()) {
const int i = parameters.size(); const int i = parameters.size();
parameters.push_back( parameters.push_back(StrCat(i < program_shape.parameter_names_size()
tensorflow::strings::StrCat(i < program_shape.parameter_names_size() ? program_shape.parameter_names(i)
? program_shape.parameter_names(i) : "(unknown)",
: "(unknown)", ": ", HumanString(shape)));
": ", HumanString(shape)));
} }
return tensorflow::strings::StrCat( return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
"(", tensorflow::str_util::Join(parameters, ", "), ") -> ", HumanString(program_shape.result()));
HumanString(program_shape.result()));
} }
namespace { namespace {
@ -581,14 +590,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the primitive element type. // Extract the primitive element type.
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
StringToPrimitiveType(element_type_string)); StringToPrimitiveType(element_type_string));
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
primitive_type == OPAQUE) {
return InvalidArgument("Invalid element type string: \"%s\".", return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str()); element_type_string.c_str());
} }
Shape result; Shape result;
if (format_string.empty() && layout_string.empty()) { if (primitive_type == OPAQUE) {
result = ShapeUtil::MakeOpaqueShape();
} else if (primitive_type == TOKEN) {
result = ShapeUtil::MakeTokenShape();
} else if (format_string.empty() && layout_string.empty()) {
// Create a shape without a layout set. // Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions); result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else if (format_string == "sparse") { } else if (format_string == "sparse") {
@ -633,43 +645,44 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
} }
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
if (lhs.element_type() == TUPLE) { if (IsArray(lhs)) {
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE && return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
} }
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
} }
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) { const Shape& rhs) {
if (lhs.element_type() == TUPLE) { if (IsArray(lhs)) {
return IsArray(rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE && return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType); CompatibleIgnoringElementType);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
} }
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
} }
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) { const Shape& rhs) {
if (lhs.element_type() == TUPLE) { if (IsArray(lhs)) {
return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) &&
CompatibleIgnoringElementType(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE && return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision); CompatibleIgnoringFpPrecision);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
} }
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return CompatibleIgnoringElementType(lhs, rhs);
}
return false;
} }
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape, /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
@ -691,10 +704,6 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
switch (primitive_type) { switch (primitive_type) {
case PRED: case PRED:
return sizeof(int8); return sizeof(int8);
case TUPLE:
LOG(FATAL) << "tuples have no definitive size";
case OPAQUE:
LOG(FATAL) << "opaque have no definitive size";
case S8: case S8:
return sizeof(int8); return sizeof(int8);
case S16: case S16:
@ -721,6 +730,13 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(double); return sizeof(double);
case C64: case C64:
return sizeof(complex64); return sizeof(complex64);
case TOKEN:
// Tokens require no space.
return 0;
case TUPLE:
case OPAQUE:
LOG(FATAL) << PrimitiveType_Name(primitive_type)
<< " primitive type has no definitive size";
default: default:
LOG(FATAL) << "Unhandled primitive type " << primitive_type; LOG(FATAL) << "Unhandled primitive type " << primitive_type;
} }
@ -729,28 +745,32 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
int64 pointer_size) { int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape)); TF_DCHECK_OK(ValidateShape(shape));
DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) { if (shape.element_type() == TUPLE) {
return ByteSizeOfTupleIndexTable(shape, pointer_size); return ByteSizeOfTupleIndexTable(shape, pointer_size);
} else if (IsArray(shape)) {
int64 byte_size = ByteSizeOfElements(shape);
if (LayoutUtil::IsSparseArray(shape)) {
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
} else if (shape.element_type() == TOKEN) {
return 0;
} }
int64 byte_size = ByteSizeOfElements(shape); LOG(FATAL) << PrimitiveType_Name(shape.element_type())
if (LayoutUtil::IsSparseArray(shape)) { << " primitive type has no definitive size";
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
} }
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size) { int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape)); TF_DCHECK_OK(ValidateShape(shape));
DCHECK_EQ(TUPLE, shape.element_type()); CHECK_EQ(TUPLE, shape.element_type());
CHECK_GT(pointer_size, 0); CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size(); return pointer_size * shape.tuple_shapes_size();
} }
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape)); TF_DCHECK_OK(ValidateShape(shape));
DCHECK(ShapeUtil::IsArray(shape)); CHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count; int64 allocated_element_count;
if (LayoutUtil::IsSparseArray(shape)) { if (LayoutUtil::IsSparseArray(shape)) {
@ -775,13 +795,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape)); TF_DCHECK_OK(ValidateShape(shape));
DCHECK(LayoutUtil::IsSparseArray(shape)); CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) * return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64); ShapeUtil::Rank(shape) * sizeof(int64);
} }
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) { const Shape& shape) {
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString().c_str());
}
if (shape.element_type() == TUPLE) { if (shape.element_type() == TUPLE) {
if (shape.dimensions_size() != 0) { if (shape.dimensions_size() != 0) {
return InvalidArgument("tuples must not have dimensions specified"); return InvalidArgument("tuples must not have dimensions specified");
@ -797,10 +821,24 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape.tuple_shapes_size() > 0) { if (shape.tuple_shapes_size() > 0) {
return InvalidArgument("non-tuple shape has tuple_shapes field"); return InvalidArgument("non-tuple shape has tuple_shapes field");
} }
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("shape has invalid element type: %s", // Tokens and opaques can should not have layout or dimensions.
shape.ShortDebugString().c_str()); if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
if (shape.dimensions_size() != 0) {
return InvalidArgument(
"shape has %s element type, but has dimensions field: %s",
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
shape.ShortDebugString().c_str());
}
if (shape.has_layout()) {
return InvalidArgument(
"shape has %s element type, but has layout field: %s",
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
shape.ShortDebugString().c_str());
}
return Status::OK();
} }
if (Rank(shape) != shape.dimensions_size()) { if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument( return InvalidArgument(
"shape's rank is mismatched with dimension count; rank=%lld " "shape's rank is mismatched with dimension count; rank=%lld "
@ -902,6 +940,8 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
} }
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
CHECK(IsArray(shape));
std::vector<int64> dimension_sizes; std::vector<int64> dimension_sizes;
std::vector<int64> degenerate_dimensions; std::vector<int64> degenerate_dimensions;
for (int64 i = 0; i < shape.dimensions_size(); ++i) { for (int64 i = 0; i < shape.dimensions_size(); ++i) {
@ -1066,6 +1106,9 @@ Status ForEachMutableSubshapeHelper(
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>> /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post) { const Shape& shape_post) {
CHECK(IsArray(shape_pre));
CHECK(IsArray(shape_post));
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>()); auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
std::vector<int64> deleted_indices; std::vector<int64> deleted_indices;
@ -1123,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
/* static */ std::vector<std::pair<int64, int64>> /* static */ std::vector<std::pair<int64, int64>>
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
const Shape& output_shape) { const Shape& output_shape) {
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
// Unmodified dimensions are merely common factors of rank 1. // Unmodified dimensions are merely common factors of rank 1.
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions())); AsInt64Slice(output_shape.dimensions()));
@ -1176,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
const Shape& output_shape) { const Shape& output_shape) {
CHECK(LayoutUtil::HasLayout(input_shape) && CHECK(IsArray(input_shape));
LayoutUtil::HasLayout(output_shape)); CHECK(IsArray(output_shape));
CHECK(LayoutUtil::HasLayout(input_shape));
CHECK(LayoutUtil::HasLayout(output_shape));
if (!SameElementType(input_shape, output_shape)) { if (!SameElementType(input_shape, output_shape)) {
return false; return false;
@ -1339,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts( /* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
const Shape& input_shape, const Shape& output_shape) { const Shape& input_shape, const Shape& output_shape) {
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
int64 input_rank = Rank(input_shape); int64 input_rank = Rank(input_shape);
int64 output_rank = Rank(output_shape); int64 output_rank = Rank(output_shape);
@ -1473,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
Shape shape) { Shape shape) {
CHECK(IsArray(shape));
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
if (LayoutUtil::HasLayout(shape)) { if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout(); Layout* layout = shape.mutable_layout();
@ -1494,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::FilterDimensions( /* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) { const std::function<bool(int64)>& p, Shape shape) {
CHECK(IsArray(shape));
std::vector<int64> dims_to_delete; std::vector<int64> dims_to_delete;
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
if (!p(i)) { if (!p(i)) {

View File

@ -169,7 +169,7 @@ class ShapeUtil {
// may not actually be able to store this number of elements. See // may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape. // elements that can be stored in a sparse shape.
// Precondition: !IsTuple(shape) // Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape); static int64 ElementsIn(const Shape& shape);
// Returns true if 'shape' has zero elements. // Returns true if 'shape' has zero elements.
@ -180,13 +180,11 @@ class ShapeUtil {
// shapes. This includes only the size of the top-level buffer. For example, a // shapes. This includes only the size of the top-level buffer. For example, a
// tuple is stored as an array of pointers to other buffers. In this case, // tuple is stored as an array of pointers to other buffers. In this case,
// this method only returns the size of the pointer array. // this method only returns the size of the pointer array.
// Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
// !ShapeUtil::IsOpaque(shape)
static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
// Returns the number of bytes used to store the primitive_type. // Returns the number of bytes used to store the primitive_type.
// //
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) // Precondition: ShapeUtil::IsArray(shape)
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
// Returns the number of bytes required to store the tuple member pointers for // Returns the number of bytes required to store the tuple member pointers for
@ -245,7 +243,7 @@ class ShapeUtil {
} }
// Returns the higher-precision element type if a and b are both floating // Returns the higher-precision element type if a and b are both floating
// point types; otherwise, checks that they have the same element type // point types; otherwise, checks that that they have the same element type
// and returns it. // and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a, static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) { const Shape& b) {
@ -293,10 +291,10 @@ class ShapeUtil {
// Scalar-specific // Scalar-specific
static bool IsScalar(const Shape& shape) { static bool IsScalar(const Shape& shape) {
return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; return IsArray(shape) && Rank(shape) == 0;
} }
static bool IsEffectiveScalar(const Shape& shape) { static bool IsEffectiveScalar(const Shape& shape) {
return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; return IsArray(shape) && TrueRank(shape) == 0;
} }
static bool IsScalarF32(const Shape& shape); static bool IsScalarF32(const Shape& shape);
@ -325,6 +323,10 @@ class ShapeUtil {
// into a custom operation. // into a custom operation.
static Shape MakeOpaqueShape(); static Shape MakeOpaqueShape();
// Creates a token shape. Values of this shape are used for ordering
// side-effecting operations.
static Shape MakeTokenShape();
// Appends a shape to the given tuple. // Appends a shape to the given tuple.
static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
@ -424,11 +426,15 @@ class ShapeUtil {
return shape.element_type() == OPAQUE; return shape.element_type() == OPAQUE;
} }
// Returns whether the shape is an token value used for ordering
// side-effecting operations.
static bool IsToken(const Shape& shape) {
return shape.element_type() == TOKEN;
}
// Returns whether the shape is an array. Note that scalars are considered // Returns whether the shape is an array. Note that scalars are considered
// arrays. // arrays.
static bool IsArray(const Shape& shape) { static bool IsArray(const Shape& shape);
return !IsTuple(shape) && !IsOpaque(shape);
}
// Returns whether the shape is a tuple with at least one element which is // Returns whether the shape is a tuple with at least one element which is
// also a tuple. // also a tuple.

View File

@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
} }
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2]), f32[3])"; string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString(shape_string)); ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({ Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}), ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
ShapeUtil::MakeOpaqueShape(),
ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}),
}); });
ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
<< "actual: " << ShapeUtil::HumanString(actual); << "actual: " << ShapeUtil::HumanString(actual);
} }
TEST(ShapeUtilTest, ParseOpaqueType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString("opaque[]"));
Shape expected = ShapeUtil::MakeOpaqueShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseTokenType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
Shape expected = ShapeUtil::MakeTokenShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseInvalidShapeString) { TEST(ShapeUtilTest, ParseInvalidShapeString) {
string shape_strings[] = { string shape_strings[] = {
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN));
EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape()));
} }
TEST(ShapeUtilTest, ByteSizeOfWithPadding) { TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) {
TEST(ShapeUtilTest, HumanString) { TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape(); Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
ShapeUtil::HumanString(tuple)); ShapeUtil::HumanString(tuple));
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(nested_tuple)); ShapeUtil::HumanString(nested_tuple));
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) {
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
ShapeUtil::HumanStringWithLayout(tuple)); ShapeUtil::HumanStringWithLayout(tuple));
EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", EXPECT_EQ(
ShapeUtil::HumanStringWithLayout(nested_tuple)); "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
ShapeUtil::HumanStringWithLayout(nested_tuple));
ProgramShape prog = ShapeUtil::MakeProgramShape( ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) {
"(unknown): u32[1,2], " "(unknown): u32[1,2], "
"(unknown): s32[3,4], " "(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", "-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog)); ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0"); prog.add_parameter_names("arg0");
@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) {
"matrix: u32[1,2], " "matrix: u32[1,2], "
"matrix2: s32[3,4], " "matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", "token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog)); ShapeUtil::HumanString(prog));
} }

View File

@ -66,11 +66,16 @@ enum PrimitiveType {
// in the dimensions field. // in the dimensions field.
TUPLE = 13; TUPLE = 13;
// An opaque type used for passing context specific data to a custom // An opaque type used for passing context-specific data to a custom
// operation. // operation. Shapes of this primitive type will have empty dimensions and
// tuple_shapes fields.
OPAQUE = 14; OPAQUE = 14;
// Next = 17 // A token type threaded between side-effecting operations. Shapes of this
// primitive type will have empty dimensions and tuple_shapes fields.
TOKEN = 17;
// Next = 18
} }
// Describes the value held inside padding elements. // Describes the value held inside padding elements.