Change unordered containers to Swiss table.

PiperOrigin-RevId: 325332451
Change-Id: I5349d9b9e9227b62752f21e0b2c777bfcc59d3eb
This commit is contained in:
Robert David 2020-08-06 16:16:29 -07:00 committed by TensorFlower Gardener
parent 0c3334857d
commit 0a0a9eeb6b
38 changed files with 127 additions and 100 deletions

View File

@ -251,6 +251,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/gl:api2",
],
}) + [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"//tensorflow/lite:kernel_api",

View File

@ -388,6 +388,8 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],
)
@ -454,6 +456,7 @@ cc_library(
":compiled_program_cache_cc_fbs",
":util",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
"@farmhash_archive//:farmhash",
"@flatbuffers",

View File

@ -21,9 +21,10 @@ limitations under the License.
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
@ -49,7 +50,7 @@ namespace gpu {
namespace cl {
namespace {
bool IsReady(const std::unordered_set<ValueId>& ready_tensors,
bool IsReady(const absl::flat_hash_set<ValueId>& ready_tensors,
const CLNode& node) {
for (const ValueId in_id : node.inputs) {
if (ready_tensors.find(in_id) == ready_tensors.end()) {
@ -325,7 +326,7 @@ absl::Status InferenceContext::ConvertOperations(
inputs, outputs, node,
&gpu_subgraph));
}
std::unordered_map<int, ValueId> mapping_to_global_ids;
absl::flat_hash_map<int, ValueId> mapping_to_global_ids;
for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
const auto& t = gpu_subgraph.new_tensors[j];
auto global_id = tensor_reserver_.Add({t.first, t.second});
@ -364,7 +365,7 @@ absl::Status InferenceContext::ConvertOperations(
}
void InferenceContext::Merge() {
std::unordered_set<ValueId> ready_tensors;
absl::flat_hash_set<ValueId> ready_tensors;
for (const auto& input_id : input_ids_) {
ready_tensors.insert(input_id);
}

View File

@ -20,9 +20,9 @@ limitations under the License.
#include <functional>
#include <map>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
@ -160,7 +160,7 @@ class InferenceContext {
DummyTensor Get(ValueId id) { return reservations_[id]; }
private:
std::unordered_map<ValueId, DummyTensor> reservations_;
absl::flat_hash_map<ValueId, DummyTensor> reservations_;
ValueId next_;
};
TensorReserver tensor_reserver_;

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
@ -93,8 +93,8 @@ class ProgramCache {
// There is a low probability of a hash collision when cache is deserialized
// because only fingerprints are serialized instead of full source code.
bool use_fingerprints_ = false;
std::unordered_map<ProgramDescriptor, CLProgram, ProgramDescriptorHasher,
ProgramDescriptorEqual>
absl::flat_hash_map<ProgramDescriptor, CLProgram, ProgramDescriptorHasher,
ProgramDescriptorEqual>
programs_;
};

View File

@ -114,6 +114,7 @@ cc_library(
":shape",
":status",
":tensor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"//tensorflow/lite/delegates:utils",
"//tensorflow/lite:context",
@ -169,6 +170,7 @@ cc_library(
hdrs = ["model_transformer.h"],
deps = [
":model",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@ -186,6 +188,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/delegates:utils",
"//tensorflow/lite/kernels:kernel_util",
"@com_google_absl//absl/container:flat_hash_map",
],
)
@ -198,6 +201,7 @@ cc_library(
":model",
":shape",
":status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:variant",
],
)
@ -212,6 +216,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:types",
"@com_google_absl//absl/container:flat_hash_map",
],
)

View File

@ -22,10 +22,10 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
@ -2884,8 +2884,8 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
// guarantee that the order will match the source model tensors order.
absl::Status PrecreateIOTensors(
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
std::unordered_map<int, int>* quant_conversion_map,
std::unordered_map<int, Value*>* tensor_to_value) {
absl::flat_hash_map<int, int>* quant_conversion_map,
absl::flat_hash_map<int, Value*>* tensor_to_value) {
for (int i = 0; i < io_tensors->size; ++i) {
const int tensor_index = io_tensors->data[i];
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
@ -2899,7 +2899,7 @@ absl::Status PrecreateIOTensors(
absl::Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
std::unordered_map<int, int>* quant_conversion_map) {
absl::flat_hash_map<int, int>* quant_conversion_map) {
std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
std::vector<int> tflite_nodes;
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
@ -2925,7 +2925,7 @@ absl::Status BuildModel(TfLiteContext* context,
operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
}
std::unordered_map<int, Value*> tensor_to_value;
absl::flat_hash_map<int, Value*> tensor_to_value;
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->input_tensors,
quant_conversion_map, &tensor_to_value));
@ -2952,7 +2952,7 @@ absl::Status BuildModel(TfLiteContext* context,
absl::Status BuildFinalModel(
TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph, std::unordered_map<int, int>* quant_conversion_map) {
GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
RETURN_IF_ERROR(
BuildModel(context, delegate_params, graph, quant_conversion_map));

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <cstdint>
#include <string>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -48,7 +48,7 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context,
absl::Status BuildModel(
TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
std::unordered_map<int, int>* quant_conversion_map = nullptr);
absl::flat_hash_map<int, int>* quant_conversion_map = nullptr);
// Same as above but also apply all transformations on the final graph.
// Prefer using this method instead of BuildModel.
@ -62,7 +62,7 @@ absl::Status BuildModel(
absl::Status BuildFinalModel(
TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
std::unordered_map<int, int>* quant_conversion_map = nullptr);
absl::flat_hash_map<int, int>* quant_conversion_map = nullptr);
// Module-internal converter, exposed for unit testing purpose only.
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <deque>
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
namespace tflite {
@ -126,7 +126,7 @@ class ModelTransformer {
TransformationReporter* reporter_;
std::deque<NodeId> to_process_;
std::unordered_set<NodeId> processed_;
absl::flat_hash_set<NodeId> processed_;
};
class NullTransformationReporter : public TransformationReporter {

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
#include <cstdint>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
@ -28,8 +28,8 @@ namespace tflite {
namespace gpu {
absl::Status ObjectReader::ReadNonConstantTensor(
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
uint32_t tensor_idx, Value** value) {
if (tensor_idx >= context->tensors_size) {
return absl::OutOfRangeError(

View File

@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OBJECT_READER_H_
#include <cstdint>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
@ -34,14 +34,14 @@ namespace gpu {
class ObjectReader {
public:
static absl::Status ReadNonConstantTensor(
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
uint32_t tensor_idx, Value** value = nullptr);
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
const TfLiteNode* node,
std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map = nullptr)
absl::flat_hash_map<int, Value*>* tensor_to_value,
absl::flat_hash_map<int, int>* quant_conversion_map = nullptr)
: graph_(graph),
context_(context),
node_(node),
@ -98,8 +98,8 @@ class ObjectReader {
GraphFloat32* graph_;
TfLiteContext* context_;
const TfLiteNode* node_;
std::unordered_map<int, Value*>* tensor_to_value_;
std::unordered_map<int, int>* quant_conversion_map_;
absl::flat_hash_map<int, Value*>* tensor_to_value_;
absl::flat_hash_map<int, int>* quant_conversion_map_;
};
} // namespace gpu

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include <cstdint>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -165,7 +165,7 @@ std::string ToString(enum OperationType op) {
OperationType OperationTypeFromString(const std::string& name) {
static const auto operations =
new std::unordered_map<std::string, OperationType>({
new absl::flat_hash_map<std::string, OperationType>({
{"abs", OperationType::ABS},
{"add", OperationType::ADD},
{"batch_normalization", OperationType::BATCH_NORMALIZATION},

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -22,8 +23,9 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace {
void DequantizeInput(TfLiteContext* context, int input_index,
const std::unordered_map<int, int>& quant_conversion_map) {
void DequantizeInput(
TfLiteContext* context, int input_index,
const absl::flat_hash_map<int, int>& quant_conversion_map) {
if (quant_conversion_map.find(input_index) == quant_conversion_map.end()) {
return;
}
@ -50,7 +52,7 @@ void DequantizeInput(TfLiteContext* context, int input_index,
}
void QuantizeOutput(TfLiteContext* context, int output_index,
const std::unordered_map<int, int>& quant_conversion_map) {
const absl::flat_hash_map<int, int>& quant_conversion_map) {
if (quant_conversion_map.find(output_index) == quant_conversion_map.end()) {
return;
}
@ -80,7 +82,7 @@ void QuantizeOutput(TfLiteContext* context, int output_index,
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
const absl::flat_hash_map<int, int>& quant_conversion_map) {
for (auto index : input_indices) {
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
}
@ -89,7 +91,7 @@ absl::Status DequantizeInputs(
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<int64_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
const absl::flat_hash_map<int, int>& quant_conversion_map) {
for (auto index : input_indices) {
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
}
@ -98,7 +100,7 @@ absl::Status DequantizeInputs(
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
const absl::flat_hash_map<int, int>& quant_conversion_map) {
for (auto index : output_indices) {
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
}
@ -108,7 +110,7 @@ absl::Status QuantizeOutputs(
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<int64_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
const absl::flat_hash_map<int, int>& quant_conversion_map) {
for (auto index : output_indices) {
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
}

View File

@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -32,11 +32,11 @@ namespace gpu {
// tensor and its original quantized one.
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map);
const absl::flat_hash_map<int, int>& quant_conversion_map);
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<int64_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map);
const absl::flat_hash_map<int, int>& quant_conversion_map);
// Quantizes output tensors post-inference, leaving float tensors intact.
// output_indices contains (fp32) inputs to be quantized, which are outputs of
@ -45,11 +45,11 @@ absl::Status DequantizeInputs(
// tensor and its original quantized one.
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map);
const absl::flat_hash_map<int, int>& quant_conversion_map);
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<int64_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map);
const absl::flat_hash_map<int, int>& quant_conversion_map);
} // namespace gpu
} // namespace tflite

View File

@ -151,7 +151,7 @@ TEST(DequantizeInputs, Int8) {
PopulateContext(tensors, context);
std::vector<uint32_t> input_indices = {1};
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
absl::flat_hash_map<int, int> quant_conversion_map = {{1, 0}};
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
@ -176,7 +176,7 @@ TEST(DequantizeInputs, UInt8) {
PopulateContext(tensors, context);
std::vector<int64_t> input_indices = {1};
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
absl::flat_hash_map<int, int> quant_conversion_map = {{1, 0}};
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
@ -199,7 +199,7 @@ TEST(QuantizeOutputs, Int8) {
PopulateContext(tensors, context);
std::vector<uint32_t> output_indices = {0};
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
absl::flat_hash_map<int, int> quant_conversion_map = {{0, 1}};
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
@ -221,7 +221,7 @@ TEST(QuantizeOutputs, UInt8) {
PopulateContext(tensors, context);
std::vector<int64_t> output_indices = {0};
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
absl::flat_hash_map<int, int> quant_conversion_map = {{0, 1}};
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <thread> // NOLINT(build/c++11)
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/lite/builtin_ops.h"
@ -350,7 +350,7 @@ class DelegateKernel {
// Whenever quantized inference is enabled, this maps the tensor index of each
// originally quantized (8-bit) tensor to its float version added in
// model_builder - and vice versa.
std::unordered_map<int, int> quant_conversion_map_;
absl::flat_hash_map<int, int> quant_conversion_map_;
std::thread::id thread_id_prepare_; // thread id used for Prapare()
bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke
};

View File

@ -29,6 +29,7 @@ cc_library(
":runtime_options",
":stats",
":variable",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"//tensorflow/lite/delegates/gpu/common:model",
@ -66,6 +67,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/gl/kernels:converter",
"//tensorflow/lite/delegates/gpu/gl/kernels:registry",
"//tensorflow/lite/delegates/gpu/gl/workgroups:default_calculator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@ -125,6 +127,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/gl/compiler:fuse_inplace",
"//tensorflow/lite/delegates/gpu/gl/compiler:shader_code",
"//tensorflow/lite/delegates/gpu/gl/compiler:shader_codegen",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:any",
],

View File

@ -19,10 +19,10 @@ limitations under the License.
#include <cstdint>
#include <deque>
#include <mutex> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
@ -46,7 +46,7 @@ namespace gpu {
namespace gl {
namespace {
using ObjectsSizes = std::unordered_map<ValueId, size_t>;
using ObjectsSizes = absl::flat_hash_map<ValueId, size_t>;
enum class InferenceContextState {
NOT_STARTED,
@ -313,7 +313,7 @@ class CompiledModelImpl
full_shaders[shader.second] = shader.first;
}
std::unordered_map<std::string, size_t> partial_shader_to_index;
absl::flat_hash_map<std::string, size_t> partial_shader_to_index;
std::vector<std::string> partial_shaders;
for (const auto& program : programs_) {
// Remove a header from a shader.
@ -366,16 +366,16 @@ class CompiledModelImpl
std::vector<GlShader> shaders_;
// Shaders are serialized in order of their indices.
std::unordered_map<std::string, size_t> shader_to_index_;
absl::flat_hash_map<std::string, size_t> shader_to_index_;
std::deque<ProgramParameters> programs_;
std::unordered_map<ValueId, size_t> object_sizes_;
absl::flat_hash_map<ValueId, size_t> object_sizes_;
CompilerStats stats_;
};
} // namespace
absl::Status Compile(const CompilationOptions& options,
const GraphFloat32& model,
const std::unordered_set<int>& tflite_graph_io,
const std::unordered_set<int>& tflite_graph_io, // NOLINT
const NodeShader& node_shader,
const WorkgroupsCalculator& workgroup_calculator,
std::unique_ptr<CompiledModel>* compiled_model) {

View File

@ -67,7 +67,7 @@ class CompiledModel {
// Turns the given model into "compiled" form that is suitable for inference.
absl::Status Compile(const CompilationOptions& options,
const GraphFloat32& model,
const std::unordered_set<int>& tflite_graph_io,
const std::unordered_set<int>& tflite_graph_io, // NOLINT
const NodeShader& node_shader,
const WorkgroupsCalculator& workgroup_calculator,
std::unique_ptr<CompiledModel>* compiled_model);

View File

@ -18,10 +18,10 @@ limitations under the License.
#include <algorithm>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
@ -542,7 +542,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
auto workgroup_calculator = NewDefaultWorkgroupsCalculator(*gpu_info_);
auto external_objects = absl::make_unique<ObjectManager>();
std::vector<GlShader> shaders;
std::unordered_map<std::string, size_t> shader_to_index;
absl::flat_hash_map<std::string, size_t> shader_to_index;
RuntimeOptions runtime_options;
auto runtime =
absl::make_unique<Runtime>(runtime_options, *gpu_info_,

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
@ -102,9 +103,10 @@ class CompilerImpl : public Compiler {
}
}
absl::Status Compile(const GraphFloat32& graph,
const std::unordered_set<int>& tflite_graph_io,
const ShaderCodeCallback& callback) final {
absl::Status Compile(
const GraphFloat32& graph,
const std::unordered_set<int>& tflite_graph_io, // NOLINT
const ShaderCodeCallback& callback) final {
// It is important to have ids in a compiled graph identical to the given
// graph.
RETURN_IF_ERROR(graph.MakeExactCopy(&compiled_graph_));
@ -158,7 +160,7 @@ class CompilerImpl : public Compiler {
}
// Prepare internal objects.
std::unordered_map<ValueId, Object> objects;
absl::flat_hash_map<ValueId, Object> objects;
for (auto value : compiled_graph_.values()) {
Object object = MakePHWC4Ref(value->id, value->tensor.shape);
object.data_type = value->tensor.type;

View File

@ -40,9 +40,10 @@ class Compiler {
// Goes over a graph and generates OpenGL shaders for the given graph.
// Callback is called for every generated shader. Callback may execute shaders
// as they come or store them elsewhere to execute later.
virtual absl::Status Compile(const GraphFloat32& graph,
const std::unordered_set<int>& tflite_graph_io,
const ShaderCodeCallback& callback) = 0;
virtual absl::Status Compile(
const GraphFloat32& graph,
const std::unordered_set<int>& tflite_graph_io, // NOLINT
const ShaderCodeCallback& callback) = 0;
};
std::unique_ptr<Compiler> NewCompiler(

View File

@ -38,6 +38,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/gl:object",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:variant",
@ -101,6 +102,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"//tensorflow/lite/delegates/gpu/gl:object",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@ -150,6 +152,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"//tensorflow/lite/delegates/gpu/gl:object",
"//tensorflow/lite/delegates/gpu/gl:variable",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
@ -164,6 +167,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:types",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:any",
"@com_google_absl//absl/types:variant",
@ -193,6 +197,7 @@ cc_library(
":preprocessor",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/gl:variable",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:variant",

View File

@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
#include <unordered_set>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
@ -28,7 +27,7 @@ namespace gl {
absl::Status MergeCode(CompiledNodeAttributes* attr,
CompiledNodeAttributes* merged_attr) {
// build a map of known names.
std::unordered_set<std::string> known_names;
absl::flat_hash_set<std::string> known_names;
for (const auto& parameter : merged_attr->code.parameters) {
known_names.insert(parameter.name);
}

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/types/any.h"
@ -102,7 +102,7 @@ TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
// Skip fusions which will result in duplicate inputs, e.g. diamond shapes.
{
std::unordered_set<ValueId> all_inputs;
absl::flat_hash_set<ValueId> all_inputs;
for (const auto& node_to_fuse : nodes_to_fuse) {
for (const auto& input : graph->FindInputs(node_to_fuse.first->id)) {
if (all_inputs.find(input->id) != all_inputs.end()) {

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
#include "tensorflow/lite/delegates/gpu/gl/object.h"
@ -85,7 +85,7 @@ class ObjectAccessor : public InlineRewrite {
RewriteStatus RewriteWrite(absl::string_view location,
absl::string_view value, std::string* output);
std::unordered_map<std::string, Object> name_to_object_;
absl::flat_hash_map<std::string, Object> name_to_object_;
const bool is_mali_;
const bool sampler_textures_;

View File

@ -16,10 +16,10 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
@ -86,7 +86,7 @@ class VariableRewriter : public InlineRewrite {
const std::string inline_delimiter_;
const NameFunctor name_func_;
std::unordered_map<std::string, Variable> name_to_variable_;
absl::flat_hash_map<std::string, Variable> name_to_variable_;
};
// Rewrites names of all objects according to returned values from the
@ -168,7 +168,7 @@ class ObjectRewriter : public InlineRewrite {
const std::string inline_delimiter_;
const NameFunctor name_func_;
std::unordered_map<std::string, std::pair<std::string, Object>>
absl::flat_hash_map<std::string, std::pair<std::string, Object>>
name_to_object_;
};

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_VARIABLE_ACCESSOR_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_VARIABLE_ACCESSOR_H_
#include <string>
#include <unordered_map>
#include <set>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
#include "tensorflow/lite/delegates/gpu/gl/variable.h"
@ -72,7 +72,7 @@ class VariableAccessor : public InlineRewrite {
private:
const bool inline_values_;
const bool vulkan_support_;
std::unordered_map<std::string, Variable> name_to_variable_;
absl::flat_hash_map<std::string, Variable> name_to_variable_;
std::set<std::string> shared_variables_;
std::set<std::string> uniform_parameters_;
};

View File

@ -155,7 +155,10 @@ cc_library(
name = "custom_registry",
srcs = ["custom_registry.cc"],
hdrs = ["custom_registry.h"],
deps = ["//tensorflow/lite/delegates/gpu/gl:node_shader"],
deps = [
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
@ -774,6 +777,7 @@ cc_library(
"//conditions:default": NON_TFLITE_GPU_BINARY_RELEASE_OPERATORS,
}) + [
":custom_registry",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/gl:node_shader",

View File

@ -17,15 +17,16 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
namespace tflite {
namespace gpu {
namespace gl {
void RegisterCustomOps(
std::unordered_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
shaders) {}
} // namespace gl

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
namespace tflite {
@ -29,7 +29,7 @@ namespace gl {
// Registers custom operations.
void RegisterCustomOps(
std::unordered_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
shaders_);
} // namespace gl

View File

@ -18,10 +18,10 @@ limitations under the License.
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -139,7 +139,7 @@ class Registry : public NodeShader {
}
private:
std::unordered_map<std::string, std::vector<std::unique_ptr<NodeShader>>>
absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>
shaders_;
};

View File

@ -17,10 +17,10 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -78,7 +78,7 @@ absl::Status SingleOpModel::Invoke(const CompilationOptions& compile_options,
// Create buffers for input tensors.
{
std::unordered_map<int, uint32_t> tensor_to_id;
absl::flat_hash_map<int, uint32_t> tensor_to_id;
for (const auto* input : graph_.inputs()) {
tensor_to_id[input->tensor.ref] = input->id;
}
@ -101,9 +101,9 @@ absl::Status SingleOpModel::Invoke(const CompilationOptions& compile_options,
GpuInfo gpu_info;
RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
std::unique_ptr<CompiledModel> compiled_model;
RETURN_IF_ERROR(Compile(
compile_options, graph_, /*tflite_graph_io=*/std::unordered_set<int>(),
shader, *NewDefaultWorkgroupsCalculator(gpu_info), &compiled_model));
RETURN_IF_ERROR(Compile(compile_options, graph_, /*tflite_graph_io=*/{},
shader, *NewDefaultWorkgroupsCalculator(gpu_info),
&compiled_model));
// Get inference context.
auto command_queue = NewCommandQueue(gpu_info);

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <unordered_map>
#include <vector>
#include "absl/strings/str_cat.h"

View File

@ -18,8 +18,8 @@ limitations under the License.
#ifndef TFLITE_GPU_BINARY_RELEASE
#include <memory>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
@ -62,7 +62,7 @@ class WorkgroupsCalculatorFromMetadata : public WorkgroupsCalculator {
}
private:
std::unordered_map<NodeId, uint3> workgroups_;
absl::flat_hash_map<NodeId, uint3> workgroups_;
std::unique_ptr<WorkgroupsCalculator> default_calculator_;
};

View File

@ -160,7 +160,7 @@ class Delegate {
tensors_[value->id] = {value->tensor.shape, 0};
}
std::unordered_set<int> tflite_graph_io;
std::unordered_set<int> tflite_graph_io; // NOLINT
// Prepare graph inputs.
//

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
#include <cstddef>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/convert.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
@ -32,7 +32,7 @@ namespace metal {
namespace {
std::string OneInputFunctor(OperationType op_type, const std::string& value) {
const std::unordered_map<OperationType, std::string> functors{
const absl::flat_hash_map<OperationType, std::string> functors{
{OperationType::ABS, "abs($0)"},
{OperationType::SIN, "sin($0)"},
{OperationType::HARD_SWISH,
@ -62,7 +62,7 @@ std::string OneInputFunctor(OperationType op_type, const std::string& value) {
std::string TwoInputFunctor(OperationType op_type, const std::string& value0,
const std::string& value1) {
const std::unordered_map<OperationType, std::string> functors{
const absl::flat_hash_map<OperationType, std::string> functors{
{OperationType::ADD, "$0 + $1"},
{OperationType::DIV, "$0 / $1"},
{OperationType::MAXIMUM, "max($0, $1)"},

View File

@ -26,6 +26,7 @@ limitations under the License.
#include <thread>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
@ -613,7 +614,7 @@ class Delegate {
// Whenever quantized inference is enabled, this maps the tensor index of each
// originally quantized (8-bit) tensor to its float version added in
// model_builder - and vice versa.
std::unordered_map<int, int> quant_conversion_map_;
absl::flat_hash_map<int, int> quant_conversion_map_;
TFLInferenceContext* inference_context_;
// input and output buffers are passed into Metal inference engine