[XLA] Make tensorflow/compiler use absl::{StrCat,string_view,InlinedVector} consistently
StringPiece is an alias for absl::string_view, InlinedVector is aliased to absl::InlinedVector. StrCat is compatible, so swapping it out is safe. PiperOrigin-RevId: 211691840
This commit is contained in:
parent
c9c8de4402
commit
11caab3c13
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -31,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
@ -135,12 +135,12 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
indices = "[0]";
|
||||
} else {
|
||||
for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
dim_vars.push_back(strings::StrCat("size_t dim", dim));
|
||||
dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += strings::StrCat("[dim", dim, "]");
|
||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += absl::StrCat("[dim", dim, "]");
|
||||
}
|
||||
}
|
||||
rewrites->push_back({"{{I}}", strings::StrCat(i)});
|
||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||
rewrites->push_back({"{{TYPE}}", type});
|
||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
@ -194,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.feed(i).name().empty()) {
|
||||
*methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
|
||||
}
|
||||
@ -235,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
result_data({{I}}))){{INDICES}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.fetch(i).name().empty()) {
|
||||
*methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
|
||||
}
|
||||
@ -304,8 +304,8 @@ std::vector<string> BufferInfosToCppExpression(
|
||||
string encoded_second_as_str =
|
||||
encoded.second == ~0ULL
|
||||
? "~0ULL"
|
||||
: strings::StrCat(encoded.second, "ULL");
|
||||
return strings::StrCat(
|
||||
: absl::StrCat(encoded.second, "ULL");
|
||||
return absl::StrCat(
|
||||
"::tensorflow::cpu_function_runtime::BufferInfo({",
|
||||
encoded.first, "ULL, ", encoded_second_as_str, "})");
|
||||
});
|
||||
@ -352,13 +352,13 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
// Create rewrite strings for namespace start and end.
|
||||
string ns_start;
|
||||
for (const string& n : opts.namespaces) {
|
||||
ns_start += strings::StrCat("namespace ", n, " {\n");
|
||||
ns_start += absl::StrCat("namespace ", n, " {\n");
|
||||
}
|
||||
ns_start += "\n";
|
||||
string ns_end("\n");
|
||||
for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
|
||||
const string& n = opts.namespaces[i];
|
||||
ns_end += strings::StrCat("} // end namespace ", n, "\n");
|
||||
ns_end += absl::StrCat("} // end namespace ", n, "\n");
|
||||
}
|
||||
|
||||
// Generate metadata.
|
||||
@ -568,10 +568,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
)";
|
||||
// The replacement strategy is naive, but good enough for our purposes.
|
||||
const std::vector<std::pair<string, string>> rewrites = {
|
||||
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
|
||||
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
|
||||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
|
||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
@ -590,11 +590,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
|
||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||
metadata_result.program_shape_access_shim},
|
||||
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
|
||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
|
||||
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
|
||||
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||
{"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
|
||||
{"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
|
||||
{"{{BUFFER_INFOS_AS_STRING}}",
|
||||
absl::StrJoin(buffer_infos_as_strings, ",\n")}};
|
||||
absl::StrReplaceAll(rewrites, header);
|
||||
@ -602,13 +602,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
}
|
||||
|
||||
static string CreateUniqueIdentifier(const CodegenOpts& opts,
|
||||
StringPiece suffix) {
|
||||
absl::string_view suffix) {
|
||||
string result = "__tfcompile";
|
||||
for (const string& n : opts.namespaces) {
|
||||
strings::StrAppend(&result, "_", n);
|
||||
absl::StrAppend(&result, "_", n);
|
||||
}
|
||||
|
||||
strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
|
||||
absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -678,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
|
||||
Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
|
||||
if (ident.empty()) {
|
||||
return errors::InvalidArgument("empty identifier: ", msg);
|
||||
}
|
||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
|
||||
|
||||
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
|
||||
// appended to error messages.
|
||||
Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
Status ValidateCppIdent(absl::string_view ident, absl::string_view msg);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef;
|
||||
|
||||
static void AddEmbeddedProtocolBufferToLlvmModule(
|
||||
llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
|
||||
StringPiece unique_identifier, string* protobuf_array_symbol_name,
|
||||
absl::string_view unique_identifier, string* protobuf_array_symbol_name,
|
||||
int64* protobuf_array_size) {
|
||||
string protobuf_array_contents = proto.SerializeAsString();
|
||||
*protobuf_array_symbol_name =
|
||||
strings::StrCat(unique_identifier, "_protobuf_array_contents");
|
||||
absl::StrCat(unique_identifier, "_protobuf_array_contents");
|
||||
*protobuf_array_size = protobuf_array_contents.size();
|
||||
|
||||
llvm::Constant* protobuf_array_initializer =
|
||||
@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule(
|
||||
protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
|
||||
}
|
||||
|
||||
static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
|
||||
StringPiece protobuf_array_symbol_name,
|
||||
int64 protobuf_array_size) {
|
||||
static string CreateCPPShimExpression(
|
||||
absl::string_view qualified_cpp_protobuf_name,
|
||||
absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) {
|
||||
string code =
|
||||
"[]() {\n"
|
||||
" {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
|
||||
@ -68,9 +68,9 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
|
||||
return absl::StrReplaceAll(
|
||||
code,
|
||||
{
|
||||
{"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)},
|
||||
{"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)},
|
||||
{"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)},
|
||||
{"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
|
||||
{"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
|
||||
{"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
|
||||
});
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
|
||||
}
|
||||
|
||||
static StatusOr<std::unique_ptr<llvm::TargetMachine>>
|
||||
GetTargetMachineFromTriple(StringPiece target_triple) {
|
||||
GetTargetMachineFromTriple(absl::string_view target_triple) {
|
||||
std::string error;
|
||||
std::string normalized_triple =
|
||||
llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
|
||||
@ -110,7 +110,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
|
||||
}
|
||||
|
||||
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
|
||||
StringPiece target_triple,
|
||||
absl::string_view target_triple,
|
||||
absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
|
||||
GetTargetMachineFromTriple(target_triple));
|
||||
@ -135,8 +135,8 @@ StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
|
||||
protobuf_to_embed.qualified_cpp_protobuf_name,
|
||||
protobuf_array_symbol_name, protobuf_array_size);
|
||||
|
||||
cpp_variable_decl = strings::StrCat("extern \"C\" char ",
|
||||
protobuf_array_symbol_name, "[];");
|
||||
cpp_variable_decl =
|
||||
absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
|
||||
} else {
|
||||
cpp_shim = "nullptr";
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ struct ProtobufToEmbed {
|
||||
// is stored in the object_file_data field in the returned
|
||||
// EmbeddedProtocolBuffers instance.
|
||||
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
|
||||
StringPiece target_triple,
|
||||
absl::string_view target_triple,
|
||||
absl::Span<const ProtobufToEmbed> protobufs_to_embed);
|
||||
|
||||
} // namespace tfcompile
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
@ -92,8 +92,9 @@ Status Main(const MainFlags& flags) {
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
|
||||
StringPiece(obj.data(), obj.size())));
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
|
@ -410,6 +410,7 @@ cc_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -566,6 +567,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -154,7 +154,7 @@ class AndPredicate : public Predicate {
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
|
||||
return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
@ -185,7 +185,7 @@ class OrPredicate : public Predicate {
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
|
||||
return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
@ -206,7 +206,7 @@ class NotPredicate : public Predicate {
|
||||
operands_({operand}) {}
|
||||
|
||||
string ToString() const override {
|
||||
return strings::StrCat("~", operand()->ToString());
|
||||
return absl::StrCat("~", operand()->ToString());
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kNot; }
|
||||
@ -240,7 +240,7 @@ class AndRecurrencePredicate : public Predicate {
|
||||
Predicate* step() const { return operands_[1]; }
|
||||
|
||||
string ToString() const override {
|
||||
return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
|
||||
return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
|
||||
"}");
|
||||
}
|
||||
|
||||
@ -267,7 +267,7 @@ class SymbolPredicate : public Predicate {
|
||||
must_be_true_(must_be_true) {}
|
||||
|
||||
string ToString() const override {
|
||||
return must_be_true() ? strings::StrCat("*", tensor_id_.ToString())
|
||||
return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
|
||||
: tensor_id_.ToString();
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
@ -45,7 +46,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -755,7 +755,7 @@ Status Encapsulator::Subgraph::RecordArg(
|
||||
if (inserted) {
|
||||
NodeDef arg_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
|
||||
absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
|
||||
DataType dtype = edge->dst()->input_type(edge->dst_input());
|
||||
builder.Attr("T", dtype);
|
||||
builder.Attr("index", arg_index);
|
||||
@ -790,7 +790,7 @@ Status Encapsulator::Subgraph::RecordResult(
|
||||
if (inserted) {
|
||||
NodeDef ret_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
|
||||
absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
|
||||
DataType dtype = src_node->output_type(src_slot);
|
||||
builder.Attr("T", dtype);
|
||||
builder.Attr("index", ret_index);
|
||||
@ -950,16 +950,15 @@ Status Encapsulator::Subgraph::AddHostComputes(
|
||||
}
|
||||
|
||||
NodeDef host_compute_def;
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_",
|
||||
NodeDefBuilder builder(absl::StrCat("outside_compilation_",
|
||||
oc_subgraph_name, "_host_compute"),
|
||||
kHostComputeOp);
|
||||
builder.Input(inputs);
|
||||
builder.Attr("Tinputs", input_dtypes);
|
||||
builder.Attr("Toutputs", output_dtypes);
|
||||
builder.Attr("ancestors", host_compute_ancestors);
|
||||
builder.Attr("key",
|
||||
strings::StrCat("host_compute_channel_", subgraph_name, "_",
|
||||
oc_subgraph_name));
|
||||
builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name,
|
||||
"_", oc_subgraph_name));
|
||||
builder.Attr("_outside_compilation_subgraph", oc_subgraph_name);
|
||||
Status s = builder.Finalize(&host_compute_def);
|
||||
if (!s.ok()) return s;
|
||||
@ -1017,8 +1016,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
|
||||
Graph* graph_out) {
|
||||
if (sequencer_ == nullptr) {
|
||||
NodeDef seq_def;
|
||||
NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
|
||||
"NoOp");
|
||||
NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
|
||||
builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
|
||||
builder.Device(device_);
|
||||
Status s = builder.Finalize(&seq_def);
|
||||
@ -1091,10 +1089,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(2) << "Build function def " << name;
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
|
||||
dump_graph::DumpFunctionDefToFile(
|
||||
strings::StrCat("encapsulate_fdef_", name), fdef);
|
||||
dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name),
|
||||
*graph_, library);
|
||||
dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name),
|
||||
fdef);
|
||||
}
|
||||
|
||||
if (!reuse_existing_functions || library->Find(name) == nullptr) {
|
||||
@ -1130,7 +1128,7 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
|
||||
host_compute->AddAttr("shapes", shapes);
|
||||
} else {
|
||||
string inference_graph_name =
|
||||
strings::StrCat("_outside_compilation_shape_inference_", subgraph_name,
|
||||
absl::StrCat("_outside_compilation_shape_inference_", subgraph_name,
|
||||
"_", outside_compilation_subgraph_name);
|
||||
FunctionDef fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1155,10 +1153,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(2) << "Replace function def " << name;
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
|
||||
absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
|
||||
library);
|
||||
dump_graph::DumpFunctionDefToFile(
|
||||
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
|
||||
absl::StrCat("replace_encapsulate_fdef_", name), fdef);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
|
||||
@ -1186,8 +1184,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder(
|
||||
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
|
||||
NodeDef key_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat(call_node_def_.name(), "_key_placeholder"),
|
||||
"Placeholder");
|
||||
absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder");
|
||||
builder.Attr("dtype", DT_STRING);
|
||||
builder.Attr("shape", shape_proto);
|
||||
builder.Attr("_host_compute_call_node", call_node_def_.name());
|
||||
@ -1221,7 +1218,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
|
||||
}
|
||||
|
||||
NodeDef recv_def;
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
|
||||
"_", oc_subgraph_name, "_recv"),
|
||||
kRecvAtHostOp);
|
||||
builder.Device(device_);
|
||||
@ -1229,8 +1226,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
|
||||
// The correct device_ordinal will be inserted during replication in a
|
||||
// subsequent rewrite.
|
||||
builder.Attr("device_ordinal", 0);
|
||||
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||
"_", oc_subgraph_name));
|
||||
builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
|
||||
oc_subgraph_name));
|
||||
builder.Attr(group_attribute, subgraph_name);
|
||||
builder.Attr(outside_compilation_attribute, oc_subgraph_name);
|
||||
builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
|
||||
@ -1276,13 +1273,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
|
||||
}
|
||||
|
||||
NodeDef send_def;
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
|
||||
"_", oc_subgraph_name, "_send"),
|
||||
kSendFromHostOp);
|
||||
builder.Device(device_);
|
||||
builder.Attr("Tinputs", dtypes);
|
||||
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||
"_", oc_subgraph_name));
|
||||
builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
|
||||
oc_subgraph_name));
|
||||
// The correct device_ordinal will be inserted during replication in a
|
||||
// subsequent rewrite.
|
||||
builder.Attr("device_ordinal", 0);
|
||||
@ -1516,7 +1513,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
|
||||
// Dump subgraphs.
|
||||
for (auto& entry : subgraphs_) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
|
||||
absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
|
||||
*entry.second.GetGraph(), library);
|
||||
}
|
||||
}
|
||||
@ -2052,7 +2049,7 @@ struct PathDetails {
|
||||
struct SubgraphAndClusterHash {
|
||||
inline std::size_t operator()(const SubgraphAndCluster& v) const {
|
||||
return hash<string>()(
|
||||
strings::StrCat(v.subgraph, v.outside_compilation_cluster));
|
||||
absl::StrCat(v.subgraph, v.outside_compilation_cluster));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
|
||||
FunctionDef* fdef = library->add_function();
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*graph,
|
||||
strings::StrCat("_outside_compilation_shape_inference_", name_suffix),
|
||||
absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
|
||||
fdef));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -65,15 +66,15 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
|
||||
const auto iter = b.find(elt_a.first);
|
||||
if (iter == b.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(
|
||||
map_name, " expected: contains element with key '",
|
||||
key_to_string(elt_a.first), "' got: map has no such element");
|
||||
*diff = absl::StrCat(map_name, " expected: contains element with key '",
|
||||
key_to_string(elt_a.first),
|
||||
"' got: map has no such element");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!compare(elt_a.first, elt_a.second, iter->second)) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(map_name, " expected: element with key '",
|
||||
*diff = absl::StrCat(map_name, " expected: element with key '",
|
||||
key_to_string(elt_a.first), "' has value '",
|
||||
value_to_string(elt_a.second), "' got: '",
|
||||
value_to_string(iter->second), "'");
|
||||
@ -85,7 +86,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
|
||||
const auto iter = a.find(elt_b.first);
|
||||
if (iter == a.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(map_name, " got: contains element with key '",
|
||||
*diff = absl::StrCat(map_name, " got: contains element with key '",
|
||||
key_to_string(elt_b.first),
|
||||
"' expected: map has no such element");
|
||||
}
|
||||
@ -99,14 +100,14 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
const string& diff_preamble, string* diff) {
|
||||
if (a.op() != b.op()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected op '", a.op(), "' got '", b.op());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (a.device() != b.device()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected device '", a.device(), "' got '",
|
||||
b.device());
|
||||
}
|
||||
@ -114,7 +115,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
}
|
||||
if (a.input_size() != b.input_size()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected ", a.input_size(), " inputs got ",
|
||||
b.input_size(), " expected:\n", a.DebugString(),
|
||||
"\ngot:\n", b.DebugString());
|
||||
@ -127,10 +128,10 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
if (absl::StartsWith(a.input(i), "^")) {
|
||||
if (!absl::StartsWith(b.input(i), "^")) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(
|
||||
diff_preamble, " mismatch for node ", a.name(), " input ", i,
|
||||
", expected control input ", a.input(i), " got ", b.input(i),
|
||||
" expected:\n", a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
" input ", i, ", expected control input ",
|
||||
a.input(i), " got ", b.input(i), " expected:\n",
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -138,17 +139,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
control_input_b.insert(b.input(i));
|
||||
} else if (a.input(i) != b.input(i)) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
" input ", i, ", expected ", a.input(i),
|
||||
" got ", b.input(i), " expected:\n",
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
" input ", i, ", expected ", a.input(i), " got ",
|
||||
b.input(i), " expected:\n", a.DebugString(),
|
||||
"\ngot:\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (control_input_a != control_input_b) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
" control inputs differ expected:\n",
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
}
|
||||
@ -170,17 +171,16 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
return av.DebugString() == bv.DebugString();
|
||||
}
|
||||
},
|
||||
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
|
||||
diff);
|
||||
absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
|
||||
}
|
||||
|
||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
string* diff) {
|
||||
if (a.signature().DebugString() != b.signature().DebugString()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Signature mismatch for function ",
|
||||
a.signature().name(), ", expected:\n",
|
||||
a.signature().DebugString(), "\ngot:\n",
|
||||
*diff =
|
||||
absl::StrCat("Signature mismatch for function ", a.signature().name(),
|
||||
", expected:\n", a.signature().DebugString(), "\ngot:\n",
|
||||
b.signature().DebugString());
|
||||
}
|
||||
return false;
|
||||
@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
[](const string& key, const AttrValue& av, const AttrValue& bv) {
|
||||
return av.DebugString() == bv.DebugString();
|
||||
},
|
||||
strings::StrCat("attr mismatch for function ", a.signature().name()),
|
||||
absl::StrCat("attr mismatch for function ", a.signature().name()),
|
||||
diff)) {
|
||||
return false;
|
||||
}
|
||||
@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
[](const string& key, const string& av, const string& bv) {
|
||||
return av == bv;
|
||||
},
|
||||
strings::StrCat("ret mismatch for function ", a.signature().name()),
|
||||
absl::StrCat("ret mismatch for function ", a.signature().name()),
|
||||
diff)) {
|
||||
return false;
|
||||
}
|
||||
@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
if (a.node_def(i).name() == b.node_def(j).name()) {
|
||||
if (!EqualFunctionNodeDef(
|
||||
a.node_def(i), b.node_def(j),
|
||||
strings::StrCat("Function ", a.signature().name()), diff)) {
|
||||
absl::StrCat("Function ", a.signature().name()), diff)) {
|
||||
return false;
|
||||
}
|
||||
found = true;
|
||||
@ -220,7 +220,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
}
|
||||
if (!found) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||
*diff = absl::StrCat("Function ", a.signature().name(),
|
||||
", expected: has node '", a.node_def(i).name(),
|
||||
"' got: no node of that name");
|
||||
}
|
||||
@ -237,7 +237,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
}
|
||||
if (!found) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||
*diff = absl::StrCat("Function ", a.signature().name(),
|
||||
", got: has node '", b.node_def(i).name(),
|
||||
"' expected: no node of that name");
|
||||
}
|
||||
@ -258,7 +258,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
auto it = actual_index.find(expected_function.signature().name());
|
||||
if (it == actual_index.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Did not find expected function '",
|
||||
*diff = absl::StrCat("Did not find expected function '",
|
||||
expected_function.signature().name(), "'");
|
||||
}
|
||||
return false;
|
||||
@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
|
||||
if (!actual_index.empty()) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Found unexpected function '",
|
||||
actual_index.begin()->second->signature().name(),
|
||||
"'");
|
||||
*diff =
|
||||
absl::StrCat("Found unexpected function '",
|
||||
actual_index.begin()->second->signature().name(), "'");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -420,10 +420,9 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
|
||||
const string& oc_cluster, absl::Span<const DataType> dtypes,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
string key =
|
||||
strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
|
||||
string name = strings::StrCat("outside_compilation_", cluster, "_",
|
||||
oc_cluster, "_recv");
|
||||
string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
|
||||
string name =
|
||||
absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv");
|
||||
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
|
||||
"_XlaRecvAtHost", opts.op_registry());
|
||||
node_builder.Input(std::move(key_input));
|
||||
@ -440,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
|
||||
const std::vector<ops::NodeOut>& inputs,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
string key =
|
||||
strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
|
||||
string name = strings::StrCat("outside_compilation_", cluster, "_",
|
||||
oc_cluster, "_send");
|
||||
string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
|
||||
string name =
|
||||
absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send");
|
||||
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
|
||||
"_XlaSendFromHost", opts.op_registry());
|
||||
node_builder.Input(inputs);
|
||||
@ -682,8 +680,8 @@ std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
|
||||
for (const Edge* edge : graph.edges()) {
|
||||
if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
|
||||
edges.emplace_back(
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output()),
|
||||
strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
|
||||
absl::StrCat(edge->src()->name(), ":", edge->src_output()),
|
||||
absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
|
||||
}
|
||||
std::sort(edges.begin(), edges.end());
|
||||
return edges;
|
||||
|
@ -14,6 +14,7 @@ cc_library(
|
||||
hdrs = ["graphcycles.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -44,7 +44,7 @@ namespace {
|
||||
typedef std::unordered_set<int32> NodeSet;
|
||||
template <typename T>
|
||||
struct VecStruct {
|
||||
typedef gtl::InlinedVector<T, 4> type;
|
||||
typedef absl::InlinedVector<T, 4> type;
|
||||
};
|
||||
template <typename T>
|
||||
using Vec = typename VecStruct<T>::type;
|
||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
@ -617,7 +616,7 @@ Status MarkForCompilationPass::Run(
|
||||
}
|
||||
|
||||
static string RatioToString(int numerator, int denominator) {
|
||||
return strings::Printf("%d / %d (%.2f%%)", numerator, denominator,
|
||||
return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
|
||||
(100.0 * numerator) / denominator);
|
||||
}
|
||||
|
||||
@ -626,14 +625,14 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<StringPiece, int> cluster_name_to_size;
|
||||
std::map<StringPiece, std::map<StringPiece, int>>
|
||||
std::map<absl::string_view, int> cluster_name_to_size;
|
||||
std::map<absl::string_view, std::map<absl::string_view, int>>
|
||||
cluster_name_to_op_histogram;
|
||||
std::map<StringPiece, int> unclustered_op_histogram;
|
||||
std::map<absl::string_view, int> unclustered_op_histogram;
|
||||
int clustered_node_count = 0;
|
||||
|
||||
for (Node* n : g.nodes()) {
|
||||
absl::optional<StringPiece> cluster_name = GetXlaClusterForNode(*n);
|
||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||
if (cluster_name) {
|
||||
clustered_node_count++;
|
||||
cluster_name_to_size[*cluster_name]++;
|
||||
@ -650,7 +649,7 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
<< RatioToString(clustered_node_count, g.num_nodes());
|
||||
|
||||
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
|
||||
StringPiece cluster_name = cluster_name_size_pair.first;
|
||||
absl::string_view cluster_name = cluster_name_size_pair.first;
|
||||
int size = cluster_name_size_pair.second;
|
||||
VLOG(2) << " " << cluster_name << " "
|
||||
<< RatioToString(size, g.num_nodes());
|
||||
@ -670,14 +669,15 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
}
|
||||
|
||||
struct EdgeInfo {
|
||||
StringPiece node_name;
|
||||
absl::optional<StringPiece> cluster_name;
|
||||
absl::string_view node_name;
|
||||
absl::optional<absl::string_view> cluster_name;
|
||||
|
||||
StringPiece GetClusterName() const {
|
||||
absl::string_view GetClusterName() const {
|
||||
return cluster_name ? *cluster_name : "[none]";
|
||||
}
|
||||
|
||||
std::pair<StringPiece, absl::optional<StringPiece>> AsPair() const {
|
||||
std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
|
||||
const {
|
||||
return {node_name, cluster_name};
|
||||
}
|
||||
|
||||
@ -686,19 +686,21 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
}
|
||||
};
|
||||
|
||||
using EdgeInfoMap = std::map<StringPiece, std::map<EdgeInfo, int64>>;
|
||||
using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
|
||||
|
||||
EdgeInfoMap incoming_edge_infos;
|
||||
EdgeInfoMap outgoing_edge_infos;
|
||||
|
||||
std::set<StringPiece> cluster_names_to_print;
|
||||
std::set<absl::string_view> cluster_names_to_print;
|
||||
|
||||
for (const Edge* e : g.edges()) {
|
||||
const Node* from = e->src();
|
||||
absl::optional<StringPiece> from_cluster_name = GetXlaClusterForNode(*from);
|
||||
absl::optional<absl::string_view> from_cluster_name =
|
||||
GetXlaClusterForNode(*from);
|
||||
|
||||
const Node* to = e->dst();
|
||||
absl::optional<StringPiece> to_cluster_name = GetXlaClusterForNode(*to);
|
||||
absl::optional<absl::string_view> to_cluster_name =
|
||||
GetXlaClusterForNode(*to);
|
||||
|
||||
if (to_cluster_name == from_cluster_name) {
|
||||
continue;
|
||||
@ -721,9 +723,9 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
VLOG(2) << " [none]";
|
||||
}
|
||||
|
||||
auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name,
|
||||
auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
|
||||
const EdgeInfoMap& edge_info_map,
|
||||
StringPiece desc) {
|
||||
absl::string_view desc) {
|
||||
auto it = edge_info_map.find(cluster_name);
|
||||
if (it != edge_info_map.end()) {
|
||||
VLOG(2) << " " << it->second.size() << " " << desc << " edges";
|
||||
@ -737,7 +739,7 @@ static void VLogClusteringSummary(const Graph& g) {
|
||||
}
|
||||
};
|
||||
|
||||
for (StringPiece cluster_name : cluster_names_to_print) {
|
||||
for (absl::string_view cluster_name : cluster_names_to_print) {
|
||||
VLOG(2) << " ** Cluster " << cluster_name;
|
||||
print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
|
||||
"incoming");
|
||||
@ -966,7 +968,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
name = strings::StrCat("cluster_", cluster_sequence_num++);
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
|
@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
{
|
||||
auto BuildNoopNode = [](StringPiece name, Graph* graph) {
|
||||
auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
|
||||
NodeDefBuilder builder(name, "NoOp");
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(builder.Finalize(&def));
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -30,7 +31,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
|
||||
MemoryTypeVector input_mtypes, output_mtypes;
|
||||
|
||||
for (Node* n : post_order) {
|
||||
absl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
|
||||
absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
|
||||
if (!from_cluster) {
|
||||
continue;
|
||||
}
|
||||
@ -79,7 +80,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
|
||||
// Check if `dst` is in a different cluster, unclustered, or about to be
|
||||
// partially declustered (here we rely on the post-order traversal order).
|
||||
// If yes, decluster `n` to avoid the device-to-host memcpy.
|
||||
absl::optional<StringPiece> dst_cluster =
|
||||
absl::optional<absl::string_view> dst_cluster =
|
||||
result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst);
|
||||
if (from_cluster != dst_cluster) {
|
||||
CHECK(result->insert(n).second);
|
||||
@ -91,15 +92,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
|
||||
}
|
||||
|
||||
Status PartiallyDeclusterNode(Graph* graph, Node* n) {
|
||||
StringPiece cluster_name = *GetXlaClusterForNode(*n);
|
||||
gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
|
||||
absl::string_view cluster_name = *GetXlaClusterForNode(*n);
|
||||
absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
|
||||
for (const Edge* out_edge : n->out_edges()) {
|
||||
if (out_edge->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* dst = out_edge->dst();
|
||||
absl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
|
||||
absl::optional<absl::string_view> dst_cluster_name =
|
||||
GetXlaClusterForNode(*dst);
|
||||
if (dst_cluster_name != cluster_name) {
|
||||
out_edges_to_clone.push_back(out_edge);
|
||||
}
|
||||
@ -108,7 +110,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
|
||||
CHECK(!out_edges_to_clone.empty()) << n->DebugString();
|
||||
|
||||
NodeDef ndef = n->def();
|
||||
ndef.set_name(strings::StrCat(n->name(), "/declustered"));
|
||||
ndef.set_name(absl::StrCat(n->name(), "/declustered"));
|
||||
RemoveFromXlaCluster(&ndef);
|
||||
Status s;
|
||||
Node* cloned_node = graph->AddNode(ndef, &s);
|
||||
|
@ -165,7 +165,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
|
||||
using ResourceOp = std::pair<int, XlaResourceOpKind>;
|
||||
|
||||
string ResourceOpToString(const ResourceOp& resource_op) {
|
||||
return strings::StrCat(
|
||||
return absl::StrCat(
|
||||
resource_op.first, ": ",
|
||||
XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
|
||||
}
|
||||
@ -257,11 +257,11 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
|
||||
std::vector<string> elements_debug_string;
|
||||
std::transform(resource_op_set.begin(), resource_op_set.end(),
|
||||
std::back_inserter(elements_debug_string), ResourceOpToString);
|
||||
return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
|
||||
return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
|
||||
}
|
||||
|
||||
string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
|
||||
return strings::StrCat(
|
||||
return absl::StrCat(
|
||||
"[", n.name(), ": ", n.type_string(), "(",
|
||||
XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
@ -52,7 +53,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
|
||||
};
|
||||
|
||||
string description;
|
||||
strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
|
||||
absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
|
||||
node_name(dst), " would create a cycle.\n");
|
||||
path.resize(path_size);
|
||||
for (int32 node_id : path) {
|
||||
@ -64,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
|
||||
} else {
|
||||
ascii_art = "+-- ";
|
||||
}
|
||||
strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
|
||||
absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
|
||||
}
|
||||
return description;
|
||||
}
|
||||
@ -186,7 +187,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
absl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
|
||||
absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
|
||||
const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
|
||||
if (attr_value == nullptr) {
|
||||
return absl::nullopt;
|
||||
|
@ -47,7 +47,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
|
||||
|
||||
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
|
||||
// otherwise returns nullopt.
|
||||
absl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
|
||||
absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
|
||||
|
||||
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
|
||||
void RemoveFromXlaCluster(NodeDef* node_def);
|
||||
|
@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() {
|
||||
string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
|
||||
string result = sig.name;
|
||||
for (const auto& a : sig.arg_types) {
|
||||
strings::StrAppend(&result, ",", DataTypeString(a.first),
|
||||
absl::StrAppend(&result, ",", DataTypeString(a.first),
|
||||
a.second.DebugString());
|
||||
}
|
||||
|
||||
for (const auto& v : sig.arg_values) {
|
||||
strings::StrAppend(&result, "; ", v.DebugString());
|
||||
absl::StrAppend(&result, "; ", v.DebugString());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -148,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
||||
}
|
||||
|
||||
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
|
||||
strings::StrCat(name_prefix, "/device:", device_name, ":",
|
||||
device_ordinal),
|
||||
absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
|
||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||
strings::StrCat("device: ", device_name, " device"));
|
||||
absl::StrCat("device: ", device_name, " device"));
|
||||
|
||||
device->reset(
|
||||
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||
|
@ -203,7 +203,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
}
|
||||
|
||||
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name,
|
||||
absl::string_view tensor_name,
|
||||
Device* device,
|
||||
Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
@ -339,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name,
|
||||
absl::string_view tensor_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
|
||||
|
@ -57,7 +57,7 @@ class XlaTransferManager {
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor, StatusCallback done) const;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
absl::string_view tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done);
|
||||
|
||||
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
|
||||
@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext {
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
absl::string_view tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
|
||||
const StatusCallback& done);
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
@ -326,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
name = strings::StrCat("cluster_", cluster_sequence_num++);
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
|
@ -122,7 +122,7 @@ class XlaTensor {
|
||||
std::shared_ptr<se::Event> definition_event_;
|
||||
// A list of all streams for which the tensor's content is defined for any
|
||||
// newly enqueued command.
|
||||
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
};
|
||||
|
||||
|
@ -1103,6 +1103,7 @@ cc_library(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -45,6 +45,8 @@ limitations under the License.
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -61,7 +63,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main()
|
||||
bool tf_xla_test_use_jit = true;
|
||||
|
||||
string LocalDeviceToFullDeviceName(const string& device) {
|
||||
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
||||
return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
||||
}
|
||||
|
||||
constexpr std::array<DataType, 5> kAllXlaTypes = {
|
||||
@ -107,11 +108,12 @@ class OpTestBuilder {
|
||||
|
||||
// Sets an attribute.
|
||||
template <class T>
|
||||
OpTestBuilder& Attr(StringPiece attr_name, T&& value);
|
||||
OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
|
||||
|
||||
// Overload needed to allow {...} expressions for value.
|
||||
template <class T>
|
||||
OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
|
||||
OpTestBuilder& Attr(absl::string_view attr_name,
|
||||
std::initializer_list<T> value);
|
||||
|
||||
// Adds nodes that executes the operator under test on 'device' to 'graphdef'.
|
||||
// If 'use_jit' is true, marks the operator under test to be compiled by XLA.
|
||||
@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
|
||||
}
|
||||
|
||||
template <class T>
|
||||
OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
|
||||
OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
|
||||
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name,
|
||||
OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
|
||||
std::initializer_list<T> value) {
|
||||
Attr<std::initializer_list<T>>(attr_name, std::move(value));
|
||||
return *this;
|
||||
@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
|
||||
|
||||
NodeDef* test_def = graphdef->add_node();
|
||||
*test_def = node_def_;
|
||||
test_def->set_name(strings::StrCat(name_prefix, "_op_under_test"));
|
||||
test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
|
||||
test_def->set_device(device);
|
||||
AddDefaultsToNodeDef(*op_def, test_def);
|
||||
if (use_jit) {
|
||||
@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
|
||||
// Build feed and fetch nodes.
|
||||
for (int i = 0; i < input_types.size(); ++i) {
|
||||
NodeDef* def = graphdef->add_node();
|
||||
string name = strings::StrCat(name_prefix, "_input_", i);
|
||||
string name = absl::StrCat(name_prefix, "_input_", i);
|
||||
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
|
||||
.Device(device)
|
||||
.Attr("dtype", input_types[i])
|
||||
@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
|
||||
|
||||
for (int i = 0; i < output_types.size(); ++i) {
|
||||
NodeDef* def = graphdef->add_node();
|
||||
string name = strings::StrCat(name_prefix, "_output_", i);
|
||||
string name = absl::StrCat(name_prefix, "_output_", i);
|
||||
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
|
||||
.Device(device)
|
||||
.Attr("T", output_types[i])
|
||||
@ -726,11 +728,11 @@ bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
|
||||
|
||||
template <typename T>
|
||||
string Str(T x) {
|
||||
return strings::StrCat(x);
|
||||
return absl::StrCat(x);
|
||||
}
|
||||
template <>
|
||||
string Str<complex64>(complex64 x) {
|
||||
return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
|
||||
return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
|
||||
auto Ty = y.flat<T>();
|
||||
for (int i = 0; i < Tx.size(); ++i) {
|
||||
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
|
||||
Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
|
||||
"atol = ", atol, " rtol = ", rtol,
|
||||
" tol = ", atol + rtol * Abs(Tx(i))));
|
||||
return errors::InvalidArgument(
|
||||
absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
|
||||
" vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
|
||||
"y = ", y.DebugString(), "atol = ", atol,
|
||||
" rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
|
||||
auto Ty = y.flat<T>();
|
||||
for (int i = 0; i < Tx.size(); ++i) {
|
||||
if (Tx(i) != Ty(i)) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i),
|
||||
". x = ", x.DebugString(), "y = ", y.DebugString()));
|
||||
}
|
||||
@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
|
||||
Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
|
||||
double rtol) {
|
||||
if (a.dtype() != b.dtype()) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"Tensors have different types: ", DataTypeString(a.dtype()), " and ",
|
||||
DataTypeString(b.dtype())));
|
||||
}
|
||||
if (!a.IsSameSize(b)) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
"Tensors have different shapes: ", a.shape().DebugString(), " and ",
|
||||
b.shape().DebugString()));
|
||||
return errors::InvalidArgument(
|
||||
absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
|
||||
" and ", b.shape().DebugString()));
|
||||
}
|
||||
|
||||
switch (a.dtype()) {
|
||||
@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
|
||||
}
|
||||
|
||||
string cpu_device =
|
||||
LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
|
||||
LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0"));
|
||||
string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
|
||||
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
|
||||
std::vector<string> expected_inputs, test_inputs;
|
||||
std::vector<string> expected_fetches, test_fetches;
|
||||
Status status = builder.BuildGraph(
|
||||
strings::StrCat("test", num_tests_, "_expected"), cpu_device,
|
||||
absl::StrCat("test", num_tests_, "_expected"), cpu_device,
|
||||
/* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
|
||||
&expected_inputs, &expected_fetches);
|
||||
if (!status.ok()) {
|
||||
@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
|
||||
}
|
||||
|
||||
NodeDef* node_def;
|
||||
status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
|
||||
status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
|
||||
test_device, tf_xla_test_use_jit, &graph,
|
||||
&node_def, &test_inputs, &test_fetches);
|
||||
if (!status.ok()) {
|
||||
|
@ -291,6 +291,7 @@ cc_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -433,6 +434,7 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -609,11 +611,10 @@ cc_library(
|
||||
srcs = ["resource_operation_table.cc"],
|
||||
hdrs = ["resource_operation_table.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) {
|
||||
|
||||
string filename = name;
|
||||
if (count > 0) {
|
||||
strings::StrAppend(&filename, "_", count);
|
||||
absl::StrAppend(&filename, "_", count);
|
||||
}
|
||||
strings::StrAppend(&filename, ".pbtxt");
|
||||
absl::StrAppend(&filename, ".pbtxt");
|
||||
return filename;
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile(
|
||||
<< proto_type << ": " << status;
|
||||
return "(unavailable)";
|
||||
}
|
||||
string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name));
|
||||
string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name));
|
||||
status = WriteTextProto(Env::Default(), filepath, proto);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
|
||||
|
@ -42,7 +42,7 @@ namespace functionalize_cond {
|
||||
|
||||
// TODO(jpienaar): Move to OutputTensor.
|
||||
string DebugString(const OutputTensor& tensor) {
|
||||
return strings::StrCat(tensor.node->name(), ":", tensor.index);
|
||||
return absl::StrCat(tensor.node->name(), ":", tensor.index);
|
||||
}
|
||||
|
||||
string Branch_Name(BranchType b) {
|
||||
@ -61,16 +61,16 @@ string Branch_Name(BranchType b) {
|
||||
string DebugString(StateMap::CondId cond_state) {
|
||||
if (cond_state == nullptr || cond_state->empty()) return "{}";
|
||||
using value_type = StateMap::CondState::value_type;
|
||||
return strings::StrCat(
|
||||
return absl::StrCat(
|
||||
"{",
|
||||
absl::StrJoin(*cond_state, ", ",
|
||||
[](string* output, const value_type& pred_branch) {
|
||||
const OutputTensor& pred = pred_branch.first;
|
||||
const BranchType& branch = pred_branch.second;
|
||||
if (branch == BranchType::kNeither)
|
||||
strings::StrAppend(output, "d");
|
||||
absl::StrAppend(output, "d");
|
||||
else
|
||||
strings::StrAppend(output, "s(", DebugString(pred), ",",
|
||||
absl::StrAppend(output, "s(", DebugString(pred), ",",
|
||||
Branch_Name(branch), ")");
|
||||
}),
|
||||
"}");
|
||||
@ -159,7 +159,7 @@ struct CondArgNode {
|
||||
: src(src), src_output(src_output) {}
|
||||
|
||||
string ToString() const {
|
||||
return strings::StrCat("src=", src->name(), ":", src_output,
|
||||
return absl::StrCat("src=", src->name(), ":", src_output,
|
||||
" switches=", NodesToString(switches));
|
||||
}
|
||||
|
||||
@ -171,11 +171,11 @@ struct CondArgNode {
|
||||
using CondArgNodes = std::vector<CondArgNode>;
|
||||
|
||||
string DebugString(const CondArgNodes& nodes) {
|
||||
return strings::StrCat(
|
||||
return absl::StrCat(
|
||||
"[",
|
||||
absl::StrJoin(nodes, ", ",
|
||||
[](string* output, const CondArgNode& node) {
|
||||
strings::StrAppend(output, node.ToString());
|
||||
absl::StrAppend(output, node.ToString());
|
||||
}),
|
||||
"]");
|
||||
}
|
||||
@ -373,7 +373,7 @@ Status Conditional::BuildArgumentNodes() {
|
||||
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
||||
int branch_index = static_cast<int>(branch);
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_Arg", arg_count),
|
||||
NodeBuilder(absl::StrCat("_Arg", arg_count),
|
||||
FunctionLibraryDefinition::kArgOp)
|
||||
.Attr("T", dtype)
|
||||
.Attr("index", arg_count)
|
||||
@ -441,7 +441,7 @@ Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
|
||||
Node* src = edge->src();
|
||||
int src_output = edge->src_output();
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(graph->NewName(strings::StrCat(src->name(), "_added_switch")),
|
||||
NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
|
||||
"Switch")
|
||||
.Input(src, src_output)
|
||||
.Input(const_cast<Node*>(predicate_.node), predicate_.index)
|
||||
@ -650,8 +650,8 @@ Status Conditional::BuildIfNode(Graph* graph,
|
||||
int64 id = ++sequence_num;
|
||||
|
||||
NameAttrList body_name;
|
||||
body_name.set_name(strings::StrCat("_functionalize_if_",
|
||||
branch_name[branch_index], "_", id));
|
||||
body_name.set_name(
|
||||
absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id));
|
||||
|
||||
VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
|
||||
<< "): "
|
||||
@ -804,7 +804,7 @@ Status Conditional::BuildAndReplace(Graph* graph,
|
||||
|
||||
string Conditional::name() const {
|
||||
CHECK(!merges_.empty());
|
||||
return strings::StrCat((*merges_.begin())->name(), "_if");
|
||||
return absl::StrCat((*merges_.begin())->name(), "_if");
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
|
||||
@ -1327,12 +1327,12 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
|
||||
for (Node* n : graph_->nodes()) {
|
||||
n->ClearAttr(kCondGroupDebugAttr);
|
||||
n->AddAttr(kCondGroupDebugAttr,
|
||||
strings::StrCat(state_map_.CondStateToString(n), "_",
|
||||
absl::StrCat(state_map_.CondStateToString(n), "_",
|
||||
state_map_.AncestorStateToString(n)));
|
||||
}
|
||||
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
|
||||
<< dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("functionalize_", name), *graph_, library_);
|
||||
<< dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name),
|
||||
*graph_, library_);
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::Functionalize(Graph* graph,
|
||||
|
@ -42,7 +42,7 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
|
||||
const char* const kRetValOp = "_Retval";
|
||||
NodeDef ret_def;
|
||||
ret_def.set_op(kRetValOp);
|
||||
ret_def.set_name(strings::StrCat(kRetValOp, index));
|
||||
ret_def.set_name(absl::StrCat(kRetValOp, index));
|
||||
AddNodeAttr("T", type, &ret_def);
|
||||
AddNodeAttr("index", index, &ret_def);
|
||||
return AddNodeDefToGraph(ret_def, graph);
|
||||
|
@ -43,11 +43,10 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
|
||||
// Returns a textual representation of the names of the nodes in the input.
|
||||
template <typename T>
|
||||
string NodesToString(const T& nodes) {
|
||||
return strings::StrCat("{",
|
||||
return absl::StrCat("{",
|
||||
absl::StrJoin(nodes, ",",
|
||||
[](string* output, const Node* node) {
|
||||
strings::StrAppend(output,
|
||||
node->name());
|
||||
absl::StrAppend(output, node->name());
|
||||
}),
|
||||
"}");
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame,
|
||||
StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
|
||||
const char* const kArgOp = "_Arg";
|
||||
NodeDef arg_def;
|
||||
NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
|
||||
NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
|
||||
builder.Attr("T", type);
|
||||
builder.Attr("index", index);
|
||||
TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
|
||||
@ -487,9 +487,9 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
static std::atomic<int64> sequence_num(0LL);
|
||||
int64 id = ++sequence_num;
|
||||
NameAttrList cond_name;
|
||||
cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
|
||||
cond_name.set_name(absl::StrCat("_functionalize_cond_", id));
|
||||
NameAttrList body_name;
|
||||
body_name.set_name(strings::StrCat("_functionalize_body_", id));
|
||||
body_name.set_name(absl::StrCat("_functionalize_body_", id));
|
||||
FunctionDef cond_fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
|
||||
|
@ -127,7 +127,7 @@ Status GraphCompiler::Compile() {
|
||||
TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
|
||||
<< "Not supported node: " << n->DebugString();
|
||||
params.op_kernel = op_kernel.get();
|
||||
gtl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
|
||||
absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
|
||||
params.output_attr_array = output_attr.data();
|
||||
|
||||
// tensor_inputs_ is a buffer reused across graph traversal. We clean up and
|
||||
|
@ -89,7 +89,7 @@ class GraphCompiler {
|
||||
ScopedStepContainer* step_container_;
|
||||
// A buffer to hold tensor inputs to a node, this is reused across the graph
|
||||
// traversal.
|
||||
gtl::InlinedVector<TensorValue, 4> tensor_inputs_;
|
||||
absl::InlinedVector<TensorValue, 4> tensor_inputs_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -26,7 +26,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
absl::Span<const int64> block_shape,
|
||||
const xla::Literal& crops) {
|
||||
const int input_rank = input_tensor_shape.dims();
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
const int block_rank = block_shape.size();
|
||||
|
||||
|
@ -39,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->num_inputs() == 2,
|
||||
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
|
||||
gtl::InlinedVector<BCast::Vec, 2> shapes;
|
||||
absl::InlinedVector<BCast::Vec, 2> shapes;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const TensorShape in_shape = ctx->InputShape(i);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
|
||||
@ -88,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel {
|
||||
ctx, ctx->num_inputs() == 2,
|
||||
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
|
||||
|
||||
gtl::InlinedVector<BCast::Vec, 4> shapes;
|
||||
absl::InlinedVector<BCast::Vec, 4> shapes;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const TensorShape in_shape = ctx->InputShape(i);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
|
||||
|
@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, kRequiredDims == input_rank,
|
||||
errors::InvalidArgument("Input rank should be ", kRequiredDims,
|
||||
"; got: ", input_rank));
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
@ -138,7 +138,7 @@ xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
|
||||
int num_dims = num_spatial_dims + 2;
|
||||
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
|
||||
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
|
||||
gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
|
||||
absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
|
||||
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
|
||||
spatial_dimensions[spatial_dim] =
|
||||
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
|
||||
|
@ -69,7 +69,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
VLOG(1) << "data shape: " << data_shape.DebugString();
|
||||
VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
|
||||
|
||||
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
|
||||
absl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
|
||||
std::vector<int64> xla_axes;
|
||||
int64 num_elements_reduced = 1LL;
|
||||
for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
|
||||
@ -103,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
xla::XlaBuilder* const b = ctx->builder();
|
||||
// Construct the builder for the reduction lambda.
|
||||
xla::XlaBuilder r(strings::StrCat(desc, "-reduction"));
|
||||
xla::XlaBuilder r(absl::StrCat(desc, "-reduction"));
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
|
||||
|
||||
|
@ -97,7 +97,7 @@ class ReverseV2Op : public XlaOpKernel {
|
||||
|
||||
// witnessed_axes is used to ensure that the same axis is not marked to be
|
||||
// reversed multiple times.
|
||||
gtl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
|
||||
absl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
|
||||
|
||||
for (int d = 0; d < axes.size(); ++d) {
|
||||
OP_REQUIRES(
|
||||
|
@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel {
|
||||
// accept legacy scalars, even when they should be forbidden by the graphdef
|
||||
// version.
|
||||
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
errors::InvalidArgument(absl::StrCat(
|
||||
"dim input to ExpandDims must be a scalar; got ",
|
||||
dim_shape.DebugString())));
|
||||
|
||||
|
@ -26,7 +26,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
absl::Span<const int64> block_shape,
|
||||
const xla::Literal& paddings) {
|
||||
const int input_rank = input_tensor_shape.dims();
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
const int block_rank = block_shape.size();
|
||||
|
||||
|
@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, kRequiredDims == input_rank,
|
||||
errors::InvalidArgument("Input rank should be ", kRequiredDims,
|
||||
"; got ", input_rank));
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel {
|
||||
xla::XlaOp value;
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
XlaResource* resource;
|
||||
string name = strings::StrCat("Stack: ", stack_name_);
|
||||
string name = absl::StrCat("Stack: ", stack_name_);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
|
||||
TensorShape(), value, /*tensor_array_size=*/size,
|
||||
|
@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
TensorShape final_shape;
|
||||
gtl::InlinedVector<int64, 4> begin;
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
xla::Literal begin_literal, end_literal, strides_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||
@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
if (strides[i] > 0) {
|
||||
@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape processing_shape, final_shape;
|
||||
gtl::InlinedVector<int64, 4> begin;
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
|
||||
@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
grad = xla::Reshape(grad, processing_shape.dim_sizes());
|
||||
|
||||
// Pad the input gradients.
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
xla::PaddingConfig padding_config;
|
||||
|
||||
for (int i = 0; i < processing_shape.dims(); ++i) {
|
||||
@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape final_shape;
|
||||
gtl::InlinedVector<int64, 4> begin;
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
xla::Literal begin_literal, end_literal, strides_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||
@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
|
||||
xla::XlaOp rhs = ctx->Input(4);
|
||||
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> slice_begin, slice_dims;
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
// TODO(phawkins): implement strides != 1
|
||||
OP_REQUIRES(
|
||||
|
@ -167,7 +167,7 @@ class TensorArrayOp : public XlaOpKernel {
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
XlaResource* var;
|
||||
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
||||
string name = absl::StrCat("TensorArray: ", tensor_array_name_);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
||||
dtype_, shape, value, /*tensor_array_size=*/size,
|
||||
|
@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel {
|
||||
|
||||
std::vector<int64> transposed_order;
|
||||
// Check whether permutation is a permutation of integers of [0 .. dims).
|
||||
gtl::InlinedVector<bool, 8> bits(dims);
|
||||
absl::InlinedVector<bool, 8> bits(dims);
|
||||
bool is_identity = true;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
const int32 d = perm[i];
|
||||
|
@ -205,7 +205,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ namespace tensorflow {
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder) {
|
||||
int arity = initial_values.size();
|
||||
std::vector<xla::Shape> var_shapes;
|
||||
@ -47,7 +47,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
|
||||
// Build the condition.
|
||||
std::unique_ptr<xla::XlaBuilder> cond_builder =
|
||||
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
|
||||
builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
|
||||
{
|
||||
auto parameter =
|
||||
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
|
||||
@ -61,7 +61,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
|
||||
// Build the body.
|
||||
std::unique_ptr<xla::XlaBuilder> body_builder =
|
||||
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
|
||||
builder->CreateSubBuilder(absl::StrCat(name, "_body"));
|
||||
{
|
||||
auto parameter =
|
||||
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
|
||||
@ -84,7 +84,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder) {
|
||||
auto while_cond_fn =
|
||||
[&](absl::Span<const xla::XlaOp> values,
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -50,7 +50,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder);
|
||||
|
||||
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
||||
@ -65,7 +65,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace tensorflow {
|
||||
/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString(
|
||||
/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
|
||||
XlaResourceOpKind op_kind) {
|
||||
switch (op_kind) {
|
||||
case XlaResourceOpKind::kRead:
|
||||
@ -30,11 +30,11 @@ namespace tensorflow {
|
||||
}
|
||||
}
|
||||
|
||||
static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
|
||||
gtl::FlatMap<StringPiece, XlaResourceOpInfo>* result =
|
||||
new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
|
||||
static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
|
||||
CreateResourceOpInfoMap() {
|
||||
auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
|
||||
|
||||
auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
|
||||
auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
|
||||
XlaResourceKind resource_kind) {
|
||||
auto insert_result =
|
||||
result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
|
||||
@ -103,23 +103,23 @@ static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
|
||||
return result;
|
||||
}
|
||||
|
||||
static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
|
||||
static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
|
||||
GetStaticResourceOpInfoMap() {
|
||||
static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
|
||||
static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
|
||||
CreateResourceOpInfoMap();
|
||||
return *op_info_map;
|
||||
}
|
||||
|
||||
const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) {
|
||||
const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
|
||||
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
|
||||
const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
|
||||
GetStaticResourceOpInfoMap();
|
||||
auto it = op_infos.find(op);
|
||||
return it == op_infos.end() ? nullptr : &it->second;
|
||||
}
|
||||
|
||||
namespace resource_op_table_internal {
|
||||
std::vector<StringPiece> GetKnownResourceOps() {
|
||||
std::vector<StringPiece> result;
|
||||
std::vector<absl::string_view> GetKnownResourceOps() {
|
||||
std::vector<absl::string_view> result;
|
||||
for (const auto& p : GetStaticResourceOpInfoMap()) {
|
||||
result.push_back(p.first);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
// Exposes information about the resource operations supported by tf2xla in a
|
||||
@ -47,7 +47,7 @@ class XlaResourceOpInfo {
|
||||
XlaResourceOpKind kind() const { return op_kind_; }
|
||||
XlaResourceKind resource_kind() const { return resource_kind_; }
|
||||
|
||||
static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind);
|
||||
static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind);
|
||||
|
||||
private:
|
||||
XlaResourceOpKind op_kind_;
|
||||
@ -57,13 +57,13 @@ class XlaResourceOpInfo {
|
||||
// Returns a XlaResourceOpInfo describing `op` if it is a resource operation
|
||||
// supported by tf2xla, otherwise returns null (i.e. if this returns null then
|
||||
// `op` is either not a resource operation or is unsupported by XLA).
|
||||
const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op);
|
||||
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op);
|
||||
|
||||
namespace resource_op_table_internal {
|
||||
// NB! Implementation detail exposed for unit testing, do not use.
|
||||
//
|
||||
// Returns the set of resource operations known by this module.
|
||||
std::vector<StringPiece> GetKnownResourceOps();
|
||||
std::vector<absl::string_view> GetKnownResourceOps();
|
||||
} // namespace resource_op_table_internal
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) {
|
||||
|
||||
TEST(ResourceOperationTableTest, HaveAllResourceOps) {
|
||||
gtl::FlatMap<string, bool> known_resource_ops;
|
||||
for (StringPiece known_resource_op :
|
||||
for (absl::string_view known_resource_op :
|
||||
resource_op_table_internal::GetKnownResourceOps()) {
|
||||
ASSERT_TRUE(
|
||||
known_resource_ops.insert({string(known_resource_op), false}).second);
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
@ -41,7 +42,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -75,7 +75,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
auto node_it = node_map.find(remap_it->second);
|
||||
if (node_it == node_map.end()) {
|
||||
// Strip off the aot_feed_#/ prefix.
|
||||
StringPiece name(remap_it->second);
|
||||
absl::string_view name(remap_it->second);
|
||||
const auto index = name.find('/');
|
||||
if (index > 0) name.remove_prefix(index + 1);
|
||||
return errors::InvalidArgument(
|
||||
@ -89,7 +89,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
// explicitly specify or override them.
|
||||
Node* arg_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||
NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp)
|
||||
.Attr("T", BaseType(feed_node->output_type(output_index)))
|
||||
.Attr("index", arg_index)
|
||||
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
|
||||
@ -136,7 +136,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
||||
// Connects fetch_node -> retval_node.
|
||||
Node* retval_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
|
||||
NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp)
|
||||
.Input(fetch_node, id.output_index())
|
||||
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
|
||||
.Attr("index", ret_index)
|
||||
@ -256,7 +256,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
for (Node* node : graph->nodes()) {
|
||||
node->set_assigned_device_name(
|
||||
strings::StrCat("/device:", DEVICE_CPU_XLA_JIT));
|
||||
absl::StrCat("/device:", DEVICE_CPU_XLA_JIT));
|
||||
}
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
@ -33,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -112,8 +112,8 @@ Status AddPlaceholdersForFeeds(
|
||||
const string name_port = TensorIdToString(feed->id());
|
||||
PlaceholderInfo& info = placeholder_info[name_port];
|
||||
info.feed = feed;
|
||||
info.placeholder_name = strings::StrCat(
|
||||
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
|
||||
info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
|
||||
"/", feed->id().node_name());
|
||||
(*feed_remapping)[name_port] = info.placeholder_name;
|
||||
}
|
||||
|
||||
@ -258,7 +258,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
|
||||
}
|
||||
|
||||
string TensorIdToString(const tf2xla::TensorId& id) {
|
||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||
return absl::StrCat(id.node_name(), ":", id.output_index());
|
||||
}
|
||||
|
||||
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
|
||||
@ -289,7 +289,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
|
||||
void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
|
||||
KernelDef* kdef) {
|
||||
for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
|
||||
if (constraint.name() == name) {
|
||||
|
@ -53,7 +53,7 @@ string TensorIdToString(const tf2xla::TensorId& id);
|
||||
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
|
||||
|
||||
// Add an allowed data type to the AttrConstraint with the given name.
|
||||
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
|
||||
void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
|
||||
KernelDef* kdef);
|
||||
|
||||
// Returns the next random seed to use for seeding xla rng.
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
@ -25,8 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -153,7 +153,7 @@ static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
|
||||
tf2xla::Config config;
|
||||
for (const auto& fetch_node_name : fetches) {
|
||||
auto* fetch = config.add_fetch();
|
||||
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
|
||||
fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
|
||||
fetch->mutable_id()->set_node_name(fetch_node_name);
|
||||
}
|
||||
return config;
|
||||
|
@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator {
|
||||
|
||||
XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
|
||||
DeviceType type)
|
||||
: LocalDevice(
|
||||
options,
|
||||
Device::BuildDeviceAttributes(
|
||||
strings::StrCat("/device:", type.type(), ":0"), type,
|
||||
Bytes(256 << 20), DeviceLocality(),
|
||||
strings::StrCat("device: XLA compilation device ", type.type()))),
|
||||
: LocalDevice(options, Device::BuildDeviceAttributes(
|
||||
absl::StrCat("/device:", type.type(), ":0"),
|
||||
type, Bytes(256 << 20), DeviceLocality(),
|
||||
absl::StrCat("device: XLA compilation device ",
|
||||
type.type()))),
|
||||
allocator_(new XlaCompilationAllocator()) {}
|
||||
|
||||
XlaCompilationDevice::~XlaCompilationDevice() {}
|
||||
|
@ -198,14 +198,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
// lowest-numbered core that consumes the argument. We choose the
|
||||
// lowest-numbered core so the assignment is deterministic.
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (StringPiece(n->type_string()) == "_Arg") {
|
||||
if (absl::string_view(n->type_string()) == "_Arg") {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
|
||||
}
|
||||
}
|
||||
// Do _Retval as a second loop, in case the retval's input is an _Arg (which
|
||||
// may have gotten a device assignment from the first loop).
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (StringPiece(n->type_string()) == "_Retval") {
|
||||
if (absl::string_view(n->type_string()) == "_Retval") {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
|
||||
}
|
||||
}
|
||||
@ -213,8 +213,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "XlaCompiler::CompileFunction: "
|
||||
<< dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("xla_compile_function_", function_id),
|
||||
*graph);
|
||||
absl::StrCat("xla_compile_function_", function_id), *graph);
|
||||
}
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
@ -522,7 +521,7 @@ Status XlaCompiler::BuildArguments(
|
||||
|
||||
// Use the _Arg nodes in the graph to resolve core assignments.
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (StringPiece(n->type_string()) != "_Arg") continue;
|
||||
if (absl::string_view(n->type_string()) != "_Arg") continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0 && index < args.size())
|
||||
@ -581,7 +580,7 @@ Status XlaCompiler::BuildArguments(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
|
||||
strings::StrCat("arg", i));
|
||||
absl::StrCat("arg", i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -644,7 +643,7 @@ Status XlaCompiler::CompileSingleOp(
|
||||
// dependency edge to the _SOURCE node.
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
Node* node;
|
||||
string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
|
||||
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
|
||||
Status status = NodeBuilder(name, "_Arg")
|
||||
.ControlInput(graph->source_node())
|
||||
.Attr("T", ctx->input_dtype(i))
|
||||
@ -657,7 +656,7 @@ Status XlaCompiler::CompileSingleOp(
|
||||
// Similarly with return values, create dummy _Retval nodes fed by `node`.
|
||||
for (int64 i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Node* node;
|
||||
string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
|
||||
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
|
||||
Status status = NodeBuilder(name, "_Retval")
|
||||
.Input(main_node, i)
|
||||
.Attr("T", ctx->expected_output_dtype(i))
|
||||
@ -693,7 +692,7 @@ Status ValidateGraph(const Graph* graph,
|
||||
const DeviceType& device_type, const string& name) {
|
||||
auto maybe_error = [&](const Node* node, const Status& s) -> Status {
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"Detected unsupported operations when trying to compile graph ", name,
|
||||
" on ", device_type.type_string(), ": ", node->def().op(), " (",
|
||||
s.error_message(), ")", FormatNodeForError(*node)));
|
||||
@ -734,7 +733,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "XlaCompiler::CompileGraph: "
|
||||
<< dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("xla_compile_graph_", name), *graph);
|
||||
absl::StrCat("xla_compile_graph_", name), *graph);
|
||||
}
|
||||
|
||||
// Report the error here if initialization failed.
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
|
||||
return GetComputationFromTensor(context_->input(index));
|
||||
}
|
||||
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
|
||||
return GetComputationFromTensor(GetInputTensorByName(name));
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) {
|
||||
return context_->input(index).shape();
|
||||
}
|
||||
|
||||
TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
|
||||
TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
|
||||
return GetInputTensorByName(name).shape();
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ Status XlaOpKernelContext::ConstantInput(int index,
|
||||
}
|
||||
|
||||
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
|
||||
StringPiece name) {
|
||||
absl::string_view name) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
@ -112,7 +112,7 @@ static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
|
||||
return start;
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInput(StringPiece name,
|
||||
Status XlaOpKernelContext::ConstantInput(absl::string_view name,
|
||||
xla::Literal* constant_literal) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
return ConstantInput(index, constant_literal);
|
||||
@ -265,7 +265,7 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
|
||||
return LiteralToInt64Scalar(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name,
|
||||
Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
|
||||
int64* out) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
return ConstantInputAsIntScalar(index, out);
|
||||
@ -305,7 +305,7 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
|
||||
return LiteralToInt64Vector(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name,
|
||||
Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
|
||||
std::vector<int64>* out) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
return ConstantInputAsIntVector(index, out);
|
||||
@ -344,7 +344,7 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
|
||||
}
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name,
|
||||
Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
|
||||
xla::Literal* out) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
return ConstantInputAsInt64Literal(index, out);
|
||||
@ -361,7 +361,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::InputList(StringPiece name,
|
||||
Status XlaOpKernelContext::InputList(absl::string_view name,
|
||||
std::vector<xla::XlaOp>* handles,
|
||||
std::vector<TensorShape>* shapes) {
|
||||
OpInputList inputs;
|
||||
@ -376,7 +376,7 @@ Status XlaOpKernelContext::InputList(StringPiece name,
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputList(
|
||||
StringPiece name, std::vector<xla::Literal>* outputs) {
|
||||
absl::string_view name, std::vector<xla::Literal>* outputs) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
|
||||
outputs->resize(stop - start);
|
||||
@ -429,8 +429,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
|
||||
value);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
|
||||
TensorShape* shape,
|
||||
Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
|
||||
DataType type, TensorShape* shape,
|
||||
xla::XlaOp* value) {
|
||||
return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
|
||||
shape, value);
|
||||
@ -564,7 +564,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
|
||||
handle, builder());
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
|
||||
Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
|
||||
xla::XlaOp handle) {
|
||||
TF_RET_CHECK(handle.valid());
|
||||
return AssignVariableTensor(GetInputTensorByName(name), type, context_,
|
||||
@ -610,7 +610,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
|
||||
return XlaContext::Get(context_).GetOrCreateMul(type);
|
||||
}
|
||||
|
||||
const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
|
||||
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
|
||||
const Tensor* tensor;
|
||||
CHECK(context_->input(name, &tensor).ok());
|
||||
return *tensor;
|
||||
|
@ -80,14 +80,14 @@ class XlaOpKernelContext {
|
||||
TensorShape InputShape(int index);
|
||||
|
||||
// Returns the shape of input `name`.
|
||||
TensorShape InputShape(StringPiece name);
|
||||
TensorShape InputShape(absl::string_view name);
|
||||
|
||||
// Returns input `index` as a XlaOp. Unlike
|
||||
// OpKernelContext::Input returns a symbolic value rather than a concrete
|
||||
// Tensor.
|
||||
const xla::XlaOp& Input(int index);
|
||||
// Returns input `name` as a XlaOp.
|
||||
const xla::XlaOp& Input(StringPiece name);
|
||||
const xla::XlaOp& Input(absl::string_view name);
|
||||
|
||||
// Returns true if all inputs are the same shape, otherwise sets the
|
||||
// status to a non-OK value and returns false.
|
||||
@ -97,7 +97,7 @@ class XlaOpKernelContext {
|
||||
// Returns the named list-valued immutable input in "list", as
|
||||
// defined in the OpDef. If the named output is not list-valued,
|
||||
// returns a one-element list.
|
||||
Status InputList(StringPiece name, std::vector<xla::XlaOp>* handles,
|
||||
Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
|
||||
std::vector<TensorShape>* shapes);
|
||||
|
||||
// Helper methods for constant inputs.
|
||||
@ -106,7 +106,7 @@ class XlaOpKernelContext {
|
||||
// expression cannot be evaluated, e.g., because it depends on unbound
|
||||
// parameters, returns a non-OK status.
|
||||
Status ConstantInput(int index, xla::Literal* constant_literal);
|
||||
Status ConstantInput(StringPiece name, xla::Literal* constant_literal);
|
||||
Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
|
||||
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
@ -118,14 +118,15 @@ class XlaOpKernelContext {
|
||||
|
||||
// Converts a constant scalar int32 or int64 tensor into an int64.
|
||||
Status ConstantInputAsIntScalar(int index, int64* out);
|
||||
Status ConstantInputAsIntScalar(StringPiece name, int64* out);
|
||||
Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
|
||||
|
||||
// Converts a constant scalar float32 or float64 tensor into a float64.
|
||||
Status ConstantInputAsFloatScalar(int index, double* out);
|
||||
|
||||
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
|
||||
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
|
||||
Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out);
|
||||
Status ConstantInputAsIntVector(absl::string_view name,
|
||||
std::vector<int64>* out);
|
||||
|
||||
// Reshapes and converts a constant int32 or int64 tensor into a vector of
|
||||
// int64s.
|
||||
@ -133,7 +134,7 @@ class XlaOpKernelContext {
|
||||
|
||||
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
|
||||
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
|
||||
Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out);
|
||||
Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
|
||||
|
||||
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
|
||||
Status ConstantInputAsShape(int index, TensorShape* shape);
|
||||
@ -141,7 +142,7 @@ class XlaOpKernelContext {
|
||||
// Returns the named list-valued immutable input in "list", as
|
||||
// defined in the OpDef. If the named output is not list-valued,
|
||||
// returns a one-element list.
|
||||
Status ConstantInputList(StringPiece name,
|
||||
Status ConstantInputList(absl::string_view name,
|
||||
std::vector<xla::Literal>* literals);
|
||||
|
||||
// Outputs
|
||||
@ -190,8 +191,8 @@ class XlaOpKernelContext {
|
||||
xla::XlaOp* value);
|
||||
// Reads the current value of the resouce variable referred to by input
|
||||
// `name`.
|
||||
Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
|
||||
xla::XlaOp* value);
|
||||
Status ReadVariableInput(absl::string_view name, DataType type,
|
||||
TensorShape* shape, xla::XlaOp* value);
|
||||
|
||||
// Assigns the value `handle` to the variable referenced by input
|
||||
// `input_index`. The variable must be of `type`. Returns an error if the
|
||||
@ -199,7 +200,8 @@ class XlaOpKernelContext {
|
||||
// different shape.
|
||||
Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
|
||||
// Assigns the value `handle` to the variable referenced by input `name`.
|
||||
Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
|
||||
Status AssignVariable(absl::string_view name, DataType type,
|
||||
xla::XlaOp handle);
|
||||
|
||||
// Helper routines for the OP_REQUIRES macros
|
||||
void CtxFailure(const Status& s);
|
||||
@ -248,7 +250,7 @@ class XlaOpKernelContext {
|
||||
|
||||
private:
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(StringPiece name);
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
||||
OpKernelContext* const context_;
|
||||
};
|
||||
|
@ -371,26 +371,28 @@ XlaOpRegistry& XlaOpRegistry::Instance() {
|
||||
return *r;
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) {
|
||||
XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) {
|
||||
registration_.reset(new XlaOpRegistry::OpRegistration);
|
||||
registration_->name = string(name);
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
|
||||
XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
|
||||
absl::string_view name) {
|
||||
XlaOpRegistrationBuilder registration(name);
|
||||
return registration;
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||
absl::Span<const StringPiece> devices) {
|
||||
absl::Span<const absl::string_view> devices) {
|
||||
registration_->has_device_whitelist = true;
|
||||
for (StringPiece device : devices) {
|
||||
for (absl::string_view device : devices) {
|
||||
registration_->device_whitelist.emplace(device);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) {
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||
absl::string_view device) {
|
||||
registration_->has_device_whitelist = true;
|
||||
registration_->device_whitelist.emplace(device);
|
||||
return *this;
|
||||
@ -407,7 +409,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() {
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
StringPiece attr_name, DataType allowed) {
|
||||
absl::string_view attr_name, DataType allowed) {
|
||||
std::set<DataType>& types =
|
||||
registration_->type_constraints[string(attr_name)];
|
||||
types.insert(allowed);
|
||||
@ -415,7 +417,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
StringPiece attr_name, absl::Span<const DataType> allowed) {
|
||||
absl::string_view attr_name, absl::Span<const DataType> allowed) {
|
||||
std::set<DataType>& types =
|
||||
registration_->type_constraints[string(attr_name)];
|
||||
for (DataType t : allowed) {
|
||||
@ -425,7 +427,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
|
||||
StringPiece input_name) {
|
||||
absl::string_view input_name) {
|
||||
registration_->compile_time_constant_inputs.emplace(input_name);
|
||||
return *this;
|
||||
}
|
||||
@ -452,7 +454,7 @@ XlaOpRegistrar::XlaOpRegistrar(
|
||||
}
|
||||
|
||||
XlaBackendRegistrar::XlaBackendRegistrar(
|
||||
StringPiece name, absl::Span<const DataType> types,
|
||||
absl::string_view name, absl::Span<const DataType> types,
|
||||
XlaOpRegistry::BackendOpFilter op_filter) {
|
||||
XlaOpRegistry& registry = XlaOpRegistry::Instance();
|
||||
registry.RegisterBackend(string(name), types, op_filter);
|
||||
|
@ -232,18 +232,18 @@ class XlaOpRegistry {
|
||||
class XlaOpRegistrationBuilder {
|
||||
public:
|
||||
// Starts an operator registration chain.
|
||||
static XlaOpRegistrationBuilder Name(StringPiece name);
|
||||
static XlaOpRegistrationBuilder Name(absl::string_view name);
|
||||
|
||||
// Specifies a whitelist of devices on which the operator may run.
|
||||
XlaOpRegistrationBuilder& Device(StringPiece devices);
|
||||
XlaOpRegistrationBuilder& Device(absl::Span<const StringPiece> devices);
|
||||
XlaOpRegistrationBuilder& Device(absl::string_view devices);
|
||||
XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
|
||||
|
||||
// Specifies a type constraint for a type variable attribute. Each constraint
|
||||
// specifies the set of types that the type variable may assume.
|
||||
XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
|
||||
XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
|
||||
DataType allowed);
|
||||
|
||||
XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
|
||||
XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
|
||||
absl::Span<const DataType> allowed);
|
||||
|
||||
// Specifies that a dummy copy of this operator should not be registered on
|
||||
@ -254,13 +254,13 @@ class XlaOpRegistrationBuilder {
|
||||
XlaOpRegistrationBuilder& AllowResourceTypes();
|
||||
|
||||
// Mark 'input_name' as an argument whose value must be known at compile-time.
|
||||
XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
|
||||
XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
|
||||
|
||||
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
|
||||
XlaOpRegistry::Factory factory);
|
||||
|
||||
private:
|
||||
XlaOpRegistrationBuilder(StringPiece name);
|
||||
XlaOpRegistrationBuilder(absl::string_view name);
|
||||
|
||||
std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
|
||||
};
|
||||
@ -288,7 +288,7 @@ class XlaOpRegistrar {
|
||||
|
||||
class XlaBackendRegistrar {
|
||||
public:
|
||||
XlaBackendRegistrar(StringPiece name, absl::Span<const DataType> types,
|
||||
XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
|
||||
XlaOpRegistry::BackendOpFilter op_filter = nullptr);
|
||||
};
|
||||
|
||||
|
@ -43,7 +43,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
|
||||
for (const string& gradient : tensor_array_gradients) {
|
||||
tensor_array_gradients_[gradient].reset(new XlaResource(
|
||||
/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||
/*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_,
|
||||
/*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
|
||||
xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}));
|
||||
}
|
||||
}
|
||||
@ -135,7 +135,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
|
||||
xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
|
||||
gradient.reset(
|
||||
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
||||
/*name=*/absl::StrCat("TensorArrayGrad: ", name_),
|
||||
type_, shape_, gradient_value, tensor_array_size_,
|
||||
/*tensor_array_gradients=*/{}));
|
||||
}
|
||||
|
@ -2520,6 +2520,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3187,6 +3188,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -259,7 +259,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
|
||||
// Fusing a reduce into a loop fusion would require changing the fusion kind.
|
||||
// That's not supported yet.
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[6400]{0} parameter(0)
|
||||
ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
|
||||
@ -277,7 +277,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[6400]{0} parameter(0)
|
||||
ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
|
||||
@ -301,7 +301,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
@ -324,7 +324,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
@ -358,7 +358,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
|
||||
|
||||
TEST_F(MultiOutputFusionTest,
|
||||
MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -15,10 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -56,6 +56,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/stream_executor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
@ -46,19 +46,15 @@ cc_library(
|
||||
deps = [
|
||||
":xrt_state_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/compiler/xrt:xrt_proto",
|
||||
"//tensorflow/compiler/xrt:xrt_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -67,6 +63,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor:stream_executor_headers_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -40,7 +41,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -70,7 +70,7 @@ Status CompilationCacheKey(const xrt::XLAComputation& computation,
|
||||
string serialized;
|
||||
TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
|
||||
uint64 fingerprint = Fingerprint64(serialized);
|
||||
*key = strings::StrCat(fingerprint);
|
||||
*key = absl::StrCat(fingerprint);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -32,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
@ -201,14 +201,14 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
|
||||
|
||||
/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
|
||||
XRTTupleAllocation** allocation) {
|
||||
string key_string = strings::StrCat(key);
|
||||
string key_string = absl::StrCat(key);
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
|
||||
int64 key) {
|
||||
string key_string = strings::StrCat(key);
|
||||
string key_string = absl::StrCat(key);
|
||||
return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
|
||||
}
|
||||
|
||||
@ -410,7 +410,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
|
||||
|
||||
Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
|
||||
*key = get_uid();
|
||||
string key_string = strings::StrCat(*key);
|
||||
string key_string = absl::StrCat(*key);
|
||||
return rm->Create(kTupleContainer, key_string, this);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user