Use absl functions instead of str_util within tf2xla.

PiperOrigin-RevId: 210040583
This commit is contained in:
Justin Lebar 2018-08-23 19:17:40 -07:00 committed by TensorFlower Gardener
parent 52acfd0811
commit 9d51d59140
25 changed files with 86 additions and 75 deletions

View File

@ -72,6 +72,7 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep "@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep
], ],
@ -100,6 +101,7 @@ cc_library(
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow { namespace tensorflow {
@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
} }
rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{I}}", strings::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type}); rewrites->push_back({"{{TYPE}}", type});
rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices}); rewrites->push_back({"{{INDICES}}", indices});
return Status::OK(); return Status::OK();
@ -159,7 +159,8 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
string RewriteWithName(const string& name, string code, string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) { const std::vector<std::pair<string, string>>& rewrites) {
absl::StrReplaceAll(rewrites, &code); absl::StrReplaceAll(rewrites, &code);
return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
return code;
} }
// Generate methods for args (inputs). // Generate methods for args (inputs).
@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
{"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name}, {"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}", {"{{DECLS_FROM_OBJ_FILE}}",
str_util::Join(metadata_result.header_variable_decls, "\n")}, absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point}, {"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim}, metadata_result.hlo_profile_printer_data_access_shim},
@ -595,7 +596,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}", {"{{BUFFER_INFOS_AS_STRING}}",
str_util::Join(buffer_infos_as_strings, ",\n")}}; absl::StrJoin(buffer_infos_as_strings, ",\n")}};
absl::StrReplaceAll(rewrites, header); absl::StrReplaceAll(rewrites, header);
return Status::OK(); return Status::OK();
} }

View File

@ -18,13 +18,13 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/strings/match.h"
#include "llvm/Support/TargetSelect.h" #include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -34,9 +34,9 @@ namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo; using ::tensorflow::cpu_function_runtime::BufferInfo;
void ExpectErrorContains(const Status& status, StringPiece str) { void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status); EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str; << "expected error: " << status.error_message() << " to contain: " << str;
} }

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
@ -34,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -55,7 +56,7 @@ const char kUsageHeader[] =
"\n"; "\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) { Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (str_util::EndsWith(fname, ".pbtxt")) { if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto); return ReadTextProto(Env::Default(), fname, proto);
} else { } else {
return ReadBinaryProto(Env::Default(), fname, proto); return ReadBinaryProto(Env::Default(), fname, proto);
@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
for (const tf2xla::Fetch& fetch : config.fetch()) { for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name()); nodes.insert(fetch.id().node_name());
} }
std::cout << str_util::Join(nodes, ","); std::cout << absl::StrJoin(nodes, ",");
return Status::OK(); return Status::OK();
} }

View File

