From 9d51d5914084f4cd65ea18c4da45dfb64f2945b6 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 23 Aug 2018 19:17:40 -0700 Subject: [PATCH] Use absl functions instead of str_util within tf2xla. PiperOrigin-RevId: 210040583 --- tensorflow/compiler/aot/BUILD | 2 ++ tensorflow/compiler/aot/codegen.cc | 13 ++++---- tensorflow/compiler/aot/codegen_test.cc | 6 ++-- .../compiler/aot/tests/tfcompile_test.cc | 1 - tensorflow/compiler/aot/tfcompile_main.cc | 7 +++-- tensorflow/compiler/jit/BUILD | 4 +++ tensorflow/compiler/jit/deadness_analysis.cc | 5 +-- .../compiler/jit/deadness_analysis_test.cc | 1 - .../jit/encapsulate_subgraphs_pass.cc | 1 - .../jit/encapsulate_subgraphs_pass_test.cc | 10 +++--- .../jit/mark_for_compilation_pass_test.cc | 12 +++---- .../jit/partially_decluster_pass_test.cc | 1 - .../jit/resource_operation_safety_analysis.cc | 3 +- ...resource_operation_safety_analysis_test.cc | 4 +-- .../compiler/jit/xla_cluster_util_test.cc | 1 - tensorflow/compiler/tf2xla/BUILD | 6 ++++ .../tf2xla/functionalize_control_flow_util.h | 11 ++++--- .../compiler/tf2xla/kernels/bcast_ops.cc | 9 +++--- .../tf2xla/kernels/reduction_ops_common.cc | 3 +- .../compiler/tf2xla/kernels/softmax_op.cc | 4 +-- tensorflow/compiler/tf2xla/sharding_util.cc | 6 ++-- tensorflow/compiler/tf2xla/tf2xla.cc | 6 ++-- .../compiler/tf2xla/tf2xla_supported_ops.cc | 8 ++--- .../compiler/tf2xla/tf2xla_util_test.cc | 6 ++-- .../compiler/tf2xla/xla_compiler_test.cc | 31 +++++++++---------- 25 files changed, 86 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 58dc91f74dd..59b961cdd9d 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -72,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -100,6 +101,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index d9edbb28496..e77a8fecf09 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, } rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -159,7 +159,8 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { absl::StrReplaceAll(rewrites, &code); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, - {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -595,7 +596,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", - str_util::Join(buffer_infos_as_strings, ",\n")}}; + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; absl::StrReplaceAll(rewrites, header); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 60d59ae996e..e3a53edb736 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -34,9 +34,9 @@ namespace { using ::tensorflow::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index a2b824da172..dd2b151098f 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7b..f3c44e9dda8 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 07a22018179..1c09ee8fb1a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -320,6 +320,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -347,6 +348,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -389,6 +391,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", ], ) @@ -488,6 +491,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0ca0f949dcd..fe28502f69d 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -153,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -182,7 +183,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index cc9f1023985..28a56044d5e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 4dc035c9cac..d0e74d0bee9 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c0543a00792..b3600fc48b9 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, std::unordered_set control_input_a; std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (str_util::StartsWith(a.input(i), "^")) { - if (!str_util::StartsWith(b.input(i), "^")) { + if (absl::StartsWith(a.input(i), "^")) { + if (!absl::StartsWith(b.input(i), "^")) { if (diff) { *diff = strings::StrCat( diff_preamble, " mismatch for node ", a.name(), " input ", i, @@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index dd789e4451c..807ab51fd3c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -656,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.ToString(), - "Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(absl::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 08a956e4c64..f61a955c222 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index c62d7230bc3..f18d988a5ca 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -327,7 +328,7 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); - return strings::StrCat("{", str_util::Join(elements_debug_string, ","), "}"); + return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } string NodeToString(const Node& n, ResourceOpKind resource_op_kind) { diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index c52c9308c99..c774bdf5ffc 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -575,7 +575,7 @@ TEST(ResourceOperationSafetyAnalysisTest, HaveAllResourceOps) { EXPECT_TRUE(unnecessary_resource_ops.empty()) << "Stale resource ops:\n" - << str_util::Join(unnecessary_resource_ops, "\n"); + << absl::StrJoin(unnecessary_resource_ops, "\n"); } } // namespace diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 2cb351e1ecd..65bbf3efe85 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 30398a5e379..b67e717f828 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -88,6 +89,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -254,6 +256,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -305,6 +308,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -372,6 +376,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -444,6 +449,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index a0544b69e9e..61940e3586c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph.h" @@ -43,11 +44,11 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); template string NodesToString(const T& nodes) { return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), "}"); } diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ba3b1c9dab7..2e383b14735 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); const int64 len = bcast.output_shape().size(); Tensor output(DT_INT32, TensorShape({len})); @@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx()); } diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a71b8ca362..598248563bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -66,7 +67,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << str_util::Join(axes, ","); + VLOG(1) << "axes : " << absl::StrJoin(axes, ","); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 025ba827410..d6bd927135c 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -33,7 +33,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = str_util::StartsWith(type_string(), "Log"); + log_ = absl::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 66835e69b23..2d7eb8b915b 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +65,8 @@ xla::StatusOr> ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !str_util::StrContains(parsed_device.type, - kDeviceSuffixReplicatedCore)) { + !absl::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return absl::optional(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 48568c825b7..f34af2d67de 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -197,8 +197,8 @@ Status RewriteAndPruneGraph( if (!missing_feeds.empty() || !missing_fetches.empty()) { return errors::Aborted( "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), + ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 7aca889a266..567d212b5ee 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } std::sort(types.begin(), types.end()); constraints.push_back("`" + constraint.name() + "={" + - str_util::Join(types, ",") + "}`"); + absl::StrJoin(types, ",") + "}`"); } std::cout << "`" << kdef->op() << "` | " - << str_util::Join(constraints, "
") << std::endl; + << absl::StrJoin(constraints, "
") << std::endl; } std::cout << "\nTo regenerate this table, run:\n\n```shell\n" @@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + - str_util::Join(device_names, ",")}, + absl::StrJoin(device_names, ",")}, }; string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ae51446204b..2b1f724dc7b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,16 +26,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7227df96499..6e5a0198f6b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -309,10 +309,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "depends on a parameter")) + absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -727,8 +727,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); } @@ -807,12 +806,10 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "Attr T is not found")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) << status.error_message(); } @@ -1078,9 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1103,10 +1100,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "is not in the list of allowed values")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "is not in the list of allowed values")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) << status.error_message(); } @@ -1130,9 +1127,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); }