Fix bug in compression of tensor protos containing complex values stored in the repeated fields scomplex or dcomplex. This bug could cause the imaginary parts being set to zero for a tensor with a tail of values with equal real parts.
PiperOrigin-RevId: 274063439
This commit is contained in:
parent
94ac0af6c4
commit
4945e9899a
@ -213,11 +213,11 @@ class TensorProtoHelper : public std::true_type {
|
||||
}
|
||||
|
||||
static T GetValue(size_t index, const TensorProto& proto) {
|
||||
const size_t stride = is_complex<T>::value ? 2 : 1;
|
||||
T val;
|
||||
if (is_complex<T>::value) index *= 2;
|
||||
CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index,
|
||||
FieldHelper::GetField(proto).begin() + index + 1,
|
||||
&val);
|
||||
CopyHelper<T>::ToArray(
|
||||
FieldHelper::GetField(proto).begin() + stride * index,
|
||||
FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
|
||||
return val;
|
||||
}
|
||||
|
||||
|
@ -496,18 +496,31 @@ TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> VectorWithConstantTail(int size, int tail_length) {
|
||||
void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
|
||||
CHECK_LE(tail_length, size);
|
||||
std::vector<T> v(size, T(0));
|
||||
for (int i = 0; i < size - tail_length; ++i) {
|
||||
v[i] = T(i + 1);
|
||||
v->clear();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T vi = (i >= size - tail_length) ? T() : T(i);
|
||||
v->push_back(vi);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void VectorWithConstantTail(int size, int tail_length,
|
||||
std::vector<std::complex<float>>* v) {
|
||||
CHECK_LE(tail_length, size);
|
||||
v->clear();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
std::complex<float> vi(
|
||||
0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
|
||||
v->push_back(vi);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
|
||||
auto values = VectorWithConstantTail<T>(size, tail_length);
|
||||
std::vector<T> values;
|
||||
VectorWithConstantTail<T>(size, tail_length, &values);
|
||||
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
||||
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
||||
TensorProto tensor_proto;
|
||||
@ -517,7 +530,8 @@ TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
|
||||
|
||||
template <typename T>
|
||||
TensorProto CreateAsProtoField(int size, int tail_length) {
|
||||
auto values = VectorWithConstantTail<T>(size, tail_length);
|
||||
std::vector<T> values;
|
||||
VectorWithConstantTail<T>(size, tail_length, &values);
|
||||
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
|
||||
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
|
||||
TensorProto tensor_proto;
|
||||
|
Loading…
Reference in New Issue
Block a user