diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 3bb4660e449..d193679ec19 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -19,12 +19,16 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/reader.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" @@ -65,12 +69,34 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { return end_microseconds - start_microseconds; } +// Ensure that constant tensors loaded from the saved model have valid shape. +// TODO(b/154763635): this is temporary and will be replaced with a better audit +static Status ValidateSavedTensors(const GraphDef& graph_def) { + for (const auto& node : graph_def.node()) { + const auto node_iterator = node.attr().find("value"); + if (node_iterator != node.attr().end()) { + AttrValue node_value = node_iterator->second; + if (node_value.has_tensor()) { + const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); + if (node_shape.num_elements() < 0) { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), "\" (op \"", + node.op(), "\") which initializes from a tensor with ", + node_shape.num_elements(), " elements"); + } + } + } + } + return Status::OK(); +} + Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr<Session>* session) { Session* session_p = nullptr; TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); session->reset(session_p); + TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def())); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index 9fc71552d6f..46f365613f1 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -40,6 +40,8 @@ constexpr char kTestDataInitOpV2[] = "cc/saved_model/testdata/half_plus_two_v2/00000123"; constexpr char kTestDataV2DebugInfo[] = "cc/saved_model/testdata/x_plus_y_v2_debuginfo"; +constexpr char kTestNegativeShapeFuzzGenerated[] = + "cc/saved_model/testdata/negative_shape/fuzz_generated"; class LoaderTest : public ::testing::Test { protected: @@ -256,5 +258,17 @@ TEST_F(LoaderTest, SavedModelV2DebugInfo) { EXPECT_NE(bundle.debug_info.get(), nullptr); } +TEST_F(LoaderTest, NegativeShapeDimension) { + SavedModelBundle bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(), + kTestNegativeShapeFuzzGenerated); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/negative_shape/fuzz_generated b/tensorflow/cc/saved_model/testdata/negative_shape/fuzz_generated new file mode 100644 index 00000000000..5ee5c360ce0 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/negative_shape/fuzz_generated differ