Cleanup: Don't crash when querying node for non-existing attributes.
PiperOrigin-RevId: 217420663
This commit is contained in:
parent
cd1975be1e
commit
3f7d60ca9d
@ -26,6 +26,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -67,7 +67,8 @@ bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
|
||||
if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
|
||||
node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -158,14 +159,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
|
||||
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
|
||||
}
|
||||
|
||||
Status CheckAttrExists(const NodeDef& node, const string& key) {
|
||||
if (node.attr().count(key) == 0) {
|
||||
return errors::InvalidArgument("Node '", node.name(), "'lacks '", key,
|
||||
"' attr: ", node.DebugString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* GetTailOfValuePreservingChain(
|
||||
const NodeDef& node, const NodeMap& node_map,
|
||||
const std::unordered_set<string>& nodes_to_preserve) {
|
||||
@ -641,7 +634,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
|
||||
CHECK(!inputs.empty()) << "Inputs must be non-empty";
|
||||
|
||||
// Do not create redundant AddN nodes
|
||||
if (inputs.size() == 1) {
|
||||
if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
@ -1450,10 +1443,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
if (IsInPreserveSet(*node)) return false;
|
||||
if (IsConcat(*node)) {
|
||||
if (IsConcat(*node) && node->attr().count("N") != 0) {
|
||||
const int n = node->attr().at("N").i();
|
||||
return n > 1;
|
||||
} else if (IsSplit(*node) || IsSplitV(*node)) {
|
||||
} else if ((IsSplit(*node) || IsSplitV(*node)) &&
|
||||
node->attr().count("num_split") != 0) {
|
||||
const int num_split = node->attr().at("num_split").i();
|
||||
if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
|
||||
// TODO(rmlarsen): Remove this constraint when we have optimizations
|
||||
@ -1556,6 +1550,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
||||
Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
|
||||
if (node_is_concat_) {
|
||||
// Handle concat nodes by looking backwards in the graph.
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
|
||||
const int n = node.attr().at("N").i();
|
||||
const int start = node.op() == "Concat" ? 1 : 0;
|
||||
const int end = start + n;
|
||||
@ -2029,6 +2024,8 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
|
||||
|
||||
// Check that 'scale * weight' can be const folded.
|
||||
TF_RETURN_IF_TRUE(!IsConstant(*scale));
|
||||
TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype", "value"}));
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
|
||||
TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
|
||||
weights->attr().at("dtype").type());
|
||||
|
||||
@ -2803,6 +2800,7 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
|
||||
DataType dtype = root->attr().at("T").type();
|
||||
|
||||
// Keep a trace of all supported input nodes that can be fused together.
|
||||
@ -3023,10 +3021,9 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
|
||||
const PartialTensorShape& pack_output_shape,
|
||||
int pack_axis, int* slice_start_value, bool* found) {
|
||||
*found = false;
|
||||
for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask",
|
||||
"shrink_axis_mask"}) {
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*node, key));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
|
||||
"new_axis_mask", "shrink_axis_mask"}));
|
||||
|
||||
const int begin_mask = node->attr().at("begin_mask").i();
|
||||
const int end_mask = node->attr().at("end_mask").i();
|
||||
@ -3056,14 +3053,14 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
|
||||
Tensor slice_strides_t;
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
|
||||
|
||||
if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
|
||||
if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
|
||||
if (!slice_strides_t.FromProto(
|
||||
slice_strides->attr().at("value").tensor())) {
|
||||
return Status::OK();
|
||||
|
@ -349,6 +349,9 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
if (IsReallyConstant(*array_size)) {
|
||||
// Don't materialize 0 sizes to avoid triggering incorrect static
|
||||
// checks. A 0 sized array that can't grow isn't useful anyway.
|
||||
if (array_size->attr().count("value") == 0) {
|
||||
continue;
|
||||
}
|
||||
const TensorProto& raw_val = array_size->attr().at("value").tensor();
|
||||
if (raw_val.dtype() != DT_INT32) {
|
||||
continue;
|
||||
@ -454,6 +457,9 @@ bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
|
||||
*min_id = std::min<int64>(*min_id, dim.size());
|
||||
}
|
||||
} else {
|
||||
if (shape_node.attr().count("value") == 0) {
|
||||
return false;
|
||||
}
|
||||
const TensorProto& raw_val = shape_node.attr().at("value").tensor();
|
||||
if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
|
||||
return false;
|
||||
@ -552,6 +558,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
|
||||
reduce_dims[0] = bcast.grad_x_reduce_idx();
|
||||
reduce_dims[1] = bcast.grad_y_reduce_idx();
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
|
||||
const DataType type = node.attr().at("T").type();
|
||||
NodeDef* out[2];
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
@ -790,7 +797,8 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
if (is_const) {
|
||||
// Don't fold strings constants for now since this causes problems with
|
||||
// checkpointing.
|
||||
if (input_node->attr().at("dtype").type() == DT_STRING) {
|
||||
if (input_node->attr().count("dtype") == 0 ||
|
||||
input_node->attr().at("dtype").type() == DT_STRING) {
|
||||
return false;
|
||||
}
|
||||
// Special case: If a Merge node has at least one constant input that
|
||||
@ -985,6 +993,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
strings::StrCat("Can't fold ", node.name(), ", its ", input,
|
||||
" isn't constant"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
|
||||
const TensorProto& raw_val = input_node->attr().at("value").tensor();
|
||||
Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
|
||||
CHECK(value->FromProto(raw_val));
|
||||
@ -1398,16 +1407,13 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
|
||||
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (node.op() == "OnesLike") {
|
||||
return true;
|
||||
}
|
||||
if (node.op() == "OnesLike") return true;
|
||||
if (node.op() == "Fill") {
|
||||
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
|
||||
return values != nullptr && IsOnes(*values);
|
||||
}
|
||||
if (node.op() != "Const") {
|
||||
return false;
|
||||
}
|
||||
if (node.op() != "Const") return false;
|
||||
if (node.attr().count("dtype") == 0) return false;
|
||||
const auto dtype = node.attr().at("dtype").type();
|
||||
switch (dtype) {
|
||||
IS_ONES_CASE(DT_BOOL);
|
||||
@ -1434,16 +1440,13 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
||||
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (node.op() == "ZerosLike") {
|
||||
return true;
|
||||
}
|
||||
if (node.op() == "ZerosLike") return true;
|
||||
if (node.op() == "Fill") {
|
||||
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
|
||||
return values != nullptr && IsZeros(*values);
|
||||
}
|
||||
if (!IsConstant(node)) {
|
||||
return false;
|
||||
}
|
||||
if (!IsConstant(node)) return false;
|
||||
if (node.attr().count("dtype") == 0) return false;
|
||||
const auto dtype = node.attr().at("dtype").type();
|
||||
switch (dtype) {
|
||||
IS_ZEROS_CASE(DT_BOOL);
|
||||
@ -1737,11 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (node->attr().count("num_split") == 0) return false;
|
||||
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return true;
|
||||
@ -1918,6 +1921,8 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
NodeDef* node, bool* success) {
|
||||
if (use_shape_info && IsStridedSlice(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 4) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
|
||||
if (node->attr().at("new_axis_mask").i() != 0 ||
|
||||
node->attr().at("shrink_axis_mask").i() != 0) {
|
||||
// Skip nodes with new/shrink axis mask, since they involve dimension
|
||||
@ -1952,6 +1957,8 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
s.value().DebugString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
|
||||
int begin_mask = node->attr().at("begin_mask").i();
|
||||
int end_mask = node->attr().at("end_mask").i();
|
||||
std::set<int> expanded_ellipsis_indices;
|
||||
@ -2280,7 +2287,7 @@ bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
|
||||
// Replace the reduction node with an identity node, that can be further
|
||||
// optimized by the model pruner.
|
||||
DataType output_type;
|
||||
if (node->attr().count("T") > 0) {
|
||||
if (node->attr().count("T") != 0) {
|
||||
output_type = node->attr().at("T").type();
|
||||
} else {
|
||||
// This is an 'any' or 'all' reduction. The output is always boolean.
|
||||
@ -2297,8 +2304,10 @@ bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
|
||||
|
||||
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
|
||||
bool use_shape_info, NodeDef* node) {
|
||||
if (!use_shape_info) return false;
|
||||
if (!IsSimplifiableReshape(*node, properties)) return false;
|
||||
if (!use_shape_info || node->attr().count("T") == 0 ||
|
||||
!IsSimplifiableReshape(*node, properties)) {
|
||||
return false;
|
||||
}
|
||||
DataType output_type = node->attr().at("T").type();
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
@ -2310,6 +2319,7 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
|
||||
Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
||||
*success = false;
|
||||
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
|
||||
const bool is_matmul = IsMatMul(*node);
|
||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||
@ -2354,6 +2364,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
|
||||
// Replace 1 / y with Reciprocal op.
|
||||
if (y_matches_output_shape && is_any_div && x_is_one) {
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
|
||||
DataType type = node->attr().at("T").type();
|
||||
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
|
||||
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
|
||||
|
@ -547,5 +547,20 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
|
||||
|
||||
#undef HANDLE_CASE
|
||||
|
||||
Status CheckAttrExists(const NodeDef& node, const string& key) {
|
||||
if (node.attr().count(key) == 0) {
|
||||
return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
|
||||
"' attr: ", node.ShortDebugString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
|
||||
for (const string& key : keys) {
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -17,8 +17,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -29,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -244,6 +249,12 @@ int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
// Removes redundant control inputs from node.
|
||||
void DedupControlInputs(NodeDef* node);
|
||||
|
||||
// Returns an error if an attribute with the given key does not exist in node.
|
||||
Status CheckAttrExists(const NodeDef& node, const string& key);
|
||||
|
||||
// Returns an error if attributes with the given keys do not exist in node.
|
||||
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
|
||||
|
||||
// Returns the data type in attribute `attr_name` of `node`. If that attribute
|
||||
// doesn't exist, returns DT_INVALID.
|
||||
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);
|
||||
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
@ -24,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -350,6 +354,32 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
|
||||
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
|
||||
}
|
||||
|
||||
TEST(CheckAttrExists, All) {
|
||||
NodeDef node;
|
||||
node.set_name("node");
|
||||
(*node.mutable_attr())["apple"].set_i(7);
|
||||
(*node.mutable_attr())["pear"].set_b(true);
|
||||
|
||||
TF_EXPECT_OK(CheckAttrExists(node, "apple"));
|
||||
TF_EXPECT_OK(CheckAttrExists(node, "pear"));
|
||||
|
||||
TF_EXPECT_OK(CheckAttrsExist(node, {}));
|
||||
TF_EXPECT_OK(CheckAttrsExist(node, {"apple"}));
|
||||
TF_EXPECT_OK(CheckAttrsExist(node, {"pear"}));
|
||||
TF_EXPECT_OK(CheckAttrsExist(node, {"apple", "pear"}));
|
||||
TF_EXPECT_OK(CheckAttrsExist(node, {"pear", "apple"}));
|
||||
|
||||
Status status = CheckAttrExists(node, "banana");
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.ToString(),
|
||||
"Invalid argument: Node 'node' lacks 'banana' attr: name: \"node\" "
|
||||
"attr { key: \"apple\" value { i: 7 } } attr { key: \"pear\" value "
|
||||
"{ b: true } }");
|
||||
EXPECT_FALSE(CheckAttrsExist(node, {""}).ok());
|
||||
EXPECT_FALSE(CheckAttrsExist(node, {"pear", "cherry"}).ok());
|
||||
EXPECT_FALSE(CheckAttrsExist(node, {"banana", "apple"}).ok());
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, DeleteNodes) {
|
||||
// TODO(rmlarsen): write forgotten test.
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user