Fix TensorFlow Lite windows builds
PiperOrigin-RevId: 260219972
This commit is contained in:
parent
6d4d33d56e
commit
f4ac771313
@ -43,10 +43,11 @@ constexpr int kOutputTensor = 0;
|
|||||||
const int kMaxDim = 4;
|
const int kMaxDim = 4;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TfLiteStatus CalculateOutputShapeVector(
|
TfLiteStatus CalculateOutputShapeVector(TfLiteContext* context,
|
||||||
TfLiteContext* context, const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
const TfLiteTensor* begin, const TfLiteTensor* size,
|
const TfLiteTensor* begin,
|
||||||
std::vector<int64_t>* output_shape_vector) {
|
const TfLiteTensor* size,
|
||||||
|
std::vector<int>* output_shape_vector) {
|
||||||
for (int idx = 0; idx < NumDimensions(input); ++idx) {
|
for (int idx = 0; idx < NumDimensions(input); ++idx) {
|
||||||
T size_value = GetTensorData<T>(size)[idx];
|
T size_value = GetTensorData<T>(size)[idx];
|
||||||
if (size_value < 0) {
|
if (size_value < 0) {
|
||||||
@ -62,7 +63,7 @@ TfLiteStatus CalculateOutputShapeVector(
|
|||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output_shape_vector->push_back(size_value);
|
output_shape_vector->push_back(static_cast<int>(size_value));
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -81,7 +82,7 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context,
|
|||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
const TfLiteTensor* begin,
|
const TfLiteTensor* begin,
|
||||||
const TfLiteTensor* size, TfLiteTensor* output) {
|
const TfLiteTensor* size, TfLiteTensor* output) {
|
||||||
std::vector<int64_t> output_shape_vector;
|
std::vector<int> output_shape_vector;
|
||||||
|
|
||||||
if (begin->type == kTfLiteInt32) {
|
if (begin->type == kTfLiteInt32) {
|
||||||
TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
|
TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
|
||||||
|
@ -70,10 +70,10 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename M>
|
||||||
void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
|
void CopyMultipleTimes(const T* in_data, int32_t in_size, M multiplier,
|
||||||
T* out_data) {
|
T* out_data) {
|
||||||
for (int i = 0; i < multiplier; ++i) {
|
for (M i = 0; i < multiplier; ++i) {
|
||||||
const T* in_end = in_data + in_size;
|
const T* in_end = in_data + in_size;
|
||||||
T* new_out_data = std::copy(in_data, in_end, out_data);
|
T* new_out_data = std::copy(in_data, in_end, out_data);
|
||||||
in_data = out_data;
|
in_data = out_data;
|
||||||
@ -109,8 +109,9 @@ std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
|
|||||||
CopyMultipleTimes(out_data, total_tiled_stride_size,
|
CopyMultipleTimes(out_data, total_tiled_stride_size,
|
||||||
multipliers[dimension] - 1,
|
multipliers[dimension] - 1,
|
||||||
out_data + total_tiled_stride_size);
|
out_data + total_tiled_stride_size);
|
||||||
return std::make_pair(total_stride_size,
|
return std::make_pair(
|
||||||
total_tiled_stride_size * multipliers[dimension]);
|
total_stride_size,
|
||||||
|
static_cast<int>(total_tiled_stride_size * multipliers[dimension]));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
Reference in New Issue
Block a user