diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 33ebba07b9f..c83bd81705b 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/shape_inference.h" @@ -236,7 +237,8 @@ class ShapeRefiner { GraphRunner graph_runner_; // Stores a map from a node to its ExtendedInferenceContext. - std::unordered_map> + absl::flat_hash_map, + hash> node_to_context_; // Holds a cache from 'tensor name' to the tensor that is @@ -257,9 +259,10 @@ class ShapeRefiner { // shape inference. const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; - // Cache the graph corresponding to each functin definition for which shapes + // Cache the graph corresponding to each function definition for which shapes // are refined. - std::unordered_map> + absl::flat_hash_map, + hash> functions_; TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);