Properly handle negative shape dimensions from improper saved models.
PiperOrigin-RevId: 308283636 Change-Id: Ib10849425de7d541d8dacfe4d0c709fbac9180b6
This commit is contained in:
parent
40d3f089be
commit
f760f88b42
tensorflow/cc/saved_model
@ -19,12 +19,16 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/cc/saved_model/constants.h"
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
#include "tensorflow/cc/saved_model/reader.h"
|
#include "tensorflow/cc/saved_model/reader.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||||
@ -65,12 +69,34 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
|
|||||||
return end_microseconds - start_microseconds;
|
return end_microseconds - start_microseconds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that constant tensors loaded from the saved model have valid shape.
|
||||||
|
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
||||||
|
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||||
|
for (const auto& node : graph_def.node()) {
|
||||||
|
const auto node_iterator = node.attr().find("value");
|
||||||
|
if (node_iterator != node.attr().end()) {
|
||||||
|
AttrValue node_value = node_iterator->second;
|
||||||
|
if (node_value.has_tensor()) {
|
||||||
|
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||||
|
if (node_shape.num_elements() < 0) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Saved model contains node \"", node.name(), "\" (op \"",
|
||||||
|
node.op(), "\") which initializes from a tensor with ",
|
||||||
|
node_shape.num_elements(), " elements");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||||
const SessionOptions& session_options,
|
const SessionOptions& session_options,
|
||||||
std::unique_ptr<Session>* session) {
|
std::unique_ptr<Session>* session) {
|
||||||
Session* session_p = nullptr;
|
Session* session_p = nullptr;
|
||||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||||
session->reset(session_p);
|
session->reset(session_p);
|
||||||
|
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
|
||||||
return (*session)->Create(meta_graph_def.graph_def());
|
return (*session)->Create(meta_graph_def.graph_def());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ constexpr char kTestDataInitOpV2[] =
|
|||||||
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||||
constexpr char kTestDataV2DebugInfo[] =
|
constexpr char kTestDataV2DebugInfo[] =
|
||||||
"cc/saved_model/testdata/x_plus_y_v2_debuginfo";
|
"cc/saved_model/testdata/x_plus_y_v2_debuginfo";
|
||||||
|
constexpr char kTestNegativeShapeFuzzGenerated[] =
|
||||||
|
"cc/saved_model/testdata/negative_shape/fuzz_generated";
|
||||||
|
|
||||||
class LoaderTest : public ::testing::Test {
|
class LoaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -256,5 +258,17 @@ TEST_F(LoaderTest, SavedModelV2DebugInfo) {
|
|||||||
EXPECT_NE(bundle.debug_info.get(), nullptr);
|
EXPECT_NE(bundle.debug_info.get(), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LoaderTest, NegativeShapeDimension) {
|
||||||
|
SavedModelBundle bundle;
|
||||||
|
RunOptions run_options;
|
||||||
|
SessionOptions session_options;
|
||||||
|
|
||||||
|
const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||||
|
kTestNegativeShapeFuzzGenerated);
|
||||||
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
|
{kSavedModelTagServe}, &bundle);
|
||||||
|
EXPECT_FALSE(st.ok());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
BIN
tensorflow/cc/saved_model/testdata/negative_shape/fuzz_generated
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/negative_shape/fuzz_generated
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user