Add TOKEN primitive type.

The token type will be threaded through side-effecting ops to order them. Subsequent cls will add new opcodes and change side effecting operations to support this ordering.

This CL also does some cleanup in shape_util and layout_util where we have assumed that shapes are either arrays or tuples.

PiperOrigin-RevId: 199215963
This commit is contained in:
Mark Heffernan 2018-06-04 16:41:46 -07:00 committed by TensorFlower Gardener
parent cf01d118ef
commit 14d4d1634d
6 changed files with 305 additions and 150 deletions

View File

@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} // namespace
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
// Opaque and token types have empty layouts.
return Layout();
}
// A Layout proto corresponds to a single array, not a tuple.
DCHECK(!ShapeUtil::IsTuple(shape));
CHECK(ShapeUtil::IsArray(shape));
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
SetToDefaultLayout(&element_shape);
}
shape->clear_layout();
} else if (ShapeUtil::IsOpaque(*shape)) {
shape->clear_layout();
} else {
} else if (ShapeUtil::IsArray(*shape)) {
shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major);
} else {
// Opaque, token types etc. have no layout.
shape->clear_layout();
}
}
@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
}
return Status::OK();
} else if (ShapeUtil::IsOpaque(shape)) {
if (shape.has_layout()) {
return InvalidArgument("opaque should not have a layout field");
}
return Status::OK();
} else {
// Array shape.
} else if (ShapeUtil::IsArray(shape)) {
if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout",
ShapeUtil::HumanString(shape).c_str());
}
return ValidateLayoutForShape(shape.layout(), shape);
} else {
// Token, opaque, etc. shape.
if (shape.has_layout()) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
}
return Status::OK();
}
}
@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return InvalidArgument("a single Layout is not valid for tuple shapes");
}
if (ShapeUtil::IsOpaque(shape)) {
return Status::OK();
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
}
if (layout.format() == INVALID_FORMAT) {
@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) {
return false;
}
@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
// Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); });
} else if (ShapeUtil::IsOpaque(shape)) {
} else if (!ShapeUtil::IsArray(shape)) {
// Opaque, token types etc. ignore layout.
return true;
}
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) {
if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
return false;
}
if (ShapeUtil::IsTuple(lhs)) {
if (ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) {
if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) {
return false;
}
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
}
return true;
} else {
} else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
LayoutUtil::Equal(lhs.layout(), rhs.layout());
} else {
// Layouts of non-array and non-tuple shapes is ignored.
return true;
}
}

View File

