Fix for scalar tensors with BytesRequiredForTensor.
PiperOrigin-RevId: 313419954 Change-Id: Id6ce0fec1a1640d332bd55694ed23dc6b9da58da
This commit is contained in:
parent
1d0dfbde01
commit
5e6cb6e324
@ -83,8 +83,12 @@ TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
|
||||
size_t* bytes, size_t* type_size,
|
||||
ErrorReporter* error_reporter) {
|
||||
int element_count = 1;
|
||||
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
|
||||
element_count *= flatbuffer_tensor.shape()->Get(n);
|
||||
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
|
||||
// so has 1 element.
|
||||
if (flatbuffer_tensor.shape() != nullptr) {
|
||||
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
|
||||
element_count *= flatbuffer_tensor.shape()->Get(n);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteType tf_lite_type;
|
||||
|
Loading…
x
Reference in New Issue
Block a user