Use absl instead of deprecated str_util
PiperOrigin-RevId: 249248716
This commit is contained in:
parent
d52b6ddef4
commit
4213d5c1bd
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
// Required for IS_MOBILE_PLATFORM
|
// Required for IS_MOBILE_PLATFORM
|
||||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||||
|
|
||||||
@ -2495,8 +2496,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
|||||||
// used in this graph
|
// used in this graph
|
||||||
for (const auto& pair : g->name_map) {
|
for (const auto& pair : g->name_map) {
|
||||||
const string& name = pair.first;
|
const string& name = pair.first;
|
||||||
if (name.compare(prefix) == 0 ||
|
if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
|
||||||
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
|
|
||||||
status->status = InvalidArgument(
|
status->status = InvalidArgument(
|
||||||
"prefix [", prefix,
|
"prefix [", prefix,
|
||||||
"] conflicts with existing node in the graph named [", name, "]");
|
"] conflicts with existing node in the graph named [", name, "]");
|
||||||
@ -2526,8 +2526,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
|||||||
// Adding the gradients to the graph can alter the prefix to prevent
|
// Adding the gradients to the graph can alter the prefix to prevent
|
||||||
// name collisions only if this prefix has not been provided explicitly
|
// name collisions only if this prefix has not been provided explicitly
|
||||||
// by the user. If it was provided, assert that it remained intact.
|
// by the user. If it was provided, assert that it remained intact.
|
||||||
if (prefix != nullptr &&
|
if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
|
||||||
!tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
|
|
||||||
status->status = tensorflow::errors::Internal(
|
status->status = tensorflow::errors::Internal(
|
||||||
"BUG: The gradients prefix have been unexpectedly altered when "
|
"BUG: The gradients prefix have been unexpectedly altered when "
|
||||||
"adding the nodes to the graph. This is a bug. Please file an "
|
"adding the nodes to the graph. This is a bug. Please file an "
|
||||||
|
@ -62,8 +62,8 @@ protocol: "grpc"
|
|||||||
TF_Buffer* null_result =
|
TF_Buffer* null_result =
|
||||||
TFE_GetServerDef(malformed_text_proto.c_str(), status);
|
TFE_GetServerDef(malformed_text_proto.c_str(), status);
|
||||||
EXPECT_NE(TF_GetCode(status), TF_OK);
|
EXPECT_NE(TF_GetCode(status), TF_OK);
|
||||||
EXPECT_TRUE(tensorflow::str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(TF_Message(status),
|
||||||
TF_Message(status), "Invalid text proto for ServerDef"));
|
"Invalid text proto for ServerDef"));
|
||||||
EXPECT_EQ(null_result, nullptr);
|
EXPECT_EQ(null_result, nullptr);
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
|
@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test {
|
|||||||
const std::unordered_set<string>& nodes) {
|
const std::unordered_set<string>& nodes) {
|
||||||
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
||||||
<< "Got unexpected number of nodes. Expected: ["
|
<< "Got unexpected number of nodes. Expected: ["
|
||||||
<< str_util::Join(nodes, ", ")
|
<< absl::StrJoin(nodes, ", ")
|
||||||
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
||||||
for (const NodeDef& node_def : fdef.node_def()) {
|
for (const NodeDef& node_def : fdef.node_def()) {
|
||||||
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
||||||
|
@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << "Fully padded shape of ["
|
VLOG(3) << "Fully padded shape of ["
|
||||||
<< tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
|
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||||
<< padded_shape.DebugString();
|
<< padded_shape.DebugString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
||||||
bool ok = str_util::StrContains(base, substr);
|
bool ok = absl::StrContains(base, substr);
|
||||||
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
@ -638,6 +638,7 @@ cc_library(
|
|||||||
"//tensorflow/core:op_gen_lib",
|
"//tensorflow/core:op_gen_lib",
|
||||||
"//tensorflow/core:proto_text",
|
"//tensorflow/core:proto_text",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -657,6 +658,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/api_def.pb.h"
|
#include "tensorflow/core/framework/api_def.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string PrintString(const string& str) {
|
string PrintString(const string& str) {
|
||||||
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
return strings::StrCat("\"", absl::CEscape(str), "\"");
|
||||||
}
|
}
|
||||||
|
|
||||||
string PrintTensorShape(const TensorShapeProto& shape_proto) {
|
string PrintTensorShape(const TensorShapeProto& shape_proto) {
|
||||||
@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) {
|
|||||||
string ret;
|
string ret;
|
||||||
for (int64 i = 0; i < num_elts; ++i) {
|
for (int64 i = 0; i < num_elts; ++i) {
|
||||||
if (i > 0) strings::StrAppend(&ret, " ");
|
if (i > 0) strings::StrAppend(&ret, " ");
|
||||||
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
|
strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -62,12 +62,12 @@ op {
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_FALSE(str_util::StrContains(s, expected))
|
EXPECT_FALSE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' contains '" << expected << "'";
|
<< "'" << s << "' contains '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
|||||||
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
||||||
for (const string& entry : node_constraints) {
|
for (const string& entry : node_constraints) {
|
||||||
StringPiece s(entry);
|
StringPiece s(entry);
|
||||||
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
||||||
current_constraints.emplace(s);
|
current_constraints.emplace(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
|||||||
const Status status = LoadSavedModelInternal(session_options, run_options,
|
const Status status = LoadSavedModelInternal(session_options, run_options,
|
||||||
export_dir, tags, bundle);
|
export_dir, tags, bundle);
|
||||||
auto log_and_count = [&](const string& status_str) {
|
auto log_and_count = [&](const string& status_str) {
|
||||||
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
|
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
|
||||||
<< " }; Status: " << status_str << ". Took "
|
<< " }; Status: " << status_str << ". Took "
|
||||||
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
|
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
|
||||||
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{"missing-tag"}, &bundle);
|
{"missing-tag"}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: "))
|
"Could not find meta graph def matching supplied tags: "))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{kSavedModelTagServe}, &bundle);
|
{kSavedModelTagServe}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget))
|
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
|||||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||||
const std::unordered_set<string>& tags,
|
const std::unordered_set<string>& tags,
|
||||||
MetaGraphDef* meta_graph_def) {
|
MetaGraphDef* meta_graph_def) {
|
||||||
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
|
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
|
||||||
<< " }";
|
<< " }";
|
||||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||||
// Get tags from the graph_def.
|
// Get tags from the graph_def.
|
||||||
@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
|||||||
error::Code::NOT_FOUND,
|
error::Code::NOT_FOUND,
|
||||||
strings::StrCat(
|
strings::StrCat(
|
||||||
"Could not find meta graph def matching supplied tags: { ",
|
"Could not find meta graph def matching supplied tags: { ",
|
||||||
str_util::Join(tags, " "),
|
absl::StrJoin(tags, " "),
|
||||||
" }. To inspect available tag-sets in the SavedModel, please "
|
" }. To inspect available tag-sets in the SavedModel, please "
|
||||||
"use the SavedModel CLI: `saved_model_cli`"));
|
"use the SavedModel CLI: `saved_model_cli`"));
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
|||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||||||
Status st = ReadMetaGraphDefFromSavedModel(
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: "))
|
"Could not find meta graph def matching supplied tags: "))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
|
@ -167,8 +167,7 @@ namespace {
|
|||||||
|
|
||||||
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||||
int32* dst) {
|
int32* dst) {
|
||||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
|
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
|
||||||
tensorflow::str_util::ConsumePrefix(&arg, "=")) {
|
|
||||||
char extra;
|
char extra;
|
||||||
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
||||||
}
|
}
|
||||||
@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
|||||||
|
|
||||||
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||||
bool* dst) {
|
bool* dst) {
|
||||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
|
if (absl::ConsumePrefix(&arg, flag)) {
|
||||||
if (arg.empty()) {
|
if (arg.empty()) {
|
||||||
*dst = true;
|
*dst = true;
|
||||||
return true;
|
return true;
|
||||||
|
@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch(
|
|||||||
missing.push_back(entry.first);
|
missing.push_back(entry.first);
|
||||||
}
|
}
|
||||||
return errors::InvalidArgument("Missing shapes for nodes: ",
|
return errors::InvalidArgument("Missing shapes for nodes: ",
|
||||||
str_util::Join(missing, ","));
|
absl::StrJoin(missing, ","));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -543,10 +543,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
|
|||||||
std::map<string, Node*> io_nodes;
|
std::map<string, Node*> io_nodes;
|
||||||
int num_inputs = 0;
|
int num_inputs = 0;
|
||||||
for (auto n : sgraph.op_nodes()) {
|
for (auto n : sgraph.op_nodes()) {
|
||||||
if (str_util::StartsWith(n->name(), kInputPHName)) {
|
if (absl::StartsWith(n->name(), kInputPHName)) {
|
||||||
num_inputs++;
|
num_inputs++;
|
||||||
io_nodes.insert({n->name(), n});
|
io_nodes.insert({n->name(), n});
|
||||||
} else if (str_util::StartsWith(n->name(), kOutputPHName)) {
|
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
|
||||||
io_nodes.insert({n->name(), n});
|
io_nodes.insert({n->name(), n});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1070,7 +1070,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
|||||||
int batch_size = -1;
|
int batch_size = -1;
|
||||||
for (const NodeDef& node : gdef.node()) {
|
for (const NodeDef& node : gdef.node()) {
|
||||||
absl::string_view node_name(node.name());
|
absl::string_view node_name(node.name());
|
||||||
if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
|
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
|
||||||
int port = -1;
|
int port = -1;
|
||||||
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
||||||
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||||
@ -32,9 +34,9 @@ namespace tensorflow {
|
|||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
namespace convert {
|
namespace convert {
|
||||||
// TODO(sami): Remove VLOG messages once the code matures
|
// TODO(sami): Remove VLOG messages once the code matures
|
||||||
|
using absl::AsciiStrToUpper;
|
||||||
using absl::StrAppend;
|
using absl::StrAppend;
|
||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
using str_util::Uppercase;
|
|
||||||
|
|
||||||
Status TRTOptimizationPass::Init(
|
Status TRTOptimizationPass::Init(
|
||||||
const RewriterConfig_CustomGraphOptimizer* config) {
|
const RewriterConfig_CustomGraphOptimizer* config) {
|
||||||
@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init(
|
|||||||
}
|
}
|
||||||
if (params.count("precision_mode")) {
|
if (params.count("precision_mode")) {
|
||||||
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
|
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
|
||||||
Uppercase(params.at("precision_mode").s()), &precision_mode_));
|
AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
|
||||||
}
|
}
|
||||||
if (params.count("use_calibration")) {
|
if (params.count("use_calibration")) {
|
||||||
use_calibration_ = params.at("use_calibration").b();
|
use_calibration_ = params.at("use_calibration").b();
|
||||||
|
@ -27,7 +27,7 @@ namespace {
|
|||||||
string RemoveSuffix(const string& name, const string& suffix) {
|
string RemoveSuffix(const string& name, const string& suffix) {
|
||||||
string output(name);
|
string output(name);
|
||||||
StringPiece piece(output);
|
StringPiece piece(output);
|
||||||
str_util::ConsumeSuffix(&piece, suffix);
|
absl::ConsumeSuffix(&piece, suffix);
|
||||||
return string(piece);
|
return string(piece);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
|
|||||||
|
|
||||||
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
|
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
|
||||||
StringPiece piece(name);
|
StringPiece piece(name);
|
||||||
str_util::ConsumePrefix(&piece, prefix_);
|
absl::ConsumePrefix(&piece, prefix_);
|
||||||
return string(piece);
|
return string(piece);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ string RegexFromStringSet(const std::vector<string>& strs) {
|
|||||||
if (uniq.size() == 1) {
|
if (uniq.size() == 1) {
|
||||||
return *uniq.begin();
|
return *uniq.begin();
|
||||||
}
|
}
|
||||||
return str_util::Join(uniq, "|");
|
return absl::StrJoin(uniq, "|");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -30,7 +30,7 @@ constexpr char kTestDataset[] = "test-dataset";
|
|||||||
constexpr char kTestTable[] = "test-table";
|
constexpr char kTestTable[] = "test-table";
|
||||||
|
|
||||||
bool HasSubstr(StringPiece base, StringPiece substr) {
|
bool HasSubstr(StringPiece base, StringPiece substr) {
|
||||||
bool ok = str_util::StrContains(base, substr);
|
bool ok = absl::StrContains(base, substr);
|
||||||
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
@ -137,17 +137,16 @@ class DecodeAudioOpV2 : public OpKernel {
|
|||||||
|
|
||||||
const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
|
const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
|
||||||
const string file_format =
|
const string file_format =
|
||||||
str_util::Lowercase(file_format_tensor.scalar<string>()());
|
absl::AsciiStrToLower(file_format_tensor.scalar<string>()());
|
||||||
const int32 samples_per_second =
|
const int32 samples_per_second =
|
||||||
samples_per_second_tensor.scalar<int32>()();
|
samples_per_second_tensor.scalar<int32>()();
|
||||||
const int32 channel_count = channel_count_tensor.scalar<int32>()();
|
const int32 channel_count = channel_count_tensor.scalar<int32>()();
|
||||||
|
|
||||||
const std::set<string> valid_file_formats(
|
const std::set<string> valid_file_formats(
|
||||||
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, valid_file_formats.count(file_format) == 1,
|
||||||
context, valid_file_formats.count(file_format) == 1,
|
|
||||||
errors::InvalidArgument("file_format must be one of {",
|
errors::InvalidArgument("file_format must be one of {",
|
||||||
str_util::Join(valid_file_formats, ", "),
|
absl::StrJoin(valid_file_formats, ", "),
|
||||||
"}, but was: \"", file_format, "\""));
|
"}, but was: \"", file_format, "\""));
|
||||||
OP_REQUIRES(context, samples_per_second > 0,
|
OP_REQUIRES(context, samples_per_second > 0,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -220,13 +219,12 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = absl::AsciiStrToLower(file_format_);
|
||||||
const std::set<string> valid_file_formats(
|
const std::set<string> valid_file_formats(
|
||||||
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
|
||||||
context, valid_file_formats.count(file_format_) == 1,
|
|
||||||
errors::InvalidArgument("file_format must be one of {",
|
errors::InvalidArgument("file_format must be one of {",
|
||||||
str_util::Join(valid_file_formats, ", "),
|
absl::StrJoin(valid_file_formats, ", "),
|
||||||
"}, but was: \"", file_format_, "\""));
|
"}, but was: \"", file_format_, "\""));
|
||||||
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
|
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
|
||||||
|
@ -95,7 +95,7 @@ class EncodeAudioOpV2 : public OpKernel {
|
|||||||
bits_per_second_tensor.shape().DebugString()));
|
bits_per_second_tensor.shape().DebugString()));
|
||||||
|
|
||||||
const string file_format =
|
const string file_format =
|
||||||
str_util::Lowercase(file_format_tensor.scalar<string>()());
|
absl::AsciiStrToLower(file_format_tensor.scalar<string>()());
|
||||||
const int32 samples_per_second =
|
const int32 samples_per_second =
|
||||||
samples_per_second_tensor.scalar<int32>()();
|
samples_per_second_tensor.scalar<int32>()();
|
||||||
const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()();
|
const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()();
|
||||||
@ -157,7 +157,7 @@ class EncodeAudioOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = absl::AsciiStrToLower(file_format_);
|
||||||
OP_REQUIRES(context, file_format_ == "wav",
|
OP_REQUIRES(context, file_format_ == "wav",
|
||||||
errors::InvalidArgument("file_format arg must be \"wav\"."));
|
errors::InvalidArgument("file_format arg must be \"wav\"."));
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
|
|||||||
// is set with a truthy value.
|
// is set with a truthy value.
|
||||||
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
|
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
|
||||||
string load_config =
|
string load_config =
|
||||||
load_config_env ? str_util::Lowercase(load_config_env) : "";
|
load_config_env ? absl::AsciiStrToLower(load_config_env) : "";
|
||||||
if (load_config == "true" || load_config == "1") {
|
if (load_config == "true" || load_config == "1") {
|
||||||
Aws::String config_file;
|
Aws::String config_file;
|
||||||
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
|
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
|
||||||
|
@ -182,7 +182,7 @@ class StringCrosser {
|
|||||||
}
|
}
|
||||||
// TODO(zakaria): this will copy the string twice, might effect
|
// TODO(zakaria): this will copy the string twice, might effect
|
||||||
// performance.
|
// performance.
|
||||||
return str_util::Join(cross_vec, k_feature_separator);
|
return absl::StrJoin(cross_vec, k_feature_separator);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -240,7 +240,7 @@ TEST(LoadSessionBundleFromPath, BasicTestRunOptionsThreadPoolInvalid) {
|
|||||||
|
|
||||||
// Expect failed session run calls with invalid run-options.
|
// Expect failed session run calls with invalid run-options.
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Invalid inter_op_thread_pool: 2"))
|
"Invalid inter_op_thread_pool: 2"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -315,7 +315,7 @@ TEST_F(SessionBundleTest, ServingGraphEmpty) {
|
|||||||
});
|
});
|
||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status_.error_message(),
|
EXPECT_TRUE(absl::StrContains(status_.error_message(),
|
||||||
"Expected exactly one serving GraphDef"))
|
"Expected exactly one serving GraphDef"))
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
}
|
}
|
||||||
@ -332,7 +332,7 @@ TEST_F(SessionBundleTest, ServingGraphAnyIncorrectType) {
|
|||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status_.error_message(),
|
absl::StrContains(status_.error_message(),
|
||||||
"Expected Any type_url for: tensorflow.GraphDef"))
|
"Expected Any type_url for: tensorflow.GraphDef"))
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
}
|
}
|
||||||
@ -349,8 +349,7 @@ TEST_F(SessionBundleTest, ServingGraphAnyValueCorrupted) {
|
|||||||
});
|
});
|
||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status_.error_message(), "Failed to unpack"))
|
||||||
str_util::StrContains(status_.error_message(), "Failed to unpack"))
|
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,7 +364,7 @@ TEST_F(SessionBundleTest, AssetFileAnyIncorrectType) {
|
|||||||
});
|
});
|
||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status_.error_message(),
|
status_.error_message(),
|
||||||
"Expected Any type_url for: tensorflow.serving.AssetFile"))
|
"Expected Any type_url for: tensorflow.serving.AssetFile"))
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
@ -383,8 +382,7 @@ TEST_F(SessionBundleTest, AssetFileAnyValueCorrupted) {
|
|||||||
});
|
});
|
||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status_.error_message(), "Failed to unpack"))
|
||||||
str_util::StrContains(status_.error_message(), "Failed to unpack"))
|
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -399,7 +397,7 @@ TEST_F(SessionBundleTest, InitOpTooManyValues) {
|
|||||||
});
|
});
|
||||||
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
|
||||||
EXPECT_FALSE(status_.ok());
|
EXPECT_FALSE(status_.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status_.error_message(),
|
EXPECT_TRUE(absl::StrContains(status_.error_message(),
|
||||||
"Expected exactly one serving init op"))
|
"Expected exactly one serving init op"))
|
||||||
<< status_.error_message();
|
<< status_.error_message();
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ namespace serving {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool HasSubstr(StringPiece base, StringPiece substr) {
|
static bool HasSubstr(StringPiece base, StringPiece substr) {
|
||||||
bool ok = str_util::StrContains(base, substr);
|
bool ok = absl::StrContains(base, substr);
|
||||||
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
@ -70,7 +70,7 @@ TEST(GetClassificationSignature, MissingSignature) {
|
|||||||
ClassificationSignature signature;
|
ClassificationSignature signature;
|
||||||
const Status status = GetClassificationSignature(meta_graph_def, &signature);
|
const Status status = GetClassificationSignature(meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected a classification signature"))
|
"Expected a classification signature"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ TEST(GetClassificationSignature, WrongSignatureType) {
|
|||||||
ClassificationSignature signature;
|
ClassificationSignature signature;
|
||||||
const Status status = GetClassificationSignature(meta_graph_def, &signature);
|
const Status status = GetClassificationSignature(meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected a classification signature"))
|
"Expected a classification signature"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ TEST(GetNamedClassificationSignature, MissingSignature) {
|
|||||||
const Status status =
|
const Status status =
|
||||||
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
|
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Missing signature named \"foo\""))
|
"Missing signature named \"foo\""))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -142,8 +142,8 @@ TEST(GetNamedClassificationSignature, WrongSignatureType) {
|
|||||||
const Status status =
|
const Status status =
|
||||||
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
|
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(
|
||||||
status.error_message(),
|
absl::StrContains(status.error_message(),
|
||||||
"Expected a classification signature for name \"foo\""))
|
"Expected a classification signature for name \"foo\""))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -177,7 +177,7 @@ TEST(GetRegressionSignature, MissingSignature) {
|
|||||||
RegressionSignature signature;
|
RegressionSignature signature;
|
||||||
const Status status = GetRegressionSignature(meta_graph_def, &signature);
|
const Status status = GetRegressionSignature(meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected a regression signature"))
|
"Expected a regression signature"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -194,7 +194,7 @@ TEST(GetRegressionSignature, WrongSignatureType) {
|
|||||||
RegressionSignature signature;
|
RegressionSignature signature;
|
||||||
const Status status = GetRegressionSignature(meta_graph_def, &signature);
|
const Status status = GetRegressionSignature(meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected a regression signature"))
|
"Expected a regression signature"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ TEST(GetNamedSignature, MissingSignature) {
|
|||||||
Signature signature;
|
Signature signature;
|
||||||
const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
|
const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Missing signature named \"foo\""))
|
"Missing signature named \"foo\""))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -371,7 +371,7 @@ TEST(RunClassification, RunNotOk) {
|
|||||||
const Status status = RunClassification(signature, input_tensor, &session,
|
const Status status = RunClassification(signature, input_tensor, &session,
|
||||||
&classes_tensor, nullptr);
|
&classes_tensor, nullptr);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Data is gone"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,8 +387,7 @@ TEST(RunClassification, TooManyOutputs) {
|
|||||||
const Status status = RunClassification(signature, input_tensor, &session,
|
const Status status = RunClassification(signature, input_tensor, &session,
|
||||||
&classes_tensor, nullptr);
|
&classes_tensor, nullptr);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected 1 output"))
|
||||||
str_util::StrContains(status.error_message(), "Expected 1 output"))
|
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,7 +404,7 @@ TEST(RunClassification, WrongBatchOutputs) {
|
|||||||
&classes_tensor, nullptr);
|
&classes_tensor, nullptr);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(),
|
absl::StrContains(status.error_message(),
|
||||||
"Input batch size did not match output batch size"))
|
"Input batch size did not match output batch size"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -452,7 +451,7 @@ TEST_F(RunRegressionTest, RunNotOk) {
|
|||||||
const Status status =
|
const Status status =
|
||||||
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
|
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Data is gone"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -464,7 +463,7 @@ TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
|
|||||||
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
|
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(),
|
absl::StrContains(status.error_message(),
|
||||||
"Input batch size did not match output batch size"))
|
"Input batch size did not match output batch size"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
@ -491,8 +490,7 @@ TEST(GetSignatures, MissingSignature) {
|
|||||||
Signatures read_signatures;
|
Signatures read_signatures;
|
||||||
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
||||||
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected exactly one"))
|
||||||
str_util::StrContains(status.error_message(), "Expected exactly one"))
|
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -506,7 +504,7 @@ TEST(GetSignatures, WrongProtoInAny) {
|
|||||||
Signatures read_signatures;
|
Signatures read_signatures;
|
||||||
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
||||||
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected Any type_url for: "
|
"Expected Any type_url for: "
|
||||||
"tensorflow.serving.Signatures"))
|
"tensorflow.serving.Signatures"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
@ -523,7 +521,7 @@ TEST(GetSignatures, JunkInAny) {
|
|||||||
Signatures read_signatures;
|
Signatures read_signatures;
|
||||||
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
||||||
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "Failed to unpack"))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Failed to unpack"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -570,8 +568,7 @@ TEST(GetSignatures, MultipleSignaturesNotOK) {
|
|||||||
Signatures read_signatures;
|
Signatures read_signatures;
|
||||||
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
const auto status = GetSignatures(meta_graph_def, &read_signatures);
|
||||||
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected exactly one"))
|
||||||
str_util::StrContains(status.error_message(), "Expected exactly one"))
|
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -645,7 +642,7 @@ TEST(GetGenericSignature, WrongSignatureType) {
|
|||||||
const Status status =
|
const Status status =
|
||||||
GetGenericSignature("generic_bindings", meta_graph_def, &signature);
|
GetGenericSignature("generic_bindings", meta_graph_def, &signature);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"Expected a generic signature:"))
|
"Expected a generic signature:"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
|
|||||||
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
|
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
|
||||||
strings::safe_strto64(parts[5], &step_id);
|
strings::safe_strto64(parts[5], &step_id);
|
||||||
parts.pop_back(); // remove step_id
|
parts.pop_back(); // remove step_id
|
||||||
key.assign(str_util::Join(parts, ";")); // stitch them together
|
key.assign(absl::StrJoin(parts, ";")); // stitch them together
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1084,6 +1084,7 @@ cc_library(
|
|||||||
":lib_internal",
|
":lib_internal",
|
||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
"//tensorflow/core/util/proto:proto_utils",
|
"//tensorflow/core/util/proto:proto_utils",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3774,6 +3775,7 @@ tf_cc_tests(
|
|||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@zlib_archive//:zlib",
|
"@zlib_archive//:zlib",
|
||||||
],
|
],
|
||||||
@ -3790,6 +3792,7 @@ tf_cc_test(
|
|||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
":test",
|
":test",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ string RemoveDoc(const OpDef& op, const string& file_contents,
|
|||||||
}
|
}
|
||||||
// Remove .Doc call.
|
// Remove .Doc call.
|
||||||
auto before_doc = file_contents.substr(0, doc_start_location);
|
auto before_doc = file_contents.substr(0, doc_start_location);
|
||||||
str_util::StripTrailingWhitespace(&before_doc);
|
absl::StripTrailingAsciiWhitespace(&before_doc);
|
||||||
return before_doc +
|
return before_doc +
|
||||||
file_contents.substr(doc_end_location + sizeof(kDocEnd) - 1);
|
file_contents.substr(doc_end_location + sizeof(kDocEnd) - 1);
|
||||||
}
|
}
|
||||||
|
@ -645,7 +645,7 @@ Status ColocationGraph::ColocateAllNodes() {
|
|||||||
if (attr_value != nullptr && attr_value->has_list()) {
|
if (attr_value != nullptr && attr_value->has_list()) {
|
||||||
for (const string& class_spec : attr_value->list().s()) {
|
for (const string& class_spec : attr_value->list().s()) {
|
||||||
StringPiece spec(class_spec);
|
StringPiece spec(class_spec);
|
||||||
if (str_util::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
|
if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
|
||||||
found_spec = true;
|
found_spec = true;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ColocateNodeToGroup(&colocation_group_root, node, spec));
|
ColocateNodeToGroup(&colocation_group_root, node, spec));
|
||||||
@ -1098,7 +1098,7 @@ Status ColocationGraph::GetDevicesForNode(
|
|||||||
|
|
||||||
string gpu_msg = "";
|
string gpu_msg = "";
|
||||||
if (!IsGoogleCudaEnabled() &&
|
if (!IsGoogleCudaEnabled() &&
|
||||||
str_util::Lowercase(specified_device_name.type) == "gpu") {
|
absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
|
||||||
gpu_msg =
|
gpu_msg =
|
||||||
" The requested device appears to be a GPU, but CUDA is not "
|
" The requested device appears to be a GPU, but CUDA is not "
|
||||||
"enabled.";
|
"enabled.";
|
||||||
@ -1108,7 +1108,7 @@ Status ColocationGraph::GetDevicesForNode(
|
|||||||
errors::FormatNodeNameForError(node->name()),
|
errors::FormatNodeNameForError(node->name()),
|
||||||
"was explicitly assigned to ", node->requested_device(),
|
"was explicitly assigned to ", node->requested_device(),
|
||||||
" but available devices are [ ",
|
" but available devices are [ ",
|
||||||
str_util::Join(device_names, ", "), " ]. Make sure ",
|
absl::StrJoin(device_names, ", "), " ]. Make sure ",
|
||||||
"the device specification refers to a valid device.", gpu_msg);
|
"the device specification refers to a valid device.", gpu_msg);
|
||||||
} else if (specified_device_name.has_type) {
|
} else if (specified_device_name.has_type) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -1315,7 +1315,7 @@ Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
|
|||||||
"with these attrs: [", node.attrs().DebugString(),
|
"with these attrs: [", node.attrs().DebugString(),
|
||||||
"]\n"
|
"]\n"
|
||||||
"Registered devices: [",
|
"Registered devices: [",
|
||||||
str_util::Join(registered_device_types, ", "), "]\n",
|
absl::StrJoin(registered_device_types, ", "), "]\n",
|
||||||
"Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
|
"Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
|||||||
device_names.push_back(itr.first);
|
device_names.push_back(itr.first);
|
||||||
}
|
}
|
||||||
VLOG(1) << "Unknown device: " << name
|
VLOG(1) << "Unknown device: " << name
|
||||||
<< " all devices: " << str_util::Join(device_names, ", ");
|
<< " all devices: " << absl::StrJoin(device_names, ", ");
|
||||||
return errors::InvalidArgument(name, " unknown device.");
|
return errors::InvalidArgument(name, " unknown device.");
|
||||||
}
|
}
|
||||||
*device = iter->second;
|
*device = iter->second;
|
||||||
|
@ -1346,8 +1346,8 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
|
|
||||||
// Fast lookup path, no sorting.
|
// Fast lookup path, no sorting.
|
||||||
const string key = strings::StrCat(
|
const string key = strings::StrCat(
|
||||||
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
|
absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
|
||||||
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
|
absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
|
||||||
"/", debug_tensor_watches_summary);
|
"/", debug_tensor_watches_summary);
|
||||||
// Set the handle, if it's needed to log memory or for partial run.
|
// Set the handle, if it's needed to log memory or for partial run.
|
||||||
if (handle_name_counter_value >= 0) {
|
if (handle_name_counter_value >= 0) {
|
||||||
@ -1379,8 +1379,8 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
std::sort(tn_sorted.begin(), tn_sorted.end());
|
std::sort(tn_sorted.begin(), tn_sorted.end());
|
||||||
|
|
||||||
const string sorted_key = strings::StrCat(
|
const string sorted_key = strings::StrCat(
|
||||||
str_util::Join(inputs_sorted, ","), "->",
|
absl::StrJoin(inputs_sorted, ","), "->",
|
||||||
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
|
absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
|
||||||
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
|
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
|
||||||
// Set the handle, if its needed to log memory or for partial run.
|
// Set the handle, if its needed to log memory or for partial run.
|
||||||
if (handle_name_counter_value >= 0) {
|
if (handle_name_counter_value >= 0) {
|
||||||
@ -1549,7 +1549,7 @@ Status DirectSession::CreateGraphs(
|
|||||||
"Creating a partition for ", local_partition_name,
|
"Creating a partition for ", local_partition_name,
|
||||||
" which doesn't exist in the list of available devices. Available "
|
" which doesn't exist in the list of available devices. Available "
|
||||||
"devices: ",
|
"devices: ",
|
||||||
str_util::Join(device_names, ","));
|
absl::StrJoin(device_names, ","));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
|
|||||||
|
|
||||||
Status s = session->RunCallable(handle, {}, nullptr, nullptr);
|
Status s = session->RunCallable(handle, {}, nullptr, nullptr);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(),
|
EXPECT_TRUE(absl::StrContains(s.error_message(),
|
||||||
"`fetch_tensors` must be provided"));
|
"`fetch_tensors` must be provided"));
|
||||||
|
|
||||||
TF_ASSERT_OK(session->ReleaseCallable(handle));
|
TF_ASSERT_OK(session->ReleaseCallable(handle));
|
||||||
@ -177,14 +177,14 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
|
|||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
s = session->RunCallable(handle, {}, &outputs, nullptr);
|
s = session->RunCallable(handle, {}, &outputs, nullptr);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.error_message(),
|
s.error_message(),
|
||||||
"Attempted to run callable after handle was released"));
|
"Attempted to run callable after handle was released"));
|
||||||
|
|
||||||
s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
|
s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(s.error_message(), "No such callable handle"));
|
absl::StrContains(s.error_message(), "No such callable handle"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,8 +260,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "would create a cycle"));
|
||||||
str_util::StrContains(s.error_message(), "would create a cycle"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -275,7 +274,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "unknown node"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown node"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -290,7 +289,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "unknown edge"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown edge"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -305,7 +304,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsNotFound(s));
|
EXPECT_TRUE(errors::IsNotFound(s));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(s.error_message(), "unable to find feed output"));
|
absl::StrContains(s.error_message(), "unable to find feed output"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -322,7 +321,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -337,7 +336,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status s = session->MakeCallable(callable_options, &handle);
|
Status s = session->MakeCallable(callable_options, &handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -762,7 +761,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
|||||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
||||||
&outputs);
|
&outputs);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectSessionTest, MultipleFeedTest_Callable) {
|
TEST(DirectSessionTest, MultipleFeedTest_Callable) {
|
||||||
@ -845,7 +844,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) {
|
|||||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
|
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
|
||||||
&handle);
|
&handle);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
|
TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
|
||||||
@ -999,7 +998,7 @@ TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
|
|||||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
|
||||||
&outputs, nullptr);
|
&outputs, nullptr);
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
|
REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
|
||||||
@ -1229,8 +1228,8 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
|
|||||||
s = session->PRun(handle, {{first_const->name(), value_11}},
|
s = session->PRun(handle, {{first_const->name(), value_11}},
|
||||||
{third_identity->name() + ":0"}, &outputs);
|
{third_identity->name() + ":0"}, &outputs);
|
||||||
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(),
|
EXPECT_TRUE(
|
||||||
"can't be computed from the feeds"));
|
absl::StrContains(s.error_message(), "can't be computed from the feeds"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
|
TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
|
||||||
@ -1259,8 +1258,8 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
|
|||||||
// Fetch fourth_identity without feeds.
|
// Fetch fourth_identity without feeds.
|
||||||
s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
|
s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
|
||||||
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(),
|
EXPECT_TRUE(
|
||||||
"can't be computed from the feeds"));
|
absl::StrContains(s.error_message(), "can't be computed from the feeds"));
|
||||||
|
|
||||||
// Feed switch_node:1 and fetch fourth_identity.
|
// Feed switch_node:1 and fetch fourth_identity.
|
||||||
s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
|
s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
|
||||||
@ -2093,7 +2092,7 @@ void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status status = session->MakeCallable(opts, &handle);
|
Status status = session->MakeCallable(opts, &handle);
|
||||||
EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
|
EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status.error_message(),
|
status.error_message(),
|
||||||
strings::StrCat(
|
strings::StrCat(
|
||||||
"Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
|
"Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
|
||||||
@ -2109,7 +2108,7 @@ void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
|
|||||||
Session::CallableHandle handle;
|
Session::CallableHandle handle;
|
||||||
Status status = session->MakeCallable(opts, &handle);
|
Status status = session->MakeCallable(opts, &handle);
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status.error_message(),
|
status.error_message(),
|
||||||
strings::StrCat(
|
strings::StrCat(
|
||||||
"Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
|
"Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
|
||||||
|
@ -54,7 +54,7 @@ const string RegisteredFactoriesErrorMessageLocked()
|
|||||||
factory_types.push_back(executor_factory.first);
|
factory_types.push_back(executor_factory.first);
|
||||||
}
|
}
|
||||||
return strings::StrCat("Registered factories are {",
|
return strings::StrCat("Registered factories are {",
|
||||||
str_util::Join(factory_types, ", "), "}.");
|
absl::StrJoin(factory_types, ", "), "}.");
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ Status GetOpSig(const string& op, const OpDef** sig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void HasError(const Status& s, StringPiece substr) {
|
void HasError(const Status& s, StringPiece substr) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
|
EXPECT_TRUE(absl::StrContains(s.ToString(), substr))
|
||||||
<< s << ", expected substring " << substr;
|
<< s << ", expected substring " << substr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,8 +254,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
||||||
EXPECT_TRUE(errors::IsNotFound(status2))
|
EXPECT_TRUE(errors::IsNotFound(status2))
|
||||||
<< "Actual status: " << status2.ToString();
|
<< "Actual status: " << status2.ToString();
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "Handle"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "not found"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
@ -324,8 +324,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
|
|
||||||
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
||||||
EXPECT_TRUE(errors::IsNotFound(status2));
|
EXPECT_TRUE(errors::IsNotFound(status2));
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "Handle"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "not found"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -150,8 +150,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
|
|
||||||
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
Status status2 = Run(flr, handle, opts, args, std::move(rets));
|
||||||
EXPECT_TRUE(errors::IsNotFound(status2));
|
EXPECT_TRUE(errors::IsNotFound(status2));
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "Handle"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
|
||||||
EXPECT_TRUE(str_util::StrContains(status2.error_message(), "not found"));
|
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -431,7 +431,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
|
|||||||
string gpu_thread_mode;
|
string gpu_thread_mode;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ReadStringFromEnvVar("TF_GPU_THREAD_MODE", "global", &gpu_thread_mode));
|
ReadStringFromEnvVar("TF_GPU_THREAD_MODE", "global", &gpu_thread_mode));
|
||||||
gpu_thread_mode = str_util::Lowercase(gpu_thread_mode);
|
gpu_thread_mode = absl::AsciiStrToLower(gpu_thread_mode);
|
||||||
if (gpu_thread_mode != "global") {
|
if (gpu_thread_mode != "global") {
|
||||||
int64 gpu_thread_count = -1;
|
int64 gpu_thread_count = -1;
|
||||||
// Default to two threads. One for device compute and another for memory
|
// Default to two threads. One for device compute and another for memory
|
||||||
@ -1760,8 +1760,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
std::vector<int> raw_ids(ids->size());
|
std::vector<int> raw_ids(ids->size());
|
||||||
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
|
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
|
||||||
[](PlatformGpuId id) -> int { return id.value(); });
|
[](PlatformGpuId id) -> int { return id.value(); });
|
||||||
LOG(INFO) << "Adding visible gpu devices: "
|
LOG(INFO) << "Adding visible gpu devices: " << absl::StrJoin(raw_ids, ", ");
|
||||||
<< str_util::Join(raw_ids, ", ");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -55,7 +55,7 @@ Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
|
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
|
EXPECT_TRUE(absl::StrContains(s.ToString(), substr))
|
||||||
<< s << ", expected substring " << substr;
|
<< s << ", expected substring " << substr;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -86,7 +86,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
|
|||||||
// Precondition: device_names must be sorted so that all devices in
|
// Precondition: device_names must be sorted so that all devices in
|
||||||
// the same task are adjacent.
|
// the same task are adjacent.
|
||||||
VLOG(2) << "Sorted task names: "
|
VLOG(2) << "Sorted task names: "
|
||||||
<< str_util::Join(col_params->instance.task_names, ", ");
|
<< absl::StrJoin(col_params->instance.task_names, ", ");
|
||||||
std::vector<int> dev_per_task;
|
std::vector<int> dev_per_task;
|
||||||
const string* prior_task_name = &col_params->instance.task_names[0];
|
const string* prior_task_name = &col_params->instance.task_names[0];
|
||||||
int dev_count = 1;
|
int dev_count = 1;
|
||||||
|
@ -51,7 +51,7 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
}
|
}
|
||||||
|
|
||||||
testing::StopTiming();
|
testing::StopTiming();
|
||||||
string t = str_util::Uppercase(device);
|
string t = absl::AsciiStrToUpper(device);
|
||||||
// Allow NewDevice to allocate a new threadpool with different number of
|
// Allow NewDevice to allocate a new threadpool with different number of
|
||||||
// threads for each new benchmark.
|
// threads for each new benchmark.
|
||||||
LocalDevice::set_use_global_threadpool(false);
|
LocalDevice::set_use_global_threadpool(false);
|
||||||
|
@ -42,7 +42,7 @@ constexpr const char* const kLowerUsingSwitchMergeAttr =
|
|||||||
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
||||||
|
|
||||||
static void AssertHasSubstr(StringPiece s, StringPiece expected) {
|
static void AssertHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
ASSERT_TRUE(str_util::StrContains(s, expected))
|
ASSERT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,7 +217,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
});
|
});
|
||||||
done2.WaitForNotification();
|
done2.WaitForNotification();
|
||||||
EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status;
|
EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status;
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "not found."));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -479,7 +479,7 @@ void TestTwoDeviceMult(
|
|||||||
if (!error.empty()) {
|
if (!error.empty()) {
|
||||||
EXPECT_TRUE(errors::IsInvalidArgument(status))
|
EXPECT_TRUE(errors::IsInvalidArgument(status))
|
||||||
<< "Actual status: " << status;
|
<< "Actual status: " << status;
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), error))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), error))
|
||||||
<< "Actual error message: " << status.error_message();
|
<< "Actual error message: " << status.error_message();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -505,11 +505,11 @@ void TestTwoDeviceInputOutput(
|
|||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
opts.rendezvous = fixture->rendezvous_;
|
opts.rendezvous = fixture->rendezvous_;
|
||||||
Tensor x1 = test::AsTensor<float>({1, 2});
|
Tensor x1 = test::AsTensor<float>({1, 2});
|
||||||
if (str_util::StrContains(inst_opts.input_devices[0], "GPU")) {
|
if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
|
||||||
x1 = fixture->CPUToGPU(x1);
|
x1 = fixture->CPUToGPU(x1);
|
||||||
}
|
}
|
||||||
Tensor x2 = test::AsTensor<float>({10, 20});
|
Tensor x2 = test::AsTensor<float>({10, 20});
|
||||||
if (str_util::StrContains(inst_opts.input_devices[1], "GPU")) {
|
if (absl::StrContains(inst_opts.input_devices[1], "GPU")) {
|
||||||
x2 = fixture->CPUToGPU(x2);
|
x2 = fixture->CPUToGPU(x2);
|
||||||
}
|
}
|
||||||
Tensor y1;
|
Tensor y1;
|
||||||
@ -517,7 +517,7 @@ void TestTwoDeviceInputOutput(
|
|||||||
TF_CHECK_OK(fixture->Run("TwoDeviceInputOutput", opts, {{"T", DT_FLOAT}},
|
TF_CHECK_OK(fixture->Run("TwoDeviceInputOutput", opts, {{"T", DT_FLOAT}},
|
||||||
inst_opts, {x1, x2}, {&y1, &y2}));
|
inst_opts, {x1, x2}, {&y1, &y2}));
|
||||||
|
|
||||||
if (str_util::StrContains(inst_opts.output_devices[0], "GPU")) {
|
if (absl::StrContains(inst_opts.output_devices[0], "GPU")) {
|
||||||
EXPECT_TRUE(IsCUDATensor(y1));
|
EXPECT_TRUE(IsCUDATensor(y1));
|
||||||
y1 = fixture->GPUToCPU(y1);
|
y1 = fixture->GPUToCPU(y1);
|
||||||
} else {
|
} else {
|
||||||
@ -525,7 +525,7 @@ void TestTwoDeviceInputOutput(
|
|||||||
}
|
}
|
||||||
test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({2, 4}));
|
test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({2, 4}));
|
||||||
|
|
||||||
if (str_util::StrContains(inst_opts.output_devices[1], "GPU")) {
|
if (absl::StrContains(inst_opts.output_devices[1], "GPU")) {
|
||||||
EXPECT_TRUE(IsCUDATensor(y2));
|
EXPECT_TRUE(IsCUDATensor(y2));
|
||||||
y2 = fixture->GPUToCPU(y2);
|
y2 = fixture->GPUToCPU(y2);
|
||||||
} else {
|
} else {
|
||||||
@ -607,7 +607,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListInput) {
|
|||||||
"FuncWithListInput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
|
"FuncWithListInput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
|
||||||
MakeOptions("CPU:0", {"CPU:0"}, {}), &handle);
|
MakeOptions("CPU:0", {"CPU:0"}, {}), &handle);
|
||||||
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
||||||
ASSERT_TRUE(str_util::StrContains(
|
ASSERT_TRUE(absl::StrContains(
|
||||||
status.error_message(),
|
status.error_message(),
|
||||||
"FuncWithListInput has an input named \"x1\" that is a list of tensors"))
|
"FuncWithListInput has an input named \"x1\" that is a list of tensors"))
|
||||||
<< "Actual error message: " << status.error_message();
|
<< "Actual error message: " << status.error_message();
|
||||||
@ -621,7 +621,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListOutput) {
|
|||||||
"FuncWithListOutput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
|
"FuncWithListOutput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
|
||||||
MakeOptions("CPU:0", {}, {"CPU:0"}), &handle);
|
MakeOptions("CPU:0", {}, {"CPU:0"}), &handle);
|
||||||
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
||||||
ASSERT_TRUE(str_util::StrContains(
|
ASSERT_TRUE(absl::StrContains(
|
||||||
status.error_message(),
|
status.error_message(),
|
||||||
"FuncWithListOutput has an output named \"y\" that is a list of tensors"))
|
"FuncWithListOutput has an output named \"y\" that is a list of tensors"))
|
||||||
<< "Actual error message: " << status.error_message();
|
<< "Actual error message: " << status.error_message();
|
||||||
@ -747,7 +747,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_PlacerError) {
|
|||||||
"ResourceOutput", test::function::Attrs({{"T", DT_FLOAT}}), inst_opts,
|
"ResourceOutput", test::function::Attrs({{"T", DT_FLOAT}}), inst_opts,
|
||||||
&handle);
|
&handle);
|
||||||
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
|
||||||
ASSERT_TRUE(str_util::StrContains(status.error_message(), "Cannot place"));
|
ASSERT_TRUE(absl::StrContains(status.error_message(), "Cannot place"));
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("BrokenOp")
|
REGISTER_OP("BrokenOp")
|
||||||
|
@ -174,7 +174,7 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
|
|||||||
// Precondition: device_names must be sorted so that all devices in
|
// Precondition: device_names must be sorted so that all devices in
|
||||||
// the same task are adjacent.
|
// the same task are adjacent.
|
||||||
VLOG(2) << "Sorted task names: "
|
VLOG(2) << "Sorted task names: "
|
||||||
<< str_util::Join(col_params->instance.task_names, ", ");
|
<< absl::StrJoin(col_params->instance.task_names, ", ");
|
||||||
std::vector<int> dev_per_task;
|
std::vector<int> dev_per_task;
|
||||||
const string* prior_task_name = &col_params->instance.task_names[0];
|
const string* prior_task_name = &col_params->instance.task_names[0];
|
||||||
int dev_count = 1;
|
int dev_count = 1;
|
||||||
|
@ -57,7 +57,7 @@ const string RegisteredFactoriesErrorMessageLocked() {
|
|||||||
factory_types.push_back(session_factory.first);
|
factory_types.push_back(session_factory.first);
|
||||||
}
|
}
|
||||||
return strings::StrCat("Registered factories are {",
|
return strings::StrCat("Registered factories are {",
|
||||||
str_util::Join(factory_types, ", "), "}.");
|
absl::StrJoin(factory_types, ", "), "}.");
|
||||||
}
|
}
|
||||||
string SessionOptionsToString(const SessionOptions& options) {
|
string SessionOptionsToString(const SessionOptions& options) {
|
||||||
return strings::StrCat("target: \"", options.target,
|
return strings::StrCat("target: \"", options.target,
|
||||||
@ -102,7 +102,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options,
|
|||||||
"Multiple session factories registered for the given session "
|
"Multiple session factories registered for the given session "
|
||||||
"options: {",
|
"options: {",
|
||||||
SessionOptionsToString(options), "} Candidate factories are {",
|
SessionOptionsToString(options), "} Candidate factories are {",
|
||||||
str_util::Join(factory_types, ", "), "}. ",
|
absl::StrJoin(factory_types, ", "), "}. ",
|
||||||
RegisteredFactoriesErrorMessageLocked());
|
RegisteredFactoriesErrorMessageLocked());
|
||||||
} else {
|
} else {
|
||||||
return errors::NotFound(
|
return errors::NotFound(
|
||||||
|
@ -32,7 +32,7 @@ TEST(SessionTest, InvalidTargetReturnsNull) {
|
|||||||
Session* session;
|
Session* session;
|
||||||
Status s = tensorflow::NewSession(options, &session);
|
Status s = tensorflow::NewSession(options, &session);
|
||||||
EXPECT_EQ(s.code(), error::NOT_FOUND);
|
EXPECT_EQ(s.code(), error::NOT_FOUND);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.error_message(),
|
s.error_message(),
|
||||||
"No session factory registered for the given session options"));
|
"No session factory registered for the given session options"));
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ class FakeSessionFactory : public SessionFactory {
|
|||||||
FakeSessionFactory() {}
|
FakeSessionFactory() {}
|
||||||
|
|
||||||
bool AcceptsOptions(const SessionOptions& options) override {
|
bool AcceptsOptions(const SessionOptions& options) override {
|
||||||
return str_util::StartsWith(options.target, "fake");
|
return absl::StartsWith(options.target, "fake");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewSession(const SessionOptions& options,
|
Status NewSession(const SessionOptions& options,
|
||||||
@ -70,9 +70,9 @@ TEST(SessionTest, MultipleFactoriesForTarget) {
|
|||||||
Status s = tensorflow::NewSession(options, &session);
|
Status s = tensorflow::NewSession(options, &session);
|
||||||
EXPECT_EQ(s.code(), error::INTERNAL);
|
EXPECT_EQ(s.code(), error::INTERNAL);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(s.error_message(), "Multiple session factories"));
|
absl::StrContains(s.error_message(), "Multiple session factories"));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_1"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "FAKE_SESSION_1"));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_2"));
|
EXPECT_TRUE(absl::StrContains(s.error_message(), "FAKE_SESSION_2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -162,8 +162,8 @@ TEST_F(ShapeRefinerTest, BadShapes) {
|
|||||||
// an error.
|
// an error.
|
||||||
Status s = m.AddNode(mm.node());
|
Status s = m.AddNode(mm.node());
|
||||||
ASSERT_FALSE(s.ok());
|
ASSERT_FALSE(s.ok());
|
||||||
ASSERT_TRUE(str_util::StrContains(
|
ASSERT_TRUE(absl::StrContains(s.error_message(),
|
||||||
s.error_message(), "Dimensions must be equal, but are 1 and 2"));
|
"Dimensions must be equal, but are 1 and 2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, SetShape) {
|
TEST_F(ShapeRefinerTest, SetShape) {
|
||||||
@ -1051,8 +1051,8 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
|
|||||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||||
}
|
}
|
||||||
TF_ASSERT_OK(m.AddNode(pack.node()));
|
TF_ASSERT_OK(m.AddNode(pack.node()));
|
||||||
EXPECT_TRUE(str_util::StrContains(m.AddNode(result).error_message(),
|
EXPECT_TRUE(
|
||||||
"but is rank 2"));
|
absl::StrContains(m.AddNode(result).error_message(), "but is rank 2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||||
|
@ -94,7 +94,7 @@ void NodeExecStatsWrapper::Done(const string& device) {
|
|||||||
} else {
|
} else {
|
||||||
text =
|
text =
|
||||||
strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
|
strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
|
||||||
str_util::Join(node_->requested_inputs(), ", "), ")");
|
absl::StrJoin(node_->requested_inputs(), ", "), ")");
|
||||||
}
|
}
|
||||||
stats_->set_timeline_label(text);
|
stats_->set_timeline_label(text);
|
||||||
step_stats_collector_->Save(device, this);
|
step_stats_collector_->Save(device, this);
|
||||||
|
@ -132,6 +132,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:proto_text",
|
"//tensorflow/core:proto_text",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -29,7 +29,7 @@ namespace {
|
|||||||
|
|
||||||
// TODO(cais): Switch to safe_strtob when available.
|
// TODO(cais): Switch to safe_strtob when available.
|
||||||
Status ParseBoolString(const string& bool_str, bool* bool_val) {
|
Status ParseBoolString(const string& bool_str, bool* bool_val) {
|
||||||
const string lower_bool_str = str_util::Lowercase(bool_str);
|
const string lower_bool_str = absl::AsciiStrToLower(bool_str);
|
||||||
if (lower_bool_str == "false" || lower_bool_str == "f" ||
|
if (lower_bool_str == "false" || lower_bool_str == "f" ||
|
||||||
lower_bool_str == "0") {
|
lower_bool_str == "0") {
|
||||||
*bool_val = false;
|
*bool_val = false;
|
||||||
@ -430,7 +430,7 @@ Status DebugNodeInserter::SetDebugNodeAttributes(
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
unfulfilled_keys.size(),
|
unfulfilled_keys.size(),
|
||||||
" attribute key(s) were not valid for debug node ", debug_node->name(),
|
" attribute key(s) were not valid for debug node ", debug_node->name(),
|
||||||
": ", str_util::Join(unfulfilled_keys, ", "));
|
": ", absl::StrJoin(unfulfilled_keys, ", "));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ limitations under the License.
|
|||||||
#pragma comment(lib, "Ws2_32.lib")
|
#pragma comment(lib, "Ws2_32.lib")
|
||||||
#endif // #ifndef PLATFORM_WINDOWS
|
#endif // #ifndef PLATFORM_WINDOWS
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/core/debug/debug_callback_registry.h"
|
#include "tensorflow/core/debug/debug_callback_registry.h"
|
||||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
@ -371,7 +373,7 @@ Status DebugIO::PublishDebugMetadata(
|
|||||||
|
|
||||||
Status status;
|
Status status;
|
||||||
for (const string& url : debug_urls) {
|
for (const string& url : debug_urls) {
|
||||||
if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
|
if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
Event grpc_event;
|
Event grpc_event;
|
||||||
|
|
||||||
@ -392,7 +394,7 @@ Status DebugIO::PublishDebugMetadata(
|
|||||||
#else
|
#else
|
||||||
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
|
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
|
||||||
#endif
|
#endif
|
||||||
} else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
|
} else if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
|
||||||
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
|
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
|
||||||
const string core_metadata_path = AppendTimestampToFilePath(
|
const string core_metadata_path = AppendTimestampToFilePath(
|
||||||
io::JoinPath(
|
io::JoinPath(
|
||||||
@ -418,7 +420,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
|
|||||||
int32 num_failed_urls = 0;
|
int32 num_failed_urls = 0;
|
||||||
std::vector<Status> fail_statuses;
|
std::vector<Status> fail_statuses;
|
||||||
for (const string& url : debug_urls) {
|
for (const string& url : debug_urls) {
|
||||||
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
|
if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
|
||||||
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
|
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
|
||||||
|
|
||||||
const int64 tensorBytes =
|
const int64 tensorBytes =
|
||||||
@ -440,7 +442,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
|
|||||||
num_failed_urls++;
|
num_failed_urls++;
|
||||||
fail_statuses.push_back(s);
|
fail_statuses.push_back(s);
|
||||||
}
|
}
|
||||||
} else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
|
} else if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
|
Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
|
||||||
debug_node_key, tensor, wall_time_us, url, gated_grpc);
|
debug_node_key, tensor, wall_time_us, url, gated_grpc);
|
||||||
@ -452,7 +454,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
|
|||||||
#else
|
#else
|
||||||
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
|
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
|
||||||
#endif
|
#endif
|
||||||
} else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
|
} else if (absl::StartsWith(absl::AsciiStrToLower(url), kMemoryURLScheme)) {
|
||||||
const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
|
const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
|
||||||
auto* callback_registry = DebugCallbackRegistry::singleton();
|
auto* callback_registry = DebugCallbackRegistry::singleton();
|
||||||
auto* callback = callback_registry->GetCallback(dump_root_dir);
|
auto* callback = callback_registry->GetCallback(dump_root_dir);
|
||||||
@ -502,7 +504,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
|
|||||||
|
|
||||||
Status status = Status::OK();
|
Status status = Status::OK();
|
||||||
for (const string& debug_url : debug_urls) {
|
for (const string& debug_url : debug_urls) {
|
||||||
if (debug_url.find(kFileURLScheme) == 0) {
|
if (absl::StartsWith(debug_url, kFileURLScheme)) {
|
||||||
const string dump_root_dir =
|
const string dump_root_dir =
|
||||||
io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
|
io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
|
||||||
DebugNodeKey::DeviceNameToDevicePath(device_name));
|
DebugNodeKey::DeviceNameToDevicePath(device_name));
|
||||||
@ -513,7 +515,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
|
|||||||
|
|
||||||
status.Update(
|
status.Update(
|
||||||
DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
|
DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
|
||||||
} else if (debug_url.find(kGrpcURLScheme) == 0) {
|
} else if (absl::StartsWith(debug_url, kGrpcURLScheme)) {
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
|
status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
|
||||||
debug_url));
|
debug_url));
|
||||||
@ -578,7 +580,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DebugIO::CloseDebugURL(const string& debug_url) {
|
Status DebugIO::CloseDebugURL(const string& debug_url) {
|
||||||
if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
|
if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) {
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
return DebugGrpcIO::CloseGrpcStream(debug_url);
|
return DebugGrpcIO::CloseGrpcStream(debug_url);
|
||||||
#else
|
#else
|
||||||
@ -846,7 +848,7 @@ Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
|
|||||||
Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
|
Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
|
||||||
const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
|
const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
|
||||||
const string addr_with_path =
|
const string addr_with_path =
|
||||||
grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
|
absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme)
|
||||||
? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
|
? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
|
||||||
: grpc_stream_url;
|
: grpc_stream_url;
|
||||||
const string server_stream_addr =
|
const string server_stream_addr =
|
||||||
|
@ -571,6 +571,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ BaseRemoteRendezvous::~BaseRemoteRendezvous() {
|
|||||||
// and device name and does no lookups in the worker->device_mgr.
|
// and device name and does no lookups in the worker->device_mgr.
|
||||||
static bool IsLocalDevice(const StringPiece worker_name,
|
static bool IsLocalDevice(const StringPiece worker_name,
|
||||||
const StringPiece device_name) {
|
const StringPiece device_name) {
|
||||||
return str_util::StartsWith(device_name, worker_name);
|
return absl::StartsWith(device_name, worker_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
|
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
|
||||||
|
@ -134,7 +134,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
|
|||||||
worker_session_->worker_cache->ListWorkers(&workers);
|
worker_session_->worker_cache->ListWorkers(&workers);
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Could not find worker with target: ", options.target,
|
"Could not find worker with target: ", options.target,
|
||||||
" Available workers: ", str_util::Join(workers, ", "));
|
" Available workers: ", absl::StrJoin(workers, ", "));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make RPC and obtain a graph handle.
|
// Make RPC and obtain a graph handle.
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
@ -243,7 +244,7 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
|||||||
CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
|
CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
|
||||||
gr->group.runtime_details.communicator_key = resp.communicator_key();
|
gr->group.runtime_details.communicator_key = resp.communicator_key();
|
||||||
VLOG(2) << "Group communicator_key="
|
VLOG(2) << "Group communicator_key="
|
||||||
<< str_util::CEscape(gr->group.runtime_details.communicator_key);
|
<< absl::CEscape(gr->group.runtime_details.communicator_key);
|
||||||
{
|
{
|
||||||
// Group membership should never change. Once a record is in group_table_
|
// Group membership should never change. Once a record is in group_table_
|
||||||
// it never gets removed.
|
// it never gets removed.
|
||||||
@ -251,7 +252,7 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
|||||||
auto it = group_table_.find(gr->group.group_key);
|
auto it = group_table_.find(gr->group.group_key);
|
||||||
if (it == group_table_.end()) {
|
if (it == group_table_.end()) {
|
||||||
VLOG(2) << "UpdateGroupCache: communicator_key="
|
VLOG(2) << "UpdateGroupCache: communicator_key="
|
||||||
<< str_util::CEscape(gr->group.runtime_details.communicator_key);
|
<< absl::CEscape(gr->group.runtime_details.communicator_key);
|
||||||
group_table_[gr->group.group_key] = std::move(gr);
|
group_table_[gr->group.group_key] = std::move(gr);
|
||||||
} else {
|
} else {
|
||||||
auto& previous_gr = group_table_[gr->group.group_key];
|
auto& previous_gr = group_table_[gr->group.group_key];
|
||||||
@ -260,10 +261,9 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
|||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"UpdateGroupCache: CompleteGroupResponse for group ",
|
"UpdateGroupCache: CompleteGroupResponse for group ",
|
||||||
gr->group.group_key, " gives communicator_key=",
|
gr->group.group_key, " gives communicator_key=",
|
||||||
str_util::CEscape(gr->group.runtime_details.communicator_key),
|
absl::CEscape(gr->group.runtime_details.communicator_key),
|
||||||
" but cache already holds communicator_key=",
|
" but cache already holds communicator_key=",
|
||||||
str_util::CEscape(
|
absl::CEscape(previous_gr->group.runtime_details.communicator_key));
|
||||||
previous_gr->group.runtime_details.communicator_key));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,7 +92,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
|||||||
n->name(),
|
n->name(),
|
||||||
NodeDetails(n->type_string(),
|
NodeDetails(n->type_string(),
|
||||||
strings::StrCat(
|
strings::StrCat(
|
||||||
"(", str_util::Join(n->requested_inputs(), ", "))));
|
"(", absl::StrJoin(n->requested_inputs(), ", "))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,7 +306,7 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
|
|||||||
task_strings.emplace_back(
|
task_strings.emplace_back(
|
||||||
strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
|
strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
|
||||||
}
|
}
|
||||||
return strings::StrCat(job_id_, " -> {", str_util::Join(task_strings, ", "),
|
return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -459,7 +459,7 @@ Status GrpcSession::ReleaseCallable(CallableHandle handle) {
|
|||||||
class GrpcSessionFactory : public SessionFactory {
|
class GrpcSessionFactory : public SessionFactory {
|
||||||
public:
|
public:
|
||||||
bool AcceptsOptions(const SessionOptions& options) override {
|
bool AcceptsOptions(const SessionOptions& options) override {
|
||||||
return str_util::StartsWith(options.target, kSchemePrefix);
|
return absl::StartsWith(options.target, kSchemePrefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewSession(const SessionOptions& options,
|
Status NewSession(const SessionOptions& options,
|
||||||
|
@ -67,7 +67,7 @@ Status FillServerDef(const string& cluster_spec, const string& job_name,
|
|||||||
my_num_tasks = host_ports.size();
|
my_num_tasks = host_ports.size();
|
||||||
}
|
}
|
||||||
LOG(INFO) << "Peer " << job_name << " " << num_tasks << " {"
|
LOG(INFO) << "Peer " << job_name << " " << num_tasks << " {"
|
||||||
<< str_util::Join(host_ports, ", ") << "}";
|
<< absl::StrJoin(host_ports, ", ") << "}";
|
||||||
}
|
}
|
||||||
if (my_num_tasks == 0) {
|
if (my_num_tasks == 0) {
|
||||||
return errors::InvalidArgument("Job name \"", options->job_name(),
|
return errors::InvalidArgument("Job name \"", options->job_name(),
|
||||||
|
@ -46,7 +46,7 @@ Status TestCluster::MakeTestCluster(const string& binary_path,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const string tf_jobs = strings::StrCat("--tf_jobs=localhost|",
|
const string tf_jobs = strings::StrCat("--tf_jobs=localhost|",
|
||||||
str_util::Join(ret->targets_, ";"));
|
absl::StrJoin(ret->targets_, ";"));
|
||||||
|
|
||||||
int num_cpus = 1;
|
int num_cpus = 1;
|
||||||
int num_gpus = 0;
|
int num_gpus = 0;
|
||||||
|
@ -61,7 +61,7 @@ Status FillServerDef(const string& job_spec, const string& job_name,
|
|||||||
my_tasks_per_replica = tasks_per_replica;
|
my_tasks_per_replica = tasks_per_replica;
|
||||||
}
|
}
|
||||||
LOG(INFO) << "Peer " << job_def->name() << " " << tasks_per_replica << " {"
|
LOG(INFO) << "Peer " << job_def->name() << " " << tasks_per_replica << " {"
|
||||||
<< str_util::Join(host_ports, ", ") << "}";
|
<< absl::StrJoin(host_ports, ", ") << "}";
|
||||||
}
|
}
|
||||||
if (my_tasks_per_replica == 0) {
|
if (my_tasks_per_replica == 0) {
|
||||||
return errors::InvalidArgument("Invalid job specification");
|
return errors::InvalidArgument("Invalid job specification");
|
||||||
|
@ -64,7 +64,7 @@ Status ServerFactory::GetFactory(const ServerDef& server_def,
|
|||||||
return errors::NotFound(
|
return errors::NotFound(
|
||||||
"No server factory registered for the given ServerDef: ",
|
"No server factory registered for the given ServerDef: ",
|
||||||
server_def.DebugString(), "\nThe available server factories are: [ ",
|
server_def.DebugString(), "\nThe available server factories are: [ ",
|
||||||
str_util::Join(server_names, ", "), " ]");
|
absl::StrJoin(server_names, ", "), " ]");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a server based on the given `server_def`, and stores it in
|
// Creates a server based on the given `server_def`, and stores it in
|
||||||
|
@ -46,10 +46,10 @@ TEST(ServerLibTest, NewServerNoFactoriesAccept) {
|
|||||||
std::unique_ptr<ServerInterface> server;
|
std::unique_ptr<ServerInterface> server;
|
||||||
Status s = NewServer(server_def, &server);
|
Status s = NewServer(server_def, &server);
|
||||||
ASSERT_NE(s, Status::OK());
|
ASSERT_NE(s, Status::OK());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.error_message(),
|
s.error_message(),
|
||||||
"No server factory registered for the given ServerDef"));
|
"No server factory registered for the given ServerDef"));
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(),
|
EXPECT_TRUE(absl::StrContains(s.error_message(),
|
||||||
"The available server factories are: ["));
|
"The available server factories are: ["));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ TEST_F(SessionMgrTest, UnknownSessionHandle) {
|
|||||||
Status s = mgr_.WorkerSessionForSession(session_handle, &session);
|
Status s = mgr_.WorkerSessionForSession(session_handle, &session);
|
||||||
EXPECT_TRUE(errors::IsAborted(s));
|
EXPECT_TRUE(errors::IsAborted(s));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(s.error_message(), "Session handle is not found"));
|
absl::StrContains(s.error_message(), "Session handle is not found"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
|
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
@ -183,7 +184,7 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
|
|||||||
}
|
}
|
||||||
|
|
||||||
string SummarizeString(const string& str) {
|
string SummarizeString(const string& str) {
|
||||||
string escaped = str_util::CEscape(str);
|
string escaped = absl::CEscape(str);
|
||||||
|
|
||||||
// If the string is long, replace the middle with ellipses.
|
// If the string is long, replace the middle with ellipses.
|
||||||
constexpr int kMaxStringSummarySize = 80;
|
constexpr int kMaxStringSummarySize = 80;
|
||||||
@ -214,7 +215,7 @@ string SummarizeFunc(const NameAttrList& func) {
|
|||||||
strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
|
strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
|
||||||
}
|
}
|
||||||
std::sort(entries.begin(), entries.end());
|
std::sort(entries.begin(), entries.end());
|
||||||
return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
|
return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -276,7 +277,7 @@ string SummarizeAttrValue(const AttrValue& attr_value) {
|
|||||||
pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
|
pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
|
||||||
pieces[5] = "...";
|
pieces[5] = "...";
|
||||||
}
|
}
|
||||||
return strings::StrCat("[", str_util::Join(pieces, ", "), "]");
|
return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
|
||||||
}
|
}
|
||||||
case AttrValue::kFunc: {
|
case AttrValue::kFunc: {
|
||||||
return SummarizeFunc(attr_value.func());
|
return SummarizeFunc(attr_value.func());
|
||||||
@ -335,7 +336,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
|
|||||||
// check if has_list is false and some other field in attr_value is
|
// check if has_list is false and some other field in attr_value is
|
||||||
// set to flag the error. This test can be made more strict once
|
// set to flag the error. This test can be made more strict once
|
||||||
// support for GraphDef versions <= 4 is dropped.
|
// support for GraphDef versions <= 4 is dropped.
|
||||||
if (str_util::StartsWith(type, "list(") && !attr_value.has_list()) {
|
if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
|
||||||
if (num_set) {
|
if (num_set) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"AttrValue missing value with expected type '", type, "'");
|
"AttrValue missing value with expected type '", type, "'");
|
||||||
@ -346,7 +347,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Okay to have an empty list, but not to be missing a non-list value.
|
// Okay to have an empty list, but not to be missing a non-list value.
|
||||||
if (num_set == 0 && !str_util::StartsWith(type, "list(")) {
|
if (num_set == 0 && !absl::StartsWith(type, "list(")) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"AttrValue missing value with expected type '", type, "'");
|
"AttrValue missing value with expected type '", type, "'");
|
||||||
}
|
}
|
||||||
@ -390,29 +391,29 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
|
|||||||
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
||||||
// Parse type.
|
// Parse type.
|
||||||
string field_name;
|
string field_name;
|
||||||
bool is_list = str_util::ConsumePrefix(&type, "list(");
|
bool is_list = absl::ConsumePrefix(&type, "list(");
|
||||||
if (str_util::ConsumePrefix(&type, "string")) {
|
if (absl::ConsumePrefix(&type, "string")) {
|
||||||
field_name = "s";
|
field_name = "s";
|
||||||
} else if (str_util::ConsumePrefix(&type, "int")) {
|
} else if (absl::ConsumePrefix(&type, "int")) {
|
||||||
field_name = "i";
|
field_name = "i";
|
||||||
} else if (str_util::ConsumePrefix(&type, "float")) {
|
} else if (absl::ConsumePrefix(&type, "float")) {
|
||||||
field_name = "f";
|
field_name = "f";
|
||||||
} else if (str_util::ConsumePrefix(&type, "bool")) {
|
} else if (absl::ConsumePrefix(&type, "bool")) {
|
||||||
field_name = "b";
|
field_name = "b";
|
||||||
} else if (str_util::ConsumePrefix(&type, "type")) {
|
} else if (absl::ConsumePrefix(&type, "type")) {
|
||||||
field_name = "type";
|
field_name = "type";
|
||||||
} else if (str_util::ConsumePrefix(&type, "shape")) {
|
} else if (absl::ConsumePrefix(&type, "shape")) {
|
||||||
field_name = "shape";
|
field_name = "shape";
|
||||||
} else if (str_util::ConsumePrefix(&type, "tensor")) {
|
} else if (absl::ConsumePrefix(&type, "tensor")) {
|
||||||
field_name = "tensor";
|
field_name = "tensor";
|
||||||
} else if (str_util::ConsumePrefix(&type, "func")) {
|
} else if (absl::ConsumePrefix(&type, "func")) {
|
||||||
field_name = "func";
|
field_name = "func";
|
||||||
} else if (str_util::ConsumePrefix(&type, "placeholder")) {
|
} else if (absl::ConsumePrefix(&type, "placeholder")) {
|
||||||
field_name = "placeholder";
|
field_name = "placeholder";
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (is_list && !str_util::ConsumePrefix(&type, ")")) {
|
if (is_list && !absl::ConsumePrefix(&type, ")")) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
@ -49,7 +50,7 @@ std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
|
|||||||
|
|
||||||
string CollGroupRuntimeDetails::ToString() const {
|
string CollGroupRuntimeDetails::ToString() const {
|
||||||
return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
|
return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
|
||||||
str_util::CEscape(communicator_key), "}");
|
absl::CEscape(communicator_key), "}");
|
||||||
}
|
}
|
||||||
|
|
||||||
string CollGroupParams::ToString() const {
|
string CollGroupParams::ToString() const {
|
||||||
|
@ -141,7 +141,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
|||||||
{}, {}, {});
|
{}, {}, {});
|
||||||
auto s = MatMulShape(&c);
|
auto s = MatMulShape(&c);
|
||||||
EXPECT_FALSE(s.ok());
|
EXPECT_FALSE(s.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 1"));
|
s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
|||||||
{S({2, 5}), S({3, 4})}, {}, {}, {});
|
{S({2, 5}), S({3, 4})}, {}, {}, {});
|
||||||
auto s = MatMulShape(&c);
|
auto s = MatMulShape(&c);
|
||||||
EXPECT_FALSE(s.ok());
|
EXPECT_FALSE(s.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.ToString(),
|
s.ToString(),
|
||||||
"Invalid argument: Dimensions must be equal, but are 5 and 3"));
|
"Invalid argument: Dimensions must be equal, but are 5 and 3"));
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
|||||||
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
|
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
|
||||||
auto s = MatMulShape(&c);
|
auto s = MatMulShape(&c);
|
||||||
EXPECT_FALSE(s.ok());
|
EXPECT_FALSE(s.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 3"));
|
s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 3"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
@ -34,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
#include "tensorflow/core/util/equal_graph_def.h"
|
#include "tensorflow/core/util/equal_graph_def.h"
|
||||||
|
|
||||||
@ -289,8 +289,7 @@ class FunctionInstantiationHelper {
|
|||||||
// must lie in the range [node_name, node_colon_bound).
|
// must lie in the range [node_name, node_colon_bound).
|
||||||
auto it = index_.lower_bound(node_name);
|
auto it = index_.lower_bound(node_name);
|
||||||
while (it != index_.end() && it->first <= node_colon_bound) {
|
while (it != index_.end() && it->first <= node_colon_bound) {
|
||||||
if (it->first == node_name ||
|
if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
|
||||||
tensorflow::str_util::StartsWith(it->first, node_colon)) {
|
|
||||||
nid = it->second.nid;
|
nid = it->second.nid;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -498,7 +497,7 @@ string Print(const AttrValue& attr_value) {
|
|||||||
}
|
}
|
||||||
std::sort(entries.begin(), entries.end());
|
std::sort(entries.begin(), entries.end());
|
||||||
return strings::StrCat(attr_value.func().name(), "[",
|
return strings::StrCat(attr_value.func().name(), "[",
|
||||||
str_util::Join(entries, ", "), "]");
|
absl::StrJoin(entries, ", "), "]");
|
||||||
}
|
}
|
||||||
return SummarizeAttrValue(attr_value);
|
return SummarizeAttrValue(attr_value);
|
||||||
}
|
}
|
||||||
@ -523,21 +522,21 @@ string Print(const NodeDef& n) {
|
|||||||
entries.push_back("device=<FAILED_TO_PARSE>");
|
entries.push_back("device=<FAILED_TO_PARSE>");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
|
strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
|
||||||
}
|
}
|
||||||
strings::StrAppend(&out, "(");
|
strings::StrAppend(&out, "(");
|
||||||
std::vector<StringPiece> dat;
|
std::vector<StringPiece> dat;
|
||||||
std::vector<string> dep;
|
std::vector<string> dep;
|
||||||
for (StringPiece s : n.input()) {
|
for (StringPiece s : n.input()) {
|
||||||
if (str_util::ConsumePrefix(&s, "^")) {
|
if (absl::ConsumePrefix(&s, "^")) {
|
||||||
dep.emplace_back(s);
|
dep.emplace_back(s);
|
||||||
} else {
|
} else {
|
||||||
dat.push_back(s);
|
dat.push_back(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
|
strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
|
||||||
if (!dep.empty()) {
|
if (!dep.empty()) {
|
||||||
strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
|
strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -901,27 +900,27 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
|
|||||||
}
|
}
|
||||||
if (!options.target.empty()) {
|
if (!options.target.empty()) {
|
||||||
entries.push_back(
|
entries.push_back(
|
||||||
strings::StrCat("_target", "=", str_util::CEscape(options.target)));
|
strings::StrCat("_target", "=", absl::CEscape(options.target)));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < options.input_devices.size(); ++i) {
|
for (int i = 0; i < options.input_devices.size(); ++i) {
|
||||||
entries.push_back(strings::StrCat(
|
entries.push_back(strings::StrCat("_input_dev", i, "=",
|
||||||
"_input_dev", i, "=", str_util::CEscape(options.input_devices[i])));
|
absl::CEscape(options.input_devices[i])));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < options.output_devices.size(); ++i) {
|
for (int i = 0; i < options.output_devices.size(); ++i) {
|
||||||
entries.push_back(strings::StrCat(
|
entries.push_back(strings::StrCat(
|
||||||
"_output_dev", i, "=", str_util::CEscape(options.output_devices[i])));
|
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
|
||||||
}
|
}
|
||||||
for (const auto& iter : options.input_tensor_shapes) {
|
for (const auto& iter : options.input_tensor_shapes) {
|
||||||
entries.push_back(
|
entries.push_back(
|
||||||
strings::StrCat("_input_tensor_shape", iter.first, "=",
|
strings::StrCat("_input_tensor_shape", iter.first, "=",
|
||||||
str_util::CEscape(iter.second.DebugString())));
|
absl::CEscape(iter.second.DebugString())));
|
||||||
}
|
}
|
||||||
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
||||||
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
||||||
DataTypeString(iter.second.first)));
|
DataTypeString(iter.second.first)));
|
||||||
entries.push_back(
|
entries.push_back(
|
||||||
strings::StrCat("_input_resource_shape", iter.first, "=",
|
strings::StrCat("_input_resource_shape", iter.first, "=",
|
||||||
str_util::CEscape(iter.second.second.DebugString())));
|
absl::CEscape(iter.second.second.DebugString())));
|
||||||
}
|
}
|
||||||
if (options.lib_def) {
|
if (options.lib_def) {
|
||||||
entries.push_back(strings::StrCat(
|
entries.push_back(strings::StrCat(
|
||||||
@ -938,11 +937,11 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
|
|||||||
string config_proto_serialized;
|
string config_proto_serialized;
|
||||||
options.config_proto.SerializeToString(&config_proto_serialized);
|
options.config_proto.SerializeToString(&config_proto_serialized);
|
||||||
if (!config_proto_serialized.empty()) {
|
if (!config_proto_serialized.empty()) {
|
||||||
entries.push_back(strings::StrCat(
|
entries.push_back(strings::StrCat("_config_proto", "=",
|
||||||
"_config_proto", "=", str_util::CEscape(config_proto_serialized)));
|
absl::CEscape(config_proto_serialized)));
|
||||||
}
|
}
|
||||||
std::sort(entries.begin(), entries.end());
|
std::sort(entries.begin(), entries.end());
|
||||||
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
|
return strings::StrCat(funcname, "[", absl::StrJoin(entries, ","), "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
|
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
|
||||||
|
@ -556,7 +556,7 @@ TEST(TFunc, IntsOnDeviceArgSet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void HasError(const Status& s, const string& substr) {
|
static void HasError(const Status& s, const string& substr) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
|
EXPECT_TRUE(absl::StrContains(s.ToString(), substr))
|
||||||
<< ">>" << s << "<<, expected substring >>" << substr << "<<";
|
<< ">>" << s << "<<, expected substring >>" << substr << "<<";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ static Status RemoveNewDefaultAttrsFromNodeDef(
|
|||||||
std::vector<string> to_remove;
|
std::vector<string> to_remove;
|
||||||
for (const auto& attr : node_def->attr()) {
|
for (const auto& attr : node_def->attr()) {
|
||||||
// If the attr is not in consumer_op_def and doesn't start with '_'...
|
// If the attr is not in consumer_op_def and doesn't start with '_'...
|
||||||
if (!str_util::StartsWith(attr.first, "_") &&
|
if (!absl::StartsWith(attr.first, "_") &&
|
||||||
FindAttr(attr.first, *consumer_op_def) == nullptr) {
|
FindAttr(attr.first, *consumer_op_def) == nullptr) {
|
||||||
const OpDef::AttrDef* producer_attr_def =
|
const OpDef::AttrDef* producer_attr_def =
|
||||||
FindAttr(attr.first, *producer_op_def);
|
FindAttr(attr.first, *producer_op_def);
|
||||||
|
@ -249,7 +249,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
|||||||
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
|
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
|
||||||
NodeDef* node_def = fdef->mutable_node_def(n_index);
|
NodeDef* node_def = fdef->mutable_node_def(n_index);
|
||||||
for (int i = 0; i < node_def->input_size(); ++i) {
|
for (int i = 0; i < node_def->input_size(); ++i) {
|
||||||
if (str_util::StartsWith(node_def->input(i), "^")) {
|
if (absl::StartsWith(node_def->input(i), "^")) {
|
||||||
// Control input
|
// Control input
|
||||||
const string normalized =
|
const string normalized =
|
||||||
node_names.Renormalize(node_def->input(i).substr(1));
|
node_names.Renormalize(node_def->input(i).substr(1));
|
||||||
|
@ -137,7 +137,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
|
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
|
||||||
if (!host_memory_args.empty()) {
|
if (!host_memory_args.empty()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
|
"HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
|
||||||
"' not found in OpDef: ", SummarizeOpDef(*op_def));
|
"' not found in OpDef: ", SummarizeOpDef(*op_def));
|
||||||
}
|
}
|
||||||
CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
|
CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
|
||||||
|
@ -238,7 +238,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
errors_ptr->size(), " errors while building NodeDef '",
|
errors_ptr->size(), " errors while building NodeDef '",
|
||||||
node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
|
node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
|
||||||
str_util::Join(*errors_ptr, "\n"));
|
absl::StrJoin(*errors_ptr, "\n"));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
NodeDef node_def_backup;
|
NodeDef node_def_backup;
|
||||||
|
@ -83,7 +83,7 @@ class NodeDefBuilderTest : public ::testing::Test {
|
|||||||
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
|
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
|
||||||
if (status.ok()) return;
|
if (status.ok()) return;
|
||||||
for (const string& message : messages) {
|
for (const string& message : messages) {
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), message))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), message))
|
||||||
<< status << ", " << message;
|
<< status << ", " << message;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,7 +104,7 @@ class NodeDefBuilderTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
|
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
|
||||||
if (status.ok()) return;
|
if (status.ok()) return;
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), message))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), message))
|
||||||
<< "Actual error: " << status.error_message()
|
<< "Actual error: " << status.error_message()
|
||||||
<< "\nDoes not contain: " << message;
|
<< "\nDoes not contain: " << message;
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ Status AttrSlice::Find(StringPiece attr_name,
|
|||||||
// Skip AttachDef for internal attrs since it is a little bit
|
// Skip AttachDef for internal attrs since it is a little bit
|
||||||
// expensive and it is common for them to correctly not be included
|
// expensive and it is common for them to correctly not be included
|
||||||
// in a NodeDef.
|
// in a NodeDef.
|
||||||
if (!str_util::StartsWith(attr_name, "_") && ndef_ != nullptr) {
|
if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
|
||||||
s = AttachDef(s, *ndef_);
|
s = AttachDef(s, *ndef_);
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
@ -500,7 +500,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
|||||||
size_t num_inputs = 0;
|
size_t num_inputs = 0;
|
||||||
// TODO(josh11b): Unify the input field validation.
|
// TODO(josh11b): Unify the input field validation.
|
||||||
for (const string& input : node_def.input()) {
|
for (const string& input : node_def.input()) {
|
||||||
if (str_util::StartsWith(input, "^")) {
|
if (absl::StartsWith(input, "^")) {
|
||||||
seen_control = true;
|
seen_control = true;
|
||||||
if (input.find(':') != string::npos) {
|
if (input.find(':') != string::npos) {
|
||||||
return errors::InvalidArgument("Control input '", input,
|
return errors::InvalidArgument("Control input '", input,
|
||||||
@ -526,7 +526,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
|||||||
}
|
}
|
||||||
for (const auto& attr : node_def.attr()) {
|
for (const auto& attr : node_def.attr()) {
|
||||||
// Allow internal optional attributes with names starting with "_".
|
// Allow internal optional attributes with names starting with "_".
|
||||||
if (str_util::StartsWith(attr.first, "_")) {
|
if (absl::StartsWith(attr.first, "_")) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto iter = op_attrs.find(attr.first);
|
auto iter = op_attrs.find(attr.first);
|
||||||
|
@ -68,7 +68,7 @@ void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
|
|||||||
<< "; OpDef: " << SummarizeOpDef(op_def);
|
<< "; OpDef: " << SummarizeOpDef(op_def);
|
||||||
|
|
||||||
LOG(INFO) << "Message: " << status.error_message();
|
LOG(INFO) << "Message: " << status.error_message();
|
||||||
EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
|
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
|
||||||
<< "NodeDef: " << SummarizeNodeDef(bad)
|
<< "NodeDef: " << SummarizeNodeDef(bad)
|
||||||
<< "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
|
<< "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
|
||||||
<< "\nDoes not contain: " << message;
|
<< "\nDoes not contain: " << message;
|
||||||
@ -270,7 +270,7 @@ void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
|
|||||||
EXPECT_TRUE(errors::IsInvalidArgument(status))
|
EXPECT_TRUE(errors::IsInvalidArgument(status))
|
||||||
<< status << "; NodeDef: " << SummarizeNodeDef(bad);
|
<< status << "; NodeDef: " << SummarizeNodeDef(bad);
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(StringPiece(status.ToString()), message))
|
EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message))
|
||||||
<< "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
|
<< "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
|
||||||
<< message;
|
<< message;
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
|
|||||||
out->Reserve(sorted.size());
|
out->Reserve(sorted.size());
|
||||||
|
|
||||||
for (const auto& item : sorted) {
|
for (const auto& item : sorted) {
|
||||||
if (include_internal || !str_util::StartsWith(item.first, "_")) {
|
if (include_internal || !absl::StartsWith(item.first, "_")) {
|
||||||
*out->Add() = item.second->op_def;
|
*out->Add() = item.second->op_def;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,7 @@ class OpCompatibilityTest : public OpsTestBase {
|
|||||||
ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. "
|
ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. "
|
||||||
<< SummarizeOpDef(new_op_def);
|
<< SummarizeOpDef(new_op_def);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), error))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), error))
|
||||||
<< status << " does not contain " << error;
|
<< status << " does not contain " << error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,8 +118,7 @@ class OpCompatibilityTest : public OpsTestBase {
|
|||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
ADD_FAILURE() << SummarizeNodeDef(*node_def());
|
ADD_FAILURE() << SummarizeNodeDef(*node_def());
|
||||||
} else {
|
} else {
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status.error_message(), validation_error))
|
||||||
str_util::StrContains(status.error_message(), validation_error))
|
|
||||||
<< status << " does not contain " << validation_error;
|
<< status << " does not contain " << validation_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +179,7 @@ class OpCompatibilityTest : public OpsTestBase {
|
|||||||
<< SummarizeOpDef(*new_op_def);
|
<< SummarizeOpDef(*new_op_def);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(), compatibility_error))
|
absl::StrContains(status.error_message(), compatibility_error))
|
||||||
<< status << " does not contain " << compatibility_error;
|
<< status << " does not contain " << compatibility_error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
@ -112,11 +114,11 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
|
|||||||
|
|
||||||
bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
|
bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
|
||||||
auto capture_begin = sp->begin();
|
auto capture_begin = sp->begin();
|
||||||
if (str_util::ConsumePrefix(sp, "numbertype") ||
|
if (absl::ConsumePrefix(sp, "numbertype") ||
|
||||||
str_util::ConsumePrefix(sp, "numerictype") ||
|
absl::ConsumePrefix(sp, "numerictype") ||
|
||||||
str_util::ConsumePrefix(sp, "quantizedtype") ||
|
absl::ConsumePrefix(sp, "quantizedtype") ||
|
||||||
str_util::ConsumePrefix(sp, "realnumbertype") ||
|
absl::ConsumePrefix(sp, "realnumbertype") ||
|
||||||
str_util::ConsumePrefix(sp, "realnumberictype")) {
|
absl::ConsumePrefix(sp, "realnumberictype")) {
|
||||||
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
|
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -157,32 +159,32 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
bool is_list = ConsumeListPrefix(&spec);
|
bool is_list = ConsumeListPrefix(&spec);
|
||||||
string type;
|
string type;
|
||||||
StringPiece type_string; // Used if type == "type"
|
StringPiece type_string; // Used if type == "type"
|
||||||
if (str_util::ConsumePrefix(&spec, "string")) {
|
if (absl::ConsumePrefix(&spec, "string")) {
|
||||||
type = "string";
|
type = "string";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "int")) {
|
} else if (absl::ConsumePrefix(&spec, "int")) {
|
||||||
type = "int";
|
type = "int";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "float")) {
|
} else if (absl::ConsumePrefix(&spec, "float")) {
|
||||||
type = "float";
|
type = "float";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "bool")) {
|
} else if (absl::ConsumePrefix(&spec, "bool")) {
|
||||||
type = "bool";
|
type = "bool";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "type")) {
|
} else if (absl::ConsumePrefix(&spec, "type")) {
|
||||||
type = "type";
|
type = "type";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "shape")) {
|
} else if (absl::ConsumePrefix(&spec, "shape")) {
|
||||||
type = "shape";
|
type = "shape";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "tensor")) {
|
} else if (absl::ConsumePrefix(&spec, "tensor")) {
|
||||||
type = "tensor";
|
type = "tensor";
|
||||||
} else if (str_util::ConsumePrefix(&spec, "func")) {
|
} else if (absl::ConsumePrefix(&spec, "func")) {
|
||||||
type = "func";
|
type = "func";
|
||||||
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
|
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
|
||||||
type = "type";
|
type = "type";
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
AttrValue* allowed = attr->mutable_allowed_values();
|
||||||
VERIFY(ProcessCompoundType(type_string, allowed),
|
VERIFY(ProcessCompoundType(type_string, allowed),
|
||||||
"Expected to see a compound type, saw: ", type_string);
|
"Expected to see a compound type, saw: ", type_string);
|
||||||
} else if (str_util::ConsumePrefix(&spec, "{")) {
|
} else if (absl::ConsumePrefix(&spec, "{")) {
|
||||||
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
|
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
AttrValue* allowed = attr->mutable_allowed_values();
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
if (str_util::StartsWith(spec, "\"") || str_util::StartsWith(spec, "'")) {
|
if (absl::StartsWith(spec, "\"") || absl::StartsWith(spec, "'")) {
|
||||||
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
|
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
|
||||||
while (true) {
|
while (true) {
|
||||||
StringPiece escaped_string;
|
StringPiece escaped_string;
|
||||||
@ -191,16 +193,16 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
"Trouble parsing allowed string at '", spec, "'");
|
"Trouble parsing allowed string at '", spec, "'");
|
||||||
string unescaped;
|
string unescaped;
|
||||||
string error;
|
string error;
|
||||||
VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
|
VERIFY(absl::CUnescape(escaped_string, &unescaped, &error),
|
||||||
"Trouble unescaping \"", escaped_string,
|
"Trouble unescaping \"", escaped_string,
|
||||||
"\", got error: ", error);
|
"\", got error: ", error);
|
||||||
allowed->mutable_list()->add_s(unescaped);
|
allowed->mutable_list()->add_s(unescaped);
|
||||||
if (str_util::ConsumePrefix(&spec, ",")) {
|
if (absl::ConsumePrefix(&spec, ",")) {
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
if (str_util::ConsumePrefix(&spec, "}"))
|
if (absl::ConsumePrefix(&spec, "}"))
|
||||||
break; // Allow ending with ", }".
|
break; // Allow ending with ", }".
|
||||||
} else {
|
} else {
|
||||||
VERIFY(str_util::ConsumePrefix(&spec, "}"),
|
VERIFY(absl::ConsumePrefix(&spec, "}"),
|
||||||
"Expected , or } after strings in list, not: '", spec, "'");
|
"Expected , or } after strings in list, not: '", spec, "'");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -218,12 +220,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
"Unrecognized type string '", type_string, "'");
|
"Unrecognized type string '", type_string, "'");
|
||||||
allowed->mutable_list()->add_type(dt);
|
allowed->mutable_list()->add_type(dt);
|
||||||
}
|
}
|
||||||
if (str_util::ConsumePrefix(&spec, ",")) {
|
if (absl::ConsumePrefix(&spec, ",")) {
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
if (str_util::ConsumePrefix(&spec, "}"))
|
if (absl::ConsumePrefix(&spec, "}"))
|
||||||
break; // Allow ending with ", }".
|
break; // Allow ending with ", }".
|
||||||
} else {
|
} else {
|
||||||
VERIFY(str_util::ConsumePrefix(&spec, "}"),
|
VERIFY(absl::ConsumePrefix(&spec, "}"),
|
||||||
"Expected , or } after types in list, not: '", spec, "'");
|
"Expected , or } after types in list, not: '", spec, "'");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -236,7 +238,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
|
|
||||||
// Write the type into *attr.
|
// Write the type into *attr.
|
||||||
if (is_list) {
|
if (is_list) {
|
||||||
VERIFY(str_util::ConsumePrefix(&spec, ")"),
|
VERIFY(absl::ConsumePrefix(&spec, ")"),
|
||||||
"Expected ) to close 'list(', not: '", spec, "'");
|
"Expected ) to close 'list(', not: '", spec, "'");
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
attr->set_type(strings::StrCat("list(", type, ")"));
|
attr->set_type(strings::StrCat("list(", type, ")"));
|
||||||
@ -245,7 +247,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read optional minimum constraint at the end.
|
// Read optional minimum constraint at the end.
|
||||||
if ((is_list || type == "int") && str_util::ConsumePrefix(&spec, ">=")) {
|
if ((is_list || type == "int") && absl::ConsumePrefix(&spec, ">=")) {
|
||||||
int64 min_limit = -999;
|
int64 min_limit = -999;
|
||||||
VERIFY(ConsumeAttrNumber(&spec, &min_limit),
|
VERIFY(ConsumeAttrNumber(&spec, &min_limit),
|
||||||
"Could not parse integer lower limit after '>=', found '", spec,
|
"Could not parse integer lower limit after '>=', found '", spec,
|
||||||
@ -255,7 +257,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse default value, if present.
|
// Parse default value, if present.
|
||||||
if (str_util::ConsumePrefix(&spec, "=")) {
|
if (absl::ConsumePrefix(&spec, "=")) {
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
|
VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
|
||||||
"Could not parse default value '", spec, "'");
|
"Could not parse default value '", spec, "'");
|
||||||
@ -465,7 +467,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
|
|||||||
|
|
||||||
// Remove trailing spaces.
|
// Remove trailing spaces.
|
||||||
for (string& line : lines) {
|
for (string& line : lines) {
|
||||||
str_util::StripTrailingWhitespace(&line);
|
absl::StripTrailingAsciiWhitespace(&line);
|
||||||
}
|
}
|
||||||
|
|
||||||
// First non-blank line -> summary.
|
// First non-blank line -> summary.
|
||||||
@ -485,7 +487,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
|
|||||||
int end_l = l;
|
int end_l = l;
|
||||||
// Trim trailing blank lines from the description.
|
// Trim trailing blank lines from the description.
|
||||||
while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
|
while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
|
||||||
string desc = str_util::Join(
|
string desc = absl::StrJoin(
|
||||||
gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
|
gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
|
||||||
if (!desc.empty()) op_def->set_description(desc);
|
if (!desc.empty()) op_def->set_description(desc);
|
||||||
|
|
||||||
@ -520,7 +522,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
|
|||||||
if (!description[i].empty()) description[i].remove_prefix(min_indent);
|
if (!description[i].empty()) description[i].remove_prefix(min_indent);
|
||||||
}
|
}
|
||||||
// Concatenate lines into a single string.
|
// Concatenate lines into a single string.
|
||||||
const string complete(str_util::Join(description, "\n"));
|
const string complete(absl::StrJoin(description, "\n"));
|
||||||
|
|
||||||
// Find name.
|
// Find name.
|
||||||
bool found = false;
|
bool found = false;
|
||||||
@ -651,7 +653,7 @@ Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
|
|||||||
FinalizeDoc(doc_, op_def, &errors);
|
FinalizeDoc(doc_, op_def, &errors);
|
||||||
|
|
||||||
if (errors.empty()) return Status::OK();
|
if (errors.empty()) return Status::OK();
|
||||||
return errors::InvalidArgument(str_util::Join(errors, "\n"));
|
return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -251,7 +251,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
|
|||||||
Status ValidateOpDef(const OpDef& op_def) {
|
Status ValidateOpDef(const OpDef& op_def) {
|
||||||
using ::tensorflow::strings::Scanner;
|
using ::tensorflow::strings::Scanner;
|
||||||
|
|
||||||
if (!str_util::StartsWith(op_def.name(), "_")) {
|
if (!absl::StartsWith(op_def.name(), "_")) {
|
||||||
VALIDATE(Scanner(op_def.name())
|
VALIDATE(Scanner(op_def.name())
|
||||||
.One(Scanner::UPPERLETTER)
|
.One(Scanner::UPPERLETTER)
|
||||||
.Any(Scanner::LETTER_DIGIT)
|
.Any(Scanner::LETTER_DIGIT)
|
||||||
@ -271,11 +271,11 @@ Status ValidateOpDef(const OpDef& op_def) {
|
|||||||
|
|
||||||
// Validate type
|
// Validate type
|
||||||
StringPiece type(attr.type());
|
StringPiece type(attr.type());
|
||||||
bool is_list = str_util::ConsumePrefix(&type, "list(");
|
bool is_list = absl::ConsumePrefix(&type, "list(");
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
|
for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
|
||||||
"tensor", "func"}) {
|
"tensor", "func"}) {
|
||||||
if (str_util::ConsumePrefix(&type, valid)) {
|
if (absl::ConsumePrefix(&type, valid)) {
|
||||||
found = true;
|
found = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -283,7 +283,7 @@ Status ValidateOpDef(const OpDef& op_def) {
|
|||||||
VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
|
VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
|
||||||
"'");
|
"'");
|
||||||
if (is_list) {
|
if (is_list) {
|
||||||
VALIDATE(str_util::ConsumePrefix(&type, ")"),
|
VALIDATE(absl::ConsumePrefix(&type, ")"),
|
||||||
"'list(' is missing ')' in attr ", attr.name(), "'s type ",
|
"'list(' is missing ')' in attr ", attr.name(), "'s type ",
|
||||||
attr.type());
|
attr.type());
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ class ValidateOpDefTest : public ::testing::Test {
|
|||||||
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(INFO) << "message: " << status;
|
LOG(INFO) << "message: " << status;
|
||||||
EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
|
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
|
||||||
<< "Actual: " << status << "\nExpected to contain: " << message;
|
<< "Actual: " << status << "\nExpected to contain: " << message;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
@ -55,7 +57,7 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) {
|
|||||||
while (str_util::EndsWith(to_append, " ")) {
|
while (str_util::EndsWith(to_append, " ")) {
|
||||||
to_append.remove_suffix(1);
|
to_append.remove_suffix(1);
|
||||||
}
|
}
|
||||||
while (str_util::ConsumePrefix(&str, " ")) {
|
while (absl::ConsumePrefix(&str, " ")) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go on to the next line.
|
// Go on to the next line.
|
||||||
@ -67,8 +69,8 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ConsumeEquals(StringPiece* description) {
|
bool ConsumeEquals(StringPiece* description) {
|
||||||
if (str_util::ConsumePrefix(description, "=")) {
|
if (absl::ConsumePrefix(description, "=")) {
|
||||||
while (str_util::ConsumePrefix(description,
|
while (absl::ConsumePrefix(description,
|
||||||
" ")) { // Also remove spaces after "=".
|
" ")) { // Also remove spaces after "=".
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -101,7 +103,7 @@ static bool StartsWithFieldName(StringPiece line,
|
|||||||
const std::vector<string>& multi_line_fields) {
|
const std::vector<string>& multi_line_fields) {
|
||||||
StringPiece up_to_colon;
|
StringPiece up_to_colon;
|
||||||
if (!SplitAt(':', &line, &up_to_colon)) return false;
|
if (!SplitAt(':', &line, &up_to_colon)) return false;
|
||||||
while (str_util::ConsumePrefix(&up_to_colon, " "))
|
while (absl::ConsumePrefix(&up_to_colon, " "))
|
||||||
; // Remove leading spaces.
|
; // Remove leading spaces.
|
||||||
for (const auto& field : multi_line_fields) {
|
for (const auto& field : multi_line_fields) {
|
||||||
if (up_to_colon == field) {
|
if (up_to_colon == field) {
|
||||||
@ -122,9 +124,9 @@ static bool ConvertLine(StringPiece line,
|
|||||||
StringPiece up_to_colon;
|
StringPiece up_to_colon;
|
||||||
StringPiece after_colon = line;
|
StringPiece after_colon = line;
|
||||||
SplitAt(':', &after_colon, &up_to_colon);
|
SplitAt(':', &after_colon, &up_to_colon);
|
||||||
while (str_util::ConsumePrefix(&after_colon, " "))
|
while (absl::ConsumePrefix(&after_colon, " "))
|
||||||
; // Remove leading spaces.
|
; // Remove leading spaces.
|
||||||
if (!str_util::ConsumePrefix(&after_colon, "\"")) {
|
if (!absl::ConsumePrefix(&after_colon, "\"")) {
|
||||||
// We only convert string fields, so don't convert this line.
|
// We only convert string fields, so don't convert this line.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -138,7 +140,7 @@ static bool ConvertLine(StringPiece line,
|
|||||||
// We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
|
// We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
|
||||||
|
|
||||||
string unescaped;
|
string unescaped;
|
||||||
if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
|
if (!absl::CUnescape(escaped, &unescaped, nullptr)) {
|
||||||
// Error unescaping, abort the conversion.
|
// Error unescaping, abort the conversion.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -184,9 +186,9 @@ string PBTxtToMultiline(StringPiece pbtxt,
|
|||||||
static bool FindMultiline(StringPiece line, size_t colon, string* end) {
|
static bool FindMultiline(StringPiece line, size_t colon, string* end) {
|
||||||
if (colon == StringPiece::npos) return false;
|
if (colon == StringPiece::npos) return false;
|
||||||
line.remove_prefix(colon + 1);
|
line.remove_prefix(colon + 1);
|
||||||
while (str_util::ConsumePrefix(&line, " ")) {
|
while (absl::ConsumePrefix(&line, " ")) {
|
||||||
}
|
}
|
||||||
if (str_util::ConsumePrefix(&line, "<<")) {
|
if (absl::ConsumePrefix(&line, "<<")) {
|
||||||
*end = string(line);
|
*end = string(line);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -230,7 +232,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
|
|||||||
bool first = true;
|
bool first = true;
|
||||||
while (!multiline_pbtxt.empty()) {
|
while (!multiline_pbtxt.empty()) {
|
||||||
SplitAt('\n', &multiline_pbtxt, &line);
|
SplitAt('\n', &multiline_pbtxt, &line);
|
||||||
if (str_util::ConsumePrefix(&line, end)) break;
|
if (absl::ConsumePrefix(&line, end)) break;
|
||||||
if (first) {
|
if (first) {
|
||||||
first = false;
|
first = false;
|
||||||
} else {
|
} else {
|
||||||
@ -241,7 +243,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Escape what we extracted and then output it in quotes.
|
// Escape what we extracted and then output it in quotes.
|
||||||
strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
|
strings::StrAppend(&pbtxt, " \"", absl::CEscape(unescaped), "\"", line,
|
||||||
"\n");
|
"\n");
|
||||||
}
|
}
|
||||||
return pbtxt;
|
return pbtxt;
|
||||||
@ -265,7 +267,7 @@ static void StringReplace(const string& from, const string& to, string* s) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Join the pieces back together with a new delimiter.
|
// Join the pieces back together with a new delimiter.
|
||||||
*s = str_util::Join(split, to.c_str());
|
*s = absl::StrJoin(split, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void RenameInDocs(const string& from, const string& to,
|
static void RenameInDocs(const string& from, const string& to,
|
||||||
@ -417,10 +419,10 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
|
|||||||
new_api_def.arg_order().end(),
|
new_api_def.arg_order().end(),
|
||||||
base_api_def->arg_order().begin())) {
|
base_api_def->arg_order().begin())) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
|
"Invalid arg_order: ", absl::StrJoin(new_api_def.arg_order(), ", "),
|
||||||
" for ", base_api_def->graph_op_name(),
|
" for ", base_api_def->graph_op_name(),
|
||||||
". All elements in arg_order override must match base arg_order: ",
|
". All elements in arg_order override must match base arg_order: ",
|
||||||
str_util::Join(base_api_def->arg_order(), ", "));
|
absl::StrJoin(base_api_def->arg_order(), ", "));
|
||||||
}
|
}
|
||||||
|
|
||||||
base_api_def->clear_arg_order();
|
base_api_def->clear_arg_order();
|
||||||
|
@ -104,7 +104,7 @@ OpKernel::OpKernel(OpKernelConstruction* context,
|
|||||||
input_name_map_(context->num_inputs()),
|
input_name_map_(context->num_inputs()),
|
||||||
output_name_map_(context->num_outputs()),
|
output_name_map_(context->num_outputs()),
|
||||||
graph_def_version_(context->graph_def_version()),
|
graph_def_version_(context->graph_def_version()),
|
||||||
is_internal_(str_util::StartsWith(type_string(), "_")),
|
is_internal_(absl::StartsWith(type_string(), "_")),
|
||||||
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
|
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
|
NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
|
||||||
@ -1030,7 +1030,7 @@ static Status IsProbablySafeToLoad(const string& path) {
|
|||||||
}
|
}
|
||||||
if (!missing_features.empty()) {
|
if (!missing_features.empty()) {
|
||||||
string errmsg = "Missing CPU features: ";
|
string errmsg = "Missing CPU features: ";
|
||||||
errmsg.append(str_util::Join(missing_features, ", "));
|
errmsg.append(absl::StrJoin(missing_features, ", "));
|
||||||
return Status(errors::Code::FAILED_PRECONDITION, errmsg);
|
return Status(errors::Code::FAILED_PRECONDITION, errmsg);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -579,8 +579,8 @@ TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
|
|||||||
{"T|list(type)|[DT_FLOAT]"}));
|
{"T|list(type)|[DT_FLOAT]"}));
|
||||||
|
|
||||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(
|
||||||
GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
|
absl::StrContains(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
|
||||||
"Invalid argument: "));
|
"Invalid argument: "));
|
||||||
|
|
||||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
|
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
|
||||||
@ -598,7 +598,7 @@ TEST_F(OpKernelBuilderTest, DuplicateKernel) {
|
|||||||
PrioritizedDeviceTypeVector devs;
|
PrioritizedDeviceTypeVector devs;
|
||||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status.error_message(), "Multiple OpKernel registrations match NodeDef"));
|
status.error_message(), "Multiple OpKernel registrations match NodeDef"));
|
||||||
|
|
||||||
ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
||||||
@ -618,7 +618,7 @@ TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
|
|||||||
PrioritizedDeviceTypeVector devs;
|
PrioritizedDeviceTypeVector devs;
|
||||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status.error_message(), "Multiple OpKernel registrations match NodeDef"));
|
status.error_message(), "Multiple OpKernel registrations match NodeDef"));
|
||||||
|
|
||||||
ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
|
ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
|
||||||
@ -640,7 +640,7 @@ TEST_F(OpKernelBuilderTest, BadConstraint) {
|
|||||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(),
|
absl::StrContains(status.error_message(),
|
||||||
"OpKernel 'BadConstraint' has constraint on attr "
|
"OpKernel 'BadConstraint' has constraint on attr "
|
||||||
"'T' not in NodeDef"));
|
"'T' not in NodeDef"));
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ string ResourceMgr::DebugString() const {
|
|||||||
line.type.c_str(), line.resource->c_str(), line.detail.c_str()));
|
line.type.c_str(), line.resource->c_str(), line.detail.c_str()));
|
||||||
}
|
}
|
||||||
std::sort(text.begin(), text.end());
|
std::sort(text.begin(), text.end());
|
||||||
return str_util::Join(text, "\n");
|
return absl::StrJoin(text, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
|
Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
|
||||||
|
@ -73,7 +73,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void HasError(const Status& s, const string& substr) {
|
static void HasError(const Status& s, const string& substr) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
|
EXPECT_TRUE(absl::StrContains(s.ToString(), substr))
|
||||||
<< s << ", expected substring " << substr;
|
<< s << ", expected substring " << substr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,7 +334,7 @@ string InferenceContext::DebugString(ShapeHandle s) {
|
|||||||
if (RankKnown(s)) {
|
if (RankKnown(s)) {
|
||||||
std::vector<string> vals;
|
std::vector<string> vals;
|
||||||
for (auto d : s->dims_) vals.push_back(DebugString(d));
|
for (auto d : s->dims_) vals.push_back(DebugString(d));
|
||||||
return strings::StrCat("[", str_util::Join(vals, ","), "]");
|
return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
|
||||||
} else {
|
} else {
|
||||||
return "?";
|
return "?";
|
||||||
}
|
}
|
||||||
@ -360,7 +360,7 @@ string InferenceContext::DebugString(
|
|||||||
for (const ShapeAndType& s : shape_and_types) {
|
for (const ShapeAndType& s : shape_and_types) {
|
||||||
pieces.push_back(DebugString(s));
|
pieces.push_back(DebugString(s));
|
||||||
}
|
}
|
||||||
return strings::StrCat("[", str_util::Join(pieces, ","), "]");
|
return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
|
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
|
||||||
@ -1176,15 +1176,15 @@ Status InferenceContext::AttachContext(const Status& status) {
|
|||||||
|
|
||||||
string error_context = strings::StrCat(
|
string error_context = strings::StrCat(
|
||||||
" for '", node_def_->name(), "' (op: '", node_def_->op(),
|
" for '", node_def_->name(), "' (op: '", node_def_->op(),
|
||||||
"') with input shapes: ", str_util::Join(input_shapes, ", "));
|
"') with input shapes: ", absl::StrJoin(input_shapes, ", "));
|
||||||
if (!input_from_tensors_str.empty()) {
|
if (!input_from_tensors_str.empty()) {
|
||||||
strings::StrAppend(&error_context, " and with computed input tensors: ",
|
strings::StrAppend(&error_context, " and with computed input tensors: ",
|
||||||
str_util::Join(input_from_tensors_str, ", "));
|
absl::StrJoin(input_from_tensors_str, ", "));
|
||||||
}
|
}
|
||||||
if (!input_from_tensors_as_shape_str.empty()) {
|
if (!input_from_tensors_as_shape_str.empty()) {
|
||||||
strings::StrAppend(&error_context,
|
strings::StrAppend(&error_context,
|
||||||
" and with input tensors computed as partial shapes: ",
|
" and with input tensors computed as partial shapes: ",
|
||||||
str_util::Join(input_from_tensors_as_shape_str, ","));
|
absl::StrJoin(input_from_tensors_as_shape_str, ","));
|
||||||
}
|
}
|
||||||
|
|
||||||
strings::StrAppend(&error_context, ".");
|
strings::StrAppend(&error_context, ".");
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
|||||||
};
|
};
|
||||||
Status s = c.Run(fn);
|
Status s = c.Run(fn);
|
||||||
// Extra error message is attached when Run fails.
|
// Extra error message is attached when Run fails.
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s.ToString(),
|
s.ToString(),
|
||||||
"Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')"))
|
"Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')"))
|
||||||
<< s;
|
<< s;
|
||||||
@ -367,7 +367,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
|||||||
|
|
||||||
// WithRankAtMost on shape with known dimensionality.
|
// WithRankAtMost on shape with known dimensionality.
|
||||||
s1 = in1;
|
s1 = in1;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.WithRankAtMost(in1, 2, &s1).ToString(),
|
c.WithRankAtMost(in1, 2, &s1).ToString(),
|
||||||
"Invalid argument: Shape must be at most rank 2 but is rank 3"));
|
"Invalid argument: Shape must be at most rank 2 but is rank 3"));
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
|||||||
|
|
||||||
// WithRankAtLeast on shape with known dimensionality.
|
// WithRankAtLeast on shape with known dimensionality.
|
||||||
s1 = in1;
|
s1 = in1;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.WithRankAtLeast(in1, 4, &s1).ToString(),
|
c.WithRankAtLeast(in1, 4, &s1).ToString(),
|
||||||
"Invalid argument: Shape must be at least rank 4 but is rank 3"));
|
"Invalid argument: Shape must be at least rank 4 but is rank 3"));
|
||||||
|
|
||||||
@ -448,12 +448,12 @@ TEST_F(ShapeInferenceTest, WithValue) {
|
|||||||
out1 = d0;
|
out1 = d0;
|
||||||
|
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(),
|
absl::StrContains(c.WithValue(d0, 0, &out1).ToString(),
|
||||||
"Invalid argument: Dimension must be 0 but is 1"));
|
"Invalid argument: Dimension must be 0 but is 1"));
|
||||||
EXPECT_FALSE(IsSet(out1));
|
EXPECT_FALSE(IsSet(out1));
|
||||||
out1 = d0;
|
out1 = d0;
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(),
|
absl::StrContains(c.WithValue(d0, 2, &out1).ToString(),
|
||||||
"Invalid argument: Dimension must be 2 but is 1"));
|
"Invalid argument: Dimension must be 2 but is 1"));
|
||||||
|
|
||||||
EXPECT_FALSE(IsSet(out1));
|
EXPECT_FALSE(IsSet(out1));
|
||||||
@ -513,12 +513,12 @@ TEST_F(ShapeInferenceTest, MergeDim) {
|
|||||||
EXPECT_EQ(3, merged_dims.size());
|
EXPECT_EQ(3, merged_dims.size());
|
||||||
|
|
||||||
// Merging unequal values is an error.
|
// Merging unequal values is an error.
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Merge(d2, d1, &out).ToString(),
|
c.Merge(d2, d1, &out).ToString(),
|
||||||
"Invalid argument: Dimensions must be equal, but are 2 and 1"));
|
"Invalid argument: Dimensions must be equal, but are 2 and 1"));
|
||||||
|
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Merge(d1, d2, &out).ToString(),
|
c.Merge(d1, d2, &out).ToString(),
|
||||||
"Invalid argument: Dimensions must be equal, but are 1 and 2"));
|
"Invalid argument: Dimensions must be equal, but are 1 and 2"));
|
||||||
|
|
||||||
@ -727,21 +727,21 @@ TEST_F(ShapeInferenceTest, MergeShape) {
|
|||||||
|
|
||||||
// Incompatible merges give errors and set out to nullptr.
|
// Incompatible merges give errors and set out to nullptr.
|
||||||
out = s_unknown;
|
out = s_unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Merge(s_u_2, s_1_3, &out).ToString(),
|
c.Merge(s_u_2, s_1_3, &out).ToString(),
|
||||||
"Invalid argument: Dimension 1 in both shapes must be equal, but "
|
"Invalid argument: Dimension 1 in both shapes must be equal, but "
|
||||||
"are 2 and 3"));
|
"are 2 and 3"));
|
||||||
|
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
out = s_unknown;
|
out = s_unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Merge(s_1_3, s_u_2, &out).ToString(),
|
c.Merge(s_1_3, s_u_2, &out).ToString(),
|
||||||
"Invalid argument: Dimension 1 in both shapes must be equal, but "
|
"Invalid argument: Dimension 1 in both shapes must be equal, but "
|
||||||
"are 3 and 2"));
|
"are 3 and 2"));
|
||||||
|
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
out = s_unknown;
|
out = s_unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Merge(s_1, s_1_2, &out).ToString(),
|
c.Merge(s_1, s_1_2, &out).ToString(),
|
||||||
"Invalid argument: Shapes must be equal rank, but are 1 and 2"));
|
"Invalid argument: Shapes must be equal rank, but are 1 and 2"));
|
||||||
|
|
||||||
@ -790,7 +790,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
|||||||
// Incompatible merges give errors and set outs to nullptr.
|
// Incompatible merges give errors and set outs to nullptr.
|
||||||
s_out = s_unknown;
|
s_out = s_unknown;
|
||||||
s_prefix_out = s_unknown;
|
s_prefix_out = s_unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(),
|
c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(),
|
||||||
"Invalid argument: Dimensions must be equal, but are 1 and 2"));
|
"Invalid argument: Dimensions must be equal, but are 1 and 2"));
|
||||||
|
|
||||||
@ -799,7 +799,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
|||||||
|
|
||||||
s_out = s_unknown;
|
s_out = s_unknown;
|
||||||
s_prefix_out = s_unknown;
|
s_prefix_out = s_unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(),
|
c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(),
|
||||||
"Invalid argument: Shape must be at least rank 3 but is rank 2"));
|
"Invalid argument: Shape must be at least rank 3 but is rank 2"));
|
||||||
EXPECT_FALSE(IsSet(s_out));
|
EXPECT_FALSE(IsSet(s_out));
|
||||||
@ -859,19 +859,19 @@ TEST_F(ShapeInferenceTest, Subshape) {
|
|||||||
|
|
||||||
// Errors.
|
// Errors.
|
||||||
out = unknown;
|
out = unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Subshape(in0, 6, -3, &out).ToString(),
|
c.Subshape(in0, 6, -3, &out).ToString(),
|
||||||
"Invalid argument: Subshape must have computed start <= end, but is 5 "
|
"Invalid argument: Subshape must have computed start <= end, but is 5 "
|
||||||
"and 2 (computed from start 6 and end -3 over shape with rank 5)"));
|
"and 2 (computed from start 6 and end -3 over shape with rank 5)"));
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
out = unknown;
|
out = unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(),
|
EXPECT_TRUE(absl::StrContains(c.Subshape(in0, -50, 100, &out).ToString(),
|
||||||
"Invalid argument: Subshape start out of "
|
"Invalid argument: Subshape start out of "
|
||||||
"bounds: -50, for shape with rank 5"));
|
"bounds: -50, for shape with rank 5"));
|
||||||
|
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
out = unknown;
|
out = unknown;
|
||||||
EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(),
|
EXPECT_TRUE(absl::StrContains(c.Subshape(in0, 0, -50, &out).ToString(),
|
||||||
"Invalid argument: Subshape end out of "
|
"Invalid argument: Subshape end out of "
|
||||||
"bounds: -50, for shape with rank 5"));
|
"bounds: -50, for shape with rank 5"));
|
||||||
|
|
||||||
@ -1086,31 +1086,31 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
|||||||
EXPECT_EQ("?", create(&t));
|
EXPECT_EQ("?", create(&t));
|
||||||
|
|
||||||
t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
|
t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
create(&t), "Input tensor must be int32 or int64, but was float"));
|
create(&t), "Input tensor must be int32 or int64, but was float"));
|
||||||
|
|
||||||
t = ::tensorflow::test::AsScalar<int32>(1);
|
t = ::tensorflow::test::AsScalar<int32>(1);
|
||||||
auto s_scalar = create(&t);
|
auto s_scalar = create(&t);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
s_scalar,
|
s_scalar,
|
||||||
"Input tensor must be rank 1, or if its rank 0 it must have value -1"))
|
"Input tensor must be rank 1, or if its rank 0 it must have value -1"))
|
||||||
<< s_scalar;
|
<< s_scalar;
|
||||||
|
|
||||||
t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
|
t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
|
||||||
auto s_matrix = create(&t);
|
auto s_matrix = create(&t);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(s_matrix,
|
||||||
s_matrix, "Input tensor must be rank 1, but was rank 2"))
|
"Input tensor must be rank 1, but was rank 2"))
|
||||||
<< s_matrix;
|
<< s_matrix;
|
||||||
|
|
||||||
// Test negative values for the dims.
|
// Test negative values for the dims.
|
||||||
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
|
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(create(&t),
|
||||||
create(&t), "Invalid value in tensor used for shape: -2"));
|
"Invalid value in tensor used for shape: -2"));
|
||||||
|
|
||||||
// Test negative values for the dims.
|
// Test negative values for the dims.
|
||||||
t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
|
t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(create(&t),
|
||||||
create(&t), "Invalid value in tensor used for shape: -2"));
|
"Invalid value in tensor used for shape: -2"));
|
||||||
|
|
||||||
// Test when the input shape is wrong.
|
// Test when the input shape is wrong.
|
||||||
{
|
{
|
||||||
@ -1168,8 +1168,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
|||||||
EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
|
EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
|
||||||
EXPECT_EQ("?", c.DebugString(out));
|
EXPECT_EQ("?", c.DebugString(out));
|
||||||
proto.add_dim()->set_size(0);
|
proto.add_dim()->set_size(0);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(
|
||||||
c.MakeShapeFromShapeProto(proto, &out).error_message(),
|
absl::StrContains(c.MakeShapeFromShapeProto(proto, &out).error_message(),
|
||||||
"An unknown shape must not have any dimensions set."));
|
"An unknown shape must not have any dimensions set."));
|
||||||
EXPECT_FALSE(IsSet(out));
|
EXPECT_FALSE(IsSet(out));
|
||||||
|
|
||||||
@ -1184,7 +1184,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
|||||||
|
|
||||||
// With invalid dimension value.
|
// With invalid dimension value.
|
||||||
proto.add_dim()->set_size(-2);
|
proto.add_dim()->set_size(-2);
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.MakeShapeFromShapeProto(proto, &out).error_message(),
|
c.MakeShapeFromShapeProto(proto, &out).error_message(),
|
||||||
"Shape [0,?,1000,-2] has dimensions with values below -1 "
|
"Shape [0,?,1000,-2] has dimensions with values below -1 "
|
||||||
"(where -1 means unknown)"));
|
"(where -1 means unknown)"));
|
||||||
@ -1254,7 +1254,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
|||||||
EXPECT_EQ("20", c.DebugString(d));
|
EXPECT_EQ("20", c.DebugString(d));
|
||||||
|
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
|
absl::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
|
||||||
"Dimension size, given by scalar input 1, must be "
|
"Dimension size, given by scalar input 1, must be "
|
||||||
"non-negative but is -1"));
|
"non-negative but is -1"));
|
||||||
|
|
||||||
@ -1265,7 +1265,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
|||||||
EXPECT_EQ("20", c.DebugString(d));
|
EXPECT_EQ("20", c.DebugString(d));
|
||||||
|
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
|
absl::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
|
||||||
"Dimension size, given by scalar input 1, must be "
|
"Dimension size, given by scalar input 1, must be "
|
||||||
"non-negative but is -1"));
|
"non-negative but is -1"));
|
||||||
}
|
}
|
||||||
@ -1320,18 +1320,18 @@ TEST_F(ShapeInferenceTest, Divide) {
|
|||||||
EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
|
EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
|
||||||
EXPECT_EQ("3", c.DebugString(out));
|
EXPECT_EQ("3", c.DebugString(out));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
|
||||||
"Dimension size must be evenly divisible by 5 but is 6"));
|
"Dimension size must be evenly divisible by 5 but is 6"));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
|
||||||
"Divisor must be positive but is 0"));
|
"Divisor must be positive but is 0"));
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
|
||||||
"Divisor must be positive but is 0"));
|
"Divisor must be positive but is 0"));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
|
||||||
"Divisor must be positive but is -1"));
|
"Divisor must be positive but is -1"));
|
||||||
|
|
||||||
@ -1340,11 +1340,11 @@ TEST_F(ShapeInferenceTest, Divide) {
|
|||||||
EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
|
EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
|
||||||
EXPECT_EQ("1", c.DebugString(out));
|
EXPECT_EQ("1", c.DebugString(out));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
|
||||||
"Divisor must be positive but is 0"));
|
"Divisor must be positive but is 0"));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
|
c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
|
||||||
"Divisor must be positive but is -1"));
|
"Divisor must be positive but is -1"));
|
||||||
}
|
}
|
||||||
@ -1394,7 +1394,7 @@ TEST_F(ShapeInferenceTest, Add) {
|
|||||||
EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
|
EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
|
||||||
EXPECT_TRUE(SameHandle(out, d_6));
|
EXPECT_TRUE(SameHandle(out, d_6));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
|
c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
|
||||||
"Dimension size overflow from adding 6 and 9223372036854775802"));
|
"Dimension size overflow from adding 6 and 9223372036854775802"));
|
||||||
}
|
}
|
||||||
@ -1444,7 +1444,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
|
|||||||
EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
|
EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
|
||||||
EXPECT_TRUE(SameHandle(out, d_6));
|
EXPECT_TRUE(SameHandle(out, d_6));
|
||||||
|
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
c.Subtract(d_5, d_6, &out).error_message(),
|
c.Subtract(d_5, d_6, &out).error_message(),
|
||||||
"Negative dimension size caused by subtracting 6 from 5"));
|
"Negative dimension size caused by subtracting 6 from 5"));
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (str_util::StartsWith(expected, "in")) {
|
if (absl::StartsWith(expected, "in")) {
|
||||||
if (in_index == -1) {
|
if (in_index == -1) {
|
||||||
return Unknown(err_prefix,
|
return Unknown(err_prefix,
|
||||||
" should have matched an input shape by "
|
" should have matched an input shape by "
|
||||||
@ -135,8 +135,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify the dimensions.
|
// Verify the dimensions.
|
||||||
CHECK(str_util::StartsWith(expected, "[") &&
|
CHECK(absl::StartsWith(expected, "[") && str_util::EndsWith(expected, "]"))
|
||||||
str_util::EndsWith(expected, "]"))
|
|
||||||
<< expected;
|
<< expected;
|
||||||
expected.remove_prefix(1);
|
expected.remove_prefix(1);
|
||||||
expected.remove_suffix(1);
|
expected.remove_suffix(1);
|
||||||
@ -178,7 +177,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
|||||||
return Unknown(err_prefix, " expected to be unknown but was ",
|
return Unknown(err_prefix, " expected to be unknown but was ",
|
||||||
c.Value(out_dim), err_suffix);
|
c.Value(out_dim), err_suffix);
|
||||||
}
|
}
|
||||||
} else if (str_util::StartsWith(expected_dim, "d")) {
|
} else if (absl::StartsWith(expected_dim, "d")) {
|
||||||
// Compare the dimension values.
|
// Compare the dimension values.
|
||||||
auto v = str_util::Split(expected_dim, '|');
|
auto v = str_util::Split(expected_dim, '|');
|
||||||
if (in_dim_idx.first == -1) {
|
if (in_dim_idx.first == -1) {
|
||||||
|
@ -29,8 +29,7 @@ namespace {
|
|||||||
#define EXPECT_CONTAINS(str, substr) \
|
#define EXPECT_CONTAINS(str, substr) \
|
||||||
do { \
|
do { \
|
||||||
string s = (str); \
|
string s = (str); \
|
||||||
EXPECT_TRUE(::tensorflow::str_util::StrContains(s, substr)) \
|
EXPECT_TRUE(absl::StrContains(s, substr)) << "String: " << s; \
|
||||||
<< "String: " << s; \
|
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
static OpShapeInferenceFn* global_fn_ptr = nullptr;
|
static OpShapeInferenceFn* global_fn_ptr = nullptr;
|
||||||
@ -100,7 +99,7 @@ TEST(ShapeInferenceTestutilTest, Failures) {
|
|||||||
ShapeInferenceTestOp("NoSuchOp"), "", "")
|
ShapeInferenceTestOp("NoSuchOp"), "", "")
|
||||||
.error_message();
|
.error_message();
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StartsWith(error_message, "Op type not registered 'NoSuchOp'"));
|
absl::StartsWith(error_message, "Op type not registered 'NoSuchOp'"));
|
||||||
|
|
||||||
// Wrong shape error messages.
|
// Wrong shape error messages.
|
||||||
EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0),
|
EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0),
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/resource_handle.pb.h"
|
#include "tensorflow/core/framework/resource_handle.pb.h"
|
||||||
@ -964,9 +965,9 @@ inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
|
|||||||
}
|
}
|
||||||
inline string PrintOneElement(const string& a, bool print_v2) {
|
inline string PrintOneElement(const string& a, bool print_v2) {
|
||||||
if (print_v2) {
|
if (print_v2) {
|
||||||
return "\"" + str_util::CEscape(a) + "\"";
|
return "\"" + absl::CEscape(a) + "\"";
|
||||||
} else {
|
} else {
|
||||||
return str_util::CEscape(a);
|
return absl::CEscape(a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
|
inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
|
||||||
|
@ -759,7 +759,7 @@ Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
|
|||||||
n = MultiplyWithoutOverflow(n, dim);
|
n = MultiplyWithoutOverflow(n, dim);
|
||||||
if (n < 0) {
|
if (n < 0) {
|
||||||
return errors::InvalidArgument("Can't compute total size of shape [",
|
return errors::InvalidArgument("Can't compute total size of shape [",
|
||||||
str_util::Join(shape, ","),
|
absl::StrJoin(shape, ","),
|
||||||
"]; product would overflow int64");
|
"]; product would overflow int64");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -479,7 +479,7 @@ TensorShapeIterOld TensorShapeOld::end() const {
|
|||||||
|
|
||||||
string TensorShapeOld::DebugString() const {
|
string TensorShapeOld::DebugString() const {
|
||||||
return strings::StrCat(
|
return strings::StrCat(
|
||||||
"[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]");
|
"[", absl::StrJoin(gtl::ArraySlice<int64>(dim_sizes_), ","), "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
string TensorShapeOld::DebugString(const TensorShapeProto& proto) {
|
string TensorShapeOld::DebugString(const TensorShapeProto& proto) {
|
||||||
|
@ -141,8 +141,8 @@ TEST(TypesTest, ComplexTypes) {
|
|||||||
TEST(TypesTest, IntegerTypes) {
|
TEST(TypesTest, IntegerTypes) {
|
||||||
for (auto dt : AllTypes()) {
|
for (auto dt : AllTypes()) {
|
||||||
const string name = DataTypeString(dt);
|
const string name = DataTypeString(dt);
|
||||||
EXPECT_EQ(DataTypeIsInteger(dt), str_util::StartsWith(name, "int") ||
|
EXPECT_EQ(DataTypeIsInteger(dt),
|
||||||
str_util::StartsWith(name, "uint"))
|
absl::StartsWith(name, "int") || absl::StartsWith(name, "uint"))
|
||||||
<< "DataTypeInteger failed for " << name;
|
<< "DataTypeInteger failed for " << name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -260,7 +260,7 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) {
|
|||||||
ClientSession session(root);
|
ClientSession session(root);
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = session.Run({create_const}, &outputs);
|
Status s = session.Run({create_const}, &outputs);
|
||||||
EXPECT_TRUE(str_util::StrContains(s.error_message(),
|
EXPECT_TRUE(absl::StrContains(s.error_message(),
|
||||||
"GPU copy from non-DMA string tensor"))
|
"GPU copy from non-DMA string tensor"))
|
||||||
<< s.ToString();
|
<< s.ToString();
|
||||||
}
|
}
|
||||||
@ -367,7 +367,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) {
|
|||||||
Status err = session.Run({create_op, identity}, &outputs);
|
Status err = session.Run({create_op, identity}, &outputs);
|
||||||
EXPECT_EQ(err.code(), errors::Code::INVALID_ARGUMENT);
|
EXPECT_EQ(err.code(), errors::Code::INVALID_ARGUMENT);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(err.error_message(),
|
absl::StrContains(err.error_message(),
|
||||||
"During Variant Host->Device Copy: non-DMA-copy "
|
"During Variant Host->Device Copy: non-DMA-copy "
|
||||||
"attempted of tensor type: string"))
|
"attempted of tensor type: string"))
|
||||||
<< err.error_message();
|
<< err.error_message();
|
||||||
|
@ -188,8 +188,7 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
|||||||
Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
|
Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
|
||||||
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit zeros_like"));
|
||||||
str_util::StrContains(s0.error_message(), "early exit zeros_like"));
|
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
||||||
v = vv_ok;
|
v = vv_ok;
|
||||||
@ -214,8 +213,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
|
|||||||
Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
|
Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
|
||||||
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit zeros_like"));
|
||||||
str_util::StrContains(s0.error_message(), "early exit zeros_like"));
|
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
||||||
v = vv_ok;
|
v = vv_ok;
|
||||||
@ -261,7 +259,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
|
|||||||
Status s0 = BinaryOpVariants<CPUDevice>(
|
Status s0 = BinaryOpVariants<CPUDevice>(
|
||||||
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add"));
|
EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit add"));
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
||||||
v_a = vv_ok;
|
v_a = vv_ok;
|
||||||
@ -288,7 +286,7 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
|
|||||||
Status s0 = BinaryOpVariants<GPUDevice>(
|
Status s0 = BinaryOpVariants<GPUDevice>(
|
||||||
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add"));
|
EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit add"));
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
||||||
v_a = vv_ok;
|
v_a = vv_ok;
|
||||||
|
@ -67,8 +67,8 @@ Status ValidateControlFlowInfo(const Graph* graph,
|
|||||||
// BackPropLoopCounter runs in the same frame as the backprop loop. They
|
// BackPropLoopCounter runs in the same frame as the backprop loop. They
|
||||||
// are the only cases that multiple loops share the same frame.
|
// are the only cases that multiple loops share the same frame.
|
||||||
if (frame.loop_cond &&
|
if (frame.loop_cond &&
|
||||||
!str_util::StrContains(frame.loop_cond->name(), "LoopCounter") &&
|
!absl::StrContains(frame.loop_cond->name(), "LoopCounter") &&
|
||||||
!str_util::StrContains(node->name(), "LoopCounter")) {
|
!absl::StrContains(node->name(), "LoopCounter")) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid loop structure: Loop \"", cf.frame_name,
|
"Invalid loop structure: Loop \"", cf.frame_name,
|
||||||
"\" has more than one LoopCond node: ", FormatNodeForError(*node),
|
"\" has more than one LoopCond node: ", FormatNodeForError(*node),
|
||||||
|
@ -60,17 +60,17 @@ TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
|
|||||||
std::vector<ControlFlowInfo> info;
|
std::vector<ControlFlowInfo> info;
|
||||||
Status status = BuildControlFlowInfo(graph.get(), &info);
|
Status status = BuildControlFlowInfo(graph.get(), &info);
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"has inputs from different frames"))
|
"has inputs from different frames"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"{{node outer/body/inner/Merge}}"))
|
"{{node outer/body/inner/Merge}}"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"{{node outer/body/inner/Enter}}"))
|
"{{node outer/body/inner/Enter}}"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(), "{{node outer/Switch}}"))
|
absl::StrContains(status.error_message(), "{{node outer/Switch}}"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,9 +109,9 @@ TEST(ValidateControlFlowTest, MismatchedParentFrames) {
|
|||||||
status = BuildControlFlowInfo(graph.get(), &info);
|
status = BuildControlFlowInfo(graph.get(), &info);
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(), "Mismatched parent frames"))
|
absl::StrContains(status.error_message(), "Mismatched parent frames"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Enter2}}"))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Enter2}}"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,14 +133,13 @@ TEST(ValidateControlFlowTest, TwoLoopCond) {
|
|||||||
std::vector<ControlFlowInfo> info;
|
std::vector<ControlFlowInfo> info;
|
||||||
Status status = BuildControlFlowInfo(graph.get(), &info);
|
Status status = BuildControlFlowInfo(graph.get(), &info);
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
EXPECT_TRUE(
|
||||||
"more than one LoopCond node"))
|
absl::StrContains(status.error_message(), "more than one LoopCond node"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
str_util::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
|
absl::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node LoopCond}}"))
|
||||||
str_util::StrContains(status.error_message(), "{{node LoopCond}}"))
|
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ Status GraphConstructor::BuildNodeIndex() {
|
|||||||
bool in_control_dependence = false;
|
bool in_control_dependence = false;
|
||||||
for (int i = 0; i < node_def.input_size(); ++i) {
|
for (int i = 0; i < node_def.input_size(); ++i) {
|
||||||
StringPiece input_name = node_def.input(i);
|
StringPiece input_name = node_def.input(i);
|
||||||
if (!input_name.empty() && str_util::StartsWith(input_name, "^")) {
|
if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
|
||||||
in_control_dependence = true;
|
in_control_dependence = true;
|
||||||
} else if (in_control_dependence) {
|
} else if (in_control_dependence) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -535,7 +535,7 @@ Status GraphConstructor::InitFromEdges() {
|
|||||||
bool has_loop_back_edge = false;
|
bool has_loop_back_edge = false;
|
||||||
for (int i = 0; i < node_def.input_size(); ++i) {
|
for (int i = 0; i < node_def.input_size(); ++i) {
|
||||||
StringPiece input_name(node_def.input(i));
|
StringPiece input_name(node_def.input(i));
|
||||||
if (str_util::StartsWith(input_name, "^")) {
|
if (absl::StartsWith(input_name, "^")) {
|
||||||
num_control_edges++;
|
num_control_edges++;
|
||||||
} else {
|
} else {
|
||||||
TensorId id(ParseTensorName(input_name));
|
TensorId id(ParseTensorName(input_name));
|
||||||
@ -585,7 +585,7 @@ Status GraphConstructor::ValidateColocationConstraints(
|
|||||||
if (iter == node_def.attr().end()) return Status::OK();
|
if (iter == node_def.attr().end()) return Status::OK();
|
||||||
for (const string& c : iter->second.list().s()) {
|
for (const string& c : iter->second.list().s()) {
|
||||||
StringPiece s(c);
|
StringPiece s(c);
|
||||||
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) &&
|
if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
|
||||||
gdef_nodes_.find(s) == gdef_nodes_.end()) {
|
gdef_nodes_.find(s) == gdef_nodes_.end()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Node '", node_def.name(),
|
"Node '", node_def.name(),
|
||||||
@ -824,7 +824,7 @@ void GraphConstructor::AddPrefixToNodeDef(
|
|||||||
// imported).
|
// imported).
|
||||||
if (input_already_exists[i]) continue;
|
if (input_already_exists[i]) continue;
|
||||||
StringPiece input(node_def->input(i));
|
StringPiece input(node_def->input(i));
|
||||||
if (str_util::ConsumePrefix(&input, "^")) {
|
if (absl::ConsumePrefix(&input, "^")) {
|
||||||
node_def->set_input(i, strings::StrCat("^", prefix_, input));
|
node_def->set_input(i, strings::StrCat("^", prefix_, input));
|
||||||
} else {
|
} else {
|
||||||
node_def->set_input(i, strings::StrCat(prefix_, input));
|
node_def->set_input(i, strings::StrCat(prefix_, input));
|
||||||
@ -836,7 +836,7 @@ void GraphConstructor::AddPrefixToNodeDef(
|
|||||||
node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
|
node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
|
||||||
for (int i = 0; i < list->s_size(); ++i) {
|
for (int i = 0; i < list->s_size(); ++i) {
|
||||||
StringPiece v(list->s(i));
|
StringPiece v(list->s(i));
|
||||||
if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) {
|
if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
|
||||||
list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
|
list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -879,7 +879,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
|
|||||||
bool updated = false;
|
bool updated = false;
|
||||||
for (int i = 0; i < coloc_values.size(); ++i) {
|
for (int i = 0; i < coloc_values.size(); ++i) {
|
||||||
StringPiece val(coloc_values[i]);
|
StringPiece val(coloc_values[i]);
|
||||||
if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
|
if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
|
||||||
auto name_pair = uniquified_names_.find(string(val));
|
auto name_pair = uniquified_names_.find(string(val));
|
||||||
if (name_pair == uniquified_names_.end()) continue;
|
if (name_pair == uniquified_names_.end()) continue;
|
||||||
updated = true;
|
updated = true;
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user