@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
"elements, but shape is rank"));
}
TEST_F(LayoutUtilTest, CopyTokenLayout) {
Shape src = ShapeUtil::MakeTokenShape();
Shape dst = ShapeUtil::MakeTokenShape();
// Layouts are trivially the same for token types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
Shape src = ShapeUtil::MakeOpaqueShape();
Shape dst = ShapeUtil::MakeOpaqueShape();
// Layouts are trivially the same for opaque types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
Shape dst = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, ClearLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
}
TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
// Opaque and token types trivially have layouts.
for (Shape shape :
{ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
LayoutUtil::ClearLayout(&shape);
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
}
}
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@ -42,17 +41,18 @@ limitations under the License.
namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
string ShapeIndex::ToString() const {
return tensorflow::strings::StrCat(
"{", tensorflow::str_util::Join(indices_, ","), "}");
return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
}
string ShapeIndexView::ToString() const {
return tensorflow::strings::StrCat(
"{",
tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
","),
"}");
return StrCat("{",
tensorflow::str_util::Join(
tensorflow::gtl::make_range(begin_, end_), ","),
"}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
@ -84,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
namespace {
// Returns whether the given primitive type corresponds to an array shape.
bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
primitive_type != OPAQUE && primitive_type != TOKEN;
}
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
if (!ShapeUtil::SameElementType(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
if (ShapeUtil::IsTuple(lhs)) {
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) {
return CompareShapes(l, r, compare_layouts);
});
} else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
return true;
}
if (compare_layouts) {
@ -125,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ShapeUtil::SameElementType(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
return true;
}
@ -171,8 +179,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
CHECK(!ShapeUtil::IsTuple(shape))
<< "Tuples do not have a rank, shape: " << shape;
CHECK(ShapeUtil::IsArray(shape))
<< "Non-arrays do not have a rank, shape: " << shape;
return shape.dimensions_size();
}
@ -199,8 +207,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShape(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
DCHECK_NE(TUPLE, element_type);
DCHECK_NE(OPAQUE, element_type);
CHECK(IsArrayPrimitiveType(element_type));
Shape result;
PopulateShape(element_type, dimensions, &result);
return result;
@ -223,8 +230,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) {
DCHECK_NE(TUPLE, element_type);
DCHECK_NE(OPAQUE, element_type);
CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
@ -271,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return result;
}
/* static */ Shape ShapeUtil::MakeTokenShape() {
Shape result;
result.set_element_type(TOKEN);
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
return result;
}
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
Shape* tuple_shape) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
@ -294,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
if (!IsArray(shape)) {
return false;
}
return primitive_util::BitWidth(shape.element_type()) == bits;
@ -320,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
case C64:
case TUPLE:
case OPAQUE:
case TOKEN:
return false;
default:
@ -335,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return primitive_util::IsFloatingPointType(shape.element_type());
}
/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
return IsArrayPrimitiveType(shape.element_type());
}
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
shape.tuple_shapes().end(), IsTuple);
@ -388,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
CHECK_EQ(shape.dimensions_size(), Rank(shape));
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
@ -403,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.element_type() == F32 && Rank(shape) == 0;
}
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
} else {
return tensorflow::strings::StrCat(
tensorflow::str_util::Lowercase(
PrimitiveType_Name(shape.element_type())),
"[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
}
}
namespace {
@ -470,48 +471,56 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
} // namespace
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
}
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
tensorflow::str_util::Join(shape.dimensions(), ","), "]");
}
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
tensorflow::strings::StrAppend(&text, prefix,
HumanStringWithLayout(elem_shape));
StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
prefix = ", ";
}
text += ")";
return text;
} else {
string result = tensorflow::strings::StrCat(
LowercasePrimitiveTypeName(shape.element_type()), "[");
for (int i = 0; i < shape.dimensions().size(); i++) {
tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
shape.dimensions(i));
}
result += "]";
if (!IsScalar(shape) && !IsOpaque(shape)) {
if (LayoutUtil::HasLayout(shape)) {
tensorflow::strings::StrAppend(&result,
LayoutUtil::HumanString(shape.layout()));
}
}
return result;
}
string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
for (int i = 0; i < shape.dimensions().size(); i++) {
StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
}
result += "]";
if (!IsScalar(shape) && IsArray(shape)) {
if (LayoutUtil::HasLayout(shape)) {
StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
}
}
return result;
}
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
std::vector<string> parameters;
for (auto& shape : program_shape.parameters()) {
const int i = parameters.size();
parameters.push_back(
tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
? program_shape.parameter_names(i)
: "(unknown)",
": ", HumanString(shape)));
parameters.push_back(StrCat(i < program_shape.parameter_names_size()
? program_shape.parameter_names(i)
: "(unknown)",
": ", HumanString(shape)));
}
return tensorflow::strings::StrCat(
"(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
HumanString(program_shape.result()));
return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
HumanString(program_shape.result()));
}
namespace {
@ -581,14 +590,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the primitive element type.
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
StringToPrimitiveType(element_type_string));
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
primitive_type == OPAQUE) {
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str());
}
Shape result;
if (format_string.empty() && layout_string.empty()) {
if (primitive_type == OPAQUE) {
result = ShapeUtil::MakeOpaqueShape();
} else if (primitive_type == TOKEN) {
result = ShapeUtil::MakeTokenShape();
} else if (format_string.empty() && layout_string.empty()) {
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else if (format_string == "sparse") {
@ -633,43 +645,44 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
if (lhs.element_type() == TUPLE) {
if (IsArray(lhs)) {
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
}
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
if (lhs.element_type() == TUPLE) {
if (IsArray(lhs)) {
return IsArray(rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
}
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
if (lhs.element_type() == TUPLE) {
if (IsArray(lhs)) {
return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) &&
CompatibleIgnoringElementType(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision);
} else {
// Opaque, token, etc types are vacuously compatible.
return true;
}
if (lhs.element_type() == OPAQUE) {
return rhs.element_type() == OPAQUE;
}
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return CompatibleIgnoringElementType(lhs, rhs);
}
return false;
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
@ -691,10 +704,6 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
switch (primitive_type) {
case PRED:
return sizeof(int8);
case TUPLE:
LOG(FATAL) << "tuples have no definitive size";
case OPAQUE:
LOG(FATAL) << "opaque have no definitive size";
case S8:
return sizeof(int8);
case S16:
@ -721,6 +730,13 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(double);
case C64:
return sizeof(complex64);
case TOKEN:
// Tokens require no space.
return 0;
case TUPLE:
case OPAQUE:
LOG(FATAL) << PrimitiveType_Name(primitive_type)
<< " primitive type has no definitive size";
default:
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
}
@ -729,28 +745,32 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) {
return ByteSizeOfTupleIndexTable(shape, pointer_size);
} else if (IsArray(shape)) {
int64 byte_size = ByteSizeOfElements(shape);
if (LayoutUtil::IsSparseArray(shape)) {
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
} else if (shape.element_type() == TOKEN) {
return 0;
}
int64 byte_size = ByteSizeOfElements(shape);
if (LayoutUtil::IsSparseArray(shape)) {
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
<< " primitive type has no definitive size";
}
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK_EQ(TUPLE, shape.element_type());
CHECK_EQ(TUPLE, shape.element_type());
CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size();
}
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK(ShapeUtil::IsArray(shape));
CHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count;
if (LayoutUtil::IsSparseArray(shape)) {
@ -775,13 +795,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK(LayoutUtil::IsSparseArray(shape));
CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString().c_str());
}
if (shape.element_type() == TUPLE) {
if (shape.dimensions_size() != 0) {
return InvalidArgument("tuples must not have dimensions specified");
@ -797,10 +821,24 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape.tuple_shapes_size() > 0) {
return InvalidArgument("non-tuple shape has tuple_shapes field");
}
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString().c_str());
// Tokens and opaques can should not have layout or dimensions.
if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
if (shape.dimensions_size() != 0) {
return InvalidArgument(
"shape has %s element type, but has dimensions field: %s",
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
shape.ShortDebugString().c_str());
}
if (shape.has_layout()) {
return InvalidArgument(
"shape has %s element type, but has layout field: %s",
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
shape.ShortDebugString().c_str());
}
return Status::OK();
}
if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument(
"shape's rank is mismatched with dimension count; rank=%lld "
@ -902,6 +940,8 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
CHECK(IsArray(shape));
std::vector<int64> dimension_sizes;
std::vector<int64> degenerate_dimensions;
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
@ -1066,6 +1106,9 @@ Status ForEachMutableSubshapeHelper(
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post) {
CHECK(IsArray(shape_pre));
CHECK(IsArray(shape_post));
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
std::vector<int64> deleted_indices;
@ -1123,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
/* static */ std::vector<std::pair<int64, int64>>
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
const Shape& output_shape) {
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
// Unmodified dimensions are merely common factors of rank 1.
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
@ -1176,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
const Shape& output_shape) {
CHECK(LayoutUtil::HasLayout(input_shape) &&
LayoutUtil::HasLayout(output_shape));
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
CHECK(LayoutUtil::HasLayout(input_shape));
CHECK(LayoutUtil::HasLayout(output_shape));
if (!SameElementType(input_shape, output_shape)) {
return false;
@ -1339,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
const Shape& input_shape, const Shape& output_shape) {
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
int64 input_rank = Rank(input_shape);
int64 output_rank = Rank(output_shape);
@ -1473,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
Shape shape) {
CHECK(IsArray(shape));
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout();
@ -1494,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) {
CHECK(IsArray(shape));
std::vector<int64> dims_to_delete;
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
if (!p(i)) {

View File

@ -169,7 +169,7 @@ class ShapeUtil {
// may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape.
// Precondition: !IsTuple(shape)
// Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape);
// Returns true if 'shape' has zero elements.
@ -180,13 +180,11 @@ class ShapeUtil {
// shapes. This includes only the size of the top-level buffer. For example, a
// tuple is stored as an array of pointers to other buffers. In this case,
// this method only returns the size of the pointer array.
// Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
// !ShapeUtil::IsOpaque(shape)
static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
// Returns the number of bytes used to store the primitive_type.
//
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
// Precondition: ShapeUtil::IsArray(shape)
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
// Returns the number of bytes required to store the tuple member pointers for
@ -245,7 +243,7 @@ class ShapeUtil {
}
// Returns the higher-precision element type if a and b are both floating
// point types; otherwise, checks that they have the same element type
// point types; otherwise, checks that that they have the same element type
// and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) {
@ -293,10 +291,10 @@ class ShapeUtil {
// Scalar-specific
static bool IsScalar(const Shape& shape) {
return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
return IsArray(shape) && Rank(shape) == 0;
}
static bool IsEffectiveScalar(const Shape& shape) {
return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
return IsArray(shape) && TrueRank(shape) == 0;
}
static bool IsScalarF32(const Shape& shape);
@ -325,6 +323,10 @@ class ShapeUtil {
// into a custom operation.
static Shape MakeOpaqueShape();
// Creates a token shape. Values of this shape are used for ordering
// side-effecting operations.
static Shape MakeTokenShape();
// Appends a shape to the given tuple.
static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
@ -424,11 +426,15 @@ class ShapeUtil {
return shape.element_type() == OPAQUE;
}
// Returns whether the shape is an token value used for ordering
// side-effecting operations.
static bool IsToken(const Shape& shape) {
return shape.element_type() == TOKEN;
}
// Returns whether the shape is an array. Note that scalars are considered
// arrays.
static bool IsArray(const Shape& shape) {
return !IsTuple(shape) && !IsOpaque(shape);
}
static bool IsArray(const Shape& shape);
// Returns whether the shape is a tuple with at least one element which is
// also a tuple.

View File

@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
}
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2]), f32[3])";
string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
ShapeUtil::MakeOpaqueShape(),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseOpaqueType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString("opaque[]"));
Shape expected = ShapeUtil::MakeOpaqueShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseTokenType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
Shape expected = ShapeUtil::MakeTokenShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST(ShapeUtilTest, ParseInvalidShapeString) {
string shape_strings[] = {
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN));
EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape()));
}
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) {
TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
ShapeUtil::HumanString(tuple));
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(nested_tuple));
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) {
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
ShapeUtil::HumanStringWithLayout(tuple));
EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})",
ShapeUtil::HumanStringWithLayout(nested_tuple));
EXPECT_EQ(
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
ShapeUtil::HumanStringWithLayout(nested_tuple));
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) {
"(unknown): u32[1,2], "
"(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0");
@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) {
"matrix: u32[1,2], "
"matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
"token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
}

View File

@ -66,11 +66,16 @@ enum PrimitiveType {
// in the dimensions field.
TUPLE = 13;
// An opaque type used for passing context specific data to a custom
// operation.
// An opaque type used for passing context-specific data to a custom
// operation. Shapes of this primitive type will have empty dimensions and
// tuple_shapes fields.
OPAQUE = 14;
// Next = 17
// A token type threaded between side-effecting operations. Shapes of this
// primitive type will have empty dimensions and tuple_shapes fields.
TOKEN = 17;
// Next = 18
}
// Describes the value held inside padding elements.