Internal bug fix.
PiperOrigin-RevId: 358331167 Change-Id: I8d4e850023c3c2c6784cf3329d19ad3b9c25948d
This commit is contained in:
parent
2a2b5ba633
commit
c7b83e2b5e
tensorflow/lite/kernels
@ -431,152 +431,6 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void EvalHybridImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, int thread_start, int thread_end,
|
||||
TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* scaling_factors, TfLiteTensor* accum_scratch,
|
||||
TfLiteTensor* row_sums, TfLiteTensor* input_offsets,
|
||||
TfLiteTensor* output) {
|
||||
ruy::profiler::ScopeLabel label("FullyConnected");
|
||||
ruy::profiler::ScopeLabel inner_label("Hybrid Kernel");
|
||||
const auto& input_shape = GetTensorShape(input);
|
||||
const auto& output_shape = GetTensorShape(output);
|
||||
const auto& filter_shape = GetTensorShape(filter);
|
||||
const int input_dims_count = input_shape.DimensionsCount();
|
||||
const int output_dims_count = output_shape.DimensionsCount();
|
||||
const int filter_dims_count = filter_shape.DimensionsCount();
|
||||
const int batch_size = thread_end - thread_start;
|
||||
const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
|
||||
input_shape, input_dims_count - 1);
|
||||
const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
|
||||
output_shape, output_dims_count - 1);
|
||||
const int per_thread_input_size = batch_size * input_depth;
|
||||
|
||||
const bool is_sparse = filter->sparsity != nullptr;
|
||||
|
||||
const float* per_thread_input =
|
||||
GetTensorData<float>(input) + thread_start * input_depth;
|
||||
float* per_thread_output =
|
||||
GetTensorData<float>(output) + thread_start * output_depth;
|
||||
|
||||
// Output = bias if bias tensor exists.
|
||||
if (bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
|
||||
output_depth, batch_size,
|
||||
per_thread_output);
|
||||
} else {
|
||||
std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
|
||||
}
|
||||
|
||||
// Save matrix multiplication computation for all zero input.
|
||||
if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
|
||||
tensor_utils::ApplyActivationToVector(
|
||||
per_thread_output, batch_size * output_depth, params->activation,
|
||||
per_thread_output);
|
||||
return;
|
||||
}
|
||||
|
||||
// Quantize input from float to uint8 + quantization params (scaling factor).
|
||||
float* scaling_factors_ptr =
|
||||
GetTensorData<float>(scaling_factors) + thread_start;
|
||||
int32_t* input_offset_ptr = nullptr;
|
||||
int32_t* row_sums_ptr = nullptr;
|
||||
if (params->asymmetric_quantize_inputs) {
|
||||
input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
|
||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||
}
|
||||
int8_t* quant_data =
|
||||
GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
|
||||
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
||||
tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
|
||||
quant_data, scaling_factors_ptr,
|
||||
input_offset_ptr,
|
||||
params->asymmetric_quantize_inputs);
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
// Incorporate scaling of the filter.
|
||||
scaling_factors_ptr[b] *= filter->params.scale;
|
||||
}
|
||||
|
||||
// Compute output += weight * quantized_input
|
||||
int32_t* scratch =
|
||||
GetTensorData<int32_t>(accum_scratch) + thread_start * output_depth;
|
||||
if (is_sparse) {
|
||||
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
|
||||
output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
|
||||
per_thread_output);
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, output_depth, input_depth, quant_data, scaling_factors_ptr,
|
||||
batch_size, per_thread_output, /*per_channel_scale=*/nullptr,
|
||||
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
|
||||
// Apply activation function to floats.
|
||||
tensor_utils::ApplyActivationToVector(per_thread_output,
|
||||
batch_size * output_depth,
|
||||
params->activation, per_thread_output);
|
||||
}
|
||||
|
||||
struct HybridFullyConnectedTask : cpu_backend_threadpool::Task {
|
||||
HybridFullyConnectedTask(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||
const int thread_start, const int thread_end,
|
||||
TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
||||
TfLiteTensor* input_offsets, TfLiteTensor* output)
|
||||
: context(context),
|
||||
node(node),
|
||||
params(params),
|
||||
data(data),
|
||||
input(input),
|
||||
filter(filter),
|
||||
bias(bias),
|
||||
thread_start(thread_start),
|
||||
thread_end(thread_end),
|
||||
input_quantized(input_quantized),
|
||||
scaling_factors(scaling_factors),
|
||||
accum_scratch(accum_scratch),
|
||||
row_sums(row_sums),
|
||||
input_offsets(input_offsets),
|
||||
output(output) {}
|
||||
|
||||
void Run() override {
|
||||
EvalHybridImpl(context, node, params, data, input, filter, bias,
|
||||
thread_start, thread_end, input_quantized, scaling_factors,
|
||||
accum_scratch, row_sums, input_offsets, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TfLiteContext* context;
|
||||
TfLiteNode* node;
|
||||
TfLiteFullyConnectedParams* params;
|
||||
OpData* data;
|
||||
const TfLiteTensor* input;
|
||||
const TfLiteTensor* filter;
|
||||
const TfLiteTensor* bias;
|
||||
const int thread_start;
|
||||
const int thread_end;
|
||||
TfLiteTensor* input_quantized;
|
||||
TfLiteTensor* scaling_factors;
|
||||
TfLiteTensor* accum_scratch;
|
||||
TfLiteTensor* row_sums;
|
||||
TfLiteTensor* input_offsets;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
// The multi-threaded kernel slices the workload along the batch dimension. If
|
||||
// there's not enough batches of data, the number of threads used is equal to
|
||||
// the batch size.
|
||||
// TODO(b/173442777): If needed, we can improve this later with slicing along
|
||||
// the row dimension of the weight.
|
||||
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
@ -584,10 +438,55 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
||||
TfLiteTensor* input_offsets, TfLiteTensor* output) {
|
||||
const auto& output_shape = GetTensorShape(output);
|
||||
CpuBackendContext* cpu_backend_context =
|
||||
CpuBackendContext::GetFromContext(context);
|
||||
int total_input_size = 1;
|
||||
for (int i = 0; i < input->dims->size; i++) {
|
||||
total_input_size *= input->dims->data[i];
|
||||
}
|
||||
|
||||
const int input_size = filter->dims->data[1];
|
||||
const int batch_size = total_input_size / filter->dims->data[1];
|
||||
const int num_units = filter->dims->data[0];
|
||||
const bool is_sparse = filter->sparsity != nullptr;
|
||||
|
||||
// Output = bias if bias tensor exists.
|
||||
if (bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
|
||||
batch_size,
|
||||
GetTensorData<float>(output));
|
||||
} else {
|
||||
std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
|
||||
}
|
||||
|
||||
// Save matrix multiplication computation for all zero input.
|
||||
if (tensor_utils::IsZeroVector(GetTensorData<float>(input),
|
||||
total_input_size)) {
|
||||
tensor_utils::ApplyActivationToVector(
|
||||
GetTensorData<float>(output), batch_size * num_units,
|
||||
params->activation, GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Quantize input from float to uint8 + quantization params (scaling factor).
|
||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||
int32_t* input_offset_ptr = nullptr;
|
||||
int32_t* row_sums_ptr = nullptr;
|
||||
if (params->asymmetric_quantize_inputs) {
|
||||
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
|
||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||
}
|
||||
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
||||
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
||||
const float* input_ptr = GetTensorData<float>(input);
|
||||
tensor_utils::BatchQuantizeFloats(
|
||||
input_ptr, batch_size, input_size, quant_data, scaling_factors_ptr,
|
||||
input_offset_ptr, params->asymmetric_quantize_inputs);
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
// Incorporate scaling of the filter.
|
||||
scaling_factors_ptr[b] *= filter->params.scale;
|
||||
}
|
||||
|
||||
// Compute output += weight * quantized_input
|
||||
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
||||
if (is_sparse) {
|
||||
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
|
||||
if (!data->ledger_initialized) {
|
||||
@ -595,37 +494,22 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
GetTensorData<uint8_t>(filter_ledger));
|
||||
data->ledger_initialized = true;
|
||||
}
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
|
||||
num_units, input_size, quant_data, scaling_factors_ptr, batch_size,
|
||||
GetTensorData<float>(output));
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
|
||||
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
|
||||
const int max_threads = cpu_backend_context->max_num_threads();
|
||||
const int batches =
|
||||
FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
|
||||
const int thread_count = std::max(1, std::min(batches, max_threads));
|
||||
if (thread_count == 1) {
|
||||
EvalHybridImpl(context, node, params, data, input, filter, bias, 0, batches,
|
||||
input_quantized, scaling_factors, accum_scratch, row_sums,
|
||||
input_offsets, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
std::vector<HybridFullyConnectedTask> tasks;
|
||||
tasks.reserve(thread_count);
|
||||
int thread_start = 0;
|
||||
for (int i = 0; i < thread_count; ++i) {
|
||||
// This makes sure the workload is relatively balanced when batches is not
|
||||
// a multiple of thread_count. The first mod(batches, thread_count) tasks
|
||||
// need to process one more batch than the rest.
|
||||
int thread_end = thread_start + batches / thread_count;
|
||||
if (i < batches % thread_count) thread_end++;
|
||||
|
||||
tasks.emplace_back(context, node, params, data, input, filter, bias,
|
||||
thread_start, thread_end, input_quantized,
|
||||
scaling_factors, accum_scratch, row_sums, input_offsets,
|
||||
output);
|
||||
thread_start = thread_end;
|
||||
}
|
||||
cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
|
||||
cpu_backend_context);
|
||||
// Apply activation function to floats.
|
||||
tensor_utils::ApplyActivationToVector(
|
||||
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
||||
GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -296,8 +296,7 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
||||
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
|
||||
const TensorData& weights,
|
||||
const TensorData& output = {TensorType_FLOAT32},
|
||||
bool asymmetric_inputs = false,
|
||||
int num_threads = 1)
|
||||
bool asymmetric_inputs = false)
|
||||
: batches_(batches), units_(units) {
|
||||
int total_input_size = 1;
|
||||
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||
@ -323,9 +322,7 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||
BuiltinOperator_FULLY_CONNECTED,
|
||||
ops::builtin::Register_FULLY_CONNECTED_PIE());
|
||||
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
|
||||
num_threads, /*allow_fp32_relax_to_fp16=*/false,
|
||||
/*apply_delegate=*/false);
|
||||
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
|
||||
}
|
||||
void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
|
||||
void SetWeights(const std::vector<float>& data) {
|
||||
@ -882,44 +879,6 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
||||
/*max_abs_error=*/1.3f)));
|
||||
}
|
||||
|
||||
TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) {
|
||||
for (int num_threads = 1; num_threads <= 4; ++num_threads) {
|
||||
HybridFullyConnectedOpModel m(
|
||||
/*units=*/3, /*batches=*/4,
|
||||
/*input=*/{TensorType_FLOAT32, {4, 10}},
|
||||
/*weights=*/
|
||||
{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
|
||||
/*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false,
|
||||
/*num_threads=*/num_threads); // Hybrid
|
||||
|
||||
m.SetSignedWeights({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||
});
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 2
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 3
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
24, 25, 26, //
|
||||
58, 59, 60, //
|
||||
24, 25, 26, //
|
||||
58, 59, 60, //
|
||||
},
|
||||
/*max_abs_error=*/1.3f)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
|
||||
HybridFullyConnectedOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
@ -1454,76 +1413,6 @@ TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) {
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
|
||||
}
|
||||
|
||||
TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) {
|
||||
std::initializer_list<float> weight_data = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
|
||||
10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
|
||||
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
|
||||
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
|
||||
8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
|
||||
TensorData weight = {};
|
||||
weight.type = TensorType_FLOAT32;
|
||||
weight.shape = {4, 48};
|
||||
weight.traversal_order = {0, 1, 2};
|
||||
weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
weight.block_map = {1};
|
||||
weight.block_size = {16};
|
||||
for (int num_threads = 1; num_threads <= 4; ++num_threads) {
|
||||
SparseFullyConnectedOpModel<float> m(
|
||||
GetRegistration(),
|
||||
/*units=*/4, /*batches=*/4,
|
||||
/*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data,
|
||||
/*num_threads)=*/num_threads, /*symmetric_quantize_weights=*/true);
|
||||
m.SetBias({1, 2, 3, 4});
|
||||
m.SetInput({
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
|
||||
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
|
||||
-1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
|
||||
0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
|
||||
-2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 2
|
||||
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
|
||||
-1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
|
||||
0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
|
||||
-2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 3
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0, 0,
|
||||
7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0},
|
||||
1e-3)));
|
||||
}
|
||||
}
|
||||
// TODO(b/148391360): Add tests for unsupported sparsity format.
|
||||
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user