Eliminate unnecessary packing/unpacking of string tensors across the C-API and language bindings.
With the std::string -> tensorflow::tstring migration in tensorflow core complete, we now have an ABI stable string type with associated C accessors/modifiers. The packing/unpacking of string tensors across the C-API is now superfluous and can be removed. This is an ABI breaking change; updates to language bindings in tensorflow/ are included in this CL. PiperOrigin-RevId: 318933001 Change-Id: I3c99cf70834ba0b6cefc4ba39a35d6c168e880db
This commit is contained in:
parent
2ab7873606
commit
24f835217f
@ -6,6 +6,11 @@
|
|||||||
|
|
||||||
* <DOCUMENT BREAKING CHANGES HERE>
|
* <DOCUMENT BREAKING CHANGES HERE>
|
||||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||||
|
* The byte layout for string tensors across the C-API has been updated to match
|
||||||
|
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||||
|
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||||
|
`TF_StringEncodedSize` are no longer relevant and have been removed; see
|
||||||
|
core/platform/ctstring.h for string access/modification in C.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
|
|
||||||
|
@ -29,6 +29,8 @@ filegroup(
|
|||||||
"tf_file_statistics.h",
|
"tf_file_statistics.h",
|
||||||
"tf_status.h",
|
"tf_status.h",
|
||||||
"tf_tensor.h",
|
"tf_tensor.h",
|
||||||
|
"tf_tstring.h",
|
||||||
|
"//tensorflow/core/platform:ctstring",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
@ -48,6 +50,7 @@ filegroup(
|
|||||||
"*test*",
|
"*test*",
|
||||||
],
|
],
|
||||||
) + [
|
) + [
|
||||||
|
"//tensorflow/core/platform:ctstring",
|
||||||
"//tensorflow/cc:srcs_no_runtime",
|
"//tensorflow/cc:srcs_no_runtime",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||||
],
|
],
|
||||||
@ -78,6 +81,7 @@ tf_cuda_library(
|
|||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
"tf_datatype.h",
|
"tf_datatype.h",
|
||||||
"tf_tensor.h",
|
"tf_tensor.h",
|
||||||
|
"tf_tstring.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/c/tf_tstring.h"
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// C API for TensorFlow.
|
// C API for TensorFlow.
|
||||||
|
@ -2286,14 +2286,15 @@ TEST_F(CApiAttributesTest, Tensor) {
|
|||||||
|
|
||||||
TEST_F(CApiAttributesTest, StringTensor) {
|
TEST_F(CApiAttributesTest, StringTensor) {
|
||||||
// Create the string-Tensor "attribute" value.
|
// Create the string-Tensor "attribute" value.
|
||||||
char encoded[] = {
|
const char test_string[] =
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
|
"borkborkborkborkborkborkborkbork"; // >24bytes to force heap alloc
|
||||||
1, // varint encoded string length
|
TF_TString tstr[1];
|
||||||
'A',
|
TF_TString_Init(&tstr[0]);
|
||||||
};
|
TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
|
||||||
|
|
||||||
auto deallocator = [](void* data, size_t len, void* arg) {};
|
auto deallocator = [](void* data, size_t len, void* arg) {};
|
||||||
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
|
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
|
||||||
sizeof(encoded), deallocator, nullptr),
|
sizeof(tstr), deallocator, nullptr),
|
||||||
TF_DeleteTensor);
|
TF_DeleteTensor);
|
||||||
|
|
||||||
// Create a TF_Operation with the attribute t_in
|
// Create a TF_Operation with the attribute t_in
|
||||||
@ -2312,9 +2313,17 @@ TEST_F(CApiAttributesTest, StringTensor) {
|
|||||||
EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
|
EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
|
||||||
EXPECT_EQ(0, TF_NumDims(t_out));
|
EXPECT_EQ(0, TF_NumDims(t_out));
|
||||||
ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
|
ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
|
||||||
EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
|
TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
|
||||||
TF_TensorByteSize(t_out)));
|
TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
|
||||||
|
EXPECT_EQ(absl::string_view(test_string),
|
||||||
|
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||||
|
TF_TString_GetSize(t_out_tstr)));
|
||||||
|
EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
|
||||||
|
TF_TString_GetSize(t_in_tstr)),
|
||||||
|
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||||
|
TF_TString_GetSize(t_out_tstr)));
|
||||||
TF_DeleteTensor(t_out);
|
TF_DeleteTensor(t_out);
|
||||||
|
TF_TString_Dealloc(&tstr[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CApiAttributesTest, TensorList) {
|
TEST_F(CApiAttributesTest, TensorList) {
|
||||||
|
@ -228,57 +228,6 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
|
|||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
|
||||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
|
||||||
memcpy(dst, src, src_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
|
||||||
size_t dst_len, TF_Status* status) {
|
|
||||||
const size_t sz = TF_StringEncodedSize(src_len);
|
|
||||||
if (sz < src_len) {
|
|
||||||
Set_TF_Status_from_Status(
|
|
||||||
status, InvalidArgument("src string is too large to encode"));
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
if (dst_len < sz) {
|
|
||||||
Set_TF_Status_from_Status(
|
|
||||||
status,
|
|
||||||
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
|
|
||||||
src_len, "-byte string"));
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
StringEncode(src, src_len, dst);
|
|
||||||
return sz;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Status TF_StringDecode_Impl(const char* src, size_t src_len,
|
|
||||||
const char** dst, size_t* dst_len) {
|
|
||||||
tensorflow::uint64 len64 = 0;
|
|
||||||
const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
|
|
||||||
if (p == nullptr) {
|
|
||||||
return InvalidArgument("invalid string encoding or truncated src buffer");
|
|
||||||
}
|
|
||||||
if (len64 > std::numeric_limits<size_t>::max()) {
|
|
||||||
return InvalidArgument("encoded string is ", len64,
|
|
||||||
"-bytes, which is too large for this architecture");
|
|
||||||
}
|
|
||||||
*dst = p;
|
|
||||||
*dst_len = static_cast<size_t>(len64);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
|
|
||||||
size_t* dst_len, TF_Status* status) {
|
|
||||||
Set_TF_Status_from_Status(status,
|
|
||||||
TF_StringDecode_Impl(src, src_len, dst, dst_len));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return 0;
|
|
||||||
return static_cast<size_t>(*dst - src) + *dst_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t TF_StringEncodedSize(size_t len) {
|
|
||||||
return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void DeleteArray(void* data, size_t size, void* arg) {
|
static void DeleteArray(void* data, size_t size, void* arg) {
|
||||||
DCHECK_EQ(data, arg);
|
DCHECK_EQ(data, arg);
|
||||||
@ -334,58 +283,12 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
|||||||
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
|
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
if (src.dtype() != tensorflow::DT_STRING) {
|
|
||||||
Tensor tensor;
|
Tensor tensor;
|
||||||
if (!tensor.CopyFrom(src, src.shape())) {
|
if (!tensor.CopyFrom(src, src.shape())) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
||||||
}
|
|
||||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
|
||||||
// encoded sequence of strings.
|
|
||||||
|
|
||||||
// Compute bytes needed for encoding.
|
|
||||||
size_t size = 0;
|
|
||||||
const auto& srcarray = src.flat<tstring>();
|
|
||||||
for (int i = 0; i < srcarray.size(); ++i) {
|
|
||||||
const string& s = srcarray(i);
|
|
||||||
// uint64 starting_offset, TF_StringEncode-d string.
|
|
||||||
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode all strings.
|
|
||||||
char* base = new char[size];
|
|
||||||
char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
|
|
||||||
char* dst = data_start; // Where next string is encoded.
|
|
||||||
size_t dst_len = size - static_cast<size_t>(data_start - base);
|
|
||||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
|
||||||
for (int i = 0; i < srcarray.size(); ++i) {
|
|
||||||
*offsets = (dst - data_start);
|
|
||||||
offsets++;
|
|
||||||
const string& s = srcarray(i);
|
|
||||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
|
||||||
StringEncode(s.data(), s.size(), dst);
|
|
||||||
dst += consumed;
|
|
||||||
dst_len -= consumed;
|
|
||||||
}
|
|
||||||
if (dst != base + size) {
|
|
||||||
*status = InvalidArgument(
|
|
||||||
"invalid string tensor encoding (decoded ", (dst - base),
|
|
||||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
|
||||||
delete[] base;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dims = src.shape().dim_sizes();
|
|
||||||
std::vector<tensorflow::int64> dimvec(dims.size());
|
|
||||||
for (size_t i = 0; i < dims.size(); ++i) {
|
|
||||||
dimvec[i] = dims[i];
|
|
||||||
}
|
|
||||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
||||||
"64-bit int types should match in size");
|
|
||||||
return TF_NewTensor(TF_STRING,
|
|
||||||
reinterpret_cast<const int64_t*>(dimvec.data()),
|
|
||||||
dimvec.size(), base, size, DeleteArray, base);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||||
@ -409,40 +312,8 @@ Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
if (tensor_.dtype() != DT_STRING) {
|
|
||||||
*dst = tensor_;
|
*dst = tensor_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
|
||||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
|
||||||
// string objects.
|
|
||||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
|
||||||
const char* input = reinterpret_cast<const char*>(Data());
|
|
||||||
const size_t src_size = ByteSize();
|
|
||||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
|
||||||
num_elements) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Malformed TF_STRING tensor; too short to hold number of elements");
|
|
||||||
}
|
|
||||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
|
||||||
const char* limit = input + src_size;
|
|
||||||
|
|
||||||
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
|
|
||||||
auto dstarray = dst->flat<tstring>();
|
|
||||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
|
||||||
tensorflow::uint64 offset =
|
|
||||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
|
||||||
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
|
|
||||||
return InvalidArgument("Malformed TF_STRING tensor; element ", i,
|
|
||||||
" out of range");
|
|
||||||
}
|
|
||||||
size_t len;
|
|
||||||
const char* p;
|
|
||||||
const char* srcp = data_start + offset;
|
|
||||||
Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
|
|
||||||
if (!status.ok()) return status;
|
|
||||||
dstarray(i).assign(p, len);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
||||||
|
@ -148,34 +148,6 @@ TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from,
|
|||||||
int num_new_dims,
|
int num_new_dims,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// Encode the string `src` (`src_len` bytes long) into `dst` in the format
|
|
||||||
// required by TF_STRING tensors. Does not write to memory more than `dst_len`
|
|
||||||
// bytes beyond `*dst`. `dst_len` should be at least
|
|
||||||
// TF_StringEncodedSize(src_len).
|
|
||||||
//
|
|
||||||
// On success returns the size in bytes of the encoded string.
|
|
||||||
// Returns an error into `status` otherwise.
|
|
||||||
TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len,
|
|
||||||
char* dst, size_t dst_len,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Decode a string encoded using TF_StringEncode.
|
|
||||||
//
|
|
||||||
// On success, sets `*dst` to the start of the decoded string and `*dst_len` to
|
|
||||||
// its length. Returns the number of bytes starting at `src` consumed while
|
|
||||||
// decoding. `*dst` points to memory within the encoded buffer. On failure,
|
|
||||||
// `*dst` and `*dst_len` are undefined and an error is set in `status`.
|
|
||||||
//
|
|
||||||
// Does not read memory more than `src_len` bytes beyond `src`.
|
|
||||||
TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len,
|
|
||||||
const char** dst, size_t* dst_len,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Return the size in bytes required to encode a string `len` bytes long into a
|
|
||||||
// TF_STRING tensor.
|
|
||||||
TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
|
|
||||||
|
|
||||||
// Returns bool iff this tensor is aligned.
|
// Returns bool iff this tensor is aligned.
|
||||||
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
|
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
|
||||||
|
|
||||||
|
20
tensorflow/c/tf_tstring.h
Normal file
20
tensorflow/c/tf_tstring.h
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_TF_TSTRING_H_
|
||||||
|
#define TENSORFLOW_C_TF_TSTRING_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/ctstring.h"
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_
|
@ -762,6 +762,14 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "ctstring",
|
||||||
|
srcs = [
|
||||||
|
"ctstring.h",
|
||||||
|
"ctstring_internal.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "types",
|
name = "types",
|
||||||
hdrs = ["types.h"],
|
hdrs = ["types.h"],
|
||||||
|
@ -16,14 +16,20 @@ limitations under the License.
|
|||||||
|
|
||||||
package tensorflow
|
package tensorflow
|
||||||
|
|
||||||
// #include <stdlib.h>
|
/*
|
||||||
// #include <string.h>
|
#include <stdlib.h>
|
||||||
// #include "tensorflow/c/c_api.h"
|
#include <string.h>
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
void toNewTString(_GoString_ gstr, TF_TString *tstr) {
|
||||||
|
TF_TString_Init(tstr);
|
||||||
|
TF_TString_Copy(tstr, _GoStringPtr(gstr), _GoStringLen(gstr));
|
||||||
|
}
|
||||||
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/bits"
|
"math/bits"
|
||||||
@ -79,9 +85,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
|||||||
nflattened := numElements(shape)
|
nflattened := numElements(shape)
|
||||||
nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
|
nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
|
||||||
if dataType == String {
|
if dataType == String {
|
||||||
// TF_STRING tensors are encoded as an array of 8-byte offsets
|
nbytes = uintptr(nflattened) * C.sizeof_TF_TString
|
||||||
// followed by string data. See c_api.h.
|
|
||||||
nbytes = uintptr(nflattened*8 + int64(byteSizeOfEncodedStrings(val)))
|
|
||||||
}
|
}
|
||||||
var shapePtr *C.int64_t
|
var shapePtr *C.int64_t
|
||||||
if len(shape) > 0 {
|
if len(shape) > 0 {
|
||||||
@ -94,7 +98,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
|||||||
runtime.SetFinalizer(t, (*Tensor).finalize)
|
runtime.SetFinalizer(t, (*Tensor).finalize)
|
||||||
raw := tensorData(t.c)
|
raw := tensorData(t.c)
|
||||||
buf := bytes.NewBuffer(raw[:0:len(raw)])
|
buf := bytes.NewBuffer(raw[:0:len(raw)])
|
||||||
if dataType != String {
|
|
||||||
if isAllArray(val.Type()) {
|
if isAllArray(val.Type()) {
|
||||||
// We have arrays all the way down, or just primitive types. We can
|
// We have arrays all the way down, or just primitive types. We can
|
||||||
// just copy the memory in as it is all contiguous.
|
// just copy the memory in as it is all contiguous.
|
||||||
@ -114,15 +118,6 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
|||||||
if uintptr(buf.Len()) != nbytes {
|
if uintptr(buf.Len()) != nbytes {
|
||||||
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
|
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()}
|
|
||||||
if err := e.encode(reflect.ValueOf(value), shape); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if int64(buf.Len()) != nflattened*8 {
|
|
||||||
return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +126,8 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
|||||||
// contiguous in RAM.
|
// contiguous in RAM.
|
||||||
func isAllArray(typ reflect.Type) bool {
|
func isAllArray(typ reflect.Type) bool {
|
||||||
switch typ.Kind() {
|
switch typ.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return false
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
return false
|
return false
|
||||||
case reflect.Array:
|
case reflect.Array:
|
||||||
@ -312,58 +309,16 @@ func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
|
|||||||
|
|
||||||
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
|
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
|
||||||
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
|
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
|
||||||
// Start by making an array of all the strings
|
|
||||||
strs := make([]string, nStrings)
|
strs := make([]string, nStrings)
|
||||||
// The first nStrings * 8 bytes of raw are offsets into the second half of
|
tstrs := (*(*[]C.TF_TString)(unsafe.Pointer(&raw)))[:nStrings]
|
||||||
// the raw data. This second half is where the strings are encoded.
|
|
||||||
offsets := (*(*[]int64)(unsafe.Pointer(&raw)))[:nStrings]
|
|
||||||
|
|
||||||
// Reset raw after the offsets. Now the offsets will work relative to raw
|
for i, tstr := range tstrs {
|
||||||
raw = raw[nStrings*8:]
|
dst := C.TF_TString_GetDataPointer(&tstr)
|
||||||
// Next we work out the final length of the string data so we can copy the
|
dstLen := C.TF_TString_GetSize(&tstr)
|
||||||
// good data out of raw (which is owned by the C tensor and won't be safe
|
|
||||||
// to access if the tensor is freed)
|
strs[i] = C.GoStringN(dst, C.int(dstLen))
|
||||||
r := bytes.NewReader(raw)
|
|
||||||
var totalLength int
|
|
||||||
for _, offset := range offsets {
|
|
||||||
// At each offset we should find a varint length of a string.
|
|
||||||
// Errors here should mean the tensor is corrupt.
|
|
||||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
l, err := binary.ReadUvarint(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
totalLength += int(l)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lets allocate a big buffer to carry our string data.
|
|
||||||
stringData := make([]byte, 0, totalLength)
|
|
||||||
// Now copy the string data across into our new buffer, keeping track of the
|
|
||||||
// location of each string in the strs slice.
|
|
||||||
var cursor int
|
|
||||||
for i, offset := range offsets {
|
|
||||||
// At each offset we should find a varint length. Read it
|
|
||||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
l, err := binary.ReadUvarint(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then copy the actual string into our large buffer
|
|
||||||
target := stringData[cursor : cursor+int(l)]
|
|
||||||
if _, err := r.Read(target); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Track where this string data is.
|
|
||||||
strs[i] = *(*string)(unsafe.Pointer(&target))
|
|
||||||
cursor += int(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// So now we have a big slice of strings
|
|
||||||
return strs, nil
|
return strs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -469,26 +424,6 @@ func numElements(shape []int64) int64 {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
|
|
||||||
// val MUST be a string, or a container (array/slice etc.) of strings.
|
|
||||||
// Tensorflow encodes strings as the varint encoded length followed by the
|
|
||||||
// string bytes. We could call into the C library to do this but cgo has a heavy
|
|
||||||
// overhead. So we just do that calculation in Go
|
|
||||||
func byteSizeOfEncodedStrings(val reflect.Value) int {
|
|
||||||
if val.Kind() == reflect.String {
|
|
||||||
return sizeVarUint(uint64(val.Len())) + val.Len()
|
|
||||||
}
|
|
||||||
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
|
|
||||||
panic(fmt.Sprintf("unexpected type %s", val.Type()))
|
|
||||||
}
|
|
||||||
// Otherwise must be an array or slice.
|
|
||||||
var size int
|
|
||||||
for i := 0; i < val.Len(); i++ {
|
|
||||||
size += byteSizeOfEncodedStrings(val.Index(i))
|
|
||||||
}
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
// sizeVarUint determines how many bytes it would take to encode the int v as
|
// sizeVarUint determines how many bytes it would take to encode the int v as
|
||||||
// an unsigned varint
|
// an unsigned varint
|
||||||
func sizeVarUint(v uint64) int {
|
func sizeVarUint(v uint64) int {
|
||||||
@ -509,12 +444,18 @@ func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) err
|
|||||||
if v.Len() != expected {
|
if v.Len() != expected {
|
||||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||||
}
|
}
|
||||||
|
} else if v.Kind() == reflect.String {
|
||||||
|
s := v.Interface().(string)
|
||||||
|
var tstr C.TF_TString
|
||||||
|
C.toNewTString(s, &tstr)
|
||||||
|
ptr := unsafe.Pointer(&tstr)
|
||||||
|
return copyPtr(w, ptr, C.sizeof_TF_TString)
|
||||||
} else if v.Kind() != reflect.Array {
|
} else if v.Kind() != reflect.Array {
|
||||||
return fmt.Errorf("unsupported type %v", v.Type())
|
return fmt.Errorf("unsupported type %v", v.Type())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Once we have just a single dimension we can just copy the data
|
// Once we have just a single dimension we can just copy the data
|
||||||
if len(shape) == 1 && v.Len() > 0 {
|
if len(shape) == 1 && v.Len() > 0 && v.Index(0).Kind() != reflect.String {
|
||||||
elt := v.Index(0)
|
elt := v.Index(0)
|
||||||
if !elt.CanAddr() {
|
if !elt.CanAddr() {
|
||||||
panic("cannot take address")
|
panic("cannot take address")
|
||||||
@ -556,45 +497,6 @@ func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type stringEncoder struct {
|
|
||||||
offsets *bytes.Buffer
|
|
||||||
data []byte
|
|
||||||
offset uint64
|
|
||||||
status *status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
|
|
||||||
if v.Kind() == reflect.String {
|
|
||||||
if err := copyPtr(e.offsets, unsafe.Pointer(&e.offset), int(unsafe.Sizeof(e.offset))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// A string is encoded as the varint length followed by the string bytes.
|
|
||||||
// We do this in Go to avoid the considerable overhead of a cgo call into
|
|
||||||
// the tensorflow library
|
|
||||||
s := v.String()
|
|
||||||
n := binary.PutUvarint(e.data[e.offset:], uint64(len(s)))
|
|
||||||
e.offset += uint64(n)
|
|
||||||
n = copy(e.data[e.offset:], s)
|
|
||||||
e.offset += uint64(n)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Kind() == reflect.Slice {
|
|
||||||
expected := int(shape[0])
|
|
||||||
if v.Len() != expected {
|
|
||||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
subShape := shape[1:]
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
if err := e.encode(v.Index(i), subShape); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func bug(format string, args ...interface{}) error {
|
func bug(format string, args ...interface{}) error {
|
||||||
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
||||||
}
|
}
|
||||||
|
@ -220,16 +220,13 @@ size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
|
jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const TF_TString* src) {
|
||||||
size_t src_len, TF_Status* status) {
|
const char* dst = TF_TString_GetDataPointer(src);
|
||||||
const char* dst = nullptr;
|
size_t dst_len = TF_TString_GetSize(src);
|
||||||
size_t dst_len = 0;
|
|
||||||
TF_StringDecode(src, src_len, &dst, &dst_len, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
jbyteArray ret = env->NewByteArray(dst_len);
|
jbyteArray ret = env->NewByteArray(dst_len);
|
||||||
jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
|
jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
|
||||||
|
|
||||||
memcpy(cpy, dst, dst_len);
|
memcpy(cpy, dst, dst_len);
|
||||||
env->ReleaseByteArrayElements(ret, cpy, 0);
|
env->ReleaseByteArrayElements(ret, cpy, 0);
|
||||||
return ret;
|
return ret;
|
||||||
@ -238,69 +235,32 @@ jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
|
|||||||
class StringTensorWriter {
|
class StringTensorWriter {
|
||||||
public:
|
public:
|
||||||
StringTensorWriter(TF_Tensor* t, int num_elements)
|
StringTensorWriter(TF_Tensor* t, int num_elements)
|
||||||
: offset_(0),
|
: index_(0), data_(static_cast<TF_TString*>(TF_TensorData(t))) {}
|
||||||
poffsets_(static_cast<char*>(TF_TensorData(t))),
|
|
||||||
pdata_(poffsets_ + 8 * num_elements),
|
|
||||||
plimit_(poffsets_ + TF_TensorByteSize(t)) {}
|
|
||||||
|
|
||||||
void Add(const char* src, size_t len, TF_Status* status) {
|
void Add(const char* src, size_t len, TF_Status* status) {
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
if (plimit_ - poffsets_ < sizeof(offset_)) {
|
TF_TString_Init(&data_[index_]);
|
||||||
TF_SetStatus(status, TF_OUT_OF_RANGE,
|
TF_TString_Copy(&data_[index_++], src, len);
|
||||||
"TF_STRING tensor encoding ran out of space for offsets, "
|
|
||||||
"this is likely a bug, please file an issue at "
|
|
||||||
"https://github.com/tensorflow/tensorflow/issues/new");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
memcpy(poffsets_, &offset_, sizeof(offset_));
|
|
||||||
size_t written =
|
|
||||||
TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
|
|
||||||
offset_ += written;
|
|
||||||
poffsets_ += 8;
|
|
||||||
pdata_ += written;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t offset_;
|
int index_;
|
||||||
char* poffsets_;
|
TF_TString* data_;
|
||||||
char* pdata_;
|
|
||||||
const char* plimit_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class StringTensorReader {
|
class StringTensorReader {
|
||||||
public:
|
public:
|
||||||
StringTensorReader(const TF_Tensor* t, int num_elements)
|
StringTensorReader(const TF_Tensor* t, int num_elements)
|
||||||
: index_(0),
|
: index_(0), data_(static_cast<const TF_TString*>(TF_TensorData(t))) {}
|
||||||
offsets_(static_cast<const char*>(TF_TensorData(t))),
|
|
||||||
data_(offsets_ + 8 * num_elements),
|
|
||||||
limit_(offsets_ + TF_TensorByteSize(t)) {}
|
|
||||||
|
|
||||||
jbyteArray Next(JNIEnv* env, TF_Status* status) {
|
jbyteArray Next(JNIEnv* env, TF_Status* status) {
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
uint64_t offset = 0;
|
return TF_StringDecodeTojbyteArray(env, &data_[index_++]);
|
||||||
const char* poffset = offsets_ + sizeof(offset) * index_;
|
|
||||||
if (poffset >= limit_) {
|
|
||||||
TF_SetStatus(
|
|
||||||
status, TF_INTERNAL,
|
|
||||||
"Invalid TF_STRING tensor, offsets table seems to be too small");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
memcpy(&offset, poffset, sizeof(offset));
|
|
||||||
const char* pdata = data_ + offset;
|
|
||||||
if (pdata >= limit_) {
|
|
||||||
TF_SetStatus(status, TF_INTERNAL,
|
|
||||||
"Invalid TF_STRING tensor, invalid entry in offset table");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
++index_;
|
|
||||||
return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int index_;
|
int index_;
|
||||||
const char* offsets_;
|
const TF_TString* data_;
|
||||||
const char* data_;
|
|
||||||
const char* limit_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
|
void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
|
||||||
@ -367,17 +327,16 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
|
|||||||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
|
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
|
||||||
// TF_StringEncode-encoded bytes.
|
// TF_StringEncode-encoded bytes.
|
||||||
size_t src_len = static_cast<int>(env->GetArrayLength(value));
|
size_t src_len = static_cast<int>(env->GetArrayLength(value));
|
||||||
size_t dst_len = TF_StringEncodedSize(src_len);
|
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, sizeof(TF_TString));
|
||||||
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
|
TF_TString* dst = static_cast<TF_TString*>(TF_TensorData(t));
|
||||||
char* dst = static_cast<char*>(TF_TensorData(t));
|
|
||||||
memset(dst, 0, 8); // The offset table
|
|
||||||
|
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
|
jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
|
||||||
// jsrc is an unsigned byte*, TF_StringEncode requires a char*.
|
// jsrc is an unsigned byte*, TF_StringEncode requires a char*.
|
||||||
// reinterpret_cast<> for this conversion should be safe.
|
// reinterpret_cast<> for this conversion should be safe.
|
||||||
TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
|
TF_TString_Init(&dst[0]);
|
||||||
dst_len, status);
|
TF_TString_Copy(&dst[0], reinterpret_cast<const char*>(jsrc), src_len);
|
||||||
|
|
||||||
env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
|
env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
|
||||||
if (!throwExceptionIfNotOK(env, status)) {
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
@ -388,27 +347,18 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
|
void checkForNullEntries(JNIEnv* env, jarray value, int num_dims) {
|
||||||
if (num_dims == 0) {
|
|
||||||
// This is the last dimension, i.e., value should correspond to a jbyteArray
|
|
||||||
// encoding the string.
|
|
||||||
return TF_StringEncodedSize(
|
|
||||||
static_cast<size_t>(env->GetArrayLength(value)));
|
|
||||||
}
|
|
||||||
jsize len = env->GetArrayLength(value);
|
jsize len = env->GetArrayLength(value);
|
||||||
size_t ret = 0;
|
|
||||||
for (jsize i = 0; i < len; ++i) {
|
for (jsize i = 0; i < len; ++i) {
|
||||||
jarray elem = static_cast<jarray>(
|
jarray elem = static_cast<jarray>(
|
||||||
env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
|
env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
|
||||||
if (elem == nullptr) {
|
if (elem == nullptr) {
|
||||||
throwException(env, kNullPointerException,
|
throwException(env, kNullPointerException,
|
||||||
"null entries in provided array");
|
"null entries in provided array");
|
||||||
return ret;
|
return;
|
||||||
}
|
}
|
||||||
ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
|
if (env->ExceptionCheck()) return;
|
||||||
if (env->ExceptionCheck()) return ret;
|
|
||||||
}
|
}
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
|
void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
|
||||||
@ -448,11 +398,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
|
|||||||
}
|
}
|
||||||
env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
|
env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
|
||||||
}
|
}
|
||||||
const size_t encoded_size =
|
checkForNullEntries(env, value, num_dims);
|
||||||
nonScalarTF_STRINGTensorSize(env, value, num_dims);
|
|
||||||
if (env->ExceptionCheck()) return 0;
|
if (env->ExceptionCheck()) return 0;
|
||||||
TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
|
TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
|
||||||
8 * num_elements + encoded_size);
|
sizeof(TF_TString) * num_elements);
|
||||||
if (t == nullptr) {
|
if (t == nullptr) {
|
||||||
delete[] dims;
|
delete[] dims;
|
||||||
throwException(env, kNullPointerException,
|
throwException(env, kNullPointerException,
|
||||||
@ -572,20 +521,8 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
|
|||||||
"Tensor is not a string/bytes scalar");
|
"Tensor is not a string/bytes scalar");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const char* data = static_cast<const char*>(TF_TensorData(t));
|
const TF_TString* data = static_cast<const TF_TString*>(TF_TensorData(t));
|
||||||
const char* src = data + 8;
|
jbyteArray ret = TF_StringDecodeTojbyteArray(env, &data[0]);
|
||||||
size_t src_len = TF_TensorByteSize(t) - 8;
|
|
||||||
uint64_t offset = 0;
|
|
||||||
memcpy(&offset, data, sizeof(offset));
|
|
||||||
if (offset >= src_len) {
|
|
||||||
throwException(env, kIllegalArgumentException,
|
|
||||||
"invalid tensor encoding: bad offsets");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
|
|
||||||
throwExceptionIfNotOK(env, status);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ import java.nio.DoubleBuffer;
|
|||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
import java.nio.IntBuffer;
|
import java.nio.IntBuffer;
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
@ -179,17 +178,13 @@ public class ConstantTest {
|
|||||||
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
|
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
|
||||||
long[] shape = {};
|
long[] shape = {};
|
||||||
|
|
||||||
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
|
|
||||||
// followed by a varint encoded size, followed by the data.
|
|
||||||
ByteArrayOutputStream baout = new ByteArrayOutputStream();
|
ByteArrayOutputStream baout = new ByteArrayOutputStream();
|
||||||
DataOutputStream out = new DataOutputStream(baout);
|
DataOutputStream out = new DataOutputStream(baout);
|
||||||
// Offset in array.
|
// We construct a TF_TString_Small tstring, which has the capacity for a 22 byte string.
|
||||||
out.writeLong(0L);
|
// The first 6 most significant bits of the first byte represent length; the remaining
|
||||||
// Varint encoded length of buffer.
|
// 2-bits are type indicators, and are left as 0b00 to denote a TF_TSTR_SMALL type.
|
||||||
// For any number < 0x80, the varint encoding is simply the number itself.
|
assertTrue(data.length <= 22);
|
||||||
// https://developers.google.com/protocol-buffers/docs/encoding#varints
|
out.writeByte(data.length << 2);
|
||||||
assertTrue(data.length < 0x80);
|
|
||||||
out.write(data.length);
|
|
||||||
out.write(data);
|
out.write(data);
|
||||||
out.close();
|
out.close();
|
||||||
byte[] content = baout.toByteArray();
|
byte[] content = baout.toByteArray();
|
||||||
|
@ -245,29 +245,17 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) {
|
|||||||
// the buffer. The caller takes ownership of the buffer.
|
// the buffer. The caller takes ownership of the buffer.
|
||||||
Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||||
size_t* size, void** buffer) {
|
size_t* size, void** buffer) {
|
||||||
// Compute bytes needed for encoding.
|
|
||||||
*size = 0;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
PyBytesArrayMap(array, [&size](const char* ptr, Py_ssize_t len) {
|
|
||||||
*size += sizeof(tensorflow::uint64) +
|
|
||||||
tensorflow::core::VarintLength(len) + len;
|
|
||||||
}));
|
|
||||||
// Encode all strings.
|
// Encode all strings.
|
||||||
std::unique_ptr<char[]> base_ptr(new char[*size]);
|
*size = nelems * sizeof(tensorflow::tstring);
|
||||||
char* base = base_ptr.get();
|
std::unique_ptr<tensorflow::tstring[]> base_ptr(
|
||||||
char* data_start = base + sizeof(tensorflow::uint64) * nelems;
|
new tensorflow::tstring[nelems]);
|
||||||
char* dst = data_start; // Where next string is encoded.
|
tensorflow::tstring* dst = base_ptr.get();
|
||||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(PyBytesArrayMap(
|
TF_RETURN_IF_ERROR(
|
||||||
array, [&data_start, &dst, &offsets](const char* ptr, Py_ssize_t len) {
|
PyBytesArrayMap(array, [&dst](const char* ptr, Py_ssize_t len) {
|
||||||
*offsets = (dst - data_start);
|
dst->assign(ptr, len);
|
||||||
offsets++;
|
dst++;
|
||||||
dst = tensorflow::core::EncodeVarint64(dst, len);
|
|
||||||
memcpy(dst, ptr, len);
|
|
||||||
dst += len;
|
|
||||||
}));
|
}));
|
||||||
CHECK_EQ(dst, base + *size);
|
|
||||||
*buffer = base_ptr.release();
|
*buffer = base_ptr.release();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -275,37 +263,18 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
|||||||
Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
|
Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
|
||||||
PyArrayObject* dst) {
|
PyArrayObject* dst) {
|
||||||
const void* tensor_data = TF_TensorData(src);
|
const void* tensor_data = TF_TensorData(src);
|
||||||
const size_t tensor_size = TF_TensorByteSize(src);
|
|
||||||
const char* limit = static_cast<const char*>(tensor_data) + tensor_size;
|
|
||||||
DCHECK(tensor_data != nullptr);
|
DCHECK(tensor_data != nullptr);
|
||||||
DCHECK_EQ(TF_STRING, TF_TensorType(src));
|
DCHECK_EQ(TF_STRING, TF_TensorType(src));
|
||||||
|
|
||||||
const uint64* offsets = static_cast<const uint64*>(tensor_data);
|
const tstring* tstr = static_cast<const tstring*>(tensor_data);
|
||||||
const size_t offsets_size = sizeof(uint64) * nelems;
|
|
||||||
const char* data = static_cast<const char*>(tensor_data) + offsets_size;
|
|
||||||
|
|
||||||
const size_t expected_tensor_size =
|
|
||||||
(limit - static_cast<const char*>(tensor_data));
|
|
||||||
if (expected_tensor_size - tensor_size) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Invalid/corrupt TF_STRING tensor: expected ", expected_tensor_size,
|
|
||||||
" bytes of encoded strings for the tensor containing ", nelems,
|
|
||||||
" strings, but the tensor is encoded in ", tensor_size, " bytes");
|
|
||||||
}
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
|
auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
|
||||||
for (int64 i = 0; i < nelems; ++i) {
|
for (int64 i = 0; i < nelems; ++i) {
|
||||||
const char* start = data + offsets[i];
|
const tstring& tstr_i = tstr[i];
|
||||||
const char* ptr = nullptr;
|
auto py_string =
|
||||||
size_t len = 0;
|
make_safe(PyBytes_FromStringAndSize(tstr_i.data(), tstr_i.size()));
|
||||||
|
|
||||||
TF_StringDecode(start, limit - start, &ptr, &len, status.get());
|
|
||||||
if (TF_GetCode(status.get()) != TF_OK) {
|
|
||||||
return errors::InvalidArgument(TF_Message(status.get()));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto py_string = make_safe(PyBytes_FromStringAndSize(ptr, len));
|
|
||||||
if (py_string == nullptr) {
|
if (py_string == nullptr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"failed to create a python byte array when converting element #", i,
|
"failed to create a python byte array when converting element #", i,
|
||||||
@ -551,14 +520,14 @@ Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
|
|||||||
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
|
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
|
||||||
encoded, size, convert_string,
|
encoded, size, convert_string,
|
||||||
[](void* data, size_t len, void* arg) {
|
[](void* data, size_t len, void* arg) {
|
||||||
delete[] reinterpret_cast<char*>(data);
|
delete[] reinterpret_cast<tensorflow::tstring*>(data);
|
||||||
},
|
},
|
||||||
nullptr)});
|
nullptr)});
|
||||||
} else {
|
} else {
|
||||||
*ret = make_safe(TF_NewTensor(
|
*ret = make_safe(TF_NewTensor(
|
||||||
dtype, dims.data(), dims.size(), encoded, size,
|
dtype, dims.data(), dims.size(), encoded, size,
|
||||||
[](void* data, size_t len, void* arg) {
|
[](void* data, size_t len, void* arg) {
|
||||||
delete[] reinterpret_cast<char*>(data);
|
delete[] reinterpret_cast<tensorflow::tstring*>(data);
|
||||||
},
|
},
|
||||||
nullptr));
|
nullptr));
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user