From 402da5c870bb7d73af0eeba68a250226e7bdcad2 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Wed, 26 Dec 2018 12:59:26 -0800 Subject: [PATCH] Remove remaining dependencies from core to contrib (in v2 only). This is likely to break things in 2.0, only tests will tell. I fixed some of the things that will definitely break (saved_model stuff), but this definitely removes some kernels that used to be linked into some of the tools (specifically, graph_transforms), which we may have to put back. PiperOrigin-RevId: 226946051 --- tensorflow/core/kernels/BUILD | 10 +- .../core/platform/default/build_config.bzl | 4 +- tensorflow/lite/python/BUILD | 2 - tensorflow/python/debug/BUILD | 4 +- tensorflow/python/tools/BUILD | 16 ++- tensorflow/python/tools/saved_model_cli.py | 9 +- tensorflow/python/tools/saved_model_utils.py | 75 ++++++++++- .../python/tools/saved_model_utils_test.py | 116 ++++++++++++++++++ tensorflow/tools/graph_transforms/BUILD | 6 +- .../tools/pip_package/pip_smoke_test.py | 1 + 10 files changed, 224 insertions(+), 19 deletions(-) create mode 100644 tensorflow/python/tools/saved_model_utils_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6fa139e5dfd..8eed2cd0a81 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -31,6 +31,7 @@ load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_android", + "if_not_v2", "if_not_windows", "tf_cc_binary", "tf_cc_test", @@ -6877,15 +6878,16 @@ tf_kernel_library( name = "summary_kernels", srcs = ["summary_kernels.cc"], deps = [ - "//tensorflow/contrib/tensorboard/db:schema", - "//tensorflow/contrib/tensorboard/db:summary_db_writer", - "//tensorflow/contrib/tensorboard/db:summary_file_writer", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:summary_ops_op_lib", "//tensorflow/core/lib/db:sqlite", - ], + ] + if_not_v2([ + "//tensorflow/contrib/tensorboard/db:schema", + "//tensorflow/contrib/tensorboard/db:summary_db_writer", + "//tensorflow/contrib/tensorboard/db:summary_file_writer", + ]), ) tf_kernel_library( diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 769e2890252..58fd0e27121 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -664,13 +664,14 @@ def tf_additional_cloud_op_deps(): "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], "//tensorflow:no_gcp_support": [], + "//tensorflow:api_version_2": [], "//conditions:default": [ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", ], }) -# TODO(jart, jhseu): Delete when GCP is default on. +# TODO(jhseu): Delete when GCP is default on. def tf_additional_cloud_kernel_deps(): return select({ "//tensorflow:android": [], @@ -678,6 +679,7 @@ def tf_additional_cloud_kernel_deps(): "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], "//tensorflow:no_gcp_support": [], + "//tensorflow:api_version_2": [], "//conditions:default": [ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/contrib/cloud/kernels:gcs_config_ops", diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index e666812bd23..7bfd1a89964 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -115,8 +115,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework", "//tensorflow/python:platform", diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 1dcdb880f55..27a700f813c 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -19,6 +19,7 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_binary") +load("//tensorflow:tensorflow.bzl", "if_not_v2") load("//tensorflow:tensorflow.bzl", "if_not_windows") py_library( @@ -406,9 +407,10 @@ py_library( ":debug_errors", ":debug_fibonacci", ":debug_keras", + ] + if_not_v2([ ":debug_mnist", ":debug_tflearn_iris", - ], + ]), ) py_binary( diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 901d6bc335f..f1a911eb489 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -38,7 +38,20 @@ py_library( name = "saved_model_utils", srcs = ["saved_model_utils.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/contrib/saved_model:reader"], +) + +py_test( + name = "saved_model_utils_test", + size = "small", + srcs = ["saved_model_utils_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows + visibility = ["//visibility:private"], + deps = [ + ":saved_model_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python/saved_model", + ], ) py_library( @@ -250,7 +263,6 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":saved_model_utils", - "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python", "//tensorflow/python/debug:local_cli_wrapper", ], diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index afc4e517cdd..cdef42e2bf8 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -30,9 +30,8 @@ import sys import warnings import numpy as np - from six import integer_types -from tensorflow.contrib.saved_model.python.saved_model import reader + from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session @@ -56,7 +55,7 @@ def _show_tag_sets(saved_model_dir): Args: saved_model_dir: Directory containing the SavedModel to inspect. """ - tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) + tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir) print('The given SavedModel contains the following tag-sets:') for tag_set in sorted(tag_sets): print(', '.join(sorted(tag_set))) @@ -190,7 +189,7 @@ def _show_all(saved_model_dir): Args: saved_model_dir: Directory containing the SavedModel to inspect. """ - tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) + tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir) for tag_set in sorted(tag_sets): print("\nMetaGraphDef with tag-set: '%s' " "contains the following SignatureDefs:" % ', '.join(tag_set)) @@ -654,7 +653,7 @@ def scan(args): scan_meta_graph_def( saved_model_utils.get_meta_graph_def(args.dir, args.tag_set)) else: - saved_model = reader.read_saved_model(args.dir) + saved_model = saved_model_utils.read_saved_model(args.dir) for meta_graph_def in saved_model.meta_graphs: scan_meta_graph_def(meta_graph_def) diff --git a/tensorflow/python/tools/saved_model_utils.py b/tensorflow/python/tools/saved_model_utils.py index c27d7a2658a..17c4b8cb831 100644 --- a/tensorflow/python/tools/saved_model_utils.py +++ b/tensorflow/python/tools/saved_model_utils.py @@ -18,7 +18,78 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.saved_model.python.saved_model import reader +import os + +from google.protobuf import message +from google.protobuf import text_format +from tensorflow.core.protobuf import saved_model_pb2 +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat + + +def read_saved_model(saved_model_dir): + """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. + + Args: + saved_model_dir: Directory containing the SavedModel file. + + Returns: + A `SavedModel` protocol buffer. + + Raises: + IOError: If the file does not exist, or cannot be successfully parsed. + """ + # Build the path to the SavedModel in pbtxt format. + path_to_pbtxt = os.path.join( + compat.as_bytes(saved_model_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) + # Build the path to the SavedModel in pb format. + path_to_pb = os.path.join( + compat.as_bytes(saved_model_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + + # Ensure that the SavedModel exists at either path. + if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists( + path_to_pb): + raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) + + # Parse the SavedModel protocol buffer. + saved_model = saved_model_pb2.SavedModel() + if file_io.file_exists(path_to_pb): + try: + file_content = file_io.FileIO(path_to_pb, "rb").read() + saved_model.ParseFromString(file_content) + return saved_model + except message.DecodeError as e: + raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) + elif file_io.file_exists(path_to_pbtxt): + try: + file_content = file_io.FileIO(path_to_pbtxt, "rb").read() + text_format.Merge(file_content.decode("utf-8"), saved_model) + return saved_model + except text_format.ParseError as e: + raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) + else: + raise IOError("SavedModel file does not exist at: %s/{%s|%s}" % + (saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT, + constants.SAVED_MODEL_FILENAME_PB)) + + +def get_saved_model_tag_sets(saved_model_dir): + """Retrieves all the tag-sets available in the SavedModel. + + Args: + saved_model_dir: Directory containing the SavedModel. + + Returns: + String representation of all tag-sets in the SavedModel. + """ + saved_model = read_saved_model(saved_model_dir) + all_tags = [] + for meta_graph_def in saved_model.meta_graphs: + all_tags.append(list(meta_graph_def.meta_info_def.tags)) + return all_tags def get_meta_graph_def(saved_model_dir, tag_set): @@ -39,7 +110,7 @@ def get_meta_graph_def(saved_model_dir, tag_set): Returns: A MetaGraphDef corresponding to the tag-set. """ - saved_model = reader.read_saved_model(saved_model_dir) + saved_model = read_saved_model(saved_model_dir) set_of_tags = set(tag_set.split(',')) for meta_graph_def in saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set_of_tags: diff --git a/tensorflow/python/tools/saved_model_utils_test.py b/tensorflow/python/tools/saved_model_utils_test.py new file mode 100644 index 00000000000..5512dea1f74 --- /dev/null +++ b/tensorflow/python/tools/saved_model_utils_test.py @@ -0,0 +1,116 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModel utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.tools import saved_model_utils + + +def tearDownModule(): + file_io.delete_recursively(test.get_temp_dir()) + + +class SavedModelUtilTest(test.TestCase): + + def _init_and_validate_variable(self, sess, variable_name, variable_value): + v = variables.Variable(variable_value, name=variable_name) + sess.run(variables.global_variables_initializer()) + self.assertEqual(variable_value, v.eval()) + + @test_util.deprecated_graph_mode_only + def testReadSavedModelValid(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model") + builder = saved_model_builder.SavedModelBuilder(saved_model_dir) + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) + builder.save() + + actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir) + self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1) + self.assertEqual( + len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1) + self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0], + tag_constants.TRAINING) + + def testReadSavedModelInvalid(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model") + with self.assertRaisesRegexp( + IOError, "SavedModel file does not exist at: %s" % saved_model_dir): + saved_model_utils.read_saved_model(saved_model_dir) + + @test_util.deprecated_graph_mode_only + def testGetSavedModelTagSets(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags") + builder = saved_model_builder.SavedModelBuilder(saved_model_dir) + + # Graph with a single variable. SavedModel invoked to: + # - add with weights. + # - a single tag (from predefined constants). + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) + + # Graph that updates the single variable. SavedModel invoked to: + # - simply add the model (weights are not updated). + # - a single tag (from predefined constants). + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 43) + builder.add_meta_graph([tag_constants.SERVING]) + + # Graph that updates the single variable. SavedModel is invoked: + # - to add the model (weights are not updated). + # - multiple predefined tags. + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 44) + builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) + + # Graph that updates the single variable. SavedModel is invoked: + # - to add the model (weights are not updated). + # - multiple predefined tags for serving on TPU. + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 44) + builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU]) + + # Graph that updates the single variable. SavedModel is invoked: + # - to add the model (weights are not updated). + # - multiple custom tags. + with self.session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 45) + builder.add_meta_graph(["foo", "bar"]) + + # Save the SavedModel to disk. + builder.save() + + actual_tags = saved_model_utils.get_saved_model_tag_sets(saved_model_dir) + expected_tags = [["train"], ["serve"], ["serve", "gpu"], ["serve", "tpu"], + ["foo", "bar"]] + self.assertEqual(expected_tags, actual_tags) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index eb1ed1f2ca8..f229099e493 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -12,6 +12,7 @@ load( "tf_cc_binary", "tf_cc_test", "tf_py_test", + "if_not_v2", ) exports_files(["LICENSE"]) @@ -131,12 +132,13 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", - "//tensorflow/contrib/rnn:gru_ops_op_lib", - "//tensorflow/contrib/rnn:lstm_ops_op_lib", "//tensorflow/core/kernels:quantization_utils", ] + if_not_windows([ "//tensorflow/core/kernels:remote_fused_graph_rewriter_transform", "//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform", + ]) + if_not_v2([ + "//tensorflow/contrib/rnn:gru_ops_op_lib", + "//tensorflow/contrib/rnn:lstm_ops_op_lib", ]), alwayslink = 1, ) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index ff821b86430..51d010c9e17 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -102,6 +102,7 @@ BLACKLIST = [ "//tensorflow/contrib/framework:checkpoint_ops_testdata", "//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long + "//tensorflow/contrib/saved_model:reader", # Not present in v2 "//tensorflow/contrib/timeseries/examples:predict", "//tensorflow/contrib/timeseries/examples:multivariate", "//tensorflow/contrib/timeseries/examples:known_anomaly",