Fix bug in compression of tensor protos containing complex values stored in the repeated fields scomplex or dcomplex. This bug could cause the imaginary parts being set to zero for a tensor with a tail of values with equal real parts.

PiperOrigin-RevId: 274063439
This commit is contained in:
A. Unique TensorFlower 2019-10-10 16:26:31 -07:00 committed by TensorFlower Gardener
parent 94ac0af6c4
commit 4945e9899a
2 changed files with 25 additions and 11 deletions

View File

@ -213,11 +213,11 @@ class TensorProtoHelper : public std::true_type {
}
static T GetValue(size_t index, const TensorProto& proto) {
const size_t stride = is_complex<T>::value ? 2 : 1;
T val;
if (is_complex<T>::value) index *= 2;
CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index,
FieldHelper::GetField(proto).begin() + index + 1,
&val);
CopyHelper<T>::ToArray(
FieldHelper::GetField(proto).begin() + stride * index,
FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
return val;
}

View File

@ -496,18 +496,31 @@ TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
}
template <typename T>
std::vector<T> VectorWithConstantTail(int size, int tail_length) {
void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
CHECK_LE(tail_length, size);
std::vector<T> v(size, T(0));
for (int i = 0; i < size - tail_length; ++i) {
v[i] = T(i + 1);
v->clear();
for (int i = 0; i < size; ++i) {
T vi = (i >= size - tail_length) ? T() : T(i);
v->push_back(vi);
}
}
template <>
void VectorWithConstantTail(int size, int tail_length,
std::vector<std::complex<float>>* v) {
CHECK_LE(tail_length, size);
v->clear();
for (int i = 0; i < size; ++i) {
std::complex<float> vi(
0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
v->push_back(vi);
}
return v;
}
template <typename T>
TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
auto values = VectorWithConstantTail<T>(size, tail_length);
std::vector<T> values;
VectorWithConstantTail<T>(size, tail_length, &values);
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
TensorProto tensor_proto;
@ -517,7 +530,8 @@ TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
template <typename T>
TensorProto CreateAsProtoField(int size, int tail_length) {
auto values = VectorWithConstantTail<T>(size, tail_length);
std::vector<T> values;
VectorWithConstantTail<T>(size, tail_length, &values);
Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
std::copy(values.begin(), values.end(), tensor.flat<T>().data());
TensorProto tensor_proto;