Fix bug in compression of tensor protos containing complex values stored in the repeated fields scomplex or dcomplex. This bug could cause the imaginary parts being set to zero for a tensor with a tail of values with equal real parts.
PiperOrigin-RevId: 274063439
This commit is contained in:
parent
94ac0af6c4
commit
4945e9899a
@ -213,11 +213,11 @@ class TensorProtoHelper : public std::true_type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static T GetValue(size_t index, const TensorProto& proto) {
|
static T GetValue(size_t index, const TensorProto& proto) {
|
||||||
|
const size_t stride = is_complex<T>::value ? 2 : 1;
|
||||||
T val;
|
T val;
|
||||||
if (is_complex<T>::value) index *= 2;
|
CopyHelper<T>::ToArray(
|
||||||
CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index,
|
FieldHelper::GetField(proto).begin() + stride * index,
|
||||||
FieldHelper::GetField(proto).begin() + index + 1,
|
FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
|
||||||
&val);
|
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -496,18 +496,31 @@ TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> VectorWithConstantTail(int size, int tail_length) {
|
void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
|
||||||
CHECK_LE(tail_length, size);
|
CHECK_LE(tail_length, size);
|
||||||
std::vector<T> v(size, T(0));
|
v->clear();
|
||||||
for (int i = 0; i < size - tail_length; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
v[i] = T(i + 1);
|
T vi = (i >= size - tail_length) ? T() : T(i);
|
||||||
|
v->push_back(vi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void VectorWithConstantTail(int size, int tail_length,
|
||||||
|
std::vector<std::complex<float>>* v) {
|
||||||
|
CHECK_LE(tail_length, size);
|
||||||
|
v->clear();
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
std::complex<float> vi(
|
||||||
|
0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
|
||||||
|
v->push_back(vi);
|
||||||
}
|
}
|
||||||
return v;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
|
TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
|
||||||
auto values = VectorWithConstantTail<T>(size, tail_length);
|
std::vector<T> values;
|
||||||
|
VectorWithConstantTail<T>(size, tail_length, &values);
|
||||||
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
||||||
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
||||||
TensorProto tensor_proto;
|
TensorProto tensor_proto;
|
||||||
@ -517,7 +530,8 @@ TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TensorProto CreateAsProtoField(int size, int tail_length) {
|
TensorProto CreateAsProtoField(int size, int tail_length) {
|
||||||
auto values = VectorWithConstantTail<T>(size, tail_length);
|
std::vector<T> values;
|
||||||
|
VectorWithConstantTail<T>(size, tail_length, &values);
|
||||||
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
||||||
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
||||||
TensorProto tensor_proto;
|
TensorProto tensor_proto;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user