Avoid direct access to Node::def() where some other method works.
PiperOrigin-RevId: 163704839
This commit is contained in:
parent
1560c55d2d
commit
4ec29c5d95
@ -1030,7 +1030,9 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
|||||||
NodeDef* ndef = gdef->add_node();
|
NodeDef* ndef = gdef->add_node();
|
||||||
ndef->set_name(NewName(n, pretty));
|
ndef->set_name(NewName(n, pretty));
|
||||||
ndef->set_op(n->type_string());
|
ndef->set_op(n->type_string());
|
||||||
*(ndef->mutable_attr()) = n->def().attr();
|
for (const auto& attr : n->attrs()) {
|
||||||
|
(*ndef->mutable_attr())[attr.first] = attr.second;
|
||||||
|
}
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
inputs.resize(n->num_inputs());
|
inputs.resize(n->num_inputs());
|
||||||
for (const Edge* e : n->in_edges()) {
|
for (const Edge* e : n->in_edges()) {
|
||||||
|
@ -147,7 +147,7 @@ class ColocationGraph {
|
|||||||
// attribute with the calls to ColocateNodeToGroup.
|
// attribute with the calls to ColocateNodeToGroup.
|
||||||
bool found_spec = false;
|
bool found_spec = false;
|
||||||
const AttrValue* attr_value =
|
const AttrValue* attr_value =
|
||||||
AttrSlice(node->def()).Find(kColocationAttrNameStringPiece);
|
node->attrs().Find(kColocationAttrNameStringPiece);
|
||||||
if (attr_value != nullptr && attr_value->has_list()) {
|
if (attr_value != nullptr && attr_value->has_list()) {
|
||||||
for (const string& class_spec : attr_value->list().s()) {
|
for (const string& class_spec : attr_value->list().s()) {
|
||||||
StringPiece spec(class_spec);
|
StringPiece spec(class_spec);
|
||||||
@ -184,7 +184,7 @@ class ColocationGraph {
|
|||||||
// error, return it.
|
// error, return it.
|
||||||
Status s = ColocateNodes(*node, *root_node);
|
Status s = ColocateNodes(*node, *root_node);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return AttachDef(s, node->def());
|
return AttachDef(s, *node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -418,7 +418,7 @@ class ColocationGraph {
|
|||||||
}
|
}
|
||||||
Status status = InitializeMember(*node, &members_[node->id()]);
|
Status status = InitializeMember(*node, &members_[node->id()]);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return AttachDef(status, node->def());
|
return AttachDef(status, *node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -727,7 +727,7 @@ Status SimplePlacer::Run() {
|
|||||||
"be on the same device), but the two nodes "
|
"be on the same device), but the two nodes "
|
||||||
"were assigned two different devices: ",
|
"were assigned two different devices: ",
|
||||||
status.error_message()),
|
status.error_message()),
|
||||||
dst->def());
|
*dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,9 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
|
|||||||
AddNodeAttr("Tout", n->input_types(), &ndef);
|
AddNodeAttr("Tout", n->input_types(), &ndef);
|
||||||
NameAttrList func;
|
NameAttrList func;
|
||||||
func.set_name(n->type_string());
|
func.set_name(n->type_string());
|
||||||
*(func.mutable_attr()) = n->def().attr();
|
for (const auto& attr : n->attrs()) {
|
||||||
|
(*func.mutable_attr())[attr.first] = attr.second;
|
||||||
|
}
|
||||||
AddNodeAttr("f", func, &ndef);
|
AddNodeAttr("f", func, &ndef);
|
||||||
Status s;
|
Status s;
|
||||||
Node* ret = g->AddNode(ndef, &s);
|
Node* ret = g->AddNode(ndef, &s);
|
||||||
|
@ -189,8 +189,9 @@ bool OptimizerCSE::Optimize(
|
|||||||
if (!n->IsOp()) continue;
|
if (!n->IsOp()) continue;
|
||||||
|
|
||||||
// Don't prune placeholder nodes.
|
// Don't prune placeholder nodes.
|
||||||
if (n->def().op() == "Placeholder" || n->def().op() == "PlaceholderV2" ||
|
if (n->type_string() == "Placeholder" ||
|
||||||
n->def().op() == "PlaceholderWithDefault") {
|
n->type_string() == "PlaceholderV2" ||
|
||||||
|
n->type_string() == "PlaceholderWithDefault") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
|
|||||||
std::vector<DataType> data_types;
|
std::vector<DataType> data_types;
|
||||||
std::vector<TensorShape> shapes;
|
std::vector<TensorShape> shapes;
|
||||||
Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
|
Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
|
||||||
original_input_node->def(), &data_types, &shapes);
|
original_input_node->attrs(), &data_types, &shapes);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
created_node->AddAttr(
|
created_node->AddAttr(
|
||||||
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types);
|
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types);
|
||||||
@ -579,7 +579,7 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GraphTransferer::NeedsToAddRank(const Node& node) {
|
bool GraphTransferer::NeedsToAddRank(const Node& node) {
|
||||||
const string& op_type = node.def().op();
|
const StringPiece op_type(node.type_string());
|
||||||
if (op_type == "Transpose" || op_type == "ExpandDims") {
|
if (op_type == "Transpose" || op_type == "ExpandDims") {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -587,7 +587,7 @@ bool GraphTransferer::NeedsToAddRank(const Node& node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GraphTransferer::IsPadNode(const Node& node) {
|
bool GraphTransferer::IsPadNode(const Node& node) {
|
||||||
const string& op_type = node.def().op();
|
const StringPiece op_type(node.type_string());
|
||||||
if (op_type == "Pad") {
|
if (op_type == "Pad") {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -678,7 +678,7 @@ void GraphTransferer::RegisterNodeWithRank(
|
|||||||
CHECK_NOTNULL(input0_node);
|
CHECK_NOTNULL(input0_node);
|
||||||
std::vector<TensorShape> shapes;
|
std::vector<TensorShape> shapes;
|
||||||
Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
|
Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
|
||||||
input0_node->def(), nullptr, &shapes);
|
input0_node->attrs(), nullptr, &shapes);
|
||||||
CHECK_EQ(1, shapes.size()) << "Output size should be 1.";
|
CHECK_EQ(1, shapes.size()) << "Output size should be 1.";
|
||||||
const int const_val_id =
|
const int const_val_id =
|
||||||
RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs());
|
RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs());
|
||||||
@ -728,7 +728,7 @@ void GraphTransferer::RegisterPadNode(
|
|||||||
CHECK(input_node->IsConstant());
|
CHECK(input_node->IsConstant());
|
||||||
|
|
||||||
const TensorProto* tensor_proto = nullptr;
|
const TensorProto* tensor_proto = nullptr;
|
||||||
TF_CHECK_OK(GetNodeAttr(input_node->def(), "value", &tensor_proto));
|
TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &tensor_proto));
|
||||||
CHECK_NOTNULL(tensor_proto);
|
CHECK_NOTNULL(tensor_proto);
|
||||||
Tensor const_tensor;
|
Tensor const_tensor;
|
||||||
TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor));
|
TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor));
|
||||||
@ -739,7 +739,7 @@ void GraphTransferer::RegisterPadNode(
|
|||||||
} else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) {
|
} else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) {
|
||||||
const int width = const_tensor.shape().dim_size(0);
|
const int width = const_tensor.shape().dim_size(0);
|
||||||
const TensorProto* proto = nullptr;
|
const TensorProto* proto = nullptr;
|
||||||
TF_CHECK_OK(GetNodeAttr(input_node->def(), "value", &proto));
|
TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &proto));
|
||||||
Tensor const_tensor;
|
Tensor const_tensor;
|
||||||
TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
|
TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
|
||||||
CHECK_EQ(DT_INT32, const_tensor.dtype());
|
CHECK_EQ(DT_INT32, const_tensor.dtype());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user