Migrated tensorflow/core/framework/ to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 262737181
This commit is contained in:
Dero Gharibian 2019-08-10 11:26:24 -07:00 committed by TensorFlower Gardener
parent 5e307a0fbb
commit b4bf76a431
7 changed files with 70 additions and 39 deletions

View File

@ -68,7 +68,12 @@ class SerializationContext;
class IteratorStateReader {
public:
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
virtual Status ReadScalar(StringPiece key, string* val) = 0;
#endif
virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
virtual bool Contains(StringPiece key) = 0;
@ -80,7 +85,12 @@ class IteratorStateReader {
class IteratorStateWriter {
public:
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
#endif
virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
virtual ~IteratorStateWriter() {}
@ -115,7 +125,7 @@ class GraphDefBuilderWrapper {
Status AddVector(const std::vector<T>& val, Node** output) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
TensorShape({static_cast<int64>(val.size())}));
for (int i = 0; i < val.size(); i++) {
for (size_t i = 0; i < val.size(); i++) {
val_t.flat<T>()(i) = val[i];
}
AddTensorInternal(val_t, output);
@ -125,6 +135,23 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
#ifdef USE_TSTRING
// TODO(dero): Temp guard to prevent duplicate declaration during tstring
// migration.
Status AddVector(const std::vector<string>& val, Node** output) {
Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
TensorShape({static_cast<int64>(val.size())}));
for (size_t i = 0; i < val.size(); i++) {
val_t.flat<tstring>()(i) = val[i];
}
AddTensorInternal(val_t, output);
if (*output == nullptr) {
return errors::Internal("AddVector: Failed to build Const op.");
}
return Status::OK();
}
#endif // USE_TSTRING
// Adds a `Const` node for the given tensor value to the graph.
//
// `*output` contains a pointer to the output `Node`. It is guaranteed to be

View File

@ -67,7 +67,8 @@ limitations under the License.
#define TF_CALL_int16(m) m(::tensorflow::int16)
#define TF_CALL_int8(m) m(::tensorflow::int8)
#define TF_CALL_string(m) m(string)
#define TF_CALL_string(m) m(tstring)
#define TF_CALL_tstring(m) m(tstring)
#define TF_CALL_resource(m) m(::tensorflow::ResourceHandle)
#define TF_CALL_variant(m) m(::tensorflow::Variant)
#define TF_CALL_complex64(m) m(::tensorflow::complex64)
@ -98,7 +99,8 @@ limitations under the License.
#define TF_CALL_int16(m)
#define TF_CALL_int8(m)
#define TF_CALL_string(m) m(string)
#define TF_CALL_string(m) m(tstring)
#define TF_CALL_tstring(m) m(tstring)
#define TF_CALL_resource(m)
#define TF_CALL_variant(m)
#define TF_CALL_complex64(m)
@ -129,6 +131,7 @@ limitations under the License.
#define TF_CALL_int8(m)
#define TF_CALL_string(m)
#define TF_CALL_tstring(m)
#define TF_CALL_resource(m)
#define TF_CALL_variant(m)
#define TF_CALL_complex64(m)
@ -188,10 +191,10 @@ limitations under the License.
// Call "m" on all types.
#define TF_CALL_ALL_TYPES(m) \
TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m)
TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) TF_CALL_resource(m) TF_CALL_variant(m)
// Call "m" on POD and string types.
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m)
// Call "m" on all number types supported on GPU.
#define TF_CALL_GPU_NUMBER_TYPES(m) \
@ -213,7 +216,7 @@ limitations under the License.
#define TF_CALL_SAVE_RESTORE_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_tstring(m) \
TF_CALL_QUANTIZED_TYPES(m)
#ifdef TENSORFLOW_SYCL_NO_DOUBLE

View File

