diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index ca378af4c5c..9862dbe99d5 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -173,7 +173,8 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_float_data = output_array->GetMutableBuffer<ArrayDataType::kFloat>().data; - output_float_data.resize(input_flat_size); + output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()), + 0.f); if (input_tensor.float_val_size() == 1) { for (int i = 0; i < input_flat_size; i++) { output_float_data[i] = input_tensor.float_val(0); @@ -203,7 +204,7 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kUint8>().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -229,7 +230,7 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kInt32>().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -255,7 +256,7 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kInt64>().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int64_val_size()) { for (int i = 0; i < input_tensor.int64_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); @@ -281,7 +282,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_string_data = output_array->GetMutableBuffer<ArrayDataType::kString>().data; - output_string_data.resize(input_flat_size); + output_string_data.resize(RequiredBufferSizeForShape(output_array->shape())); if (input_flat_size != input_tensor.string_val_size()) { LOG(FATAL) << "Input_content string_val doesn't have the right " "dimensions for this string tensor.";