Add support for ssize_t in TensorShape
PiperOrigin-RevId: 305053288 Change-Id: Ie63614b506444e186b8ad0e2ab0e2a655670f930
This commit is contained in:
parent
d422ab9a2d
commit
6db5faa3e2
@ -281,6 +281,7 @@ cc_library(
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -53,23 +53,23 @@ class AbstractContextInterface {
|
||||
|
||||
// Tensor creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
absl::Span<const ssize_t> dim_sizes) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
@ -69,15 +70,9 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
namespace {
|
||||
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, size_t len) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
tensorflow::TensorShape(absl::MakeSpan(dims, num_dims)), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
|
@ -162,47 +162,47 @@ AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_INT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_UINT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_INT32, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_FLOAT, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_HALF, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_STRING, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
absl::Span<const ssize_t> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_BOOL, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
|
@ -161,23 +161,23 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
AbstractTensorInterface* CreateBoolScalar(bool value) override;
|
||||
|
||||
AbstractTensorInterface* CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
absl::Span<const ssize_t> dim_sizes) override;
|
||||
|
||||
AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) override;
|
||||
|
@ -160,6 +160,12 @@ TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
|
||||
InitDims(dim_sizes);
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<ssize_t> dim_sizes)
|
||||
: TensorShapeBase(gtl::ArraySlice<int64>(
|
||||
reinterpret_cast<const int64*>(dim_sizes.data()), dim_sizes.size())) {
|
||||
}
|
||||
|
||||
// Returns true iff partial is true and val is < 0.
|
||||
// REQUIRES: val < kMaxRep16
|
||||
// REQUIRES: partial || val >= 0
|
||||
|
@ -167,6 +167,7 @@ class TensorShapeBase : public TensorShapeRep {
|
||||
public:
|
||||
/// \brief Construct a `TensorShapeBase` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
|
||||
explicit TensorShapeBase(gtl::ArraySlice<ssize_t> dim_sizes);
|
||||
explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
|
||||
TensorShapeBase(std::initializer_list<int64> dim_sizes)
|
||||
: TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
@ -72,7 +72,7 @@ bool IsPyFloat(PyObject* obj) {
|
||||
|
||||
struct ConverterState {
|
||||
// The inferred tensor shape.
|
||||
gtl::InlinedVector<int64, 4> inferred_shape;
|
||||
gtl::InlinedVector<ssize_t, 4> inferred_shape;
|
||||
|
||||
// The inferred tensor data type.
|
||||
DataType inferred_dtype;
|
||||
@ -320,7 +320,7 @@ struct ConverterTraits<int64> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateInt64Tensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -360,7 +360,7 @@ struct ConverterTraits<uint64> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateUint64Tensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -397,7 +397,7 @@ struct ConverterTraits<int32> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateInt32Tensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -504,7 +504,7 @@ struct ConverterTraits<float> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateFloatTensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -520,7 +520,7 @@ struct ConverterTraits<double> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateDoubleTensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -540,7 +540,7 @@ struct ConverterTraits<Eigen::half> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateHalfTensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -561,7 +561,7 @@ struct ConverterTraits<tstring> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateStringTensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -628,7 +628,7 @@ struct ConverterTraits<complex128> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateComplex128Tensor(dim_sizes);
|
||||
}
|
||||
|
||||
@ -656,7 +656,7 @@ struct ConverterTraits<bool> {
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
TFE_Context* ctx, absl::Span<const ssize_t> dim_sizes) {
|
||||
return ctx->context->CreateBoolTensor(dim_sizes);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user