Add support for ssize_t in TensorShape

PiperOrigin-RevId: 305053288
Change-Id: Ie63614b506444e186b8ad0e2ab0e2a655670f930
This commit is contained in:
Gaurav Jain 2020-04-06 09:56:35 -07:00 committed by TensorFlower Gardener
parent d422ab9a2d
commit 6db5faa3e2
8 changed files with 47 additions and 44 deletions

View File

@ -281,6 +281,7 @@ cc_library(
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"@com_google_absl//absl/types:span",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -53,23 +53,23 @@ class AbstractContextInterface {
// Tensor creation functions
virtual AbstractTensorInterface* CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
absl::Span<const ssize_t> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
@ -69,15 +70,9 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
namespace {
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
tensorflow::TensorShape(absl::MakeSpan(dims, num_dims)), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {

View File

@ -162,47 +162,47 @@ AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
}
AbstractTensorInterface* EagerContext::CreateInt64Tensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_INT64, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateUint64Tensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_UINT64, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateInt32Tensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_INT32, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateFloatTensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_FLOAT, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateDoubleTensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateHalfTensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_HALF, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateStringTensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_STRING, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateBoolTensor(
absl::Span<const int64> dim_sizes) {
absl::Span<const ssize_t> dim_sizes) {
return new TensorInterface(Tensor(DT_BOOL, TensorShape(dim_sizes)));
}

View File

@ -161,23 +161,23 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
AbstractTensorInterface* CreateBoolScalar(bool value) override;
AbstractTensorInterface* CreateInt64Tensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateUint64Tensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateInt32Tensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateFloatTensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateDoubleTensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateHalfTensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateStringTensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorInterface* CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
absl::Span<const ssize_t> dim_sizes) override;
AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) override;

View File

@ -160,6 +160,12 @@ TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
InitDims(dim_sizes);
}
template <class Shape>
TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<ssize_t> dim_sizes)
: TensorShapeBase(gtl::ArraySlice<int64>(
reinterpret_cast<const int64*>(dim_sizes.data()), dim_sizes.size())) {
}
// Returns true iff partial is true and val is < 0.
// REQUIRES: val < kMaxRep16
// REQUIRES: partial || val >= 0

View File

@ -167,6 +167,7 @@ class TensorShapeBase : public TensorShapeRep {
public:
/// \brief Construct a `TensorShapeBase` from the provided sizes.
/// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
explicit TensorShapeBase(gtl::ArraySlice<ssize_t> dim_sizes);
explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
TensorShapeBase(std::initializer_list<int64> dim_sizes)
: TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}

View File

@ -72,7 +72,7 @@ bool IsPyFloat(PyObject* obj) {
struct ConverterState {
// The inferred tensor shape.
gtl::InlinedVector<int64, 4> inferred_shape;
gtl::InlinedVector<ssize_t, 4> inferred_shape;
// The inferred tensor data type.
DataType inferred_dtype;
@ -320,7 +320,7 @@ struct ConverterTraits<int64> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateInt64Tensor(dim_sizes);
}
@ -360,7 +360,7 @@ struct ConverterTraits<uint64> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateUint64Tensor(dim_sizes);
}
@ -397,7 +397,7 @@ struct ConverterTraits<int32> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateInt32Tensor(dim_sizes);
}
@ -504,7 +504,7 @@ struct ConverterTraits<float> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateFloatTensor(dim_sizes);
}
@ -520,7 +520,7 @@ struct ConverterTraits<double> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateDoubleTensor(dim_sizes);
}
@ -540,7 +540,7 @@ struct ConverterTraits<Eigen::half> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateHalfTensor(dim_sizes);
}
@ -561,7 +561,7 @@ struct ConverterTraits<tstring> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateStringTensor(dim_sizes);
}
@ -628,7 +628,7 @@ struct ConverterTraits<complex128> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateComplex128Tensor(dim_sizes);
}
@ -656,7 +656,7 @@ struct ConverterTraits<bool> {
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
return ctx->context->CreateBoolTensor(dim_sizes);
}