Migrated tensorflow/core/framework/ to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 262737181
This commit is contained in:
Dero Gharibian 2019-08-10 11:26:24 -07:00 committed by TensorFlower Gardener
parent 5e307a0fbb
commit b4bf76a431
7 changed files with 70 additions and 39 deletions

View File

@ -68,7 +68,12 @@ class SerializationContext;
class IteratorStateReader { class IteratorStateReader {
public: public:
virtual Status ReadScalar(StringPiece key, int64* val) = 0; virtual Status ReadScalar(StringPiece key, int64* val) = 0;
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
virtual Status ReadScalar(StringPiece key, string* val) = 0; virtual Status ReadScalar(StringPiece key, string* val) = 0;
#endif
virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
virtual bool Contains(StringPiece key) = 0; virtual bool Contains(StringPiece key) = 0;
@ -80,7 +85,12 @@ class IteratorStateReader {
class IteratorStateWriter { class IteratorStateWriter {
public: public:
virtual Status WriteScalar(StringPiece key, const int64 val) = 0; virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
virtual Status WriteScalar(StringPiece key, const string& val) = 0; virtual Status WriteScalar(StringPiece key, const string& val) = 0;
#endif
virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
virtual ~IteratorStateWriter() {} virtual ~IteratorStateWriter() {}
@ -115,7 +125,7 @@ class GraphDefBuilderWrapper {
Status AddVector(const std::vector<T>& val, Node** output) { Status AddVector(const std::vector<T>& val, Node** output) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
TensorShape({static_cast<int64>(val.size())})); TensorShape({static_cast<int64>(val.size())}));
for (int i = 0; i < val.size(); i++) { for (size_t i = 0; i < val.size(); i++) {
val_t.flat<T>()(i) = val[i]; val_t.flat<T>()(i) = val[i];
} }
AddTensorInternal(val_t, output); AddTensorInternal(val_t, output);
@ -125,6 +135,23 @@ class GraphDefBuilderWrapper {
return Status::OK(); return Status::OK();
} }
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
Status AddVector(const std::vector<string>& val, Node** output) {
Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
TensorShape({static_cast<int64>(val.size())}));
for (size_t i = 0; i < val.size(); i++) {
val_t.flat<tstring>()(i) = val[i];
}
AddTensorInternal(val_t, output);
if (*output == nullptr) {
return errors::Internal("AddVector: Failed to build Const op.");
}
return Status::OK();
}
#endif // USE_TSTRING
// Adds a `Const` node for the given tensor value to the graph. // Adds a `Const` node for the given tensor value to the graph.
// //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be // `*output` contains a pointer to the output `Node`. It is guaranteed to be

View File

@ -67,7 +67,8 @@ limitations under the License.
#define TF_CALL_int16(m) m(::tensorflow::int16) #define TF_CALL_int16(m) m(::tensorflow::int16)
#define TF_CALL_int8(m) m(::tensorflow::int8) #define TF_CALL_int8(m) m(::tensorflow::int8)
#define TF_CALL_string(m) m(string) #define TF_CALL_string(m) m(tstring)
#define TF_CALL_tstring(m) m(tstring)
#define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle)
#define TF_CALL_variant(m) m(::tensorflow::Variant) #define TF_CALL_variant(m) m(::tensorflow::Variant)
#define TF_CALL_complex64(m) m(::tensorflow::complex64) #define TF_CALL_complex64(m) m(::tensorflow::complex64)
@ -98,7 +99,8 @@ limitations under the License.
#define TF_CALL_int16(m) #define TF_CALL_int16(m)
#define TF_CALL_int8(m) #define TF_CALL_int8(m)
#define TF_CALL_string(m) m(string) #define TF_CALL_string(m) m(tstring)
#define TF_CALL_tstring(m) m(tstring)
#define TF_CALL_resource(m) #define TF_CALL_resource(m)
#define TF_CALL_variant(m) #define TF_CALL_variant(m)
#define TF_CALL_complex64(m) #define TF_CALL_complex64(m)
@ -129,6 +131,7 @@ limitations under the License.
#define TF_CALL_int8(m) #define TF_CALL_int8(m)
#define TF_CALL_string(m) #define TF_CALL_string(m)
#define TF_CALL_tstring(m)
#define TF_CALL_resource(m) #define TF_CALL_resource(m)
#define TF_CALL_variant(m) #define TF_CALL_variant(m)
#define TF_CALL_complex64(m) #define TF_CALL_complex64(m)
@ -188,10 +191,10 @@ limitations under the License.
// Call "m" on all types. // Call "m" on all types.
#define TF_CALL_ALL_TYPES(m) \ #define TF_CALL_ALL_TYPES(m) \
TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) TF_CALL_resource(m) TF_CALL_variant(m)
// Call "m" on POD and string types. // Call "m" on POD and string types.
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m)
// Call "m" on all number types supported on GPU. // Call "m" on all number types supported on GPU.
#define TF_CALL_GPU_NUMBER_TYPES(m) \ #define TF_CALL_GPU_NUMBER_TYPES(m) \
@ -213,7 +216,7 @@ limitations under the License.
#define TF_CALL_SAVE_RESTORE_TYPES(m) \ #define TF_CALL_SAVE_RESTORE_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \ TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \ TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \ TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_tstring(m) \
TF_CALL_QUANTIZED_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
#ifdef TENSORFLOW_SYCL_NO_DOUBLE #ifdef TENSORFLOW_SYCL_NO_DOUBLE

View File

@ -168,7 +168,7 @@ struct Helper {
// Helper specialization for string (the only non-simple type we // Helper specialization for string (the only non-simple type we
// support). // support).
template <> template <>
struct Helper<string> { struct Helper<tstring> {
// Proto message uses RepeatedFieldType to hold repeated T. // Proto message uses RepeatedFieldType to hold repeated T.
typedef protobuf::RepeatedPtrField<string> RepeatedFieldType; typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
@ -176,7 +176,7 @@ struct Helper<string> {
// "out", which is usually the TensorProto::tensor_content. // "out", which is usually the TensorProto::tensor_content.
template <typename Destination> template <typename Destination>
static void Encode(TensorBuffer* in, int64 n, Destination* out) { static void Encode(TensorBuffer* in, int64 n, Destination* out) {
port::EncodeStringList(in->base<const string>(), n, out); port::EncodeStringList(in->base<const tstring>(), n, out);
} }
// Decodes "n" elements of type string from "in" and constructs a // Decodes "n" elements of type string from "in" and constructs a
@ -184,8 +184,8 @@ struct Helper<string> {
// usually the TensorProto::tensor_content. // usually the TensorProto::tensor_content.
template <typename Source> template <typename Source>
static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
Buffer<string>* buf = new Buffer<string>(a, n); Buffer<tstring>* buf = new Buffer<tstring>(a, n);
string* strings = buf->template base<string>(); tstring* strings = buf->template base<tstring>();
if (strings == nullptr || !port::DecodeStringList(in, strings, n)) { if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
buf->Unref(); buf->Unref();
return nullptr; return nullptr;
@ -197,8 +197,8 @@ struct Helper<string> {
// stored in buffer "in". // stored in buffer "in".
static int64 TotalBytes(TensorBuffer* in, int n) { static int64 TotalBytes(TensorBuffer* in, int n) {
int64 tot = in->size(); int64 tot = in->size();
DCHECK_EQ(tot, sizeof(string) * n); DCHECK_EQ(tot, sizeof(tstring) * n);
const string* p = in->base<const string>(); const tstring* p = in->base<const tstring>();
for (int i = 0; i < n; ++i, ++p) tot += p->size(); for (int i = 0; i < n; ++i, ++p) tot += p->size();
return tot; return tot;
} }
@ -302,7 +302,7 @@ PROTO_TRAITS(uint32, uint32, uint32);
PROTO_TRAITS(int16, int32, int); PROTO_TRAITS(int16, int32, int);
PROTO_TRAITS(int8, int32, int); PROTO_TRAITS(int8, int32, int);
PROTO_TRAITS(bool, bool, bool); PROTO_TRAITS(bool, bool, bool);
PROTO_TRAITS(string, string, string); PROTO_TRAITS(tstring, tstring, string);
PROTO_TRAITS(qint8, int32, int); PROTO_TRAITS(qint8, int32, int);
PROTO_TRAITS(quint8, int32, int); PROTO_TRAITS(quint8, int32, int);
PROTO_TRAITS(qint16, int32, int); PROTO_TRAITS(qint16, int32, int);
@ -713,7 +713,7 @@ bool Tensor::RefCountIsOne() const {
CASE(uint64, SINGLE_ARG(STMTS)) \ CASE(uint64, SINGLE_ARG(STMTS)) \
CASE(int16, SINGLE_ARG(STMTS)) \ CASE(int16, SINGLE_ARG(STMTS)) \
CASE(int8, SINGLE_ARG(STMTS)) \ CASE(int8, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \ CASE(tstring, SINGLE_ARG(STMTS)) \
CASE(complex64, SINGLE_ARG(STMTS)) \ CASE(complex64, SINGLE_ARG(STMTS)) \
CASE(complex128, SINGLE_ARG(STMTS)) \ CASE(complex128, SINGLE_ARG(STMTS)) \
CASE(int64, SINGLE_ARG(STMTS)) \ CASE(int64, SINGLE_ARG(STMTS)) \
@ -968,7 +968,7 @@ inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
bool print_v2) { bool print_v2) {
return a; return a;
} }
inline string PrintOneElement(const string& a, bool print_v2) { inline string PrintOneElement(const tstring& a, bool print_v2) {
if (print_v2) { if (print_v2) {
return "\"" + absl::CEscape(a) + "\""; return "\"" + absl::CEscape(a) + "\"";
} else { } else {
@ -1164,7 +1164,7 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2); return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break; break;
case DT_STRING: case DT_STRING:
return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2); return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
break; break;
default: { default: {
// All irregular cases // All irregular cases

View File

@ -150,7 +150,7 @@ class Tensor {
: Tensor(scalar_value, host_scalar_tag{}) {} : Tensor(scalar_value, host_scalar_tag{}) {}
explicit Tensor(int8 scalar_value) explicit Tensor(int8 scalar_value)
: Tensor(scalar_value, host_scalar_tag{}) {} : Tensor(scalar_value, host_scalar_tag{}) {}
explicit Tensor(string scalar_value) explicit Tensor(tstring scalar_value)
: Tensor(std::move(scalar_value), host_scalar_tag{}) {} : Tensor(std::move(scalar_value), host_scalar_tag{}) {}
explicit Tensor(complex64 scalar_value) explicit Tensor(complex64 scalar_value)
: Tensor(scalar_value, host_scalar_tag{}) {} : Tensor(scalar_value, host_scalar_tag{}) {}
@ -183,7 +183,7 @@ class Tensor {
// convenience because otherwise passing a string literal would surprisingly // convenience because otherwise passing a string literal would surprisingly
// construct a DT_BOOL tensor. // construct a DT_BOOL tensor.
explicit Tensor(const char* scalar_value) explicit Tensor(const char* scalar_value)
: Tensor(string(scalar_value), host_scalar_tag{}) {} : Tensor(tstring(scalar_value), host_scalar_tag{}) {}
/// Copy constructor. /// Copy constructor.
Tensor(const Tensor& other); Tensor(const Tensor& other);

View File

@ -94,6 +94,7 @@ TEST(TensorTest, DataType_Traits) {
EXPECT_TRUE(std::is_trivial<int8>::value); EXPECT_TRUE(std::is_trivial<int8>::value);
EXPECT_TRUE(std::is_trivial<int64>::value); EXPECT_TRUE(std::is_trivial<int64>::value);
EXPECT_TRUE(std::is_trivial<bool>::value); EXPECT_TRUE(std::is_trivial<bool>::value);
EXPECT_FALSE(std::is_trivial<tstring>::value);
EXPECT_FALSE(std::is_trivial<string>::value); EXPECT_FALSE(std::is_trivial<string>::value);
EXPECT_EQ(sizeof(bool), 1); EXPECT_EQ(sizeof(bool), 1);
@ -903,15 +904,15 @@ TEST(Tensor_Float, Reshape_And_Slice_Assignment) {
} }
TEST(Tensor_String, Simple) { TEST(Tensor_String, Simple) {
Tensor t = test::AsTensor<string>( Tensor t = test::AsTensor<tstring>(
{"hello", "world", "machine", "learning", "new", "york"}, {"hello", "world", "machine", "learning", "new", "york"},
TensorShape({3, 2})); TensorShape({3, 2}));
auto s = t.shape(); auto s = t.shape();
ASSERT_EQ(s.dims(), 2); ASSERT_EQ(s.dims(), 2);
ASSERT_EQ(s.dim_size(0), 3); ASSERT_EQ(s.dim_size(0), 3);
ASSERT_EQ(s.dim_size(1), 2); ASSERT_EQ(s.dim_size(1), 2);
auto m = t.matrix<string>(); auto m = t.matrix<tstring>();
EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4); EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(tstring) + 5 + 5 + 7 + 8 + 3 + 4);
EXPECT_EQ(m(0, 0), "hello"); EXPECT_EQ(m(0, 0), "hello");
EXPECT_EQ(m(0, 1), "world"); EXPECT_EQ(m(0, 1), "world");
@ -920,7 +921,7 @@ TEST(Tensor_String, Simple) {
EXPECT_EQ(m(2, 0), "new"); EXPECT_EQ(m(2, 0), "new");
EXPECT_EQ(m(2, 1), "york"); EXPECT_EQ(m(2, 1), "york");
TestCopies<string>(t); TestCopies<tstring>(t);
} }
TEST(Tensor_Float, SimpleWithHelper) { TEST(Tensor_Float, SimpleWithHelper) {
@ -976,7 +977,7 @@ TEST(Tensor_Int64, SimpleWithHelper) {
} }
TEST(Tensor_String, SimpleWithHelper) { TEST(Tensor_String, SimpleWithHelper) {
Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3}); Tensor t1 = test::AsTensor<tstring>({"0", "1", "2", "3", "4", "5"}, {2, 3});
Tensor t2(DT_STRING, {2, 3}); Tensor t2(DT_STRING, {2, 3});
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) { for (int j = 0; j < 3; ++j) {
@ -985,7 +986,7 @@ TEST(Tensor_String, SimpleWithHelper) {
} }
// Test with helper. // Test with helper.
test::ExpectTensorEqual<string>(t1, t2); test::ExpectTensorEqual<tstring>(t1, t2);
} }
TEST(Tensor_Bool, SimpleWithHelper) { TEST(Tensor_Bool, SimpleWithHelper) {
@ -1365,11 +1366,11 @@ TEST(SummarizeValue, BOOL) {
} }
TEST(SummarizeValue, STRING) { TEST(SummarizeValue, STRING) {
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}), Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
{"one", "two", "three", "four", "five"}); {"one", "two", "three", "four", "five"});
EXPECT_EQ("one two three four five", x.SummarizeValue(16)); EXPECT_EQ("one two three four five", x.SummarizeValue(16));
x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}), x = MkTensor<tstring>(DT_STRING, TensorShape({5, 1, 5}),
{"one", "two", "three", "four", "five"}); {"one", "two", "three", "four", "five"});
EXPECT_EQ("[[one two three four five]][[one...]]...", x.SummarizeValue(6)); EXPECT_EQ("[[one two three four five]][[one...]]...", x.SummarizeValue(6));
} }
@ -1421,16 +1422,16 @@ TEST(SummarizeValue, BOOL_PRINT_V2) {
} }
TEST(SummarizeValue, STRING_PRINT_V2) { TEST(SummarizeValue, STRING_PRINT_V2) {
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}), Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
{"one", "two", "three", "four", "five"}); {"one", "two", "three", "four", "five"});
EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]", EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]",
x.SummarizeValue(16, true)); x.SummarizeValue(16, true));
EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]", EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]",
x.SummarizeValue(-1, true)); x.SummarizeValue(-1, true));
EXPECT_EQ("[\"one\" \"two\" ... \"four\" \"five\"]", EXPECT_EQ("[\"one\" \"two\" ... \"four\" \"five\"]",
x.SummarizeValue(2, true)); x.SummarizeValue(2, true));
x = MkTensor<string>(DT_STRING, TensorShape({2, 2}), x = MkTensor<tstring>(DT_STRING, TensorShape({2, 2}),
{"one", "two", "three", "four", "five"}); {"one", "two", "three", "four", "five"});
EXPECT_EQ("[[\"one\" \"two\"]\n [\"three\" \"four\"]]", EXPECT_EQ("[[\"one\" \"two\"]\n [\"three\" \"four\"]]",
x.SummarizeValue(16, true)); x.SummarizeValue(16, true));
} }