@ -168,7 +168,7 @@ struct Helper {
// Helper specialization for string (the only non-simple type we
// support).
template <>
struct Helper<string> {
struct Helper<tstring> {
// Proto message uses RepeatedFieldType to hold repeated T.
typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
@ -176,7 +176,7 @@ struct Helper<string> {
// "out", which is usually the TensorProto::tensor_content.
template <typename Destination>
static void Encode(TensorBuffer* in, int64 n, Destination* out) {
port::EncodeStringList(in->base<const string>(), n, out);
port::EncodeStringList(in->base<const tstring>(), n, out);
}
// Decodes "n" elements of type string from "in" and constructs a
@ -184,8 +184,8 @@ struct Helper<string> {
// usually the TensorProto::tensor_content.
template <typename Source>
static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
Buffer<string>* buf = new Buffer<string>(a, n);
string* strings = buf->template base<string>();
Buffer<tstring>* buf = new Buffer<tstring>(a, n);
tstring* strings = buf->template base<tstring>();
if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
buf->Unref();
return nullptr;
@ -197,8 +197,8 @@ struct Helper<string> {
// stored in buffer "in".
static int64 TotalBytes(TensorBuffer* in, int n) {
int64 tot = in->size();
DCHECK_EQ(tot, sizeof(string) * n);
const string* p = in->base<const string>();
DCHECK_EQ(tot, sizeof(tstring) * n);
const tstring* p = in->base<const tstring>();
for (int i = 0; i < n; ++i, ++p) tot += p->size();
return tot;
}
@ -302,7 +302,7 @@ PROTO_TRAITS(uint32, uint32, uint32);
PROTO_TRAITS(int16, int32, int);
PROTO_TRAITS(int8, int32, int);
PROTO_TRAITS(bool, bool, bool);
PROTO_TRAITS(string, string, string);
PROTO_TRAITS(tstring, tstring, string);
PROTO_TRAITS(qint8, int32, int);
PROTO_TRAITS(quint8, int32, int);
PROTO_TRAITS(qint16, int32, int);
@ -713,7 +713,7 @@ bool Tensor::RefCountIsOne() const {
CASE(uint64, SINGLE_ARG(STMTS)) \
CASE(int16, SINGLE_ARG(STMTS)) \
CASE(int8, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \
CASE(tstring, SINGLE_ARG(STMTS)) \
CASE(complex64, SINGLE_ARG(STMTS)) \
CASE(complex128, SINGLE_ARG(STMTS)) \
CASE(int64, SINGLE_ARG(STMTS)) \
@ -968,7 +968,7 @@ inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
bool print_v2) {
return a;
}
inline string PrintOneElement(const string& a, bool print_v2) {
inline string PrintOneElement(const tstring& a, bool print_v2) {
if (print_v2) {
return "\"" + absl::CEscape(a) + "\"";
} else {
@ -1164,7 +1164,7 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
case DT_STRING:
return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2);
return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases

View File

@ -150,7 +150,7 @@ class Tensor {
: Tensor(scalar_value, host_scalar_tag{}) {}
explicit Tensor(int8 scalar_value)
: Tensor(scalar_value, host_scalar_tag{}) {}
explicit Tensor(string scalar_value)
explicit Tensor(tstring scalar_value)
: Tensor(std::move(scalar_value), host_scalar_tag{}) {}
explicit Tensor(complex64 scalar_value)
: Tensor(scalar_value, host_scalar_tag{}) {}
@ -183,7 +183,7 @@ class Tensor {
// convenience because otherwise passing a string literal would surprisingly
// construct a DT_BOOL tensor.
explicit Tensor(const char* scalar_value)
: Tensor(string(scalar_value), host_scalar_tag{}) {}
: Tensor(tstring(scalar_value), host_scalar_tag{}) {}
/// Copy constructor.
Tensor(const Tensor& other);

View File

@ -94,6 +94,7 @@ TEST(TensorTest, DataType_Traits) {
EXPECT_TRUE(std::is_trivial<int8>::value);
EXPECT_TRUE(std::is_trivial<int64>::value);
EXPECT_TRUE(std::is_trivial<bool>::value);
EXPECT_FALSE(std::is_trivial<tstring>::value);
EXPECT_FALSE(std::is_trivial<string>::value);
EXPECT_EQ(sizeof(bool), 1);
@ -903,15 +904,15 @@ TEST(Tensor_Float, Reshape_And_Slice_Assignment) {
}
TEST(Tensor_String, Simple) {
Tensor t = test::AsTensor<string>(
Tensor t = test::AsTensor<tstring>(
{"hello", "world", "machine", "learning", "new", "york"},
TensorShape({3, 2}));
auto s = t.shape();
ASSERT_EQ(s.dims(), 2);
ASSERT_EQ(s.dim_size(0), 3);
ASSERT_EQ(s.dim_size(1), 2);
auto m = t.matrix<string>();
EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4);
auto m = t.matrix<tstring>();
EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(tstring) + 5 + 5 + 7 + 8 + 3 + 4);
EXPECT_EQ(m(0, 0), "hello");
EXPECT_EQ(m(0, 1), "world");
@ -920,7 +921,7 @@ TEST(Tensor_String, Simple) {
EXPECT_EQ(m(2, 0), "new");
EXPECT_EQ(m(2, 1), "york");
TestCopies<string>(t);
TestCopies<tstring>(t);
}
TEST(Tensor_Float, SimpleWithHelper) {
@ -976,7 +977,7 @@ TEST(Tensor_Int64, SimpleWithHelper) {
}
TEST(Tensor_String, SimpleWithHelper) {
Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3});
Tensor t1 = test::AsTensor<tstring>({"0", "1", "2", "3", "4", "5"}, {2, 3});
Tensor t2(DT_STRING, {2, 3});
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
@ -985,7 +986,7 @@ TEST(Tensor_String, SimpleWithHelper) {
}
// Test with helper.
test::ExpectTensorEqual<string>(t1, t2);
test::ExpectTensorEqual<tstring>(t1, t2);
}
TEST(Tensor_Bool, SimpleWithHelper) {
@ -1365,10 +1366,10 @@ TEST(SummarizeValue, BOOL) {
}
TEST(SummarizeValue, STRING) {
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
{"one", "two", "three", "four", "five"});
EXPECT_EQ("one two three four five", x.SummarizeValue(16));
x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
x = MkTensor<tstring>(DT_STRING, TensorShape({5, 1, 5}),
{"one", "two", "three", "four", "five"});
EXPECT_EQ("[[one two three four five]][[one...]]...", x.SummarizeValue(6));
}
@ -1421,7 +1422,7 @@ TEST(SummarizeValue, BOOL_PRINT_V2) {
}
TEST(SummarizeValue, STRING_PRINT_V2) {
Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
Tensor x = MkTensor<tstring>(DT_STRING, TensorShape({5}),
{"one", "two", "three", "four", "five"});
EXPECT_EQ("[\"one\" \"two\" \"three\" \"four\" \"five\"]",
x.SummarizeValue(16, true));
@ -1429,7 +1430,7 @@ TEST(SummarizeValue, STRING_PRINT_V2) {
x.SummarizeValue(-1, true));
EXPECT_EQ("[\"one\" \"two\" ... \"four\" \"five\"]",
x.SummarizeValue(2, true));
x = MkTensor<string>(DT_STRING, TensorShape({2, 2}),
x = MkTensor<tstring>(DT_STRING, TensorShape({2, 2}),
{"one", "two", "three", "four", "five"});
EXPECT_EQ("[[\"one\" \"two\"]\n [\"three\" \"four\"]]",
x.SummarizeValue(16, true));

View File

@ -98,8 +98,8 @@ Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
if (dtype != DT_STRING) {
return errors::Internal("Unexpected data type");
}
string* to_strings =
reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
tstring* to_strings =
reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
int64 offset = 0;
for (const Tensor& tensor : tensors) {
@ -163,7 +163,7 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
shape.set_dim(0, size);
result->emplace_back(tensor.dtype(), shape);
Tensor& split = (*result)[result->size() - 1];
string* to_strings = reinterpret_cast<string*>(
tstring* to_strings = reinterpret_cast<tstring*>(
const_cast<char*>(split.tensor_data().data()));
CHECK_LE(offset + split.NumElements(), tensor.NumElements());

View File

@ -77,19 +77,19 @@ class TypedAllocator {
template <>
/* static */
inline void TypedAllocator::RunCtor(Allocator* raw_allocator, string* p,
inline void TypedAllocator::RunCtor(Allocator* raw_allocator, tstring* p,
size_t n) {
if (!raw_allocator->AllocatesOpaqueHandle()) {
for (size_t i = 0; i < n; ++p, ++i) new (p) string();
for (size_t i = 0; i < n; ++p, ++i) new (p) tstring();
}
}
template <>
/* static */
inline void TypedAllocator::RunDtor(Allocator* raw_allocator, string* p,
inline void TypedAllocator::RunDtor(Allocator* raw_allocator, tstring* p,
size_t n) {
if (!raw_allocator->AllocatesOpaqueHandle()) {
for (size_t i = 0; i < n; ++p, ++i) p->~string();
for (size_t i = 0; i < n; ++p, ++i) p->~tstring();
}
}