Use absl functions instead of str_util within tf2xla.
PiperOrigin-RevId: 210040583
This commit is contained in:
parent
52acfd0811
commit
9d51d59140
@ -72,6 +72,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support", # fixdeps: keep
|
||||
"@llvm//:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
@ -100,6 +101,7 @@ cc_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
|
||||
@ -30,7 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
}
|
||||
rewrites->push_back({"{{I}}", strings::StrCat(i)});
|
||||
rewrites->push_back({"{{TYPE}}", type});
|
||||
rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
rewrites->push_back({"{{INDICES}}", indices});
|
||||
return Status::OK();
|
||||
@ -159,7 +159,8 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
string RewriteWithName(const string& name, string code,
|
||||
const std::vector<std::pair<string, string>>& rewrites) {
|
||||
absl::StrReplaceAll(rewrites, &code);
|
||||
return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
|
||||
absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
|
||||
return code;
|
||||
}
|
||||
|
||||
// Generate methods for args (inputs).
|
||||
@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
|
||||
{"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
|
||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||
str_util::Join(metadata_result.header_variable_decls, "\n")},
|
||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||
{"{{ENTRY}}", compile_result.entry_point},
|
||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||
metadata_result.hlo_profile_printer_data_access_shim},
|
||||
@ -595,7 +596,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
|
||||
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
|
||||
{"{{BUFFER_INFOS_AS_STRING}}",
|
||||
str_util::Join(buffer_infos_as_strings, ",\n")}};
|
||||
absl::StrJoin(buffer_infos_as_strings, ",\n")}};
|
||||
absl::StrReplaceAll(rewrites, header);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -18,13 +18,13 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -34,9 +34,9 @@ namespace {
|
||||
|
||||
using ::tensorflow::cpu_function_runtime::BufferInfo;
|
||||
|
||||
void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
void ExpectErrorContains(const Status& status, absl::string_view str) {
|
||||
EXPECT_NE(Status::OK(), status);
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), str))
|
||||
<< "expected error: " << status.error_message() << " to contain: " << str;
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -34,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -55,7 +56,7 @@ const char kUsageHeader[] =
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (str_util::EndsWith(fname, ".pbtxt")) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << str_util::Join(nodes, ",");
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -320,6 +320,7 @@ cc_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -347,6 +348,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -389,6 +391,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -488,6 +491,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
@ -153,7 +154,7 @@ class AndPredicate : public Predicate {
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
|
||||
return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
@ -182,7 +183,7 @@ class OrPredicate : public Predicate {
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
|
||||
return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -44,7 +44,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
@ -25,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
std::unordered_set<string> control_input_a;
|
||||
std::unordered_set<string> control_input_b;
|
||||
for (int i = 0; i < a.input_size(); ++i) {
|
||||
if (str_util::StartsWith(a.input(i), "^")) {
|
||||
if (!str_util::StartsWith(b.input(i), "^")) {
|
||||
if (absl::StartsWith(a.input(i), "^")) {
|
||||
if (!absl::StartsWith(b.input(i), "^")) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(
|
||||
diff_preamble, " mismatch for node ", a.name(), " input ", i,
|
||||
@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
|
||||
Graph* graph = graph_ptr->get();
|
||||
for (const Node* n : graph->nodes()) {
|
||||
if (n->type_string() == "_Arg" &&
|
||||
str_util::StartsWith(n->name(), "const")) {
|
||||
absl::StartsWith(n->name(), "const")) {
|
||||
++guaranteed_consts;
|
||||
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
|
||||
} else {
|
||||
@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
|
||||
Graph* graph = graph_ptr->get();
|
||||
for (const Node* n : graph->nodes()) {
|
||||
if (n->type_string() == "_Arg" &&
|
||||
str_util::StartsWith(n->name(), "const")) {
|
||||
absl::StartsWith(n->name(), "const")) {
|
||||
++guaranteed_consts;
|
||||
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
|
||||
} else {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
@ -32,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -656,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
||||
|
||||
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(status.ToString(),
|
||||
"Edge from c to a would create a cycle.\n"
|
||||
"+-> a\n"
|
||||
"| b\n"
|
||||
"+-- c\n"));
|
||||
EXPECT_TRUE(absl::StrContains(status.ToString(),
|
||||
"Edge from c to a would create a cycle.\n"
|
||||
"+-> a\n"
|
||||
"| b\n"
|
||||
"+-- c\n"));
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, Retval) {
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -82,6 +82,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
@ -327,7 +328,7 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
|
||||
std::vector<string> elements_debug_string;
|
||||
std::transform(resource_op_set.begin(), resource_op_set.end(),
|
||||
std::back_inserter(elements_debug_string), ResourceOpToString);
|
||||
return strings::StrCat("{", str_util::Join(elements_debug_string, ","), "}");
|
||||
return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
|
||||
}
|
||||
|
||||
string NodeToString(const Node& n, ResourceOpKind resource_op_kind) {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
@ -33,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -575,7 +575,7 @@ TEST(ResourceOperationSafetyAnalysisTest, HaveAllResourceOps) {
|
||||
|
||||
EXPECT_TRUE(unnecessary_resource_ops.empty())
|
||||
<< "Stale resource ops:\n"
|
||||
<< str_util::Join(unnecessary_resource_ops, "\n");
|
||||
<< absl::StrJoin(unnecessary_resource_ops, "\n");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -39,6 +39,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -88,6 +89,7 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -254,6 +256,7 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -305,6 +308,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -372,6 +376,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -444,6 +449,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
@ -43,11 +44,11 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
|
||||
template <typename T>
|
||||
string NodesToString(const T& nodes) {
|
||||
return strings::StrCat("{",
|
||||
str_util::Join(nodes, ",",
|
||||
[](string* output, const Node* node) {
|
||||
strings::StrAppend(output,
|
||||
node->name());
|
||||
}),
|
||||
absl::StrJoin(nodes, ",",
|
||||
[](string* output, const Node* node) {
|
||||
strings::StrAppend(output,
|
||||
node->name());
|
||||
}),
|
||||
"}");
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// XLA-specific Ops for broadcasting used in gradient
|
||||
// code.
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel {
|
||||
BCast bcast(shapes[0], shapes[1]);
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"Incompatible shapes: [", str_util::Join(shapes[0], ","),
|
||||
"] vs. [", str_util::Join(shapes[1], ","), "]"));
|
||||
"Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
|
||||
"] vs. [", absl::StrJoin(shapes[1], ","), "]"));
|
||||
|
||||
const int64 len = bcast.output_shape().size();
|
||||
Tensor output(DT_INT32, TensorShape({len}));
|
||||
@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel {
|
||||
BCast bcast(shapes[0], shapes[1]);
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"Incompatible shapes: [", str_util::Join(shapes[0], ","),
|
||||
"] vs. [", str_util::Join(shapes[1], ","), "]"));
|
||||
"Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
|
||||
"] vs. [", absl::StrJoin(shapes[1], ","), "]"));
|
||||
Output(ctx, 0, bcast.grad_x_reduce_idx());
|
||||
Output(ctx, 1, bcast.grad_y_reduce_idx());
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA-specific reduction Ops.
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -66,7 +67,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes));
|
||||
|
||||
VLOG(1) << "data shape: " << data_shape.DebugString();
|
||||
VLOG(1) << "axes : " << str_util::Join(axes, ",");
|
||||
VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
|
||||
|
||||
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
|
||||
std::vector<int64> xla_axes;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA-specific Ops for softmax.
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -25,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -33,7 +33,7 @@ namespace {
|
||||
class SoftmaxOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
log_ = str_util::StartsWith(type_string(), "Log");
|
||||
log_ = absl::StartsWith(type_string(), "Log");
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
@ -65,8 +65,8 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
if (explicit_sharding.has_value()) {
|
||||
return explicit_sharding;
|
||||
} else if (!parsed_device.has_type || !parsed_device.has_id ||
|
||||
!str_util::StrContains(parsed_device.type,
|
||||
kDeviceSuffixReplicatedCore)) {
|
||||
!absl::StrContains(parsed_device.type,
|
||||
kDeviceSuffixReplicatedCore)) {
|
||||
return absl::optional<xla::OpSharding>();
|
||||
} else {
|
||||
const int core = parsed_device.id;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -40,7 +41,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -197,8 +197,8 @@ Status RewriteAndPruneGraph(
|
||||
if (!missing_feeds.empty() || !missing_fetches.empty()) {
|
||||
return errors::Aborted(
|
||||
"Post graph-pruning",
|
||||
", missing feeds: ", str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ", str_util::Join(missing_fetches, ", "));
|
||||
", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
|
||||
", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -20,11 +20,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) {
|
||||
}
|
||||
std::sort(types.begin(), types.end());
|
||||
constraints.push_back("`" + constraint.name() + "={" +
|
||||
str_util::Join(types, ",") + "}`");
|
||||
absl::StrJoin(types, ",") + "}`");
|
||||
}
|
||||
std::cout << "`" << kdef->op() << "` | "
|
||||
<< str_util::Join(constraints, "<br>") << std::endl;
|
||||
<< absl::StrJoin(constraints, "<br>") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
|
||||
@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
|
||||
{"device", &device,
|
||||
"Name of the compilation device for which to print supported ops, "
|
||||
"one of: " +
|
||||
str_util::Join(device_names, ",")},
|
||||
absl::StrJoin(device_names, ",")},
|
||||
};
|
||||
string usage = Flags::Usage(argv[0], flag_list);
|
||||
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
@ -25,16 +26,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
void ExpectErrorContains(const Status& status, absl::string_view str) {
|
||||
EXPECT_NE(Status::OK(), status);
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), str))
|
||||
<< "expected error: " << status.error_message() << " to contain: " << str;
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
@ -38,7 +39,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
@ -309,10 +309,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
||||
std::move(graph), args, &result);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
str_util::StrContains(status.error_message(), "depends on a parameter"))
|
||||
absl::StrContains(status.error_message(), "depends on a parameter"))
|
||||
<< status.error_message();
|
||||
EXPECT_TRUE(
|
||||
str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
|
||||
absl::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
@ -727,8 +727,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
|
||||
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
|
||||
/*args=*/{}, &result);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
|
||||
"is not defined."))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
@ -807,12 +806,10 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
|
||||
|
||||
ASSERT_FALSE(status.ok());
|
||||
// Flib lookup failure.
|
||||
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
|
||||
"is not defined."))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
|
||||
<< status.error_message();
|
||||
// Local flib lookup failure.
|
||||
EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
|
||||
"Attr T is not found"))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
@ -1078,9 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
|
||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
||||
std::move(graph), args, &result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
|
||||
<< status.error_message();
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}"))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
@ -1103,10 +1100,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
|
||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
|
||||
std::move(graph), args, &result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
||||
"is not in the list of allowed values"))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||
"is not in the list of allowed values"))
|
||||
<< status.error_message();
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}"))
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
@ -1130,9 +1127,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
|
||||
std::move(graph_copy), args, &result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
str_util::StrContains(status.error_message(),
|
||||
"The following nodes are unreachable "
|
||||
"from the source in the graph: {{node NoOp}}"))
|
||||
absl::StrContains(status.error_message(),
|
||||
"The following nodes are unreachable "
|
||||
"from the source in the graph: {{node NoOp}}"))
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user