View File

@ -98,8 +98,8 @@ Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
if (dtype != DT_STRING) { if (dtype != DT_STRING) {
return errors::Internal("Unexpected data type"); return errors::Internal("Unexpected data type");
} }
string* to_strings = tstring* to_strings =
reinterpret_cast<string*>(const_cast<char*>(to_data.data())); reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
int64 offset = 0; int64 offset = 0;
for (const Tensor& tensor : tensors) { for (const Tensor& tensor : tensors) {
@ -163,7 +163,7 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
shape.set_dim(0, size); shape.set_dim(0, size);
result->emplace_back(tensor.dtype(), shape); result->emplace_back(tensor.dtype(), shape);
Tensor& split = (*result)[result->size() - 1]; Tensor& split = (*result)[result->size() - 1];
string* to_strings = reinterpret_cast<string*>( tstring* to_strings = reinterpret_cast<tstring*>(
const_cast<char*>(split.tensor_data().data())); const_cast<char*>(split.tensor_data().data()));
CHECK_LE(offset + split.NumElements(), tensor.NumElements()); CHECK_LE(offset + split.NumElements(), tensor.NumElements());

View File

@ -77,19 +77,19 @@ class TypedAllocator {
template <> template <>
/* static */ /* static */
inline void TypedAllocator::RunCtor(Allocator* raw_allocator, string* p, inline void TypedAllocator::RunCtor(Allocator* raw_allocator, tstring* p,
size_t n) { size_t n) {
if (!raw_allocator->AllocatesOpaqueHandle()) { if (!raw_allocator->AllocatesOpaqueHandle()) {
for (size_t i = 0; i < n; ++p, ++i) new (p) string(); for (size_t i = 0; i < n; ++p, ++i) new (p) tstring();
} }
} }
template <> template <>
/* static */ /* static */
inline void TypedAllocator::RunDtor(Allocator* raw_allocator, string* p, inline void TypedAllocator::RunDtor(Allocator* raw_allocator, tstring* p,
size_t n) { size_t n) {
if (!raw_allocator->AllocatesOpaqueHandle()) { if (!raw_allocator->AllocatesOpaqueHandle()) {
for (size_t i = 0; i < n; ++p, ++i) p->~string(); for (size_t i = 0; i < n; ++p, ++i) p->~tstring();
} }
} }