Add FooWithStatus methods for all Foo methods with CHECK fails.

This is part of eliminating CHECK fails from `tensor_shape.cc`.

The goal is to provide options for the callers to either test the `CHECK` conditions before calling the `Foo` methods or to call the `FooWithStatus` methods and receive a `Status` to act on instead of just crashing.

PiperOrigin-RevId: 350218073
Change-Id: I4238ba3a2eff16a224e61bafb58b6553d706bb4b
This commit is contained in:
Mihai Maruseac 2021-01-05 14:17:34 -08:00 committed by TensorFlower Gardener
parent 671c78343c
commit d5a1371029
5 changed files with 423 additions and 19 deletions

View File

@ -681,7 +681,9 @@ cc_library(
"//tensorflow/core/lib/gtl:inlined_vector",
"//tensorflow/core/lib/strings:str_util",
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
"//tensorflow/core/util:overflow",
"//third_party/eigen3",
],

View File

@ -65,6 +65,19 @@ TEST(PartialTensorShapeTest, Concatenate) {
EXPECT_EQ(-1, s4.num_elements());
}
TEST(PartialTensorShapeTest, ConcatenateWithStatus) {
PartialTensorShape s({10, 5, 20});
Status status = s.ConcatenateWithStatus(400, &s);
EXPECT_TRUE(status.ok());
EXPECT_EQ(400000, s.num_elements());
ASSERT_EQ(4, s.dims());
status = s.ConcatenateWithStatus(-10, &s);
EXPECT_TRUE(status.ok());
EXPECT_EQ(-1, s.num_elements());
ASSERT_EQ(5, s.dims());
}
TEST(PartialTensorShapeTest, InvalidShapeProto) {
TensorShapeProto proto;
EXPECT_TRUE(PartialTensorShape::IsValid(proto));

View File

@ -20,7 +20,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/overflow.h"
namespace tensorflow {
@ -153,11 +155,44 @@ TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
}
}
template <class Shape>
Status TensorShapeBase<Shape>::BuildTensorShapeBase(
const TensorShapeProto& proto, TensorShapeBase* out) {
out->set_tag(REP16);
out->set_data_type(DT_INVALID);
// NOTE(irving): Unfortunately, TensorShape allows parsing protos with
// unknown_shape() set, and it seems hard to remove this without backwards
// compatibility issues.
if (kIsPartial && proto.unknown_rank()) {
out->set_ndims_byte(kUnknownRank);
out->set_num_elements(-1);
} else {
out->set_ndims_byte(0);
out->set_num_elements(1);
Status s = Status::OK();
for (const auto& d : proto.dim()) {
s = out->AddDimWithStatus(d.size());
if (!s.ok()) {
return s;
}
}
}
return Status::OK();
}
template <class Shape>
TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
set_tag(REP16);
set_data_type(DT_INVALID);
InitDims(dim_sizes);
TF_CHECK_OK(InitDims(dim_sizes));
}
template <class Shape>
Status TensorShapeBase<Shape>::BuildTensorShapeBase(
gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
out->set_tag(REP16);
out->set_data_type(DT_INVALID);
return out->InitDims(dim_sizes);
}
// Returns true iff partial is true and val is < 0.
@ -175,7 +210,7 @@ static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
}
template <class Shape>
void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
DCHECK_EQ(tag(), REP16);
// Allow sizes that are under kint64max^0.25 so that 4-way multiplication
@ -191,11 +226,12 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
}
}
// TODO(mihaimaruseac): Remove this CHECK as the refactoring continues
// Temporaryly moving the CHECK from Set16 here
if (!kIsPartial && !large_size) {
for (auto s : dim_sizes) {
CHECK_GE(s, 0);
if (TF_PREDICT_FALSE(s < 0)) {
return errors::Internal(
"Expected shape dimensions to be non-negative, got ", s);
}
}
}
@ -208,7 +244,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
const int64 size = dim_sizes[0];
const bool neg = Set16(kIsPartial, dst, 0, size);
set_num_elements(neg ? -1 : size);
return;
return Status::OK();
}
case 2: {
set_ndims_byte(2);
@ -217,7 +253,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
bool neg = Set16(kIsPartial, dst, 0, size0);
neg |= Set16(kIsPartial, dst, 1, size1);
set_num_elements(neg ? -1 : (size0 * size1));
return;
return Status::OK();
}
case 3: {
set_ndims_byte(3);
@ -228,7 +264,7 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
neg |= Set16(kIsPartial, dst, 1, size1);
neg |= Set16(kIsPartial, dst, 2, size2);
set_num_elements(neg ? -1 : (size0 * size1 * size2));
return;
return Status::OK();
}
case 4: {
set_ndims_byte(4);
@ -241,16 +277,22 @@ void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
neg |= Set16(kIsPartial, dst, 2, size2);
neg |= Set16(kIsPartial, dst, 3, size3);
set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
return;
return Status::OK();
}
}
}
set_ndims_byte(0);
set_num_elements(1);
Status status = Status::OK();
for (int64 s : dim_sizes) {
AddDim(internal::SubtleMustCopy(s));
status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
if (!status.ok()) {
return status;
}
}
return status;
}
template <class Shape>
@ -365,6 +407,38 @@ void TensorShapeBase<Shape>::AddDim(int64 size) {
UnsafeAddDim(size, new_num_elements);
}
template <class Shape>
Status TensorShapeBase<Shape>::AddDimWithStatus(int64 size) {
if (!kIsPartial) {
if (TF_PREDICT_FALSE(size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
}
}
if (unknown_rank()) {
return Status::OK();
}
if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
return errors::Internal("Too many dimensions in tensor");
}
int64 new_num_elements;
if (kIsPartial && (num_elements() < 0 || size < 0)) {
new_num_elements = -1;
} else {
new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
if (TF_PREDICT_FALSE(new_num_elements < 0)) {
return errors::Internal("Encountered overflow when multiplying ",
num_elements(), " with ", size,
", result: ", new_num_elements);
}
}
UnsafeAddDim(size, new_num_elements);
return Status::OK();
}
template <class Shape>
void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
const int nd = ndims_byte();
@ -415,6 +489,19 @@ void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
for (auto d : shape) AddDim(d.size);
}
template <class Shape>
Status TensorShapeBase<Shape>::AppendShapeWithStatus(
const TensorShapeBase& shape) {
Status s = Status::OK();
for (auto d : shape) {
s.Update(AddDimWithStatus(d.size));
if (!s.ok()) {
return s;
}
}
return s;
}
template <class Shape>
void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
CHECK_GE(d, 0);
@ -430,6 +517,42 @@ void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
}
}
template <class Shape>
Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64 size) {
if (!kIsPartial) {
if (TF_PREDICT_FALSE(size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
}
}
if (TF_PREDICT_FALSE(d < 0)) {
return errors::Internal("The insertion index must be non-negative, got ",
d);
}
if (TF_PREDICT_FALSE(d > dims())) {
return errors::Internal("The insertion index must be at most ", dims(),
" got ", d);
}
if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
return errors::Internal("Shape has ", dims(),
" dimensions which is the maximum allowed");
}
gtl::InlinedVector<int64, 8> vals;
AppendTo(*this, &vals);
vals.insert(vals.begin() + d, size);
ClearAllButDataType();
Status s = Status::OK();
for (auto dval : vals) {
s.Update(AddDimWithStatus(dval));
if (!s.ok()) {
return s;
}
}
return s;
}
template <class Shape>
gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
gtl::InlinedVector<int64, 4> result;
@ -465,6 +588,45 @@ void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
TF_CHECK_OK(RecomputeNumElements());
}
template <class Shape>
Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64 size) {
if (TF_PREDICT_FALSE(d < 0)) {
return errors::Internal("Index must be non-negative, got ", d);
}
if (TF_PREDICT_FALSE(d >= dims())) {
return errors::Internal("Index must be less than ", dims(), ", got ", d);
}
if (TF_PREDICT_FALSE(size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
}
if (tag() == REP16 && size < kMaxRep16) {
as16()->dims_[d] =
kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
} else if (tag() == REP32 && size < kMaxRep32) {
as32()->dims_[d] =
kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
} else if (tag() == REP_OUT_OF_LINE) {
(*as64()->dims_)[d] = size;
} else {
// Must upgrade
gtl::InlinedVector<int64, 8> vals;
AppendTo(*this, &vals);
vals[d] = size;
ClearAllButDataType();
Status s = Status::OK();
for (auto dval : vals) {
s.Update(AddDimWithStatus(dval));
if (!s.ok()) {
return s;
}
}
}
return RecomputeNumElements();
}
template <class Shape>
void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
if (unknown_rank()) return;
@ -485,6 +647,50 @@ void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
TF_CHECK_OK(RecomputeNumElements());
}
template <class Shape>
Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
if (unknown_rank()) {
return Status::OK();
}
begin = begin < 0 ? dims() + begin + 1 : begin;
end = end < 0 ? dims() + end + 1 : end;
if (TF_PREDICT_FALSE(begin < 0)) {
return errors::Internal("Start index must be non-negative, got ", begin);
}
if (TF_PREDICT_FALSE(begin > dims())) {
return errors::Internal("Start index must be less than ", dims(), ", got ",
begin);
}
if (TF_PREDICT_FALSE(end < 0)) {
return errors::Internal("End index must be non-negative, got ", end);
}
if (TF_PREDICT_FALSE(end > dims())) {
return errors::Internal("End index must be less than ", dims(), ", got ",
end);
}
if (begin >= end) {
return Status::OK();
}
gtl::InlinedVector<int64, 8> vals;
AppendTo(*this, &vals);
vals.erase(vals.begin() + begin, vals.begin() + end);
ClearAllButDataType();
Status s = Status::OK();
for (auto dval : vals) {
s.Update(AddDimWithStatus(dval));
if (!s.ok()) {
return s;
}
}
return RecomputeNumElements();
}
bool TensorShape::IsSameSize(const TensorShape& b) const {
if (b.dims() != dims()) return false;
for (int d = 0; d < dims(); d++) {
@ -646,6 +852,12 @@ PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
return out;
}
Status PartialTensorShape::ConcatenateWithStatus(
int64 size, PartialTensorShape* out) const {
out = const_cast<PartialTensorShape*>(this);
return out->AddDimWithStatus(size);
}
PartialTensorShape PartialTensorShape::Concatenate(
const PartialTensorShape& shape) const {
if (unknown_rank() || shape.unknown_rank()) {
@ -656,6 +868,21 @@ PartialTensorShape PartialTensorShape::Concatenate(
return out;
}
Status PartialTensorShape::ConcatenateWithStatus(
const PartialTensorShape& shape, PartialTensorShape* out) const {
if (unknown_rank() || shape.unknown_rank()) {
*out = PartialTensorShape();
return Status::OK();
}
out = const_cast<PartialTensorShape*>(this);
for (auto dim : shape) {
Status s = out->AddDimWithStatus(dim.size);
if (!s.ok()) return s;
}
return Status::OK();
}
Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
PartialTensorShape* result) const {
if (unknown_rank()) {
@ -672,8 +899,14 @@ Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
"PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
shape.dims());
}
CHECK(result != this);
if (result == this) {
return errors::Internal(
"PartialTensorShape::MergeWith: cannot merge shape with itself");
}
result->Clear();
Status s = Status::OK();
for (int i = 0; i < dims_; ++i) {
const int64 dim0 = dim_size(i);
const int64 dim1 = shape.dim_size(i);
@ -682,7 +915,10 @@ Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
"PartialTensorShape: Incompatible shapes during merge: ",
DebugString(), " vs. ", shape.DebugString());
}
result->AddDim(dim0 >= 0 ? dim0 : dim1);
s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
if (!s.ok()) {
return s;
}
}
return Status::OK();
}

