Merge pull request #35832 from ROCmSoftwarePlatform:google_upstream_rocm_csb_fix_200113
PiperOrigin-RevId: 289548933 Change-Id: I09fbb769ea57c8b0ac7cf5098b449cbbaf641d07
This commit is contained in:
commit
9fc8b64300
@ -249,4 +249,33 @@ int DataTypeSize(DataType dt) {
|
||||
#undef CASE
|
||||
}
|
||||
|
||||
// Define DataTypeToEnum<T>::value.
|
||||
#define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
|
||||
constexpr DataType DataTypeToEnum<TYPE>::value;
|
||||
|
||||
DEFINE_DATATYPETOENUM_VALUE(float);
|
||||
DEFINE_DATATYPETOENUM_VALUE(double);
|
||||
DEFINE_DATATYPETOENUM_VALUE(int32);
|
||||
DEFINE_DATATYPETOENUM_VALUE(uint32);
|
||||
DEFINE_DATATYPETOENUM_VALUE(uint16);
|
||||
DEFINE_DATATYPETOENUM_VALUE(uint8);
|
||||
DEFINE_DATATYPETOENUM_VALUE(int16);
|
||||
DEFINE_DATATYPETOENUM_VALUE(int8);
|
||||
DEFINE_DATATYPETOENUM_VALUE(tstring);
|
||||
DEFINE_DATATYPETOENUM_VALUE(complex64);
|
||||
DEFINE_DATATYPETOENUM_VALUE(complex128);
|
||||
DEFINE_DATATYPETOENUM_VALUE(int64);
|
||||
DEFINE_DATATYPETOENUM_VALUE(uint64);
|
||||
DEFINE_DATATYPETOENUM_VALUE(bool);
|
||||
DEFINE_DATATYPETOENUM_VALUE(qint8);
|
||||
DEFINE_DATATYPETOENUM_VALUE(quint8);
|
||||
DEFINE_DATATYPETOENUM_VALUE(qint16);
|
||||
DEFINE_DATATYPETOENUM_VALUE(quint16);
|
||||
DEFINE_DATATYPETOENUM_VALUE(qint32);
|
||||
DEFINE_DATATYPETOENUM_VALUE(bfloat16);
|
||||
DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
|
||||
DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
|
||||
DEFINE_DATATYPETOENUM_VALUE(Variant);
|
||||
#undef DEFINE_DATATYPETOENUM_VALUE
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -224,11 +224,16 @@ struct SerializeGroups<T, Variant> {
|
||||
|
||||
int64 last_nonempty_group = -1;
|
||||
|
||||
// The "DataTypeToEnum<T>::value" member is static and defined but not
|
||||
// declared. This leads to linker errors when a "DataTypeToEnum<T>::value"
|
||||
// reference is passed to a routine. Creating a local variable here to
|
||||
// workaround the linker errors.
|
||||
DataType T_type = DataTypeToEnum<T>::value;
|
||||
|
||||
auto serialize_empty_element = [&](int64 b) {
|
||||
serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64,
|
||||
TensorShape({0, rank - 1}));
|
||||
serialized_sparse_t(b, 1).emplace<Tensor>(DataTypeToEnum<T>::value,
|
||||
TensorShape({0}));
|
||||
serialized_sparse_t(b, 1).emplace<Tensor>(T_type, TensorShape({0}));
|
||||
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
|
||||
};
|
||||
|
||||
@ -256,7 +261,7 @@ struct SerializeGroups<T, Variant> {
|
||||
Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>(
|
||||
DT_INT64, TensorShape({num_entries, rank - 1}));
|
||||
Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
|
||||
DataTypeToEnum<T>::value, TensorShape({num_entries}));
|
||||
T_type, TensorShape({num_entries}));
|
||||
|
||||
auto output_indices_t = output_indices.matrix<int64>();
|
||||
auto output_values_t = output_values.vec<T>();
|
||||
|
Loading…
x
Reference in New Issue
Block a user