Fix for scalar tensors with BytesRequiredForTensor.
PiperOrigin-RevId: 313419954 Change-Id: Id6ce0fec1a1640d332bd55694ed23dc6b9da58da
This commit is contained in:
parent
1d0dfbde01
commit
5e6cb6e324
@ -83,8 +83,12 @@ TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
|
|||||||
size_t* bytes, size_t* type_size,
|
size_t* bytes, size_t* type_size,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
int element_count = 1;
|
int element_count = 1;
|
||||||
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
|
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
|
||||||
element_count *= flatbuffer_tensor.shape()->Get(n);
|
// so has 1 element.
|
||||||
|
if (flatbuffer_tensor.shape() != nullptr) {
|
||||||
|
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
|
||||||
|
element_count *= flatbuffer_tensor.shape()->Get(n);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteType tf_lite_type;
|
TfLiteType tf_lite_type;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user