Migrated tensorflow/core/framework/ to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 262737181
This commit is contained in:
parent
5e307a0fbb
commit
b4bf76a431
@ -68,7 +68,12 @@ class SerializationContext;
|
||||
class IteratorStateReader {
|
||||
public:
|
||||
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
|
||||
#ifdef USE_TSTRING
|
||||
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
|
||||
// migration.
|
||||
virtual Status ReadScalar(StringPiece key, string* val) = 0;
|
||||
#endif
|
||||
virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
|
||||
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
|
||||
virtual bool Contains(StringPiece key) = 0;
|
||||
|
||||
@ -80,7 +85,12 @@ class IteratorStateReader {
|
||||
class IteratorStateWriter {
|
||||
public:
|
||||
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
|
||||
#ifdef USE_TSTRING
|
||||
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
|
||||
// migration.
|
||||
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
|
||||
#endif
|
||||
virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
|
||||
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
|
||||
|
||||
virtual ~IteratorStateWriter() {}
|
||||
@ -115,7 +125,7 @@ class GraphDefBuilderWrapper {
|
||||
Status AddVector(const std::vector<T>& val, Node** output) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
|
||||
TensorShape({static_cast<int64>(val.size())}));
|
||||
for (int i = 0; i < val.size(); i++) {
|
||||
for (size_t i = 0; i < val.size(); i++) {
|
||||
val_t.flat<T>()(i) = val[i];
|
||||
}
|
||||
AddTensorInternal(val_t, output);
|
||||
@ -125,6 +135,23 @@ class GraphDefBuilderWrapper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
|
||||
// migration.
|
||||
Status AddVector(const std::vector<string>& val, Node** output) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
|
||||
TensorShape({static_cast<int64>(val.size())}));
|
||||
for (size_t i = 0; i < val.size(); i++) {
|
||||
val_t.flat<tstring>()(i) = val[i];
|
||||
}
|
||||
AddTensorInternal(val_t, output);
|
||||
if (*output == nullptr) {
|
||||
return errors::Internal("AddVector: Failed to build Const op.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
|
||||
// Adds a `Const` node for the given tensor value to the graph.
|
||||
//
|
||||
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
|
||||
|
@ -67,7 +67,8 @@ limitations under the License.
|
||||
#define TF_CALL_int16(m) m(::tensorflow::int16)
|
||||
|
||||
#define TF_CALL_int8(m) m(::tensorflow::int8)
|
||||
#define TF_CALL_string(m) m(string)
|
||||
#define TF_CALL_string(m) m(tstring)
|
||||
#define TF_CALL_tstring(m) m(tstring)
|
||||
#define TF_CALL_resource(m) m(::tensorflow::ResourceHandle)
|
||||
#define TF_CALL_variant(m) m(::tensorflow::Variant)
|
||||
#define TF_CALL_complex64(m) m(::tensorflow::complex64)
|
||||
@ -98,7 +99,8 @@ limitations under the License.
|
||||
#define TF_CALL_int16(m)
|
||||
|
||||
#define TF_CALL_int8(m)
|
||||
#define TF_CALL_string(m) m(string)
|
||||
#define TF_CALL_string(m) m(tstring)
|
||||
#define TF_CALL_tstring(m) m(tstring)
|
||||
#define TF_CALL_resource(m)
|
||||
#define TF_CALL_variant(m)
|
||||
#define TF_CALL_complex64(m)
|
||||
@ -129,6 +131,7 @@ limitations under the License.
|
||||
|
||||
#define TF_CALL_int8(m)
|
||||
#define TF_CALL_string(m)
|
||||
#define TF_CALL_tstring(m)
|
||||
#define TF_CALL_resource(m)
|
||||
#define TF_CALL_variant(m)
|
||||
#define TF_CALL_complex64(m)
|
||||
@ -188,10 +191,10 @@ limitations under the License.
|
||||
|
||||
// Call "m" on all types.
|
||||
#define TF_CALL_ALL_TYPES(m) \
|
||||
TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m)
|
||||
TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) TF_CALL_resource(m) TF_CALL_variant(m)
|
||||
|
||||
// Call "m" on POD and string types.
|
||||
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
|
||||
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m)
|
||||
|
||||
// Call "m" on all number types supported on GPU.
|
||||
#define TF_CALL_GPU_NUMBER_TYPES(m) \
|
||||
@ -213,7 +216,7 @@ limitations under the License.
|
||||
#define TF_CALL_SAVE_RESTORE_TYPES(m) \
|
||||
TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
|
||||
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \
|
||||
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_tstring(m) \
|
||||
TF_CALL_QUANTIZED_TYPES(m)
|
||||
|
||||
#ifdef TENSORFLOW_SYCL_NO_DOUBLE
|
||||
|
@ -168,7 +168,7 @@ struct Helper {
|
||||
// Helper specialization for string (the only non-simple type we
|
||||
// support).
|
||||
template <>
|
||||
struct Helper<string> {
|
||||
struct Helper<tstring> {
|
||||
// Proto message uses RepeatedFieldType to hold repeated T.
|
||||
typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
|
||||
|
||||
@ -176,7 +176,7 @@ struct Helper<string> {
|
||||
// "out", which is usually the TensorProto::tensor_content.
|
||||
template <typename Destination>
|
||||
static void Encode(TensorBuffer* in, int64 n, Destination* out) {
|
||||
port::EncodeStringList(in->base<const string>(), n, out);
|
||||
port::EncodeStringList(in->base<const tstring>(), n, out);
|
||||
}
|
||||
|
||||
// Decodes "n" elements of type string from "in" and constructs a
|
||||
@ -184,8 +184,8 @@ struct Helper<string> {
|
||||
// usually the TensorProto::tensor_content.
|
||||
template <typename Source>
|
||||
static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
|
||||
Buffer<string>* buf = new Buffer<string>(a, n);
|
||||
string* strings = buf->template base<string>();
|
||||
Buffer<tstring>* buf = new Buffer<tstring>(a, n);
|
||||
tstring* strings = buf->template base<tstring>();
|
||||
if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
|
||||
buf->Unref();
|
||||
return nullptr;
|
||||
@ -197,8 +197,8 @@ struct Helper<string> {
|
||||
// stored in buffer "in".
|
||||
static int64 TotalBytes(TensorBuffer* in, int n) {
|
||||
int64 tot = in->size();
|
||||
DCHECK_EQ(tot, sizeof(string) * n);
|
||||
const string* p = in->base<const string>();
|
||||
DCHECK_EQ(tot, sizeof(tstring) * n);
|
||||
const tstring* p = in->base<const tstring>();
|
||||
for (int i = 0; i < n; ++i, ++p) tot += p->size();
|
||||
return tot;
|
||||
}
|
||||
@ -302,7 +302,7 @@ PROTO_TRAITS(uint32, uint32, uint32);
|
||||
PROTO_TRAITS(int16, int32, int);
|
||||
PROTO_TRAITS(int8, int32, int);
|
||||
PROTO_TRAITS(bool, bool, bool);
|
||||
PROTO_TRAITS(string, string, string);
|
||||
PROTO_TRAITS(tstring, tstring, string);
|
||||
PROTO_TRAITS(qint8, int32, int);
|
||||
PROTO_TRAITS(quint8, int32, int);
|
||||
PROTO_TRAITS(qint16, int32, int);
|
||||
@ -713,7 +713,7 @@ bool Tensor::RefCountIsOne() const {
|
||||
CASE(uint64, SINGLE_ARG(STMTS)) \
|
||||
CASE(int16, SINGLE_ARG(STMTS)) \
|
||||
CASE(int8, SINGLE_ARG(STMTS)) \
|
||||
CASE(string, SINGLE_ARG(STMTS)) \
|
||||
CASE(tstring, SINGLE_ARG(STMTS)) \
|
||||
CASE(complex64, SINGLE_ARG(STMTS)) \
|
||||
CASE(complex128, SINGLE_ARG(STMTS)) \
|
||||
CASE(int64, SINGLE_ARG(STMTS)) \
|
||||
@ -968,7 +968,7 @@ inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
|
||||
bool print_v2) {
|
||||
return a;
|
||||
}
|
||||
inline string PrintOneElement(const string& a, bool print_v2) {
|
||||
inline string PrintOneElement(const tstring& a, bool print_v2) {
|
||||
if (print_v2) {
|
||||
return "\"" + absl::CEscape(a) + "\"";
|
||||
} else {
|
||||
@ -1164,7 +1164,7 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
|
||||
return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
|
||||
break;
|
||||
case DT_STRING:
|
||||
return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2);
|
||||
return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
|
||||
break;
|
||||
default: {
|
||||
// All irregular cases
|
||||
|
@ -150,7 +150,7 @@ class Tensor {
|
||||
: Tensor(scalar_value, host_scalar_tag{}) {}
|
||||
explicit Tensor(int8 scalar_value)
|
||||
: Tensor(scalar_value, host_scalar_tag{}) {}
|
||||
explicit Tensor(string scalar_value)
|
||||
explicit Tensor(tstring scalar_value)
|
||||
: Tensor(std::move(scalar_value), host_scalar_tag{}) {}
|
||||
explicit Tensor(complex64 scalar_value)
|
||||
: Tensor(scalar_value, host_scalar_tag{}) {}
|
||||
@ -183,7 +183,7 @@ class Tensor {
|
||||
// convenience because otherwise passing a string literal would surprisingly
|
||||
// construct a DT_BOOL tensor.
|
||||
explicit Tensor(const char* scalar_value)
|
||||
: Tensor(string(scalar_value), host_scalar_tag{}) {}
|
||||
: Tensor(tstring(scalar_value), host_scalar_tag{}) {}
|
||||
|
||||
/// Copy constructor.
|
||||
Tensor(const Tensor& other);
|
||||
|
@ -94,6 +94,7 @@ TEST(TensorTest, DataType_Traits) {
|
||||
EXPECT_TRUE(std::is_trivial<int8>::value);
|
||||
EXPECT_TRUE(std::is_trivial<int64>::value);
|
||||
EXPECT_TRUE(std::is_trivial<bool>::value);
|
||||
EXPECT_FALSE(std::is_trivial<tstring>::value);
|
||||
EXPECT_FALSE(std::is_trivial<string>::value);
|
||||
|
||||
EXPECT_EQ(sizeof(bool), 1);
|
||||
@ -903,15 +904,15 @@ TEST(Tensor_Float, Reshape_And_Slice_Assignment) {
|
||||
}
|
||||
|
||||
TEST(Tensor_String, Simple) {
|
||||
Tensor t = test::AsTensor<string>(
|
||||
Tensor t = test::AsTensor<tstring>(
|
||||
{"hello", "world", "machine", "learning", "new", "york"},
|
||||
TensorShape({3, 2}));
|
||||
auto s = t.shape();
|
||||
ASSERT_EQ(s.dims(), 2);
|
||||
ASSERT_EQ(s.dim_size(0), 3);
|
||||
ASSERT_EQ(s.dim_size(1), 2);
|
||||
auto m = t.matrix<string>();
|
||||
EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4);
|
||||
auto m = t.matrix<tstring>();
|
||||
EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(tstring) + 5 + 5 + 7 + 8 + 3 + 4);
|
||||
|
||||
EXPECT_EQ(m(0, 0), "hello");
|
||||
EXPECT_EQ(m(0, 1), "world");
|
||||
@ -920,7 +921,7 @@ TEST(Tensor_String, Simple) {
|
||||
EXPECT_EQ(m(2, 0), "new");
|
||||
EXPECT_EQ(m(2, 1), "york");
|
||||
|
||||
TestCopies<string>(t);
|
||||
TestCopies<tstring>(t);
|
||||
}
|
||||
|
||||
TEST(Tensor_Float, SimpleWithHelper) {
|
||||
@ -976,7 +977,7 @@ TEST(Tensor_Int64, SimpleWithHelper) {
|
||||
}
|
||||
|
||||
TEST(Tensor_String, SimpleWithHelper) {
|
||||
Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3});
|
||||
Tensor t1 = test::AsTensor<tstring>({"0", "1", "2", "3", "4", "5"}, {2, 3});
|
||||
Tensor t2(DT_STRING, {2, 3});
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
@ -985,7 +986,7 @@ TEST(Tensor_String, SimpleWithHelper) {
|
||||
}
|
||||
|
||||
// Test with helper.
|
||||
test::ExpectTensorEqual<string>(t1, t2);
|
||||
test::ExpectTensorEqual<tstring>(t1, t2);
|
||||
}
|
||||
|
||||
TEST(Tensor_Bool, SimpleWithHelper) {
|
||||
@ -1365,10 +1366,10 @@ TEST(SummarizeValue, BOOL) {
|
||||
}
|
||||
|
||||
TEST(SummarizeValue, STRING) {
|
||||
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
|
||||
Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
|
||||
{"one", "two", "three", "four", "five"});
|
||||
EXPECT_EQ("one two three four five", x.SummarizeValue(16));
|
||||
x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
|
||||
x = MkTensor<tstring>(DT_STRING, TensorShape({5, 1, 5}),
|
||||
{"one", "two", "three", "four", "five"});
|
||||
EXPECT_EQ("[[one two three four five]][[one...]]...", x.SummarizeValue(6));
|
||||
}
|
||||
@ -1421,7 +1422,7 @@ TEST(SummarizeValue, BOOL_PRINT_V2) {
|
||||
}
|
||||
|
||||
TEST(SummarizeValue, STRING_PRINT_V2) {
|
||||
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
|
||||
Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
|
||||
{"one", "two", "three", "four", "five"});
|
||||
EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]",
|
||||
x.SummarizeValue(16, true));
|
||||
@ -1429,7 +1430,7 @@ TEST(SummarizeValue, STRING_PRINT_V2) {
|
||||
x.SummarizeValue(-1, true));
|
||||
EXPECT_EQ("[\"one\" \"two\" ... \"four\" \"five\"]",
|
||||
x.SummarizeValue(2, true));
|
||||
x = MkTensor<string>(DT_STRING, TensorShape({2, 2}),
|
||||
x = MkTensor<tstring>(DT_STRING, TensorShape({2, 2}),
|
||||
{"one", "two", "three", "four", "five"});
|
||||
EXPECT_EQ("[[\"one\" \"two\"]\n [\"three\" \"four\"]]",
|
||||
x.SummarizeValue(16, true));
|
||||
|
@ -98,8 +98,8 @@ Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
|
||||
if (dtype != DT_STRING) {
|
||||
return errors::Internal("Unexpected data type");
|
||||
}
|
||||
string* to_strings =
|
||||
reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
|
||||
tstring* to_strings =
|
||||
reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
|
||||
|
||||
int64 offset = 0;
|
||||
for (const Tensor& tensor : tensors) {
|
||||
@ -163,7 +163,7 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
|
||||
shape.set_dim(0, size);
|
||||
result->emplace_back(tensor.dtype(), shape);
|
||||
Tensor& split = (*result)[result->size() - 1];
|
||||
string* to_strings = reinterpret_cast<string*>(
|
||||
tstring* to_strings = reinterpret_cast<tstring*>(
|
||||
const_cast<char*>(split.tensor_data().data()));
|
||||
|
||||
CHECK_LE(offset + split.NumElements(), tensor.NumElements());
|
||||
|
@ -77,19 +77,19 @@ class TypedAllocator {
|
||||
|
||||
template <>
|
||||
/* static */
|
||||
inline void TypedAllocator::RunCtor(Allocator* raw_allocator, string* p,
|
||||
inline void TypedAllocator::RunCtor(Allocator* raw_allocator, tstring* p,
|
||||
size_t n) {
|
||||
if (!raw_allocator->AllocatesOpaqueHandle()) {
|
||||
for (size_t i = 0; i < n; ++p, ++i) new (p) string();
|
||||
for (size_t i = 0; i < n; ++p, ++i) new (p) tstring();
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
/* static */
|
||||
inline void TypedAllocator::RunDtor(Allocator* raw_allocator, string* p,
|
||||
inline void TypedAllocator::RunDtor(Allocator* raw_allocator, tstring* p,
|
||||
size_t n) {
|
||||
if (!raw_allocator->AllocatesOpaqueHandle()) {
|
||||
for (size_t i = 0; i < n; ++p, ++i) p->~string();
|
||||
for (size_t i = 0; i < n; ++p, ++i) p->~tstring();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user