Remove remaining dependencies from core to contrib (in v2 only).
This is likely to break things in 2.0, only tests will tell. I fixed some of the things that will definitely break (saved_model stuff), but this definitely removes some kernels that used to be linked into some of the tools (specifically, graph_transforms), which we may have to put back. PiperOrigin-RevId: 226946051
This commit is contained in:
parent
a1193c2954
commit
402da5c870
@ -31,6 +31,7 @@ load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_header_only_library",
|
||||
"if_android",
|
||||
"if_not_v2",
|
||||
"if_not_windows",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
@ -6877,15 +6878,16 @@ tf_kernel_library(
|
||||
name = "summary_kernels",
|
||||
srcs = ["summary_kernels.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/tensorboard/db:schema",
|
||||
"//tensorflow/contrib/tensorboard/db:summary_db_writer",
|
||||
"//tensorflow/contrib/tensorboard/db:summary_file_writer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:summary_ops_op_lib",
|
||||
"//tensorflow/core/lib/db:sqlite",
|
||||
],
|
||||
] + if_not_v2([
|
||||
"//tensorflow/contrib/tensorboard/db:schema",
|
||||
"//tensorflow/contrib/tensorboard/db:summary_db_writer",
|
||||
"//tensorflow/contrib/tensorboard/db:summary_file_writer",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -664,13 +664,14 @@ def tf_additional_cloud_op_deps():
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_gcp_support": [],
|
||||
"//tensorflow:api_version_2": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
|
||||
"//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
|
||||
],
|
||||
})
|
||||
|
||||
# TODO(jart, jhseu): Delete when GCP is default on.
|
||||
# TODO(jhseu): Delete when GCP is default on.
|
||||
def tf_additional_cloud_kernel_deps():
|
||||
return select({
|
||||
"//tensorflow:android": [],
|
||||
@ -678,6 +679,7 @@ def tf_additional_cloud_kernel_deps():
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_gcp_support": [],
|
||||
"//tensorflow:api_version_2": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
|
||||
"//tensorflow/contrib/cloud/kernels:gcs_config_ops",
|
||||
|
@ -115,8 +115,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform",
|
||||
|
@ -19,6 +19,7 @@ exports_files(["LICENSE"])
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_v2")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
|
||||
py_library(
|
||||
@ -406,9 +407,10 @@ py_library(
|
||||
":debug_errors",
|
||||
":debug_fibonacci",
|
||||
":debug_keras",
|
||||
] + if_not_v2([
|
||||
":debug_mnist",
|
||||
":debug_tflearn_iris",
|
||||
],
|
||||
]),
|
||||
)
|
||||
|
||||
py_binary(
|
||||
|
@ -38,7 +38,20 @@ py_library(
|
||||
name = "saved_model_utils",
|
||||
srcs = ["saved_model_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/contrib/saved_model:reader"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "saved_model_utils_test",
|
||||
size = "small",
|
||||
srcs = ["saved_model_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/saved_model",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -250,7 +263,6 @@ py_binary(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/contrib/saved_model:saved_model_py",
|
||||
"//tensorflow/python",
|
||||
"//tensorflow/python/debug:local_cli_wrapper",
|
||||
],
|
||||
|
@ -30,9 +30,8 @@ import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from six import integer_types
|
||||
from tensorflow.contrib.saved_model.python.saved_model import reader
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -56,7 +55,7 @@ def _show_tag_sets(saved_model_dir):
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
"""
|
||||
tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
|
||||
tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
|
||||
print('The given SavedModel contains the following tag-sets:')
|
||||
for tag_set in sorted(tag_sets):
|
||||
print(', '.join(sorted(tag_set)))
|
||||
@ -190,7 +189,7 @@ def _show_all(saved_model_dir):
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
"""
|
||||
tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
|
||||
tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
|
||||
for tag_set in sorted(tag_sets):
|
||||
print("\nMetaGraphDef with tag-set: '%s' "
|
||||
"contains the following SignatureDefs:" % ', '.join(tag_set))
|
||||
@ -654,7 +653,7 @@ def scan(args):
|
||||
scan_meta_graph_def(
|
||||
saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
|
||||
else:
|
||||
saved_model = reader.read_saved_model(args.dir)
|
||||
saved_model = saved_model_utils.read_saved_model(args.dir)
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
scan_meta_graph_def(meta_graph_def)
|
||||
|
||||
|
@ -18,7 +18,78 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.saved_model.python.saved_model import reader
|
||||
import os
|
||||
|
||||
from google.protobuf import message
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.protobuf import saved_model_pb2
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def read_saved_model(saved_model_dir):
|
||||
"""Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel file.
|
||||
|
||||
Returns:
|
||||
A `SavedModel` protocol buffer.
|
||||
|
||||
Raises:
|
||||
IOError: If the file does not exist, or cannot be successfully parsed.
|
||||
"""
|
||||
# Build the path to the SavedModel in pbtxt format.
|
||||
path_to_pbtxt = os.path.join(
|
||||
compat.as_bytes(saved_model_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
# Build the path to the SavedModel in pb format.
|
||||
path_to_pb = os.path.join(
|
||||
compat.as_bytes(saved_model_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
|
||||
# Ensure that the SavedModel exists at either path.
|
||||
if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists(
|
||||
path_to_pb):
|
||||
raise IOError("SavedModel file does not exist at: %s" % saved_model_dir)
|
||||
|
||||
# Parse the SavedModel protocol buffer.
|
||||
saved_model = saved_model_pb2.SavedModel()
|
||||
if file_io.file_exists(path_to_pb):
|
||||
try:
|
||||
file_content = file_io.FileIO(path_to_pb, "rb").read()
|
||||
saved_model.ParseFromString(file_content)
|
||||
return saved_model
|
||||
except message.DecodeError as e:
|
||||
raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
|
||||
elif file_io.file_exists(path_to_pbtxt):
|
||||
try:
|
||||
file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
|
||||
text_format.Merge(file_content.decode("utf-8"), saved_model)
|
||||
return saved_model
|
||||
except text_format.ParseError as e:
|
||||
raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
|
||||
else:
|
||||
raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
|
||||
(saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT,
|
||||
constants.SAVED_MODEL_FILENAME_PB))
|
||||
|
||||
|
||||
def get_saved_model_tag_sets(saved_model_dir):
|
||||
"""Retrieves all the tag-sets available in the SavedModel.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel.
|
||||
|
||||
Returns:
|
||||
String representation of all tag-sets in the SavedModel.
|
||||
"""
|
||||
saved_model = read_saved_model(saved_model_dir)
|
||||
all_tags = []
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
all_tags.append(list(meta_graph_def.meta_info_def.tags))
|
||||
return all_tags
|
||||
|
||||
|
||||
def get_meta_graph_def(saved_model_dir, tag_set):
|
||||
@ -39,7 +110,7 @@ def get_meta_graph_def(saved_model_dir, tag_set):
|
||||
Returns:
|
||||
A MetaGraphDef corresponding to the tag-set.
|
||||
"""
|
||||
saved_model = reader.read_saved_model(saved_model_dir)
|
||||
saved_model = read_saved_model(saved_model_dir)
|
||||
set_of_tags = set(tag_set.split(','))
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
|
||||
|
116
tensorflow/python/tools/saved_model_utils_test.py
Normal file
116
tensorflow/python/tools/saved_model_utils_test.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SavedModel utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
file_io.delete_recursively(test.get_temp_dir())
|
||||
|
||||
|
||||
class SavedModelUtilTest(test.TestCase):
|
||||
|
||||
def _init_and_validate_variable(self, sess, variable_name, variable_value):
|
||||
v = variables.Variable(variable_value, name=variable_name)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertEqual(variable_value, v.eval())
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testReadSavedModelValid(self):
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
|
||||
builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
|
||||
builder.save()
|
||||
|
||||
actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir)
|
||||
self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
|
||||
self.assertEqual(
|
||||
len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
|
||||
self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
|
||||
tag_constants.TRAINING)
|
||||
|
||||
def testReadSavedModelInvalid(self):
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model")
|
||||
with self.assertRaisesRegexp(
|
||||
IOError, "SavedModel file does not exist at: %s" % saved_model_dir):
|
||||
saved_model_utils.read_saved_model(saved_model_dir)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testGetSavedModelTagSets(self):
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags")
|
||||
builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
|
||||
|
||||
# Graph with a single variable. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
# - a single tag (from predefined constants).
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
|
||||
|
||||
# Graph that updates the single variable. SavedModel invoked to:
|
||||
# - simply add the model (weights are not updated).
|
||||
# - a single tag (from predefined constants).
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 43)
|
||||
builder.add_meta_graph([tag_constants.SERVING])
|
||||
|
||||
# Graph that updates the single variable. SavedModel is invoked:
|
||||
# - to add the model (weights are not updated).
|
||||
# - multiple predefined tags.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 44)
|
||||
builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
|
||||
|
||||
# Graph that updates the single variable. SavedModel is invoked:
|
||||
# - to add the model (weights are not updated).
|
||||
# - multiple predefined tags for serving on TPU.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 44)
|
||||
builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
|
||||
|
||||
# Graph that updates the single variable. SavedModel is invoked:
|
||||
# - to add the model (weights are not updated).
|
||||
# - multiple custom tags.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 45)
|
||||
builder.add_meta_graph(["foo", "bar"])
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
actual_tags = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
|
||||
expected_tags = [["train"], ["serve"], ["serve", "gpu"], ["serve", "tpu"],
|
||||
["foo", "bar"]]
|
||||
self.assertEqual(expected_tags, actual_tags)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,6 +12,7 @@ load(
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
"tf_py_test",
|
||||
"if_not_v2",
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
@ -131,12 +132,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/contrib/rnn:gru_ops_op_lib",
|
||||
"//tensorflow/contrib/rnn:lstm_ops_op_lib",
|
||||
"//tensorflow/core/kernels:quantization_utils",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/core/kernels:remote_fused_graph_rewriter_transform",
|
||||
"//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform",
|
||||
]) + if_not_v2([
|
||||
"//tensorflow/contrib/rnn:gru_ops_op_lib",
|
||||
"//tensorflow/contrib/rnn:lstm_ops_op_lib",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -102,6 +102,7 @@ BLACKLIST = [
|
||||
"//tensorflow/contrib/framework:checkpoint_ops_testdata",
|
||||
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
|
||||
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
|
||||
"//tensorflow/contrib/saved_model:reader", # Not present in v2
|
||||
"//tensorflow/contrib/timeseries/examples:predict",
|
||||
"//tensorflow/contrib/timeseries/examples:multivariate",
|
||||
"//tensorflow/contrib/timeseries/examples:known_anomaly",
|
||||
|
Loading…
Reference in New Issue
Block a user