Merge pull request #35832 from ROCmSoftwarePlatform:google_upstream_rocm_csb_fix_200113

PiperOrigin-RevId: 289548933
Change-Id: I09fbb769ea57c8b0ac7cf5098b449cbbaf641d07
This commit is contained in:
TensorFlower Gardener 2020-01-13 16:58:43 -08:00
commit 9fc8b64300
2 changed files with 37 additions and 3 deletions

View File

@ -249,4 +249,33 @@ int DataTypeSize(DataType dt) {
#undef CASE
}
// Define DataTypeToEnum<T>::value.
#define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
constexpr DataType DataTypeToEnum<TYPE>::value;
DEFINE_DATATYPETOENUM_VALUE(float);
DEFINE_DATATYPETOENUM_VALUE(double);
DEFINE_DATATYPETOENUM_VALUE(int32);
DEFINE_DATATYPETOENUM_VALUE(uint32);
DEFINE_DATATYPETOENUM_VALUE(uint16);
DEFINE_DATATYPETOENUM_VALUE(uint8);
DEFINE_DATATYPETOENUM_VALUE(int16);
DEFINE_DATATYPETOENUM_VALUE(int8);
DEFINE_DATATYPETOENUM_VALUE(tstring);
DEFINE_DATATYPETOENUM_VALUE(complex64);
DEFINE_DATATYPETOENUM_VALUE(complex128);
DEFINE_DATATYPETOENUM_VALUE(int64);
DEFINE_DATATYPETOENUM_VALUE(uint64);
DEFINE_DATATYPETOENUM_VALUE(bool);
DEFINE_DATATYPETOENUM_VALUE(qint8);
DEFINE_DATATYPETOENUM_VALUE(quint8);
DEFINE_DATATYPETOENUM_VALUE(qint16);
DEFINE_DATATYPETOENUM_VALUE(quint16);
DEFINE_DATATYPETOENUM_VALUE(qint32);
DEFINE_DATATYPETOENUM_VALUE(bfloat16);
DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
DEFINE_DATATYPETOENUM_VALUE(Variant);
#undef DEFINE_DATATYPETOENUM_VALUE
} // namespace tensorflow

View File

@ -224,11 +224,16 @@ struct SerializeGroups<T, Variant> {
int64 last_nonempty_group = -1;
// The "DataTypeToEnum<T>::value" member is static and defined but not
// declared. This leads to linker errors when a "DataTypeToEnum<T>::value"
// reference is passed to a routine. Creating a local variable here to
// workaround the linker errors.
DataType T_type = DataTypeToEnum<T>::value;
auto serialize_empty_element = [&](int64 b) {
serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64,
TensorShape({0, rank - 1}));
serialized_sparse_t(b, 1).emplace<Tensor>(DataTypeToEnum<T>::value,
TensorShape({0}));
serialized_sparse_t(b, 1).emplace<Tensor>(T_type, TensorShape({0}));
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
};
@ -256,7 +261,7 @@ struct SerializeGroups<T, Variant> {
Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>(
DT_INT64, TensorShape({num_entries, rank - 1}));
Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
DataTypeToEnum<T>::value, TensorShape({num_entries}));
T_type, TensorShape({num_entries}));
auto output_indices_t = output_indices.matrix<int64>();
auto output_values_t = output_values.vec<T>();