This avoids creating and destroying a temporary `Shape` in `add_tuple_shapes()`, then move-assigning and destroying each tuple-shape in the constructor. PiperOrigin-RevId: 298703231 Change-Id: I793742c163cf24780b0764aa6c3846715517f99d
228 lines
7.3 KiB
C++
228 lines
7.3 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/shape.h"
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
|
|
namespace xla {
|
|
|
|
Shape::Shape(const ShapeProto& shape_proto) {
|
|
set_element_type(shape_proto.element_type());
|
|
dimensions_.reserve(shape_proto.dimensions_size());
|
|
for (const int64 dimension : shape_proto.dimensions()) {
|
|
add_dimensions(dimension);
|
|
}
|
|
// A malformed proto may have different is_dynamic_dimension_size and
|
|
// dimensions_size. Since C++ is evil, and we have no good way of bailing out
|
|
// in a constructor, conservatively trim the is_dynamic_dimension size.
|
|
// TODO(b/120111794): Make this a hard error when we have a factory method
|
|
// instead of a constructor.
|
|
if (shape_proto.dimensions_size() !=
|
|
shape_proto.is_dynamic_dimension_size()) {
|
|
if (shape_proto.is_dynamic_dimension_size() != 0) {
|
|
LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension "
|
|
"fields does not match number of dimension fields";
|
|
} else {
|
|
LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty";
|
|
}
|
|
}
|
|
int64 num_dynamic_dimension_fields = std::min(
|
|
shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size());
|
|
for (int i = 0; i < num_dynamic_dimension_fields; i++) {
|
|
dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i);
|
|
}
|
|
tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
|
|
for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
|
|
tuple_shapes_.emplace_back(element_shape);
|
|
}
|
|
if (shape_proto.has_layout()) {
|
|
*mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
|
|
}
|
|
}
|
|
|
|
ShapeProto Shape::ToProto() const {
|
|
ShapeProto proto;
|
|
proto.set_element_type(element_type_);
|
|
proto.mutable_dimensions()->Reserve(dimensions_size());
|
|
for (const int64 dimension : dimensions()) {
|
|
proto.add_dimensions(dimension);
|
|
}
|
|
for (const bool dynamic : dynamic_dimensions_) {
|
|
proto.add_is_dynamic_dimension(dynamic);
|
|
}
|
|
proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
|
|
for (const Shape& shape : tuple_shapes()) {
|
|
*proto.add_tuple_shapes() = shape.ToProto();
|
|
}
|
|
if (has_layout()) {
|
|
*proto.mutable_layout() = layout().ToProto();
|
|
}
|
|
return proto;
|
|
}
|
|
|
|
string Shape::ToString(bool print_layout) const {
|
|
if (print_layout) {
|
|
return ShapeUtil::HumanStringWithLayout(*this);
|
|
} else {
|
|
return ShapeUtil::HumanString(*this);
|
|
}
|
|
}
|
|
|
|
bool Shape::is_static() const {
|
|
if (IsTuple()) {
|
|
for (const Shape& subshape : tuple_shapes_) {
|
|
if (!subshape.is_static()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
|
|
}
|
|
|
|
void Shape::DeleteDimension(int64 dim_to_delete) {
|
|
CHECK(IsArray());
|
|
CHECK_GE(dim_to_delete, 0);
|
|
CHECK_LT(dim_to_delete, dimensions_.size());
|
|
dimensions_.erase(dimensions_.begin() + dim_to_delete);
|
|
dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete);
|
|
if (LayoutUtil::HasLayout(*this)) {
|
|
layout_.set_format(DENSE);
|
|
for (int64 i = 0; i < layout_.minor_to_major().size();) {
|
|
if (layout_.minor_to_major(i) == dim_to_delete) {
|
|
layout_.mutable_minor_to_major()->erase(
|
|
layout_.mutable_minor_to_major()->begin() + i);
|
|
continue;
|
|
}
|
|
if (layout_.minor_to_major(i) > dim_to_delete) {
|
|
(*layout_.mutable_minor_to_major())[i] -= 1;
|
|
}
|
|
++i;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
|
|
if (lhs.IsTuple()) {
|
|
return rhs.IsTuple() &&
|
|
absl::c_equal(
|
|
lhs.tuple_shapes(), rhs.tuple_shapes(),
|
|
[=](const Shape& l, const Shape& r) { return (*this)(l, r); });
|
|
} else if (!lhs.IsArray()) {
|
|
// Non-tuple, non-array tupes such as opaque and token types are trivially
|
|
// the same.
|
|
return lhs.element_type() == rhs.element_type();
|
|
}
|
|
|
|
if (!rhs.IsArray()) {
|
|
return false;
|
|
}
|
|
|
|
if (!ignore_element_type_) {
|
|
if ((ignore_fp_precision_ &&
|
|
!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
|
|
(!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) {
|
|
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
|
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
|
return false;
|
|
}
|
|
|
|
if (!ignore_layout_) {
|
|
if (lhs.layout().format() != rhs.layout().format()) {
|
|
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
|
|
return false;
|
|
}
|
|
if (LayoutUtil::IsDenseArray(lhs)) {
|
|
Layout::Equal equal;
|
|
if (ignore_tiles_in_layout_) {
|
|
equal.IgnoreTiles();
|
|
}
|
|
if (ignore_element_size_in_layout_) {
|
|
equal.IgnoreElementSize();
|
|
}
|
|
if (ignore_memory_space_in_layout_) {
|
|
equal.IgnoreMemorySpace();
|
|
}
|
|
if (!equal(lhs.layout(), rhs.layout())) {
|
|
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!ignore_dynamic_dimension_) {
|
|
for (int i = 0; i < lhs.rank(); ++i) {
|
|
if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {
|
|
VLOG(3)
|
|
<< "CompareShapes: lhs and rhs have different dynamic dimensions.";
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
|
|
out << shape.ToString(/*print_layout=*/true);
|
|
return out;
|
|
}
|
|
|
|
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
|
for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
|
|
*add_parameters() = Shape(shape_proto);
|
|
}
|
|
*mutable_result() = Shape(program_shape_proto.result());
|
|
for (const string& name : program_shape_proto.parameter_names()) {
|
|
add_parameter_names(name);
|
|
}
|
|
}
|
|
|
|
ProgramShapeProto ProgramShape::ToProto() const {
|
|
ProgramShapeProto proto;
|
|
for (const Shape& shape : parameters()) {
|
|
*proto.add_parameters() = shape.ToProto();
|
|
}
|
|
*proto.mutable_result() = result().ToProto();
|
|
for (const string& name : parameter_names()) {
|
|
proto.add_parameter_names(name);
|
|
}
|
|
return proto;
|
|
}
|
|
|
|
string ProgramShape::ToString() const {
|
|
std::vector<string> parameter_strings(parameters_size());
|
|
for (int i = 0; i < parameters_size(); ++i) {
|
|
parameter_strings[i] = absl::StrCat(
|
|
i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
|
|
ShapeUtil::HumanString(parameters(i)));
|
|
}
|
|
return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
|
|
ShapeUtil::HumanString(result()));
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
|
|
out << program_shape.ToString() << "\n";
|
|
return out;
|
|
}
|
|
|
|
} // namespace xla
|