Cleanup: Don't crash when querying node for non-existing attributes.

PiperOrigin-RevId: 217420663
This commit is contained in:
A. Unique TensorFlower 2018-10-16 17:56:15 -07:00 committed by TensorFlower Gardener
parent cd1975be1e
commit 3f7d60ca9d
6 changed files with 101 additions and 36 deletions

View File

@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)

View File

@ -67,7 +67,8 @@ bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
return false;
}
if (node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
return false;
}
@ -158,14 +159,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
}
Status CheckAttrExists(const NodeDef& node, const string& key) {
if (node.attr().count(key) == 0) {
return errors::InvalidArgument("Node '", node.name(), "'lacks '", key,
"' attr: ", node.DebugString());
}
return Status::OK();
}
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@ -641,7 +634,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
CHECK(!inputs.empty()) << "Inputs must be non-empty";
// Do not create redundant AddN nodes
if (inputs.size() == 1) {
if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
return inputs[0];
}
@ -1450,10 +1443,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override {
if (IsInPreserveSet(*node)) return false;
if (IsConcat(*node)) {
if (IsConcat(*node) && node->attr().count("N") != 0) {
const int n = node->attr().at("N").i();
return n > 1;
} else if (IsSplit(*node) || IsSplitV(*node)) {
} else if ((IsSplit(*node) || IsSplitV(*node)) &&
node->attr().count("num_split") != 0) {
const int num_split = node->attr().at("num_split").i();
if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
// TODO(rmlarsen): Remove this constraint when we have optimizations
@ -1556,6 +1550,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
if (node_is_concat_) {
// Handle concat nodes by looking backwards in the graph.
TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
const int n = node.attr().at("N").i();
const int start = node.op() == "Concat" ? 1 : 0;
const int end = start + n;
@ -2029,6 +2024,8 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
// Check that 'scale * weight' can be const folded.
TF_RETURN_IF_TRUE(!IsConstant(*scale));
TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype", "value"}));
TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
weights->attr().at("dtype").type());
@ -2803,6 +2800,7 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
DataType dtype = root->attr().at("T").type();
// Keep a trace of all supported input nodes that can be fused together.
@ -3023,10 +3021,9 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
const PartialTensorShape& pack_output_shape,
int pack_axis, int* slice_start_value, bool* found) {
*found = false;
for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask",
"shrink_axis_mask"}) {
TF_RETURN_IF_ERROR(CheckAttrExists(*node, key));
}
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
"new_axis_mask", "shrink_axis_mask"}));
const int begin_mask = node->attr().at("begin_mask").i();
const int end_mask = node->attr().at("end_mask").i();
@ -3056,14 +3053,14 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
Tensor slice_strides_t;
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
return Status::OK();
}
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
return Status::OK();
}
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
if (!slice_strides_t.FromProto(
slice_strides->attr().at("value").tensor())) {
return Status::OK();

View File

@ -349,6 +349,9 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (IsReallyConstant(*array_size)) {
// Don't materialize 0 sizes to avoid triggering incorrect static
// checks. A 0 sized array that can't grow isn't useful anyway.
if (array_size->attr().count("value") == 0) {
continue;
}
const TensorProto& raw_val = array_size->attr().at("value").tensor();
if (raw_val.dtype() != DT_INT32) {
continue;
@ -454,6 +457,9 @@ bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
*min_id = std::min<int64>(*min_id, dim.size());
}
} else {
if (shape_node.attr().count("value") == 0) {
return false;
}
const TensorProto& raw_val = shape_node.attr().at("value").tensor();
if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
return false;
@ -552,6 +558,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
reduce_dims[0] = bcast.grad_x_reduce_idx();
reduce_dims[1] = bcast.grad_y_reduce_idx();
TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
for (int j = 0; j < 2; ++j) {
@ -790,7 +797,8 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (is_const) {
// Don't fold strings constants for now since this causes problems with
// checkpointing.
if (input_node->attr().at("dtype").type() == DT_STRING) {
if (input_node->attr().count("dtype") == 0 ||
input_node->attr().at("dtype").type() == DT_STRING) {
return false;
}
// Special case: If a Merge node has at least one constant input that
@ -985,6 +993,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
strings::StrCat("Can't fold ", node.name(), ", its ", input,
" isn't constant"));
}
TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
const TensorProto& raw_val = input_node->attr().at("value").tensor();
Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
CHECK(value->FromProto(raw_val));
@ -1398,16 +1407,13 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
if (node.op() == "OnesLike") {
return true;
}
if (node.op() == "OnesLike") return true;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsOnes(*values);
}
if (node.op() != "Const") {
return false;
}
if (node.op() != "Const") return false;
if (node.attr().count("dtype") == 0) return false;
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ONES_CASE(DT_BOOL);
@ -1434,16 +1440,13 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
if (node.op() == "ZerosLike") {
return true;
}
if (node.op() == "ZerosLike") return true;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsZeros(*values);
}
if (!IsConstant(node)) {
return false;
}
if (!IsConstant(node)) return false;
if (node.attr().count("dtype") == 0) return false;
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ZEROS_CASE(DT_BOOL);
@ -1737,11 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph,
NodeDef* node) {
if (node->attr().count("num_split") == 0) return false;
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
return true;
}
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
@ -1918,6 +1921,8 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
NodeDef* node, bool* success) {
if (use_shape_info && IsStridedSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 4) {
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
if (node->attr().at("new_axis_mask").i() != 0 ||
node->attr().at("shrink_axis_mask").i() != 0) {
// Skip nodes with new/shrink axis mask, since they involve dimension
@ -1952,6 +1957,8 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
}
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
@ -2280,7 +2287,7 @@ bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
// Replace the reduction node with an identity node, that can be further
// optimized by the model pruner.
DataType output_type;
if (node->attr().count("T") > 0) {
if (node->attr().count("T") != 0) {
output_type = node->attr().at("T").type();
} else {
// This is an 'any' or 'all' reduction. The output is always boolean.
@ -2297,8 +2304,10 @@ bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
bool use_shape_info, NodeDef* node) {
if (!use_shape_info) return false;
if (!IsSimplifiableReshape(*node, properties)) return false;
if (!use_shape_info || node->attr().count("T") == 0 ||
!IsSimplifiableReshape(*node, properties)) {
return false;
}
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
@ -2310,6 +2319,7 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
Status ConstantFolding::SimplifyArithmeticOperations(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) {
*success = false;
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
@ -2354,6 +2364,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
// Replace 1 / y with Reciprocal op.
if (y_matches_output_shape && is_any_div && x_is_one) {
TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);

View File

@ -547,5 +547,20 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
#undef HANDLE_CASE
Status CheckAttrExists(const NodeDef& node, const string& key) {
if (node.attr().count(key) == 0) {
return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
"' attr: ", node.ShortDebugString());
}
return Status::OK();
}
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
for (const string& key : keys) {
TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
}
return Status::OK();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -17,8 +17,12 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
#include <functional>
#include <iterator>
#include <set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
@ -29,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
@ -244,6 +249,12 @@ int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
// Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node);
// Returns an error if an attribute with the given key does not exist in node.
Status CheckAttrExists(const NodeDef& node, const string& key);
// Returns an error if attributes with the given keys do not exist in node.
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils.h"
#include <unistd.h>
#include <memory>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
@ -24,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
@ -350,6 +354,32 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
}
TEST(CheckAttrExists, All) {
NodeDef node;
node.set_name("node");
(*node.mutable_attr())["apple"].set_i(7);
(*node.mutable_attr())["pear"].set_b(true);
TF_EXPECT_OK(CheckAttrExists(node, "apple"));
TF_EXPECT_OK(CheckAttrExists(node, "pear"));
TF_EXPECT_OK(CheckAttrsExist(node, {}));
TF_EXPECT_OK(CheckAttrsExist(node, {"apple"}));
TF_EXPECT_OK(CheckAttrsExist(node, {"pear"}));
TF_EXPECT_OK(CheckAttrsExist(node, {"apple", "pear"}));
TF_EXPECT_OK(CheckAttrsExist(node, {"pear", "apple"}));
Status status = CheckAttrExists(node, "banana");
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.ToString(),
"Invalid argument: Node 'node' lacks 'banana' attr: name: \"node\" "
"attr { key: \"apple\" value { i: 7 } } attr { key: \"pear\" value "
"{ b: true } }");
EXPECT_FALSE(CheckAttrsExist(node, {""}).ok());
EXPECT_FALSE(CheckAttrsExist(node, {"pear", "cherry"}).ok());
EXPECT_FALSE(CheckAttrsExist(node, {"banana", "apple"}).ok());
}
TEST_F(UtilsTest, DeleteNodes) {
// TODO(rmlarsen): write forgotten test.
}