Use GetNodeAttrSimple() when it is possible that the attr is not present.

In the Status-returning GetNodeAttr(), constructing an `errors::NotFound()` when the attr is not present involves expensive string concatenation.

Additionally, change GetNodeAttr() to GetNodeAttrString() on hot codepaths (e.g. `Executor::PropagateOutputs()`) to avoid copying a string on each call, and add overloads of GetNodeAttrSimple() that enable accessing const-pointers to non-POD types in the AttrValue proto without copying them.

PiperOrigin-RevId: 261141528
This commit is contained in:
Derek Murray 2019-08-01 09:59:07 -07:00 committed by TensorFlower Gardener
parent 01ebab4044
commit 6a42e239dc
27 changed files with 222 additions and 77 deletions

View File

@ -272,7 +272,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
std::unordered_set<string> current_constraints(colocation_constraints_);
const AttrSlice attrs = colocate_with_op.node()->attrs();
std::vector<string> node_constraints;
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
if (GetNodeAttrSimple(attrs, kColocationAttrName, &node_constraints)) {
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {

View File

@ -1317,7 +1317,7 @@ Status EncapsulateSubgraphsPass::Run(
bool IsXlaCompiledKernel(const Node& node) {
bool is_compiled = false;
bool has_compilation_attr =
GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
GetNodeAttrSimple(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
is_compiled;
return has_compilation_attr ? is_compiled : false;
}

View File

@ -245,8 +245,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// while iterating.
std::vector<Node*> launch_nodes;
for (Node* n : graph->nodes()) {
string name;
if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
if (!name.empty()) {
launch_nodes.push_back(n);
}
}

View File

@ -913,10 +913,9 @@ xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
for (Node* n : g.op_nodes()) {
bool is_lifted_arg;
string outside_compilation_attr;
if (GetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg).ok() &&
GetNodeAttr(n->def(), "_xla_outside_compilation",
&outside_compilation_attr)
.ok()) {
if (GetNodeAttrSimple(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
GetNodeAttrSimple(n->def(), "_xla_outside_compilation",
&outside_compilation_attr)) {
TF_RET_CHECK(is_lifted_arg);
TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
outside_compilation_attr_to_node[outside_compilation_attr] = n;

View File

@ -677,7 +677,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
}
DataType dtype;
if (!GetNodeAttr(n->def(), "dtype", &dtype).ok() ||
if (!GetNodeAttrSimple(n->def(), "dtype", &dtype) ||
!DataTypeIsInteger(dtype)) {
return false;
}
@ -695,7 +695,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
}
const TensorProto* proto = nullptr;
if (!GetNodeAttr(const_input->def(), "value", &proto).ok()) {
if (!GetNodeAttrSimple(const_input->def(), "value", &proto)) {
return false;
}
@ -935,8 +935,8 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
return absl::nullopt;
}
string scope;
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
if (!scope.empty()) {
return scope;
}
@ -999,7 +999,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
bool is_xla_compile_attr_true = false;
bool xla_compile_attr;
if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) {
if (GetNodeAttrSimple(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
is_xla_compile_attr_true |= xla_compile_attr;
}

View File

@ -52,7 +52,7 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
string cluster;
if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
if (GetNodeAttrSimple(node->attrs(), kXlaClusterAttr, &cluster)) {
CHECK(!cluster.empty());
ids[node->name()] = cluster;
}

View File

@ -135,7 +135,7 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
if (constant_value) {
const TensorProto* proto = nullptr;
if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
if (!GetNodeAttrSimple(node->def(), "value", &proto)) {
if (listener->IsInterested()) {
*listener << "\ncould not find \"value\" attribute in node";
}

View File

@ -85,9 +85,9 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass {
// With `parallel_iterations == 1` it's safe to use TemporaryVariable.
if (is_in_while_loop) {
int parallel_iterations;
Status s = GetNodeAttr(frame->attrs(), kParallelIterationsAttrName,
&parallel_iterations);
if (s.ok() && parallel_iterations == 1) {
bool found = GetNodeAttrSimple(
frame->attrs(), kParallelIterationsAttrName, &parallel_iterations);
if (found && parallel_iterations == 1) {
is_in_while_loop = false;
}
}
@ -112,8 +112,8 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass {
// The pieces of AccumulateNV2 should all be on the same node.
node_builder.Device(n->requested_device());
string colo;
if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
if (!colo.empty()) {
node_builder.Attr(kColocationAttrName, colo);
}
return node_builder;
@ -261,8 +261,8 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass {
.Attr("T", dtype)
.Input(data_inputs)
.ControlInputs(control_inputs);
string colo;
if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
if (!colo.empty()) {
builder.Attr(kColocationAttrName, colo);
}
TF_RETURN_IF_ERROR(builder.Finalize(g, &add_n_node));

View File

@ -2548,9 +2548,9 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
const Node* node,
FrameState** child) {
// Get the child frame name.
string enter_name;
Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
DCHECK(s.ok()) << s;
const string& enter_name = GetNodeAttrString(node->attrs(), "frame_name");
DCHECK(!enter_name.empty())
<< "Could not find \"frame_name\" attr in node " << node->name();
const string child_name = MakeFrameName(frame, iter, enter_name);
{
@ -2567,8 +2567,10 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
if (vlog_) VLOG(2) << "Create frame: " << child_name;
int parallel_iters;
s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
DCHECK(s.ok()) << s;
bool found_parallel_iters =
GetNodeAttrSimple(node->attrs(), "parallel_iterations", &parallel_iters);
DCHECK(found_parallel_iters)
<< "Could not find \"parallel_iterations\" attr in node " << node->name();
FrameState* temp = new FrameState(impl_, parallel_iters);
temp->frame_name = child_name;
temp->frame_id = Hash64(child_name);

View File

@ -1654,7 +1654,7 @@ namespace {
Status ValidateNoInline(const FunctionBody* fbody) {
const auto attr = AttrSlice(&fbody->fdef.attr());
bool noinline = false;
if (GetNodeAttr(attr, kNoInlineAttr, &noinline).ok() && noinline) {
if (GetNodeAttrSimple(attr, kNoInlineAttr, &noinline) && noinline) {
return errors::InvalidArgument(
"Can't inline function marked with '_noinline'");
}

View File

@ -466,8 +466,8 @@ Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
// All the node types handled here have their output datatype set in
// either attribute 'dtype' or 'T'.
if (!GetNodeAttr(node, "dtype", type).ok() &&
!GetNodeAttr(node, "T", type).ok()) {
if (!GetNodeAttrSimple(node, "dtype", type) &&
!GetNodeAttrSimple(node, "T", type)) {
return errors::InvalidArgument(
"Could not determine output type for feed node: ", node.name(),
" of type ", node.op());

View File

@ -33,8 +33,9 @@ bool LowerAsMultiDeviceFunction(const Node* n) {
if (n->IsPartitionedCall()) return true;
bool match;
Status s = GetNodeAttr(n->attrs(), kLowerAsMultiDeviceFunctionAttr, &match);
return s.ok() && match;
bool found =
GetNodeAttrSimple(n->attrs(), kLowerAsMultiDeviceFunctionAttr, &match);
return found && match;
}
} // namespace

View File

@ -40,15 +40,15 @@ constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
// Checks if boolean attribute is defined and it's value is 'true'.
bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
bool match;
Status s = GetNodeAttr(n->attrs(), attr_name, &match);
return s.ok() && match;
bool found = GetNodeAttrSimple(n->attrs(), attr_name, &match);
return found && match;
}
// Checks if string attribute is defined and it's not empty.
bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
string match;
Status s = GetNodeAttr(n->attrs(), attr_name, &match);
return s.ok() && !match.empty();
bool found = GetNodeAttrSimple(n->attrs(), attr_name, &match);
return found && !match.empty();
}
bool LowerUsingSwitchMergeIsOn(const Node* n) {

View File

@ -55,8 +55,8 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
NodeDebugInfo debug_info(*n);
NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
node_builder.Device(n->requested_device());
string colo;
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
const string& colo = GetNodeAttrString(n_attrs, "_class");
if (!colo.empty()) {
node_builder.Attr("_class", colo);
}
return node_builder;

View File

@ -492,6 +492,13 @@ void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
}
}
void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
out->mutable_list()->Clear(); // Create list() even if value empty.
for (auto& v : value) {
out->mutable_list()->add_s(std::move(v));
}
}
void SetAttrValue(const TensorShape& value, AttrValue* out) {
value.AsProto(out->mutable_shape());
}

View File

@ -87,6 +87,8 @@ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
void SetAttrValue(const AttrValue& value, AttrValue* out);
void MoveAttrValue(std::vector<string>&& value, AttrValue* out);
// Returns true if a and b have the same value.
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);

View File

@ -621,7 +621,7 @@ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
strings::StrAppend(&out, "\n(");
auto get_type_and_device = [](const NodeDef& n) {
DataType dt;
if (!GetNodeAttr(n, "T", &dt).ok()) {
if (!GetNodeAttrSimple(n, "T", &dt)) {
dt = DT_INVALID;
}
if (!n.device().empty()) {
@ -1389,7 +1389,7 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
// Foo's attributes.
const NameAttrList* forward_func_attrs;
if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
if (!GetNodeAttrSimple(ndef, kFuncAttr, &forward_func_attrs)) {
return nullptr;
}
const string& func_name = forward_func_attrs->name();
@ -1434,7 +1434,7 @@ template <typename T>
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
const string& attr, T* value) const {
const FunctionDef* fdef = GetAttrImpl(ndef);
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
if (fdef && GetNodeAttrSimple(AttrSlice(&fdef->attr()), attr, value)) {
return Status::OK();
}
return errors::InvalidArgument("Attr ", attr, " is not defined.");

View File

@ -156,14 +156,14 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
}
std::vector<int32> hostmem_attr;
if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
if (GetNodeAttrSimple(ndef, "_input_hostmem", &hostmem_attr)) {
for (int32 i : hostmem_attr) {
if (0 <= i && i < inp_mtypes->size()) {
(*inp_mtypes)[i] = HOST_MEMORY;
}
}
}
if (GetNodeAttr(ndef, "_output_hostmem", &hostmem_attr).ok()) {
if (GetNodeAttrSimple(ndef, "_output_hostmem", &hostmem_attr)) {
for (int32 i : hostmem_attr) {
if (0 <= i && i < out_mtypes->size()) {
(*out_mtypes)[i] = HOST_MEMORY;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/op.h"
@ -243,6 +244,7 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
const AttrValue* attr_value; \
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
value->reserve(attr_value->list().FIELD().size()); \
for (const auto& v : attr_value->list().FIELD()) { \
__VA_ARGS__; \
value->APPEND_OP(CAST); \
@ -276,6 +278,7 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
if (!s.ok()) { \
return false; \
} \
value->reserve(attr_value->list().FIELD().size()); \
for (const auto& v : attr_value->list().FIELD()) { \
__VA_ARGS__; \
value->APPEND_OP(CAST); \
@ -286,22 +289,50 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
DEFINE_GET_ATTR_SIMPLE(int64, i, "int", emplace_back, v, ;)
DEFINE_GET_ATTR(
int32, i, "int", emplace_back, static_cast<int32>(v),
if (static_cast<int64>(static_cast<int32>(v)) != v) {
return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
" out of range for an int32");
})
DEFINE_GET_ATTR_SIMPLE(
int32, i, "int", emplace_back, static_cast<int32>(v),
if (static_cast<int64>(static_cast<int32>(v)) != v) {
static int log_counter = 0;
if (log_counter < 10) {
log_counter++;
LOG(WARNING) << "Attr " << attr_name << " has value " << v
<< " out of range for an int32";
}
return false;
})
DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
DEFINE_GET_ATTR_SIMPLE(float, f, "float", emplace_back, v, ;)
// std::vector<bool> specialization does not have emplace_back until
// c++14, so we have to use push_back (see
// http://en.cppreference.com/w/cpp/container/vector/emplace_back)
DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
DEFINE_GET_ATTR_SIMPLE(bool, b, "bool", push_back, v, ;)
DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
;)
DEFINE_GET_ATTR_SIMPLE(DataType, type, "type", emplace_back,
static_cast<DataType>(v),
;)
DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
DEFINE_GET_ATTR_SIMPLE(
TensorShape, shape, "shape", emplace_back, TensorShape(v),
if (!TensorShape::IsValidShape(v).ok()) {
static int log_counter = 0;
if (log_counter < 10) {
log_counter++;
LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
<< v.DebugString();
}
return false;
})
DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
PartialTensorShape(v),
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
@ -332,6 +363,40 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
return attr_value->s();
}
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<const string*>* value) {
const AttrValue* attr_value = attrs.Find(attr_name);
if (attr_value == nullptr) {
return false;
}
Status s = AttrValueHasType(*attr_value, "list(string)");
if (!s.ok()) {
return false;
}
value->reserve(attr_value->list().s().size());
for (const auto& v : attr_value->list().s()) {
value->push_back(&v);
}
return true;
}
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<const TensorShapeProto*>* value) {
const AttrValue* attr_value = attrs.Find(attr_name);
if (attr_value == nullptr) {
return false;
}
Status s = AttrValueHasType(*attr_value, "list(shape)");
if (!s.ok()) {
return false;
}
value->reserve(attr_value->list().shape().size());
for (const auto& v : attr_value->list().shape()) {
value->push_back(&v);
}
return true;
}
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value) {
const AttrValue* attr_value;
@ -352,6 +417,20 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
return Status::OK();
}
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value) {
const AttrValue* attr_value = attrs.Find(attr_name);
if (attr_value == nullptr) {
return false;
}
Status s = AttrValueHasType(*attr_value, "tensor");
if (!s.ok()) {
return false;
}
*value = &attr_value->tensor();
return true;
}
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value) {
const AttrValue* attr_value;
@ -361,6 +440,20 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
return Status::OK();
}
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value) {
const AttrValue* attr_value = attrs.Find(attr_name);
if (attr_value == nullptr) {
return false;
}
Status s = AttrValueHasType(*attr_value, "func");
if (!s.ok()) {
return false;
}
*value = &attr_value->func();
return true;
}
namespace { // Helper for InOutTypesForNode().
template <class NodeDefOrAttrSlice>

View File

@ -235,11 +235,15 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
// These versions copies the NameAttrList(s).
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
@ -253,7 +257,41 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type: "string"
int64* value); // type: "int"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type: "int"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
int32* value); // type: "int"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
float* value); // type: "float"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
bool* value); // type: "bool"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
DataType* value); // type: "type"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
TensorShape* value); // type: "shape"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type: "list(string)"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type: "list(int)"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<float>* value); // type: "list(float)"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<bool>* value); // type: "list(bool)"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<DataType>* value); // type: "list(type)"
bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShape> value); // type: "shape"
// Overloads of GetNodeAttrSimple() that avoid copying the non-POD attribute
// values.
bool GetNodeAttrSimple(
const AttrSlice& attrs, StringPiece attr_name,
std::vector<const string*>* value); // type: "list(string)"
bool GetNodeAttrSimple(
const AttrSlice& attrs, StringPiece attr_name,
std::vector<const TensorShapeProto*>* value); // type: "list(shape)"
// Look up the attr with name attr_name and return a reference to its value.
// If no attr with attr_name is found in node_def, or the attr does not have

View File

@ -189,6 +189,11 @@ class Node {
UpdateProperties();
}
void AddAttr(const string& name, std::vector<string>&& val) {
MoveAttrValue(std::move(val), AddAttrHelper(name));
UpdateProperties();
}
void ClearAttr(const string& name);
// Returns into '*e' the edge connecting to the 'idx' input of this Node.

View File

@ -728,9 +728,9 @@ Status GraphConstructor::ValidateShape(Node* node) {
if (!opts_.importing || !opts_.validate_shape) return Status::OK();
TF_RETURN_IF_ERROR(refiner_->AddNode(node));
// For nodes with the _output_shapes attribute, override the shape.
std::vector<TensorShapeProto> shape_attrs;
std::vector<const TensorShapeProto*> shape_attrs;
const char* kAttrName = "_output_shapes";
if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
if (!GetNodeAttrSimple(node->attrs(), kAttrName, &shape_attrs)) {
// No _output_shapes attribute, the AddNode call above was sufficient.
return Status::OK();
}
@ -753,7 +753,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
<< " outputs. Output shapes may be inaccurate.";
}
for (int i = 0; i < node->num_outputs(); ++i) {
const TensorShapeProto& p = shape_attrs[i];
const TensorShapeProto& p = *shape_attrs[i];
shape_inference::ShapeHandle h;
Status s = ic->MakeShapeFromShapeProto(p, &h);
if (!s.ok()) {
@ -772,7 +772,6 @@ Status GraphConstructor::ValidateShape(Node* node) {
// This is an escape hatch that allows us to correct shape
// functions that are not critical to correct execution but
// would cause graphs to fail if imported after correcting.
//
const string& op = node->type_string();
const std::vector<string> whitelist = {
// To be removed after 2017/03/08.
@ -991,11 +990,10 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
Node* node = pair.second.node;
if (node == nullptr) continue;
std::vector<string> coloc_values;
Status status =
GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
if (!status.ok()) continue;
if (!GetNodeAttrSimple(node->attrs(), kColocationAttrName, &coloc_values))
continue;
bool updated = false;
for (int i = 0; i < coloc_values.size(); ++i) {
for (size_t i = 0; i < coloc_values.size(); ++i) {
StringPiece val(coloc_values[i]);
if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
auto name_pair = uniquified_names_.find(string(val));
@ -1006,7 +1004,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
}
}
if (updated) {
node->AddAttr(kColocationAttrName, coloc_values);
node->AddAttr(kColocationAttrName, std::move(coloc_values));
}
}
}

View File

@ -947,13 +947,13 @@ void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
// Not related to send/recv.
return;
}
string send_device;
if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
const string& send_device = GetNodeAttrString(*ndef, "send_device");
if (send_device.empty()) {
// No known send_device. The runtime will detect it later.
return;
}
int64 incarnation = PartitionOptions::kIllegalIncarnation;
if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
if (!GetNodeAttrSimple(*ndef, "send_device_incarnation", &incarnation) ||
(incarnation == PartitionOptions::kIllegalIncarnation)) {
incarnation = opts.get_incarnation(send_device);
SetAttrValue(incarnation,

View File

@ -1479,7 +1479,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
DCHECK(n);
float alpha;
bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
bool has_attr = GetNodeAttrSimple(n->def(), "alpha", &alpha);
DCHECK(has_attr);
// If the alpha of LeakyRelu is less than 1, rewrite the node.
@ -1542,7 +1542,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
// it includes those we support.
DataType T;
if (!GetNodeAttr(n->def(), "T", &T).ok() ||
if (!GetNodeAttrSimple(n->def(), "T", &T) ||
!mkl_op_registry::IsMklLayoutDependentOp(csinfo_.mkl_fused_conv2d, T)) {
return false;
}
@ -1932,7 +1932,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
// If this is an MKL op, then it will create extra output for MKL layout.
DataType T;
if (GetNodeAttr(n->def(), "T", &T).ok() &&
if (GetNodeAttrSimple(n->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
// If this is an MKL op, then it will generate an edge that will receive
// Mkl tensor from a node.
@ -3428,13 +3428,13 @@ MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
DataType Tinput, Tfilter;
bool type_attrs_present = false;
if (GetNodeAttr(n->def(), "Tinput", &Tinput).ok() &&
GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok() &&
if (GetNodeAttrSimple(n->def(), "Tinput", &Tinput) &&
GetNodeAttrSimple(n->def(), "Tfilter", &Tfilter) &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
type_attrs_present = true;
} else if (GetNodeAttr(n->def(), "T1", &T1).ok() &&
GetNodeAttr(n->def(), "T2", &T2).ok() &&
} else if (GetNodeAttrSimple(n->def(), "T1", &T1) &&
GetNodeAttrSimple(n->def(), "T2", &T2) &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
type_attrs_present = true;
@ -3465,7 +3465,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
// MklRelu if type is INT32.
DataType T;
if (!GetNodeAttr(n->def(), "T", &T).ok()) {
if (!GetNodeAttrSimple(n->def(), "T", &T)) {
return nullptr;
}
@ -3721,7 +3721,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// If graph node is not Mkl node, then return.
DataType T = DT_INVALID;
if (!GetNodeAttr(n->def(), "T", &T).ok() ||
if (!GetNodeAttrSimple(n->def(), "T", &T) ||
!mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
return result;
}
@ -3746,7 +3746,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
// node, then we don't need to do anything.
Node* e_src = e->src();
if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
if (GetNodeAttrSimple(e_src->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
// Source node for edge 'e' is Mkl node.
// Destination node and destination input slot of e is node 'n' and 'idx'

View File

@ -132,7 +132,7 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
// Do not remove ops with attribute _grappler_do_not_remove. This is useful
// for debugging.
bool do_not_remove;
if (GetNodeAttr(attrs, "_grappler_do_not_remove", &do_not_remove).ok() &&
if (GetNodeAttrSimple(attrs, "_grappler_do_not_remove", &do_not_remove) &&
do_not_remove) {
result.insert(node.name());
}

View File

@ -146,7 +146,7 @@ class FakeDevice : public Device {
bool MarkedNoSpecialize(const FunctionDef& fdef) {
const auto attr = AttrSlice(&fdef.attr());
bool nospecialize = false;
return GetNodeAttr(attr, kNoSpecializeAttr, &nospecialize).ok() &&
return GetNodeAttrSimple(attr, kNoSpecializeAttr, &nospecialize) &&
nospecialize;
}
@ -787,15 +787,14 @@ using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
// Checks if boolean attribute is defined and its value is 'true'.
bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
bool match;
Status s = GetNodeAttr(n->attrs(), attr_name, &match);
return s.ok() && match;
bool found = GetNodeAttrSimple(n->attrs(), attr_name, &match);
return found && match;
}
// Checks if string attribute is defined and it's not empty.
bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
string match;
Status s = GetNodeAttr(n->attrs(), attr_name, &match);
return s.ok() && !match.empty();
const string& value = GetNodeAttrString(n->attrs(), attr_name);
return !value.empty();
}
bool LowerUsingSwitchMergeIsOn(const Node* n) {

View File

@ -456,7 +456,8 @@ bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx, int node_index,
// Squeeze must not squeeze output channel dimension.
std::vector<int32> dims;
if (!GetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims).ok()) return false;
if (!GetNodeAttrSimple(*squeeze_node_def, "squeeze_dims", &dims))
return false;
for (auto dim : dims) {
if (dim == 3) return false;
}
@ -531,7 +532,7 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
// We successfully found a Conv2D+FusedBatchNorm pattern.
matched->contraction = conv2d_node_view->node_index();
matched->fused_batch_norm = node_index;
if (!GetNodeAttr(*node_def, "epsilon", &matched->epsilon).ok()) return false;
if (!GetNodeAttrSimple(*node_def, "epsilon", &matched->epsilon)) return false;
return true;
}
@ -684,7 +685,7 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
// Check that the node is in inference mode.
bool is_training = true;
if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok()) return false;
if (!GetNodeAttrSimple(*node_def, kIsTraining, &is_training)) return false;
if (is_training) return false;
const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
@ -1477,7 +1478,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
bool is_training = true;
if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok()) return false;
if (!GetNodeAttrSimple(*node_def, kIsTraining, &is_training)) return false;
if (is_training) return false;
return true;