diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 0f7ce1a9894..fb8216d2859 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -213,11 +213,11 @@ class TensorProtoHelper : public std::true_type { } static T GetValue(size_t index, const TensorProto& proto) { + const size_t stride = is_complex<T>::value ? 2 : 1; T val; - if (is_complex<T>::value) index *= 2; - CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index, - FieldHelper::GetField(proto).begin() + index + 1, - &val); + CopyHelper<T>::ToArray( + FieldHelper::GetField(proto).begin() + stride * index, + FieldHelper::GetField(proto).begin() + stride * (index + 1), &val); return val; } diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc index fe988015e27..673e7c02187 100644 --- a/tensorflow/core/framework/tensor_util_test.cc +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -496,18 +496,31 @@ TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) { } template <typename T> -std::vector<T> VectorWithConstantTail(int size, int tail_length) { +void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) { CHECK_LE(tail_length, size); - std::vector<T> v(size, T(0)); - for (int i = 0; i < size - tail_length; ++i) { - v[i] = T(i + 1); + v->clear(); + for (int i = 0; i < size; ++i) { + T vi = (i >= size - tail_length) ? T() : T(i); + v->push_back(vi); + } +} + +template <> +void VectorWithConstantTail(int size, int tail_length, + std::vector<std::complex<float>>* v) { + CHECK_LE(tail_length, size); + v->clear(); + for (int i = 0; i < size; ++i) { + std::complex<float> vi( + 0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i)); + v->push_back(vi); } - return v; } template <typename T> TensorProto CreateAsProtoTensorContent(int size, int tail_length) { - auto values = VectorWithConstantTail<T>(size, tail_length); + std::vector<T> values; + VectorWithConstantTail<T>(size, tail_length, &values); Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size})); std::copy(values.begin(), values.end(), tensor.flat<T>().data()); TensorProto tensor_proto; @@ -517,7 +530,8 @@ TensorProto CreateAsProtoTensorContent(int size, int tail_length) { template <typename T> TensorProto CreateAsProtoField(int size, int tail_length) { - auto values = VectorWithConstantTail<T>(size, tail_length); + std::vector<T> values; + VectorWithConstantTail<T>(size, tail_length, &values); Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size})); std::copy(values.begin(), values.end(), tensor.flat<T>().data()); TensorProto tensor_proto;