@ -320,6 +320,7 @@ cc_library(
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -347,6 +348,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/strings",
], ],
) )
@ -389,6 +391,7 @@ cc_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/strings",
], ],
) )
@ -488,6 +491,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/graph/tensor_id.h"
@ -153,7 +154,7 @@ class AndPredicate : public Predicate {
std::back_inserter(operands_str), std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); }); [](Predicate* pred) { return pred->ToString(); });
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
} }
Kind kind() const override { return Kind::kAnd; } Kind kind() const override { return Kind::kAnd; }
@ -182,7 +183,7 @@ class OrPredicate : public Predicate {
std::back_inserter(operands_str), std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); }); [](Predicate* pred) { return pred->ToString(); });
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
} }
Kind kind() const override { return Kind::kOr; } Kind kind() const override { return Kind::kOr; }

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {

View File

@ -44,7 +44,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/util/equal_graph_def.h"
@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
std::unordered_set<string> control_input_a; std::unordered_set<string> control_input_a;
std::unordered_set<string> control_input_b; std::unordered_set<string> control_input_b;
for (int i = 0; i < a.input_size(); ++i) { for (int i = 0; i < a.input_size(); ++i) {
if (str_util::StartsWith(a.input(i), "^")) { if (absl::StartsWith(a.input(i), "^")) {
if (!str_util::StartsWith(b.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) {
if (diff) { if (diff) {
*diff = strings::StrCat( *diff = strings::StrCat(
diff_preamble, " mismatch for node ", a.name(), " input ", i, diff_preamble, " mismatch for node ", a.name(), " input ", i,
@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Graph* graph = graph_ptr->get(); Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) { for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" && if (n->type_string() == "_Arg" &&
str_util::StartsWith(n->name(), "const")) { absl::StartsWith(n->name(), "const")) {
++guaranteed_consts; ++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n)); EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else { } else {
@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Graph* graph = graph_ptr->get(); Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) { for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" && if (n->type_string() == "_Arg" &&
str_util::StartsWith(n->name(), "const")) { absl::StartsWith(n->name(), "const")) {
++guaranteed_consts; ++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n)); EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else { } else {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h"
@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -656,7 +656,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.ToString(), EXPECT_TRUE(absl::StrContains(status.ToString(),
"Edge from c to a would create a cycle.\n" "Edge from c to a would create a cycle.\n"
"+-> a\n" "+-> a\n"
"| b\n" "| b\n"

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {

View File

@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/graph/tensor_id.h"
@ -327,7 +328,7 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
std::vector<string> elements_debug_string; std::vector<string> elements_debug_string;
std::transform(resource_op_set.begin(), resource_op_set.end(), std::transform(resource_op_set.begin(), resource_op_set.end(),
std::back_inserter(elements_debug_string), ResourceOpToString); std::back_inserter(elements_debug_string), ResourceOpToString);
return strings::StrCat("{", str_util::Join(elements_debug_string, ","), "}"); return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
} }
string NodeToString(const Node& n, ResourceOpKind resource_op_kind) { string NodeToString(const Node& n, ResourceOpKind resource_op_kind) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "absl/strings/str_join.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h"
@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -575,7 +575,7 @@ TEST(ResourceOperationSafetyAnalysisTest, HaveAllResourceOps) {
EXPECT_TRUE(unnecessary_resource_ops.empty()) EXPECT_TRUE(unnecessary_resource_ops.empty())
<< "Stale resource ops:\n" << "Stale resource ops:\n"
<< str_util::Join(unnecessary_resource_ops, "\n"); << absl::StrJoin(unnecessary_resource_ops, "\n");
} }
} // namespace } // namespace

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {

View File

@ -39,6 +39,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:ops", "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -88,6 +89,7 @@ cc_library(
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -254,6 +256,7 @@ cc_library(
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -305,6 +308,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
], ],
) )
@ -372,6 +376,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/strings",
], ],
) )
@ -444,6 +449,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
@ -43,7 +44,7 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
template <typename T> template <typename T>
string NodesToString(const T& nodes) { string NodesToString(const T& nodes) {
return strings::StrCat("{", return strings::StrCat("{",
str_util::Join(nodes, ",", absl::StrJoin(nodes, ",",
[](string* output, const Node* node) { [](string* output, const Node* node) {
strings::StrAppend(output, strings::StrAppend(output,
node->name()); node->name());

View File

@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for broadcasting used in gradient // XLA-specific Ops for broadcasting used in gradient
// code. // code.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]); BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(), OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument( errors::InvalidArgument(
"Incompatible shapes: [", str_util::Join(shapes[0], ","), "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
"] vs. [", str_util::Join(shapes[1], ","), "]")); "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
const int64 len = bcast.output_shape().size(); const int64 len = bcast.output_shape().size();
Tensor output(DT_INT32, TensorShape({len})); Tensor output(DT_INT32, TensorShape({len}));
@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]); BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(), OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument( errors::InvalidArgument(
"Incompatible shapes: [", str_util::Join(shapes[0], ","), "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
"] vs. [", str_util::Join(shapes[1], ","), "]")); "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 0, bcast.grad_x_reduce_idx());
Output(ctx, 1, bcast.grad_y_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx());
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific reduction Ops. // XLA-specific reduction Ops.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -66,7 +67,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes));
VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "data shape: " << data_shape.DebugString();
VLOG(1) << "axes : " << str_util::Join(axes, ","); VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false); gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
std::vector<int64> xla_axes; std::vector<int64> xla_axes;

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Ops for softmax. // XLA-specific Ops for softmax.
#include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -33,7 +33,7 @@ namespace {
class SoftmaxOp : public XlaOpKernel { class SoftmaxOp : public XlaOpKernel {
public: public:
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
log_ = str_util::StartsWith(type_string(), "Log"); log_ = absl::StartsWith(type_string(), "Log");
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "absl/strings/match.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -65,7 +65,7 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
if (explicit_sharding.has_value()) { if (explicit_sharding.has_value()) {
return explicit_sharding; return explicit_sharding;
} else if (!parsed_device.has_type || !parsed_device.has_id || } else if (!parsed_device.has_type || !parsed_device.has_id ||
!str_util::StrContains(parsed_device.type, !absl::StrContains(parsed_device.type,
kDeviceSuffixReplicatedCore)) { kDeviceSuffixReplicatedCore)) {
return absl::optional<xla::OpSharding>(); return absl::optional<xla::OpSharding>();
} else { } else {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -197,8 +197,8 @@ Status RewriteAndPruneGraph(
if (!missing_feeds.empty() || !missing_fetches.empty()) { if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted( return errors::Aborted(
"Post graph-pruning", "Post graph-pruning",
", missing feeds: ", str_util::Join(missing_feeds, ", "), ", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
", missing fetches: ", str_util::Join(missing_fetches, ", ")); ", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -20,11 +20,11 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) {
} }
std::sort(types.begin(), types.end()); std::sort(types.begin(), types.end());
constraints.push_back("`" + constraint.name() + "={" + constraints.push_back("`" + constraint.name() + "={" +
str_util::Join(types, ",") + "}`"); absl::StrJoin(types, ",") + "}`");
} }
std::cout << "`" << kdef->op() << "` | " std::cout << "`" << kdef->op() << "` | "
<< str_util::Join(constraints, "<br>") << std::endl; << absl::StrJoin(constraints, "<br>") << std::endl;
} }
std::cout << "\nTo regenerate this table, run:\n\n```shell\n" std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
{"device", &device, {"device", &device,
"Name of the compilation device for which to print supported ops, " "Name of the compilation device for which to print supported ops, "
"one of: " + "one of: " +
str_util::Join(device_names, ",")}, absl::StrJoin(device_names, ",")},
}; };
string usage = Flags::Usage(argv[0], flag_list); string usage = Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/function_ops.h"
@ -25,16 +26,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
void ExpectErrorContains(const Status& status, StringPiece str) { void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status); EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str; << "expected error: " << status.error_message() << " to contain: " << str;
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/function_ops.h"
@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
@ -309,10 +309,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
std::move(graph), args, &result); std::move(graph), args, &result);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_TRUE( EXPECT_TRUE(
str_util::StrContains(status.error_message(), "depends on a parameter")) absl::StrContains(status.error_message(), "depends on a parameter"))
<< status.error_message(); << status.error_message();
EXPECT_TRUE( EXPECT_TRUE(
str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) absl::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
<< status.error_message(); << status.error_message();
} }
@ -727,8 +727,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result); /*args=*/{}, &result);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
"is not defined."))
<< status.error_message(); << status.error_message();
} }
@ -807,12 +806,10 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
// Flib lookup failure. // Flib lookup failure.
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
"is not defined."))
<< status.error_message(); << status.error_message();
// Local flib lookup failure. // Local flib lookup failure.
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
"Attr T is not found"))
<< status.error_message(); << status.error_message();
} }
@ -1078,9 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result); std::move(graph), args, &result);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message(); << status.error_message();
EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
<< status.error_message(); << status.error_message();
} }
@ -1103,10 +1100,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, &result); std::move(graph), args, &result);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.error_message(), EXPECT_TRUE(absl::StrContains(status.error_message(),
"is not in the list of allowed values")) "is not in the list of allowed values"))
<< status.error_message(); << status.error_message();
EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
<< status.error_message(); << status.error_message();
} }
@ -1130,7 +1127,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::move(graph_copy), args, &result); std::move(graph_copy), args, &result);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE( EXPECT_TRUE(
str_util::StrContains(status.error_message(), absl::StrContains(status.error_message(),
"The following nodes are unreachable " "The following nodes are unreachable "
"from the source in the graph: {{node NoOp}}")) "from the source in the graph: {{node NoOp}}"))
<< status.error_message(); << status.error_message();