Validate NodeDef
s from FunctionDefLibrary
of a GraphDef
.
We already validated `NodeDef`s from a `GraphDef` but missed validating those from the `FunctionDefLibrary`. Thus, some maliciously crafted models could evade detection and cause denial of service due to a `CHECK`-fail. PiperOrigin-RevId: 332536309 Change-Id: I052efe919ff1fe2f90815e286a1aa4c54c7b94ff
This commit is contained in:
parent
2d88f470de
commit
adf095206f
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
@ -73,26 +74,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
|
||||
// Ensure that constant tensors loaded from the saved model have valid shape.
|
||||
// Also ensure that constant nodes have a value assigned to them.
|
||||
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
||||
static Status ValidateNode(const NodeDef& node) {
|
||||
const auto node_iterator = node.attr().find("value");
|
||||
if (node_iterator != node.attr().end()) {
|
||||
AttrValue node_value = node_iterator->second;
|
||||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
|
||||
"\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
}
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||
for (const auto& node : graph_def.node()) {
|
||||
const auto node_iterator = node.attr().find("value");
|
||||
if (node_iterator != node.attr().end()) {
|
||||
AttrValue node_value = node_iterator->second;
|
||||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"",
|
||||
node.op(), "\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||
}
|
||||
|
||||
if (graph_def.has_library()) {
|
||||
const FunctionDefLibrary& library = graph_def.library();
|
||||
for (const auto& function : library.function()) {
|
||||
for (const auto& node : function.node_def()) {
|
||||
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,8 @@ constexpr char kTestFuzzGeneratedNegativeShape[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/negative_shape";
|
||||
constexpr char kTestFuzzGeneratedConstWithNoValue[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/const_with_no_value";
|
||||
constexpr char kTestFuzzGeneratedBadNodeAttr[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/bad_node_attr";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -328,5 +330,20 @@ TEST_F(LoaderTest, ConstNoValue) {
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, BadNodeAttr) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestFuzzGeneratedBadNodeAttr);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_NE(
|
||||
st.error_message().find("constant tensor but no value has been provided"),
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
0
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty
vendored
Normal file
0
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index
vendored
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user