Add TOKEN primitive type.
The token type will be threaded through side-effecting ops to order them. Subsequent cls will add new opcodes and change side effecting operations to support this ordering. This CL also does some cleanup in shape_util and layout_util where we have assumed that shapes are either arrays or tuples. PiperOrigin-RevId: 199215963
This commit is contained in:
parent
cf01d118ef
commit
14d4d1634d
@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
} // namespace
|
||||
|
||||
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
|
||||
if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
|
||||
// Opaque and token types have empty layouts.
|
||||
return Layout();
|
||||
}
|
||||
|
||||
// A Layout proto corresponds to a single array, not a tuple.
|
||||
DCHECK(!ShapeUtil::IsTuple(shape));
|
||||
CHECK(ShapeUtil::IsArray(shape));
|
||||
return CreateDefaultLayoutForRank(shape.dimensions_size());
|
||||
}
|
||||
|
||||
@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
SetToDefaultLayout(&element_shape);
|
||||
}
|
||||
shape->clear_layout();
|
||||
} else if (ShapeUtil::IsOpaque(*shape)) {
|
||||
shape->clear_layout();
|
||||
} else {
|
||||
} else if (ShapeUtil::IsArray(*shape)) {
|
||||
shape->mutable_layout()->set_format(DENSE);
|
||||
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
|
||||
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
|
||||
minor_to_major->Resize(shape->dimensions_size(), 0);
|
||||
SetDefaultLayoutToContainer(minor_to_major);
|
||||
} else {
|
||||
// Opaque, token types etc. have no layout.
|
||||
shape->clear_layout();
|
||||
}
|
||||
}
|
||||
|
||||
@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (ShapeUtil::IsOpaque(shape)) {
|
||||
if (shape.has_layout()) {
|
||||
return InvalidArgument("opaque should not have a layout field");
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Array shape.
|
||||
} else if (ShapeUtil::IsArray(shape)) {
|
||||
if (!shape.has_layout()) {
|
||||
return InvalidArgument("shape %s does not have a layout",
|
||||
ShapeUtil::HumanString(shape).c_str());
|
||||
}
|
||||
return ValidateLayoutForShape(shape.layout(), shape);
|
||||
} else {
|
||||
// Token, opaque, etc. shape.
|
||||
if (shape.has_layout()) {
|
||||
return InvalidArgument(
|
||||
"shape of primitive type %s should not have a layout",
|
||||
PrimitiveType_Name(shape.element_type()).c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
return InvalidArgument("a single Layout is not valid for tuple shapes");
|
||||
}
|
||||
|
||||
if (ShapeUtil::IsOpaque(shape)) {
|
||||
return Status::OK();
|
||||
if (!ShapeUtil::IsArray(shape)) {
|
||||
return InvalidArgument(
|
||||
"shape of primitive type %s should not have a layout",
|
||||
PrimitiveType_Name(shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (layout.format() == INVALID_FORMAT) {
|
||||
@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
|
||||
if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
|
||||
if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
|
||||
shape.layout().padded_dimensions_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
// Tuple shape: all subshapes must have a layout.
|
||||
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
|
||||
[](const Shape& s) { return HasLayout(s); });
|
||||
} else if (ShapeUtil::IsOpaque(shape)) {
|
||||
} else if (!ShapeUtil::IsArray(shape)) {
|
||||
// Opaque, token types etc. ignore layout.
|
||||
return true;
|
||||
}
|
||||
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
|
||||
@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
|
||||
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
|
||||
return false;
|
||||
}
|
||||
if (ShapeUtil::IsTuple(lhs)) {
|
||||
if (ShapeUtil::TupleElementCount(lhs) !=
|
||||
ShapeUtil::TupleElementCount(rhs)) {
|
||||
if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
|
||||
ShapeUtil::TupleElementCount(rhs)) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
|
||||
@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
} else if (ShapeUtil::IsArray(lhs)) {
|
||||
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
|
||||
LayoutUtil::Equal(lhs.layout(), rhs.layout());
|
||||
} else {
|
||||
// Layouts of non-array and non-tuple shapes is ignored.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
|
||||
"elements, but shape is rank"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyTokenLayout) {
|
||||
Shape src = ShapeUtil::MakeTokenShape();
|
||||
Shape dst = ShapeUtil::MakeTokenShape();
|
||||
|
||||
// Layouts are trivially the same for token types and copying layouts should
|
||||
// be a nop.
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
|
||||
Shape src = ShapeUtil::MakeOpaqueShape();
|
||||
Shape dst = ShapeUtil::MakeOpaqueShape();
|
||||
|
||||
// Layouts are trivially the same for opaque types and copying layouts should
|
||||
// be a nop.
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
|
||||
Shape src = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
|
||||
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
|
||||
Shape dst = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
|
||||
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
|
||||
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, ClearLayoutTuple) {
|
||||
Shape shape = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
|
||||
@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
|
||||
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
|
||||
// Opaque and token types trivially have layouts.
|
||||
for (Shape shape :
|
||||
{ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
|
||||
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
|
||||
LayoutUtil::ClearLayout(&shape);
|
||||
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
|
||||
Shape shape = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
||||
@ -42,17 +41,18 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
string ShapeIndex::ToString() const {
|
||||
return tensorflow::strings::StrCat(
|
||||
"{", tensorflow::str_util::Join(indices_, ","), "}");
|
||||
return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
|
||||
}
|
||||
|
||||
string ShapeIndexView::ToString() const {
|
||||
return tensorflow::strings::StrCat(
|
||||
"{",
|
||||
tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
|
||||
","),
|
||||
"}");
|
||||
return StrCat("{",
|
||||
tensorflow::str_util::Join(
|
||||
tensorflow::gtl::make_range(begin_, end_), ","),
|
||||
"}");
|
||||
}
|
||||
|
||||
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
|
||||
@ -84,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns whether the given primitive type corresponds to an array shape.
|
||||
bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
|
||||
return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
|
||||
primitive_type != OPAQUE && primitive_type != TOKEN;
|
||||
}
|
||||
|
||||
// Recursive helper for comparing the equality of two shapes. Returns true if
|
||||
// the shapes are the same. If compare_layouts is true, then layouts must also
|
||||
// match.
|
||||
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
|
||||
return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
if (!ShapeUtil::SameElementType(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ShapeUtil::IsTuple(lhs)) {
|
||||
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
[=](const Shape& l, const Shape& r) {
|
||||
return CompareShapes(l, r, compare_layouts);
|
||||
});
|
||||
} else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
|
||||
return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
|
||||
} else if (!ShapeUtil::IsArray(lhs)) {
|
||||
// Non-tuple, non-array tupes such as opaque and token types are trivially
|
||||
// the same.
|
||||
return true;
|
||||
}
|
||||
|
||||
if (compare_layouts) {
|
||||
@ -125,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
}
|
||||
if (!ShapeUtil::SameElementType(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -171,8 +179,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
|
||||
CHECK(!ShapeUtil::IsTuple(shape))
|
||||
<< "Tuples do not have a rank, shape: " << shape;
|
||||
CHECK(ShapeUtil::IsArray(shape))
|
||||
<< "Non-arrays do not have a rank, shape: " << shape;
|
||||
return shape.dimensions_size();
|
||||
}
|
||||
|
||||
@ -199,8 +207,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShape(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
DCHECK_NE(TUPLE, element_type);
|
||||
DCHECK_NE(OPAQUE, element_type);
|
||||
CHECK(IsArrayPrimitiveType(element_type));
|
||||
Shape result;
|
||||
PopulateShape(element_type, dimensions, &result);
|
||||
return result;
|
||||
@ -223,8 +230,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
DCHECK_NE(TUPLE, element_type);
|
||||
DCHECK_NE(OPAQUE, element_type);
|
||||
CHECK(IsArrayPrimitiveType(element_type));
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
|
||||
@ -271,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeTokenShape() {
|
||||
Shape result;
|
||||
result.set_element_type(TOKEN);
|
||||
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
|
||||
Shape* tuple_shape) {
|
||||
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
|
||||
@ -294,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
|
||||
if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
|
||||
if (!IsArray(shape)) {
|
||||
return false;
|
||||
}
|
||||
return primitive_util::BitWidth(shape.element_type()) == bits;
|
||||
@ -320,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
case C64:
|
||||
case TUPLE:
|
||||
case OPAQUE:
|
||||
case TOKEN:
|
||||
return false;
|
||||
|
||||
default:
|
||||
@ -335,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return primitive_util::IsFloatingPointType(shape.element_type());
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
|
||||
return IsArrayPrimitiveType(shape.element_type());
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
|
||||
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
|
||||
shape.tuple_shapes().end(), IsTuple);
|
||||
@ -388,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
|
||||
CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
|
||||
CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
|
||||
CHECK_EQ(shape.dimensions_size(), Rank(shape));
|
||||
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
|
||||
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
|
||||
@ -403,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return shape.element_type() == F32 && Rank(shape) == 0;
|
||||
}
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
|
||||
if (IsTuple(shape)) {
|
||||
string text = "(";
|
||||
const char* prefix = "";
|
||||
for (const Shape& elem_shape : shape.tuple_shapes()) {
|
||||
tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
|
||||
prefix = ", ";
|
||||
}
|
||||
text += ")";
|
||||
return text;
|
||||
} else {
|
||||
return tensorflow::strings::StrCat(
|
||||
tensorflow::str_util::Lowercase(
|
||||
PrimitiveType_Name(shape.element_type())),
|
||||
"[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -470,48 +471,56 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
|
||||
if (IsTuple(shape)) {
|
||||
string text = "(";
|
||||
const char* prefix = "";
|
||||
for (const Shape& elem_shape : shape.tuple_shapes()) {
|
||||
StrAppend(&text, prefix, HumanString(elem_shape));
|
||||
prefix = ", ";
|
||||
}
|
||||
text += ")";
|
||||
return text;
|
||||
}
|
||||
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
|
||||
tensorflow::str_util::Join(shape.dimensions(), ","), "]");
|
||||
}
|
||||
|
||||
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
|
||||
if (IsTuple(shape)) {
|
||||
string text = "(";
|
||||
const char* prefix = "";
|
||||
for (const Shape& elem_shape : shape.tuple_shapes()) {
|
||||
tensorflow::strings::StrAppend(&text, prefix,
|
||||
HumanStringWithLayout(elem_shape));
|
||||
StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
|
||||
prefix = ", ";
|
||||
}
|
||||
text += ")";
|
||||
return text;
|
||||
} else {
|
||||
string result = tensorflow::strings::StrCat(
|
||||
LowercasePrimitiveTypeName(shape.element_type()), "[");
|
||||
for (int i = 0; i < shape.dimensions().size(); i++) {
|
||||
tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
|
||||
shape.dimensions(i));
|
||||
}
|
||||
result += "]";
|
||||
if (!IsScalar(shape) && !IsOpaque(shape)) {
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
tensorflow::strings::StrAppend(&result,
|
||||
LayoutUtil::HumanString(shape.layout()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
|
||||
for (int i = 0; i < shape.dimensions().size(); i++) {
|
||||
StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
|
||||
}
|
||||
result += "]";
|
||||
if (!IsScalar(shape) && IsArray(shape)) {
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
|
||||
std::vector<string> parameters;
|
||||
for (auto& shape : program_shape.parameters()) {
|
||||
const int i = parameters.size();
|
||||
parameters.push_back(
|
||||
tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
|
||||
? program_shape.parameter_names(i)
|
||||
: "(unknown)",
|
||||
": ", HumanString(shape)));
|
||||
parameters.push_back(StrCat(i < program_shape.parameter_names_size()
|
||||
? program_shape.parameter_names(i)
|
||||
: "(unknown)",
|
||||
": ", HumanString(shape)));
|
||||
}
|
||||
return tensorflow::strings::StrCat(
|
||||
"(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
|
||||
HumanString(program_shape.result()));
|
||||
return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -581,14 +590,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
// Extract the primitive element type.
|
||||
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
|
||||
StringToPrimitiveType(element_type_string));
|
||||
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
|
||||
primitive_type == OPAQUE) {
|
||||
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".",
|
||||
element_type_string.c_str());
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (format_string.empty() && layout_string.empty()) {
|
||||
if (primitive_type == OPAQUE) {
|
||||
result = ShapeUtil::MakeOpaqueShape();
|
||||
} else if (primitive_type == TOKEN) {
|
||||
result = ShapeUtil::MakeTokenShape();
|
||||
} else if (format_string.empty() && layout_string.empty()) {
|
||||
// Create a shape without a layout set.
|
||||
result = ShapeUtil::MakeShape(primitive_type, dimensions);
|
||||
} else if (format_string == "sparse") {
|
||||
@ -633,43 +645,44 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
|
||||
if (lhs.element_type() == TUPLE) {
|
||||
if (IsArray(lhs)) {
|
||||
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
|
||||
} else if (lhs.element_type() == TUPLE) {
|
||||
return rhs.element_type() == TUPLE &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
|
||||
} else {
|
||||
// Opaque, token, etc types are vacuously compatible.
|
||||
return true;
|
||||
}
|
||||
if (lhs.element_type() == OPAQUE) {
|
||||
return rhs.element_type() == OPAQUE;
|
||||
}
|
||||
return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
if (lhs.element_type() == TUPLE) {
|
||||
if (IsArray(lhs)) {
|
||||
return IsArray(rhs) && SameDimensions(lhs, rhs);
|
||||
} else if (lhs.element_type() == TUPLE) {
|
||||
return rhs.element_type() == TUPLE &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
CompatibleIgnoringElementType);
|
||||
} else {
|
||||
// Opaque, token, etc types are vacuously compatible.
|
||||
return true;
|
||||
}
|
||||
if (lhs.element_type() == OPAQUE) {
|
||||
return rhs.element_type() == OPAQUE;
|
||||
}
|
||||
return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
if (lhs.element_type() == TUPLE) {
|
||||
if (IsArray(lhs)) {
|
||||
return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) &&
|
||||
CompatibleIgnoringElementType(lhs, rhs);
|
||||
} else if (lhs.element_type() == TUPLE) {
|
||||
return rhs.element_type() == TUPLE &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
CompatibleIgnoringFpPrecision);
|
||||
} else {
|
||||
// Opaque, token, etc types are vacuously compatible.
|
||||
return true;
|
||||
}
|
||||
if (lhs.element_type() == OPAQUE) {
|
||||
return rhs.element_type() == OPAQUE;
|
||||
}
|
||||
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||
return CompatibleIgnoringElementType(lhs, rhs);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
|
||||
@ -691,10 +704,6 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
switch (primitive_type) {
|
||||
case PRED:
|
||||
return sizeof(int8);
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "tuples have no definitive size";
|
||||
case OPAQUE:
|
||||
LOG(FATAL) << "opaque have no definitive size";
|
||||
case S8:
|
||||
return sizeof(int8);
|
||||
case S16:
|
||||
@ -721,6 +730,13 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
return sizeof(double);
|
||||
case C64:
|
||||
return sizeof(complex64);
|
||||
case TOKEN:
|
||||
// Tokens require no space.
|
||||
return 0;
|
||||
case TUPLE:
|
||||
case OPAQUE:
|
||||
LOG(FATAL) << PrimitiveType_Name(primitive_type)
|
||||
<< " primitive type has no definitive size";
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
|
||||
}
|
||||
@ -729,28 +745,32 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
|
||||
int64 pointer_size) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK_NE(OPAQUE, shape.element_type());
|
||||
if (shape.element_type() == TUPLE) {
|
||||
return ByteSizeOfTupleIndexTable(shape, pointer_size);
|
||||
} else if (IsArray(shape)) {
|
||||
int64 byte_size = ByteSizeOfElements(shape);
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
byte_size += ByteSizeOfSparseIndices(shape);
|
||||
}
|
||||
return byte_size;
|
||||
} else if (shape.element_type() == TOKEN) {
|
||||
return 0;
|
||||
}
|
||||
int64 byte_size = ByteSizeOfElements(shape);
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
byte_size += ByteSizeOfSparseIndices(shape);
|
||||
}
|
||||
return byte_size;
|
||||
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
|
||||
<< " primitive type has no definitive size";
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
|
||||
int64 pointer_size) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK_EQ(TUPLE, shape.element_type());
|
||||
CHECK_EQ(TUPLE, shape.element_type());
|
||||
CHECK_GT(pointer_size, 0);
|
||||
return pointer_size * shape.tuple_shapes_size();
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK(ShapeUtil::IsArray(shape));
|
||||
CHECK(ShapeUtil::IsArray(shape));
|
||||
int64 allocated_element_count;
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
@ -775,13 +795,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK(LayoutUtil::IsSparseArray(shape));
|
||||
CHECK(LayoutUtil::IsSparseArray(shape));
|
||||
return LayoutUtil::MaxSparseElements(shape.layout()) *
|
||||
ShapeUtil::Rank(shape) * sizeof(int64);
|
||||
}
|
||||
|
||||
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
|
||||
const Shape& shape) {
|
||||
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("shape has invalid element type: %s",
|
||||
shape.ShortDebugString().c_str());
|
||||
}
|
||||
if (shape.element_type() == TUPLE) {
|
||||
if (shape.dimensions_size() != 0) {
|
||||
return InvalidArgument("tuples must not have dimensions specified");
|
||||
@ -797,10 +821,24 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
if (shape.tuple_shapes_size() > 0) {
|
||||
return InvalidArgument("non-tuple shape has tuple_shapes field");
|
||||
}
|
||||
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("shape has invalid element type: %s",
|
||||
shape.ShortDebugString().c_str());
|
||||
|
||||
// Tokens and opaques can should not have layout or dimensions.
|
||||
if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
|
||||
if (shape.dimensions_size() != 0) {
|
||||
return InvalidArgument(
|
||||
"shape has %s element type, but has dimensions field: %s",
|
||||
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
|
||||
shape.ShortDebugString().c_str());
|
||||
}
|
||||
if (shape.has_layout()) {
|
||||
return InvalidArgument(
|
||||
"shape has %s element type, but has layout field: %s",
|
||||
LowercasePrimitiveTypeName(shape.element_type()).c_str(),
|
||||
shape.ShortDebugString().c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (Rank(shape) != shape.dimensions_size()) {
|
||||
return InvalidArgument(
|
||||
"shape's rank is mismatched with dimension count; rank=%lld "
|
||||
@ -902,6 +940,8 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
|
||||
CHECK(IsArray(shape));
|
||||
|
||||
std::vector<int64> dimension_sizes;
|
||||
std::vector<int64> degenerate_dimensions;
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
@ -1066,6 +1106,9 @@ Status ForEachMutableSubshapeHelper(
|
||||
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
|
||||
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
||||
const Shape& shape_post) {
|
||||
CHECK(IsArray(shape_pre));
|
||||
CHECK(IsArray(shape_post));
|
||||
|
||||
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
|
||||
|
||||
std::vector<int64> deleted_indices;
|
||||
@ -1123,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
||||
/* static */ std::vector<std::pair<int64, int64>>
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
const Shape& output_shape) {
|
||||
CHECK(IsArray(input_shape));
|
||||
CHECK(IsArray(output_shape));
|
||||
|
||||
// Unmodified dimensions are merely common factors of rank 1.
|
||||
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||
AsInt64Slice(output_shape.dimensions()));
|
||||
@ -1176,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
|
||||
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
|
||||
const Shape& output_shape) {
|
||||
CHECK(LayoutUtil::HasLayout(input_shape) &&
|
||||
LayoutUtil::HasLayout(output_shape));
|
||||
CHECK(IsArray(input_shape));
|
||||
CHECK(IsArray(output_shape));
|
||||
CHECK(LayoutUtil::HasLayout(input_shape));
|
||||
CHECK(LayoutUtil::HasLayout(output_shape));
|
||||
|
||||
if (!SameElementType(input_shape, output_shape)) {
|
||||
return false;
|
||||
@ -1339,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
|
||||
/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
|
||||
const Shape& input_shape, const Shape& output_shape) {
|
||||
CHECK(IsArray(input_shape));
|
||||
CHECK(IsArray(output_shape));
|
||||
|
||||
int64 input_rank = Rank(input_shape);
|
||||
int64 output_rank = Rank(output_shape);
|
||||
|
||||
@ -1473,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
|
||||
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
|
||||
Shape shape) {
|
||||
CHECK(IsArray(shape));
|
||||
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
Layout* layout = shape.mutable_layout();
|
||||
@ -1494,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
|
||||
/* static */ Shape ShapeUtil::FilterDimensions(
|
||||
const std::function<bool(int64)>& p, Shape shape) {
|
||||
CHECK(IsArray(shape));
|
||||
std::vector<int64> dims_to_delete;
|
||||
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
|
||||
if (!p(i)) {
|
||||
|
@ -169,7 +169,7 @@ class ShapeUtil {
|
||||
// may not actually be able to store this number of elements. See
|
||||
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
|
||||
// elements that can be stored in a sparse shape.
|
||||
// Precondition: !IsTuple(shape)
|
||||
// Precondition: IsArray(shape)
|
||||
static int64 ElementsIn(const Shape& shape);
|
||||
|
||||
// Returns true if 'shape' has zero elements.
|
||||
@ -180,13 +180,11 @@ class ShapeUtil {
|
||||
// shapes. This includes only the size of the top-level buffer. For example, a
|
||||
// tuple is stored as an array of pointers to other buffers. In this case,
|
||||
// this method only returns the size of the pointer array.
|
||||
// Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
|
||||
// !ShapeUtil::IsOpaque(shape)
|
||||
static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
|
||||
|
||||
// Returns the number of bytes used to store the primitive_type.
|
||||
//
|
||||
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
|
||||
// Precondition: ShapeUtil::IsArray(shape)
|
||||
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
|
||||
|
||||
// Returns the number of bytes required to store the tuple member pointers for
|
||||
@ -245,7 +243,7 @@ class ShapeUtil {
|
||||
}
|
||||
|
||||
// Returns the higher-precision element type if a and b are both floating
|
||||
// point types; otherwise, checks that they have the same element type
|
||||
// point types; otherwise, checks that that they have the same element type
|
||||
// and returns it.
|
||||
static PrimitiveType HigherPrecisionElementType(const Shape& a,
|
||||
const Shape& b) {
|
||||
@ -293,10 +291,10 @@ class ShapeUtil {
|
||||
// Scalar-specific
|
||||
|
||||
static bool IsScalar(const Shape& shape) {
|
||||
return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
|
||||
return IsArray(shape) && Rank(shape) == 0;
|
||||
}
|
||||
static bool IsEffectiveScalar(const Shape& shape) {
|
||||
return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
|
||||
return IsArray(shape) && TrueRank(shape) == 0;
|
||||
}
|
||||
static bool IsScalarF32(const Shape& shape);
|
||||
|
||||
@ -325,6 +323,10 @@ class ShapeUtil {
|
||||
// into a custom operation.
|
||||
static Shape MakeOpaqueShape();
|
||||
|
||||
// Creates a token shape. Values of this shape are used for ordering
|
||||
// side-effecting operations.
|
||||
static Shape MakeTokenShape();
|
||||
|
||||
// Appends a shape to the given tuple.
|
||||
static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
|
||||
|
||||
@ -424,11 +426,15 @@ class ShapeUtil {
|
||||
return shape.element_type() == OPAQUE;
|
||||
}
|
||||
|
||||
// Returns whether the shape is an token value used for ordering
|
||||
// side-effecting operations.
|
||||
static bool IsToken(const Shape& shape) {
|
||||
return shape.element_type() == TOKEN;
|
||||
}
|
||||
|
||||
// Returns whether the shape is an array. Note that scalars are considered
|
||||
// arrays.
|
||||
static bool IsArray(const Shape& shape) {
|
||||
return !IsTuple(shape) && !IsOpaque(shape);
|
||||
}
|
||||
static bool IsArray(const Shape& shape);
|
||||
|
||||
// Returns whether the shape is a tuple with at least one element which is
|
||||
// also a tuple.
|
||||
|
@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
|
||||
string shape_string = "(f32[1],(f32[2]), f32[3])";
|
||||
string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
|
||||
ShapeUtil::MakeOpaqueShape(),
|
||||
ShapeUtil::MakeShape(F32, {3}),
|
||||
});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseOpaqueType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString("opaque[]"));
|
||||
Shape expected = ShapeUtil::MakeOpaqueShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseTokenType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
|
||||
Shape expected = ShapeUtil::MakeTokenShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseInvalidShapeString) {
|
||||
string shape_strings[] = {
|
||||
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
|
||||
@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
|
||||
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
|
||||
|
||||
EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN));
|
||||
EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape()));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
|
||||
@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) {
|
||||
|
||||
TEST(ShapeUtilTest, HumanString) {
|
||||
Shape opaque = ShapeUtil::MakeOpaqueShape();
|
||||
Shape token = ShapeUtil::MakeTokenShape();
|
||||
Shape scalar = ShapeUtil::MakeShape(F32, {});
|
||||
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
|
||||
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
|
||||
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
|
||||
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
|
||||
|
||||
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
|
||||
EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
|
||||
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
|
||||
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
|
||||
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
|
||||
ShapeUtil::HumanString(tuple));
|
||||
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
|
||||
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(nested_tuple));
|
||||
|
||||
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
|
||||
@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) {
|
||||
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
|
||||
ShapeUtil::HumanStringWithLayout(tuple));
|
||||
EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})",
|
||||
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
||||
EXPECT_EQ(
|
||||
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
||||
"token[])",
|
||||
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
||||
|
||||
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
||||
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
||||
@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) {
|
||||
"(unknown): u32[1,2], "
|
||||
"(unknown): s32[3,4], "
|
||||
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
|
||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
|
||||
prog.add_parameter_names("arg0");
|
||||
@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) {
|
||||
"matrix: u32[1,2], "
|
||||
"matrix2: s32[3,4], "
|
||||
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
|
||||
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
|
||||
"token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
}
|
||||
|
||||
|
@ -66,11 +66,16 @@ enum PrimitiveType {
|
||||
// in the dimensions field.
|
||||
TUPLE = 13;
|
||||
|
||||
// An opaque type used for passing context specific data to a custom
|
||||
// operation.
|
||||
// An opaque type used for passing context-specific data to a custom
|
||||
// operation. Shapes of this primitive type will have empty dimensions and
|
||||
// tuple_shapes fields.
|
||||
OPAQUE = 14;
|
||||
|
||||
// Next = 17
|
||||
// A token type threaded between side-effecting operations. Shapes of this
|
||||
// primitive type will have empty dimensions and tuple_shapes fields.
|
||||
TOKEN = 17;
|
||||
|
||||
// Next = 18
|
||||
}
|
||||
|
||||
// Describes the value held inside padding elements.
|
||||
|
Loading…
Reference in New Issue
Block a user