Add FooWithStatus
methods for all Foo
methods with CHECK
fails.
This is part of eliminating CHECK fails from `tensor_shape.cc`. The goal is to provide options for the callers to either test the `CHECK` conditions before calling the `Foo` methods or to call the `FooWithStatus` methods and receive a `Status` to act on instead of just crashing. PiperOrigin-RevId: 350218073 Change-Id: I4238ba3a2eff16a224e61bafb58b6553d706bb4b
This commit is contained in:
parent
671c78343c
commit
d5a1371029
@ -681,7 +681,9 @@ cc_library(
|
||||
"//tensorflow/core/lib/gtl:inlined_vector",
|
||||
"//tensorflow/core/lib/strings:str_util",
|
||||
"//tensorflow/core/lib/strings:strcat",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/util:overflow",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
@ -65,6 +65,19 @@ TEST(PartialTensorShapeTest, Concatenate) {
|
||||
EXPECT_EQ(-1, s4.num_elements());
|
||||
}
|
||||
|
||||
TEST(PartialTensorShapeTest, ConcatenateWithStatus) {
|
||||
PartialTensorShape s({10, 5, 20});
|
||||
Status status = s.ConcatenateWithStatus(400, &s);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_EQ(400000, s.num_elements());
|
||||
ASSERT_EQ(4, s.dims());
|
||||
|
||||
status = s.ConcatenateWithStatus(-10, &s);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_EQ(-1, s.num_elements());
|
||||
ASSERT_EQ(5, s.dims());
|
||||
}
|
||||
|
||||
TEST(PartialTensorShapeTest, InvalidShapeProto) {
|
||||
TensorShapeProto proto;
|
||||
EXPECT_TRUE(PartialTensorShape::IsValid(proto));
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/overflow.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -153,11 +155,44 @@ TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
|
||||
}
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::BuildTensorShapeBase(
|
||||
const TensorShapeProto& proto, TensorShapeBase* out) {
|
||||
out->set_tag(REP16);
|
||||
out->set_data_type(DT_INVALID);
|
||||
// NOTE(irving): Unfortunately, TensorShape allows parsing protos with
|
||||
// unknown_shape() set, and it seems hard to remove this without backwards
|
||||
// compatibility issues.
|
||||
if (kIsPartial && proto.unknown_rank()) {
|
||||
out->set_ndims_byte(kUnknownRank);
|
||||
out->set_num_elements(-1);
|
||||
} else {
|
||||
out->set_ndims_byte(0);
|
||||
out->set_num_elements(1);
|
||||
Status s = Status::OK();
|
||||
for (const auto& d : proto.dim()) {
|
||||
s = out->AddDimWithStatus(d.size());
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
|
||||
set_tag(REP16);
|
||||
set_data_type(DT_INVALID);
|
||||
InitDims(dim_sizes);
|
||||
TF_CHECK_OK(InitDims(dim_sizes));
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::BuildTensorShapeBase(
|
||||
gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
|
||||
out->set_tag(REP16);
|
||||
out->set_data_type(DT_INVALID);
|
||||
return out->InitDims(dim_sizes);
|
||||
}
|
||||
|
||||
// Returns true iff partial is true and val is < 0.
|
||||
@ -175,7 +210,7 @@ static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
DCHECK_EQ(tag(), REP16);
|
||||
|
||||
// Allow sizes that are under kint64max^0.25 so that 4-way multiplication
|
||||
@ -191,11 +226,12 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(mihaimaruseac): Remove this CHECK as the refactoring continues
|
||||
// Temporaryly moving the CHECK from Set16 here
|
||||
if (!kIsPartial && !large_size) {
|
||||
for (auto s : dim_sizes) {
|
||||
CHECK_GE(s, 0);
|
||||
if (TF_PREDICT_FALSE(s < 0)) {
|
||||
return errors::Internal(
|
||||
"Expected shape dimensions to be non-negative, got ", s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,7 +244,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
const int64 size = dim_sizes[0];
|
||||
const bool neg = Set16(kIsPartial, dst, 0, size);
|
||||
set_num_elements(neg ? -1 : size);
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
case 2: {
|
||||
set_ndims_byte(2);
|
||||
@ -217,7 +253,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
bool neg = Set16(kIsPartial, dst, 0, size0);
|
||||
neg |= Set16(kIsPartial, dst, 1, size1);
|
||||
set_num_elements(neg ? -1 : (size0 * size1));
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
case 3: {
|
||||
set_ndims_byte(3);
|
||||
@ -228,7 +264,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
neg |= Set16(kIsPartial, dst, 1, size1);
|
||||
neg |= Set16(kIsPartial, dst, 2, size2);
|
||||
set_num_elements(neg ? -1 : (size0 * size1 * size2));
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
case 4: {
|
||||
set_ndims_byte(4);
|
||||
@ -241,16 +277,22 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
|
||||
neg |= Set16(kIsPartial, dst, 2, size2);
|
||||
neg |= Set16(kIsPartial, dst, 3, size3);
|
||||
set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_ndims_byte(0);
|
||||
set_num_elements(1);
|
||||
Status status = Status::OK();
|
||||
for (int64 s : dim_sizes) {
|
||||
AddDim(internal::SubtleMustCopy(s));
|
||||
status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
@ -365,6 +407,38 @@ void TensorShapeBase<Shape>::AddDim(int64 size) {
|
||||
UnsafeAddDim(size, new_num_elements);
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::AddDimWithStatus(int64 size) {
|
||||
if (!kIsPartial) {
|
||||
if (TF_PREDICT_FALSE(size < 0)) {
|
||||
return errors::Internal("Expected a non-negative size, got ", size);
|
||||
}
|
||||
}
|
||||
|
||||
if (unknown_rank()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
|
||||
return errors::Internal("Too many dimensions in tensor");
|
||||
}
|
||||
|
||||
int64 new_num_elements;
|
||||
if (kIsPartial && (num_elements() < 0 || size < 0)) {
|
||||
new_num_elements = -1;
|
||||
} else {
|
||||
new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
|
||||
if (TF_PREDICT_FALSE(new_num_elements < 0)) {
|
||||
return errors::Internal("Encountered overflow when multiplying ",
|
||||
num_elements(), " with ", size,
|
||||
", result: ", new_num_elements);
|
||||
}
|
||||
}
|
||||
|
||||
UnsafeAddDim(size, new_num_elements);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
|
||||
const int nd = ndims_byte();
|
||||
@ -415,6 +489,19 @@ void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
|
||||
for (auto d : shape) AddDim(d.size);
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::AppendShapeWithStatus(
|
||||
const TensorShapeBase& shape) {
|
||||
Status s = Status::OK();
|
||||
for (auto d : shape) {
|
||||
s.Update(AddDimWithStatus(d.size));
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
|
||||
CHECK_GE(d, 0);
|
||||
@ -430,6 +517,42 @@ void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64 size) {
|
||||
if (!kIsPartial) {
|
||||
if (TF_PREDICT_FALSE(size < 0)) {
|
||||
return errors::Internal("Expected a non-negative size, got ", size);
|
||||
}
|
||||
}
|
||||
|
||||
if (TF_PREDICT_FALSE(d < 0)) {
|
||||
return errors::Internal("The insertion index must be non-negative, got ",
|
||||
d);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(d > dims())) {
|
||||
return errors::Internal("The insertion index must be at most ", dims(),
|
||||
" got ", d);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
|
||||
return errors::Internal("Shape has ", dims(),
|
||||
" dimensions which is the maximum allowed");
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 8> vals;
|
||||
AppendTo(*this, &vals);
|
||||
vals.insert(vals.begin() + d, size);
|
||||
ClearAllButDataType();
|
||||
|
||||
Status s = Status::OK();
|
||||
for (auto dval : vals) {
|
||||
s.Update(AddDimWithStatus(dval));
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
|
||||
gtl::InlinedVector<int64, 4> result;
|
||||
@ -465,6 +588,45 @@ void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
|
||||
TF_CHECK_OK(RecomputeNumElements());
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64 size) {
|
||||
if (TF_PREDICT_FALSE(d < 0)) {
|
||||
return errors::Internal("Index must be non-negative, got ", d);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(d >= dims())) {
|
||||
return errors::Internal("Index must be less than ", dims(), ", got ", d);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(size < 0)) {
|
||||
return errors::Internal("Expected a non-negative size, got ", size);
|
||||
}
|
||||
|
||||
if (tag() == REP16 && size < kMaxRep16) {
|
||||
as16()->dims_[d] =
|
||||
kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
|
||||
} else if (tag() == REP32 && size < kMaxRep32) {
|
||||
as32()->dims_[d] =
|
||||
kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
|
||||
} else if (tag() == REP_OUT_OF_LINE) {
|
||||
(*as64()->dims_)[d] = size;
|
||||
} else {
|
||||
// Must upgrade
|
||||
gtl::InlinedVector<int64, 8> vals;
|
||||
AppendTo(*this, &vals);
|
||||
vals[d] = size;
|
||||
ClearAllButDataType();
|
||||
|
||||
Status s = Status::OK();
|
||||
for (auto dval : vals) {
|
||||
s.Update(AddDimWithStatus(dval));
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RecomputeNumElements();
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
|
||||
if (unknown_rank()) return;
|
||||
@ -485,6 +647,50 @@ void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
|
||||
TF_CHECK_OK(RecomputeNumElements());
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
|
||||
if (unknown_rank()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
begin = begin < 0 ? dims() + begin + 1 : begin;
|
||||
end = end < 0 ? dims() + end + 1 : end;
|
||||
|
||||
if (TF_PREDICT_FALSE(begin < 0)) {
|
||||
return errors::Internal("Start index must be non-negative, got ", begin);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(begin > dims())) {
|
||||
return errors::Internal("Start index must be less than ", dims(), ", got ",
|
||||
begin);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(end < 0)) {
|
||||
return errors::Internal("End index must be non-negative, got ", end);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(end > dims())) {
|
||||
return errors::Internal("End index must be less than ", dims(), ", got ",
|
||||
end);
|
||||
}
|
||||
|
||||
if (begin >= end) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 8> vals;
|
||||
AppendTo(*this, &vals);
|
||||
vals.erase(vals.begin() + begin, vals.begin() + end);
|
||||
ClearAllButDataType();
|
||||
|
||||
Status s = Status::OK();
|
||||
for (auto dval : vals) {
|
||||
s.Update(AddDimWithStatus(dval));
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
return RecomputeNumElements();
|
||||
}
|
||||
|
||||
bool TensorShape::IsSameSize(const TensorShape& b) const {
|
||||
if (b.dims() != dims()) return false;
|
||||
for (int d = 0; d < dims(); d++) {
|
||||
@ -646,6 +852,12 @@ PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
|
||||
return out;
|
||||
}
|
||||
|
||||
Status PartialTensorShape::ConcatenateWithStatus(
|
||||
int64 size, PartialTensorShape* out) const {
|
||||
out = const_cast<PartialTensorShape*>(this);
|
||||
return out->AddDimWithStatus(size);
|
||||
}
|
||||
|
||||
PartialTensorShape PartialTensorShape::Concatenate(
|
||||
const PartialTensorShape& shape) const {
|
||||
if (unknown_rank() || shape.unknown_rank()) {
|
||||
@ -656,6 +868,21 @@ PartialTensorShape PartialTensorShape::Concatenate(
|
||||
return out;
|
||||
}
|
||||
|
||||
Status PartialTensorShape::ConcatenateWithStatus(
|
||||
const PartialTensorShape& shape, PartialTensorShape* out) const {
|
||||
if (unknown_rank() || shape.unknown_rank()) {
|
||||
*out = PartialTensorShape();
|
||||
return Status::OK();
|
||||
}
|
||||
out = const_cast<PartialTensorShape*>(this);
|
||||
for (auto dim : shape) {
|
||||
Status s = out->AddDimWithStatus(dim.size);
|
||||
if (!s.ok()) return s;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
|
||||
PartialTensorShape* result) const {
|
||||
if (unknown_rank()) {
|
||||
@ -672,8 +899,14 @@ Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
|
||||
"PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
|
||||
shape.dims());
|
||||
}
|
||||
CHECK(result != this);
|
||||
|
||||
if (result == this) {
|
||||
return errors::Internal(
|
||||
"PartialTensorShape::MergeWith: cannot merge shape with itself");
|
||||
}
|
||||
|
||||
result->Clear();
|
||||
Status s = Status::OK();
|
||||
for (int i = 0; i < dims_; ++i) {
|
||||
const int64 dim0 = dim_size(i);
|
||||
const int64 dim1 = shape.dim_size(i);
|
||||
@ -682,7 +915,10 @@ Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
|
||||
"PartialTensorShape: Incompatible shapes during merge: ",
|
||||
DebugString(), " vs. ", shape.DebugString());
|
||||
}
|
||||
result->AddDim(dim0 >= 0 ? dim0 : dim1);
|
||||
s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -26,7 +26,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -172,8 +174,22 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
/// Construct an empty TensorShape, or an unknown rank PartialTensorShape
|
||||
TensorShapeBase();
|
||||
|
||||
// TODO(mihaimaruseac): Mark this explicit in a subsequent change
|
||||
TensorShapeBase(const TensorShapeProto& proto);
|
||||
|
||||
// These factory methods should be used instead of the constructors that take
|
||||
// an array of sizes if calling code cannot validate that the sizes specify a
|
||||
// valid `TensorShape`.
|
||||
// The value in `*out` is valid iff the returned value is `Status::OK`.
|
||||
static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
|
||||
TensorShapeBase* out);
|
||||
static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
|
||||
TensorShapeBase* out) {
|
||||
return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
|
||||
}
|
||||
static Status BuildTensorShapeBase(const TensorShapeProto& proto,
|
||||
TensorShapeBase* out);
|
||||
|
||||
/// Returns `true` iff `proto` is a valid tensor shape.
|
||||
// For TensorShape, the proto shape must be fully defined.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
@ -189,19 +205,37 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
/// REQUIRES: `size >= 0`
|
||||
void AddDim(int64 size);
|
||||
|
||||
/// Same as `AddDim` but returns a `Status`.
|
||||
/// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
|
||||
Status AddDimWithStatus(int64 size);
|
||||
|
||||
/// Appends all the dimensions from `shape`.
|
||||
void AppendShape(const TensorShapeBase& shape);
|
||||
|
||||
/// Same as `RemoveDim` but returns a `Status`.
|
||||
/// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
|
||||
Status AppendShapeWithStatus(const TensorShapeBase& shape);
|
||||
|
||||
/// \brief Insert a dimension somewhere in the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d <= dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void InsertDim(int d, int64 size);
|
||||
|
||||
/// Same as `InsertDim` but returns a `Status`.
|
||||
/// Use if unsure if requirements in `InsertDim` are satistified, to prevent
|
||||
/// `CHECK`-fail crashes.
|
||||
Status InsertDimWithStatus(int d, int64 size);
|
||||
|
||||
/// \brief Modifies the size of the dimension `d` to be `size`
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void set_dim(int d, int64 size);
|
||||
|
||||
/// Same as `set_dim` but returns a `Status`.
|
||||
/// Use if unsure if requirements in `set_dim` are satistified, to prevent
|
||||
/// `CHECK`-fail crashes.
|
||||
Status SetDimWithStatus(int d, int64 size);
|
||||
|
||||
/// \brief Removes dimension `d` from the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
void RemoveDim(int d) {
|
||||
@ -209,6 +243,16 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
RemoveDimRange(d, d + 1);
|
||||
}
|
||||
|
||||
/// Same as `RemoveDim` but returns a `Status`.
|
||||
/// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
|
||||
Status RemoveDimWithStatus(int64 d) {
|
||||
if (TF_PREDICT_FALSE(d < 0)) {
|
||||
return errors::Internal(
|
||||
"Expected dimension index to be non-negative, got ", d);
|
||||
}
|
||||
return RemoveDimRangeWithStatus(d, d + 1);
|
||||
}
|
||||
|
||||
/// \brief Removes last `n` dimensions from the `TensorShape`.
|
||||
/// REQUIRES: `0 <= n <= dims()`
|
||||
void RemoveLastDims(int n) {
|
||||
@ -216,12 +260,28 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
RemoveDimRange(dims() - n, dims());
|
||||
}
|
||||
|
||||
/// Same as `RemoveLastDims` but returns a `Status`.
|
||||
/// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
|
||||
Status RemoveLastDimsWithStatus(int64 n) {
|
||||
if (TF_PREDICT_FALSE(n < dims())) {
|
||||
return errors::Internal("Expected dimension index to be at most ", dims(),
|
||||
" got ", n);
|
||||
}
|
||||
return RemoveDimRangeWithStatus(dims() - n, dims());
|
||||
}
|
||||
|
||||
/// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
|
||||
/// Negative values of `end` are interpreted as `dims() + end + 1` (as in
|
||||
/// Python). The same is true for negative values of `begin`. REQUIRES:
|
||||
/// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
|
||||
/// Python). The same is true for negative values of `begin`.
|
||||
/// REQUIRES: `-(dims()+1) <= begin <= dims()`
|
||||
/// REQUIRES: `-(dims()+1) <= end <= dims()`
|
||||
void RemoveDimRange(int begin, int end);
|
||||
|
||||
/// Same as `RemoveDimRange` but returns a `Status`.
|
||||
/// Use if unsure if requirements in `RemoveDimRange` are satistified, to
|
||||
/// prevent `CHECK`-fail crashes.
|
||||
Status RemoveDimRangeWithStatus(int begin, int end);
|
||||
|
||||
/// Return whether the rank is unknown
|
||||
bool unknown_rank() const {
|
||||
return kIsPartial && ndims_byte() == kUnknownRank;
|
||||
@ -264,7 +324,7 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
|
||||
private:
|
||||
Status RecomputeNumElements();
|
||||
void InitDims(gtl::ArraySlice<int64> dim_sizes);
|
||||
Status InitDims(gtl::ArraySlice<int64> dim_sizes);
|
||||
|
||||
// True for PartialTensorShape, false for TensorShape
|
||||
static constexpr bool kIsPartial =
|
||||
@ -314,6 +374,13 @@ class TensorShape : public TensorShapeBase<TensorShape> {
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
|
||||
|
||||
// Same as `AsEigenDSizes()` but returns a `Status` instead.
|
||||
// Use this method to surface error to user instead of crashing if `NDMIS` is
|
||||
// not equal to `dims()`.
|
||||
// Caller must take ownership of `out`.
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
|
||||
|
||||
/// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
|
||||
/// which case we pad the rest of the sizes with 1.
|
||||
/// Notice: Using IndexType=int32 in combination with To32Bit() can
|
||||
@ -321,6 +388,14 @@ class TensorShape : public TensorShapeBase<TensorShape> {
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
|
||||
|
||||
// Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
|
||||
// Use this method to surface error to user instead of crashing if `NDMIS` is
|
||||
// not equal to `dims()`.
|
||||
// Caller must take ownership of `out`.
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Status AsEigenDSizesWithPaddingWithStatus(
|
||||
Eigen::DSizes<IndexType, NDIMS>* out) const;
|
||||
|
||||
private:
|
||||
// These CHECK fail to ease debugging.
|
||||
// REQUIRES: dims() == NDIMS
|
||||
@ -328,6 +403,18 @@ class TensorShape : public TensorShapeBase<TensorShape> {
|
||||
// REQUIRES: dims() >= NDIMS
|
||||
void CheckDimsAtLeast(int NDIMS) const;
|
||||
|
||||
// Fill output from `*this`.
|
||||
// Helper method for common code between `AsEigenDSize()` and
|
||||
// `AsEigenDSizeWithStatus()`.
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
|
||||
|
||||
// Fill output from `*this`.
|
||||
// Helper method for common code between `AsEigenDSizesWithPadding()` and
|
||||
// `AsEigenDSizeWithPaddingWithStatus()`.
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
|
||||
|
||||
// For access to TensorShapeBase(DataType).
|
||||
friend class Tensor;
|
||||
};
|
||||
@ -426,10 +513,21 @@ class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
|
||||
/// REQUIRES: `size >= -1`, where -1 means unknown
|
||||
PartialTensorShape Concatenate(int64 size) const;
|
||||
|
||||
/// Similar to `Concatenate` but returning `Status`.
|
||||
/// Use if calling code cannot validate all requirements and if `CHECK`-fails
|
||||
/// are to be avoided.
|
||||
Status ConcatenateWithStatus(int64 size, PartialTensorShape* out) const;
|
||||
|
||||
/// Appends all the dimensions from `shape`. Returns a new
|
||||
/// PartialTensorShape.
|
||||
PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
|
||||
|
||||
/// Similar to `Concatenate` but returning `Status`.
|
||||
/// Use if calling code cannot validate all requirements and if `CHECK`-fails
|
||||
/// are to be avoided.
|
||||
Status ConcatenateWithStatus(const PartialTensorShape& shape,
|
||||
PartialTensorShape* out) const;
|
||||
|
||||
/// Merges all the dimensions from `shape`. Returns
|
||||
/// `InvalidArgument` error if either `shape` has a different rank
|
||||
/// or if any of the dimensions are incompatible.
|
||||
@ -481,8 +579,7 @@ class PartialTensorShapeUtils {
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
|
||||
CheckDimsEqual(NDIMS);
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
|
||||
Eigen::DSizes<IndexType, NDIMS> dsizes;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
dsizes[d] = static_cast<IndexType>(dim_size(d));
|
||||
@ -491,8 +588,7 @@ Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
|
||||
}
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
|
||||
CheckDimsAtLeast(NDIMS);
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
|
||||
static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
|
||||
Eigen::DSizes<IndexType, NDIMS> dsizes;
|
||||
for (int d = 0; d < dims(); d++) {
|
||||
@ -504,6 +600,40 @@ Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
|
||||
return dsizes;
|
||||
}
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
|
||||
CheckDimsEqual(NDIMS);
|
||||
return AsEigenDSizesCopy<NDIMS, IndexType>();
|
||||
}
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Status TensorShape::AsEigenDSizesWithStatus(
|
||||
Eigen::DSizes<IndexType, NDIMS>* out) const {
|
||||
if (TF_PREDICT_FALSE(NDIMS != dims())) {
|
||||
return errors::Internal("Asking for tensor of ", NDIMS,
|
||||
" dimensions from a tensor of ", dims(),
|
||||
" dimensions");
|
||||
}
|
||||
*out = AsEigenDSizesCopy<NDIMS, IndexType>();
|
||||
}
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
|
||||
CheckDimsAtLeast(NDIMS);
|
||||
return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
|
||||
}
|
||||
|
||||
template <int NDIMS, typename IndexType>
|
||||
Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
|
||||
Eigen::DSizes<IndexType, NDIMS>* out) const {
|
||||
if (TF_PREDICT_FALSE(NDIMS < dims())) {
|
||||
return errors::Internal("Asking for tensor of at least ", NDIMS,
|
||||
" dimensions from a tensor of ", dims(),
|
||||
" dimensions");
|
||||
}
|
||||
*out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Inlining of some performance critical routines
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class TensorShapeTestHelper {
|
||||
@ -205,6 +206,28 @@ TEST(TensorShapeTest, ostream) {
|
||||
EXPECT_EQ(ss.str(), "[10,5,4]");
|
||||
}
|
||||
|
||||
TEST(TensorShapeTest, AddDimWithStatus) {
|
||||
TensorShape s({10, 5, 20});
|
||||
Status status = s.AddDimWithStatus(400);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_EQ(400000, s.num_elements());
|
||||
ASSERT_EQ(4, s.dims());
|
||||
|
||||
status = s.AddDimWithStatus(-1);
|
||||
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
|
||||
}
|
||||
|
||||
TEST(TensorShapeTest, Factory) {
|
||||
TensorShape s;
|
||||
Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_EQ(1000, s.num_elements());
|
||||
ASSERT_EQ(3, s.dims());
|
||||
|
||||
status = TensorShape::BuildTensorShapeBase({-10, 5, 20}, &s);
|
||||
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// An old implementation of TensorShape using a different representation,
|
||||
// preserved here in the unittest to allow us to have a randomized unittest
|
||||
|
Loading…
x
Reference in New Issue
Block a user