Optimize calls to std::string::find() and friends for a single char.

The character literal overload is more efficient.

PiperOrigin-RevId: 348126864
Change-Id: I12485209607a957ecb17a4ba1087473bb0c4dd06
This commit is contained in:
Chris Kennelly 2020-12-17 17:59:05 -08:00 committed by TensorFlower Gardener
parent 343ee97148
commit 4692525ffa
24 changed files with 32 additions and 32 deletions

View File

@ -81,7 +81,7 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
size_t bucket_end = fname.find('/', scheme_end + 1);
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");

View File

@ -38,7 +38,7 @@ void ParseHadoopPath(const std::string& fname, std::string* scheme,
size_t scheme_end = fname.find("://") + 2;
// We don't want `://` in scheme.
*scheme = fname.substr(0, scheme_end - 2);
size_t nn_end = fname.find("/", scheme_end + 1);
size_t nn_end = fname.find('/', scheme_end + 1);
if (nn_end == std::string::npos) {
*namenode = fname.substr(scheme_end + 1);
*path = "";

View File

@ -60,7 +60,7 @@ string GetPath(const string& dot_h_fname) {
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);
pos = result.find("/");
pos = result.find('/');
if (pos != string::npos) {
result = result.substr(pos + 1);
}

View File

@ -40,7 +40,7 @@ class XlaEinsumOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp lhs = ctx->Input(0);
if (equation_.find(",") == equation_.npos) {
if (equation_.find(',') == equation_.npos) {
ctx->SetOutput(0, xla::Einsum(lhs, equation_));
} else {
xla::XlaOp rhs = ctx->Input(1);
@ -68,7 +68,7 @@ class EinsumOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
ctx->InputList("inputs", &input_handles, &input_shapes));
if (equation_.find(",") == equation_.npos) {
if (equation_.find(',') == equation_.npos) {
OP_REQUIRES(
ctx, input_handles.size() == 1,
errors::InvalidArgument(

View File

@ -327,7 +327,7 @@ void StepStatsCollector::BuildCostModel(
for (const auto& node_stats : dev_stats.hardware_stats->node_stats()) {
string node_name = node_stats.node_name();
// Remove the part of op name (e.g. :Conv2D) in the end of a node name.
size_t pos = node_name.find_first_of(":");
size_t pos = node_name.find_first_of(':');
if (pos != std::string::npos) {
node_name = node_name.substr(0, pos);
}

View File

@ -378,9 +378,9 @@ Status DebugIO::PublishDebugMetadata(
// Determine the path (if any) in the grpc:// URL, and add it as a field
// of the JSON string.
const string address = url.substr(strlen(DebugIO::kFileURLScheme));
const string path = address.find("/") == string::npos
const string path = address.find('/') == string::npos
? ""
: address.substr(address.find("/"));
: address.substr(address.find('/'));
grpc_event.set_wall_time(event.wall_time());
LogMessage* log_message_grpc = grpc_event.mutable_log_message();
log_message_grpc->set_message(

View File

@ -29,7 +29,7 @@ namespace {
template <typename T>
void OutputToLog(const T& proto) {
string type_name = proto.GetTypeName();
const size_t index = type_name.find_last_of(".");
const size_t index = type_name.find_last_of('.');
if (index != string::npos) type_name = type_name.substr(index + 1);
LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { "
<< proto.ShortDebugString() << " }";

View File

@ -177,7 +177,7 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
BIO_free_all(bio);
// Now check the content of the header and the claim.
int dot = header_dot_claim.find_last_of(".");
int dot = header_dot_claim.find_last_of('.');
string header_encoded = header_dot_claim.substr(0, dot);
string claim_encoded = header_dot_claim.substr(dot + 1);

View File

@ -114,7 +114,7 @@ class FunctionTable {
func_pb->set_id(function_table_.size());
string file_base(io::Basename(file_path));
file_base = file_base.substr(0, file_base.find_last_of("."));
file_base = file_base.substr(0, file_base.find_last_of('.'));
func_pb->set_name(
string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
func_pb->set_filename(string_table_->GetIndex(file_path));

View File

@ -48,13 +48,13 @@ void TFScope::AddNode(TFGraphNode* node) {
nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
}
auto last_slash = name.find_last_of("/");
auto last_slash = name.find_last_of('/');
while (last_slash != name.npos) {
name = name.substr(0, last_slash);
if (nodes_map_.find(name) == nodes_map_.end()) {
CHECK(CreateParentNode(name));
}
last_slash = name.find_last_of("/");
last_slash = name.find_last_of('/');
}
}
@ -65,7 +65,7 @@ void TFScope::Build() {
// Found roots, which are nodes without "/".
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
ScopeNode* node = it->second.get();
auto last_slash = node->name().find_last_of("/");
auto last_slash = node->name().find_last_of('/');
if (last_slash == string::npos) {
roots.push_back(node);
} else {

View File

@ -212,7 +212,7 @@ void TFStats::AddGraph(std::unique_ptr<GraphDef> graph) {
int output_idx = 0;
// input name format can be: "^node:src_output"
// if not :src_output, then it's the first one (further verify?)
auto prefix_pos = node_input.find(":");
auto prefix_pos = node_input.find(':');
if (prefix_pos != node_input.npos) {
std::vector<string> input_parts = absl::StrSplit(node_input, ':');
DCHECK(input_parts.size() == 2)
@ -287,7 +287,7 @@ void TFStats::AddRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
for (const NodeExecStats& node_stat : dev_stat.node_stats()) {
string name = node_stat.node_name();
// Sometimes the node_name is suffixed with unnecessary information.
auto split_pos = node_stat.node_name().find(":");
auto split_pos = node_stat.node_name().find(':');
if (split_pos != node_stat.node_name().npos) {
name = node_stat.node_name().substr(0, split_pos);
}

View File

@ -45,7 +45,7 @@ namespace tensorflow {
namespace tfprof {
void completion(const char* buf, linenoiseCompletions* lc) {
string buf_str = buf;
if (buf_str.find(" ") == buf_str.npos) {
if (buf_str.find(' ') == buf_str.npos) {
for (const char* opt : kCmds) {
if (string(opt).find(buf_str) == 0) {
linenoiseAddCompletion(lc, opt);

View File

@ -43,7 +43,7 @@ tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
std::set<string> output_types(kOutput,
kOutput + sizeof(kOutput) / sizeof(*kOutput));
auto opt_split = output_opt.find(":");
auto opt_split = output_opt.find(':');
std::vector<string> kv_split;
if (opt_split == output_opt.npos) {
if (output_types.find(output_opt) == output_types.end()) {

View File

@ -121,7 +121,7 @@ std::string OpType(const DeviceStepStats& ds, const NodeExecStats& ns) {
std::string::size_type start = label.find(sep);
if (start == std::string::npos) return "<>";
start += sep.size();
std::string::size_type end = label.find("(", start);
std::string::size_type end = label.find('(', start);
if (end == std::string::npos) return "<>";
return label.substr(start, end - start);
}

View File

@ -346,7 +346,7 @@ std::string ProfilingInfo::GetDetailedReport() const {
result += " " + dispatch.label + " - " +
std::to_string(absl::ToDoubleMilliseconds(dispatch.duration)) +
" ms\n";
auto name = dispatch.label.substr(0, dispatch.label.find(" "));
auto name = dispatch.label.substr(0, dispatch.label.find(' '));
if (statistics.find(name) != statistics.end()) {
statistics[name].count++;
statistics[name].total_time +=

View File

@ -255,7 +255,7 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir,
size_t pos = 0;
int added = 0;
while (true) {
size_t end_pos = manifest.find("\n", pos);
size_t end_pos = manifest.find('\n', pos);
if (end_pos == string::npos) break;
string filename = manifest.substr(pos, end_pos - pos);
test_paths->push_back(dir + "/" + filename);
@ -294,13 +294,13 @@ TEST_P(OpsTest, RunZipTests) {
string test_path_and_label = GetParam();
string test_path = test_path_and_label;
string label = test_path_and_label;
size_t end_pos = test_path_and_label.find(" ");
size_t end_pos = test_path_and_label.find(' ');
if (end_pos != string::npos) {
test_path = test_path_and_label.substr(0, end_pos);
label = test_path_and_label.substr(end_pos + 1);
}
string tflite_test_case = test_path + "_tests.txt";
string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
string tflite_dir = test_path.substr(0, test_path.find_last_of('/'));
string test_name = label.substr(label.find_last_of('/'));
std::ifstream tflite_stream(tflite_test_case);

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/lite/testing/tflite_driver.h"
std::string dirname(const std::string& s) {
return s.substr(0, s.find_last_of("/"));
return s.substr(0, s.find_last_of('/'));
}
bool Interpret(const char* examples_filename, bool use_nnapi) {

View File

@ -214,13 +214,13 @@ std::string SanitizeErrorMessage(const std::string& error_message) {
size_t pos = error_message.find(s1);
if (pos != std::string::npos) {
// Find the terminate point for flex op list.
auto end = error_message.find(".", pos);
auto end = error_message.find('.', pos);
pruned_message.append(error_message.substr(pos, end - pos + 1));
}
pos = error_message.find(s2);
if (pos != std::string::npos) {
// Find the terminate point for custom op list.
auto end = error_message.find(".", pos);
auto end = error_message.find('.', pos);
pruned_message.append(error_message.substr(pos, end - pos + 1));
}
return pruned_message;

View File

@ -265,7 +265,7 @@ std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
// Assuming the node name has a pattern like:
// "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
// CELLNAME as the cluster name.
size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1;
size_t cell_pos = node.name().rfind('/', weights_pos - 2) + 1;
std::string cell_name =
node.name().substr(cell_pos, weights_pos - cell_pos - 1);
cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);

View File

@ -1074,7 +1074,7 @@ void CheckEachArray(const Model& model) {
// Check name. Either "name_with_suffix_8", "name_with_port:3", but not
// "name_with_both:3_8".
const std::string& name = array_entry.first;
auto colon_pos = name.find_first_of(":");
auto colon_pos = name.find_first_of(':');
if (colon_pos != std::string::npos) {
CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
std::string::npos)

View File

@ -75,7 +75,7 @@ TfLiteStatus GetSortedFileNames(
while ((ent = readdir(dir)) != nullptr) {
if (ent->d_type == DT_DIR) continue;
std::string filename(std::string(ent->d_name));
size_t lastdot = filename.find_last_of(".");
size_t lastdot = filename.find_last_of('.');
std::string ext = lastdot != std::string::npos ? filename.substr(lastdot)
: std::string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);

View File

@ -239,7 +239,7 @@ PYBIND11_MODULE(_pywrap_file_io, m) {
py::gil_scoped_release release;
auto* env = tensorflow::Env::Default();
std::unique_ptr<WritableFile> self;
const auto status = mode.find("a") == std::string::npos
const auto status = mode.find('a') == std::string::npos
? env->NewWritableFile(filename, &self)
: env->NewAppendableFile(filename, &self);
py::gil_scoped_acquire acquire;

View File

@ -295,7 +295,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
std::string version_and_rest = driver_version_file_contents.substr(
offset + strlen(kDriverFilePrelude), std::string::npos);
size_t space_index = version_and_rest.find(" ");
size_t space_index = version_and_rest.find(' ');
auto kernel_version = version_and_rest.substr(0, space_index);
// TODO(b/22689637): Eliminate the explicit namespace if possible.
auto stripped_kernel_version = absl::StripSuffix(kernel_version, ".ld64");

View File

@ -105,7 +105,7 @@ int MainImpl(int argc, char** argv) {
const tensorflow::protobuf::FileDescriptor* fd =
importer.Import(proto_path);
const int index = proto_path.find_last_of(".");
const int index = proto_path.find_last_of('.');
string proto_path_no_suffix = proto_path.substr(0, index);
proto_path_no_suffix =