Eliminate unnecessary packing/unpacking of string tensors across the C-API and language bindings.
With the std::string -> tensorflow::tstring migration in tensorflow core complete, we now have an ABI stable string type with associated C accessors/modifiers. The packing/unpacking of string tensors across the C-API is now superfluous and can be removed. This is an ABI breaking change; updates to language bindings in tensorflow/ are included in this CL. PiperOrigin-RevId: 318933001 Change-Id: I3c99cf70834ba0b6cefc4ba39a35d6c168e880db
This commit is contained in:
parent
2ab7873606
commit
24f835217f
@ -6,6 +6,11 @@
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
* The byte layout for string tensors across the C-API has been updated to match
|
||||
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
`TF_StringEncodedSize` are no longer relevant and have been removed; see
|
||||
core/platform/ctstring.h for string access/modification in C.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
@ -29,6 +29,8 @@ filegroup(
|
||||
"tf_file_statistics.h",
|
||||
"tf_status.h",
|
||||
"tf_tensor.h",
|
||||
"tf_tstring.h",
|
||||
"//tensorflow/core/platform:ctstring",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
@ -48,6 +50,7 @@ filegroup(
|
||||
"*test*",
|
||||
],
|
||||
) + [
|
||||
"//tensorflow/core/platform:ctstring",
|
||||
"//tensorflow/cc:srcs_no_runtime",
|
||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||
],
|
||||
@ -78,6 +81,7 @@ tf_cuda_library(
|
||||
"c_api_internal.h",
|
||||
"tf_datatype.h",
|
||||
"tf_tensor.h",
|
||||
"tf_tstring.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/c/tf_tstring.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow.
|
||||
|
@ -2286,14 +2286,15 @@ TEST_F(CApiAttributesTest, Tensor) {
|
||||
|
||||
TEST_F(CApiAttributesTest, StringTensor) {
|
||||
// Create the string-Tensor "attribute" value.
|
||||
char encoded[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
|
||||
1, // varint encoded string length
|
||||
'A',
|
||||
};
|
||||
const char test_string[] =
|
||||
"borkborkborkborkborkborkborkbork"; // >24bytes to force heap alloc
|
||||
TF_TString tstr[1];
|
||||
TF_TString_Init(&tstr[0]);
|
||||
TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
|
||||
|
||||
auto deallocator = [](void* data, size_t len, void* arg) {};
|
||||
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
|
||||
sizeof(encoded), deallocator, nullptr),
|
||||
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
|
||||
sizeof(tstr), deallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
|
||||
// Create a TF_Operation with the attribute t_in
|
||||
@ -2312,9 +2313,17 @@ TEST_F(CApiAttributesTest, StringTensor) {
|
||||
EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
|
||||
EXPECT_EQ(0, TF_NumDims(t_out));
|
||||
ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
|
||||
EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
|
||||
TF_TensorByteSize(t_out)));
|
||||
TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
|
||||
TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
|
||||
EXPECT_EQ(absl::string_view(test_string),
|
||||
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||
TF_TString_GetSize(t_out_tstr)));
|
||||
EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
|
||||
TF_TString_GetSize(t_in_tstr)),
|
||||
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||
TF_TString_GetSize(t_out_tstr)));
|
||||
TF_DeleteTensor(t_out);
|
||||
TF_TString_Dealloc(&tstr[0]);
|
||||
}
|
||||
|
||||
TEST_F(CApiAttributesTest, TensorList) {
|
||||
|
@ -228,57 +228,6 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
|
||||
} // namespace tensorflow
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
}
|
||||
|
||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
if (sz < src_len) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument("src string is too large to encode"));
|
||||
return 0;
|
||||
}
|
||||
if (dst_len < sz) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
StringEncode(src, src_len, dst);
|
||||
return sz;
|
||||
}
|
||||
|
||||
static Status TF_StringDecode_Impl(const char* src, size_t src_len,
|
||||
const char** dst, size_t* dst_len) {
|
||||
tensorflow::uint64 len64 = 0;
|
||||
const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
|
||||
if (p == nullptr) {
|
||||
return InvalidArgument("invalid string encoding or truncated src buffer");
|
||||
}
|
||||
if (len64 > std::numeric_limits<size_t>::max()) {
|
||||
return InvalidArgument("encoded string is ", len64,
|
||||
"-bytes, which is too large for this architecture");
|
||||
}
|
||||
*dst = p;
|
||||
*dst_len = static_cast<size_t>(len64);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
|
||||
size_t* dst_len, TF_Status* status) {
|
||||
Set_TF_Status_from_Status(status,
|
||||
TF_StringDecode_Impl(src, src_len, dst, dst_len));
|
||||
if (TF_GetCode(status) != TF_OK) return 0;
|
||||
return static_cast<size_t>(*dst - src) + *dst_len;
|
||||
}
|
||||
|
||||
size_t TF_StringEncodedSize(size_t len) {
|
||||
return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
|
||||
}
|
||||
|
||||
static void DeleteArray(void* data, size_t size, void* arg) {
|
||||
DCHECK_EQ(data, arg);
|
||||
@ -334,58 +283,12 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
|
||||
return t;
|
||||
}
|
||||
if (src.dtype() != tensorflow::DT_STRING) {
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
||||
}
|
||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||
// encoded sequence of strings.
|
||||
|
||||
// Compute bytes needed for encoding.
|
||||
size_t size = 0;
|
||||
const auto& srcarray = src.flat<tstring>();
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
const string& s = srcarray(i);
|
||||
// uint64 starting_offset, TF_StringEncode-d string.
|
||||
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
|
||||
}
|
||||
|
||||
// Encode all strings.
|
||||
char* base = new char[size];
|
||||
char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
|
||||
char* dst = data_start; // Where next string is encoded.
|
||||
size_t dst_len = size - static_cast<size_t>(data_start - base);
|
||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const string& s = srcarray(i);
|
||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||
StringEncode(s.data(), s.size(), dst);
|
||||
dst += consumed;
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
*status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
delete[] base;
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto dims = src.shape().dim_sizes();
|
||||
std::vector<tensorflow::int64> dimvec(dims.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimvec[i] = dims[i];
|
||||
}
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
return TF_NewTensor(TF_STRING,
|
||||
reinterpret_cast<const int64_t*>(dimvec.data()),
|
||||
dimvec.size(), base, size, DeleteArray, base);
|
||||
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
@ -409,39 +312,7 @@ Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (tensor_.dtype() != DT_STRING) {
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(Data());
|
||||
const size_t src_size = ByteSize();
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
num_elements) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_STRING tensor; too short to hold number of elements");
|
||||
}
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
||||
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
|
||||
return InvalidArgument("Malformed TF_STRING tensor; element ", i,
|
||||
" out of range");
|
||||
}
|
||||
size_t len;
|
||||
const char* p;
|
||||
const char* srcp = data_start + offset;
|
||||
Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
|
||||
if (!status.ok()) return status;
|
||||
dstarray(i).assign(p, len);
|
||||
}
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -148,34 +148,6 @@ TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from,
|
||||
int num_new_dims,
|
||||
TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Encode the string `src` (`src_len` bytes long) into `dst` in the format
|
||||
// required by TF_STRING tensors. Does not write to memory more than `dst_len`
|
||||
// bytes beyond `*dst`. `dst_len` should be at least
|
||||
// TF_StringEncodedSize(src_len).
|
||||
//
|
||||
// On success returns the size in bytes of the encoded string.
|
||||
// Returns an error into `status` otherwise.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len,
|
||||
char* dst, size_t dst_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Decode a string encoded using TF_StringEncode.
|
||||
//
|
||||
// On success, sets `*dst` to the start of the decoded string and `*dst_len` to
|
||||
// its length. Returns the number of bytes starting at `src` consumed while
|
||||
// decoding. `*dst` points to memory within the encoded buffer. On failure,
|
||||
// `*dst` and `*dst_len` are undefined and an error is set in `status`.
|
||||
//
|
||||
// Does not read memory more than `src_len` bytes beyond `src`.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len,
|
||||
const char** dst, size_t* dst_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Return the size in bytes required to encode a string `len` bytes long into a
|
||||
// TF_STRING tensor.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
|
||||
|
||||
// Returns bool iff this tensor is aligned.
|
||||
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
|
||||
|
||||
|
20
tensorflow/c/tf_tstring.h
Normal file
20
tensorflow/c/tf_tstring.h
Normal file
@ -0,0 +1,20 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_TF_TSTRING_H_
|
||||
#define TENSORFLOW_C_TF_TSTRING_H_
|
||||
|
||||
#include "tensorflow/core/platform/ctstring.h"
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_
|
@ -762,6 +762,14 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ctstring",
|
||||
srcs = [
|
||||
"ctstring.h",
|
||||
"ctstring_internal.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "types",
|
||||
hdrs = ["types.h"],
|
||||
|
@ -16,14 +16,20 @@ limitations under the License.
|
||||
|
||||
package tensorflow
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include <string.h>
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
void toNewTString(_GoString_ gstr, TF_TString *tstr) {
|
||||
TF_TString_Init(tstr);
|
||||
TF_TString_Copy(tstr, _GoStringPtr(gstr), _GoStringLen(gstr));
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/bits"
|
||||
@ -79,9 +85,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
nflattened := numElements(shape)
|
||||
nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
|
||||
if dataType == String {
|
||||
// TF_STRING tensors are encoded as an array of 8-byte offsets
|
||||
// followed by string data. See c_api.h.
|
||||
nbytes = uintptr(nflattened*8 + int64(byteSizeOfEncodedStrings(val)))
|
||||
nbytes = uintptr(nflattened) * C.sizeof_TF_TString
|
||||
}
|
||||
var shapePtr *C.int64_t
|
||||
if len(shape) > 0 {
|
||||
@ -94,35 +98,26 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
runtime.SetFinalizer(t, (*Tensor).finalize)
|
||||
raw := tensorData(t.c)
|
||||
buf := bytes.NewBuffer(raw[:0:len(raw)])
|
||||
if dataType != String {
|
||||
if isAllArray(val.Type()) {
|
||||
// We have arrays all the way down, or just primitive types. We can
|
||||
// just copy the memory in as it is all contiguous.
|
||||
if err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// When there are slices involved the memory for each leaf slice may
|
||||
// not be contiguous with the others or in the order we might
|
||||
// expect, so we need to work our way down to each slice of
|
||||
// primitives and copy them individually
|
||||
if err := encodeTensorWithSlices(buf, val, shape); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if uintptr(buf.Len()) != nbytes {
|
||||
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
|
||||
}
|
||||
} else {
|
||||
e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()}
|
||||
if err := e.encode(reflect.ValueOf(value), shape); err != nil {
|
||||
if isAllArray(val.Type()) {
|
||||
// We have arrays all the way down, or just primitive types. We can
|
||||
// just copy the memory in as it is all contiguous.
|
||||
if err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(buf.Len()) != nflattened*8 {
|
||||
return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
|
||||
} else {
|
||||
// When there are slices involved the memory for each leaf slice may
|
||||
// not be contiguous with the others or in the order we might
|
||||
// expect, so we need to work our way down to each slice of
|
||||
// primitives and copy them individually
|
||||
if err := encodeTensorWithSlices(buf, val, shape); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if uintptr(buf.Len()) != nbytes {
|
||||
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@ -131,6 +126,8 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
// contiguous in RAM.
|
||||
func isAllArray(typ reflect.Type) bool {
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
return false
|
||||
case reflect.Slice:
|
||||
return false
|
||||
case reflect.Array:
|
||||
@ -312,58 +309,16 @@ func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
|
||||
|
||||
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
|
||||
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
|
||||
// Start by making an array of all the strings
|
||||
strs := make([]string, nStrings)
|
||||
// The first nStrings * 8 bytes of raw are offsets into the second half of
|
||||
// the raw data. This second half is where the strings are encoded.
|
||||
offsets := (*(*[]int64)(unsafe.Pointer(&raw)))[:nStrings]
|
||||
tstrs := (*(*[]C.TF_TString)(unsafe.Pointer(&raw)))[:nStrings]
|
||||
|
||||
// Reset raw after the offsets. Now the offsets will work relative to raw
|
||||
raw = raw[nStrings*8:]
|
||||
// Next we work out the final length of the string data so we can copy the
|
||||
// good data out of raw (which is owned by the C tensor and won't be safe
|
||||
// to access if the tensor is freed)
|
||||
r := bytes.NewReader(raw)
|
||||
var totalLength int
|
||||
for _, offset := range offsets {
|
||||
// At each offset we should find a varint length of a string.
|
||||
// Errors here should mean the tensor is corrupt.
|
||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalLength += int(l)
|
||||
for i, tstr := range tstrs {
|
||||
dst := C.TF_TString_GetDataPointer(&tstr)
|
||||
dstLen := C.TF_TString_GetSize(&tstr)
|
||||
|
||||
strs[i] = C.GoStringN(dst, C.int(dstLen))
|
||||
}
|
||||
|
||||
// Lets allocate a big buffer to carry our string data.
|
||||
stringData := make([]byte, 0, totalLength)
|
||||
// Now copy the string data across into our new buffer, keeping track of the
|
||||
// location of each string in the strs slice.
|
||||
var cursor int
|
||||
for i, offset := range offsets {
|
||||
// At each offset we should find a varint length. Read it
|
||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Then copy the actual string into our large buffer
|
||||
target := stringData[cursor : cursor+int(l)]
|
||||
if _, err := r.Read(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Track where this string data is.
|
||||
strs[i] = *(*string)(unsafe.Pointer(&target))
|
||||
cursor += int(l)
|
||||
}
|
||||
|
||||
// So now we have a big slice of strings
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
@ -469,26 +424,6 @@ func numElements(shape []int64) int64 {
|
||||
return n
|
||||
}
|
||||
|
||||
// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
|
||||
// val MUST be a string, or a container (array/slice etc.) of strings.
|
||||
// Tensorflow encodes strings as the varint encoded length followed by the
|
||||
// string bytes. We could call into the C library to do this but cgo has a heavy
|
||||
// overhead. So we just do that calculation in Go
|
||||
func byteSizeOfEncodedStrings(val reflect.Value) int {
|
||||
if val.Kind() == reflect.String {
|
||||
return sizeVarUint(uint64(val.Len())) + val.Len()
|
||||
}
|
||||
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
|
||||
panic(fmt.Sprintf("unexpected type %s", val.Type()))
|
||||
}
|
||||
// Otherwise must be an array or slice.
|
||||
var size int
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
size += byteSizeOfEncodedStrings(val.Index(i))
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// sizeVarUint determines how many bytes it would take to encode the int v as
|
||||
// an unsigned varint
|
||||
func sizeVarUint(v uint64) int {
|
||||
@ -509,12 +444,18 @@ func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) err
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
} else if v.Kind() == reflect.String {
|
||||
s := v.Interface().(string)
|
||||
var tstr C.TF_TString
|
||||
C.toNewTString(s, &tstr)
|
||||
ptr := unsafe.Pointer(&tstr)
|
||||
return copyPtr(w, ptr, C.sizeof_TF_TString)
|
||||
} else if v.Kind() != reflect.Array {
|
||||
return fmt.Errorf("unsupported type %v", v.Type())
|
||||
}
|
||||
|
||||
// Once we have just a single dimension we can just copy the data
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
if len(shape) == 1 && v.Len() > 0 && v.Index(0).Kind() != reflect.String {
|
||||
elt := v.Index(0)
|
||||
if !elt.CanAddr() {
|
||||
panic("cannot take address")
|
||||
@ -556,45 +497,6 @@ func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
type stringEncoder struct {
|
||||
offsets *bytes.Buffer
|
||||
data []byte
|
||||
offset uint64
|
||||
status *status
|
||||
}
|
||||
|
||||
func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
|
||||
if v.Kind() == reflect.String {
|
||||
if err := copyPtr(e.offsets, unsafe.Pointer(&e.offset), int(unsafe.Sizeof(e.offset))); err != nil {
|
||||
return err
|
||||
}
|
||||
// A string is encoded as the varint length followed by the string bytes.
|
||||
// We do this in Go to avoid the considerable overhead of a cgo call into
|
||||
// the tensorflow library
|
||||
s := v.String()
|
||||
n := binary.PutUvarint(e.data[e.offset:], uint64(len(s)))
|
||||
e.offset += uint64(n)
|
||||
n = copy(e.data[e.offset:], s)
|
||||
e.offset += uint64(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Slice {
|
||||
expected := int(shape[0])
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
subShape := shape[1:]
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if err := e.encode(v.Index(i), subShape); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bug(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
@ -220,16 +220,13 @@ size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
|
||||
}
|
||||
}
|
||||
|
||||
jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
|
||||
size_t src_len, TF_Status* status) {
|
||||
const char* dst = nullptr;
|
||||
size_t dst_len = 0;
|
||||
TF_StringDecode(src, src_len, &dst, &dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const TF_TString* src) {
|
||||
const char* dst = TF_TString_GetDataPointer(src);
|
||||
size_t dst_len = TF_TString_GetSize(src);
|
||||
|
||||
jbyteArray ret = env->NewByteArray(dst_len);
|
||||
jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
|
||||
|
||||
memcpy(cpy, dst, dst_len);
|
||||
env->ReleaseByteArrayElements(ret, cpy, 0);
|
||||
return ret;
|
||||
@ -238,69 +235,32 @@ jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
|
||||
class StringTensorWriter {
|
||||
public:
|
||||
StringTensorWriter(TF_Tensor* t, int num_elements)
|
||||
: offset_(0),
|
||||
poffsets_(static_cast<char*>(TF_TensorData(t))),
|
||||
pdata_(poffsets_ + 8 * num_elements),
|
||||
plimit_(poffsets_ + TF_TensorByteSize(t)) {}
|
||||
: index_(0), data_(static_cast<TF_TString*>(TF_TensorData(t))) {}
|
||||
|
||||
void Add(const char* src, size_t len, TF_Status* status) {
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (plimit_ - poffsets_ < sizeof(offset_)) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE,
|
||||
"TF_STRING tensor encoding ran out of space for offsets, "
|
||||
"this is likely a bug, please file an issue at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new");
|
||||
return;
|
||||
}
|
||||
memcpy(poffsets_, &offset_, sizeof(offset_));
|
||||
size_t written =
|
||||
TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
|
||||
offset_ += written;
|
||||
poffsets_ += 8;
|
||||
pdata_ += written;
|
||||
TF_TString_Init(&data_[index_]);
|
||||
TF_TString_Copy(&data_[index_++], src, len);
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t offset_;
|
||||
char* poffsets_;
|
||||
char* pdata_;
|
||||
const char* plimit_;
|
||||
int index_;
|
||||
TF_TString* data_;
|
||||
};
|
||||
|
||||
class StringTensorReader {
|
||||
public:
|
||||
StringTensorReader(const TF_Tensor* t, int num_elements)
|
||||
: index_(0),
|
||||
offsets_(static_cast<const char*>(TF_TensorData(t))),
|
||||
data_(offsets_ + 8 * num_elements),
|
||||
limit_(offsets_ + TF_TensorByteSize(t)) {}
|
||||
: index_(0), data_(static_cast<const TF_TString*>(TF_TensorData(t))) {}
|
||||
|
||||
jbyteArray Next(JNIEnv* env, TF_Status* status) {
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
uint64_t offset = 0;
|
||||
const char* poffset = offsets_ + sizeof(offset) * index_;
|
||||
if (poffset >= limit_) {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
"Invalid TF_STRING tensor, offsets table seems to be too small");
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(&offset, poffset, sizeof(offset));
|
||||
const char* pdata = data_ + offset;
|
||||
if (pdata >= limit_) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Invalid TF_STRING tensor, invalid entry in offset table");
|
||||
return nullptr;
|
||||
}
|
||||
++index_;
|
||||
return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
|
||||
return TF_StringDecodeTojbyteArray(env, &data_[index_++]);
|
||||
}
|
||||
|
||||
private:
|
||||
int index_;
|
||||
const char* offsets_;
|
||||
const char* data_;
|
||||
const char* limit_;
|
||||
const TF_TString* data_;
|
||||
};
|
||||
|
||||
void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
|
||||
@ -367,17 +327,16 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
|
||||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
|
||||
// TF_StringEncode-encoded bytes.
|
||||
size_t src_len = static_cast<int>(env->GetArrayLength(value));
|
||||
size_t dst_len = TF_StringEncodedSize(src_len);
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
|
||||
char* dst = static_cast<char*>(TF_TensorData(t));
|
||||
memset(dst, 0, 8); // The offset table
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, sizeof(TF_TString));
|
||||
TF_TString* dst = static_cast<TF_TString*>(TF_TensorData(t));
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
|
||||
// jsrc is an unsigned byte*, TF_StringEncode requires a char*.
|
||||
// reinterpret_cast<> for this conversion should be safe.
|
||||
TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
|
||||
dst_len, status);
|
||||
TF_TString_Init(&dst[0]);
|
||||
TF_TString_Copy(&dst[0], reinterpret_cast<const char*>(jsrc), src_len);
|
||||
|
||||
env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
@ -388,27 +347,18 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
|
||||
}
|
||||
|
||||
namespace {
|
||||
size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
|
||||
if (num_dims == 0) {
|
||||
// This is the last dimension, i.e., value should correspond to a jbyteArray
|
||||
// encoding the string.
|
||||
return TF_StringEncodedSize(
|
||||
static_cast<size_t>(env->GetArrayLength(value)));
|
||||
}
|
||||
void checkForNullEntries(JNIEnv* env, jarray value, int num_dims) {
|
||||
jsize len = env->GetArrayLength(value);
|
||||
size_t ret = 0;
|
||||
for (jsize i = 0; i < len; ++i) {
|
||||
jarray elem = static_cast<jarray>(
|
||||
env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
|
||||
if (elem == nullptr) {
|
||||
throwException(env, kNullPointerException,
|
||||
"null entries in provided array");
|
||||
return ret;
|
||||
return;
|
||||
}
|
||||
ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
|
||||
if (env->ExceptionCheck()) return ret;
|
||||
if (env->ExceptionCheck()) return;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
|
||||
@ -448,11 +398,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
|
||||
}
|
||||
env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
|
||||
}
|
||||
const size_t encoded_size =
|
||||
nonScalarTF_STRINGTensorSize(env, value, num_dims);
|
||||
checkForNullEntries(env, value, num_dims);
|
||||
if (env->ExceptionCheck()) return 0;
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
|
||||
8 * num_elements + encoded_size);
|
||||
sizeof(TF_TString) * num_elements);
|
||||
if (t == nullptr) {
|
||||
delete[] dims;
|
||||
throwException(env, kNullPointerException,
|
||||
@ -572,20 +521,8 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
|
||||
"Tensor is not a string/bytes scalar");
|
||||
return nullptr;
|
||||
}
|
||||
const char* data = static_cast<const char*>(TF_TensorData(t));
|
||||
const char* src = data + 8;
|
||||
size_t src_len = TF_TensorByteSize(t) - 8;
|
||||
uint64_t offset = 0;
|
||||
memcpy(&offset, data, sizeof(offset));
|
||||
if (offset >= src_len) {
|
||||
throwException(env, kIllegalArgumentException,
|
||||
"invalid tensor encoding: bad offsets");
|
||||
return nullptr;
|
||||
}
|
||||
TF_Status* status = TF_NewStatus();
|
||||
jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
const TF_TString* data = static_cast<const TF_TString*>(TF_TensorData(t));
|
||||
jbyteArray ret = TF_StringDecodeTojbyteArray(env, &data[0]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,6 @@ import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
@ -179,17 +178,13 @@ public class ConstantTest {
|
||||
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
|
||||
long[] shape = {};
|
||||
|
||||
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
|
||||
// followed by a varint encoded size, followed by the data.
|
||||
ByteArrayOutputStream baout = new ByteArrayOutputStream();
|
||||
DataOutputStream out = new DataOutputStream(baout);
|
||||
// Offset in array.
|
||||
out.writeLong(0L);
|
||||
// Varint encoded length of buffer.
|
||||
// For any number < 0x80, the varint encoding is simply the number itself.
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding#varints
|
||||
assertTrue(data.length < 0x80);
|
||||
out.write(data.length);
|
||||
// We construct a TF_TString_Small tstring, which has the capacity for a 22 byte string.
|
||||
// The first 6 most significant bits of the first byte represent length; the remaining
|
||||
// 2-bits are type indicators, and are left as 0b00 to denote a TF_TSTR_SMALL type.
|
||||
assertTrue(data.length <= 22);
|
||||
out.writeByte(data.length << 2);
|
||||
out.write(data);
|
||||
out.close();
|
||||
byte[] content = baout.toByteArray();
|
||||
|
@ -245,29 +245,17 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) {
|
||||
// the buffer. The caller takes ownership of the buffer.
|
||||
Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
size_t* size, void** buffer) {
|
||||
// Compute bytes needed for encoding.
|
||||
*size = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PyBytesArrayMap(array, [&size](const char* ptr, Py_ssize_t len) {
|
||||
*size += sizeof(tensorflow::uint64) +
|
||||
tensorflow::core::VarintLength(len) + len;
|
||||
}));
|
||||
// Encode all strings.
|
||||
std::unique_ptr<char[]> base_ptr(new char[*size]);
|
||||
char* base = base_ptr.get();
|
||||
char* data_start = base + sizeof(tensorflow::uint64) * nelems;
|
||||
char* dst = data_start; // Where next string is encoded.
|
||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
||||
*size = nelems * sizeof(tensorflow::tstring);
|
||||
std::unique_ptr<tensorflow::tstring[]> base_ptr(
|
||||
new tensorflow::tstring[nelems]);
|
||||
tensorflow::tstring* dst = base_ptr.get();
|
||||
|
||||
TF_RETURN_IF_ERROR(PyBytesArrayMap(
|
||||
array, [&data_start, &dst, &offsets](const char* ptr, Py_ssize_t len) {
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
dst = tensorflow::core::EncodeVarint64(dst, len);
|
||||
memcpy(dst, ptr, len);
|
||||
dst += len;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PyBytesArrayMap(array, [&dst](const char* ptr, Py_ssize_t len) {
|
||||
dst->assign(ptr, len);
|
||||
dst++;
|
||||
}));
|
||||
CHECK_EQ(dst, base + *size);
|
||||
*buffer = base_ptr.release();
|
||||
return Status::OK();
|
||||
}
|
||||
@ -275,37 +263,18 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
|
||||
PyArrayObject* dst) {
|
||||
const void* tensor_data = TF_TensorData(src);
|
||||
const size_t tensor_size = TF_TensorByteSize(src);
|
||||
const char* limit = static_cast<const char*>(tensor_data) + tensor_size;
|
||||
DCHECK(tensor_data != nullptr);
|
||||
DCHECK_EQ(TF_STRING, TF_TensorType(src));
|
||||
|
||||
const uint64* offsets = static_cast<const uint64*>(tensor_data);
|
||||
const size_t offsets_size = sizeof(uint64) * nelems;
|
||||
const char* data = static_cast<const char*>(tensor_data) + offsets_size;
|
||||
const tstring* tstr = static_cast<const tstring*>(tensor_data);
|
||||
|
||||
const size_t expected_tensor_size =
|
||||
(limit - static_cast<const char*>(tensor_data));
|
||||
if (expected_tensor_size - tensor_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid/corrupt TF_STRING tensor: expected ", expected_tensor_size,
|
||||
" bytes of encoded strings for the tensor containing ", nelems,
|
||||
" strings, but the tensor is encoded in ", tensor_size, " bytes");
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
|
||||
for (int64 i = 0; i < nelems; ++i) {
|
||||
const char* start = data + offsets[i];
|
||||
const char* ptr = nullptr;
|
||||
size_t len = 0;
|
||||
|
||||
TF_StringDecode(start, limit - start, &ptr, &len, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
return errors::InvalidArgument(TF_Message(status.get()));
|
||||
}
|
||||
|
||||
auto py_string = make_safe(PyBytes_FromStringAndSize(ptr, len));
|
||||
const tstring& tstr_i = tstr[i];
|
||||
auto py_string =
|
||||
make_safe(PyBytes_FromStringAndSize(tstr_i.data(), tstr_i.size()));
|
||||
if (py_string == nullptr) {
|
||||
return errors::Internal(
|
||||
"failed to create a python byte array when converting element #", i,
|
||||
@ -551,14 +520,14 @@ Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
|
||||
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
|
||||
encoded, size, convert_string,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
delete[] reinterpret_cast<tensorflow::tstring*>(data);
|
||||
},
|
||||
nullptr)});
|
||||
} else {
|
||||
*ret = make_safe(TF_NewTensor(
|
||||
dtype, dims.data(), dims.size(), encoded, size,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
delete[] reinterpret_cast<tensorflow::tstring*>(data);
|
||||
},
|
||||
nullptr));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user