View File

@ -26,7 +26,9 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@ -172,8 +174,22 @@ class TensorShapeBase : public TensorShapeRep {
/// Construct an empty TensorShape, or an unknown rank PartialTensorShape
TensorShapeBase();
// TODO(mihaimaruseac): Mark this explicit in a subsequent change
TensorShapeBase(const TensorShapeProto& proto);
// These factory methods should be used instead of the constructors that take
// an array of sizes if calling code cannot validate that the sizes specify a
// valid `TensorShape`.
// The value in `*out` is valid iff the returned value is `Status::OK`.
static Status BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,
TensorShapeBase* out);
static Status BuildTensorShapeBase(std::initializer_list<int64> dim_sizes,
TensorShapeBase* out) {
return BuildTensorShapeBase(gtl::ArraySlice<int64>(dim_sizes), out);
}
static Status BuildTensorShapeBase(const TensorShapeProto& proto,
TensorShapeBase* out);
/// Returns `true` iff `proto` is a valid tensor shape.
// For TensorShape, the proto shape must be fully defined.
static bool IsValid(const TensorShapeProto& proto);
@ -189,19 +205,37 @@ class TensorShapeBase : public TensorShapeRep {
/// REQUIRES: `size >= 0`
void AddDim(int64 size);
/// Same as `AddDim` but returns a `Status`.
/// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
Status AddDimWithStatus(int64 size);
/// Appends all the dimensions from `shape`.
void AppendShape(const TensorShapeBase& shape);
/// Same as `RemoveDim` but returns a `Status`.
/// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
Status AppendShapeWithStatus(const TensorShapeBase& shape);
/// \brief Insert a dimension somewhere in the `TensorShape`.
/// REQUIRES: `0 <= d <= dims()`
/// REQUIRES: `size >= 0`
void InsertDim(int d, int64 size);
/// Same as `InsertDim` but returns a `Status`.
/// Use if unsure if requirements in `InsertDim` are satistified, to prevent
/// `CHECK`-fail crashes.
Status InsertDimWithStatus(int d, int64 size);
/// \brief Modifies the size of the dimension `d` to be `size`
/// REQUIRES: `0 <= d < dims()`
/// REQUIRES: `size >= 0`
void set_dim(int d, int64 size);
/// Same as `set_dim` but returns a `Status`.
/// Use if unsure if requirements in `set_dim` are satistified, to prevent
/// `CHECK`-fail crashes.
Status SetDimWithStatus(int d, int64 size);
/// \brief Removes dimension `d` from the `TensorShape`.
/// REQUIRES: `0 <= d < dims()`
void RemoveDim(int d) {
@ -209,6 +243,16 @@ class TensorShapeBase : public TensorShapeRep {
RemoveDimRange(d, d + 1);
}
/// Same as `RemoveDim` but returns a `Status`.
/// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
Status RemoveDimWithStatus(int64 d) {
if (TF_PREDICT_FALSE(d < 0)) {
return errors::Internal(
"Expected dimension index to be non-negative, got ", d);
}
return RemoveDimRangeWithStatus(d, d + 1);
}
/// \brief Removes last `n` dimensions from the `TensorShape`.
/// REQUIRES: `0 <= n <= dims()`
void RemoveLastDims(int n) {
@ -216,12 +260,28 @@ class TensorShapeBase : public TensorShapeRep {
RemoveDimRange(dims() - n, dims());
}
/// Same as `RemoveLastDims` but returns a `Status`.
/// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
Status RemoveLastDimsWithStatus(int64 n) {
if (TF_PREDICT_FALSE(n < dims())) {
return errors::Internal("Expected dimension index to be at most ", dims(),
" got ", n);
}
return RemoveDimRangeWithStatus(dims() - n, dims());
}
/// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
/// Negative values of `end` are interpreted as `dims() + end + 1` (as in
/// Python). The same is true for negative values of `begin`. REQUIRES:
/// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
/// Python). The same is true for negative values of `begin`.
/// REQUIRES: `-(dims()+1) <= begin <= dims()`
/// REQUIRES: `-(dims()+1) <= end <= dims()`
void RemoveDimRange(int begin, int end);
/// Same as `RemoveDimRange` but returns a `Status`.
/// Use if unsure if requirements in `RemoveDimRange` are satistified, to
/// prevent `CHECK`-fail crashes.
Status RemoveDimRangeWithStatus(int begin, int end);
/// Return whether the rank is unknown
bool unknown_rank() const {
return kIsPartial && ndims_byte() == kUnknownRank;
@ -264,7 +324,7 @@ class TensorShapeBase : public TensorShapeRep {
private:
Status RecomputeNumElements();
void InitDims(gtl::ArraySlice<int64> dim_sizes);
Status InitDims(gtl::ArraySlice<int64> dim_sizes);
// True for PartialTensorShape, false for TensorShape
static constexpr bool kIsPartial =
@ -314,6 +374,13 @@ class TensorShape : public TensorShapeBase<TensorShape> {
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
// Same as `AsEigenDSizes()` but returns a `Status` instead.
// Use this method to surface error to user instead of crashing if `NDMIS` is
// not equal to `dims()`.
// Caller must take ownership of `out`.
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
/// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
/// which case we pad the rest of the sizes with 1.
/// Notice: Using IndexType=int32 in combination with To32Bit() can
@ -321,6 +388,14 @@ class TensorShape : public TensorShapeBase<TensorShape> {
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
// Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
// Use this method to surface error to user instead of crashing if `NDMIS` is
// not equal to `dims()`.
// Caller must take ownership of `out`.
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Status AsEigenDSizesWithPaddingWithStatus(
Eigen::DSizes<IndexType, NDIMS>* out) const;
private:
// These CHECK fail to ease debugging.
// REQUIRES: dims() == NDIMS
@ -328,6 +403,18 @@ class TensorShape : public TensorShapeBase<TensorShape> {
// REQUIRES: dims() >= NDIMS
void CheckDimsAtLeast(int NDIMS) const;
// Fill output from `*this`.
// Helper method for common code between `AsEigenDSize()` and
// `AsEigenDSizeWithStatus()`.
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
// Fill output from `*this`.
// Helper method for common code between `AsEigenDSizesWithPadding()` and
// `AsEigenDSizeWithPaddingWithStatus()`.
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
// For access to TensorShapeBase(DataType).
friend class Tensor;
};
@ -426,10 +513,21 @@ class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
/// REQUIRES: `size >= -1`, where -1 means unknown
PartialTensorShape Concatenate(int64 size) const;
/// Similar to `Concatenate` but returning `Status`.
/// Use if calling code cannot validate all requirements and if `CHECK`-fails
/// are to be avoided.
Status ConcatenateWithStatus(int64 size, PartialTensorShape* out) const;
/// Appends all the dimensions from `shape`. Returns a new
/// PartialTensorShape.
PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
/// Similar to `Concatenate` but returning `Status`.
/// Use if calling code cannot validate all requirements and if `CHECK`-fails
/// are to be avoided.
Status ConcatenateWithStatus(const PartialTensorShape& shape,
PartialTensorShape* out) const;
/// Merges all the dimensions from `shape`. Returns
/// `InvalidArgument` error if either `shape` has a different rank
/// or if any of the dimensions are incompatible.
@ -481,8 +579,7 @@ class PartialTensorShapeUtils {
// ----------------------------------------------------------------------------
template <int NDIMS, typename IndexType>
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
CheckDimsEqual(NDIMS);
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
Eigen::DSizes<IndexType, NDIMS> dsizes;
for (int d = 0; d < NDIMS; d++) {
dsizes[d] = static_cast<IndexType>(dim_size(d));
@ -491,8 +588,7 @@ Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
}
template <int NDIMS, typename IndexType>
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
CheckDimsAtLeast(NDIMS);
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
Eigen::DSizes<IndexType, NDIMS> dsizes;
for (int d = 0; d < dims(); d++) {
@ -504,6 +600,40 @@ Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
return dsizes;
}
template <int NDIMS, typename IndexType>
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
CheckDimsEqual(NDIMS);
return AsEigenDSizesCopy<NDIMS, IndexType>();
}
template <int NDIMS, typename IndexType>
Status TensorShape::AsEigenDSizesWithStatus(
Eigen::DSizes<IndexType, NDIMS>* out) const {
if (TF_PREDICT_FALSE(NDIMS != dims())) {
return errors::Internal("Asking for tensor of ", NDIMS,
" dimensions from a tensor of ", dims(),
" dimensions");
}
*out = AsEigenDSizesCopy<NDIMS, IndexType>();
}
template <int NDIMS, typename IndexType>
Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
CheckDimsAtLeast(NDIMS);
return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
}
template <int NDIMS, typename IndexType>
Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
Eigen::DSizes<IndexType, NDIMS>* out) const {
if (TF_PREDICT_FALSE(NDIMS < dims())) {
return errors::Internal("Asking for tensor of at least ", NDIMS,
" dimensions from a tensor of ", dims(),
" dimensions");
}
*out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
}
// ----------------------------------------------------------------------------
// Inlining of some performance critical routines
// ----------------------------------------------------------------------------

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
class TensorShapeTestHelper {
@ -205,6 +206,28 @@ TEST(TensorShapeTest, ostream) {
EXPECT_EQ(ss.str(), "[10,5,4]");
}
TEST(TensorShapeTest, AddDimWithStatus) {
TensorShape s({10, 5, 20});
Status status = s.AddDimWithStatus(400);
EXPECT_TRUE(status.ok());
EXPECT_EQ(400000, s.num_elements());
ASSERT_EQ(4, s.dims());
status = s.AddDimWithStatus(-1);
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
}
TEST(TensorShapeTest, Factory) {
TensorShape s;
Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s);
EXPECT_TRUE(status.ok());
EXPECT_EQ(1000, s.num_elements());
ASSERT_EQ(3, s.dims());
status = TensorShape::BuildTensorShapeBase({-10, 5, 20}, &s);
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
}
// -----------------------------------------------------------------------
// An old implementation of TensorShape using a different representation,
// preserved here in the unittest to allow us to have a randomized unittest