Fix for scalar tensors with BytesRequiredForTensor.

PiperOrigin-RevId: 313419954
Change-Id: Id6ce0fec1a1640d332bd55694ed23dc6b9da58da
This commit is contained in:
A. Unique TensorFlower 2020-05-27 10:48:43 -07:00 committed by TensorFlower Gardener
parent 1d0dfbde01
commit 5e6cb6e324

View File

@ -83,8 +83,12 @@ TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
size_t* bytes, size_t* type_size,
ErrorReporter* error_reporter) {
int element_count = 1;
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
element_count *= flatbuffer_tensor.shape()->Get(n);
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
// so has 1 element.
if (flatbuffer_tensor.shape() != nullptr) {
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
element_count *= flatbuffer_tensor.shape()->Get(n);
}
}
TfLiteType tf_lite_type;