From e5088cb823964216adfba3155965e0f6f2c7bf7c Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 31 May 2017 17:55:46 -0700 Subject: [PATCH 01/72] Fix discrepancy between measured and analytical cost graph. Use tf_cuda_library for utils. PiperOrigin-RevId: 157660745 --- tensorflow/core/grappler/clusters/BUILD | 4 +++- tensorflow/core/grappler/costs/BUILD | 14 ++++++------- tensorflow/core/grappler/costs/utils.cc | 16 ++++++++------ tensorflow/python/grappler/cost_analyzer.cc | 23 ++++++++++++++------- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 1431641a8fb..fd2f2b32492 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + filegroup( name = "all_files", srcs = glob( @@ -20,7 +22,7 @@ config_setting( }, ) -cc_library( +tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = [ diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 5118a2530b2..aa675fcc771 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + filegroup( name = "all_files", srcs = glob( @@ -108,25 +110,21 @@ cc_test( ], ) -cc_library( +tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - defines = if_cuda(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":op_performance_data_cc", - "//third_party/eigen3", - "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", - ] + if_cuda([ - "//tensorflow/core:cuda", - "@local_config_cuda//cuda:cuda_headers", - ]), + "//tensorflow/core/grappler/clusters:utils", + "//third_party/eigen3", + ], ) cc_library( diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 3cc92b56d20..1eca141e57c 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -167,12 +167,16 @@ std::vector FindInputFeatures( inputs.push_back(UnknownInput()); } else { const CostGraphDef::Node* input_cost = it->second; - const CostGraphDef::Node::OutputInfo& output = - input_cost->output_info(output_index); - OpInfo::TensorProperties input; - input.set_dtype(output.dtype()); - *input.mutable_shape() = output.shape(); - inputs.push_back(input); + if (input_cost->output_info_size() == 0) { + inputs.push_back(UnknownInput()); + } else { + const CostGraphDef::Node::OutputInfo& output = + input_cost->output_info(output_index); + OpInfo::TensorProperties input; + input.set_dtype(output.dtype()); + *input.mutable_shape() = output.shape(); + inputs.push_back(input); + } } } diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 273a74dd286..29976b79495 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -56,8 +56,8 @@ void CostAnalyzer::GatherCosts() { CostGraphDef cost_graph_measured; PredictCosts(&measure_estimator_, &cost_graph_measured, &total_time_measured_); + VLOG(1) << "Graph size: " << item_->graph.node_size(); VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size(); - op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph); CostGraphDef cost_graph_analytical; PredictCosts(&analytical_estimator_, &cost_graph_analytical, @@ -66,25 +66,32 @@ void CostAnalyzer::GatherCosts() { << cost_graph_analytical.node_size(); CostGraphDef cost_graph_analytical_filtered; - std::set cost_nodes; - for (auto& node : cost_graph_measured.node()) { - cost_nodes.insert(node.name()); + CostGraphDef cost_graph_measured_filtered; + std::map measured_nodes; + for (const auto& node : cost_graph_measured.node()) { + measured_nodes[node.name()] = &node; } for (const auto& node : cost_graph_analytical.node()) { - auto it = cost_nodes.find(node.name()); + auto it = measured_nodes.find(node.name()); // Filter the nodes that are not the cost nodes returned by // MeasuringCostEstimator. - if (it == cost_nodes.end()) { + if (it == measured_nodes.end()) { continue; } - auto added_node = cost_graph_analytical_filtered.add_node(); - *added_node = node; + auto added_node_analytical = cost_graph_analytical_filtered.add_node(); + auto added_node_measured = cost_graph_measured_filtered.add_node(); + *added_node_analytical = node; + *added_node_measured = *(it->second); } VLOG(1) << "cost_graph_analytical_filtered size: " << cost_graph_analytical_filtered.node_size(); + // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and + // op_perf_ cover the same set of nodes. op_perf_analytical_ = CostGraphToOpPerformanceData( cost_graph_analytical_filtered, item_->graph); + op_perf_ = + CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph); } void CostAnalyzer::PreprocessCosts() { From 732a6b1ae350b7879fd3c51575d6c76a273d3245 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Wed, 31 May 2017 19:32:34 -0700 Subject: [PATCH 02/72] Upgrade TypeScript to v2.3.4 PiperOrigin-RevId: 157667511 --- tensorflow/workspace.bzl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index a4b4fa0116f..d1326ff08fb 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -703,16 +703,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""): licenses = ["notice"], # Apache 2.0 sha256_urls = { "a7d00bfd54525bc694b6e32f64c7ebcf5e6b7ae3657be5cc12767bce74654a47": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", + "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/LICENSE.txt", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/LICENSE.txt", ], - "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", + "b8d68724e111d3fd9516255733d1e9469de72e1cc4733c33702f260a011ab117": [ + "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/lib/tsc.js", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/lib/tsc.js", ], "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", + "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/lib/lib.es6.d.ts", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.4/lib/lib.es6.d.ts", ], }, extra_build_file_content = "\n".join([ From 24623653b2305aed001a45007f513b98ab15fa5f Mon Sep 17 00:00:00 2001 From: James Qin Date: Wed, 31 May 2017 20:07:53 -0700 Subject: [PATCH 03/72] Fix graph text format serialization PiperOrigin-RevId: 157669530 --- tensorflow/python/framework/graph_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py index 0033a370883..f909bcd62d2 100644 --- a/tensorflow/python/framework/graph_io.py +++ b/tensorflow/python/framework/graph_io.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import os.path +from google.protobuf import text_format from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io @@ -64,7 +65,8 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True): file_io.recursive_create_dir(logdir) path = os.path.join(logdir, name) if as_text: - file_io.atomic_write_string_to_file(path, str(graph_def)) + file_io.atomic_write_string_to_file(path, + text_format.MessageToString(graph_def)) else: file_io.atomic_write_string_to_file(path, graph_def.SerializeToString()) return path From 25bb504ccd204f916441d2bebe8aa008e85e8433 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 May 2017 21:13:46 -0700 Subject: [PATCH 04/72] Make a plugin that serves data for the audio dashboard. Subsequent changes will make TensorBoard use this audio plugin instead of the previous handlers for audio-related data. PiperOrigin-RevId: 157673132 --- tensorflow/BUILD | 1 + tensorflow/contrib/cmake/tf_python.cmake | 1 + tensorflow/contrib/cmake/tf_tests.cmake | 3 +- tensorflow/tensorboard/BUILD | 1 + tensorflow/tensorboard/plugins/audio/BUILD | 48 ++++++ .../tensorboard/plugins/audio/audio_plugin.py | 135 +++++++++++++++ .../plugins/audio/audio_plugin_test.py | 157 ++++++++++++++++++ 7 files changed, 345 insertions(+), 1 deletion(-) create mode 100644 tensorflow/tensorboard/plugins/audio/BUILD create mode 100644 tensorflow/tensorboard/plugins/audio/audio_plugin.py create mode 100644 tensorflow/tensorboard/plugins/audio/audio_plugin_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index b90dc1b2050..055c55a7170 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -379,6 +379,7 @@ filegroup( "//tensorflow/tensorboard/demo:all_files", "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/plugins:all_files", + "//tensorflow/tensorboard/plugins/audio:all_files", "//tensorflow/tensorboard/plugins/histograms:all_files", "//tensorflow/tensorboard/plugins/images:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index c9b5e20cf8e..243aabc73be 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -229,6 +229,7 @@ add_python_module("tensorflow/tensorboard") add_python_module("tensorflow/tensorboard/backend") add_python_module("tensorflow/tensorboard/backend/event_processing") add_python_module("tensorflow/tensorboard/plugins") +add_python_module("tensorflow/tensorboard/plugins/audio") add_python_module("tensorflow/tensorboard/plugins/histograms") add_python_module("tensorflow/tensorboard/plugins/images") add_python_module("tensorflow/tensorboard/plugins/projector") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index a559d3b94b0..0eee80cccee 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -206,10 +206,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Broken TensorBoard tests due to different paths in windows "${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py" + "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py" + "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py" # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index b5bff9eaf72..0b9c254b514 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -13,6 +13,7 @@ py_binary( deps = [ "//tensorflow/tensorboard/backend:application", "//tensorflow/tensorboard/backend/event_processing:event_file_inspector", + "//tensorflow/tensorboard/plugins/audio:audio_plugin", "//tensorflow/tensorboard/plugins/histograms:histograms_plugin", "//tensorflow/tensorboard/plugins/images:images_plugin", "//tensorflow/tensorboard/plugins/projector:projector_plugin", diff --git a/tensorflow/tensorboard/plugins/audio/BUILD b/tensorflow/tensorboard/plugins/audio/BUILD new file mode 100644 index 00000000000..5ef52a3d0fe --- /dev/null +++ b/tensorflow/tensorboard/plugins/audio/BUILD @@ -0,0 +1,48 @@ +# Description: +# TensorBoard plugin for audio + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "audio_plugin", + srcs = ["audio_plugin.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/tensorboard/backend:http_util", + "//tensorflow/tensorboard/backend/event_processing:event_accumulator", + "//tensorflow/tensorboard/plugins:base_plugin", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +py_test( + name = "audio_plugin_test", + size = "small", + srcs = ["audio_plugin_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":audio_plugin", + "//tensorflow:tensorflow_py", + "//tensorflow/tensorboard/backend:application", + "//tensorflow/tensorboard/backend/event_processing:event_multiplexer", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + visibility = ["//tensorflow:__pkg__"], +) diff --git a/tensorflow/tensorboard/plugins/audio/audio_plugin.py b/tensorflow/tensorboard/plugins/audio/audio_plugin.py new file mode 100644 index 00000000000..ee63b67637d --- /dev/null +++ b/tensorflow/tensorboard/plugins/audio/audio_plugin.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The TensorBoard Audio plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import urllib +from werkzeug import wrappers + +from tensorflow.tensorboard.backend import http_util +from tensorflow.tensorboard.backend.event_processing import event_accumulator +from tensorflow.tensorboard.plugins import base_plugin + +_PLUGIN_PREFIX_ROUTE = event_accumulator.AUDIO + + +class AudioPlugin(base_plugin.TBPlugin): + """Audio Plugin for TensorBoard.""" + + plugin_name = _PLUGIN_PREFIX_ROUTE + + def get_plugin_apps(self, multiplexer, unused_logdir): + self._multiplexer = multiplexer + return { + '/audio': self._serve_audio_metadata, + '/individualAudio': self._serve_individual_audio, + '/tags': self._serve_tags, + } + + def is_active(self): + """The audio plugin is active iff any run has at least one relevant tag.""" + return any(self.index_impl().values()) + + def _index_impl(self): + return { + run_name: run_data[event_accumulator.AUDIO] + for (run_name, run_data) in self._multiplexer.Runs().items() + if event_accumulator.AUDIO in run_data + } + + @wrappers.Request.application + def _serve_audio_metadata(self, request): + """Given a tag and list of runs, serve a list of metadata for audio. + + Note that the audio themselves are not sent; instead, we respond with URLs + to the audio. The frontend should treat these URLs as opaque and should not + try to parse information about them or generate them itself, as the format + may change. + + Args: + request: A werkzeug.wrappers.Request object. + + Returns: + A werkzeug.Response application. + """ + tag = request.args.get('tag') + run = request.args.get('run') + + audio_list = self._multiplexer.Audio(run, tag) + response = self._audio_response_for_run(audio_list, run, tag) + return http_util.Respond(request, response, 'application/json') + + def _audio_response_for_run(self, run_audio, run, tag): + """Builds a JSON-serializable object with information about run_audio. + + Args: + run_audio: A list of event_accumulator.AudioValueEvent objects. + run: The name of the run. + tag: The name of the tag the audio entries all belong to. + + Returns: + A list of dictionaries containing the wall time, step, URL, width, and + height for each audio entry. + """ + response = [] + for index, run_audio_clip in enumerate(run_audio): + response.append({ + 'wall_time': run_audio_clip.wall_time, + 'step': run_audio_clip.step, + 'content_type': run_audio_clip.content_type, + 'query': self._query_for_individual_audio(run, tag, index) + }) + return response + + def _query_for_individual_audio(self, run, tag, index): + """Builds a URL for accessing the specified audio. + + This should be kept in sync with _serve_audio_metadata. Note that the URL is + *not* guaranteed to always return the same audio, since audio may be + unloaded from the reservoir as new audio entries come in. + + Args: + run: The name of the run. + tag: The tag. + index: The index of the audio entry. Negative values are OK. + + Returns: + A string representation of a URL that will load the index-th sampled audio + in the given run with the given tag. + """ + query_string = urllib.parse.urlencode({ + 'run': run, + 'tag': tag, + 'index': index + }) + return query_string + + @wrappers.Request.application + def _serve_individual_audio(self, request): + """Serves an individual audio entry.""" + tag = request.args.get('tag') + run = request.args.get('run') + index = int(request.args.get('index')) + audio = self._multiplexer.Audio(run, tag)[index] + return http_util.Respond( + request, audio.encoded_audio_string, audio.content_type) + + @wrappers.Request.application + def _serve_tags(self, request): + index = self._index_impl() + return http_util.Respond(request, index, 'application/json') diff --git a/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py b/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py new file mode 100644 index 00000000000..961691086e1 --- /dev/null +++ b/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py @@ -0,0 +1,157 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the Tensorboard audio plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json +import os +import shutil +import tempfile + +import numpy +from six.moves import urllib +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf +from werkzeug import test as werkzeug_test +from werkzeug import wrappers + +from tensorflow.tensorboard.backend import application +from tensorflow.tensorboard.backend.event_processing import event_multiplexer +from tensorflow.tensorboard.plugins.audio import audio_plugin + + +class AudioPluginTest(tf.test.TestCase): + + def setUp(self): + self.log_dir = tempfile.mkdtemp() + + # We use numpy.random to generate audio. We seed to avoid non-determinism + # in this test. + numpy.random.seed(42) + + # Create audio summaries for run foo. + tf.reset_default_graph() + sess = tf.Session() + placeholder = tf.placeholder(tf.float32) + tf.summary.audio(name="baz", tensor=placeholder, sample_rate=44100) + merged_summary_op = tf.summary.merge_all() + foo_directory = os.path.join(self.log_dir, "foo") + writer = tf.summary.FileWriter(foo_directory) + writer.add_graph(sess.graph) + for step in xrange(2): + # The floats (sample data) range from -1 to 1. + writer.add_summary(sess.run(merged_summary_op, feed_dict={ + placeholder: numpy.random.rand(42, 22050) * 2 - 1 + }), global_step=step) + writer.close() + + # Create audio summaries for run bar. + tf.reset_default_graph() + sess = tf.Session() + placeholder = tf.placeholder(tf.float32) + tf.summary.audio(name="quux", tensor=placeholder, sample_rate=44100) + merged_summary_op = tf.summary.merge_all() + bar_directory = os.path.join(self.log_dir, "bar") + writer = tf.summary.FileWriter(bar_directory) + writer.add_graph(sess.graph) + for step in xrange(2): + # The floats (sample data) range from -1 to 1. + writer.add_summary(sess.run(merged_summary_op, feed_dict={ + placeholder: numpy.random.rand(42, 11025) * 2 - 1 + }), global_step=step) + writer.close() + + # Start a server with the plugin. + multiplexer = event_multiplexer.EventMultiplexer({ + "foo": foo_directory, + "bar": bar_directory, + }) + plugin = audio_plugin.AudioPlugin() + wsgi_app = application.TensorBoardWSGIApp( + self.log_dir, [plugin], multiplexer, reload_interval=0) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + self.routes = plugin.get_plugin_apps(multiplexer, self.log_dir) + + def tearDown(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + + def _DeserializeResponse(self, byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) + + def testRoutesProvided(self): + """Tests that the plugin offers the correct routes.""" + self.assertIsInstance(self.routes["/audio"], collections.Callable) + self.assertIsInstance(self.routes["/individualAudio"], collections.Callable) + self.assertIsInstance(self.routes["/tags"], collections.Callable) + + def testAudioRoute(self): + """Tests that the /audio routes returns with the correct data.""" + response = self.server.get( + "/data/plugin/audio/audio?run=foo&tag=baz/audio/0") + self.assertEqual(200, response.status_code) + + # Verify that the correct entries are returned. + entries = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(entries)) + + # Verify that the 1st entry is correct. + entry = entries[0] + self.assertEqual(0, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["0"], parsed_query["index"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) + + # Verify that the 2nd entry is correct. + entry = entries[1] + self.assertEqual(1, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["1"], parsed_query["index"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) + + def testIndividualAudioRoute(self): + """Tests fetching an individual audio.""" + response = self.server.get( + "/data/plugin/audio/individualAudio?run=bar&tag=quux/audio/0&index=0") + self.assertEqual(200, response.status_code) + self.assertEqual("audio/wav", response.headers.get("content-type")) + + def testRunsRoute(self): + """Tests that the /runs route offers the correct run to tag mapping.""" + response = self.server.get("/data/plugin/audio/tags") + self.assertEqual(200, response.status_code) + run_to_tags = self._DeserializeResponse(response.get_data()) + self.assertItemsEqual(("foo", "bar"), run_to_tags.keys()) + self.assertItemsEqual( + ["baz/audio/0", "baz/audio/1", "baz/audio/2"], run_to_tags["foo"]) + self.assertItemsEqual( + ["quux/audio/0", "quux/audio/1", "quux/audio/2"], run_to_tags["bar"]) + + +if __name__ == "__main__": + tf.test.main() From d9620cab82760834a418df7d914c3c21b984b13d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 May 2017 21:24:02 -0700 Subject: [PATCH 05/72] Add flag to determine whether to do L1 optimizations and inline functions. Default is to do them. In tf_optimizer don't inline or do l1 optimizations. PiperOrigin-RevId: 157673614 --- .../core/grappler/grappler_item_builder.cc | 23 +++++++++++-------- .../core/grappler/grappler_item_builder.h | 7 ++++-- .../grappler/grappler_item_builder_test.cc | 1 + tensorflow/python/grappler/tf_optimizer.i | 4 +++- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 9ac0303da92..384402ad291 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/grappler/grappler_item_builder.h" #include @@ -70,7 +69,8 @@ void InitializeTensor(DataType type, Tensor* tensor) { // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in // order to get the correct session options and environment, and performing the // correct optimizations. -Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) { +Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, + const ItemConfig& cfg) { // Create a session option for a single GPU device. SessionOptions options; @@ -94,7 +94,12 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) { // Optimizer options: L1 and inlining. L1 is default. OptimizerOptions* optimizer_opts = options.config.mutable_graph_options()->mutable_optimizer_options(); - optimizer_opts->set_do_function_inlining(true); + if (cfg.apply_optimizations) { + optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L1); + } else { + optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0); + } + optimizer_opts->set_do_function_inlining(cfg.inline_functions); // Create the function library runtime. std::unique_ptr flib(NewFunctionLibraryRuntime( @@ -130,13 +135,11 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( new_item->graph = meta_graph.graph_def(); // Optimize the graph (function inlining, l1 optimizations, etc). - if (cfg.apply_optimizations) { - Status optimize_status = - OptimizeGraph(meta_graph.graph_def(), &new_item->graph); - if (!optimize_status.ok()) { - LOG(ERROR) << "Function optimization failed: " << optimize_status; - return nullptr; - } + Status optimize_status = + OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg); + if (!optimize_status.ok()) { + LOG(ERROR) << "Function optimization failed: " << optimize_status; + return nullptr; } // Attempt to detect the fetch node(s). diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index 62be8dfe14f..3aa1d2027f5 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -31,7 +31,8 @@ struct ItemConfig { : ignore_user_placement(true), ignore_colocation(true), placeholder_unknown_output_shape_dim(-1), - apply_optimizations(true) {} + apply_optimizations(true), + inline_functions(true) {} // If true, ignore all user specified node placement. bool ignore_user_placement; @@ -40,8 +41,10 @@ struct ItemConfig { // Dimension to use if a placeholder node has an _output_shapes attribute with // a dimension of -1. int placeholder_unknown_output_shape_dim; - // If true, does inlining and L1 optimizations. + // If true, does L1 optimizations. bool apply_optimizations; + // If true, does inlining. + bool inline_functions; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 54400f7051c..92225ffb1b4 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -70,6 +70,7 @@ std::unique_ptr CreateGrapplerItem(const GraphDef &def, const CollectionDef &fetches) { MetaGraphDef meta_def; ItemConfig cfg; + cfg.inline_functions = true; *meta_def.mutable_graph_def() = def; (*meta_def.mutable_collection_def())["train_op"] = fetches; return GrapplerItemFromMetaGraphDef("0", meta_def, cfg); diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index 404ce351801..a8067467d91 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -67,7 +67,9 @@ PyObject* TF_OptimizeGraph( const tensorflow::RewriterConfig& rewriter_config, const tensorflow::MetaGraphDef& metagraph, const string& graph_id, TF_Status* out_status) { - const tensorflow::grappler::ItemConfig item_config; + tensorflow::grappler::ItemConfig item_config; + item_config.inline_functions = false; + item_config.apply_optimizations = false; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); std::unordered_map device_map; From 6db400bbcf8fb084f01ad036357e83937c2c6254 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 May 2017 21:55:19 -0700 Subject: [PATCH 06/72] Refactoring Python op code generation. PiperOrigin-RevId: 157675126 --- tensorflow/contrib/cmake/tf_python.cmake | 1 + tensorflow/python/BUILD | 5 +- tensorflow/python/framework/python_op_gen.cc | 278 ++++++++++-------- tensorflow/python/framework/python_op_gen.h | 2 +- .../python/framework/python_op_gen_internal.h | 86 ++++++ 5 files changed, 248 insertions(+), 124 deletions(-) create mode 100644 tensorflow/python/framework/python_op_gen_internal.h diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 243aabc73be..132d84d00bb 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -530,6 +530,7 @@ set(tf_python_op_gen_main_srcs "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1bb92e79ca2..c959ad904d7 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -318,7 +318,10 @@ py_test( cc_library( name = "python_op_gen", srcs = ["framework/python_op_gen.cc"], - hdrs = ["framework/python_op_gen.h"], + hdrs = [ + "framework/python_op_gen.h", + "framework/python_op_gen_internal.h", + ], visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index a3168a00883..00260fe0bf7 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -36,9 +36,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/python/framework/python_op_gen_internal.h" namespace tensorflow { -namespace { +namespace python_op_gen_internal { const int kRightMargin = 78; @@ -67,15 +68,11 @@ bool IsPythonReserved(const string& s) { "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError", "UnicodeWarning", "UserWarning", "ValueError", "Warning", "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__", - "__package__", - // Imports and symbols used in the generated code: - "_text_format", "_op_def_pb2", "_common_shapes", "_op_def_registry", - "_ops", "_op_def_library"}); + "__package__"}); return kPythonReserved->count(s) > 0; } -// Add a _ to the end of s if necessary to avoid a Python keyword or built-in. string AvoidPythonReserved(const string& s) { if (IsPythonReserved(s)) return strings::StrCat(s, "_"); return s; @@ -323,8 +320,8 @@ string StringToPython(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } -string DataTypeToPython(DataType dtype) { - return strings::StrCat("tf.", PythonDataTypeString(dtype)); +string DataTypeToPython(DataType dtype, const string& dtype_module) { + return strings::StrCat(dtype_module, PythonDataTypeString(dtype)); } string ShapeToPython(const TensorShapeProto& shape) { @@ -346,7 +343,8 @@ string TensorToPython(const TensorProto& proto) { return ProtoShortDebugString(proto); } -string AttrListToPython(const AttrValue& value) { +string AttrListToPython(const AttrValue& value, + const string& dtype_module = "tf.") { string ret; if (value.list().s_size() > 0) { for (int i = 0; i < value.list().s_size(); ++i) { @@ -371,7 +369,8 @@ string AttrListToPython(const AttrValue& value) { } else if (value.list().type_size() > 0) { for (int i = 0; i < value.list().type_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend(&ret, DataTypeToPython(value.list().type(i))); + strings::StrAppend(&ret, + DataTypeToPython(value.list().type(i), dtype_module)); } } else if (value.list().shape_size() > 0) { for (int i = 0; i < value.list().shape_size(); ++i) { @@ -392,7 +391,8 @@ string AttrListToPython(const AttrValue& value) { return ret; } -string AttrValueToPython(const string& type, const AttrValue& value) { +string AttrValueToPython(const string& type, const AttrValue& value, + const string& dtype_module) { if (type == "string") { return StringToPython(value.s()); } else if (type == "int") { @@ -402,7 +402,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) { } else if (type == "bool") { return value.b() ? "True" : "False"; } else if (type == "type") { - return DataTypeToPython(value.type()); + return DataTypeToPython(value.type(), dtype_module); } else if (type == "shape") { return ShapeToPython(value.shape()); } else if (type == "tensor") { @@ -410,7 +410,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) { } else if (type == "func") { return StringToPython(value.func().name()); } else if (StringPiece(type).starts_with("list(")) { - return strings::StrCat("[", AttrListToPython(value), "]"); + return strings::StrCat("[", AttrListToPython(value, dtype_module), "]"); } else { return "?"; } @@ -432,35 +432,41 @@ void GenerateLowerCaseOpName(const string& str, string* result) { } } -} // namespace +static void AddDelimiter(string* append_to, const string& delim) { + if (!append_to->empty()) strings::StrAppend(append_to, delim); +} -string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { - string result; - // Map from attr name to the first input arg it is inferred from. - std::unordered_map inferred_attrs; +GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name) + : op_def_(op_def), + function_name_(function_name), + num_outs_(op_def.output_arg_size()) {} + +GenPythonOp::~GenPythonOp() {} + +string GenPythonOp::Code() { // This has all the input args followed by those attrs that don't have // defaults. std::vector args_no_default; // The parameters with defaults (these have to be listed after those without). // No input args are included, just attrs. std::vector args_with_defaults; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); + for (int i = 0; i < op_def_.input_arg_size(); ++i) { + const auto& arg(op_def_.input_arg(i)); args_no_default.push_back(arg.name()); if (!arg.type_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name()); + gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); } else if (!arg.type_list_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(), + gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(), arg.name()); } if (!arg.number_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name()); + gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); } } - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < op_def_.attr_size(); ++i) { + const auto& attr(op_def_.attr(i)); // Do not add inferred attrs to the Python function signature. - if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) { + if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { if (attr.has_default_value()) { args_with_defaults.push_back(attr.name()); } else { @@ -471,110 +477,92 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { // Save the list of attr parameters (attrs that won't be inferred), // those with defaults go at the end. - std::vector attrs; // Get the attrs in the order we want by taking the attrs without defaults // from the end of args_no_default, and adding args_no_default. - attrs.reserve(args_no_default.size() - op_def.input_arg_size() + - args_with_defaults.size()); - attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(), - args_no_default.end()); - attrs.insert(attrs.end(), args_with_defaults.begin(), - args_with_defaults.end()); + attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() + + args_with_defaults.size()); + attrs_.insert(attrs_.end(), + args_no_default.begin() + op_def_.input_arg_size(), + args_no_default.end()); + attrs_.insert(attrs_.end(), args_with_defaults.begin(), + args_with_defaults.end()); - std::vector param_names; - param_names.reserve(args_no_default.size() + args_with_defaults.size()); + param_names_.reserve(args_no_default.size() + args_with_defaults.size()); string parameters; for (const string& name : args_no_default) { - if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + AddDelimiter(¶meters, ", "); const string param = AvoidPythonReserved(name); strings::StrAppend(¶meters, param); - param_names.push_back(param); + param_names_.push_back(param); } for (const string& name : args_with_defaults) { - if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + AddDelimiter(¶meters, ", "); const string param = AvoidPythonReserved(name); strings::StrAppend(¶meters, param, "=None"); - param_names.push_back(param); + param_names_.push_back(param); } + AddDelimiter(¶meters, ", "); + strings::StrAppend(¶meters, "name=None"); - const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name); + AddDefLine(parameters); + AddDocStringDescription(); + AddDocStringArgs(); + AddDocStringInputs(); + AddDocStringAttrs(); + AddDocStringNameArg(); + AddOutputGlobals(); + AddDocStringOutputs(); + strings::StrAppend(&result_, " \"\"\"\n"); + AddBody(" "); + strings::StrAppend(&result_, "\n\n"); - const int num_outs = op_def.output_arg_size(); - // Prepare a NamedTuple type to hold the outputs, if there are multiple - if (num_outs > 1) { - // Prepare the list of output names - std::vector out_names(num_outs); - for (int i = 0; i < num_outs; ++i) { - if (!op_def.output_arg(i).name().empty()) { - out_names[i] = op_def.output_arg(i).name(); - } else { - out_names[i] = strings::StrCat("output", i); - } - } - string out_names_list = - strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); + return prelude_ + result_; +} - // Provide the output names as a Python list - string lower_op_name_outputs = - strings::StrCat("_", lower_op_name, "_outputs"); - const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); - strings::StrAppend(&result, "\n", - WordWrap(outputs_prefix, out_names_list, kRightMargin), - "\n"); +void GenPythonOp::AddDefLine(const string& parameters) { + const string def_prefix = strings::StrCat("def ", function_name_, "("); + strings::StrAppend( + &result_, WordWrap(def_prefix, parameters + "):", kRightMargin), "\n"); +} - strings::StrAppend(&result, "_", op_def.name(), - "Output = _collections.namedtuple(\n"); - const string tuple_type_prefix = " "; - const string tuple_type_suffix = strings::StrCat( - "\"", op_def.name(), "\", ", lower_op_name_outputs, ")"); - strings::StrAppend( - &result, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), - "\n\n"); - } - strings::StrAppend(&result, "\n"); - - // Print: def Function(parameters): - const string def_prefix = strings::StrCat("def ", lower_op_name, "("); - const bool has_args = args_no_default.size() + args_with_defaults.size() > 0; - const string def_suffix = - strings::StrCat(parameters, has_args ? ", " : "", "name=None):"); - - strings::StrAppend(&result, WordWrap(def_prefix, def_suffix, kRightMargin), - "\n"); - - // Format the Op's descriptions so that it can be a Python docstring. +void GenPythonOp::AddDocStringDescription() { string comment; - if (op_def.summary().empty()) { + if (op_def_.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def.summary(), "\n"); - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description())); + comment = strings::StrCat(op_def_.summary(), "\n"); + if (!op_def_.description().empty()) { + strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description())); } } + strings::StrAppend(&result_, " r\"\"\"", comment, "\n"); +} - strings::StrAppend(&result, " r\"\"\"", comment, "\n Args:\n"); +void GenPythonOp::AddDocStringArgs() { + strings::StrAppend(&result_, " Args:\n"); +} - // Inputs - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - StringPiece description = op_def.input_arg(i).description(); +void GenPythonOp::AddDocStringInputs() { + for (int i = 0; i < op_def_.input_arg_size(); ++i) { + const auto& arg(op_def_.input_arg(i)); + StringPiece description = op_def_.input_arg(i).description(); string desc; if (ConsumeEquals(&description)) { // Skip the generated type info. - desc = strings::StrCat(param_names[i], ": "); + desc = strings::StrCat(param_names_[i], ": "); } else { - desc = strings::StrCat(param_names[i], ": ", - ArgTypeName(op_def, arg, inferred_attrs, false)); + desc = strings::StrCat(param_names_[i], ": ", + ArgTypeName(op_def_, arg, inferred_attrs_, false)); } if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - strings::StrAppend(&result, Indent(4, 6, desc)); + strings::StrAppend(&result_, Indent(4, 6, desc)); } +} - // Attrs - for (const string& name : attrs) { - const auto& attr = *FindAttr(name, op_def); +void GenPythonOp::AddDocStringAttrs() { + for (const string& name : attrs_) { + const auto& attr = *FindAttr(name, op_def_); string desc = strings::StrCat(AvoidPythonReserved(name), ": "); static const char* const kAttrTypeName[][2] = { @@ -638,40 +626,86 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { AppendWithinWidth(&desc, attr.description(), kRightMargin - 4 /* indent */); } - strings::StrAppend(&result, Indent(4, 6, desc)); + strings::StrAppend(&result_, Indent(4, 6, desc)); } +} - strings::StrAppend(&result, +void GenPythonOp::AddDocStringNameArg() { + strings::StrAppend(&result_, " name: A name for the operation (optional).\n"); +} - std::vector output_type_string; - output_type_string.reserve(num_outs); - for (int i = 0; i < num_outs; ++i) { - output_type_string.push_back( - ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true)); +void GenPythonOp::AddOutputGlobals() { + // Prepare a NamedTuple type to hold the outputs, if there are multiple + if (num_outs_ > 1) { + // Prepare the list of output names + std::vector out_names(num_outs_); + for (int i = 0; i < num_outs_; ++i) { + if (!op_def_.output_arg(i).name().empty()) { + out_names[i] = op_def_.output_arg(i).name(); + } else { + out_names[i] = strings::StrCat("output", i); + } + } + string out_names_list = + strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); + + // Provide the output names as a Python list + string lower_op_name_outputs = + strings::StrCat("_", function_name_, "_outputs"); + const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); + strings::StrAppend(&prelude_, "\n", + WordWrap(outputs_prefix, out_names_list, kRightMargin), + "\n"); + + strings::StrAppend(&prelude_, "_", op_def_.name(), + "Output = _collections.namedtuple(\n"); + const string tuple_type_prefix = " "; + const string tuple_type_suffix = strings::StrCat( + "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")"); + strings::StrAppend( + &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), + "\n\n"); } - strings::StrAppend(&result, GetReturns(op_def, output_type_string)); + strings::StrAppend(&prelude_, "\n"); +} - string return_prefix = strings::StrCat(" result = _op_def_lib.apply_op("); - string return_args = strings::StrCat("\"", op_def.name(), "\", "); - for (size_t i = 0; i < param_names.size(); ++i) { - strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", "); +void GenPythonOp::AddDocStringOutputs() { + std::vector output_type_string; + output_type_string.reserve(num_outs_); + for (int i = 0; i < num_outs_; ++i) { + output_type_string.push_back( + ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true)); + } + strings::StrAppend(&result_, GetReturns(op_def_, output_type_string)); +} + +void GenPythonOp::AddBody(const string& prefix) { + string return_prefix = + strings::StrCat(prefix, "result = _op_def_lib.apply_op("); + string return_args = strings::StrCat("\"", op_def_.name(), "\", "); + for (size_t i = 0; i < param_names_.size(); ++i) { + strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i], + ", "); } strings::StrAppend(&return_args, "name=name)"); - strings::StrAppend(&result, " \"\"\"\n", + strings::StrAppend(&result_, // Wrap the arguments, and indent to the (. WordWrap(return_prefix, return_args, kRightMargin), "\n"); - if (num_outs <= 1) { - strings::StrAppend(&result, " return result\n"); + if (num_outs_ <= 1) { + strings::StrAppend(&result_, prefix, "return result\n"); } else { - strings::StrAppend(&result, " return _", op_def.name(), + strings::StrAppend(&result_, prefix, "return _", op_def_.name(), "Output._make(result)\n"); } - strings::StrAppend(&result, "\n\n"); +} - return result; +} // namespace python_op_gen_internal + +string GetPythonOp(const OpDef& op_def, const string& function_name) { + return python_op_gen_internal::GenPythonOp(op_def, function_name).Code(); } string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, @@ -711,20 +745,20 @@ from tensorflow.python.framework import op_def_library as _op_def_library } } - // PrintPythonOp(op_def, is_hidden, op_def.name()); - string lower_case_name; - GenerateLowerCaseOpName(op_def.name(), &lower_case_name); + string function_name; + python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), + &function_name); + if (is_hidden) function_name = strings::StrCat("_", function_name); // When users create custom python wrappers, they may link in the // default op registry by accident, and because they can't // enumerate all 'hidden' symbols, this guard is to prevent // instantiating a python reserved word in their wrapper. - if (!is_hidden && IsPythonReserved(lower_case_name)) { + if (python_op_gen_internal::IsPythonReserved(function_name)) { continue; } - strings::StrAppend(&result, - GetPythonOp(op_def, is_hidden, lower_case_name)); + strings::StrAppend(&result, GetPythonOp(op_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index d865c238743..f485044c5af 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -31,7 +31,7 @@ void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, bool require_shapes); string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, bool require_shapes); -string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name); +string GetPythonOp(const OpDef& op_def, const string& function_name); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h new file mode 100644 index 00000000000..44b1aed71f1 --- /dev/null +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ +#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace python_op_gen_internal { + +// Returns true if s is a Python keyword or built-in. +bool IsPythonReserved(const string& s); + +// Add a _ to the end of s if necessary to avoid a Python keyword or built-in. +string AvoidPythonReserved(const string& s); + +// Convert an AttrValue with type `type` to the Python representation for +// that value. +string AttrValueToPython(const string& type, const AttrValue& value, + const string& dtype_module = "tf."); + +void GenerateLowerCaseOpName(const string& str, string* result); + +class GenPythonOp { + public: + GenPythonOp(const OpDef& op_def, const string& function_name); + virtual ~GenPythonOp(); + + virtual string Code(); + + protected: + // Print: def Function(parameters): + void AddDefLine(const string& parameters); + + // Format the Op's descriptions so that it can be a Python docstring. + void AddDocStringDescription(); + + void AddDocStringArgs(); + void AddDocStringInputs(); + void AddDocStringAttrs(); + void AddDocStringNameArg(); + void AddOutputGlobals(); + void AddDocStringOutputs(); + void AddBody(const string& prefix); + + // From constructor arguments + const OpDef& op_def_; + const string& function_name_; + const int num_outs_; + + // Return value from Code() is prelude_ + result_. + string prelude_; // Code before function definition + string result_; // Function definition + + // Map from attr name to the first input arg it is inferred from + std::unordered_map inferred_attrs_; + + // The names of the non-inferred attrs, in parameter order + std::vector attrs_; + + // All parameters, including inputs & non-inferred attrs, required and those + // with defaults, except "name" + std::vector param_names_; +}; + +} // namespace python_op_gen_internal +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ From d3e840a6c1d26c59fe7b01963e0a2a1dc0067496 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 31 May 2017 21:56:30 -0700 Subject: [PATCH 07/72] Disable writing of compressed checkpoints. Snappy compression (and decompression) was enabled after the 1.1 release (in commit 63b2f999d3f22cfe915b89103faa1b0a1b1b7617). This means that checkpoints produced by the 1.2.0 release candidates will cause TensorFlow 1.1 (and prior) binaries to crash as they CHECK fail when trying to load snappy-compressed tables. To ease transition, disable writing of compressed checkpoints in 1.2.0 for now. Reconsider this in the next release. PiperOrigin-RevId: 157675189 --- .../core/util/tensor_bundle/tensor_bundle.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index dd04cea40d1..334444a4a22 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -255,6 +255,16 @@ Status CorruptFileError(const Status& in_status, const string& filename, detail, "): ", in_status.error_message())); } +table::Options TableBuilderOptions() { + table::Options o; + // Compressed tables cannot be read by TensorFlow releases prior to 1.1. + // To smoothen the transition, compressed writes are disabled for now + // (version 1.2) with the intention that they will be enabled again at + // some point (perhaps the 1.3 release?). + o.compression = table::kNoCompression; + return o; +} + } // namespace BundleWriter::BundleWriter(Env* env, StringPiece prefix) @@ -442,7 +452,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, table::Table* table = nullptr; TF_RETURN_IF_ERROR( - table::Table::Open(table::Options(), file.get(), file_size, &table)); + table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table)); std::unique_ptr table_deleter(table); std::unique_ptr iter(table->NewIterator()); @@ -555,7 +565,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, TF_RETURN_IF_ERROR( env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata)); { - table::TableBuilder builder(table::Options(), merged_metadata.get()); + table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get()); // Header entry. BundleHeaderProto header; header.set_num_shards(merge.num_shards); From a23255bc079bb94006aa0bfdc5000eed0d97098a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 08:31:23 -0700 Subject: [PATCH 08/72] Adds missing group OP to benchmark PiperOrigin-RevId: 157716500 --- tensorflow/python/kernel_tests/cholesky_op_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index 9a1c918b150..b7f8f5c51f6 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -303,7 +304,7 @@ class CholeskyBenchmark(test.Benchmark): ops.device("/cpu:0"): l = linalg_ops.cholesky(data) self.run_op_benchmark( - sess, l, + sess, control_flow_ops.group(l,), min_iters=25, name="cholesky_cpu_{size}".format(size=size)) @@ -328,7 +329,7 @@ class CholeskyBenchmark(test.Benchmark): ops.device(device): grad = grad_fn(l, grad_data) self.run_op_benchmark( - sess, grad, + sess, control_flow_ops.group(grad,), min_iters=25, name="{name}_{dev}_{size}".format( name=name, dev=grad.device, size=size)) From ce32228c49e595f966485acee947131e4ab04905 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 7 Jan 2017 09:19:27 -0800 Subject: [PATCH 09/72] Add release notes for Intel MKL integration. PiperOrigin-RevId: 157722003 --- RELEASE.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index ec24d6fd80a..1590aabfefd 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -39,6 +39,15 @@ be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- processing of the rnn. For RNN decoding, this functionality has been replaced with an alternative API in `tf.contrib.seq2seq`. +* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of + optimized deep learning primitives: In addition to matrix multiplication and + convolution, these building blocks include: + Direct batched convolution + Pooling: maximum, minimum, average + Normalization: LRN, batch normalization + Activation: rectified linear unit (ReLU) + Data manipulation: multi-dimensional transposition (conversion), split, + concat, sum and scale. ## Bug Fixes and Other Changes * In python, `Operation.get_attr` on type attributes returns the Python DType From eb10a4c494d95e7c17ddc44ef35197d08f2f6b33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 09:42:04 -0700 Subject: [PATCH 10/72] Preallocate vector storage when the ultimate vector size is known in advance PiperOrigin-RevId: 157724431 --- tensorflow/c/c_api.cc | 1 + tensorflow/cc/client/client_session.cc | 2 ++ tensorflow/cc/framework/gradient_checker.cc | 1 + tensorflow/compiler/aot/compile.cc | 1 + tensorflow/compiler/tf2xla/kernels/fill_op.cc | 1 + tensorflow/compiler/tf2xla/kernels/slice_op.cc | 1 + tensorflow/compiler/xla/service/allocation_tracker.cc | 1 + tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc | 1 + tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 1 + .../compiler/xla/service/gpu/stream_assignment_test.cc | 1 + tensorflow/compiler/xla/service/hlo_instruction.cc | 1 + tensorflow/compiler/xla/service/service.cc | 4 ++++ tensorflow/compiler/xla/service/user_computation.cc | 1 + tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc | 2 ++ tensorflow/compiler/xla/tests/log_test.cc | 1 + tensorflow/compiler/xla/tests/params_test.cc | 1 + tensorflow/compiler/xla/tests/slice_test.cc | 1 + tensorflow/compiler/xla/tests/vector_ops_simple_test.cc | 2 ++ .../xla/tools/dumped_computation_to_operation_list.cc | 1 + tensorflow/compiler/xla/tools/dumped_computation_to_text.cc | 1 + tensorflow/compiler/xla/tools/replay_computation.cc | 1 + tensorflow/contrib/batching/kernels/batch_kernels.cc | 1 + tensorflow/contrib/batching/shared_batch_scheduler_test.cc | 1 + .../contrib/boosted_trees/lib/utils/dropout_utils_test.cc | 1 + tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc | 1 + .../contrib/layers/kernels/sparse_feature_cross_kernel.cc | 1 + .../core/common_runtime/gpu/gpu_bfc_allocator_test.cc | 1 + tensorflow/core/common_runtime/kernel_benchmark_testlib.cc | 2 ++ tensorflow/core/common_runtime/session_factory.cc | 1 + tensorflow/core/common_runtime/shape_refiner_test.cc | 1 + tensorflow/core/framework/common_shape_fns.cc | 1 + tensorflow/core/framework/resource_mgr.cc | 1 + tensorflow/core/framework/shape_inference.cc | 1 + tensorflow/core/framework/shape_inference_test.cc | 1 + tensorflow/core/framework/shape_inference_testutil_test.cc | 1 + tensorflow/core/graph/graph_constructor.cc | 1 + tensorflow/core/graph/graph_test.cc | 1 + tensorflow/core/grappler/optimizers/layout_optimizer.cc | 1 + tensorflow/core/kernels/adjust_contrast_op_test.cc | 1 + tensorflow/core/kernels/dequantize_op_test.cc | 1 + tensorflow/core/kernels/dynamic_partition_op.cc | 2 ++ tensorflow/core/kernels/fractional_max_pool_op.cc | 2 ++ tensorflow/core/kernels/gather_op_test.cc | 1 + tensorflow/core/kernels/mfcc_mel_filterbank_test.cc | 1 + tensorflow/core/kernels/mfcc_test.cc | 2 ++ tensorflow/core/kernels/quantization_utils_test.cc | 2 ++ tensorflow/core/kernels/sdca_ops_test.cc | 6 ++++++ tensorflow/core/kernels/serialize_sparse_op.cc | 1 + tensorflow/core/kernels/sparse_cross_op.cc | 1 + tensorflow/core/kernels/sparse_tensors_map_ops.cc | 1 + tensorflow/core/kernels/stage_op.cc | 1 + tensorflow/core/lib/gtl/inlined_vector_test.cc | 2 ++ tensorflow/core/lib/gtl/optional_test.cc | 1 + tensorflow/core/ops/array_grad.cc | 1 + tensorflow/core/ops/array_ops.cc | 2 ++ tensorflow/core/ops/array_ops_test.cc | 5 +++++ tensorflow/core/ops/control_flow_ops_test.cc | 2 ++ tensorflow/core/ops/functional_ops_test.cc | 1 + tensorflow/core/ops/math_ops_test.cc | 1 + tensorflow/core/ops/sparse_ops_test.cc | 1 + tensorflow/core/ops/string_ops_test.cc | 1 + tensorflow/core/platform/cloud/retrying_file_system_test.cc | 1 + tensorflow/core/util/command_line_flags_test.cc | 1 + tensorflow/core/util/ctc/ctc_beam_search_test.cc | 3 +++ tensorflow/python/framework/cpp_shape_inference.cc | 1 + tensorflow/python/lib/core/py_func.cc | 1 + tensorflow/stream_executor/cuda/cuda_dnn.cc | 1 + tensorflow/tools/graph_transforms/summarize_graph_main.cc | 2 ++ tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc | 1 + 69 files changed, 95 insertions(+) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a9644a5555d..77faa475ed4 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, } std::vector dim_vec; + dim_vec.reserve(num_dims); for (int i = 0; i < num_dims; ++i) { dim_vec.push_back(ic->MakeDim(dims[i])); } diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 2879445441d..ba056a8f3a8 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, feeds.emplace_back(feed.first.name(), feed.second.tensor); } std::vector output_tensor_names; + output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { output_tensor_names.push_back(output.name()); } std::vector target_node_names; + target_node_names.reserve(run_outputs.size()); for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); } diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 8f20ff1457b..b8e5411bf71 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose( size_t x_num = x_shapes.size(); // Call AddSymbolicGradients to get 'dxs' (we will feed 'dys'). OutputList dys; + dys.reserve(y_shapes.size()); for (const auto& y_shape : y_shapes) { // TODO(suharshs): This currently assumes that all x's are the same type. dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type())); diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index bb79fe81ab3..ca17c5ab690 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client, compile_result->program_shape = *pshape_or.ValueOrDie(); xla::ProgramShape* pshape = &compile_result->program_shape; std::vector arg_layouts; + arg_layouts.reserve(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 1b5f94d4e5f..1e1d2a1b4b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel { // Convert the dims literal into a vector that we can pass to // ComputationBuilder. std::vector broadcast; + broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); } diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 87cd266708b..51c97d85d7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel { // slice will be an empty handle if the output has no elements. CHECK_EQ(begin.size(), size.size()); std::vector limits; + limits.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + size[i]); } diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 83759a7a0c6..ad2fee2d39a 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -171,6 +171,7 @@ StatusOr> AllocationTracker::DeconstructTuple( executor, allocation->device_memory(), allocation->shape())); std::vector element_handles; + element_handles.reserve(element_bases.size()); for (int i = 0; i < element_bases.size(); ++i) { element_handles.push_back(RegisterInternal( allocation->backend(), allocation->device_ordinal(), element_bases[i], diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index c6749851dbb..dc421695cb1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 00471f72c99..9a09d2c02bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1631,6 +1631,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // Compute the input buffer indices. std::vector io_buffers; + io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); } diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 28d47d2b0f8..56e3ff99fa9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index deb355145a8..19a97c0175f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands, } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector bounds; + bounds.reserve(slice_starts_.size()); for (int i = 0; i < slice_starts_.size(); ++i) { bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 0b94b37d376..c8f2188b53c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -649,6 +649,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), executor->device_ordinal())); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -677,6 +678,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, BuildExecutables(versioned_handles, std::move(module_configs), execute_backend_.get(), executors)); std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { executable_ptrs.push_back(executable.get()); } @@ -752,6 +754,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -820,6 +823,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index ac5f67418ed..4cde03849e9 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit( // to append dimensions on the left the broadcast_dimensions should just // be the n highest dimension numbers of the output shape where n is // the number of input dimensions. + broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { broadcast_dimensions.push_back(i + ShapeUtil::Rank(request.output_shape()) - diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e2e6e25c06c..7a512166171 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); std::vector values; + values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } @@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; + expected.reserve(values.size()); for (float value : values) { expected.push_back(value * value); } diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index b520d89de3c..d3d1039e1bb 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) { builder.Log(x); std::vector expected; + expected.reserve(input.size()); for (float f : input) { expected.push_back(std::log(f)); } diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2f05576ceeb..cd8f06efd82 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } std::vector param_data; + param_data.reserve(param_data_owner.size()); for (const std::unique_ptr& data : param_data_owner) { param_data.push_back(data.get()); } diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index d63582fb98a..82bdd6d35f0 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase { template void RunSliceTenToTwo() { std::vector constant; + constant.reserve(10); for (int i = 0; i < 10; ++i) { constant.push_back(static_cast(i)); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 4ab4c84aa56..41bac6234da 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { ComputationBuilder builder(client_, TestName()); std::vector exponents; + exponents.reserve(count); for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } @@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { auto exp = builder.Exp(x); std::vector expected; + expected.reserve(exponents.size()); for (float exponent : exponents) { expected.push_back(std::exp(exponent)); } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4c242abc9b7..8d7f7fd1237 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { client->GetComputationShape(computation).ConsumeValueOrDie(); std::vector layouts; + layouts.reserve(program_shape->parameters_size()); for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 152e0dcf56a..2a3a8803283 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { client->GetComputationShape(computation).ConsumeValueOrDie(); std::vector layouts; + layouts.reserve(program_shape->parameters_size()); for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ffb2d5aefba..f4d46b26e65 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -74,6 +74,7 @@ StatusOr> ReplayComputation( } std::vector execute_arguments; + execute_arguments.reserve(arguments.size()); for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 1e0957298ba..3c06325651f 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -347,6 +347,7 @@ class BatchResource : public ResourceBase { // Concatenate the tasks ith input tensors into a big output tensor. std::vector to_concatenate; + to_concatenate.reserve(batch->num_tasks()); for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { to_concatenate.push_back(batch->task(task_idx).inputs.at(i)); } diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 809958c737e..3e924ae5f13 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) { &callback_data](std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); std::vector batch_data; + batch_data.reserve(batch->num_tasks()); for (int i = 0; i < batch->num_tasks(); ++i) { batch_data.push_back(batch->mutable_task(i)->size()); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc index 66e0995ecd0..f658532acb2 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc @@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector& vec1, std::vector GetWeightsByIndex(const std::vector& weights, const std::vector& indices) { std::vector res; + res.reserve(indices.size()); for (const int index : indices) { res.push_back(weights[index]); } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 2c6e278fec7..2871c146289 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) { } std::vector sine_wave; + sine_wave.reserve(20000); for (int i = 0; i < 20000; ++i) { sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0)); } diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 47a5b2a2077..219473153bd 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel { ExtractFeatureData(indices_list_in, batch_size, &feature_counts, &feature_start_indices); + columns.reserve(values_list_in.size()); for (int i = 0; i < values_list_in.size(); ++i) { columns.emplace_back(new SparseTensorColumn( values_list_in[i], std::move(feature_counts[i]), diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 9bc86ef6ef8..1c4aaa5f748 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) { int size_index = 0; std::vector ptrs; + ptrs.reserve(delay); for (int i = 0; i < delay; i++) { ptrs.push_back(nullptr); } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 4e14e6fe1a6..7b5cc1c5cba 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -123,10 +123,12 @@ void Benchmark::RunWithArgs( } // Gets inputs' and outputs' rendezvous keys. std::vector> in; + in.reserve(inputs.size()); for (const auto& p : inputs) { in.push_back({GetRendezvousKey(p.first), p.second}); } std::vector out; + out.reserve(outputs.size()); for (const auto& n : outputs) { out.push_back(GetRendezvousKey(n)); } diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 2e81811b7c2..dba7a9253e9 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options, // TODO(mrry): Consider providing a system-default fallback option // in this case. std::vector factory_types; + factory_types.reserve(candidate_factories.size()); for (const auto& candidate_factory : candidate_factories) { factory_types.push_back(candidate_factory.first); } diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 986b657e0e9..466b779e9b0 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -259,6 +259,7 @@ REGISTER_OP("ShapeData") } std::vector dims; + dims.reserve(shape_data->NumElements()); for (int i = 0; i < shape_data->NumElements(); ++i) { dims.emplace_back(c->MakeDim(shape_data->flat()(i))); } diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index d5e6e293d6d..035bceb640f 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -746,6 +746,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, } // Build result of different unknown dims. std::vector dims; + dims.reserve(rank); for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); c->set_output(0, c->MakeShape(dims)); return Status::OK(); diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index ab7dd0c5475..c3666f7ab90 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -96,6 +96,7 @@ string ResourceMgr::DebugString() const { } } std::vector text; + text.reserve(lines.size()); for (const Line& line : lines) { text.push_back(strings::Printf( "%-20s | %-40s | %-40s | %-s", line.container->c_str(), diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 2cbbf966b8a..6ad464cc61e 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -565,6 +565,7 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t, } const auto num_dims = Value(shape_dim); std::vector dims; + dims.reserve(num_dims); for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim()); return ReturnCreatedShape(dims, out); } diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 6f63937108c..a9c0303d4cb 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -783,6 +783,7 @@ TEST_F(ShapeInferenceTest, MakeShape) { std::vector dims; auto in0 = c.input(0); const int rank = c.Rank(in0); + dims.reserve(rank); for (int i = 0; i < rank; ++i) { dims.push_back(c.Dim(in0, rank - i - 1)); } diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc index de14c071b46..20a6807064b 100644 --- a/tensorflow/core/framework/shape_inference_testutil_test.cc +++ b/tensorflow/core/framework/shape_inference_testutil_test.cc @@ -51,6 +51,7 @@ string RunInferShapes(const string& op_name, const string& ins, ShapeInferenceTestOp op(op_name); const int num_inputs = 1 + std::count(ins.begin(), ins.end(), ';'); std::vector src_list; + src_list.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) src_list.emplace_back("a", 0, DT_FLOAT); NodeDef node_def; TF_CHECK_OK(NodeDefBuilder("dummy", op_name) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 1d7eea2206f..19442d8c087 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -496,6 +496,7 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { void RemoveInputs(NodeDef* node_def, const std::vector& inputs_to_remove) { // TODO(skyewm): is there a better way to do this? std::vector inputs; + inputs.reserve(node_def->input_size()); for (int i = 0; i < node_def->input_size(); ++i) { inputs.push_back(node_def->input(i)); } diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 89784c631f0..68848ae8c84 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -110,6 +110,7 @@ class GraphTest : public ::testing::Test { // are readable. static std::vector Stringify(const std::vector& nodes) { std::vector result; + result.reserve(nodes.size()); for (Node* n : nodes) { result.push_back(n->DebugString()); } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index e37c4a5b36a..c42218e447b 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -536,6 +536,7 @@ class AddNProcessor : public AgnosticNodeProcessor { protected: std::vector GetInputPos() const override { std::vector input_pos; + input_pos.reserve(node_->input_size()); for (int i = 0; i < node_->input_size(); i++) { input_pos.push_back(i); } diff --git a/tensorflow/core/kernels/adjust_contrast_op_test.cc b/tensorflow/core/kernels/adjust_contrast_op_test.cc index 06fd7ca419b..d028c0bc591 100644 --- a/tensorflow/core/kernels/adjust_contrast_op_test.cc +++ b/tensorflow/core/kernels/adjust_contrast_op_test.cc @@ -73,6 +73,7 @@ TEST_F(AdjustContrastOpTest, Big_99x99x3) { TF_EXPECT_OK(InitOp()); std::vector values; + values.reserve(99 * 99 * 3); for (int i = 0; i < 99 * 99 * 3; ++i) { values.push_back(i % 255); } diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc index efce8101754..8992629d426 100644 --- a/tensorflow/core/kernels/dequantize_op_test.cc +++ b/tensorflow/core/kernels/dequantize_op_test.cc @@ -105,6 +105,7 @@ static void BM_DequantizeMinCombinedCpu(int iters) { auto root = Scope::NewRootScope().ExitOnError(); const int64 num_values = 1500 * 250; std::vector inputs; + inputs.reserve(num_values); for (int i = 0; i < num_values; ++i) inputs.push_back(i); ops::Dequantize(root, test::AsTensor(inputs), test::AsTensor({-1.5f}), diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc index 06765d8ee3a..861e16b2fd0 100644 --- a/tensorflow/core/kernels/dynamic_partition_op.cc +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -104,6 +104,7 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { const auto data_flat = data->flat(); std::vector, Eigen::Aligned> > out_vec; + out_vec.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_vec.push_back(outputs[p]->vec()); } @@ -124,6 +125,7 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { // If data has extra dimensions, use Eigen slices std::vector, Eigen::Aligned> > out_flat; + out_flat.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_flat.push_back(outputs[p]->flat_outer_dims()); } diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc index dfba8e01e4e..33d73c84776 100644 --- a/tensorflow/core/kernels/fractional_max_pool_op.cc +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -245,9 +245,11 @@ class FractionalMaxPoolGradOp : public OpKernel { constexpr int tensor_in_and_out_dims = 4; std::vector input_size; std::vector output_size; + input_size.reserve(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { input_size.push_back(tensor_in.dim_size(i)); } + output_size.reserve(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { output_size.push_back(tensor_out.dim_size(i)); } diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index 23645dafad4..37c1462f10c 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -164,6 +164,7 @@ static Graph* Gather(int dim) { random::PhiloxRandom philox(301, 17); random::SimplePhilox rnd(&philox); std::vector indices_vec; + indices_vec.reserve(kLookups); for (int i = 0; i < kLookups; i++) { indices_vec.push_back(rnd.Uniform(kRows)); } diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc index c3a7e779403..602dfeb4e54 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc @@ -29,6 +29,7 @@ TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) { std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(i + 1); } diff --git a/tensorflow/core/kernels/mfcc_test.cc b/tensorflow/core/kernels/mfcc_test.cc index 9ab726e5b9c..07b94e2e6c2 100644 --- a/tensorflow/core/kernels/mfcc_test.cc +++ b/tensorflow/core/kernels/mfcc_test.cc @@ -26,6 +26,7 @@ TEST(MfccTest, AgreesWithPythonGoldenValues) { Mfcc mfcc; std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(i + 1); } @@ -51,6 +52,7 @@ TEST(MfccTest, AvoidsNansWithZeroInput) { Mfcc mfcc; std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(0.0); } diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index c547b166eee..901ea65bdc1 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -37,6 +37,7 @@ void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device, float input_min, int tolerance = 1) { const int values_count = values_quantized.size(); std::vector expected_values; + expected_values.reserve(values_count); for (int value_index = 0; value_index < values_count; ++value_index) { expected_values.push_back(FloatToQuantized( QuantizedToFloat(values_quantized[value_index], input_min, input_max), @@ -78,6 +79,7 @@ void TestRequantizeMany8To32Bit(float input_min, float input_max, int tolerance = 256) { const int values_count = values_quantized.size(); std::vector expected_values; + expected_values.reserve(values_count); for (int value_index = 0; value_index < values_count; ++value_index) { expected_values.push_back(FloatToQuantized( QuantizedToFloat(values_quantized[value_index], input_min, input_max), diff --git a/tensorflow/core/kernels/sdca_ops_test.cc b/tensorflow/core/kernels/sdca_ops_test.cc index 400f330ce7b..ce50116a2d0 100644 --- a/tensorflow/core/kernels/sdca_ops_test.cc +++ b/tensorflow/core/kernels/sdca_ops_test.cc @@ -57,6 +57,7 @@ Node* Var(Graph* const g, const int n) { std::vector VarVector(Graph* const g, const int nodes, const int node_size) { std::vector result; + result.reserve(nodes); for (int i = 0; i < nodes; ++i) { result.push_back(Var(g, node_size)); } @@ -164,6 +165,7 @@ void GetGraphs(const int32 num_examples, const int32 num_sparse_feature_groups, sparse_weights.push_back(NodeBuilder::NodeOut(n)); } std::vector dense_weights; + dense_weights.reserve(dense_weight_nodes.size()); for (Node* n : dense_weight_nodes) { dense_weights.push_back(NodeBuilder::NodeOut(n)); } @@ -171,20 +173,24 @@ void GetGraphs(const int32 num_examples, const int32 num_sparse_feature_groups, std::vector sparse_example_indices; std::vector sparse_feature_indices; std::vector sparse_values; + sparse_example_indices.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_example_indices.push_back(NodeBuilder::NodeOut( SparseExampleIndices(g, sparse_features_per_group, num_examples))); } + sparse_feature_indices.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_feature_indices.push_back(NodeBuilder::NodeOut( SparseFeatureIndices(g, sparse_features_per_group, num_examples))); } + sparse_values.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_values.push_back( NodeBuilder::NodeOut(RandomZeroOrOne(g, num_examples * 4))); } std::vector dense_features; + dense_features.reserve(num_dense_feature_groups); for (int i = 0; i < num_dense_feature_groups; ++i) { dense_features.push_back(NodeBuilder::NodeOut( RandomZeroOrOneMatrix(g, num_examples, dense_features_per_group))); diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 4f73583ed80..4d04a206754 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -361,6 +361,7 @@ class DeserializeManySparseOp : public OpKernel { std::iota(std_order.begin(), std_order.end(), 0); std::vector tensors_to_concat; + tensors_to_concat.reserve(num_sparse_tensors); for (int i = 0; i < num_sparse_tensors; ++i) { tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i], preconcat_shape, std_order); diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 2b4d5effdad..ed93caad331 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -452,6 +452,7 @@ class SparseCrossOp : public OpKernel { ExtractFeatureData(indices_list_in, batch_size, &feature_counts, &feature_start_indices); + columns.reserve(values_list_in.size()); for (int i = 0; i < values_list_in.size(); ++i) { columns.emplace_back(new SparseTensorColumn( values_list_in[i], std::move(feature_counts[i]), diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc index f7b609191af..047e7c9e5d7 100644 --- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc @@ -463,6 +463,7 @@ class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { std::iota(std_order.begin(), std_order.end(), 0); std::vector tensors_to_concat; + tensors_to_concat.reserve(N); for (int i = 0; i < N; ++i) { tensors_to_concat.emplace_back(std::move(indices_to_concat[i]), std::move(values_to_concat[i]), diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index c3644fb98cc..387c2471ceb 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -88,6 +88,7 @@ class StageOp : public OpKernel { OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); core::ScopedUnref scope(buf); Buffer::Tuple tuple; + tuple.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { tuple.push_back(ctx->input(i)); } diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc index ef1d44fa944..2721885c4a7 100644 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc @@ -778,6 +778,7 @@ BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); static void BM_StdVectorFill(int iters, int len) { for (int i = 0; i < iters; i++) { std::vector v; + v.reserve(len); for (int j = 0; j < len; j++) { v.push_back(j); } @@ -810,6 +811,7 @@ static void BM_StdVectorFillString(int iters, int len) { "012345678901234567", "to cause allocation"}; for (int i = 0; i < iters; i++) { std::vector v; + v.reserve(len); for (int j = 0; j < len; j++) { v.push_back(strings[j & 3]); } diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc index bd203b9e859..547bee7b75f 100644 --- a/tensorflow/core/lib/gtl/optional_test.cc +++ b/tensorflow/core/lib/gtl/optional_test.cc @@ -1078,6 +1078,7 @@ TEST(optionalTest, NoExcept) { static_assert( !std::is_nothrow_move_constructible>::value, ""); std::vector> v; + v.reserve(10); for (int i = 0; i < 10; ++i) v.emplace_back(); } diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index e9c313e9031..325dbc48835 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -248,6 +248,7 @@ Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); std::vector dys; + dys.reserve(N); for (int i = 0; i < N; ++i) { dys.push_back(strings::StrCat("dy:", i)); } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 1fa5a4ed25e..85a6cfcac91 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -613,6 +613,7 @@ REGISTER_OP("Const") TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape())); TensorShape shape(proto->tensor_shape()); std::vector dims; + dims.reserve(shape.dims()); for (int i = 0; i < shape.dims(); ++i) { dims.push_back(c->MakeDim(shape.dim_size(i))); } @@ -894,6 +895,7 @@ REGISTER_OP("MatrixDiagPart") } const int32 rank = c->Rank(in); std::vector dims; + dims.reserve(rank - 2); for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i)); DimensionHandle min_dim; diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 1be68b6000e..a7b4422bab6 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -31,6 +31,7 @@ TEST(ArrayOpsTest, Pack_ShapeFn) { auto set_axis = [&op](int axis) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Pack") .Input(src_list) @@ -281,6 +282,7 @@ TEST(ArrayOpsTest, ShapeN_ShapeFn) { ShapeInferenceTestOp op("ShapeN"); int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "ShapeN") .Input(src_list) @@ -546,6 +548,7 @@ TEST(ArrayOpsTest, Concat_ShapeFn) { ShapeInferenceTestOp op("Concat"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Concat") .Input({"concat_dim", 0, DT_INT32}) @@ -619,6 +622,7 @@ TEST(ArrayOpsTest, ConcatV2_ShapeFn) { ShapeInferenceTestOp op("ConcatV2"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2") .Input(src_list) @@ -695,6 +699,7 @@ TEST(ArrayOpsTest, ConcatOffset_ShapeFn) { const int n = 4; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT32); TF_ASSERT_OK(NodeDefBuilder("test", "ConcatOffset") .Input({"concat_dim", 0, DT_INT32}) diff --git a/tensorflow/core/ops/control_flow_ops_test.cc b/tensorflow/core/ops/control_flow_ops_test.cc index 9aa14e27a0a..b6abafc51b8 100644 --- a/tensorflow/core/ops/control_flow_ops_test.cc +++ b/tensorflow/core/ops/control_flow_ops_test.cc @@ -28,6 +28,7 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Merge") .Input(src_list) @@ -54,6 +55,7 @@ TEST(ControlFlowOpsTest, RefSelect_ShapeFn) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 1, DT_FLOAT_REF); TF_ASSERT_OK(NodeDefBuilder("test", "RefSelect") .Input("index", 0, DT_INT32) diff --git a/tensorflow/core/ops/functional_ops_test.cc b/tensorflow/core/ops/functional_ops_test.cc index 37ee301c3bd..64b5ccea5a8 100644 --- a/tensorflow/core/ops/functional_ops_test.cc +++ b/tensorflow/core/ops/functional_ops_test.cc @@ -33,6 +33,7 @@ TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) { in_type_list.emplace_back(DT_FLOAT); src_list.emplace_back("a", 0, DT_FLOAT); } + out_type_list.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { out_type_list.emplace_back(DT_FLOAT); } diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 31bbe916f43..c10e667f564 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -27,6 +27,7 @@ TEST(MathOpsTest, AddN_ShapeFn) { ShapeInferenceTestOp op("AddN"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "AddN") .Input(src_list) diff --git a/tensorflow/core/ops/sparse_ops_test.cc b/tensorflow/core/ops/sparse_ops_test.cc index b3ee92fa21e..21b27346889 100644 --- a/tensorflow/core/ops/sparse_ops_test.cc +++ b/tensorflow/core/ops/sparse_ops_test.cc @@ -255,6 +255,7 @@ TEST(SparseOpsTest, SparseConcat_ShapeFn) { ShapeInferenceTestOp op("SparseConcat"); std::vector src_list; int n = 2; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT64); TF_ASSERT_OK(NodeDefBuilder("test", "SparseConcat") .Input(src_list) diff --git a/tensorflow/core/ops/string_ops_test.cc b/tensorflow/core/ops/string_ops_test.cc index 79130bae2c0..f4d3adbb2a3 100644 --- a/tensorflow/core/ops/string_ops_test.cc +++ b/tensorflow/core/ops/string_ops_test.cc @@ -27,6 +27,7 @@ TEST(StringOpsTest, StringJoin_ShapeFn) { ShapeInferenceTestOp op("StringJoin"); int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_STRING); TF_ASSERT_OK(NodeDefBuilder("test", "StringJoin") .Input(src_list) diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index aced1aa8baf..232dcb3e71a 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -25,6 +25,7 @@ typedef std::vector> ExpectedCalls; ExpectedCalls CreateRetriableErrors(const string& method, int n) { ExpectedCalls expected_calls; + expected_calls.reserve(n); for (int i = 0; i < n; i++) { expected_calls.emplace_back(std::make_tuple( method, errors::Unavailable(strings::StrCat("Retriable error #", i)))); diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index 62025463af7..c86a70ec9d0 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -27,6 +27,7 @@ namespace { std::vector CharPointerVectorFromStrings( const std::vector &strings) { std::vector result; + result.reserve(strings.size()); for (const string &string : strings) { result.push_back(const_cast(string.c_str())); } diff --git a/tensorflow/core/util/ctc/ctc_beam_search_test.cc b/tensorflow/core/util/ctc/ctc_beam_search_test.cc index 32fbb6802f5..217c7ce1f6b 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search_test.cc +++ b/tensorflow/core/util/ctc/ctc_beam_search_test.cc @@ -150,6 +150,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } @@ -199,6 +200,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } @@ -293,6 +295,7 @@ TEST(CtcBeamSearch, LabelSelection) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index 04bcbddde46..2931b8c378c 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -182,6 +182,7 @@ std::vector RunCppShapeInference( std::vector input_constant_tensor_values_v; int cnt = PyList_Size(input_constant_tensor_values); + input_constant_tensor_values_v.reserve(cnt); for (int i = 0; i < cnt; ++i) { input_constant_tensor_values_v.push_back( PyList_GetItem(input_constant_tensor_values, i)); diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 9b2d7618837..c48296eccb0 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -347,6 +347,7 @@ Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret) { PyArray_Descr* descr = PyArray_DescrFromType(typenum); CHECK(descr); std::vector dims; + dims.reserve(t.dims()); for (int i = 0; i < t.dims(); ++i) { dims.push_back(t.dim_size(i)); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index e1674745c84..ec6919f9784 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -2942,6 +2942,7 @@ bool CudnnSupport::DoMatMul(Stream* stream, } const auto toPtrs = [](std::vector>& v) { std::vector*> ptrs; + ptrs.reserve(v.size()); for (auto& mem : v) { ptrs.push_back(&mem); } diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index f8ff5ece36b..de2e0de6d99 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -80,6 +80,7 @@ void PrintBenchmarkUsage(const std::vector& placeholders, shape = PartialTensorShape(shape_proto); } } + sizes.reserve(shape.dims()); for (int i = 0; i < shape.dims(); ++i) { sizes.push_back(shape.dim_size(i)); } @@ -87,6 +88,7 @@ void PrintBenchmarkUsage(const std::vector& placeholders, input_layer_shapes.push_back(sizes_string); } std::vector output_layers; + output_layers.reserve(outputs.size()); for (const NodeDef* node : outputs) { output_layers.push_back(node->name()); } diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc index 1586e1ba412..62e29b5128f 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc @@ -634,6 +634,7 @@ void Generator::AppendDebugStringFunctions(const Descriptor& md) { Print().Print("namespace internal {").Print(); Print(sig, " {").Nest(); std::vector fields; + fields.reserve(md.field_count()); for (int i = 0; i < md.field_count(); ++i) { fields.push_back(md.field(i)); } From f60b6bdcb59f5538f3301207eabc30c10a9b6d46 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 1 Jun 2017 10:13:32 -0700 Subject: [PATCH 11/72] Add a warning to documentation of MonitoredSession. PiperOrigin-RevId: 157728225 --- tensorflow/python/training/monitored_session.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index b8554abb4ff..ff77470a824 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -279,7 +279,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to - initialize/restore. + initialize/restore. Please check `tf.train.MonitoredSession` for more + information. Args: @@ -633,6 +634,12 @@ class MonitoredSession(_MonitoredSession): See `MonitoredTrainingSession` for an example usage based on chief or worker. + Note: This is not a `tf.Session`. For example, it cannot do following: + + * it cannot be set as default session. + * it cannot be sent to saver.save. + * it cannot be sent to tf.train.start_queue_runners. + Args: session_creator: A factory object to create session. Typically a `ChiefSessionCreator` which is the default one. From 2b75a9a6ea3ad646f64f70f99ccbb070a860e64a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 10:56:16 -0700 Subject: [PATCH 12/72] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 157734029 --- tensorflow/go/op/wrappers.go | 89 +++++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 16 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 476e4af1d2f..72c28f95dfd 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -208,6 +208,16 @@ func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWith } } +// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. // // Arguments: @@ -254,16 +264,26 @@ func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { } } +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // // and `max` to 'outputs' tensor of same shape as `inputs`. // -// [min; max] is the clamping range for the 'inputs' data. Op divides this range -// into 255 steps (total of 256 values), then replaces each 'inputs' value with the -// closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// This operation has a gradient and thus allows for training `min` and `max` values. +// This operation has a gradient and thus allows for training `min` and `max` +// values. func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return @@ -2149,7 +2169,7 @@ func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { // dimension. Must sum to the dimension of value along split_dim. // Can contain one -1 indicating that dimension is to be inferred. // split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(value))`. +// `[-rank(value), rank(value))`. // // // Returns Tensors whose shape matches that of `value` @@ -2184,7 +2204,7 @@ func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, split_dim tf.O // // Arguments: // split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(value))`. +// `[-rank(value), rank(value))`. // value: The tensor to split. // num_split: The number of ways to split. Must evenly divide // `value.shape[split_dim]`. @@ -3325,12 +3345,21 @@ func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr { } } +// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. // -// Attributes [min; max] define the clamping range for the 'inputs' data. Op -// divides this range into 255 steps (total of 256 values), then replaces each -// 'inputs' value with the closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// Attributes `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // // Quantization is called fake since the output is still in floating point. func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) { @@ -6410,6 +6439,14 @@ func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgs } } +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxArgs operation. // // Arguments: @@ -8601,17 +8638,27 @@ func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVa } } +// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, // // `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` // to 'outputs' tensor of same shape as `inputs`. // -// [min; max] is the clamping range for the 'inputs' data in the corresponding -// depth channel. Op divides this range into 255 steps (total of 256 values), then -// replaces each 'inputs' value with the closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// This operation has a gradient and thus allows for training `min` and `max` values. +// This operation has a gradient and thus allows for training `min` and `max` +// values. func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { if scope.Err() != nil { return @@ -21779,6 +21826,16 @@ func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVars } } +// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxVars operation. // // Arguments: From 02ac85399d4fb35d5055ecf426632b9446a70041 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 11:30:36 -0700 Subject: [PATCH 13/72] Introduce new class Literal to replace protobuf Literal. This renames the existing Literal message to LiteralProto and introduces a new C++ class named Literal to replace it. The LiteralProto is only used at RPC boundaries, or when protobuf-specific functionality is required. The Literal class offers a 'ToProto' function to generate a new LiteralProto message when necessary. Currently, all the static functions in class LiteralUtil, just forward to their counterparts in class Literal. This will change in a future CL. Class Literal implements all the buffers as std::vectors. The only exception is preds(), which given the std::vector representation, makes it unusable for the semantics we require (it's not possible to get the address of the underlying vector, for instance). The CL adds a BoolVector class to work around that issue. In future CLs, the std::vector representation may be changed to something more efficient, if needed. PiperOrigin-RevId: 157739125 --- tensorflow/compiler/tf2xla/literal_util.h | 1 + tensorflow/compiler/xla/client/client.cc | 10 +- tensorflow/compiler/xla/client/client.h | 1 + .../xla/client/computation_builder.cc | 7 +- .../compiler/xla/client/local_client.cc | 11 +- tensorflow/compiler/xla/literal_util.cc | 913 ++++++----- tensorflow/compiler/xla/literal_util.h | 1433 ++++++++++++----- .../compiler/xla/packed_literal_reader.cc | 4 +- .../compiler/xla/packed_literal_reader.h | 1 + tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/dfs_hlo_visitor.h | 1 + .../service/dfs_hlo_visitor_with_default.h | 1 + tensorflow/compiler/xla/service/hlo.proto | 2 +- .../compiler/xla/service/hlo_instruction.cc | 8 +- .../compiler/xla/service/hlo_instruction.h | 3 +- .../compiler/xla/service/llvm_ir/llvm_util.h | 1 + tensorflow/compiler/xla/service/service.cc | 31 +- tensorflow/compiler/xla/service/session.proto | 4 +- .../compiler/xla/service/transfer_manager.h | 1 + .../xla/service/transfer_manager_test.cc | 2 +- .../compiler/xla/service/user_computation.cc | 2 +- .../xla/service/user_computation_test.cc | 7 +- .../xla/tests/client_library_test_base.cc | 2 +- .../compiler/xla/tests/literal_test_util.cc | 4 +- .../xla/tests/literal_test_util_test.cc | 5 +- tensorflow/compiler/xla/text_literal_reader.h | 1 + tensorflow/compiler/xla/text_literal_writer.h | 1 + .../compiler/xla/tools/replay_computation.cc | 5 +- tensorflow/compiler/xla/tools/show_literal.cc | 7 +- tensorflow/compiler/xla/xla.proto | 8 +- tensorflow/compiler/xla/xla_data.proto | 6 +- 31 files changed, 1619 insertions(+), 865 deletions(-) diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index e8b2233853d..fe08e83c239 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 6b38f856442..454d0fbd965 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -58,14 +58,13 @@ StatusOr> Client::Transfer( "server provided response without a literal in " "TransferToClient request"); } - - return WrapUnique(response.release_literal()); + return MakeUnique(response.literal()); } StatusOr> Client::TransferToServer( const Literal& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -93,7 +92,7 @@ StatusOr> Client::TransferToServer( Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -141,7 +140,8 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - return WrapUnique(response.release_literal()); + Literal literal(response.literal()); + return MakeUnique(literal); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 50de730a52b..797835160fa 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 22a70681468..940d38c44e7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp( } ConstantRequest request; - Literal* literal = request.mutable_literal(); - populate(literal); - VLOG(3) << "created constant: " << literal->ShortDebugString(); + Literal literal; + populate(&literal); + *request.mutable_literal() = literal.ToProto(); + VLOG(3) << "created constant: " << request.literal().ShortDebugString(); OpRequest op_request; *op_request.mutable_constant_request() = request; *op_request.mutable_computation() = computation_.handle(); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 6f2914b4718..96944a53b7e 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_RETURN_IF_ERROR( - LiteralFromShapedBuffer(*argument, session_module->add_arguments())); + Literal literal; + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); + *session_module->add_arguments() = literal.ToProto(); } return tensorflow::Status::OK(); } @@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - return LiteralFromShapedBuffer(*result, session_module->mutable_result()); + Literal literal(session_module->result()); + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); + *session_module->mutable_result() = literal.ToProto(); + return tensorflow::Status::OK(); } +// TODO(dnovillo) Change signature to return StatusOr. tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer, Literal* literal) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index ec4012a7036..5162c2b0cc3 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -36,7 +36,7 @@ limitations under the License. namespace xla { -LiteralUtil::StrideConfig::StrideConfig( +Literal::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice dimensions) : dimensions(dimensions), @@ -59,30 +59,28 @@ LiteralUtil::StrideConfig::StrideConfig( } } -/* static */ std::unique_ptr LiteralUtil::CreateFromShape( - const Shape& shape) { +std::unique_ptr Literal::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(); *literal->mutable_shape() = shape; - Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); + literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( +/* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } template -/* static */ Status LiteralUtil::CopyRange( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status Literal::CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { const Shape& src_shape = src_literal.shape(); - const Shape& dest_shape = dest_literal->shape(); - tensorflow::gtl::ArraySlice src_data = GetArraySlice(src_literal); - tensorflow::gtl::MutableArraySlice dest_data = - GetMutableArraySlice(dest_literal); + const Shape& dest_shape = shape(); + tensorflow::gtl::ArraySlice src_data = src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = GetMutableArraySlice(); TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); @@ -90,8 +88,8 @@ template // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); - StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data, - LinearIndex(src_literal, src_base), 0, 1); + StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, + src_literal.LinearIndex(src_base), 0, 1); } else if (!ShapeUtil::HasZeroElements(dest_shape)) { TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -113,8 +111,8 @@ template std::transform(indexes.begin(), indexes.end(), dest_base.begin(), dest_indexes.begin(), std::plus()); - int64 src_index = LinearIndex(src_literal, src_indexes); - int64 dest_index = LinearIndex(*dest_literal, dest_indexes); + int64 src_index = src_literal.LinearIndex(src_indexes); + int64 dest_index = LinearIndex(dest_indexes); StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, src_index, stride_config.source_stride, @@ -129,37 +127,28 @@ template return Status::OK(); } -/* static */ Status LiteralUtil::Copy( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK( - ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape())); +Status Literal::Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); switch (src_literal.shape().element_type()) { case U32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case U64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case S32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case S64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case PRED: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); default: break; } @@ -167,28 +156,28 @@ template src_literal.shape().element_type()); } -/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { +/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case F16: - return *LiteralUtil::CreateR0(static_cast(0.0f)); + return *Literal::CreateR0(static_cast(0.0f)); case F32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case F64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -201,31 +190,31 @@ template } } -/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { +/* static */ Literal Literal::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0(static_cast(1.0f)); + return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -235,33 +224,32 @@ template } } -/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case F32: - return *LiteralUtil::CreateR0( - -std::numeric_limits::infinity()); + return *Literal::CreateR0(-std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( -std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; @@ -272,33 +260,32 @@ template } } -/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case F32: - return *LiteralUtil::CreateR0( - std::numeric_limits::infinity()); + return *Literal::CreateR0(std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; @@ -309,14 +296,14 @@ template } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( +/* static */ std::unique_ptr Literal::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique(); *literal->mutable_shape() = @@ -325,150 +312,145 @@ template return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, + float to, + int64 rows, + int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::Relayout( - const Literal& original, const Layout& layout) { - std::unique_ptr result = CloneToUnique(original); +std::unique_ptr Literal::Relayout(const Layout& layout) const { + std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; - const Shape& shape = original.shape(); - DimensionVector base(ShapeUtil::Rank(shape), 0); - DimensionVector copy_size(shape.dimensions().begin(), - shape.dimensions().end()); + DimensionVector base(ShapeUtil::Rank(shape()), 0); + DimensionVector copy_size(shape().dimensions().begin(), + shape().dimensions().end()); - TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size)); + TF_CHECK_OK(result->Copy(*this, base, base, copy_size)); return result; } -/* static */ StatusOr> LiteralUtil::Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice dimensions) { - if (ShapeUtil::IsTuple(input.shape())) { +StatusOr> Literal::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (ShapeUtil::IsTuple(shape())) { return InvalidArgument("Reshape does not support tuples."); } std::unique_ptr output; - if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - std::vector minor_to_major(ShapeUtil::Rank(input.shape())); + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + std::vector minor_to_major(ShapeUtil::Rank(shape())); std::iota(minor_to_major.rbegin(), minor_to_major.rend(), static_cast(0)); - output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major)); + output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); } else { - output = CloneToUnique(input); + output = CloneToUnique(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. *output->mutable_shape() = - ShapeUtil::MakeShape(input.shape().element_type(), dimensions); + ShapeUtil::MakeShape(shape().element_type(), dimensions); - int64 elements_before = ShapeUtil::ElementsIn(input.shape()); + int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); if (elements_before != elements_after) { return InvalidArgument( - "Shapes before and after LiteralUtil::Reshape have different numbers " + "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(input.shape()).c_str(), + ShapeUtil::HumanString(shape()).c_str(), ShapeUtil::HumanString(output->shape()).c_str()); } return std::move(output); } -/* static */ std::unique_ptr LiteralUtil::Transpose( - const Literal& original, tensorflow::gtl::ArraySlice permutation) { - CHECK(!ShapeUtil::IsTuple(original.shape())) - << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape()))) +std::unique_ptr Literal::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. // This is considerably faster than iterating over every array element using // the EachCell<>() and Set<>() APIs. std::vector inverse_permutation = InversePermutation(permutation); - Shape shape = - ShapeUtil::PermuteDimensions(inverse_permutation, original.shape()); - // Replace the layout with one affine to the original shape, such that a + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a // transpose operation can be performed by leaving the flat values // representation intact. // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. - Layout* layout = shape.mutable_layout(); + Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); - for (auto index : original.shape().layout().minor_to_major()) { + for (auto index : shape().layout().minor_to_major()) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(shape); + std::unique_ptr new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), - ShapeUtil::ByteSizeOf(original.shape())); - std::memcpy(MutableInternalData(new_literal.get()), InternalData(original), - ShapeUtil::ByteSizeOf(original.shape())); + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->MutableInternalData(), InternalData(), + ShapeUtil::ByteSizeOf(shape())); return new_literal; } -/* static */ std::unique_ptr LiteralUtil::Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - CHECK(!ShapeUtil::IsTuple(literal.shape())) - << "tuple is not supported for reshape"; +std::unique_ptr Literal::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); int64 dimension = limit_indices[dnum] - start_indices[dnum]; CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } const auto result_shape = ShapeUtil::MakeShapeWithLayout( - literal.shape().element_type(), result_dimensions, - AsInt64Slice(literal.shape().layout().minor_to_major())); + shape().element_type(), result_dimensions, + AsInt64Slice(shape().layout().minor_to_major())); auto result_literal = MakeUnique(); *result_literal->mutable_shape() = result_shape; - Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); + result_literal->Reserve(ShapeUtil::ElementsIn(result_shape)); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - float value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + float value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case S32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - int32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + int32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case U32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - uint32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + uint32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; default: @@ -477,98 +459,95 @@ template } } -/* static */ std::unique_ptr LiteralUtil::CloneToUnique( - const Literal& literal) { +std::unique_ptr Literal::CloneToUnique() const { auto unique = MakeUnique(); - *unique = literal; + *unique = *this; return unique; } -/* static */ string LiteralUtil::GetAsString( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - switch (literal.shape().element_type()) { +string Literal::GetAsString( + tensorflow::gtl::ArraySlice multi_index) const { + switch (shape().element_type()) { case PRED: - return Get(literal, multi_index) ? "true" : "false"; + return Get(multi_index) ? "true" : "false"; case U8: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F16: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); default: return tensorflow::strings::StrCat( - "[", PrimitiveType_Name(literal.shape().element_type()), "]"); + "[", PrimitiveType_Name(shape().element_type()), "]"); } } -/* static */ int64 LiteralUtil::LinearIndex( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), - multi_index); +int64 Literal::LinearIndex( + tensorflow::gtl::ArraySlice multi_index) const { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -/* static */ string LiteralUtil::ToString(const Literal& literal) { - const Shape& shape = literal.shape(); +string Literal::ToString() const { std::vector pieces; auto element_to_string = - [&literal](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = literal.shape().element_type(); + [this](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = shape().element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. - return Get(literal, indices) ? "1" : "0"; + return Get(indices) ? "1" : "0"; } return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - GetAsString(literal, indices); + GetAsString(indices); }; // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(shape)) { - pieces.push_back(ShapeUtil::HumanString(shape)); + if (ShapeUtil::IsTuple(shape())) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" (\n"); - for (const auto& element_literal : literal.tuple_literals()) { - pieces.push_back(ToString(element_literal)); + for (const auto& element_literal : tuple_literals()) { + pieces.push_back(element_literal.ToString()); pieces.push_back(",\n"); } pieces.push_back(")"); - } else if (ShapeUtil::Rank(shape) == 0) { - pieces.push_back(GetAsString(literal, {})); - } else if (ShapeUtil::Rank(shape) == 1) { + } else if (ShapeUtil::Rank(shape()) == 0) { + pieces.push_back(GetAsString({})); + } else if (ShapeUtil::Rank(shape()) == 1) { pieces.push_back("{"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(element_to_string({i0})); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 2) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(element_to_string({i0, i1})); } pieces.push_back(" "); pieces.push_back("},\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 3) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(element_to_string({i0, i1, i2})); } pieces.push_back(" }"); @@ -576,17 +555,17 @@ template pieces.push_back(" }"); } pieces.push_back("\n}"); - } else if (ShapeUtil::Rank(shape) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 4) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(" {"); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(element_to_string({i0, i1, i2, i3})); } pieces.push_back("},\n"); @@ -596,20 +575,20 @@ template pieces.push_back(" },\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 5) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back( tensorflow::strings::Printf(" { // i2=%lld\n", i2)); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(" {"); - for (int64 i4 = 0; i4 < shape.dimensions(4); ++i4) { + for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); } pieces.push_back("},\n"); @@ -622,14 +601,14 @@ template } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {...}"); } return tensorflow::str_util::Join(pieces, ""); } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { auto literal = MakeUnique(); std::vector shape; @@ -641,136 +620,137 @@ template return literal; } -/* static */ const void* LiteralUtil::InternalData(const Literal& literal) { - switch (literal.shape().element_type()) { +const void* Literal::InternalData() const { + return const_cast( + const_cast(this)->MutableInternalData()); +} + +void* Literal::MutableInternalData() { + // NOTE: We access the vectors directly to avoid the const reference + // created by the accessor functions. + switch (shape().element_type()) { case PRED: - return reinterpret_cast(literal.preds().data()); + return reinterpret_cast(preds_.data()); case U8: - return reinterpret_cast(literal.u8s().data()); + return reinterpret_cast(u8s_.data()); case S32: - return reinterpret_cast(literal.s32s().data()); + return reinterpret_cast(s32s_.data()); case S64: - return reinterpret_cast(literal.s64s().data()); + return reinterpret_cast(s64s_.data()); case U32: - return reinterpret_cast(literal.u32s().data()); + return reinterpret_cast(u32s_.data()); case U64: - return reinterpret_cast(literal.u64s().data()); + return reinterpret_cast(u64s_.data()); case F32: - return reinterpret_cast(literal.f32s().data()); + return reinterpret_cast(f32s_.data()); case F64: - return reinterpret_cast(literal.f64s().data()); + return reinterpret_cast(f64s_.data()); case F16: - return reinterpret_cast(literal.f16s().data()); + return reinterpret_cast(f16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal.shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ void* LiteralUtil::MutableInternalData(Literal* literal) { - return const_cast(LiteralUtil::InternalData(*literal)); -} - -/* static */ void LiteralUtil::Reserve(int64 num_elements, Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - switch (literal->shape().element_type()) { +void Literal::Reserve(int64 num_elements) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + switch (shape().element_type()) { case PRED: - Resize(num_elements, false, literal); + Resize(num_elements, false); break; case S8: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U8: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case S32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case S64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F16: - Resize(num_elements, static_cast(0.0f), literal); + Resize(num_elements, static_cast(0.0f)); break; default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal->shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ tensorflow::Status LiteralUtil::ValidateLiteral( - const Literal& literal) { - TF_CHECK_OK(ShapeUtil::ValidateShape(literal.shape())); - int64 expected = ShapeUtil::ElementsIn(literal.shape()); +tensorflow::Status Literal::ValidateLiteral() const { + TF_CHECK_OK(ShapeUtil::ValidateShape(shape())); + int64 expected = ShapeUtil::ElementsIn(shape()); int64 actual = -1; - switch (literal.shape().element_type()) { + switch (shape().element_type()) { case PRED: - actual = literal.preds().size(); + actual = preds_size(); break; case U8: - actual = literal.u8s().size(); + actual = u8s_size(); break; case S32: - actual = literal.s32s_size(); + actual = s32s_size(); break; case U32: - actual = literal.u32s_size(); + actual = u32s_size(); break; case S64: - actual = literal.s64s_size(); + actual = s64s_size(); break; case U64: - actual = literal.u64s_size(); + actual = u64s_size(); break; case F32: - actual = literal.f32s_size(); + actual = f32s_size(); break; case F64: - actual = literal.f64s_size(); + actual = f64s_size(); break; case F16: - actual = literal.f16s().size() / sizeof(half); + actual = f16s().size() / sizeof(half); break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + - PrimitiveType_Name(literal.shape().element_type())); + PrimitiveType_Name(shape().element_type())); } if (expected != actual) { return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( "literal has bad number of elements for its shape %s: want %lld " "got %lld", - ShapeUtil::HumanString(literal.shape()).c_str(), expected, actual)); + ShapeUtil::HumanString(shape()).c_str(), expected, actual)); } return tensorflow::Status::OK(); } -/* static */ void LiteralUtil::EachCellAsString( - const Literal& literal, +void Literal::EachCellAsString( const std::function indices, - const string& value)>& per_cell) { - if (ShapeUtil::HasZeroElements(literal.shape())) { + const string& value)>& per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - literal.shape(), /*linear_index=*/0); + shape(), /*linear_index=*/0); do { - per_cell(indices, GetAsString(literal, indices)); - } while (IndexUtil::BumpIndices(literal.shape(), &indices)); + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } namespace { @@ -784,8 +764,8 @@ template bool EqualElements(const Literal& literal1, const Literal& literal2, int dimension, std::vector* multi_index) { if (dimension == ShapeUtil::Rank(literal1.shape())) { - return (LiteralUtil::Get(literal1, *multi_index) == - LiteralUtil::Get(literal2, *multi_index)); + return (literal1.Get(*multi_index) == + literal2.Get(*multi_index)); } for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { (*multi_index)[dimension] = i; @@ -799,219 +779,197 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } // namespace -/* static */ bool LiteralUtil::Equal(const Literal& literal1, - const Literal& literal2) { - if (!ShapeUtil::Compatible(literal1.shape(), literal2.shape())) { +bool Literal::Equal(const Literal& literal2) const { + if (!ShapeUtil::Compatible(shape(), literal2.shape())) { return false; } - if (ShapeUtil::IsTuple(literal1.shape())) { + if (ShapeUtil::IsTuple(shape())) { // Because the shapes are compatible, they must have the same number of // tuple elements. - CHECK_EQ(literal1.tuple_literals_size(), literal2.tuple_literals_size()); - for (int i = 0; i < literal1.tuple_literals_size(); ++i) { - if (!Equal(literal1.tuple_literals(i), literal2.tuple_literals(i))) { + CHECK_EQ(tuple_literals_size(), literal2.tuple_literals_size()); + for (int i = 0; i < tuple_literals_size(); ++i) { + if (!tuple_literals(i).Equal(literal2.tuple_literals(i))) { return false; } } return true; } else { - std::vector multi_index(ShapeUtil::Rank(literal1.shape()), 0); - switch (literal1.shape().element_type()) { + std::vector multi_index(ShapeUtil::Rank(shape()), 0); + switch (shape().element_type()) { case PRED: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U8: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F16: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); default: - LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " - << PrimitiveType_Name(literal1.shape().element_type()); + LOG(FATAL) << "Unimplemented: Literal::Equal for type " + << PrimitiveType_Name(shape().element_type()); } } } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_preds(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_preds(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. - auto values = literal->mutable_u8s(); + auto values = mutable_u8s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. - auto values = literal->mutable_u8s(); + auto values = mutable_u8s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_s32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_s32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_u32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_u32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && alignof(int64) == alignof(tensorflow::protobuf_int64), "The int64 and tensorflow::protobuf_int64 types are not " "compatible"); - auto values = literal->mutable_s64s(); + auto values = mutable_s64s(); // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is // necessary from the raw data pointer returned by the mutable_data() API. return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->mutable_data()), values->size()); + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && alignof(uint64) == alignof(tensorflow::protobuf_uint64), "The uint64 and tensorflow::protobuf_uint64 types are not " "compatible"); - auto values = literal->mutable_u64s(); + auto values = mutable_u64s(); // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t // while tensorflow::uint64 is defined as unsigned long long, a // reinterpret_cast<> is necessary from the raw data pointer returned by the // mutable_data() API. return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->mutable_data()), values->size()); + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_f32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_f64s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f64s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. // TODO - there is an endianess problem here. fix it, or wait for uint16 // support in protobuf - auto values = literal->mutable_f16s(); + auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), PRED); - return literal.preds(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), PRED); + return tensorflow::gtl::ArraySlice(preds().data(), preds().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U8); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U8); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.u8s().data()), - literal.u8s().size()); + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S8); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S8); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.u8s().data()), - literal.u8s().size()); + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U32); - return literal.u32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U32); + return u32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U64); - return AsUInt64Slice(literal.u64s()); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U64); + return u64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S32); - return literal.s32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S32); + return s32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S64); - return AsInt64Slice(literal.s64s()); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S64); + return s64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), F64); - return literal.f64s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F64); + return f64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), F16); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F16); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.f16s().data()), - literal.f16s().size() / sizeof(half)); + reinterpret_cast(f16s().data()), + f16s().size() / sizeof(half)); } template @@ -1019,48 +977,48 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { auto multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - if (LiteralUtil::Get(literal, multi_index) != value) { + if (literal.Get(multi_index) != value) { return false; } } return true; } -/* static */ bool LiteralUtil::IsAll(const Literal& literal, int8 value) { - switch (literal.shape().element_type()) { +bool Literal::IsAll(int8 value) const { + switch (shape().element_type()) { case U8: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U32: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U64: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case S8: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F16: - return AllElementsEqualValue(literal, static_cast(value)); + return AllElementsEqualValue(*this, static_cast(value)); case PRED: if (value == 0) { - return AllElementsEqualValue(literal, false); + return AllElementsEqualValue(*this, false); } if (value == 1) { - return AllElementsEqualValue(literal, true); + return AllElementsEqualValue(*this, true); } return false; default: @@ -1068,119 +1026,218 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { } } -/* static */ bool LiteralUtil::IsAllFloat(const Literal& literal, float value) { - switch (literal.shape().element_type()) { +bool Literal::IsAllFloat(float value) const { + switch (shape().element_type()) { case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F16: - return AllElementsEqualValue(literal, static_cast(value)); + return AllElementsEqualValue(*this, static_cast(value)); default: return false; } } -/* static */ bool LiteralUtil::IsZero( - const Literal& literal, tensorflow::gtl::ArraySlice indices) { - switch (literal.shape().element_type()) { +bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { + switch (shape().element_type()) { case U8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case F32: - return Get(literal, indices) == 0.0f; + return Get(indices) == 0.0f; case F64: - return Get(literal, indices) == 0.0; + return Get(indices) == 0.0; case F16: - return Get(literal, indices) == static_cast(0.0f); + return Get(indices) == static_cast(0.0f); case PRED: - return Get(literal, indices) == false; + return Get(indices) == false; default: LOG(FATAL) << "Input literal must be an array."; } } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_preds()->Resize(num_elements, value); +/* static */ void Literal::Resize(int64 num_elements, bool value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_preds()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u8s()->resize(num_elements, value); +void Literal::Resize(int64 num_elements, int8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u8s()->resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_s32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, int32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_s64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, int64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, float value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, float value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, double value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, double value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, half value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f16s()->resize(num_elements * sizeof(half)); - auto data = GetMutableArraySlice(literal); +void Literal::Resize(int64 num_elements, half value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice(); for (int i = 0; i < num_elements; i++) { data[i] = value; } } +template +static void CopyToRepeatedField(proto2::RepeatedField* dest, + const std::vector& src) { + *dest = proto2::RepeatedField(src.begin(), src.end()); +} + +LiteralProto Literal::ToProto() const { + LiteralProto proto; + proto.Clear(); + *proto.mutable_shape() = shape(); + switch (shape().element_type()) { + case PRED: + if (preds().begin()) { + *proto.mutable_preds() = + proto2::RepeatedField(preds().begin(), preds().end()); + } + break; + case U8: + *proto.mutable_u8s() = u8s_string(); + break; + case S32: + CopyToRepeatedField(proto.mutable_s32s(), s32s()); + break; + case S64: + CopyToRepeatedField(proto.mutable_s64s(), s64s()); + break; + case U32: + CopyToRepeatedField(proto.mutable_u32s(), u32s()); + break; + case U64: + CopyToRepeatedField(proto.mutable_u64s(), u64s()); + break; + case F16: + *proto.mutable_f16s() = + string(reinterpret_cast(f16s_.data()), + f16s_.size() / sizeof(half)); + break; + case F32: + CopyToRepeatedField(proto.mutable_f32s(), f32s()); + break; + case F64: + CopyToRepeatedField(proto.mutable_f64s(), f64s()); + break; + case TUPLE: + for (const auto& tuple : tuple_literals()) { + *proto.add_tuple_literals() = tuple.ToProto(); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } + + return proto; +} + +template +static void CopyFromRepeatedField(std::vector* dest, + const proto2::RepeatedField& src) { + *dest = std::vector(src.begin(), src.end()); +} + +void Literal::CopyFromProto(const LiteralProto& literal_proto) { + if (!literal_proto.has_shape()) { + return; + } + + *mutable_shape() = literal_proto.shape(); + switch (shape().element_type()) { + case PRED: + *mutable_preds() = BoolVector(literal_proto.preds().begin(), + literal_proto.preds().end()); + break; + case U8: + set_u8s(literal_proto.u8s()); + break; + case S32: + CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + break; + case S64: + CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + break; + case U32: + CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + break; + case U64: + CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + break; + case F16: { + const string& s(literal_proto.f16s()); + CHECK_EQ(0, s.size() % sizeof(half)); + f16s_ = std::vector(s.size() / sizeof(half)); + memcpy(f16s_.data(), s.data(), s.size() / sizeof(half)); + break; + } + case F32: + CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + break; + case F64: + CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); + break; + case TUPLE: + for (const auto& proto : literal_proto.tuple_literals()) { + mutable_tuple_literals()->push_back(Literal(proto)); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a05dc968ee5..31f08150ef8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -48,15 +49,210 @@ limitations under the License. namespace xla { +// This class is a simple vector of boolean values. It's used to workaround some +// implementations of std::vector that use a bitset which does not have +// the semantics expected by Literal::preds(). +class BoolVector { + public: + typedef bool* iterator; + typedef const bool* const_iterator; + + BoolVector() : bits_(nullptr), size_(0), capacity_(0) {} + + BoolVector(const_iterator other_begin, const_iterator other_end) + : bits_(nullptr), size_(0), capacity_(0) { + if (other_begin && other_end) { + resize(other_end - other_begin + 1); + memcpy(begin(), other_begin, size()); + } + } + + BoolVector(const BoolVector& other) { CopyFrom(other); } + + BoolVector& operator=(const BoolVector& other) { + CopyFrom(other); + return *this; + } + + void push_back(const bool& value) { + resize(size_ + 1); + bits_[size_ - 1] = value; + } + + bool* data() const { return bits_.get(); } + + size_t size() const { return size_; } + + size_t capacity() const { return capacity_; } + + void resize(size_t new_size, bool val = false) { + if (new_size == 0) { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } else { + size_t old_size = size(); + if (new_size > old_size) { + grow(new_size); + } + if (old_size < new_size) { + memset(&bits_[old_size], val, new_size - old_size); + } + size_ = new_size; + } + } + + void clear() { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } + + iterator begin() { return &bits_[0]; } + iterator end() { return &bits_[size()]; } + const_iterator begin() const { return &bits_[0]; } + const_iterator end() const { return &bits_[size()]; } + + private: + void grow(size_t n) { + if (capacity_ < n) { + capacity_ = 2 * n; + bool* new_bits = new bool[capacity_](); + if (size_ > 0) { + memcpy(new_bits, bits_.get(), size_); + } + bits_.reset(new_bits); + } + } + + void CopyFrom(const BoolVector& other) { + bits_ = MakeUnique(other.capacity()); + memcpy(begin(), other.begin(), other.size()); + size_ = other.size(); + capacity_ = other.capacity(); + } + + std::unique_ptr bits_; + size_t size_; + size_t capacity_; +}; + // Utility class for dealing with XLA literal values. Most methods are // templated by native (host) type which corresponds to a unique XLA // PrimitiveType. See ComputationBuilder for details. Not all primitive types // defined in xla_data.proto have a corresponding native type or even have a // storage location in the Literal proto yet (for example, primitive type F16). -class LiteralUtil { +class Literal { public: - // Create new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the + Literal() {} + + Literal(const Literal& other) = default; + + explicit Literal(const LiteralProto& other) { CopyFromProto(other); } + + Literal& operator=(const Literal& other) = default; + + LiteralProto ToProto() const; + + bool has_shape() const { + return shape_.element_type() != PRIMITIVE_TYPE_INVALID; + } + + // Basic accessor functions. Names mirror the original protobuf + // functions for convenience. + string DebugString() const { return ToProto().DebugString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + + void Clear() { + shape_.Clear(); + preds_.clear(); + u8s_.clear(); + s32s_.clear(); + s64s_.clear(); + u32s_.clear(); + u64s_.clear(); + f16s_.clear(); + f32s_.clear(); + f64s_.clear(); + tuple_literals_.clear(); + } + + int preds_size() const { return preds().size(); } + const BoolVector& preds() const { return preds_; } + BoolVector* mutable_preds() { return &preds_; } + + int s32s_size() const { return s32s().size(); } + int32 s32s(int i) const { return s32s_[i]; } + const std::vector& s32s() const { return s32s_; } + std::vector* mutable_s32s() { return &s32s_; } + + int s64s_size() const { return s64s().size(); } + void add_s64s(int64 value) { s64s_.push_back(value); } + const std::vector& s64s() const { return s64s_; } + std::vector* mutable_s64s() { return &s64s_; } + + int u32s_size() const { return u32s().size(); } + uint32 u32s(int i) const { return u32s_[i]; } + const std::vector& u32s() const { return u32s_; } + std::vector* mutable_u32s() { return &u32s_; } + + int u64s_size() const { return u64s().size(); } + const std::vector& u64s() const { return u64s_; } + std::vector* mutable_u64s() { return &u64s_; } + + int f16s_size() const { return f16s().size(); } + half f16s(int i) const { return f16s_[i]; } + const std::vector& f16s() const { return f16s_; } + std::vector* mutable_f16s() { return &f16s_; } + + int f32s_size() const { return f32s().size(); } + float f32s(int i) const { return f32s_[i]; } + void add_f32s(float value) { f32s_.push_back(value); } + const std::vector& f32s() const { return f32s_; } + std::vector& f32s() { return f32s_; } + std::vector* mutable_f32s() { return &f32s_; } + + int f64s_size() const { return f64s().size(); } + const std::vector& f64s() const { return f64s_; } + std::vector* mutable_f64s() { return &f64s_; } + + int tuple_literals_size() const { return tuple_literals().size(); } + const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } + Literal* add_tuple_literals() { + tuple_literals_.push_back(Literal()); + return &tuple_literals_.back(); + } + std::vector* mutable_tuple_literals() { return &tuple_literals_; } + const std::vector& tuple_literals() const { return tuple_literals_; } + + int u8s_size() const { return u8s().size(); } + const std::vector& u8s() const { return u8s_; } + void set_u8s(const std::vector& value) { u8s_ = value; } + void set_u8s(absl::string_view value) { + u8s_ = std::vector(value.size()); + u8s_.clear(); + append_u8s(value); + } + + void append_u8s(absl::string_view value) { + u8s_.insert(u8s_.end(), value.begin(), value.end()); + } + + string u8s_string() const { return string(u8s().begin(), u8s().end()); } + + std::vector* mutable_u8s() { return &u8s_; } + + const Shape& shape() const { return shape_; } + Shape* mutable_shape() { return &shape_; } + + void Swap(Literal* other) { + Literal temp = *this; + *this = *other; + *other = temp; + } + + // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // // CreateR1({1.0, 42.0}); @@ -101,12 +297,12 @@ class LiteralUtil { values, const Layout& layout); - // Create a new Literal object with the shape specified as parameter. + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). static std::unique_ptr CreateFromShape(const Shape& shape); - // Create a new Literal object with its values havings the primitive_type + // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -115,86 +311,84 @@ class LiteralUtil { tensorflow::gtl::ArraySlice dimensions); // Copies the values from src_literal, starting at src_base shape indexes, - // to dest_literal, starting at dest_base, where the copy size in each + // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. - // The src_literal and dest_literal must have the same primitive type, + // The src_literal and this literal must have the same primitive type, // src_base+copy_size must fit the source literal dimensions, as well as // dest_base+copy_size must fit the destination literal dimensions. - static Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. + // Creates a new value that has the equivalent value as this literal, but + // conforms to new_layout; e.g. a literal matrix that was in {0, 1} + // minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. // // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout); + std::unique_ptr Relayout(const Layout& new_layout) const; - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape); + // Creates a new literal by reshaping this literal to have 'shape'. Both the + // original shape and 'shape' must contain the same number of elements. The + // implementation currently only supports monotonic dim0-major layouts. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice shape) const; - // Creates a new literal by reordering the dimensions of the original literal. + // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers // in the original literal, and it specifies the order of the new dimensions // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation); + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; - // Creates a sub-array from the given literal by extracting the indices + // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the // same rank and layout as for the given literal. The number of indices in // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices); + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. template - static std::unique_ptr Replicate(const Literal& input, int64 times); + std::unique_ptr Replicate(int64 times) const; - // Create a literal by converting each element in an original literal to a new + // Creates a literal by converting each element in this literal to a new // type. template - static std::unique_ptr Convert(const Literal& literal); + std::unique_ptr Convert() const; - // Create a literal value zero of the given primitive type. + // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Create a literal value one of the given primitive type. + // Creates a literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); // Creates a literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Create a literal value containing the maximum value of the given + // Creates a literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Create a literal of the given shape where each element is `value`. + // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); - // Create a new literal from an array. The variants not ending with WithLayout - // use the default XLA layout for the literal's linear representation in - // memory. + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); @@ -236,39 +430,33 @@ class LiteralUtil { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal); + // Clones this literal into an owned unique_ptr version. + std::unique_ptr CloneToUnique() const; - // Returns the linear index of the given index within the literal's + // Returns the linear index of the given index within this literal's // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + int64 LinearIndex(tensorflow::gtl::ArraySlice multi_index) const; // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value); + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); // Retrieves the mutable array slice interface which can be used to manipulate // pre-allocated literal values. template - static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( - Literal* literal); + tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template - static NativeT GetFirstElement(const Literal& literal); + NativeT GetFirstElement() const; // As Get(), but determines the correct type and converts the value // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + string GetAsString(tensorflow::gtl::ArraySlice multi_index) const; // Returns an identity matrix (rank 2) with the given row and column count. template @@ -280,10 +468,10 @@ class LiteralUtil { // Validates that the data payload of the literal matches the literal shape; // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal); + tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - static string ToString(const Literal& literal); + string ToString() const; // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -292,112 +480,97 @@ class LiteralUtil { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). - static void EachCellAsString( - const Literal& literal, + void EachCellAsString( const std::function indices, - const string& value)>& per_cell); + const string& value)>& per_cell) const; template - static void EachCell( - const Literal& literal, - std::function indices, - NativeT value)> - per_cell); + void EachCell(std::function indices, + NativeT value)> + per_cell) const; - // Templated methods which populate the given repeated field in the Literal - // proto with the given value(s). The Shape field of the Literal proto is set + // Templated methods which populate the given repeated field in this literal + // with the given value(s). The Shape field of this literal is set // to match the array dimensions and type. Examples: // // // Populate with floats. // Array2D float_values = ... - // PopulateR2FromArray2D(values, literal); + // literal.PopulateR2FromArray2D(values); // // // Populate with int32s. - // PopulateR2({{1, 2}, {3, 4}}, literal); + // literal.PopulateR2({{1, 2}, {3, 4}}); // template - static void PopulateR0(NativeT values, Literal* literal); + void PopulateR0(NativeT values); template - static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal); - static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal); + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); template - static void PopulateR2( + void PopulateR2(std::initializer_list> values); + template + void PopulateR2WithLayout( std::initializer_list> values, - Literal* literal); + const Layout& layout); template - static void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout, Literal* literal); + void PopulateR2FromArray2D(const Array2D& values); template - static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal); + void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout, - Literal* literal); + void PopulateR3FromArray3D(const Array3D& values); template - static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal); + void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout, - Literal* literal); + void PopulateR4FromArray4D(const Array4D& values); template - static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal); - template - static void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout, - Literal* literal); + void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Populates literal values by calling the generator function for every cell - // in the literal object. + // in this literal object. template - static Status Populate( - Literal* literal, + Status Populate( const std::function indexes)>& generator); // Creates a Literal of the given dimensions with all elements set to the // given value. template - static void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions, - Literal* literal); + void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions); - // Returns a pointer to the underlying buffer in the protobuf containing the - // array data. Use with care. - static const void* InternalData(const Literal& literal); - static void* MutableInternalData(Literal* literal); - - // Allocates space in the repeated_field of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the buffer are set - // to zero. num_elements must equal the number of elements in the literals + // Returns a pointer to the underlying vector corresponding to the Literal's // shape. - static void Reserve(int64 num_elements, Literal* literal); + const void* InternalData() const; + void* MutableInternalData(); - // Allocates space in the repeated_field of the literal sufficient to hold - // num_elements of the literal's primitive type and sets each element in the + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type. Values in the vector are set + // to zero. num_elements must equal the number of elements in the literal's + // shape. + void Reserve(int64 num_elements); + + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type and sets each element in this // literal to the given value. num_elements must equal the number of elements - // in the literals shape. + // in this literal's shape. template - static void Resize(int64 num_elements, NativeT value, Literal* literal); + void Resize(int64 num_elements, NativeT value); - // Returns true if the two given literals have the same shape and - // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2); + // Returns true if this literal has the same shape and value as the given + // literal. Layout is not considered in the comparison. + bool Equal(const Literal& literal2) const; - // Returns whether every element in the given literal is equal to value. + // Returns whether every element in this literal is equal to value. // // value is an int8 because we expect this to be called with small // compile-time constants (0, -1, etc.) and so that whatever value you pass // can be represented exactly by floating-point types as small as 16 bits. // - // If value doesn't fit in literal's type, returns false. Values of 1/0 are - // considered equal to true/false; other values are not considered equal to - // true. - static bool IsAll(const Literal& literal, int8 value); + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. + bool IsAll(int8 value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is // equal to a particular floating-point number. @@ -408,34 +581,34 @@ class LiteralUtil { // admonishments about floating-point equality checks apply. We expect you to // use this to check for values that can be expressed precisely as a float, // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value); + bool IsAllFloat(float value) const; - // Returns whether the literal is zero at the specified index. The literal + // Returns whether this literal is zero at the specified index. This literal // must be an array. - static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices); + bool IsZero(tensorflow::gtl::ArraySlice indices) const; private: - // Returns an ArraySlice view of the array for the given literal for the - // given NativeT (e.g., float). These - // functions map native type to XLA PrimitiveType via template - // specialization. The unspecialized forms below aborts to handle the error - // case where the given native type does not map to an XLA primitive type. + // Returns an ArraySlice view of the array for this literal for the given + // NativeT (e.g., float). These functions map native type to XLA PrimitiveType + // via template specialization. The unspecialized forms below aborts to handle + // the error case where the given native type does not map to an XLA primitive + // type. template - static tensorflow::gtl::ArraySlice GetArraySlice( - const Literal& literal) { + tensorflow::gtl::ArraySlice GetArraySlice() const { static_assert(!std::is_same::value, "Cannot map native type to primitive type."); } + // Copy from a LiteralProto instance. + void CopyFromProto(const LiteralProto& literal_proto); + // Internal template helper for the Copy() API, matching its arguments one by // one. template - static Status CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. @@ -460,6 +633,549 @@ class LiteralUtil { int64 minor_loop_size = 1; }; + Shape shape_; + BoolVector preds_; + std::vector u8s_; + std::vector s32s_; + std::vector s64s_; + std::vector u32s_; + std::vector u64s_; + std::vector f16s_; + std::vector f32s_; + std::vector f64s_; + std::vector tuple_literals_; +}; + +// Utility class for dealing with XLA literal values. Most methods are +// templated by native (host) type which corresponds to a unique XLA +// PrimitiveType. See ComputationBuilder for details. Not all primitive types +// defined in xla_data.proto have a corresponding native type or even have a +// storage location in the Literal proto yet (for example, primitive type F16). +// +// TODO(dnovillo) - All functions in this class simply redirect to the +// corresponding function in class Literal. Remove this class after converting +// all user code to use Literal directly. +class LiteralUtil { + public: + // Creates new literal of a given rank. To minimize ambiguity (for users and + // the compiler) these CreateR[0-2] methods should explicitly specify the + // native type. For example: + // + // CreateR1({1.0, 42.0}); + // CreateR2({{1, 2}, {3, 4}}); + // + // The variants not ending with WithLayout use the default XLA layout for the + // literal's linear representation in memory. + template + static std::unique_ptr CreateR0(NativeT value) { + return Literal::CreateR0(value); + } + + template + static std::unique_ptr CreateR1( + tensorflow::gtl::ArraySlice values) { + return Literal::CreateR1(values); + } + + static std::unique_ptr CreateR1( + const tensorflow::core::Bitmap& values) { + return Literal::CreateR1(values); + } + + template + static std::unique_ptr CreateR2( + std::initializer_list> values) { + return Literal::CreateR2(values); + } + + template + static std::unique_ptr CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout) { + return Literal::CreateR2WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3( + std::initializer_list< + std::initializer_list>> + values) { + return Literal::CreateR3(values); + } + + template + static std::unique_ptr CreateR3WithLayout( + std::initializer_list< + std::initializer_list>> + values, + const Layout& layout) { + return Literal::CreateR3WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4( + std::initializer_list>>> + values) { + return Literal::CreateR4(values); + } + + template + static std::unique_ptr CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout) { + return Literal::CreateR4WithLayout(values, layout); + } + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape) { + return Literal::CreateFromShape(shape); + } + + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return Literal::CreateFromDimensions(primitive_type, dimensions); + } + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); + } + + // Creates a new value that has the equivalent value as literal, but conforms + // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major + // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension + // layout and the value in the cell at any given logical index (i0, i1) will + // be the same. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + static std::unique_ptr Relayout(const Literal& literal, + const Layout& new_layout) { + return literal.Relayout(new_layout); + } + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. The implementation + // currently only supports monotonic dim0-major layouts. + static StatusOr> Reshape( + const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { + return input.Reshape(shape); + } + + // Creates a new literal by reordering the dimensions of the original literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + static std::unique_ptr Transpose( + const Literal& literal, tensorflow::gtl::ArraySlice permutation) { + return literal.Transpose(permutation); + } + + // Creates a sub-array from the given literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + static std::unique_ptr Slice( + const Literal& literal, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + return literal.Slice(start_indices, limit_indices); + } + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // literal replicated four times. + template + static std::unique_ptr Replicate(const Literal& input, int64 times) { + return input.Replicate(times); + } + + // Creates a literal by converting each element in an original literal to a + // new type. + template + static std::unique_ptr Convert(const Literal& literal) { + return literal.Convert(); + } + + // Creates a literal value zero of the given primitive type. + static Literal Zero(PrimitiveType primitive_type) { + return Literal::Zero(primitive_type); + } + + // Creates a literal value one of the given primitive type. + static Literal One(PrimitiveType primitive_type) { + return Literal::One(primitive_type); + } + + // Creates a literal value containing the minimum value of the given + // primitive type. For floating-point types, returns -inf. + static Literal MinValue(PrimitiveType primitive_type) { + return Literal::MinValue(primitive_type); + } + + // Creates a literal value containing the maximum value of the given + // primitive type. For floating-point types, returns inf. + static Literal MaxValue(PrimitiveType primitive_type) { + return Literal::MaxValue(primitive_type); + } + + // Creates a literal of the given shape where each element is `value`. + template + static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + tensorflow::gtl::ArraySlice dimensions, NativeT value) { + return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); + } + + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. + template + static std::unique_ptr CreateR2FromArray2D( + const Array2D& values) { + return Literal::CreateR2FromArray2D(values); + } + + template + static std::unique_ptr CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return Literal::CreateR2FromArray2DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3FromArray3D( + const Array3D& values) { + return Literal::CreateR3FromArray3D(values); + } + + template + static std::unique_ptr CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return Literal::CreateR3FromArray3DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4FromArray4D( + const Array4D& values) { + return Literal::CreateR4FromArray4D(values); + } + + template + static std::unique_ptr CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return Literal::CreateR4FromArray4DWithLayout(values, layout); + } + + // Creates a new vector of U8s literal value from a string. + static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { + return Literal::CreateR1U8(value); + } + + // Creates a linspace-populated literal with the given number of rows and + // columns. + static std::unique_ptr CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { + return Literal::CreateR2F32Linspace(from, to, rows, cols); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z dimension given by "projection". + template + static std::unique_ptr CreateR3Projected( + std::initializer_list> values, + int64 projection) { + return Literal::CreateR3Projected(values, projection); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z) { + return Literal::CreateR4Projected(values, projection_p, projection_z); + } + + // Clones literal into an owned unique_ptr version. + static std::unique_ptr CloneToUnique(const Literal& literal) { + return literal.CloneToUnique(); + } + + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.LinearIndex(multi_index); + } + + // Gets or sets an element in the literal at the given index. The index is + // CHECKed against the dimension sizes. + template + static NativeT Get(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.Get(multi_index); + } + + template + static void Set(Literal* literal, + tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + literal->Set(multi_index, value); + } + + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( + Literal* literal) { + return literal->GetMutableArraySlice(); + } + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + static NativeT GetFirstElement(const Literal& literal) { + return literal.GetFirstElement(); + } + + // As Get(), but determines the correct type and converts the value + // into text. + static string GetAsString(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.GetAsString(multi_index); + } + + // Returns an identity matrix (rank 2) with the given row and column count. + template + static std::unique_ptr MakeIdentityR2(int64 size) { + return Literal::MakeIdentityR2(size); + } + + // Returns a tuple literal composed of given literals. + static std::unique_ptr MakeTuple( + tensorflow::gtl::ArraySlice elements) { + return Literal::MakeTuple(elements); + } + + // Validates that the data payload of the literal matches the literal shape; + // if it does not, an appropriate status is returned. + static tensorflow::Status ValidateLiteral(const Literal& literal) { + return literal.ValidateLiteral(); + } + + // Returns a string representation of the literal value. + static string ToString(const Literal& literal) { return literal.ToString(); } + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + static void EachCellAsString( + const Literal& literal, + const std::function indices, + const string& value)>& per_cell) { + literal.EachCellAsString(per_cell); + } + + template + static void EachCell( + const Literal& literal, + std::function indices, + NativeT value)> + per_cell) { + literal.EachCell(per_cell); + } + + // Templated methods which populate the given repeated field in the Literal + // proto with the given value(s). The Shape field of the Literal proto is set + // to match the array dimensions and type. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // PopulateR2FromArray2D(values, literal); + // + // // Populate with int32s. + // PopulateR2({{1, 2}, {3, 4}}, literal); + // + template + static void PopulateR0(NativeT values, Literal* literal) { + literal->PopulateR0(values); + } + + template + static void PopulateR1(tensorflow::gtl::ArraySlice values, + Literal* literal) { + literal->PopulateR1(values); + } + + static void PopulateR1(const tensorflow::core::Bitmap& values, + Literal* literal) { + literal->PopulateR1(values); + } + + template + static void PopulateR2( + std::initializer_list> values, + Literal* literal) { + literal->PopulateR2(values); + } + + template + static void PopulateR2WithLayout( + std::initializer_list> values, + const Layout& layout, Literal* literal) { + literal->PopulateR2WithLayout(values, layout); + } + + template + static void PopulateR2FromArray2D(const Array2D& values, + Literal* literal) { + literal->PopulateR2FromArray2D(values); + } + + template + static void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR2FromArray2DWithLayout(values, layout); + } + + template + static void PopulateR3FromArray3D(const Array3D& values, + Literal* literal) { + literal->PopulateR3FromArray3D(values); + } + + template + static void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR3FromArray3DWithLayout(values, layout); + } + + template + static void PopulateR4FromArray4D(const Array4D& values, + Literal* literal) { + literal->PopulateR4FromArray4D(values); + } + + template + static void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR4FromArray4DWithLayout(values, layout); + } + + // Populates literal values by calling the generator function for every cell + // in the literal object. + template + static Status Populate( + Literal* literal, + const std::function indexes)>& + generator) { + return literal->Populate(generator); + } + + // Creates a Literal of the given dimensions with all elements set to the + // given value. + template + static void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions, + Literal* literal) { + return literal->PopulateWithValue(value, dimensions); + } + + // Returns a pointer to the underlying vector containing the array data. Use + // with care. + static const void* InternalData(const Literal& literal) { + return literal.InternalData(); + } + + static void* MutableInternalData(Literal* literal) { + return literal->MutableInternalData(); + } + + // Allocates space in the underlying vector of the literal sufficient to hold + // num_elements of the literal's primitive type. Values in the vector are set + // to zero. num_elements must equal the number of elements in the literals + // shape. + static void Reserve(int64 num_elements, Literal* literal) { + literal->Reserve(num_elements); + } + + // Allocates space in the underlying vector of the literal sufficient to hold + // num_elements of the literal's primitive type and sets each element in the + // literal to the given value. num_elements must equal the number of elements + // in the literals shape. + template + static void Resize(int64 num_elements, NativeT value, Literal* literal) { + literal->Resize(num_elements, value); + } + + // Returns true if the two given literals have the same shape and + // values. Layout is not considered in the comparison. + static bool Equal(const Literal& literal1, const Literal& literal2) { + return literal1.Equal(literal2); + } + + // Returns whether every element in the given literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in literal's type, returns false. Values of 1/0 are + // considered equal to true/false; other values are not considered equal to + // true. + static bool IsAll(const Literal& literal, int8 value) { + return literal.IsAll(value); + } + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. + static bool IsAllFloat(const Literal& literal, float value) { + return literal.IsAllFloat(value); + } + + // Returns whether the literal is zero at the specified index. The literal + // must be an array. + static bool IsZero(const Literal& literal, + tensorflow::gtl::ArraySlice indices) { + return literal.IsZero(indices); + } + TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; @@ -467,160 +1183,131 @@ class LiteralUtil { // GetMutableArraySlice. The specializations map native type to XLA primitive // type. template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ inline tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - DCHECK(literal.shape().element_type() == F32); - return literal.f32s(); +inline tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + DCHECK(shape().element_type() == F32); + return f32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, - Literal* literal); +void Literal::Resize(int64 num_elements, bool value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int8 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint8 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int32 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint32 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int64 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint64 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, float value, - Literal* literal); +void Literal::Resize(int64 num_elements, float value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, double value, - Literal* literal); +void Literal::Resize(int64 num_elements, double value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, half value, - Literal* literal); +void Literal::Resize(int64 num_elements, half value); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { +/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); - PopulateR0(value, literal.get()); + literal->PopulateR0(value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ std::unique_ptr Literal::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2WithLayout(values, layout, literal.get()); + literal->PopulateR2WithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ std::unique_ptr Literal::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ std::unique_ptr Literal::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -645,14 +1332,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ std::unique_ptr Literal::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ std::unique_ptr Literal::CreateR4WithLayout( std::initializer_list>>> values, @@ -683,7 +1370,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ std::unique_ptr Literal::CreateR4( std::initializer_list>>> values) { @@ -691,38 +1378,37 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2FromArray2DWithLayout(values, layout, literal.get()); + literal->PopulateR2FromArray2DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { return CreateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR3FromArray3DWithLayout(values, layout, literal.get()); + literal->PopulateR3FromArray3DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { return CreateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ std::unique_ptr Literal::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -747,7 +1433,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ std::unique_ptr Literal::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -775,99 +1461,92 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { return CreateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR4FromArray4DWithLayout(values, layout, literal.get()); + literal->PopulateR4FromArray4DWithLayout(values, layout); return literal; } template -/* static */ NativeT LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - int64 linear_index = LinearIndex(literal, multi_index); - return GetArraySlice(literal).at(linear_index); +NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index) const { + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice().at(linear_index); } template -/* static */ NativeT LiteralUtil::GetFirstElement(const Literal& literal) { - return GetArraySlice(literal).at(0); +NativeT Literal::GetFirstElement() const { + return GetArraySlice().at(0); } template <> -/* static */ inline uint8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == U8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline uint8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == U8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; } template <> -/* static */ inline int8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == S8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline int8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == S8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; } template <> -/* static */ inline half LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == F16); - int64 linear_index = LinearIndex(literal, multi_index); - return GetArraySlice(literal)[linear_index]; +inline half Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == F16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; } template -/* static */ void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - int64 linear_index = LinearIndex(*literal, multi_index); - GetMutableArraySlice(literal).at(linear_index) = value; +void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + int64 linear_index = LinearIndex(multi_index); + GetMutableArraySlice().at(linear_index) = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint8 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u8s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + uint8 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u8s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int8 value) { - return Set(literal, multi_index, value); +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int8 value) { + return Set(multi_index, value); } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_s64s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_s64s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u64s())[linear_index] = value; +/* static */ inline void Literal::Set( + tensorflow::gtl::ArraySlice multi_index, uint64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u64s())[linear_index] = value; } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -876,55 +1555,51 @@ template } template -/* static */ void LiteralUtil::EachCell( - const Literal& literal, +void Literal::EachCell( std::function indices, NativeT value)> - per_cell) { - if (ShapeUtil::HasZeroElements(literal.shape())) { + per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } - std::vector indices(ShapeUtil::Rank(literal.shape()), 0); + std::vector indices(ShapeUtil::Rank(shape()), 0); do { - per_cell(indices, Get(literal, indices)); - } while (IndexUtil::BumpIndices(literal.shape(), &indices)); + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } template -/* static */ inline void LiteralUtil::PopulateR0(NativeT value, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( +inline void Literal::PopulateR0(NativeT value) { + *mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {}); - Resize(1, value, literal); + Resize(1, value); } template -/* static */ void LiteralUtil::PopulateR1( - tensorflow::gtl::ArraySlice values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { + *mutable_shape() = ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())}); - Reserve(values.size(), literal); + Reserve(values.size()); for (int64 i = 0; i < values.size(); ++i) { - Set(literal, {i}, values[i]); + Set({i}, values[i]); } } -/* static */ inline void LiteralUtil::PopulateR1( - const tensorflow::core::Bitmap& values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + *mutable_shape() = ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); - Reserve(values.bits(), literal); + Reserve(values.bits()); for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set(literal, {i}, values.get(i)); + Set({i}, values.get(i)); } } template -/* static */ void LiteralUtil::PopulateR2WithLayout( +void Literal::PopulateR2WithLayout( std::initializer_list> values, - const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -932,17 +1607,17 @@ template const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); const int64 num_elements = dim1_size * dim0_size; - Reserve(num_elements, literal); + Reserve(num_elements); int64 dim0 = 0; for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - Set(literal, {dim0, dim1}, value); + Set({dim0, dim1}, value); ++dim1; } CHECK_EQ(dim1_size, dim1); @@ -951,84 +1626,79 @@ template } template -/* static */ void LiteralUtil::PopulateR2( - std::initializer_list> values, - Literal* literal) { - PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), literal); +void Literal::PopulateR2( + std::initializer_list> values) { + PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ void LiteralUtil::PopulateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); const int64 dim1_size = values.width(); const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); - Reserve(dim1_size * dim0_size, literal); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + Reserve(dim1_size * dim0_size); for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set(literal, {dim0, dim1}, values(dim0, dim1)); + Set({dim0, dim1}, values(dim0, dim1)); } } } template -/* static */ void LiteralUtil::PopulateR2FromArray2D( - const Array2D& values, Literal* literal) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), - literal); +void Literal::PopulateR2FromArray2D(const Array2D& values) { + PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ void LiteralUtil::PopulateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.n1(), values.n2(), values.n3()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + Reserve(values.n1() * values.n2() * values.n3()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set(literal, {dim0, dim1, dim2}, values(dim0, dim1, dim2)); + Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); } } } } template -/* static */ void LiteralUtil::PopulateR3FromArray3D( - const Array3D& values, Literal* literal) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3(), - literal); +void Literal::PopulateR3FromArray3D(const Array3D& values) { + PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ void LiteralUtil::PopulateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.planes(), values.depth(), values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - CHECK_EQ(values.n4(), literal->shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + CHECK_EQ(values.n4(), shape().dimensions(3)); + Reserve(values.n1() * values.n2() * values.n3() * values.n4()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set(literal, {dim0, dim1, dim2, dim3}, - values(dim0, dim1, dim2, dim3)); + Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); } } } @@ -1036,31 +1706,29 @@ template } template -/* static */ void LiteralUtil::PopulateR4FromArray4D( - const Array4D& values, Literal* literal) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4(), - literal); +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ Status LiteralUtil::Populate( - Literal* literal, +Status Literal::Populate( const std::function indexes)>& generator) { - const Shape& shape = literal->shape(); - int64 rank = ShapeUtil::Rank(shape); - TF_RET_CHECK(shape.element_type() == + const Shape& this_shape = shape(); + int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); tensorflow::gtl::MutableArraySlice data = - GetMutableArraySlice(literal); + GetMutableArraySlice(); if (rank > 0) { - StrideConfig stride_config(shape, shape, AsInt64Slice(shape.dimensions())); + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); DimensionVector minor_scan_indexes(rank, 0); int64 minor_dimension_size = - ShapeUtil::GetDimension(shape, stride_config.minor_dimension); + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector& indexes) { - int64 index = LinearIndex(*literal, indexes); + int64 index = LinearIndex(indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; @@ -1068,8 +1736,9 @@ template } return true; }; - ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, - stride_config.step, init_function); + ShapeUtil::ForEachIndex(this_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + init_function); } else { data.at(0) = generator({}); } @@ -1077,30 +1746,27 @@ template } template -/* static */ void LiteralUtil::PopulateWithValue( - NativeT value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( +void Literal::PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions) { + *mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), dimensions); - Resize(ShapeUtil::ElementsIn(literal->shape()), value, literal); + Resize(ShapeUtil::ElementsIn(shape()), value); } template -/* static */ std::unique_ptr LiteralUtil::Convert( - const Literal& literal) { - const Shape& shape = literal.shape(); +std::unique_ptr Literal::Convert() const { + const Shape& this_shape = shape(); auto result_literal = MakeUnique(); Shape* result_shape = result_literal->mutable_shape(); - *result_shape = shape; + *result_shape = this_shape; result_shape->set_element_type( primitive_util::NativeToPrimitiveType()); - LiteralUtil::Reserve(ShapeUtil::ElementsIn(*result_shape), - result_literal.get()); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); tensorflow::gtl::ArraySlice src_data = - GetArraySlice(literal); + GetArraySlice(); tensorflow::gtl::MutableArraySlice dest_data = - GetMutableArraySlice(result_literal.get()); - int64 num_elements = ShapeUtil::ElementsIn(shape); + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(this_shape); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = static_cast(src_data[i]); @@ -1110,36 +1776,35 @@ template template /* static */ std::unique_ptr -LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( +Literal::CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( primitive_util::NativeToPrimitiveType(), dimensions); auto literal = MakeUnique(); - *literal->mutable_shape() = shape; - Reserve(ShapeUtil::ElementsIn(shape), literal.get()); + *literal->mutable_shape() = this_shape; + literal->Reserve(ShapeUtil::ElementsIn(this_shape)); std::vector index(dimensions.size(), 0); do { - Set(literal.get(), index, value); - } while (IndexUtil::BumpIndices(shape, &index)); + literal->Set(index, value); + } while (IndexUtil::BumpIndices(this_shape, &index)); return literal; } template -/* static */ std::unique_ptr LiteralUtil::Replicate( - const Literal& input, int64 times) { +std::unique_ptr Literal::Replicate(int64 times) const { DimensionVector bounds = {times}; - bounds.reserve(input.shape().dimensions_size() + 1); - for (int64 bound : input.shape().dimensions()) { + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } auto literal = MakeUnique(); *literal->mutable_shape() = - ShapeUtil::MakeShape(input.shape().element_type(), bounds); + ShapeUtil::MakeShape(shape().element_type(), bounds); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; } - Reserve(elements, literal.get()); + literal->Reserve(elements); DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; @@ -1147,8 +1812,8 @@ template bool done = false; while (!done) { - const auto element = Get(input, input_indices); - Set(literal.get(), output_indices, element); + const auto element = Get(input_indices); + literal->Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 21766a2a0c8..d488830a6cd 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -60,8 +60,8 @@ StatusOr> PackedLiteralReader::Read( int64 elements = ShapeUtil::ElementsIn(shape); LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), result.get()); - tensorflow::protobuf::RepeatedField* field = result->mutable_f32s(); - char* data = tensorflow::bit_cast(field->mutable_data()); + std::vector* field = result->mutable_f32s(); + char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 563d978cf5d..45a9fe01278 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f4eda7f91e7..f21ce6bc3ac 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -531,6 +531,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 49e9874cda2..78a398f8efa 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index c27710fbdb2..6557c3aa8e6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 6583e509674..cfd1f0f53b7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -46,7 +46,7 @@ message HloInstructionProto { xla.OpMetadata metadata = 7; // Literal, only present for kConstant. - xla.Literal literal = 8; + xla.LiteralProto literal = 8; // Parameter info, only present for kParameter. int64 parameter_number = 9; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 19a97c0175f..b02089206e9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat; WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_.reset(new Literal); - *instruction->literal_->mutable_u8s() += tag; + instruction->literal_->append_u8s(tag); return instruction; } @@ -1551,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; switch (opcode_) { case HloOpcode::kConstant: - *proto.mutable_literal() = *literal_; + *proto.mutable_literal() = literal_->ToProto(); break; case HloOpcode::kParameter: proto.set_parameter_number(parameter_number_); @@ -1648,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -const string& HloInstruction::tracing_tag() const { +string HloInstruction::TracingTag() const { CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK(literal_ != nullptr); - return literal_->u8s(); + return literal_->u8s_string(); } bool HloInstruction::IsFused() const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3db185896da..3bf46341be2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -30,6 +30,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -535,7 +536,7 @@ class HloInstruction { // Returns a tag to be used in tracing. // // Precondition: opcode() == HloOpcode::kTrace - const string& tracing_tag() const; + string TracingTag() const; // Returns whether the instruction is a constant. bool IsConstant() const; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 08bb10dbd98..d9a98ae5eb4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Value.h" #include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index c8f2188b53c..2157604518d 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -77,8 +77,10 @@ tensorflow::Status RecordArguments( SessionModule* module) { module->clear_arguments(); for (const Allocation* allocation : arg_allocations) { - TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(), - module->add_arguments())); + Literal argument; + TF_RETURN_IF_ERROR( + LiteralFromAllocation(allocation, allocation->shape(), &argument)); + *module->add_arguments() = argument.ToProto(); } return tensorflow::Status::OK(); } @@ -87,8 +89,11 @@ tensorflow::Status RecordArguments( tensorflow::Status RecordResult(const Allocation* result_allocation, SessionModule* module) { module->clear_result(); - return LiteralFromAllocation(result_allocation, result_allocation->shape(), - module->mutable_result()); + Literal result; + TF_RETURN_IF_ERROR(LiteralFromAllocation( + result_allocation, result_allocation->shape(), &result)); + *module->mutable_result() = result.ToProto(); + return tensorflow::Status::OK(); } } // namespace @@ -912,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, literal_shape = &allocation->shape(); } - return LiteralFromAllocation(allocation, *literal_shape, - result->mutable_literal()); + Literal literal; + auto status = LiteralFromAllocation(allocation, *literal_shape, &literal); + *result->mutable_literal() = literal.ToProto(); + return status; } tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - const Literal& literal = arg->literal(); + Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { @@ -982,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, arg->literal()); + executor, Literal(arg->literal())); } tensorflow::Status Service::TransferFromOutfeed( @@ -1005,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed( executor = execute_backend_->Replicas()[arg->replica_id()]; } - return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), result->mutable_literal()); + Literal literal; + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( + executor, arg->shape_with_layout(), &literal)); + *result->mutable_literal() = literal.ToProto(); + return tensorflow::Status::OK(); } tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto index 4902cb521c2..bb8d1cd2a10 100644 --- a/tensorflow/compiler/xla/service/session.proto +++ b/tensorflow/compiler/xla/service/session.proto @@ -75,10 +75,10 @@ message SessionModule { repeated SessionComputation embedded_computations = 2; // The arguments passed to the computation. - repeated Literal arguments = 3; + repeated LiteralProto arguments = 3; // The result of the computation. - Literal result = 4; + LiteralProto result = 4; // The name of the platform used to run the computation. string execution_platform = 5; diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index a417b988bfe..15f6b7bfb4a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index 564111c4f2b..ca38601d919 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { const Shape shape = ShapeUtil::MakeShape(U8, {4}); TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s()); + CHECK_EQ("klmn", literal.u8s_string()); } TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4cde03849e9..b97823d2dc0 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(constant_request.literal()))); + LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); break; } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ddd13edeb86..ea691201263 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = - *LiteralUtil::CreateR1({123.0f, 42.0f}); + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, computation.AddConstantInstruction(constant_request)); @@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { UserComputation computation("TheComputation", handle); ConstantRequest a_request; - *a_request.mutable_literal() = *LiteralUtil::CreateR1({123.0f, 42.0f}); + *a_request.mutable_literal() = + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, computation.AddConstantInstruction(a_request)); ConstantRequest b_request; - *b_request.mutable_literal() = *LiteralUtil::CreateR0(1.0f); + *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, computation.AddConstantInstruction(b_request)); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7bf1168dc39..08e3f81a283 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); - EXPECT_EQ(expected, actual->u8s()); + EXPECT_EQ(expected, actual->u8s_string()); } void ClientLibraryTestBase::ComputeAndCompareTuple( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 23453db57bc..eb979ad189d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -262,7 +262,7 @@ class NearComparator { max_abs_err_ = 0.0; *miscompares_.mutable_shape() = ShapeUtil::ChangeElementType(actual.shape(), PRED); - miscompares_.mutable_preds()->Resize( + miscompares_.mutable_preds()->resize( ShapeUtil::ElementsIn(miscompares_.shape()), false); multi_index_.resize(expected.shape().dimensions_size(), 0); @@ -389,7 +389,7 @@ class NearComparator { tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, name.c_str())); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal)); + filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; } diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index fdec11c0e98..a94f45f73b7 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { - Literal literal; + LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, - &literal)); + &literal_proto)); + Literal literal(literal_proto); if (result.find("expected") != string::npos) { EXPECT_EQ("2", LiteralUtil::ToString(literal)); } else if (result.find("actual") != string::npos) { diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 3cfbb2c7fbf..e45e5291c9b 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 545bd22da91..7375493f430 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index f4d46b26e65..3a75bf64954 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -66,7 +66,8 @@ StatusOr> ReplayComputation( if (use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available - for (const Literal& literal : module.arguments()) { + for (const auto& proto : module.arguments()) { + Literal literal(proto); TF_ASSIGN_OR_RETURN(std::unique_ptr data, client->TransferToServer(literal)); arguments.push_back(std::move(data)); @@ -101,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(module.result()).c_str()); + LiteralUtil::ToString(Literal(module.result())).c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index cf363913b15..b6538f5de07 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -37,9 +37,10 @@ int main(int argc, char **argv) { << " "; } - xla::Literal literal; + xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], - &literal)); - LOG(INFO) << "literal: " << literal.ShortDebugString(); + &literal_proto)); + xla::Literal literal(literal_proto); + LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 716eb424424..193ae49afee 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -92,11 +92,11 @@ message TransferToClientRequest { } message TransferToClientResponse { - Literal literal = 1; + LiteralProto literal = 1; } message TransferToServerRequest { - Literal literal = 1; + LiteralProto literal = 1; DeviceHandle device_handle = 2; } @@ -105,7 +105,7 @@ message TransferToServerResponse { } message TransferToInfeedRequest { - Literal literal = 1; + LiteralProto literal = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; } @@ -123,7 +123,7 @@ message TransferFromOutfeedRequest { } message TransferFromOutfeedResponse { - Literal literal = 1; + LiteralProto literal = 1; } message ResetDeviceRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1239816c50e..44a94e171fa 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -275,7 +275,7 @@ message ChannelHandle { // // Transfers to/from the client are encoded in literal form, and the structure // of the repeated fields is implied by the shape. -message Literal { +message LiteralProto { Shape shape = 1; repeated bool preds = 2; bytes u8s = 3; @@ -285,7 +285,7 @@ message Literal { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; - repeated Literal tuple_literals = 10; + repeated LiteralProto tuple_literals = 10; bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } @@ -337,7 +337,7 @@ message Window { // field in OpRequest. message ConstantRequest { - Literal literal = 2; + LiteralProto literal = 2; } message GetTupleElementRequest { From 0aa3e01941d231fe313e600eaa5f7cc052c1c077 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 11:41:41 -0700 Subject: [PATCH 14/72] Internal change PiperOrigin-RevId: 157740660 --- tensorflow/core/BUILD | 1 + tensorflow/tensorflow.bzl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 25863802d94..2647bfe3bbd 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "full_path", "if_android", "if_ios", "if_x86", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 348745f8d2b..b0ed57996c0 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -22,6 +22,9 @@ load( "if_mkl",) +def full_path(relative_paths): + return [PACKAGE_NAME + "/" + relative for relative in relative_paths] + # List of proto files for android builds def tf_android_core_proto_sources(core_proto_sources_relative): return [ From 9fc1642250713f27f520af0da080c388390912c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 11:41:49 -0700 Subject: [PATCH 15/72] Fix index_table_from_file to allow vocabulary_file be a Tensor PiperOrigin-RevId: 157740677 --- tensorflow/contrib/lookup/lookup_ops.py | 7 ++++--- tensorflow/contrib/lookup/lookup_ops_test.py | 20 ++++++++++++++++++- .../python/kernel_tests/lookup_ops_test.py | 18 ++++++++++++++++- tensorflow/python/ops/lookup_ops.py | 7 ++++--- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 65474f03fa0..e49b62afa28 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None, ``` Args: - vocabulary_file: The vocabulary filename. + vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. @@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None, ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater than zero. """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") + if vocabulary_file is None or ( + isinstance(vocabulary_file, str) and not vocabulary_file): + raise ValueError("vocabulary_file must be specified and must not be empty.") if num_oov_buckets < 0: raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5ec169b6db4..180dfefe29d 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) @@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): + self.assertRaises( + ValueError, + lookup.index_table_from_file, + vocabulary_file="") + + def test_index_table_from_file_fails_with_empty_vocabulary(self): self.assertRaises( ValueError, lookup.index_table_from_file, diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 2a90bc539bb..79254cb28c2 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -280,6 +280,18 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) @@ -340,7 +352,11 @@ class IndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): + self.assertRaises( + ValueError, lookup_ops.index_table_from_file, vocabulary_file="") + + def test_index_table_from_file_fails_with_empty_vocabulary(self): self.assertRaises( ValueError, lookup_ops.index_table_from_file, vocabulary_file=None) diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 82277ebaccb..bffaf6324fc 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -893,7 +893,7 @@ def index_table_from_file(vocabulary_file=None, ``` Args: - vocabulary_file: The vocabulary filename. + vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. @@ -911,8 +911,9 @@ def index_table_from_file(vocabulary_file=None, ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater than zero. """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") + if vocabulary_file is None or ( + isinstance(vocabulary_file, str) and not vocabulary_file): + raise ValueError("vocabulary_file must be specified and must not be empty.") if num_oov_buckets < 0: raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) From d29bbeca3d237e10c678242f34bde908ca68ccc3 Mon Sep 17 00:00:00 2001 From: Dandelion Man? Date: Thu, 1 Jun 2017 12:01:54 -0700 Subject: [PATCH 16/72] Fix outdated code ref in TensorBoard README, add link to SO question. PiperOrigin-RevId: 157743374 --- tensorflow/tensorboard/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/tensorboard/README.md b/tensorflow/tensorboard/README.md index 20be8593cb3..5aff57a241b 100644 --- a/tensorflow/tensorboard/README.md +++ b/tensorflow/tensorboard/README.md @@ -330,7 +330,9 @@ TensorBoard uses [reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling) to downsample your data so that it can be loaded into RAM. You can modify the number of elements it will keep per tag in -[tensorboard/backend/server.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/backend/server.py). +[tensorboard/backend/application.py](https://www.github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/tensorboard/backend/application.py). +See this [StackOverflow question](http://stackoverflow.com/questions/43702546/tensorboard-doesnt-show-all-data-points/) +for some more information. ### I get a network security popup every time I run TensorBoard on a mac! From 15a740ebbba3fc176b5bc4318db84d470e356dad Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 1 Jun 2017 12:06:18 -0700 Subject: [PATCH 17/72] Update and Move DNNLinearCombinedRegressor to estimator/canned. PiperOrigin-RevId: 157744087 --- .../estimator/canned/dnn_linear_combined.py | 373 ++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 tensorflow/python/estimator/canned/dnn_linear_combined.py diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py new file mode 100644 index 00000000000..db5a72b8f21 --- /dev/null +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -0,0 +1,373 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimators for Linear and DNN joined training models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import six + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import sync_replicas_optimizer +from tensorflow.python.training import training_util + +# The default learning rates are a historical artifact of the initial +# implementation, but seem a reasonable choice. +_DNN_LEARNING_RATE = 0.05 +_LINEAR_LEARNING_RATE = 0.2 + + +def _check_no_sync_replicas_optimizer(optimizer): + if isinstance(optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): + raise ValueError( + 'SyncReplicasOptimizer does not support multi optimizers case. ' + 'Therefore, it is not supported in DNNLinearCombined model. ' + 'If you want to use this optimizer, please use either DNN or Linear ' + 'model.') + + +def _linear_learning_rate(num_linear_feature_columns): + """Returns the default learning rate of the linear model. + + The calculation is a historical artifact of this initial implementation, but + has proven a reasonable choice. + + Args: + num_linear_feature_columns: The number of feature columns of the linear + model. + + Returns: + A float. + """ + default_learning_rate = 1. / math.sqrt(num_linear_feature_columns) + return min(_LINEAR_LEARNING_RATE, default_learning_rate) + + +def _add_layer_summary(value, tag): + summary.scalar('%s/fraction_of_zero_values' % tag, nn.zero_fraction(value)) + summary.histogram('%s/activation' % tag, value) + + +def _dnn_linear_combined_model_fn( + features, labels, mode, head, + linear_feature_columns=None, linear_optimizer='Ftrl', + dnn_feature_columns=None, dnn_optimizer='Adagrad', dnn_hidden_units=None, + dnn_activation_fn=nn.relu, dnn_dropout=None, + input_layer_partitioner=None, config=None): + """Deep Neural Net and Linear combined model_fn. + + Args: + features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`). + labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype + `int32` or `int64` in the range `[0, n_classes)`. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `Head` instance. + linear_feature_columns: An iterable containing all the feature columns used + by the Linear model. + linear_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the Linear model. Defaults to the Ftrl + optimizer. + dnn_feature_columns: An iterable containing all the feature columns used by + the DNN model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN model. Defaults to the Adagrad + optimizer. + dnn_hidden_units: List of hidden units per DNN layer. + dnn_activation_fn: Activation function applied to each DNN layer. If `None`, + will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability we will drop out a given DNN + coordinate. + input_layer_partitioner: Partitioner for input layer. + config: `RunConfig` object to configure the runtime settings. + + Returns: + `ModelFnOps` + + Raises: + ValueError: If both `linear_feature_columns` and `dnn_features_columns` + are empty at the same time, or `input_layer_partitioner` is missing. + """ + if not linear_feature_columns and not dnn_feature_columns: + raise ValueError( + 'Either linear_feature_columns or dnn_feature_columns must be defined.') + num_ps_replicas = config.num_ps_replicas if config else 0 + input_layer_partitioner = input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) + + linear_optimizer = optimizers.get_optimizer_instance( + linear_optimizer, + learning_rate=_linear_learning_rate(len(linear_feature_columns))) + _check_no_sync_replicas_optimizer(linear_optimizer) + + dnn_optimizer = optimizers.get_optimizer_instance( + dnn_optimizer, + learning_rate=_DNN_LEARNING_RATE) + _check_no_sync_replicas_optimizer(dnn_optimizer) + + # Build DNN Logits. + dnn_parent_scope = 'dnn' + + if not dnn_feature_columns: + dnn_logits = None + else: + if not dnn_hidden_units: + raise ValueError( + 'dnn_hidden_units must be defined when dnn_feature_columns is ' + 'specified.') + dnn_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas)) + with variable_scope.variable_scope( + dnn_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner): + with variable_scope.variable_scope('input', + partitioner=input_layer_partitioner): + net = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns) + + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + 'hiddenlayer_%d' % layer_id, + values=(net,)) as dnn_hidden_layer_scope: + net = core_layers.dense( + net, + units=num_hidden_units, + activation=dnn_activation_fn, + kernel_initializer=init_ops.glorot_uniform_initializer(), + name=dnn_hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = core_layers.dropout(net, rate=dnn_dropout, training=True) + _add_layer_summary(net, dnn_hidden_layer_scope.name) + + with variable_scope.variable_scope( + 'logits', + values=(net,)) as dnn_logits_scope: + logits = core_layers.dense( + net, + units=head.logits_dimension, + activation=None, + kernel_initializer=init_ops.glorot_uniform_initializer(), + name=dnn_logits_scope) + _add_layer_summary(dnn_logits, dnn_logits_scope.name) + + linear_parent_scope = 'linear' + + if not linear_feature_columns: + linear_logits = None + else: + with variable_scope.variable_scope( + linear_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=input_layer_partitioner) as scope: + linear_logits = feature_column_lib.linear_model( + features=features, + feature_columns=linear_feature_columns, + units=head.logits_dimension) + _add_layer_summary(linear_logits, scope.name) + + # Combine logits and build full model. + if dnn_logits is not None and linear_logits is not None: + logits = dnn_logits + linear_logits + elif dnn_logits is not None: + logits = dnn_logits + else: + logits = linear_logits + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + train_ops = [] + global_step = training_util.get_global_step() + if dnn_logits is not None: + train_ops.append( + dnn_optimizer.minimize( + loss, + var_list=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=dnn_parent_scope))) + if linear_logits is not None: + train_ops.append( + linear_optimizer.minimize( + loss, + var_list=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=linear_parent_scope))) + + train_op = control_flow_ops.group(*train_ops) + with ops.control_dependencies([train_op]): + with ops.colocate_with(global_step): + return state_ops.assign_add(global_step, 1) + + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + +class DNNLinearCombinedRegressor(estimator.Estimator): + """An estimator for TensorFlow Linear and DNN joined models for regresssion. + + Note: This estimator is also known as wide-n-deep. + + Example: + + ```python + numeric_feature = numeric_column(...) + sparse_column_a = categorical_column_with_hash_bucket(...) + sparse_column_b = categorical_column_with_hash_bucket(...) + + sparse_feature_a_x_sparse_feature_b = crossed_column(...) + sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a, + ...) + sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b, + ...) + + estimator = DNNLinearCombinedRegressor( + # wide settings + linear_feature_columns=[sparse_feature_a_x_sparse_feature_b], + linear_optimizer=tf.train.FtrlOptimizer(...), + # deep settings + dnn_feature_columns=[ + sparse_feature_a_emb, sparse_feature_b_emb, numeric_feature], + dnn_hidden_units=[1000, 500, 100], + dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + + # To apply L1 and L2 regularization, you can set optimizers as follows: + tf.train.ProximalAdagradOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.001) + # It is same for FtrlOptimizer. + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + """ + + def __init__(self, + model_dir=None, + linear_feature_columns=None, + linear_optimizer=None, + dnn_feature_columns=None, + dnn_optimizer=None, + dnn_hidden_units=None, + dnn_activation_fn=nn.relu, + dnn_dropout=None, + label_dimension=1, + input_layer_partitioner=None, + config=None): + """Initializes a DNNLinearCombinedRegressor instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + linear_feature_columns: An iterable containing all the feature columns + used by linear part of the model. All items in the set must be + instances of classes derived from `FeatureColumn`. + linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the linear part of the model. If `None`, will use a FTRL optimizer. + dnn_feature_columns: An iterable containing all the feature columns used + by deep part of the model. All items in the set must be instances of + classes derived from `FeatureColumn`. + dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the deep part of the model. If `None`, will use an Adagrad optimizer. + dnn_hidden_units: List of hidden units per layer. All layers are fully + connected. + dnn_activation_fn: Activation function applied to each layer. If None, + will use `tf.nn.relu`. + dnn_dropout: When not None, the probability we will drop out + a given coordinate. + label_dimension: Number of regression targets per example. This is the + size of the last dimension of the labels and logits `Tensor` objects + (typically, these have shape `[batch_size, label_dimension]`). + input_layer_partitioner: Partitioner for input layer. Defaults to + `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: RunConfig object to configure the runtime settings. + + Raises: + ValueError: If both linear_feature_columns and dnn_features_columns are + empty at the same time. + """ + linear_feature_columns = linear_feature_columns or [] + dnn_feature_columns = dnn_feature_columns or [] + self._feature_columns = linear_feature_columns + dnn_feature_columns + if not self._feature_columns: + raise ValueError('Either linear_feature_columns or dnn_feature_columns ' + 'must be defined.') + + def _model_fn(features, labels, mode, config): + return _dnn_linear_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + label_dimension=label_dimension), + linear_feature_columns=linear_feature_columns, + linear_optimizer=linear_optimizer, + dnn_feature_columns=dnn_feature_columns, + dnn_optimizer=dnn_optimizer, + dnn_hidden_units=dnn_hidden_units, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + super(DNNLinearCombinedRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) From 7106f9fac32c61af59285e6ccb0b9c623a8334c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 12:24:35 -0700 Subject: [PATCH 18/72] Implemented an initial version of virtual scheduler unit test. PiperOrigin-RevId: 157746305 --- tensorflow/core/grappler/costs/BUILD | 22 +++ .../grappler/costs/virtual_scheduler_test.cc | 136 ++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tensorflow/core/grappler/costs/virtual_scheduler_test.cc diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index aa675fcc771..206fac1decc 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -181,6 +181,28 @@ cc_library( ], ) +cc_test( + name = "virtual_scheduler_test", + srcs = ["virtual_scheduler_test.cc"], + deps = [ + ":graph_properties", + ":utils", + ":virtual_placer", + ":virtual_scheduler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:cost_estimator", + ], +) + cc_library( name = "measuring_cost_estimator", srcs = ["measuring_cost_estimator.cc"], diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc new file mode 100644 index 00000000000..cc4a63e5ff0 --- /dev/null +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/virtual_placer.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class VirtualSchedulerTest : public ::testing::Test { + protected: + void SetUp() override { + // Initializes cluster_ and placer_. + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; + + cluster_.reset(new VirtualCluster(devices)); + placer_.reset(new VirtualPlacer(cluster_.get())); + } + + void CreateSchedulerWithConv2Ds() { + // Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in + // fetch nodes. + const int bs = 4; + const int width = 10; + const int height = 10; + const int depth_in = 8; + const int kernel = 3; + const int depth_out = 16; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = tensorflow::ops::RandomUniform( + s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT); + auto y = tensorflow::ops::RandomUniform( + s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT); + auto z = tensorflow::ops::RandomUniform( + s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT); + auto f = tensorflow::ops::RandomUniform( + s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT); + std::vector strides = {1, 1, 1, 1}; + auto c0 = + tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); + auto c1 = + tensorflow::ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); + auto c2 = + tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + LOG(INFO) << def.DebugString(); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_conv2d_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"c0", "c1"}; + + scheduler_.reset(new VirtualScheduler( + grappler_item_.get(), true /* use_static_shapes */, + "CPU" /* default_device_type */, cluster_.get(), placer_.get())); + TF_CHECK_OK(scheduler_->Init()); + } + + // SetUp() inits cluster_ and placer_. + std::unique_ptr cluster_; + std::unique_ptr placer_; + + // grappler_item_ and scheduler_ will be initialized differently for each test + // case + std::unique_ptr grappler_item_; + std::unique_ptr scheduler_; +}; + +TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { + CreateSchedulerWithConv2Ds(); // init scheduler_. + + Costs zero_costs = Costs::ZeroCosts(); + std::unordered_map ops_executed; + do { + NodeInfo node_info = scheduler_->GetCurrNodeInfo(); + ops_executed[node_info.name] = node_info; + + // Check scheduling order: x and f before c0, and y and f before c1. + if (node_info.name == "c0") { + EXPECT_GT(ops_executed.count("x"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + } else if (node_info.name == "c1") { + EXPECT_GT(ops_executed.count("y"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + } + } while (scheduler_->MarkCurrNodeExecuted(zero_costs)); + + // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be + // executed. + EXPECT_EQ(8, ops_executed.size()); + + // x, y, f, c0, and c1 should be in the ops executed. + EXPECT_GT(ops_executed.count("x"), 0); + EXPECT_GT(ops_executed.count("y"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + EXPECT_GT(ops_executed.count("c0"), 0); + EXPECT_GT(ops_executed.count("c1"), 0); + + // z and c2 shouldn't be part of it. + EXPECT_EQ(ops_executed.count("z"), 0); + EXPECT_EQ(ops_executed.count("c2"), 0); + + // Check input / output properties. + EXPECT_EQ(1, ops_executed["x"].outputs.size()); + EXPECT_EQ(1, ops_executed["y"].outputs.size()); + EXPECT_EQ(1, ops_executed["f"].outputs.size()); + EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); + EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); +} +} // end namespace grappler +} // end namespace tensorflow From 7ad0d0698ab443324bbe68dd5d6476111c6b229a Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 1 Jun 2017 12:43:25 -0700 Subject: [PATCH 19/72] Add type error to start_queue_runners if given session is not a `tf.Session`. Due to semver, we suppress the error if a MonitoredSession is provided. PiperOrigin-RevId: 157748375 --- .../python/training/queue_runner_impl.py | 14 ++++++++++ .../python/training/queue_runner_test.py | 28 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index d713e222aee..4e58602a6f7 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -22,6 +22,7 @@ import threading import weakref from tensorflow.core.protobuf import queue_runner_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging @@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection: A `GraphKey` specifying the graph collection to get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. + Raises: + ValueError: if `sess` is None and there isn't any default session. + TypeError: if `sess` is not a `tf.Session` object. + Returns: A list of threads. """ @@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, raise ValueError("Cannot start queue runners: No default session is " "registered. Use `with sess.as_default()` or pass an " "explicit session to tf.start_queue_runners(sess=sess)") + + if not isinstance(sess, session.SessionInterface): + # Following check is due to backward compatibility. (b/62061352) + if sess.__class__.__name__ in [ + "MonitoredSession", "SingularMonitoredSession"]: + return [] + raise TypeError("sess must be a `tf.Session` object. " + "Given class: {}".format(sess.__class__)) + with sess.graph.as_default(): threads = [] for qr in ops.get_collection(collection): diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py index 5b00ac9fc31..51c0eecf46a 100644 --- a/tensorflow/python/training/queue_runner_test.py +++ b/tensorflow/python/training/queue_runner_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import coordinator +from tensorflow.python.training import monitored_session from tensorflow.python.training import queue_runner_impl @@ -247,6 +248,33 @@ class QueueRunnerTest(test.TestCase): # The variable should be 3. self.assertEqual(3, var.eval()) + def testStartQueueRunnersRaisesIfNotASession(self): + zero64 = constant_op.constant(0, dtype=dtypes.int64) + var = variables.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = data_flow_ops.FIFOQueue(10, dtypes.float32) + init_op = variables.global_variables_initializer() + qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) + queue_runner_impl.add_queue_runner(qr) + with self.test_session(): + init_op.run() + with self.assertRaisesRegexp(TypeError, "tf.Session"): + queue_runner_impl.start_queue_runners("NotASession") + + def testStartQueueRunnersIgnoresMonitoredSession(self): + zero64 = constant_op.constant(0, dtype=dtypes.int64) + var = variables.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = data_flow_ops.FIFOQueue(10, dtypes.float32) + init_op = variables.global_variables_initializer() + qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) + queue_runner_impl.add_queue_runner(qr) + with self.test_session(): + init_op.run() + threads = queue_runner_impl.start_queue_runners( + monitored_session.MonitoredSession()) + self.assertFalse(threads) + def testStartQueueRunnersNonDefaultGraph(self): # CountUpTo will raise OUT_OF_RANGE when it reaches the count. graph = ops.Graph() From fdffafbc19d85a63c72b76ecfeb2d92a4c43dc75 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 1 Jun 2017 14:11:58 -0700 Subject: [PATCH 20/72] Add QueueDequeueUpTo to the list of dequeue ops PiperOrigin-RevId: 157760201 --- tensorflow/core/grappler/op_types.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 5c2438e258e..7a239aeffec 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -30,8 +30,8 @@ bool IsConstant(const NodeDef& node) { bool IsDequeueOp(const NodeDef& node) { static const std::set dequeue_ops = { - "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", - "QueueDequeue"}; + "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", + "QueueDequeue", "QueueDequeueUpToV2", "QueueDequeueUpTo"}; return dequeue_ops.count(node.op()) > 0; } From 7866fa01b79a297908e7871d3b274fa02a5ce5e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 14:29:25 -0700 Subject: [PATCH 21/72] This change significantly reduces time and resources used to load large TensorFlow graphs. For a real-world large graph (13k nodes, 20k edges), this change: * reduces all heap allocations by 19% * reduces retained (final) heap allocations by 2.2% * reduces CPU time by 11.2% In most TF graphs, the set of unique values set to Node::assigned_device_name() is quite small. This change adds an interning table to the Graph object, which contains all of the unique values used for Node::set_assigned_device_name(), as well as a look-up table. This is the main source of the reduction in retained heap memory; nearly all nodes are assigned to just one or two unique devices. This change removes the "string assigned_device_name_" field from the Node class, and replaces it with "int assigned_device_name_index_". However, because you need both the index and the name table to get the actual value, the Node::assigned_device_name() accessor needs access to the parent Graph. This requires adding a "Graph* graph_" field to the Node class. In the future, if all users of this property are converted to use Graph::assigned_device_name(Node*), then the Node::graph_ field can be deleted, and the space reclaimed. However, doing so is out of the scope of this CL, and even with this new pointer field, the Node class is smaller than it was before, so this is still a net win. The placement algorithm in simple_placer.cc is one of the main accessors of the Node::assigned_device_name property. This CL contains significant changes to simple_placer.cc, which directly take advantage of the fact that the property is an index into a name table, rather than treating it simply as a string. Many temporary allocations are also removed, which is the main source of the reduction in total heap allocations. This CL also contains a few changes that remove short-lived allocations in unrelated code, such as the changes in op.cc/h, costmodel.cc, etc. It is extremely easy in C++ to accidentally allocate memory, especially when implicit conversions and copy constructors allocate memory. All of the changes in this CL were motivated by empirical measurement, using CPU profiling and heap profiling. PiperOrigin-RevId: 157762909 --- tensorflow/core/common_runtime/device.h | 4 +- tensorflow/core/common_runtime/device_set.cc | 2 +- .../core/common_runtime/simple_placer.cc | 630 ++++++++++-------- .../core/common_runtime/simple_placer.h | 4 +- tensorflow/core/framework/op.cc | 4 +- tensorflow/core/framework/op.h | 6 +- tensorflow/core/framework/types.h | 1 + tensorflow/core/graph/costmodel.cc | 8 + tensorflow/core/graph/graph.cc | 36 +- tensorflow/core/graph/graph.h | 85 ++- tensorflow/core/graph/graph_constructor.cc | 7 +- 11 files changed, 473 insertions(+), 314 deletions(-) diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index c0e58f143e3..11024805cb2 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -60,7 +60,9 @@ class Device : public DeviceBase { const string& name() const { return device_attributes_.name(); } // Parsed name of this device - const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; } + const DeviceNameUtils::ParsedName& parsed_name() const { + return parsed_name_; + } // Describes what kind of device this is. This is intended to be // human-readable and not computer-parsed, except that two devices diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index 0ed9470655b..493349176ea 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const { // static int DeviceSet::DeviceTypeOrder(const DeviceType& d) { - return DeviceFactory::DevicePriority(d.type()); + return DeviceFactory::DevicePriority(d.type_string()); } static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) { diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 13a5133887d..6b7c47f8fe5 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -69,36 +69,6 @@ std::vector FilterSupportedDevices( return filtered_devices; } -// Returns the name of the colocation group of the node by inspecting -// the kColocationAttrName attribute of the NodeDef. -void ColocationGroups(const Node& node, - std::vector* colocation_groups) { - std::vector class_specs; - // TODO(vrv): We should consider adding a GetNodeAttr that returns a - // StringPiece, to avoid a copy. - if (!GetNodeAttrSimple(node.attrs(), kColocationAttrNameStringPiece, - &class_specs)) { - // No attribute value is equivalent to the empty colocation_group. - *colocation_groups = { - strings::StrCat(kColocationGroupPrefixStringPiece, node.name())}; - return; - } - - bool found_spec = false; - for (const string& class_spec : class_specs) { - StringPiece spec(class_spec); - if (spec.Consume(kColocationGroupPrefixStringPiece)) { - found_spec = true; - colocation_groups->emplace_back(class_spec); - } - } - - if (!found_spec) { - *colocation_groups = { - strings::StrCat(kColocationGroupPrefixStringPiece, node.name())}; - } -} - // This class maintains the connected components of a colocation // constraint graph, and uses this information to assign a satisfying // device placement to the nodes of the graph. @@ -130,51 +100,96 @@ void ColocationGroups(const Node& node, class ColocationGraph { public: ColocationGraph(Graph* graph, const DeviceSet* device_set, - const SessionOptions* options) - : device_set_(device_set), + bool allow_soft_placement) + : graph_(graph), + device_set_(device_set), device_types_(device_set->PrioritizedDeviceTypeList()), - options_(options) { - members_.reserve(graph->num_node_ids()); + allow_soft_placement_(allow_soft_placement) { + members_.resize(graph->num_node_ids()); } - // Adds the given node to this ColocationGraph as a singleton. + // Adds each node of the Graph to this ColocationGraph as a singleton. // // NOTE: The implementation assumes that the ids of nodes passed to // this method are dense and zero-based; the memory used will be linear in // the largest node ID. // NOTE: If this method returns an error, *this is left in an undefined // state. - Status AddNode(const Node& node) { - Member member; - TF_RETURN_IF_ERROR(InitializeMember(node, &member)); - CHECK_GE(member.parent, 0); - members_.resize(member.parent + 1); - members_[member.parent] = std::move(member); + Status ColocateAllNodes() { + // This maps from a colocation group identifier to the 'root' of that + // colocation group. Note that the keys in this map are StringPiece; the + // actual strings are stored under the NodeDef. The lifetime of this map + // is limited to this ColocateAllNodes() method, and no part of the + // NodeDef trees are changed during the lifetime of this method, so using + // StringPiece as a key is safe. + // + // Also, as a further optimization, we remove the "loc:@" prefix from + // "class" attribute values, when they are used as keys in this table. + // This allows us to use StringPiece values that refer to substrings of + // 'string' values stored in NodeDef attribute lists, as well as StringPiece + // values that refer to 'string' values from NodeDef::name(), without + // performing any string allocations. + std::unordered_map + colocation_group_root; - // When adding the node, identify whether it is part of a - // colocation group. - std::vector colocation_groups; - ColocationGroups(node, &colocation_groups); - Status s; - for (const string& colocation_group : colocation_groups) { - auto it = colocation_group_root_.find(colocation_group); - if (it == colocation_group_root_.end()) { - // This is the first node of the colocation group, so - // designate this node as the 'root' of that colocation group. - colocation_group_root_[colocation_group] = &node; - } else { - // Try to colocate the node with the root. If there is an - // error, return it. - s = ColocateNodes(node, *(it->second)); - if (!s.ok()) { - return s; + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + + // When adding the node, identify whether it is part of a + // colocation group. + + // This code is effectively the equivalent of GetNodeAttr() for a string + // array, but it avoids all internal allocations (the allocation of the + // backing store of the std::vector as well as the copies of the + // strings within it). Instead, we combine the query of the colocation + // attribute with the calls to ColocateNodeToGroup. + bool found_spec = false; + const AttrValue* attr_value = + AttrSlice(node->def()).Find(kColocationAttrNameStringPiece); + if (attr_value != nullptr && attr_value->has_list()) { + for (const string& class_spec : attr_value->list().s()) { + StringPiece spec(class_spec); + if (spec.Consume(kColocationGroupPrefixStringPiece)) { + found_spec = true; + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, spec)); + } } } + + if (!found_spec) { + // If the node does not specify a colocation group, then use the + // name of this node as the colocation group. + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, node->name())); + } } return Status::OK(); } + Status ColocateNodeToGroup( + std::unordered_map* + colocation_group_root, + Node* node, StringPiece colocation_group) { + const Node*& root_node = (*colocation_group_root)[colocation_group]; + if (root_node == nullptr) { + // This is the first node of the colocation group, so + // designate this node as the 'root' of that colocation group. + root_node = node; + } else { + // Try to colocate the node with the root. If there is an + // error, return it. + Status s = ColocateNodes(*node, *root_node); + if (!s.ok()) { + return AttachDef(s, node->def()); + } + } + return Status::OK(); + } + // Merge the (possibly disjoint) sets containing nodes "x" and // "y". Returns OK if the all nodes in the union of these sets can // be placed on the same device type. @@ -184,105 +199,104 @@ class ColocationGraph { Status ColocateNodes(const Node& x, const Node& y) { int x_root = FindRoot(x.id()); int y_root = FindRoot(y.id()); + return ColocateNodes(x, x_root, y, y_root); + } - Status s; - if (x_root != y_root) { - // Merge the sets by swinging the parent pointer of the smaller - // tree to point to the root of the larger tree. Together with - // path compression in ColocationGraph::FindRoot, this ensures - // that we do not experience pathological performance on graphs - // such as chains. - int new_root, old_root; - if (members_[x_root].rank < members_[y_root].rank) { - // The tree rooted at x_root is shallower, so connect it to - // y_root. The rank of y_root is unchanged because its new - // child has strictly less rank. - members_[x_root].parent = y_root; - new_root = y_root; - old_root = x_root; - } else if (members_[x_root].rank > members_[y_root].rank) { - // The tree rooted at y_root is shallower, so connect it to - // x_root. The rank of x_root is unchanged because its new - // child has strictly less rank. - members_[y_root].parent = x_root; - new_root = x_root; - old_root = y_root; - } else { - // Both trees have the same rank, so break the tie by choosing - // x_root as the new root. - members_[y_root].parent = x_root; - // Increment the rank of the tree rooted at x_root, because it - // is now strictly deeper than before. - ++members_[x_root].rank; - new_root = x_root; - old_root = y_root; - } - - // Merge the partial device specifications, and ensure that they are - // compatible. NULL options_ is treated as allowing soft placement. - // TODO(mrry): Consider enriching the error message by pointing - // out which nodes have the explicit partial device - // specifications that caused this conflict. - s = DeviceNameUtils::MergeDevNames( - &members_[new_root].device_name, members_[old_root].device_name, - options_ == nullptr || options_->config.allow_soft_placement()); - if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); - } - - // Transfer ids in the old group to the new one. - members_[new_root].ids_in_group.insert( - members_[old_root].ids_in_group.begin(), - members_[old_root].ids_in_group.end()); - members_[old_root].ids_in_group.clear(); - - // Ensure that the common root has at least one supported device - // type, by computing the intersection of - // members_[new_root].supported_device_types and - // members_[old_root].supported_device_types. - MergeSupportedDevices(&members_[new_root].supported_device_types, - members_[old_root].supported_device_types); - if (members_[new_root].supported_device_types.empty()) { - string debug_info; - AddDebugInfo(x_root, &debug_info); - AddDebugInfo(y_root, &debug_info); - return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " - "other nodes colocated with them.", - debug_info); - } + // This overload of ColocateNodes() allows a caller to provide the root node + // ids for the two nodes. For large graphs, this noticeably reduces the + // graph load time. + Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { + if (x_root == y_root) { + return Status::OK(); } + + DCHECK_EQ(x_root, FindRoot(x.id())); + DCHECK_EQ(y_root, FindRoot(y.id())); + + Member& x_root_member = members_[x_root]; + Member& y_root_member = members_[y_root]; + + // Merge the sets by swinging the parent pointer of the smaller + // tree to point to the root of the larger tree. Together with + // path compression in ColocationGraph::FindRoot, this ensures + // that we do not experience pathological performance on graphs + // such as chains. + int new_root, old_root; + if (x_root_member.rank < y_root_member.rank) { + // The tree rooted at x_root is shallower, so connect it to + // y_root. The rank of y_root is unchanged because its new + // child has strictly less rank. + x_root_member.parent = y_root; + new_root = y_root; + old_root = x_root; + } else if (x_root_member.rank > y_root_member.rank) { + // The tree rooted at y_root is shallower, so connect it to + // x_root. The rank of x_root is unchanged because its new + // child has strictly less rank. + y_root_member.parent = x_root; + new_root = x_root; + old_root = y_root; + } else { + // Both trees have the same rank, so break the tie by choosing + // x_root as the new root. + y_root_member.parent = x_root; + // Increment the rank of the tree rooted at x_root, because it + // is now strictly deeper than before. + ++x_root_member.rank; + new_root = x_root; + old_root = y_root; + } + + Member& new_root_member = members_[new_root]; + Member& old_root_member = members_[old_root]; + + // Merge the partial device specifications, and ensure that they are + // compatible. NULL options_ is treated as allowing soft placement. + // TODO(mrry): Consider enriching the error message by pointing + // out which nodes have the explicit partial device + // specifications that caused this conflict. + Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name, + old_root_member.device_name, + allow_soft_placement_); + if (!s.ok()) { + return errors::InvalidArgument("Cannot colocate nodes '", x.name(), + "' and '", y.name(), ": ", + s.error_message()); + } + + // Ensure that the common root has at least one supported device + // type, by computing the intersection of + // new_root_member.supported_device_types and + // old_root_member.supported_device_types. + MergeSupportedDevices(&new_root_member.supported_device_types, + old_root_member.supported_device_types); + if (new_root_member.supported_device_types.empty()) { + return errors::InvalidArgument( + "Cannot colocate nodes '", x.name(), "' and '", y.name(), + "' because no device type supports both of those nodes and the " + "other nodes colocated with them.", + DebugInfo(x_root), DebugInfo(y_root)); + } + return Status::OK(); } - // Returns the device name associated with 'node'. - DeviceNameUtils::ParsedName DeviceForNode(const Node& node) { - int node_root = FindRoot(node.id()); - return members_[node_root].device_name; - } - - void SetDeviceForNode(Node* node, const DeviceNameUtils::ParsedName& device) { - int node_root = FindRoot(node->id()); - members_[node_root].device_name = device; - } - // For the given node, subject to the constraints previously given // to this ColocationGraph, set its assigned_device_name. Returns OK // if a satisfying device can be found, otherwise an error. - Status GetDevicesForNode(Node* node, std::vector* possible_devices) { - possible_devices->clear(); + // + // Note: This method returns a pointer to a field within members_. + // The caller must not use the returned pointer after there is any possibility + // that the members_[i].possible_devices field has been modified. + Status GetDevicesForNode(Node* node, + std::vector** possible_devices) { + *possible_devices = nullptr; const int node_root = FindRoot(node->id()); if (!members_[node_root].possible_devices.empty()) { - *possible_devices = members_[node_root].possible_devices; + *possible_devices = &members_[node_root].possible_devices; return Status::OK(); } - // String containing additional debugging info on failures. - string debug_info; - // We have not yet computed the possible devices for the // colocated node set containing 'node', so we do so now using the // constraints on the root node. @@ -304,10 +318,8 @@ class ColocationGraph { devices, members_[node_root].supported_device_types); } - // Perform soft placement if allow_soft_placement is set. options_ - // being NULL is treated as allowing soft placement. - if (devices.empty() && - (options_ == nullptr || options_->config.allow_soft_placement())) { + // Perform soft placement if allow_soft_placement_ is set. + if (devices.empty() && allow_soft_placement_) { // The soft_device_name is the same as the node's device name // without specifying the device type or ID. DeviceNameUtils::ParsedName soft_device_name = @@ -326,7 +338,7 @@ class ColocationGraph { // Return an error when a physical device that matches an explicit // device specification is not found. This ensures that we don't // assign a node to GPU when the user wanted to force it on CPU. - AddDebugInfo(node_root, &debug_info); + string debug_info = DebugInfo(node_root); DeviceNameUtils::ParsedName specified_device_name; if (DeviceNameUtils::ParseFullName(node->requested_device(), @@ -386,21 +398,32 @@ class ColocationGraph { device_set_->devices(), members_[node_root].supported_device_types); if (devices.empty()) { - AddDebugInfo(node_root, &debug_info); return errors::InvalidArgument( "Node had no OpKernel registered to support this operation: ", "Operation was ", node->type_string(), " and inputs were ", - DataTypeVectorString(node->input_types()), debug_info); + DataTypeVectorString(node->input_types()), DebugInfo(node_root)); } } // Cache the result of the possible devices for this node group. - members_[node_root].possible_devices = devices; - *possible_devices = members_[node_root].possible_devices; + members_[node_root].possible_devices = std::move(devices); + *possible_devices = &members_[node_root].possible_devices; + return Status::OK(); + } + + Status InitializeMembers() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + Status status = InitializeMember(*node, &members_[node->id()]); + if (!status.ok()) { + return AttachDef(status, node->def()); + } + } return Status::OK(); } - private: // Represents a node in the disjoint node set forest, and the // accumulated constraints on the device used by that node. struct Member { @@ -409,15 +432,6 @@ class ColocationGraph { // id if it is a root. parent <= 0 indicates that this member is invalid. int parent = -1; - // The set of ids that are part of the disjoint node set forest. - // - // This is only fully specified in the root of a disjoint - // node set forest. - std::set ids_in_group; - - // The type of the op for this node. - string op_type; - // A proxy for the depth of the tree that is used to prefer // connecting smaller trees to larger trees when merging disjoint // sets. @@ -438,49 +452,56 @@ class ColocationGraph { std::vector possible_devices; }; - // Adds debugging info to 'output' for the node referred to by - // 'node_root'. - void AddDebugInfo(const int node_root, string* output) { - if (members_[node_root].ids_in_group.size() > 1) { - strings::StrAppend(output, "\nColocation Debug Info:\n"); + // Returns debugging info for the node referred to by 'node_root'. + string DebugInfo(const int node_root) { + string text( + "\nColocation Debug Info:\n" + "Colocation group had the following types and devices: "); - // If this node is part of a colocation group, then we want to - // collect the mapping of ops to supported devices, so that - // the user can see why an unsatisfiable placement occurred. - strings::StrAppend( - output, "Colocation group had the following types and devices: "); + // If this node is part of a colocation group, then we want to + // collect the mapping of ops to supported devices, so that + // the user can see why an unsatisfiable placement occurred. - std::unordered_map type_to_devices; - for (const int id : members_[node_root].ids_in_group) { - const string& op_type = members_[id].op_type; - string devices_registered; - for (const auto& device_type : members_[id].supported_device_types) { - strings::StrAppend(&devices_registered, DeviceTypeString(device_type), - " "); - } + std::unordered_map type_to_devices; + int num_nodes_found = 0; - type_to_devices[op_type] = devices_registered; + for (const Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + int id = node->id(); + if (FindRoot(id) != node_root) { + continue; + } + ++num_nodes_found; + const string& op_type = node->type_string(); + string devices_registered; + for (const auto& device_type : members_[id].supported_device_types) { + strings::StrAppend(&devices_registered, DeviceTypeString(device_type), + " "); } - for (const auto& td : type_to_devices) { - strings::StrAppend(output, "\n", td.first, ": ", td.second); - } + type_to_devices[op_type] = std::move(devices_registered); } + + for (const auto& td : type_to_devices) { + strings::StrAppend(&text, "\n", td.first, ": ", td.second); + } + + if (num_nodes_found <= 1) { + text.clear(); + } + return text; } Status InitializeMember(const Node& node, Member* member) { const int id = node.id(); - member->ids_in_group.insert(id); - member->op_type = node.type_string(); - - if (id < 0) { - return errors::InvalidArgument("Node id was not positive: ", id); - } + DCHECK_GE(id, 0); member->parent = id; TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( device_types_, node.def(), &member->supported_device_types)); - if (!node.assigned_device_name().empty()) { + if (node.has_assigned_device_name()) { // This node has already been assigned to a device, so we // respect this placement, after sanity-checking it. The // device_name and supported_device_types for this node reflect @@ -490,17 +511,16 @@ class ColocationGraph { // NOTE: Since any assignment must have been performed by // the TensorFlow runtime, we consider errors in this branch to // be INTERNAL. - if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(), + const string& assigned_device_name = node.assigned_device_name(); + if (!DeviceNameUtils::ParseFullName(assigned_device_name, &member->device_name)) { return errors::Internal("Malformed assigned device '", - node.assigned_device_name(), "'"); + assigned_device_name, "'"); } - std::vector devices; const Device* assigned_device = - device_set_->FindDeviceByName(node.assigned_device_name()); + device_set_->FindDeviceByName(assigned_device_name); if (assigned_device == nullptr) { - return errors::Internal("Assigned device '", - node.assigned_device_name(), + return errors::Internal("Assigned device '", assigned_device_name, "' does not match any device"); } @@ -510,7 +530,7 @@ class ColocationGraph { } } - return errors::Internal("Assigned device '", node.assigned_device_name(), + return errors::Internal("Assigned device '", assigned_device_name, "' does not have registered OpKernel support " "for ", node.type_string()); @@ -577,32 +597,32 @@ class ColocationGraph { // Returns the root node of the disjoint tree to which the node with the // given id is connected. int FindRoot(int node_id) { - DCHECK_GE(members_[node_id].parent, 0); - if (members_[node_id].parent != node_id) { + Member& member = members_[node_id]; + + int parent = member.parent; + DCHECK_GE(parent, 0); + + if (parent != node_id) { // NOTE: Compress paths from node_id to its root, so that future // calls to FindRoot and ColocateNodes are more efficient. - members_[node_id].parent = FindRoot(members_[node_id].parent); + int root = FindRoot(parent); + if (parent != root) { + parent = root; + member.parent = root; + } } - return members_[node_id].parent; + + DCHECK_GE(parent, 0); + return parent; } + Graph* const graph_; // Not owned. std::vector members_; const DeviceSet* device_set_; // Not owned. const std::vector device_types_; - const SessionOptions* options_; // Not owned; - - // Maps from a colocation group identifier to the 'root' of that - // colocation group. - std::unordered_map colocation_group_root_; + const bool allow_soft_placement_; }; -// Returns true if the node only depends on its input's metadata -// (shape). Not necessarily a complete list. -bool IsMetadataNode(const Node* node) { - const string& node_type = node->type_string(); - return (node_type == "Size" || node_type == "Shape" || node_type == "Rank"); -} - // Returns true if the node has no inputs and produces outputs // that are consumed by a single node. // @@ -618,12 +638,14 @@ bool IsGeneratorNode(const Node* node) { SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, const SessionOptions* options) - : graph_(graph), devices_(devices), options_(options) {} + : graph_(graph), + devices_(devices), + options_(options), + log_device_placement_(options != nullptr && + options->config.log_device_placement()) {} SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices) - : graph_(graph), devices_(devices) { - options_ = nullptr; -} + : SimplePlacer(graph, devices, nullptr) {} SimplePlacer::~SimplePlacer() {} @@ -632,91 +654,93 @@ Status SimplePlacer::Run() { return errors::FailedPrecondition("No devices are registered"); } - ColocationGraph colocation_graph(graph_, devices_, options_); - Status status; + ColocationGraph colocation_graph( + graph_, devices_, + options_ == nullptr || options_->config.allow_soft_placement()); + + TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers()); // 1. First add all of the nodes. Note that steps (1) and (2) // requires two passes over the nodes because the graph (and hence // the constraints) may not be acyclic. - for (Node* node : graph_->op_nodes()) { - status = colocation_graph.AddNode(*node); - if (!status.ok()) return AttachDef(status, *node); - } + TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes()); // 2. Enumerate the constraint edges, and use them to update the disjoint // node set. - for (Node* node : graph_->op_nodes()) { - // If `node` has an input edge with reference type, add an - // edge from the source of that edge to `node`. - for (const auto& edge : node->in_edges()) { - if (!edge->IsControlEdge() && - (IsRefType(node->input_type(edge->dst_input())) || - node->input_type(edge->dst_input()) == DT_RESOURCE)) { - // If both the source node and this node have paritally - // specified a device, then 'node's device should be - // cleared: the reference edge forces 'node' to be on the - // same device as the source node. - auto source_parsed_name = colocation_graph.DeviceForNode(*edge->src()); - auto dest_parsed_name = colocation_graph.DeviceForNode(*node); - if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && - DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { - // Add a log saying that we are ignoring a specified device - // for 'node' if the two names were incompatible. - if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, - dest_parsed_name)) { - LOG(INFO) << "Ignoring device specification " - << DeviceNameUtils::ParsedNameToString( - colocation_graph.DeviceForNode(*node)) - << " for node '" << node->name() - << "' because the input edge from '" - << edge->src()->name() - << "' is a reference connection and already has a device " - "field set to " - << DeviceNameUtils::ParsedNameToString( - colocation_graph.DeviceForNode(*edge->src())); - // Make 'node' colocated with the source - colocation_graph.SetDeviceForNode(node, source_parsed_name); + // If `node` has an input edge with reference type, add an + // edge from the source of that edge to `node`. + for (const Edge* edge : graph_->edges()) { + if (edge->IsControlEdge()) { + continue; + } + Node* src = edge->src(); + Node* dst = edge->dst(); + DataType input_type = dst->input_type(edge->dst_input()); + if (input_type == DT_RESOURCE || IsRefType(input_type)) { + int src_root_id = colocation_graph.FindRoot(src->id()); + int dst_root_id = colocation_graph.FindRoot(dst->id()); + auto& src_root = colocation_graph.members_[src_root_id]; + auto& dst_root = colocation_graph.members_[dst_root_id]; + // If both the source node and this node have paritally + // specified a device, then 'node's device should be + // cleared: the reference edge forces 'node' to be on the + // same device as the source node. + const auto& source_parsed_name = src_root.device_name; + const auto& dest_parsed_name = dst_root.device_name; + if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && + DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { + // Add a log saying that we are ignoring a specified device + // for 'dst' if the two names were incompatible. + if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, + dest_parsed_name)) { + LOG(INFO) << "Ignoring device specification " + << DeviceNameUtils::ParsedNameToString(dest_parsed_name) + << " for node '" << dst->name() + << "' because the input edge from '" << src->name() + << "' is a reference connection and already has a device " + "field set to " + << DeviceNameUtils::ParsedNameToString(source_parsed_name); + + // Make 'dst' colocated with the source + dst_root.device_name = source_parsed_name; + } else { + bool source_subset_of_dest = DeviceNameUtils::IsSpecification( + source_parsed_name, dest_parsed_name); + bool dest_subset_of_source = DeviceNameUtils::IsSpecification( + dest_parsed_name, source_parsed_name); + + if (source_subset_of_dest && !dest_subset_of_source) { + src_root.device_name = dest_parsed_name; } else { - bool source_subset_of_dest = DeviceNameUtils::IsSpecification( - source_parsed_name, dest_parsed_name); - bool dest_subset_of_source = DeviceNameUtils::IsSpecification( - dest_parsed_name, source_parsed_name); - - if (source_subset_of_dest && !dest_subset_of_source) { - colocation_graph.SetDeviceForNode(edge->src(), dest_parsed_name); - } else { - colocation_graph.SetDeviceForNode(node, source_parsed_name); - } + dst_root.device_name = source_parsed_name; } } + } - status = colocation_graph.ColocateNodes(*edge->src(), *node); - if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Nodes were connected by a " - "reference connection (requiring them to " - "be on the same device), but the two nodes " - "were assigned two different devices: ", - status.error_message()), - *node); - } + Status status = + colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Nodes were connected by a " + "reference connection (requiring them to " + "be on the same device), but the two nodes " + "were assigned two different devices: ", + status.error_message()), + dst->def()); } } } // 3. For each node, assign a device based on the constraints in the // disjoint node set. - std::vector devices; std::vector second_pass; for (Node* node : graph_->op_nodes()) { // The graph may have come pre-populated by the framework with assigned // devices (e.g., for stateful placements), so the placer should not try to // place nodes that are already placed. - if (!node->assigned_device_name().empty()) { - // Although the device is already assigned, we run this function to - // possibly log pre-assigned placements. - AssignAndLog(node->assigned_device_name(), node); + if (node->has_assigned_device_name()) { + LogDeviceAssignment(node); continue; } @@ -731,7 +755,8 @@ Status SimplePlacer::Run() { continue; } - status = colocation_graph.GetDevicesForNode(node, &devices); + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", @@ -748,12 +773,12 @@ Status SimplePlacer::Run() { // given a choice of devices. Once we have a better idea of the // types of heuristics we want to use and the information needed // to perform good placement we can add an interface for this. - string assigned_device = devices[0]->name(); + int assigned_device = -1; // Heuristic B: If the node only operates on metadata, not data, // then it is desirable to place that metadata node with its // input. - if (IsMetadataNode(node)) { + if (IsMetadata(node)) { // Make sure that the input device type is in the list of supported // device types for this node. const Node* input = (*node->in_edges().begin())->src(); @@ -761,19 +786,24 @@ Status SimplePlacer::Run() { // node's assignment to the second pass, so that we handle the // case where a metadata node's input comes from a backedge // of a loop. - const string& input_device_name = input->assigned_device_name(); - if (CanAssignToDevice(input_device_name, devices)) { - assigned_device = input_device_name; + if (CanAssignToDevice(input->assigned_device_name(), *devices)) { + assigned_device = input->assigned_device_name_index(); } } + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + AssignAndLog(assigned_device, node); } // 4. Perform a second pass assignment for those nodes explicitly // skipped during the first pass. for (Node* node : second_pass) { - status = colocation_graph.GetDevicesForNode(node, &devices); + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", @@ -781,25 +811,30 @@ Status SimplePlacer::Run() { *node); } - string assigned_device = devices[0]->name(); + int assigned_device = -1; // Heuristic A application. if (IsGeneratorNode(node)) { const Node* output = (*node->out_edges().begin())->dst(); - const string& output_device_name = output->assigned_device_name(); + int output_device_name = output->assigned_device_name_index(); const bool consumers_on_same_device = std::all_of( node->out_edges().begin(), node->out_edges().end(), [output_device_name](const Edge* e) { - return e->dst()->assigned_device_name() == output_device_name; + return e->dst()->assigned_device_name_index() == output_device_name; }); if (consumers_on_same_device && - CanAssignToDevice(output_device_name, devices)) { + CanAssignToDevice(output->assigned_device_name(), *devices)) { assigned_device = output_device_name; } } + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + AssignAndLog(assigned_device, node); } @@ -824,11 +859,14 @@ bool SimplePlacer::CanAssignToDevice( return false; } -void SimplePlacer::AssignAndLog(const string& assigned_device, - Node* node) const { - node->set_assigned_device_name(assigned_device); +void SimplePlacer::AssignAndLog(int assigned_device, Node* node) const { + node->set_assigned_device_name_index(assigned_device); + LogDeviceAssignment(node); +} + +void SimplePlacer::LogDeviceAssignment(const Node* node) const { // Log placement if log_device_placement is set. - if (options_ && options_->config.log_device_placement()) { + if (log_device_placement_) { printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), node->assigned_device_name().c_str()); LOG(INFO) << node->name() << ": " diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h index a041e968309..9c63cef40bb 100644 --- a/tensorflow/core/common_runtime/simple_placer.h +++ b/tensorflow/core/common_runtime/simple_placer.h @@ -86,11 +86,13 @@ class SimplePlacer { // Assigns 'node's devices to 'assigned_device', and logs the // placement if the SessionOptions entry in 'options_' requests it. - void AssignAndLog(const string& assigned_device, Node* node) const; + void AssignAndLog(int assigned_device, Node* node) const; + void LogDeviceAssignment(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. const SessionOptions* options_; // Not owned. + const bool log_device_placement_; TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer); }; diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 5ddac6b1982..fe333dc9ffa 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -48,7 +48,7 @@ OpRegistry::~OpRegistry() { for (const auto& e : registry_) delete e.second; } -void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) { +void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { mutex_lock lock(mu_); if (initialized_) { TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); @@ -181,7 +181,7 @@ Status OpRegistry::CallDeferred() const { } Status OpRegistry::RegisterAlreadyLocked( - OpRegistrationDataFactory op_data_factory) const { + const OpRegistrationDataFactory& op_data_factory) const { std::unique_ptr op_reg_data(new OpRegistrationData); Status s = op_data_factory(op_reg_data.get()); if (s.ok()) { diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 892ed9b60b4..c5a0983a547 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -70,7 +70,7 @@ class OpRegistry : public OpRegistryInterface { OpRegistry(); ~OpRegistry() override; - void Register(OpRegistrationDataFactory op_data_factory); + void Register(const OpRegistrationDataFactory& op_data_factory); Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const override; @@ -138,8 +138,8 @@ class OpRegistry : public OpRegistryInterface { // Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. - Status RegisterAlreadyLocked(OpRegistrationDataFactory op_data_factory) const - EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) + const EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 0a81b1cb9f3..f562880e7cf 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -57,6 +57,7 @@ class DeviceType { explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} const char* type() const { return type_.c_str(); } + const string& type_string() const { return type_; } bool operator<(const DeviceType& other) const; bool operator==(const DeviceType& other) const; diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 69247a4f621..f798af85e15 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -476,6 +476,14 @@ static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) { } // namespace void CostModel::InitFromGraph(const Graph& g) { + const int num_node_ids = g.num_node_ids(); + slot_bytes_.reserve(num_node_ids); + count_.reserve(num_node_ids); + time_.reserve(num_node_ids); + max_mem_usage_.reserve(num_node_ids); + max_exec_time_.reserve(num_node_ids); + output_port_alloc_ids_.reserve(num_node_ids); + AddNodesToCostModel(g, this); AssignSizes(g, this); EstimateComputationCosts(g, this); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 80161ceb56b..dcb8520cf73 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -56,6 +56,9 @@ const std::unordered_map& Node::kNodeClassTable = {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, {"GetSessionTensor", NC_GET_SESSION_TENSOR}, {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, + {"Size", NC_METADATA}, + {"Shape", NC_METADATA}, + {"Rank", NC_METADATA}, }); #undef REF_CLASS @@ -77,7 +80,7 @@ string Node::DebugString() const { strings::StrAppend(&ret, " sink}"); } else { strings::StrAppend(&ret, " op device:"); - strings::StrAppend(&ret, "{", assigned_device_name_, "}"); + strings::StrAppend(&ret, "{", assigned_device_name(), "}"); strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); } return ret; @@ -88,7 +91,7 @@ Node::Node() cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), - assigned_device_name_() {} + assigned_device_name_index_(0) {} Node::~Node() { if (props_) { @@ -124,7 +127,7 @@ void Node::Clear() { props_ = nullptr; } - assigned_device_name_.clear(); + assigned_device_name_index_ = 0; } gtl::iterator_range Node::out_nodes() const { @@ -241,6 +244,10 @@ Graph::Graph(const OpRegistryInterface* ops) versions_.set_producer(TF_GRAPH_DEF_VERSION); versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); + // Initialize the name interning table for assigned_device_name. + device_names_.push_back(""); + DCHECK_EQ(0, InternDeviceName("")); + // Source and sink have no endpoints, just control edges. NodeDef def; def.set_name("_SOURCE"); @@ -503,6 +510,7 @@ Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { node = free_nodes_.back(); free_nodes_.pop_back(); } + node->graph_ = this; const int id = nodes_.size(); int cost_id = cost_node ? cost_node->cost_id() : id; node->Initialize(id, cost_id, props); @@ -519,4 +527,26 @@ void Graph::ReleaseNode(Node* node) { node->Clear(); } +// Ensures that 'device_name' is present in the device name table, and returns +// the index of that device name. The index is stable, and can be used in +// calls to Node::set_assigned_device_name_index(). +int Graph::InternDeviceName(const string& device_name) { + // Special case, very common. Also, this allows us to use a single map + // lookup below, instead of two. The 'if (index_cell > 0)' test below + // relies on this check. + if (device_name.empty()) { + return 0; + } + + int& index_cell = device_names_map_[device_name]; + if (index_cell > 0) { + return index_cell; + } + + const int index = device_names_map_.size(); + index_cell = index; + device_names_.push_back(device_name); + return index; +} + } // namespace tensorflow diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index e82580f204b..8cb270170e9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -104,10 +104,13 @@ class Node { // fully specifies a device, and satisfies def().device(). // TODO(josh11b): Move assigned_device_name outside of Node into a // NodeId->DeviceName map. - string assigned_device_name() const { return assigned_device_name_; } - void set_assigned_device_name(const string& device_name) { - assigned_device_name_ = device_name; + const string& assigned_device_name() const; + void set_assigned_device_name(const string& device_name); + bool has_assigned_device_name() const { + return assigned_device_name_index_ > 0; } + int assigned_device_name_index() const { return assigned_device_name_index_; } + void set_assigned_device_name_index(int index); // Read only access to attributes AttrSlice attrs() const { return AttrSlice(def()); } @@ -155,6 +158,8 @@ class Node { bool IsHostSend() const { return class_ == NC_HOST_SEND; } bool IsHostRecv() const { return class_ == NC_HOST_RECV; } + bool IsMetadata() const { return class_ == NC_METADATA; } + template void AddAttr(const string& name, const T& val) { MaybeCopyOnWrite(); @@ -232,6 +237,7 @@ class Node { NC_GET_SESSION_HANDLE, NC_GET_SESSION_TENSOR, NC_DELETE_SESSION_TENSOR, + NC_METADATA, NC_OTHER // Not a special kind of node }; @@ -248,8 +254,16 @@ class Node { Properties* props_; - // Name of device assigned to perform this computation. - string assigned_device_name_; + // Index within Graph::device_names_ of the name of device assigned + // to perform this computation. + int assigned_device_name_index_; + + // A back-pointer to the Graph that owns this node. Currently, this exists + // solely to allow Node::[set_]assigned_device_name() to work. However, if all + // callers of Node::[set_]assigned_device_name() are modified to use the + // equivalent methods defined directly on Graph, then we can remove this + // field and reclaim that memory. + Graph* graph_; TF_DISALLOW_COPY_AND_ASSIGN(Node); }; @@ -478,6 +492,26 @@ class Graph { const OpRegistryInterface* op_registry() const { return &ops_; } const FunctionLibraryDefinition& flib_def() const { return ops_; } + void CheckDeviceNameIndex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, static_cast(device_names_.size())); + } + + int InternDeviceName(const string& device_name); + + const string& get_assigned_device_name(const Node& node) const { + return device_names_[node.assigned_device_name_index()]; + } + + void set_assigned_device_name_index(Node* node, int device_name_index) { + CheckDeviceNameIndex(device_name_index); + node->assigned_device_name_index_ = device_name_index; + } + + void set_assigned_device_name(Node* node, const string& device_name) { + node->assigned_device_name_index_ = InternDeviceName(device_name); + } + // TODO(josh11b): uint64 hash() const; private: @@ -518,6 +552,30 @@ class Graph { // For generating unique names. int name_counter_ = 0; + // In most graphs, the number of unique values used for the + // Node::assigned_device_name() property is quite small. If the graph is + // large, then this duplication of values can consume a significant amount of + // memory. Instead, we represent the same information using an interning + // table, which consists of a vector of unique strings (device_names_), as + // well a map (device_names_map_) from unique strings to indices within the + // unique string table. + // + // The InternDeviceName() method handles adding a new entry into the table, + // or locating the index of an existing entry. + // + // The fact that Node::assigned_device_name() is implemented using an + // interning table is intentionally public. This allows algorithms that + // frequently access this field to do so efficiently, especially for the case + // where the assigned_device_name of one Node is copied directly from that + // of another Node. + + // A table of the unique assigned device names. Indices do NOT correspond + // to node IDs. Index 0 is always the empty string. + std::vector device_names_; + + // Maps unique device names to indices within device_names_[i]. + std::unordered_map device_names_map_; + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; @@ -550,6 +608,10 @@ inline bool IsIdentity(const Node* node) { return node->IsIdentity(); } // Returns true iff 'n' is a control flow node. inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } +// Returns true if the node only depends on its input's metadata +// (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. +inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } + inline bool IsHostMemoryPreserving(const Node* node) { return IsIdentity(node) || IsControlFlow(node); } @@ -666,6 +728,19 @@ inline gtl::iterator_range Graph::op_nodes() const { return gtl::make_range(begin, end); } +inline void Node::set_assigned_device_name_index(int index) { + graph_->CheckDeviceNameIndex(index); + assigned_device_name_index_ = index; +} + +inline void Node::set_assigned_device_name(const string& device_name) { + graph_->set_assigned_device_name(this, device_name); +} + +inline const string& Node::assigned_device_name() const { + return graph_->get_assigned_device_name(*this); +} + } // namespace tensorflow #endif // TENSORFLOW_GRAPH_GRAPH_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 19442d8c087..7f2cc45d6a5 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -612,6 +612,9 @@ Status GraphConstructor::Convert() { std::vector inputs; int processed = 0; + + std::vector input_already_exists; + // Process the NodeDefs in topological order. // (InitFromEdges() sets this up by filling in ready_ with nodes that have no // inputs, pending_counts_ with the number of inputs for each node and @@ -631,8 +634,8 @@ Status GraphConstructor::Convert() { // importing refers to a preexisting node in g_ (i.e. input[i] existed prior // to importing gdef_). Conversely, input_already_exists[i] is false iff // the input refers to a node in gdef_. - std::vector input_already_exists(original_node_def.input_size(), - false); + input_already_exists.clear(); + input_already_exists.resize(original_node_def.input_size(), false); if (opts_.importing) { // TODO(ashankar): The line below means an additional copy of the NodeDef, From d6fe47af57f27cdf7a2edc5c8cb8c38393099748 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 14:44:53 -0700 Subject: [PATCH 22/72] Use tensorflow::StringPiece in literal_util. Use template for RepeatedField assignment. PiperOrigin-RevId: 157765477 --- tensorflow/compiler/xla/literal_util.cc | 45 ++++++++++++++----------- tensorflow/compiler/xla/literal_util.h | 4 +-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 5162c2b0cc3..4648680dc53 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -308,7 +308,7 @@ Status Literal::Copy(const Literal& src_literal, auto literal = MakeUnique(); *literal->mutable_shape() = ShapeUtil::MakeShape(U8, {static_cast(value.size())}); - literal->set_u8s(value.ToString()); + literal->set_u8s(tensorflow::StringPiece(value.ToString())); return literal; } @@ -1130,10 +1130,16 @@ void Literal::Resize(int64 num_elements, half value) { } } -template -static void CopyToRepeatedField(proto2::RepeatedField* dest, +template +static void CopyToRepeatedField(RepeatedFieldT* dest, const std::vector& src) { - *dest = proto2::RepeatedField(src.begin(), src.end()); + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +template +static void CopyToRepeatedBoolField(RepeatedFieldT* dest, + const BoolVector& src) { + *dest = RepeatedFieldT(src.begin(), src.end()); } LiteralProto Literal::ToProto() const { @@ -1143,24 +1149,23 @@ LiteralProto Literal::ToProto() const { switch (shape().element_type()) { case PRED: if (preds().begin()) { - *proto.mutable_preds() = - proto2::RepeatedField(preds().begin(), preds().end()); + CopyToRepeatedBoolField(proto.mutable_preds(), preds()); } break; case U8: *proto.mutable_u8s() = u8s_string(); break; case S32: - CopyToRepeatedField(proto.mutable_s32s(), s32s()); + CopyToRepeatedField(proto.mutable_s32s(), s32s()); break; case S64: - CopyToRepeatedField(proto.mutable_s64s(), s64s()); + CopyToRepeatedField(proto.mutable_s64s(), s64s()); break; case U32: - CopyToRepeatedField(proto.mutable_u32s(), u32s()); + CopyToRepeatedField(proto.mutable_u32s(), u32s()); break; case U64: - CopyToRepeatedField(proto.mutable_u64s(), u64s()); + CopyToRepeatedField(proto.mutable_u64s(), u64s()); break; case F16: *proto.mutable_f16s() = @@ -1168,10 +1173,10 @@ LiteralProto Literal::ToProto() const { f16s_.size() / sizeof(half)); break; case F32: - CopyToRepeatedField(proto.mutable_f32s(), f32s()); + CopyToRepeatedField(proto.mutable_f32s(), f32s()); break; case F64: - CopyToRepeatedField(proto.mutable_f64s(), f64s()); + CopyToRepeatedField(proto.mutable_f64s(), f64s()); break; case TUPLE: for (const auto& tuple : tuple_literals()) { @@ -1185,9 +1190,9 @@ LiteralProto Literal::ToProto() const { return proto; } -template +template static void CopyFromRepeatedField(std::vector* dest, - const proto2::RepeatedField& src) { + const RepeatedFieldT& src) { *dest = std::vector(src.begin(), src.end()); } @@ -1206,16 +1211,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { set_u8s(literal_proto.u8s()); break; case S32: - CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); break; case S64: - CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); break; case U32: - CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); break; case U64: - CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); break; case F16: { const string& s(literal_proto.f16s()); @@ -1225,10 +1230,10 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { break; } case F32: - CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); break; case F64: - CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); + CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); break; case TUPLE: for (const auto& proto : literal_proto.tuple_literals()) { diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 31f08150ef8..b2b63cd9e25 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -228,13 +228,13 @@ class Literal { int u8s_size() const { return u8s().size(); } const std::vector& u8s() const { return u8s_; } void set_u8s(const std::vector& value) { u8s_ = value; } - void set_u8s(absl::string_view value) { + void set_u8s(tensorflow::StringPiece value) { u8s_ = std::vector(value.size()); u8s_.clear(); append_u8s(value); } - void append_u8s(absl::string_view value) { + void append_u8s(tensorflow::StringPiece value) { u8s_.insert(u8s_.end(), value.begin(), value.end()); } From cc346e69063824eb2e9a8a66dec6e8878da69f9e Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 1 Jun 2017 15:17:54 -0700 Subject: [PATCH 23/72] Strip the :x suffix when generating control inputs from input names PiperOrigin-RevId: 157770257 --- tensorflow/core/grappler/optimizers/constant_folding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 291d7f35bc4..c9169d63f4b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -176,7 +176,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { // Turn the inputs into control dependencies. CHECK_EQ(1, node.input_size()); - node.set_input(0, strings::StrCat("^", node.input(0))); + node.set_input(0, strings::StrCat("^", NodeName(node.input(0)))); } } } From 2e44be35dc037b9c191569fd43caf1a7fcfceaec Mon Sep 17 00:00:00 2001 From: Vinu Rajashekhar Date: Thu, 1 Jun 2017 15:18:42 -0700 Subject: [PATCH 24/72] Adds a protected DeleteResourceMgr(...) method in Device. PiperOrigin-RevId: 157770378 --- tensorflow/core/common_runtime/device.cc | 6 +++++- tensorflow/core/common_runtime/device.h | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index aa8a2d989bf..8fc64fff69a 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes) rmgr_ = new ResourceMgr(parsed_name_.job); } -Device::~Device() { delete rmgr_; } +Device::~Device() { + if (rmgr_ != nullptr) { + DeleteResourceMgr(); + } +} // static DeviceAttributes Device::BuildDeviceAttributes( diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 11024805cb2..7312226f388 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -151,6 +151,12 @@ class Device : public DeviceBase { return BuildDeviceAttributes(name, device, memory_limit, locality, ""); } + protected: + void DeleteResourceMgr() { + delete rmgr_; + rmgr_ = nullptr; + } + private: const DeviceAttributes device_attributes_; DeviceNameUtils::ParsedName parsed_name_; From 8032e1f75dd6a56b39a07890f60acc6b275d0683 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Thu, 1 Jun 2017 15:24:16 -0700 Subject: [PATCH 25/72] Make function instantiation use std::vector instead of GraphDef It's about to turn into std::vector; this change gets us partway there. RELNOTES: n/a PiperOrigin-RevId: 157771141 --- tensorflow/core/common_runtime/function.cc | 2 +- .../core/common_runtime/function_test.cc | 4 +- tensorflow/core/framework/function.cc | 68 ++++++----- tensorflow/core/framework/function.h | 3 +- tensorflow/core/framework/function_test.cc | 18 +-- tensorflow/core/framework/graph_def_util.cc | 8 ++ tensorflow/core/framework/graph_def_util.h | 1 + tensorflow/core/graph/graph_constructor.cc | 106 ++++++++++++------ tensorflow/core/graph/graph_constructor.h | 6 + tensorflow/core/util/equal_graph_def.cc | 32 ++++-- tensorflow/core/util/equal_graph_def.h | 2 + 11 files changed, 162 insertions(+), 88 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 93bd3a6adbe..6e0f312bc04 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper( GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - Status s = ConvertGraphDefToGraph(opts, result.gdef, graph); + Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph); if (!s.ok()) { delete graph; } else { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index e27fc3898dc..dec6ca996aa 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test { GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g)); const int version = g->versions().producer(); LocalExecutorParams params; @@ -949,7 +949,7 @@ GraphDef Optimize(const std::function& pass, GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get())); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get())); pass(g.get()); std::unique_ptr g1(new Graph(OpRegistry::Global())); CopyGraph(*g, g1.get()); diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 186095201d1..9026075a2f0 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -140,7 +140,7 @@ class FunctionInstantiationHelper { FunctionInstantiationHelper(GetFunctionSignature get_function, InstantiationResult* result) : get_function_(std ::move(get_function)), result_(*result) { - result_.gdef.Clear(); + result_.nodes.clear(); } // Builds index for nodes that can be used as node's input arguments. @@ -151,15 +151,14 @@ class FunctionInstantiationHelper { TF_RETURN_IF_ERROR( ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); CHECK_GE(dtypes.size(), size_t{1}); - GraphDef* gdef = &result_.gdef; - int arg_index = gdef->node_size(); + int arg_index = result_.nodes.size(); TF_RETURN_IF_ERROR( AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); - // Creates dtypes.size() nodes in the gdef. + // Creates dtypes.size() nodes in the graph. for (size_t i = 0; i < dtypes.size(); ++i) { TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i), {true, arg_index, 0, false, {dtypes[i]}})); - DCHECK_EQ(arg_index, gdef->node_size()); + DCHECK_EQ(arg_index, result_.nodes.size()); string name = arg_def.name(); if (dtypes.size() > 1) { strings::StrAppend(&name, "_", i); @@ -332,13 +331,13 @@ class FunctionInstantiationHelper { // Adds the actual node inputs to the result graph by converting indexes to // the node names. void AddNodeInputs() { - for (int i = 0; i < result_.gdef.node_size(); i++) { + for (int i = 0; i < result_.nodes.size(); i++) { NodeInfo& node_info = nodes_[i]; for (const auto& p : node_info.data_inputs) { - result_.gdef.mutable_node(i)->add_input(Name(p.first, p.second)); + result_.nodes[i].add_input(Name(p.first, p.second)); } for (int index : node_info.control_inputs) { - result_.gdef.mutable_node(i)->add_input(Dep(index)); + result_.nodes[i].add_input(Dep(index)); } } } @@ -348,11 +347,10 @@ class FunctionInstantiationHelper { // node's input arguments. // // If is_func_arg is true, the name is a function's argument. In - // this case, the produced graph def has gdef.node[nid ... nid + - // dtype.size()). + // this case, the produced graph def has node[nid:nid + dtype.size()]. // // Otherwise, the name is a function body's node return value. In - // this case, the produced graph def has one node gdef.node[nid] and + // this case, the produced graph def has one node node[nid] and // the node's output index [idx ... idx + num) corresponds to the // named outputs. // @@ -398,10 +396,11 @@ class FunctionInstantiationHelper { } NodeDef* AddNode(const string& name) { - NodeDef* gnode = result_.gdef.add_node(); + result_.nodes.emplace_back(); + NodeDef* gnode = &result_.nodes.back(); gnode->set_name(name); nodes_.push_back({name, {}, {}}); - CHECK_EQ(result_.gdef.node_size(), nodes_.size()); + CHECK_EQ(result_.nodes.size(), nodes_.size()); return gnode; } @@ -429,7 +428,7 @@ class FunctionInstantiationHelper { // Control inputs (dependencies). std::vector control_inputs; }; - // nodes_[i] is the information about result_.gdef.node(i). + // nodes_[i] is the information about result_.nodes[i]. std::vector nodes_; }; @@ -545,17 +544,17 @@ string Print(const FunctionDef& fdef) { return out; } -string Print(const GraphDef& gdef) { +string Print(gtl::ArraySlice nodes) { std::vector arg; std::vector ret; std::vector body; - for (const NodeDef& n : gdef.node()) { - if (n.op() == "_Arg") { - arg.push_back(&n); - } else if (n.op() == "_Retval") { - ret.push_back(&n); + for (const NodeDef* n : nodes) { + if (n->op() == "_Arg") { + arg.push_back(n); + } else if (n->op() == "_Retval") { + ret.push_back(n); } else { - body.push_back(&n); + body.push_back(n); } } auto comp = [](const NodeDef* x, const NodeDef* y) { @@ -570,12 +569,11 @@ string Print(const GraphDef& gdef) { string out; strings::StrAppend(&out, "\n("); auto get_type = [](const NodeDef& n) { - for (auto a : n.attr()) { - if (a.first == "T") { - return DataTypeString(a.second.type()); - } + DataType dt; + if (!GetNodeAttr(n, "T", &dt).ok()) { + dt = DT_INVALID; } - return DataTypeString(DT_INVALID); + return DataTypeString(dt); }; for (size_t i = 0; i < arg.size(); ++i) { const NodeDef* n = arg[i]; @@ -663,13 +661,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), - result->gdef.node_size() + i); + result->nodes.size() + i); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; } } - // Emits one gdef.node for each fdef.node_def. + // Emits one node for each fdef.node_def. for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); if (!s.ok()) { @@ -697,7 +695,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, string DebugString(const FunctionDef& func_def) { return Print(func_def); } string DebugString(const GraphDef& instantiated_func_def) { - return Print(instantiated_func_def); + std::vector ptrs; + for (const NodeDef& n : instantiated_func_def.node()) { + ptrs.push_back(&n); + } + return Print(ptrs); +} + +string DebugString(gtl::ArraySlice instantiated_func_nodes) { + std::vector ptrs; + for (const NodeDef& n : instantiated_func_nodes) { + ptrs.push_back(&n); + } + return Print(ptrs); } string DebugStringWhole(const GraphDef& gdef) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 188c3855c6e..6c2da84790c 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -200,7 +200,7 @@ typedef std::function struct InstantiationResult { DataTypeVector arg_types; DataTypeVector ret_types; - GraphDef gdef; + std::vector nodes; }; Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, @@ -216,6 +216,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, // etc.) string DebugString(const FunctionDef& func_def); string DebugString(const GraphDef& instantiated_func_def); +string DebugString(gtl::ArraySlice instantiated_func_nodes); // Returns a debug string for a top level graph (the main program and // its supporting functions defined in its library). diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index c83ecf4e5e8..ba4b15aefd7 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -108,7 +108,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, ControlDep) { @@ -154,7 +154,7 @@ ControlDep(x:int32) -> (y:int32) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } REGISTER_OP("HasDefaultType") @@ -198,7 +198,7 @@ BackCompat() -> (y:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector()); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, NTimesT) { @@ -234,7 +234,7 @@ NTimesT(x:float, y:float) -> (z:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } // NOTE: This is the simplest Map op. It takes a f:T->U. @@ -299,7 +299,7 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, ControlDeps) { @@ -344,7 +344,7 @@ ControlDeps(x:float) -> () { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, XTimesTwo) { @@ -425,7 +425,7 @@ Test(i:float) -> (o:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } REGISTER_OP("Cond") @@ -493,7 +493,7 @@ MySelect(x:float) -> (z:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } static void HasError(const Status& s, const string& substr) { @@ -1028,7 +1028,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { *proto.add_gradient() = grad; FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); TF_EXPECT_OK(lib_def.AddLibrary(lib_def3)); -}; +} TEST(FunctionLibraryDefinitionTest, ToProto) { FunctionDefLibrary proto1; diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index b76ab40b683..d731003366a 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -39,6 +39,14 @@ string SummarizeGraphDef(const GraphDef& graph_def) { return ret; } +string SummarizeGraphDef(gtl::ArraySlice node_defs) { + string ret; + for (const NodeDef& node : node_defs) { + strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); + } + return ret; +} + Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { for (const NodeDef& node : graph_def.node()) { TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index 56355eaf367..27e3de581ad 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -27,6 +27,7 @@ namespace tensorflow { // Produce a human-readable version of a GraphDef that is more concise // than a text-format proto. string SummarizeGraphDef(const GraphDef& graph_def); +string SummarizeGraphDef(gtl::ArraySlice node_defs); // Validates the syntax of a GraphDef provided externally. // diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 7f2cc45d6a5..28ebf7e8c3d 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -91,24 +91,36 @@ class GraphConstructor { bool importing; }; - static Status Construct(const Options& opts, const GraphDef* gdef, Graph* g, + typedef gtl::ArraySlice NodeDefSlice; + + // versions and library may be nullptr + static Status Construct(const Options& opts, NodeDefSlice node_defs, + const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors) { - TF_RETURN_IF_ERROR(CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION, - TF_GRAPH_DEF_VERSION_MIN_PRODUCER, - "GraphDef", "graph")); - GraphConstructor c(opts, gdef, g, refiner, return_tensors); + if (versions) { + TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, + TF_GRAPH_DEF_VERSION_MIN_PRODUCER, + "GraphDef", "graph")); + } + GraphConstructor c(opts, node_defs, versions, library, g, refiner, + return_tensors); const Status s = c.TryImport(); if (!s.ok()) c.Undo(); return s; } private: - GraphConstructor(const Options& opts, const GraphDef* gdef, Graph* g, + GraphConstructor(const Options& opts, NodeDefSlice node_defs, + const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors) : opts_(opts), - gdef_(gdef), + node_defs_(node_defs), + versions_(versions), + library_(library), g_(g), original_versions_(g->versions()), refiner_(refiner), @@ -159,7 +171,9 @@ class GraphConstructor { // From constructor const Options opts_; - const GraphDef* gdef_; + const NodeDefSlice node_defs_; + const VersionDef* versions_; + const FunctionDefLibrary* library_; Graph* g_; const VersionDef original_versions_; @@ -168,7 +182,7 @@ class GraphConstructor { // May be null. Not owned. std::vector>* return_tensors_; - // Mapping from node name to the index within gdef_ + // Mapping from node name to the index within node_defs_ struct NodeInfo { explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} // std::unordered_map<> requires that we have a default constructor. @@ -183,18 +197,18 @@ class GraphConstructor { // Mapping from node name to the existing node in g_ std::unordered_map existing_nodes_; - // Index of NodeDefs in gdef_ with all inputs already converted. + // Index of NodeDefs in node_defs_ with all inputs already converted. std::vector ready_; - // Mapping between index within gdef_ and the number of inputs that + // Mapping between index within node_defs_ and the number of inputs that // still need to be converted. std::vector pending_count_; - // Mapping between index within gdef_ and the index within gdef_ of + // Mapping between index within node_defs_ and the index within node_defs_ of // all nodes it outputs to. std::vector> outputs_; - // Used in the conversion from gdef_ to g_ to represent the ith input + // Used in the conversion from node_defs_ to g_ to represent the ith input // of a node. struct InputInfo { explicit InputInfo(const string& node_name, Node* n, int i) @@ -205,7 +219,7 @@ class GraphConstructor { int index; }; - // Used in the conversion from gdef_ to g_ to represent an edge from + // Used in the conversion from node_defs_ to g_ to represent an edge from // the node named 'name' to node 'n'. struct EdgeInfo { explicit EdgeInfo(const string& name, int i1, Node* n, int i2) @@ -254,8 +268,8 @@ Status GraphConstructor::EnsureNoNameCollisions() { } } if (opts_.prefix.empty() && opts_.importing) { - for (int n = 0; n < gdef_->node_size(); ++n) { - const string& name = gdef_->node(n).name(); + for (const NodeDef* n : node_defs_) { + const string& name = n->name(); if (existing_nodes_.find(name) != existing_nodes_.end()) { return errors::InvalidArgument("Node '", name, "' already exists in the Graph"); @@ -312,8 +326,8 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() { Status GraphConstructor::BuildNodeIndex() { // Validate the node names and add them to gdef_nodes_. - for (int n = 0; n < gdef_->node_size(); ++n) { - const NodeDef& node_def(gdef_->node(n)); + for (int n = 0; n < node_defs_.size(); ++n) { + const NodeDef& node_def = *node_defs_[n]; if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { return errors::InvalidArgument( "Node '", node_def.name(), @@ -351,13 +365,13 @@ Status GraphConstructor::BuildNodeIndex() { } Status GraphConstructor::InitFromEdges() { - const int num_nodes = gdef_->node_size(); + const int num_nodes = node_defs_.size(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { - const NodeDef& node_def(gdef_->node(n)); + const NodeDef& node_def = *node_defs_[n]; if (IsMerge(node_def)) { // for merge only wait for one non-control input. int32 num_control_edges = 0; @@ -489,7 +503,9 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); AddDefaultsToNodeDef(*op_def, node_def); TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); - TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, gdef_->versions().producer())); + if (versions_) { + TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer())); + } return Status::OK(); } @@ -608,7 +624,9 @@ void GraphConstructor::AddPrefixToNodeDef( Status GraphConstructor::Convert() { // Import functions before adding nodes, since imported nodes may refer to // functions - TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library())); + if (library_) { + TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_)); + } std::vector inputs; int processed = 0; @@ -626,14 +644,14 @@ Status GraphConstructor::Convert() { inputs.clear(); bool has_data_back_edge = false; - const NodeDef& original_node_def = gdef_->node(o); + const NodeDef& original_node_def = *node_defs_[o]; NodeDef imported_node_def; const NodeDef* node_def; // input_already_exists[i] is true iff the i-th input of the node we're // importing refers to a preexisting node in g_ (i.e. input[i] existed prior - // to importing gdef_). Conversely, input_already_exists[i] is false iff - // the input refers to a node in gdef_. + // to importing node_defs_). Conversely, input_already_exists[i] is false + // iff the input refers to a node in node_defs_. input_already_exists.clear(); input_already_exists.resize(original_node_def.input_size(), false); @@ -731,8 +749,8 @@ Status GraphConstructor::Convert() { } } - if (processed < gdef_->node_size()) { - return errors::InvalidArgument(gdef_->node_size() - processed, + if (processed < node_defs_.size()) { + return errors::InvalidArgument(node_defs_.size() - processed, " nodes in a cycle"); } return Status::OK(); @@ -756,20 +774,21 @@ Status GraphConstructor::AddBackEdges() { } Status GraphConstructor::UpdateVersionDef() { + if (versions_ == nullptr) return Status::OK(); + if (!opts_.importing) { - g_->set_versions(gdef_->versions()); + g_->set_versions(*versions_); return Status::OK(); } VersionDef versions = g_->versions(); - versions.set_producer( - std::min(versions.producer(), gdef_->versions().producer())); + versions.set_producer(std::min(versions.producer(), versions_->producer())); versions.set_min_consumer( - std::max(versions.min_consumer(), gdef_->versions().min_consumer())); - if (gdef_->versions().bad_consumers_size() > 0) { + std::max(versions.min_consumer(), versions_->min_consumer())); + if (versions_->bad_consumers_size() > 0) { std::set bad(versions.bad_consumers().begin(), versions.bad_consumers().end()); - bad.insert(gdef_->versions().bad_consumers().begin(), - gdef_->versions().bad_consumers().end()); + bad.insert(versions_->bad_consumers().begin(), + versions_->bad_consumers().end()); versions.clear_bad_consumers(); for (int v : bad) { versions.add_bad_consumers(v); @@ -837,7 +856,20 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g) { ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); - return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr); + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, &refiner, nullptr); +} + +Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + gtl::ArraySlice nodes, Graph* g) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); + // TODO(irving): Copy will go away once NodeInfo exists + std::vector node_defs; + for (const auto& n : nodes) { + node_defs.push_back(&n); + } + return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, + &refiner, nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, @@ -886,7 +918,9 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, refiner->set_graph_def_version( std::min(refiner->graph_def_version(), gdef.versions().producer())); - return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors); + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, refiner, + return_tensors); } void CopyGraph(const Graph& src, Graph* dest) { diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 9b80f211fc6..bc4f23ed2d1 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -46,6 +46,12 @@ struct GraphConstructorOptions { extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g); +// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function +// instantiation. +// TODO(irving): This will turn into std::vector soon. +extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + gtl::ArraySlice nodes, Graph* g); + // Add the graph in GraphDef gdef into an existing Graph *g. // // On error, returns non-OK and leaves *g unmodified. diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index 2db026da56c..8ad91e5adb2 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -24,16 +24,10 @@ limitations under the License. namespace tensorflow { -bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, - string* diff, const EqualGraphDefOptions& options) { - // Intentionally do not check that versions match so that this routine can - // be used for less brittle golden file tests. - return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options); -} - -bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, - const protobuf::RepeatedPtrField& expected, - string* diff, const EqualGraphDefOptions& options) { +template +static bool EqualNodeDefsHelper( + const NodeDefs& actual, const protobuf::RepeatedPtrField& expected, + string* diff, const EqualGraphDefOptions& options) { std::unordered_map actual_index; for (const NodeDef& node : actual) { actual_index[node.name()] = &node; @@ -68,6 +62,24 @@ bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, return true; } +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff, const EqualGraphDefOptions& options) { + // Intentionally do not check that versions match so that this routine can + // be used for less brittle golden file tests. + return EqualNodeDefsHelper(actual.node(), expected.node(), diff, options); +} + +bool EqualGraphDef(gtl::ArraySlice actual, const GraphDef& expected, + string* diff, const EqualGraphDefOptions& options) { + return EqualNodeDefsHelper(actual, expected.node(), diff, options); +} + +bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, + const protobuf::RepeatedPtrField& expected, + string* diff, const EqualGraphDefOptions& options) { + return EqualNodeDefsHelper(actual, expected, diff, options); +} + namespace { string JoinStringField(const protobuf::RepeatedPtrField& f) { diff --git a/tensorflow/core/util/equal_graph_def.h b/tensorflow/core/util/equal_graph_def.h index 1ce6181c2e7..29d0385493f 100644 --- a/tensorflow/core/util/equal_graph_def.h +++ b/tensorflow/core/util/equal_graph_def.h @@ -36,6 +36,8 @@ struct EqualGraphDefOptions { // nodes must be consistent. bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, string* diff, const EqualGraphDefOptions& options = {}); +bool EqualGraphDef(gtl::ArraySlice actual, const GraphDef& expected, + string* diff, const EqualGraphDefOptions& options = {}); // Determines if actual and expected are equal, ignoring: ordering of // attrs, internal attributes (if set in `options`), and control inputs. From 6b16c33b324a6eb106e7a27ca7901dfa33121b27 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 15:51:55 -0700 Subject: [PATCH 26/72] Make audio-related logic use the audio plugin. Previously, fetching audio and related data from TensorBoard used handlers within application.py. We now remove those handlers in favor of routes offered by the audio plugin. ML Dash is updated as well. PiperOrigin-RevId: 157774953 --- tensorflow/tensorboard/backend/application.py | 90 +------------------ .../tensorboard/backend/application_test.py | 25 +----- .../components/tf_backend/backend.ts | 31 ++++++- .../components/tf_backend/router.ts | 11 --- .../tf_backend/test/backendTests.ts | 5 +- tensorflow/tensorboard/http_api.md | 23 +++-- tensorflow/tensorboard/tensorboard.py | 2 + 7 files changed, 51 insertions(+), 136 deletions(-) diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py index cf1c376be08..ef2d0c6d693 100644 --- a/tensorflow/tensorboard/backend/application.py +++ b/tensorflow/tensorboard/backend/application.py @@ -30,7 +30,6 @@ import time import six from six import StringIO -from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin from six.moves.urllib import parse as urlparse import tensorflow as tf @@ -59,6 +58,7 @@ DEFAULT_SIZE_GUIDANCE = { # Once everything has been migrated, we should be able to delete # /data/runs entirely. _MIGRATED_DATA_KEYS = frozenset(( + 'audio', 'histograms', 'images', 'scalars', @@ -69,9 +69,7 @@ LOGDIR_ROUTE = '/logdir' RUNS_ROUTE = '/runs' PLUGIN_PREFIX = '/plugin' PLUGINS_LISTING_ROUTE = '/plugins_listing' -AUDIO_ROUTE = '/' + event_accumulator.AUDIO COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS -INDIVIDUAL_AUDIO_ROUTE = '/individualAudio' GRAPH_ROUTE = '/' + event_accumulator.GRAPH RUN_METADATA_ROUTE = '/' + event_accumulator.RUN_METADATA TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] @@ -161,14 +159,10 @@ class TensorBoardWSGIApp(object): reload_multiplexer(self._multiplexer, path_to_run) self.data_applications = { - DATA_PREFIX + AUDIO_ROUTE: - self._serve_audio, DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: self._serve_compressed_histograms, DATA_PREFIX + GRAPH_ROUTE: self._serve_graph, - DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE: - self._serve_individual_audio, DATA_PREFIX + LOGDIR_ROUTE: self._serve_logdir, # TODO(chizeng): Delete this RPC once we have skylark rules that obviate @@ -209,30 +203,6 @@ class TensorBoardWSGIApp(object): path = DATA_PREFIX + PLUGIN_PREFIX + '/' + plugin.plugin_name + route self.data_applications[path] = app - # We use underscore_names for consistency with inherited methods. - - def _audio_response_for_run(self, run_audio, run, tag): - """Builds a JSON-serializable object with information about run_audio. - - Args: - run_audio: A list of event_accumulator.AudioValueEvent objects. - run: The name of the run. - tag: The name of the tag the audio files all belong to. - - Returns: - A list of dictionaries containing the wall time, step, URL, and - content_type for each audio clip. - """ - response = [] - for index, run_audio_clip in enumerate(run_audio): - response.append({ - 'wall_time': run_audio_clip.wall_time, - 'step': run_audio_clip.step, - 'content_type': run_audio_clip.content_type, - 'query': self._query_for_individual_audio(run, tag, index) - }) - return response - def _path_is_safe(self, path): """Check path is safe (stays within current directory). @@ -337,61 +307,6 @@ class TensorBoardWSGIApp(object): return http_util.Respond( request, compressed_histograms, 'application/json') - @wrappers.Request.application - def _serve_audio(self, request): - """Given a tag and list of runs, serve a list of audio. - - Note that the audio clips themselves are not sent; instead, we respond with - URLs to the audio. The frontend should treat these URLs as opaque and should - not try to parse information about them or generate them itself, as the - format may change. - - Args: - request: A werkzeug.wrappers.Request object. - - Returns: - A werkzeug.Response application. - """ - tag = request.args.get('tag') - run = request.args.get('run') - - audio_list = self._multiplexer.Audio(run, tag) - response = self._audio_response_for_run(audio_list, run, tag) - return http_util.Respond(request, response, 'application/json') - - @wrappers.Request.application - def _serve_individual_audio(self, request): - """Serves an individual audio clip.""" - tag = request.args.get('tag') - run = request.args.get('run') - index = int(request.args.get('index')) - audio = self._multiplexer.Audio(run, tag)[index] - return http_util.Respond( - request, audio.encoded_audio_string, audio.content_type) - - def _query_for_individual_audio(self, run, tag, index): - """Builds a URL for accessing the specified audio. - - This should be kept in sync with _serve_individual_audio. Note that the URL - is *not* guaranteed to always return the same audio, since audio may be - unloaded from the reservoir as new audio comes in. - - Args: - run: The name of the run. - tag: The tag. - index: The index of the audio. Negative values are OK. - - Returns: - A string representation of a URL that will load the index-th - sampled audio in the given run with the given tag. - """ - query_string = urllib.parse.urlencode({ - 'run': run, - 'tag': tag, - 'index': index - }) - return query_string - @wrappers.Request.application def _serve_plugins_listing(self, request): """Serves an object mapping plugin name to whether it is enabled. @@ -418,8 +333,7 @@ class TensorBoardWSGIApp(object): Returns: A werkzeug Response with the following content: - {runName: {audio: [tag4, tag5, tag6], - firstEventTimestamp: 123456.789}} + {runName: {firstEventTimestamp: 123456.789}} """ runs = self._multiplexer.Runs() for run_name, run_data in runs.items(): diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py index 08f3485047a..9729e6395a0 100644 --- a/tensorflow/tensorboard/backend/application_test.py +++ b/tensorflow/tensorboard/backend/application_test.py @@ -168,7 +168,6 @@ class TensorboardServerTest(tf.test.TestCase): { 'run1': { 'compressedHistograms': ['histogram'], - 'audio': ['audio'], # if only_use_meta_graph, the graph is from the metagraph 'graph': True, 'meta_graph': self._only_use_meta_graph, @@ -193,7 +192,7 @@ class TensorboardServerTest(tf.test.TestCase): def testDataPaths_disableAllCaching(self): """Test the format of the /data/runs endpoint.""" - for path in ('/data/runs', '/data/logdir', '/data/audio?run=run1&tag=audio', + for path in ('/data/runs', '/data/logdir', '/data/run_metadata?run=run1&tag=test%20run'): connection = http_client.HTTPConnection('localhost', self._server.server_address[1]) @@ -204,20 +203,6 @@ class TensorboardServerTest(tf.test.TestCase): response.read() connection.close() - def testAudio(self): - """Test listing audio and retrieving an individual audio clip.""" - audio_json = self._getJson('/data/audio?tag=audio&run=run1') - audio_query = audio_json[0]['query'] - # We don't care about the format of the audio query. - del audio_json[0]['query'] - self.assertEqual(audio_json, [{ - 'wall_time': 0, - 'step': 0, - 'content_type': 'audio/wav' - }]) - response = self._get('/data/individualAudio?%s' % audio_query) - self.assertEqual(response.status, 200) - def testGraph(self): """Test retrieving the graph definition.""" response = self._get('/data/graph?run=run1&limit_attr_size=1024' @@ -324,20 +309,12 @@ class TensorboardServerTest(tf.test.TestCase): device_stats = run_metadata.step_stats.dev_stats.add() device_stats.device = 'test device' writer.add_run_metadata(run_metadata, 'test run') - - audio_value = tf.Summary.Audio( - sample_rate=44100, - length_frames=22050, - num_channels=2, - encoded_audio_string=b'', - content_type='audio/wav') writer.add_event( tf.Event( wall_time=0, step=0, summary=tf.Summary(value=[ tf.Summary.Value(tag='histogram', histo=histogram_value), - tf.Summary.Value(tag='audio', audio=audio_value) ]))) writer.flush() diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index 351aeef662a..1c93b4f429a 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -170,8 +170,9 @@ export class Backend { /** * Return a promise showing the Run-to-Tag mapping for audio data. */ - public audioRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'audio')); + public audioTags(): Promise { + return this.requestManager.request( + this.router.pluginRoute('audio', '/tags')); } /** @@ -296,7 +297,7 @@ export class Backend { * Return a promise containing AudioDatums for given run and tag. */ public audio(tag: string, run: string): Promise> { - const url = this.router.audio(tag, run); + const url = (this.router.pluginRunTagRoute('audio', '/audio')(tag, run)); let p: Promise; p = this.requestManager.request(url); return p.then(map(this.createAudio.bind(this))); @@ -356,11 +357,33 @@ export class Backend { } private createAudio(x: AudioMetadata): Audio&Datum { + const pluginRoute = this.router.pluginRoute('audio', '/individualAudio'); + + let query = x.query; + if (pluginRoute.indexOf('?') > -1) { + // The route already has GET parameters. Append our parameters to them. + query = '&' + query; + } else { + // The route lacks GET parameters. We append them. + query = '?' + query; + } + + if (this.router.isDemoMode()) { + query = demoify(query); + } + + let individualAudioUrl = pluginRoute + query; + // Include wall_time just to disambiguate the URL and force the browser + // to reload the audio when the URL changes. The backend doesn't care + // about the value. + individualAudioUrl += + this.router.isDemoMode() ? '.wav' : '&ts=' + x.wall_time; + return { content_type: x.content_type, wall_time: timeToDate(x.wall_time), step: x.step, - url: this.router.individualAudio(x.query), + url: individualAudioUrl, }; } } diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts index 615ecf3d2f2..b31f9b366ea 100644 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ b/tensorflow/tensorboard/components/tf_backend/router.ts @@ -22,8 +22,6 @@ export interface Router { runs: () => string; isDemoMode: () => boolean; compressedHistograms: RunTagUrlFn; - audio: RunTagUrlFn; - individualAudio: (query: string) => string; graph: (run: string, limit_attr_size?: number, large_attrs_key?: string) => string; @@ -57,13 +55,6 @@ export function router(dataDir = 'data', demoMode = false): Router { return url; }; } - function individualAudioUrl(query: string) { - var url = dataDir + '/' + clean('individualAudio?' + query); - if (demoMode) { - url += '.wav'; - } - return url; - } function graphUrl( run: string, limit_attr_size?: number, large_attrs_key?: string) { let query_params = [['run', clean(run)]]; @@ -96,10 +87,8 @@ export function router(dataDir = 'data', demoMode = false): Router { logdir: () => dataDir + '/logdir', runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), isDemoMode: () => demoMode, - individualAudio: individualAudioUrl, graph: graphUrl, compressedHistograms: standardRoute('compressedHistograms'), - audio: standardRoute('audio'), runMetadata: standardRoute('run_metadata', '.pbtxt'), healthPills: () => dataDir + '/plugin/debugger/health_pills', textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), diff --git a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts index 530091c28e2..0ef58157aef 100644 --- a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts +++ b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts @@ -105,9 +105,6 @@ describe('backend tests', () => { const audio = audioClips[0]; assertIsDatum(audio); chai.assert.equal(audio.content_type, 'audio/wav'); - const nonDemoQuery = 'index=0&tag=audio1&run=run1'; - const expectedUrl = demoRouter.individualAudio(nonDemoQuery); - chai.assert.equal(audio.url, expectedUrl); done(); }); }); @@ -138,7 +135,7 @@ describe('backend tests', () => { chai.assert.deepEqual(x, image); next(); }); - backend.audioRuns().then((x) => { + backend.audioTags().then((x) => { chai.assert.deepEqual(x, audio); next(); }); diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md index 0cf788601a7..541394cbe00 100644 --- a/tensorflow/tensorboard/http_api.md +++ b/tensorflow/tensorboard/http_api.md @@ -56,14 +56,12 @@ all of the data available from the TensorBoard server. Here is an example: { "train_run": { "compressedHistograms": ["foo_histogram", "bar_histogram"], - "audio": ["input_audio"], "graph": true, "firstEventTimestamp": 123456.789 "run_metadata": ["forward prop", "inference"] }, "eval": { "compressedHistograms": ["foo_histogram", "bar_histogram"], - "audio": ["input_audio"], "graph": false, "run_metadata": [] } @@ -80,6 +78,7 @@ will have the same tag type across different runs. Each of the following tag types `` has been migrated to `/data/plugin//tags`, and will not appear in the output from this route: + - `audio` - `images` - `scalars` - `histograms` @@ -238,13 +237,13 @@ tags present in the corresponding run. Here is an example: Note that runs without any image tags are included as keys with value the empty array. -## `/audio?run=foo&tag=bar` +## `/data/plugin/audio/audio?run=foo&tag=bar` Gets a sample of AudioMetadatas for the given run and tag. Returns an array of objects containing information about available audio, crucially including the query parameter that may be used to retrieve that audio. -(See /individualAudio for details.) +(See /data/plugin/audio/individualAudio for details.) For example: @@ -256,7 +255,7 @@ For example: # param for /individualAudio } -## `/individualAudio?{{query}}` +## `/data/plugin/audio/individualAudio?{{query}}` Retrieves an individual audio clip. The audio query should not be generated by the frontend, but instead acquired from calling the /audio route (the audio @@ -270,6 +269,20 @@ replaced with other clips. (See Notes for details on the reservoir sampling.) An example call to this route would look like this: /individualAudio?index=0&tagname=input%2Faudio%2F2&run=train +## `/data/plugin/audio/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all audio +tags present in the corresponding run. Here is an example: + + { + "train": ["foo_audio", "bar_audio"], + "eval": ["foo_audio", "bar_audio"], + } + +Note that runs without any audio tags are included as keys with value the empty +array. + ## `/data/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` Returns the graph definition for the given run in gzipped pbtxt format. The diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py index a9d07bd10dc..bce5dd259dd 100644 --- a/tensorflow/tensorboard/tensorboard.py +++ b/tensorflow/tensorboard/tensorboard.py @@ -32,6 +32,7 @@ from werkzeug import serving from tensorflow.tensorboard.backend import application from tensorflow.tensorboard.backend.event_processing import event_file_inspector as efi +from tensorflow.tensorboard.plugins.audio import audio_plugin from tensorflow.tensorboard.plugins.histograms import histograms_plugin from tensorflow.tensorboard.plugins.images import images_plugin from tensorflow.tensorboard.plugins.projector import projector_plugin @@ -203,6 +204,7 @@ def main(unused_argv=None): return 0 else: plugins = [ + audio_plugin.AudioPlugin(), histograms_plugin.HistogramsPlugin(), images_plugin.ImagesPlugin(), scalars_plugin.ScalarsPlugin(), From 4f3ae76996ec5cdf7791d4633615112d80c56120 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 15:52:27 -0700 Subject: [PATCH 27/72] Add beam_search kernels used by BeamSearchDecoder to tensorflow.contrib. PiperOrigin-RevId: 157775011 --- tensorflow/contrib/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index d3fb30ca50f..b4ff1da30ff 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -84,6 +84,7 @@ cc_library( "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", + "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", ], @@ -99,6 +100,7 @@ cc_library( "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", + "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", ], From 2ee09b873a7f658fba151d8e39d2a8bc67e136a6 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Thu, 1 Jun 2017 16:12:12 -0700 Subject: [PATCH 28/72] [XLA] Various improvements to ShapeTree. Add support for holding non-copyable types, operator==, and a CopySubtreeFrom method for copying a subtree from one ShapeTree to another. PiperOrigin-RevId: 157777636 --- tensorflow/compiler/xla/shape_tree.h | 83 ++++++++++++- tensorflow/compiler/xla/shape_tree_test.cc | 134 +++++++++++++++++++++ 2 files changed, 215 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index aa4341d18e1..122d6ce4a98 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,6 +44,7 @@ struct ShapeTreeNode { // Children of this node. std::vector> children; + ShapeTreeNode() = default; explicit ShapeTreeNode(const T& data) : data(data) {} ShapeTreeNode(const ShapeTreeNode& other) @@ -85,8 +86,9 @@ class ShapeTree { public: // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} - // Create ShapeTree with the given shape, and default T values for all nodes. - explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {} + // Create ShapeTree with the given shape, and default-constructed T values for + // all nodes. + explicit ShapeTree(const Shape& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(const Shape& shape, const T& init_value); @@ -127,6 +129,19 @@ class ShapeTree { const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>; Status ForEachMutableElement(const MutableVisitorFunction& func); + // Copy the subtree of values from 'other' rooted at ShapeIndex + // 'source_base_index' into the subtree of value in this ShapeTree rooted at + // 'target_base_index'. + // + // Precondition: The subshape of other.shape() at index source_base_index must + // be compatible with the subshape of shape() at index target_base_index. + void CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index); + + bool operator==(const ShapeTree& other) const; + bool operator!=(const ShapeTree& other) const { return !(*this == other); } + private: using Node = internal::ShapeTreeNode; @@ -134,6 +149,10 @@ class ShapeTree { // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); + // Initialize node->children based on 'shape'. All children have + // default-constructed data values. + void InitChildren(const Shape& shape, Node* node); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). @@ -165,6 +184,24 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, } } +template +void ShapeTree::InitChildren(const Shape& shape, Node* node) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + node->children.emplace_back(new Node()); + InitChildren(shape.tuple_shapes(i), node->children.back().get()); + } + } +} + +template +ShapeTree::ShapeTree(const Shape& shape) : root_(), shape_(shape) { + // The shape_ field is just used to hold the structure of the shape. + // It should not be relied upon to store layout information. + LayoutUtil::ClearLayout(&shape_); + InitChildren(shape_, &root_); +} + template ShapeTree::ShapeTree(const Shape& shape, const T& init_value) : root_(init_value), shape_(shape) { @@ -240,6 +277,48 @@ Status ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { return ForEachMutableHelper(func, &root_, &index); } +template +void ShapeTree::CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index) { + CHECK(ShapeUtil::Compatible( + ShapeUtil::GetSubshape(shape(), target_base_index), + ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ForEachMutableElement( + [this, &other, &source_base_index, &target_base_index]( + const ShapeIndex& index, bool /*is_leaf*/, T* data) { + // Copy the data element only if index is in the + // subtree rooted at target_base_index. + for (int i = 0; i < target_base_index.size(); ++i) { + if (i >= index.size() || index[i] != target_base_index[i]) { + return Status::OK(); + } + } + // Construct source element index to copy from. + ShapeIndex source_index = source_base_index; + for (int i = target_base_index.size(); i < index.size(); ++i) { + source_index.push_back(index[i]); + } + *data = other.element(source_index); + return Status::OK(); + }) + .IgnoreError(); +} + +template +bool ShapeTree::operator==(const ShapeTree& other) const { + bool equal = true; + ForEachElement([this, &other, &equal](const ShapeIndex& index, + bool /*is_leaf*/, const T& data) { + if (data != other.element(index)) { + equal = false; + } + return Status::OK(); + }) + .IgnoreError(); + return equal; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index efb6f422e00..1b9e18023ef 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { EXPECT_DEATH(shape_tree.element({0, 0}), ""); } +TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { + ShapeTree> shape_tree{tuple_shape_}; + EXPECT_EQ(shape_tree.element({2}).get(), nullptr); + *shape_tree.mutable_element({2}) = MakeUnique(42); + EXPECT_EQ(*shape_tree.element({2}), 42); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) { + // Test CopySubtreeFrom method for a single value copied between array-shaped + // ShapeTrees. + ShapeTree source(array_shape_); + *source.mutable_element(/*index=*/{}) = 42; + ShapeTree destination(array_shape_, 123); + + EXPECT_EQ(destination.element(/*index=*/{}), 123); + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 42); +} + +TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of all elements from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2}), 13); +} + +TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of a single element from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{0}, + /*target_base_index=*/{1}); + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 11); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) { + // Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a + // nested-tuple-shaped ShapeTree. + ShapeTree source( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_})); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + + ShapeTree destination(nested_tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{2, 0}); + + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); + EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) { + // Test CopySubtreeFrom method for a copy from a nested-tuple-shape. + ShapeTree source(nested_tuple_shape_, 42); + *source.mutable_element(/*index=*/{1}) = 10; + *source.mutable_element(/*index=*/{1, 0}) = 11; + *source.mutable_element(/*index=*/{1, 1}) = 12; + + ShapeTree destination( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{1}, + /*target_base_index=*/{}); + + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); +} + +TEST_F(ShapeTreeTest, OperatorEquals) { + { + ShapeTree a(array_shape_, 123); + ShapeTree b(array_shape_, 42); + ShapeTree c(array_shape_, 42); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + } + { + ShapeTree a(tuple_shape_); + *a.mutable_element(/*index=*/{}) = 10; + *a.mutable_element(/*index=*/{0}) = 11; + *a.mutable_element(/*index=*/{1}) = 12; + + ShapeTree b(tuple_shape_); + *b.mutable_element(/*index=*/{}) = 10; + *b.mutable_element(/*index=*/{0}) = 42; + *b.mutable_element(/*index=*/{1}) = 11; + + ShapeTree c(tuple_shape_); + *c.mutable_element(/*index=*/{}) = 10; + *c.mutable_element(/*index=*/{0}) = 42; + *c.mutable_element(/*index=*/{1}) = 11; + + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + EXPECT_FALSE(b != c); + } +} + } // namespace } // namespace xla From 5bc685d7f16b0fc27b936e63fa01668e4af4034c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 16:24:40 -0700 Subject: [PATCH 29/72] [XLA] If an op has a single "large" operand, we want to fuse this op into some of its consumers, even if we can't fuse into all of them. PiperOrigin-RevId: 157779106 --- .../xla/service/instruction_fusion.cc | 21 ++++++- .../xla/service/instruction_fusion_test.cc | 62 ++++++++++++++++--- 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 5069215031b..721640cdbd8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -151,7 +151,26 @@ StatusOr InstructionFusion::Run(HloModule* module) { return true; }; - if (std::all_of(hlo->users().begin(), hlo->users().end(), + // An "effectively unary" operation is one that has one "large" + // input with the others being negligible in terms of memory usage. + // We use "has a smaller true rank than the output" as a heuristic + // for "negligible" memory usage. + auto effectively_unary = [](HloInstruction* hlo) { + if (hlo->operands().size() == 1) { + return true; + } + auto output_rank = ShapeUtil::TrueRank(hlo->shape()); + return std::count_if( + hlo->operands().begin(), hlo->operands().end(), + [output_rank](HloInstruction* operand) { + return ((operand->opcode() != HloOpcode::kBroadcast) && + ShapeUtil::TrueRank(operand->shape()) >= + output_rank); + }) <= 1; + }; + + if (effectively_unary(hlo) || + std::all_of(hlo->users().begin(), hlo->users().end(), user_fusable_into_hlo)) { all_consumers_fusable.insert(hlo); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9a79e4c3824..d2df0b699ef 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {16, 16}), "0")); - HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); - HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1)); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); } +TEST_F(InstructionFusionTest, AllowUnaryDuplication) { + HloComputation::Builder builder(TestName()); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + HloInstruction* unary1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); + builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + HloInstruction* unary2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto small_shape = ShapeUtil::MakeShape(F32, {16}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, small_shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace xla From eebd44123674b3db65d7790778c9b41945406ec9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 16:34:21 -0700 Subject: [PATCH 30/72] Add a frontend method for retrieving numeric alerts from the debugger plugin. This route responds with a list of alerts (occurrences of bad values) in ascending timestamp order. PiperOrigin-RevId: 157780270 --- .../components/tf_backend/backend.ts | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index 1c93b4f429a..a7a222beaaf 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -100,6 +100,20 @@ export type HealthPillDatum = Datum & HealthPill; // data entries. export interface HealthPillsResponse { [key: string]: HealthPillDatum[]; } +// An object that encapsulates an alert issued by the debugger. This alert is +// sent by debugging libraries after bad values (NaN, +/- Inf) are encountered. +export interface DebuggerNumericsAlertReport { + device_name: string; + tensor_name: string; + first_timestamp: number; + nan_event_count: number; + neg_inf_event_count: number; + pos_inf_event_count: number; +} +// A DebuggerNumericsAlertReportResponse contains alerts issued by the debugger +// in ascending order of timestamp. This helps the user identify for instance +// when bad values first appeared in the model. +export type DebuggerNumericsAlertReportResponse = DebuggerNumericsAlertReport[]; export const TYPES = [ 'scalar', 'histogram', 'compressedHistogram', 'graph', 'image', 'audio', @@ -240,7 +254,8 @@ export class Backend { } /** - * Returns a promise for requesting the health pills for a list of nodes. + * Returns a promise for requesting the health pills for a list of nodes. This + * route is used by the debugger plugin. */ public healthPills(nodeNames: string[], step?: number): Promise { @@ -258,6 +273,16 @@ export class Backend { return this.requestManager.request(this.router.healthPills(), postData); } + /** + * Returns a promise for alerts for bad values (detected by the debugger). + * This route is used by the debugger plugin. + */ + public debuggerNumericsAlerts(): + Promise { + return this.requestManager.request( + this.router.pluginRoute('debugger', '/numerics_alert_report')); + } + /** * Return a promise containing HistogramDatums for given run and tag. */ From f7de292df3534f8758236f9255e538ce6d402b72 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 16:39:34 -0700 Subject: [PATCH 31/72] Update placeholder nodes' shapes in the GraphDef to reflect manually specified values for incomplete placeholder shapes. Previously, these overrides were only specified in the feed nodes, which improves estimates when using dynamic shapes but not when using static shapes. With this change, static shapes also benefit. PiperOrigin-RevId: 157780800 --- .../core/grappler/grappler_item_builder.cc | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 384402ad291..8f7333f1dbf 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -183,13 +183,17 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( // from it. We do this because in newer protos, the input placeholder // shape is not empty if the shape is partially defined. TensorShape shape; + TensorShapeProto shape_proto; std::vector dims; for (const auto& dim_proto : node.attr().at("shape").shape().dim()) { if (cfg.placeholder_unknown_output_shape_dim >= 0 && dim_proto.size() == -1) { dims.push_back(cfg.placeholder_unknown_output_shape_dim); + shape_proto.add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); } else { dims.push_back(dim_proto.size()); + shape_proto.add_dim()->set_size(dim_proto.size()); } } Status make_shape_status = @@ -214,6 +218,7 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1) && (node.attr().at("_output_shapes").list().shape(0).dim_size() != 0)) { shape.Clear(); + shape_proto.clear_dim(); for (int dim_i = 0; dim_i < node.attr().at("_output_shapes").list().shape(0).dim_size(); @@ -222,19 +227,27 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( node.attr().at("_output_shapes").list().shape(0).dim(dim_i); if (dim.size() == -1) { shape.AddDim(cfg.placeholder_unknown_output_shape_dim); + shape_proto.add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); } else { - shape.AddDim(node.attr() - .at("_output_shapes") - .list() - .shape(0) - .dim(dim_i) - .size()); + int size = node.attr() + .at("_output_shapes") + .list() + .shape(0) + .dim(dim_i) + .size(); + shape.AddDim(size); + shape_proto.add_dim()->set_size(size); } } } Tensor fake_input(type, shape); InitializeTensor(type, &fake_input); new_item->feed.emplace_back(node.name(), fake_input); + // Set the shape of the node in the graph. This is needed for statically + // inferring shapes and is a no-op when dynamically inferring shapes as + // the Placeholder shape will match the shape passed from new_item->feed. + *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto; } // Delete user specified placement if requested. From fb4bc806a83b629652f7919f4e0c0e9ae08198c0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Jun 2017 16:45:07 -0700 Subject: [PATCH 32/72] Fix flakiness in GpuMultiSessionMemoryTest. PiperOrigin-RevId: 157781368 --- .../python/kernel_tests/basic_gpu_test.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py index e6d0c06d140..dbbc2de811e 100644 --- a/tensorflow/python/kernel_tests/basic_gpu_test.py +++ b/tensorflow/python/kernel_tests/basic_gpu_test.py @@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase): class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase): """Tests concurrent sessions executing on the same GPU.""" - def _run_session(self, results): + def _run_session(self, session, results): n_iterations = 500 - with self.test_session(use_gpu=True) as s: + with session as s: data = variables.Variable(1.0) with ops.device('/gpu:0'): random_seed.set_random_seed(1) @@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase): for _ in xrange(n_iterations): value = s.run(x4) - results.append(value) - if value != results[0]: + results.add(value.flat[0]) + if len(results) != 1: break def testConcurrentSessions(self): - if not test.is_gpu_available(): - return - n_threads = 4 - results = [[]] * n_threads - threads = [ - threading.Thread(target=self._run_session, args=(results[i],)) - for i in xrange(n_threads) - ] + threads = [] + results = [] + for _ in xrange(n_threads): + session = self.test_session(graph=ops.Graph(), use_gpu=True) + results.append(set()) + args = (session, results[-1]) + threads.append(threading.Thread(target=self._run_session, args=args)) + for thread in threads: thread.start() for thread in threads: thread.join() - flat_results = [x for x in itertools.chain(*results)] - self.assertNotEqual(0, len(flat_results)) - for result in flat_results: - self.assertEqual(result, flat_results[0]) + flat_results = set([x for x in itertools.chain(*results)]) + self.assertEqual(1, + len(flat_results), + 'Expected single value, got %r' % flat_results) if __name__ == '__main__': From a65a70ea5a59728087e5cb1d01d58248604bfb3d Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Thu, 1 Jun 2017 17:05:43 -0700 Subject: [PATCH 33/72] Fix pip tests under contrib/text PiperOrigin-RevId: 157783952 --- tensorflow/contrib/text/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD index ff69c4e2cbe..6bcb03238cc 100644 --- a/tensorflow/contrib/text/BUILD +++ b/tensorflow/contrib/text/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "py_test", "tf_custom_op_library", "tf_custom_op_py_library", "tf_gen_op_libs", From 40411cd5c68e4f91a1fc0d5861fac88b404329bb Mon Sep 17 00:00:00 2001 From: Dandelion Man? Date: Thu, 1 Jun 2017 17:06:15 -0700 Subject: [PATCH 34/72] Refactor projector plugin to only use tf public methods. Remove all reference to the PluginAsset system, which is deprecated. Part of an ongoing effort to have TensorBoard only consume the public TensorFlow api. PiperOrigin-RevId: 157784016 --- .../plugins/projector/projector_plugin.py | 169 +------ .../projector/projector_plugin_test.py | 439 ------------------ 2 files changed, 14 insertions(+), 594 deletions(-) diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin.py b/tensorflow/tensorboard/plugins/projector/projector_plugin.py index 9b9e3197c6b..9a3a305d53a 100644 --- a/tensorflow/tensorboard/plugins/projector/projector_plugin.py +++ b/tensorflow/tensorboard/plugins/projector/projector_plugin.py @@ -23,14 +23,11 @@ import imghdr import math import os import numpy as np -from six import BytesIO import tensorflow as tf from werkzeug import wrappers from google.protobuf import json_format from google.protobuf import text_format -from tensorflow.python.lib.io import file_io -from tensorflow.python.summary import plugin_asset from tensorflow.tensorboard.backend.http_util import Respond from tensorflow.tensorboard.plugins.base_plugin import TBPlugin from tensorflow.tensorboard.plugins.projector import projector_config_pb2 @@ -141,148 +138,8 @@ class EmbeddingMetadata(object): self.name_to_values[column_name] = column_values -class ProjectorPluginAsset(plugin_asset.PluginAsset): - """Provides a registry for assets needed by the Projector plugin.""" - plugin_name = _PLUGIN_NAME - - def __init__(self): - self._config = projector_config_pb2.ProjectorConfig() - self._assets = {} - self._used_names = set() - - def add_metadata_for_embedding_variable(self, - var_name, - metadata=None, - thumbnails=None, - thumbnail_dim=None): - """Adds metadata for an embedding variable stored in a checkpoint file. - - Args: - var_name: Name of the embedding variable. - metadata: Optional. A `Metadata` container mapping column header names to - the values of that column. - thumbnails: Optional. A 4D `ndarray` or a list of 3D `ndarray`s. Each - 3D array represents the pixels [height, width, channels] of a single - thumbnail. The i-th image corresponds to the i-th row (data point) of - the embedding variable. - thumbnail_dim: Required if `thumbnails` is provided. A tuple - (height, width) of a single thumbnail in the sprite. - - Raises: - ValueError: If the name of the variable was previously used in this - object, or both `metadata` and `thumbnails` are None. - """ - - if metadata is None and thumbnails is None: - raise ValueError('At least one of (`metadata`, `thumbnails`) must be ' - 'provided') - self._convert_embedding_to_assets(var_name, None, metadata, thumbnails, - thumbnail_dim) - - def add_embedding(self, - name, - values, - metadata=None, - thumbnails=None, - thumbnail_dim=None): - """Adds an embedding asset to be visualized by the Embedding Projector. - - Args: - name: Name of the embedding. - values: 2D `ndarray` of shape [numPoints, dimensionality] - containing the embedding values. The i-th row corresponds to the i-th - data point. - metadata: Optional. A `Metadata` container mapping column header names to - the values of that column. - thumbnails: Optional. A 4D `ndarray` or a list of 3D `ndarray`s. Each - 3D array represents the pixels [height, width, channels] of a single - thumbnail. The i-th image corresponds to the i-th row (data point) of - the `values` matrix. - thumbnail_dim: Required if `thumbnails` is provided. A tuple - (height, width) of a single thumbnail in the sprite. - - Raises: - ValueError: If the name of the embedding was previously used in this - object, or `values` is not a 2D array. - """ - - # Sanity checks. - if values.ndim != 2: - raise ValueError('`values` must be a 2D array, but is ' - '%d-D' % values.ndim) - self._convert_embedding_to_assets(name, values, metadata, thumbnails, - thumbnail_dim) - - def _convert_embedding_to_assets(self, - name, - values=None, - metadata=None, - thumbnails=None, - thumbnail_dim=None): - """Converts the data associated with embeddings into serializable assets.""" - - if name in self._used_names: - raise ValueError('The name "%s" was previously used' % name) - if thumbnails is not None and not thumbnail_dim: - raise ValueError('`thumbnail_dim` is required when `thumbnails` is ' - 'provided') - if thumbnail_dim is not None: - if not isinstance(thumbnail_dim, (list, tuple, np.ndarray)): - raise ValueError('`thumbnail_dim` must be either a list, tuple or ' - '`ndarray`') - if len(thumbnail_dim) != 2: - raise ValueError('`thumbnail_dim` must be of length 2, ' - 'but is of length %d' % len(thumbnail_dim)) - if metadata: - if values is not None and len(values) != metadata.num_points: - raise ValueError('First dimension of `values` "%d" must match ' - '`metadata.num_points` "%d"' % (len(values), - metadata.num_points)) - if not metadata.column_names: - raise ValueError('The provided metadata has no columns. Did you forget ' - 'to add a column?') - - self._used_names.add(name) - embedding_info = self._config.embeddings.add() - embedding_info.tensor_name = name - - if values is not None: - bytes_io = BytesIO() - np.savetxt(bytes_io, values, fmt='%.6g', delimiter='\t') - fname = '{}_values.tsv'.format(name) - embedding_info.tensor_path = fname - embedding_info.tensor_shape.extend(values.shape) - self._assets[fname] = bytes_io.getvalue() - - if metadata: - metadata_tsv_lines = [] - should_have_header = len(metadata.column_names) > 1 - if should_have_header: - metadata_tsv_lines.append('\t'.join(metadata.column_names)) - - for i in range(metadata.num_points): - row = [ - metadata.name_to_values[col_name][i] - for col_name in metadata.column_names - ] - metadata_tsv_lines.append('\t'.join(map(str, row))) - fname = '{}_metadata.tsv'.format(name) - embedding_info.metadata_path = fname - self._assets[fname] = '\n'.join(metadata_tsv_lines) + '\n' - - if thumbnails is not None: - fname = '{}_sprite.png'.format(name) - embedding_info.sprite.image_path = fname - embedding_info.sprite.single_image_dim.extend(thumbnail_dim) - self._assets[fname] = _make_sprite_image(thumbnails, thumbnail_dim) - - def assets(self): - self._assets[PROJECTOR_FILENAME] = text_format.MessageToString(self._config) - return self._assets - - def _read_tensor_tsv_file(fpath): - with file_io.FileIO(fpath, 'r') as f: + with tf.gfile.GFile(fpath, 'r') as f: tensor = [] for line in f: if line: @@ -304,8 +161,9 @@ def _latest_checkpoints_changed(configs, run_path_pairs): if run_name not in configs: config = projector_config_pb2.ProjectorConfig() config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) - if file_io.file_exists(config_fpath): - file_content = file_io.read_file_to_string(config_fpath) + if tf.gfile.Exists(config_fpath): + with tf.gfile.GFile(config_fpath, 'r') as f: + file_content = f.read() text_format.Merge(file_content, config) else: config = configs[run_name] @@ -466,8 +324,9 @@ class ProjectorPlugin(TBPlugin): for run_name, assets_dir in run_path_pairs: config = projector_config_pb2.ProjectorConfig() config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) - if file_io.file_exists(config_fpath): - file_content = file_io.read_file_to_string(config_fpath) + if tf.gfile.Exists(config_fpath): + with tf.gfile.GFile(config_fpath, 'r') as f: + file_content = f.read() text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: @@ -592,12 +451,12 @@ class ProjectorPlugin(TBPlugin): 'No metadata file found for tensor "%s" in the config file "%s"' % (name, self.config_fpaths[run]), 'text/plain', 400) fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not file_io.file_exists(fpath) or file_io.is_directory(fpath): + if not tf.gfile.Exists(fpath) or tf.gfile.IsDirectory(fpath): return Respond(request, '"%s" not found, or is not a file' % fpath, 'text/plain', 400) num_header_rows = 0 - with file_io.FileIO(fpath, 'r') as f: + with tf.gfile.GFile(fpath, 'r') as f: lines = [] # Stream reading the file with early break in case the file doesn't fit in # memory. @@ -639,7 +498,7 @@ class ProjectorPlugin(TBPlugin): if embedding and embedding.tensor_path: fpath = _rel_to_abs_asset_path(embedding.tensor_path, self.config_fpaths[run]) - if not file_io.file_exists(fpath): + if not tf.gfile.Exists(fpath): return Respond(request, 'Tensor file "%s" does not exist' % fpath, 'text/plain', 400) @@ -688,12 +547,12 @@ class ProjectorPlugin(TBPlugin): 'No bookmarks file found for tensor "%s" in the config file "%s"' % (name, self.config_fpaths[run]), 'text/plain', 400) fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not file_io.file_exists(fpath) or file_io.is_directory(fpath): + if not tf.gfile.Exists(fpath) or tf.gfile.IsDirectory(fpath): return Respond(request, '"%s" not found, or is not a file' % fpath, 'text/plain', 400) bookmarks_json = None - with file_io.FileIO(fpath, 'rb') as f: + with tf.gfile.GFile(fpath, 'rb') as f: bookmarks_json = f.read() return Respond(request, bookmarks_json, 'application/json') @@ -723,10 +582,10 @@ class ProjectorPlugin(TBPlugin): fpath = os.path.expanduser(embedding_info.sprite.image_path) fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not file_io.file_exists(fpath) or file_io.is_directory(fpath): + if not tf.gfile.Exists(fpath) or tf.gfile.IsDirectory(fpath): return Respond(request, '"%s" does not exist or is directory' % fpath, 'text/plain', 400) - f = file_io.FileIO(fpath, 'rb') + f = tf.gfile.GFile(fpath, 'rb') encoded_image_string = f.read() f.close() image_type = imghdr.what(None, encoded_image_string) diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py index f3468c3ffe2..06cf2c3d0d4 100644 --- a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py +++ b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py @@ -30,7 +30,6 @@ from werkzeug import test as werkzeug_test from werkzeug import wrappers from google.protobuf import text_format -from tensorflow.python.summary import plugin_asset from tensorflow.tensorboard.backend import application from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -170,8 +169,6 @@ class ProjectorAppTest(tf.test.TestCase): def testEndpointsNoAssets(self): g = tf.Graph() - with g.as_default(): - plugin_asset.get_plugin_asset(projector_plugin.ProjectorPluginAsset) fw = tf.summary.FileWriter(self.log_dir, graph=g) fw.close() @@ -180,203 +177,6 @@ class ProjectorAppTest(tf.test.TestCase): run_json = self._GetJson('/data/plugin/projector/runs') self.assertEqual(run_json, []) - def testEndpointsMetadataForVariableAssets(self): - self._GenerateProjectorTestData() - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - manager.add_metadata_for_embedding_variable('test', metadata) - - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertTrue(run_json) - - run = run_json[0] - metedata_query = '/data/plugin/projector/metadata?run=%s&name=test' % run - metadata_tsv = self._Get(metedata_query).data - self.assertEqual(metadata_tsv, b'a\nb\nc\n') - - unk_tensor_query = '/data/plugin/projector/tensor?run=%s&name=test' % run - response = self._Get(unk_tensor_query) - self.assertEqual(response.status_code, 400) - - expected_tensor = np.array([[6, 6]], dtype=np.float32) - tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run - tensor_bytes = self._Get(tensor_query).data - self._AssertTensorResponse(tensor_bytes, expected_tensor) - - def testEndpointsMetadataForVariableAssetsButNoCheckpoint(self): - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - manager.add_metadata_for_embedding_variable('test', metadata) - - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertEqual(run_json, []) - - def testEndpointsTensorAndMetadataAssets(self): - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - manager.add_metadata_for_embedding_variable('test', metadata) - expected_tensor = np.array([[1, 2], [3, 4], [5, 6]]) - image1 = np.array([[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]) - image2 = np.array([[[10, 20, 30], [40, 50, 60]], - [[70, 80, 90], [100, 110, 120]]]) - manager.add_embedding('emb', expected_tensor, metadata, [image1, image2], - [2, 2]) - - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertTrue(run_json) - - run = run_json[0] - metadata_query = '/data/plugin/projector/metadata?run=%s&name=emb' % run - metadata_tsv = self._Get(metadata_query).data - self.assertEqual(metadata_tsv, b'a\nb\nc\n') - - unk_metadata_query = '/data/plugin/projector/metadata?run=%s&name=q' % run - response = self._Get(unk_metadata_query) - self.assertEqual(response.status_code, 400) - - tensor_query = '/data/plugin/projector/tensor?run=%s&name=emb' % run - tensor_bytes = self._Get(tensor_query).data - self._AssertTensorResponse(tensor_bytes, expected_tensor) - - unk_tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run - response = self._Get(unk_tensor_query) - self.assertEqual(response.status_code, 400) - - image_query = '/data/plugin/projector/sprite_image?run=%s&name=emb' % run - image_bytes = self._Get(image_query).data - with tf.Graph().as_default(): - s = tf.Session() - image_array = tf.image.decode_png(image_bytes).eval(session=s).tolist() - expected_sprite_image = [ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[7, 8, 9], [10, 11, 12], [70, 80, 90], [100, 110, 120]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] - ] - self.assertEqual(image_array, expected_sprite_image) - - def testSpriteImageRequestMissingRunAndName(self): - self._SetupWSGIApp() - q = '/data/plugin/projector/sprite_image' - response = self._Get(q) - self.assertEqual(response.status_code, 400) - - def testSpriteImageRequestMissingName(self): - self._SetupWSGIApp() - q = '/data/plugin/projector/sprite_image?run=.' - response = self._Get(q) - self.assertEqual(response.status_code, 400) - - def testSpriteImageRequestMissingRun(self): - self._SetupWSGIApp() - q = '/data/plugin/projector/sprite_image?name=emb' - response = self._Get(q) - self.assertEqual(response.status_code, 400) - - def testSpriteImageUnknownRun(self): - self._GenerateProjectorTestData() - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - image1 = np.array([[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]) - image2 = np.array([[[10, 20, 30], [40, 50, 60]], - [[70, 80, 90], [100, 110, 120]]]) - manager.add_metadata_for_embedding_variable('var1', - thumbnails=[image1, image2], - thumbnail_dim=[2, 2]) - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - self._SetupWSGIApp() - - q = '/data/plugin/projector/sprite_image?run=unknown&name=var1' - response = self._Get(q) - self.assertEqual(response.status_code, 400) - - def testSpriteImageUnknownName(self): - self._GenerateProjectorTestData() - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - image1 = np.array([[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]) - image2 = np.array([[[10, 20, 30], [40, 50, 60]], - [[70, 80, 90], [100, 110, 120]]]) - manager.add_metadata_for_embedding_variable('var1', - thumbnails=[image1, image2], - thumbnail_dim=[2, 2]) - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - self._SetupWSGIApp() - q = '/data/plugin/projector/sprite_image?run=.&name=unknown' - response = self._Get(q) - self.assertEqual(response.status_code, 400) - - def testEndpointsComboTensorAssetsAndCheckpoint(self): - self._GenerateProjectorTestData() - g = tf.Graph() - with g.as_default(): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - manager.add_metadata_for_embedding_variable('var1', metadata) - - new_tensor_values = np.array([[1, 2], [3, 4], [5, 6]]) - manager.add_embedding('new_tensor', new_tensor_values) - - fw = tf.summary.FileWriter(self.log_dir, graph=g) - fw.close() - - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertTrue(run_json) - - run = run_json[0] - var1_values = np.array([[6, 6]], dtype=np.float32) - var1_tensor_query = '/data/plugin/projector/tensor?run=%s&name=var1' % run - tensor_bytes = self._Get(var1_tensor_query).data - self._AssertTensorResponse(tensor_bytes, var1_values) - - metadata_query = '/data/plugin/projector/metadata?run=%s&name=var1' % run - metadata_tsv = self._Get(metadata_query).data - self.assertEqual(metadata_tsv, b'a\nb\nc\n') - - tensor_query = '/data/plugin/projector/tensor?run=%s&name=new_tensor' % run - tensor_bytes = self._Get(tensor_query).data - self._AssertTensorResponse(tensor_bytes, new_tensor_values) - def _AssertTensorResponse(self, tensor_bytes, expected_tensor): tensor = np.reshape(np.fromstring(tensor_bytes, dtype=np.float32), expected_tensor.shape) @@ -507,245 +307,6 @@ class MetadataColumnsTest(tf.test.TestCase): metadata.add_column('Labels', np.array(['a', 'b'])) -class ProjectorPluginAssetTest(tf.test.TestCase): - - def testNoAssets(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - self.assertEqual(manager.assets(), {'projector_config.pbtxt': ''}) - - def testAddEmbeddingNoMetadata(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - manager.add_embedding('test', np.array([[1, 2, 3.1]])) - - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - embedding.tensor_name = 'test' - embedding.tensor_shape.extend([1, 3]) - embedding.tensor_path = 'test_values.tsv' - expected_config_pbtxt = text_format.MessageToString(config) - - self.assertEqual(manager.assets(), { - 'projector_config.pbtxt': expected_config_pbtxt, - 'test_values.tsv': b'1\t2\t3.1\n' - }) - - def testAddEmbeddingIncorrectRank(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - with self.assertRaises(ValueError): - manager.add_embedding('test', np.array([1, 2, 3.1])) - - def testAddEmbeddingWithTwoMetadataColumns(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'друг јазик']) - metadata.add_column('sizes', [10, 20, 30]) - manager.add_embedding('test', np.array([[1], [2], [3]]), metadata) - - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - embedding.tensor_name = 'test' - embedding.tensor_shape.extend([3, 1]) - embedding.tensor_path = 'test_values.tsv' - embedding.metadata_path = 'test_metadata.tsv' - expected_config_pbtxt = text_format.MessageToString(config) - - self.assertEqual(manager.assets(), { - 'projector_config.pbtxt': expected_config_pbtxt, - 'test_values.tsv': b'1\n2\n3\n', - 'test_metadata.tsv': 'labels\tsizes\na\t10\nb\t20\nдруг јазик\t30\n' - }) - - def testAddEmbeddingWithOneMetadataColumn(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - manager.add_embedding('test', np.array([[1], [2], [3]]), metadata) - - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - embedding.tensor_name = 'test' - embedding.tensor_shape.extend([3, 1]) - embedding.tensor_path = 'test_values.tsv' - embedding.metadata_path = 'test_metadata.tsv' - expected_config_pbtxt = text_format.MessageToString(config) - - self.assertEqual(manager.assets(), { - 'projector_config.pbtxt': expected_config_pbtxt, - 'test_values.tsv': b'1\n2\n3\n', - 'test_metadata.tsv': 'a\nb\nc\n' - }) - - def testAddEmbeddingWithThumbnails(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - image1 = np.array([[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]) - image2 = np.array([[[10, 20, 30], [40, 50, 60]], - [[70, 80, 90], [100, 110, 120]]]) - manager.add_embedding( - 'test', - np.array([[1], [2], [3]]), - thumbnails=[image1, image2], - thumbnail_dim=[2, 2]) - - assets = manager.assets() - - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - embedding.tensor_name = 'test' - embedding.tensor_shape.extend([3, 1]) - embedding.tensor_path = 'test_values.tsv' - embedding.sprite.image_path = 'test_sprite.png' - embedding.sprite.single_image_dim.extend([2, 2]) - expected_config_pbtxt = text_format.MessageToString(config) - - self.assertEqual(assets['projector_config.pbtxt'], expected_config_pbtxt) - self.assertEqual(assets['test_values.tsv'], b'1\n2\n3\n') - - png_bytes = assets['test_sprite.png'] - with tf.Graph().as_default(): - s = tf.Session() - image_array = tf.image.decode_png(png_bytes).eval(session=s).tolist() - expected_master_image = [ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[7, 8, 9], [10, 11, 12], [70, 80, 90], [100, 110, 120]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] - ] - self.assertEqual(image_array, expected_master_image) - - def testAddEmbeddingWithSpriteImageButNoThumbnailDim(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) - with self.assertRaises(ValueError): - manager.add_embedding( - 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails) - - def testAddEmbeddingThumbnailDimNotAList(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) - with self.assertRaises(ValueError): - manager.add_embedding( - 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails, - thumbnail_dim=4) - - def testAddEmbeddingThumbnailDimNotOfLength2(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) - with self.assertRaises(ValueError): - manager.add_embedding( - 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails, - thumbnail_dim=[4]) - - def testAddEmbeddingThumbnailListHasNoEntries(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - with self.assertRaises(ValueError): - manager.add_embedding('test', np.array([[1]]), thumbnails=[], - thumbnail_dim=[1, 1]) - - def testAddEmbeddingThumbnailListNotOfRank4(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - with self.assertRaises(ValueError): - manager.add_embedding('test2', np.array([[1]]), - thumbnails=np.array([[1]]), thumbnail_dim=[1, 1]) - - def testAddEmbeddingThumbnailListEntriesNot3DTensors(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - with self.assertRaises(ValueError): - manager.add_embedding('test3', np.array([[1]]), thumbnails=[[1, 2, 3]], - thumbnail_dim=[1, 1]) - - def testAddEmbeddingWithMetadataOfIncorrectLength(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('labels', ['a', 'b', 'c']) - # values has length 2, while metadata has length 3. - values = np.array([[1], [2]]) - - with self.assertRaises(ValueError): - manager.add_embedding('test', values, metadata) - - def testAddMetadataForVariableButNoColumns(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - metadata = projector_plugin.EmbeddingMetadata(3) - with self.assertRaises(ValueError): - manager.add_metadata_for_embedding_variable('test', metadata) - - def testAddMetadataForVariable(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('Labels', ['a', 'b', 'c']) - manager.add_metadata_for_embedding_variable('test', metadata) - - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - embedding.tensor_name = 'test' - embedding.metadata_path = 'test_metadata.tsv' - expected_config_pbtxt = text_format.MessageToString(config) - - self.assertEqual(manager.assets(), { - 'projector_config.pbtxt': expected_config_pbtxt, - 'test_metadata.tsv': 'a\nb\nc\n' - }) - - def testAddMetadataForVariableAtLeastOneParamIsRequired(self): - manager = plugin_asset.get_plugin_asset( - projector_plugin.ProjectorPluginAsset) - with self.assertRaises(ValueError): - manager.add_metadata_for_embedding_variable('test') - - def testNoAssetsProperSerializationOnDisk(self): - logdir = self.get_temp_dir() - plugin_dir = os.path.join(logdir, 'plugins', - projector_plugin.ProjectorPluginAsset.plugin_name) - - with tf.Graph().as_default() as g: - plugin_asset.get_plugin_asset(projector_plugin.ProjectorPluginAsset) - fw = tf.summary.FileWriter(logdir, graph=g) - fw.close() - - with tf.gfile.Open(os.path.join(plugin_dir, 'projector_config.pbtxt')) as f: - content = f.read() - self.assertEqual(content, '') - - def testNoReferenceToPluginNoSerializationOnDisk(self): - logdir = self.get_temp_dir() - plugin_dir = os.path.join(logdir, 'plugins', - projector_plugin.ProjectorPluginAsset.plugin_name) - - with tf.Graph().as_default() as g: - fw = tf.summary.FileWriter(logdir, graph=g) - fw.close() - - self.assertFalse( - tf.gfile.Exists(plugin_dir), - 'The projector plugin directory should not exist.') - - class LRUCacheTest(tf.test.TestCase): def testInvalidSize(self): From d741d81c5f3412dd97da25a07574c327b2f8b1fc Mon Sep 17 00:00:00 2001 From: Dandelion Man? Date: Thu, 1 Jun 2017 17:11:04 -0700 Subject: [PATCH 35/72] Expose tf.test.StubOutForTesting in the tf testing api Also redirect TensorBoard usage to use that endpoint. This is part of my ongoing effort to have TensorBoard only depend on TensorFlow via its public api, so that it can be split into a project with a fast external build. PiperOrigin-RevId: 157784552 --- tensorflow/python/platform/test.py | 4 +++ .../directory_watcher_test.py | 3 +- .../event_accumulator_test.py | 3 +- .../event_multiplexer_test.py | 3 +- ...ensorflow.test.-stub-out-for-testing.pbtxt | 28 +++++++++++++++++++ .../tools/api/golden/tensorflow.test.pbtxt | 4 +++ .../api/lib/python_object_to_proto_visitor.py | 6 ++++ 7 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 5cb2c152b04..a307347f606 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -61,6 +61,9 @@ else: # Import Benchmark class Benchmark = _googletest.Benchmark # pylint: disable=invalid-name +# Import StubOutForTesting class +StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name + def main(argv=None): """Runs all unit tests.""" @@ -117,6 +120,7 @@ _allowed_symbols = [ # We piggy-back googletest documentation. 'Benchmark', 'mock', + 'StubOutForTesting', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py b/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py index 190ae6a96b4..d44f74a8a43 100644 --- a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py +++ b/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py @@ -24,7 +24,6 @@ import shutil import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.tensorboard.backend.event_processing import directory_watcher from tensorflow.tensorboard.backend.event_processing import io_wrapper @@ -55,7 +54,7 @@ class DirectoryWatcherTest(tf.test.TestCase): os.mkdir(self._directory) self._watcher = directory_watcher.DirectoryWatcher(self._directory, _ByteLoader) - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() def tearDown(self): self.stubs.CleanUp() diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py index a2ac371a931..9efd64bd2ef 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py @@ -24,7 +24,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.python.summary.writer.writer import SummaryToEventTransformer from tensorflow.tensorboard.backend.event_processing import event_accumulator as ea @@ -182,7 +181,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def setUp(self): super(MockingEventAccumulatorTest, self).setUp() - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() self._real_constructor = ea.EventAccumulator self._real_generator = ea._GeneratorFromPath diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py index a7a6413ad1f..ea536dfaad6 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py @@ -24,7 +24,6 @@ import shutil import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.tensorboard.backend.event_processing import event_accumulator from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -116,7 +115,7 @@ class EventMultiplexerTest(tf.test.TestCase): def setUp(self): super(EventMultiplexerTest, self).setUp() - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator) diff --git a/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt new file mode 100644 index 00000000000..e02a0c6097c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.test.StubOutForTesting" +tf_class { + is_instance: "" + member_method { + name: "CleanUp" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Set" + argspec: "args=[\'self\', \'parent\', \'child_name\', \'new_child\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartSet" + argspec: "args=[\'self\', \'obj\', \'attr_name\', \'new_attr\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartUnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "UnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.pbtxt index 1e717ad2371..2a88f26ed02 100644 --- a/tensorflow/tools/api/golden/tensorflow.test.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.test.pbtxt @@ -4,6 +4,10 @@ tf_module { name: "Benchmark" mtype: "" } + member { + name: "StubOutForTesting" + mtype: "" + } member { name: "TestCase" mtype: "" diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 3197d0288f8..43ba52f9834 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -92,6 +92,12 @@ def _SanitizedMRO(obj): if 'tensorflow' not in str_repr: break + # Hack - tensorflow.test.StubOutForTesting may or may not be type + # depending on the environment. To avoid inconsistency, break after we add + # StubOutForTesting to the return_list. + if 'StubOutForTesting' in str_repr: + break + return return_list From 23cdf96b85177e657585da52651d89c5e6620e8d Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Thu, 1 Jun 2017 17:29:27 -0700 Subject: [PATCH 36/72] Re-enable session_test.py A number of CL's have split up session_test.py to be a bit smaller. As a result, this CL will re-enable the session_test to see if it remains flaky. PiperOrigin-RevId: 157786407 --- tensorflow/python/BUILD | 60 +++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c959ad904d7..93606ce4ce4 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2945,38 +2945,34 @@ tf_cuda_library( alwayslink = 1, ) -# Disabled due to http://b/62145493 -# py_test( -# name = "session_test", -# size = "medium", # http://b/62144199 -# srcs = ["client/session_test.py"], -# srcs_version = "PY2AND3", -# tags = [ -# "no_gpu", -# "no_pip_gpu", # testInteractivePlacePrunedGraph fails on invalid assumption about GPU ops. -# ], -# deps = [ -# ":array_ops", -# ":client", -# ":construction_fails_op", -# ":control_flow_ops", -# ":data_flow_ops", -# ":errors", -# ":framework", -# ":framework_for_generated_wrappers", -# ":framework_test_lib", -# ":math_ops", -# ":platform_test", -# ":state_ops", -# ":training", -# ":util", -# ":variables", -# "//third_party/py/numpy", -# "@six_archive//:six", -# "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", -# "//tensorflow/core/distributed_runtime/rpc:grpc_session", -# ], -# ) +py_test( + name = "session_test", + size = "small", + srcs = ["client/session_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", # testInteractivePlacePrunedGraph fails on invalid assumption about GPU ops. + ], + deps = [ + ":array_ops", + ":client", + ":control_flow_ops", + ":data_flow_ops", + ":errors", + ":framework", + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":math_ops", + ":platform_test", + ":state_ops", + ":training", + ":util", + ":variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) py_test( name = "session_clusterspec_prop_test", From 7d7a40309693f01359537dce97fd6ff82e19755d Mon Sep 17 00:00:00 2001 From: William Chargin Date: Thu, 1 Jun 2017 17:37:28 -0700 Subject: [PATCH 37/72] Extract the distributions dashboard to a plugin This continues the great plugin migration. The distributions plugin was similar to the histograms plugin, but it also purported to allow CSV download like the scalars plugin. However, the existing implementation of this was flawed, and would always yield a 500 on current prod [1] (unless there were actually no data). This indicates that no one is actually using it---probably because there isn't a relevant button on the frontend, anyway!---so I just removed it. This also changes most frontend occurrences of "compressedHistograms" to "distributions" while we're at it. [1]: Due to the reference `value.rank_in_bps` in the handler `_serve_compressed_histograms`; this field does not exist and throws an `AttributeError`. PiperOrigin-RevId: 157787156 --- tensorflow/BUILD | 1 + tensorflow/contrib/cmake/tf_python.cmake | 1 + tensorflow/tensorboard/BUILD | 1 + tensorflow/tensorboard/backend/application.py | 46 +------ .../tensorboard/backend/application_test.py | 21 --- .../event_processing/event_accumulator.py | 2 +- .../components/tf_backend/backend.ts | 8 +- .../components/tf_backend/router.ts | 2 - tensorflow/tensorboard/http_api.md | 23 +++- .../tensorboard/plugins/distributions/BUILD | 50 +++++++ .../distributions/distributions_plugin.py | 69 ++++++++++ .../distributions_plugin_test.py | 125 ++++++++++++++++++ tensorflow/tensorboard/tensorboard.py | 8 +- 13 files changed, 277 insertions(+), 80 deletions(-) create mode 100644 tensorflow/tensorboard/plugins/distributions/BUILD create mode 100644 tensorflow/tensorboard/plugins/distributions/distributions_plugin.py create mode 100644 tensorflow/tensorboard/plugins/distributions/distributions_plugin_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 055c55a7170..ce1387ba43c 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -380,6 +380,7 @@ filegroup( "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins/audio:all_files", + "//tensorflow/tensorboard/plugins/distributions:all_files", "//tensorflow/tensorboard/plugins/histograms:all_files", "//tensorflow/tensorboard/plugins/images:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 132d84d00bb..95dbefc37ab 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -230,6 +230,7 @@ add_python_module("tensorflow/tensorboard/backend") add_python_module("tensorflow/tensorboard/backend/event_processing") add_python_module("tensorflow/tensorboard/plugins") add_python_module("tensorflow/tensorboard/plugins/audio") +add_python_module("tensorflow/tensorboard/plugins/distributions") add_python_module("tensorflow/tensorboard/plugins/histograms") add_python_module("tensorflow/tensorboard/plugins/images") add_python_module("tensorflow/tensorboard/plugins/projector") diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index 0b9c254b514..a8a4fb16614 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -14,6 +14,7 @@ py_binary( "//tensorflow/tensorboard/backend:application", "//tensorflow/tensorboard/backend/event_processing:event_file_inspector", "//tensorflow/tensorboard/plugins/audio:audio_plugin", + "//tensorflow/tensorboard/plugins/distributions:distributions_plugin", "//tensorflow/tensorboard/plugins/histograms:histograms_plugin", "//tensorflow/tensorboard/plugins/images:images_plugin", "//tensorflow/tensorboard/plugins/projector:projector_plugin", diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py index ef2d0c6d693..46f081a67c9 100644 --- a/tensorflow/tensorboard/backend/application.py +++ b/tensorflow/tensorboard/backend/application.py @@ -22,15 +22,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import csv import os import re import threading import time import six -from six import StringIO -from six.moves import xrange # pylint: disable=redefined-builtin from six.moves.urllib import parse as urlparse import tensorflow as tf from werkzeug import wrappers @@ -59,6 +56,7 @@ DEFAULT_SIZE_GUIDANCE = { # /data/runs entirely. _MIGRATED_DATA_KEYS = frozenset(( 'audio', + 'distributions', 'histograms', 'images', 'scalars', @@ -69,7 +67,6 @@ LOGDIR_ROUTE = '/logdir' RUNS_ROUTE = '/runs' PLUGIN_PREFIX = '/plugin' PLUGINS_LISTING_ROUTE = '/plugins_listing' -COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS GRAPH_ROUTE = '/' + event_accumulator.GRAPH RUN_METADATA_ROUTE = '/' + event_accumulator.RUN_METADATA TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] @@ -80,16 +77,6 @@ TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] _VALID_PLUGIN_RE = re.compile(r'^[A-Za-z0-9_.-]+$') -class _OutputFormat(object): - """An enum used to list the valid output formats for API calls. - - Not all API calls support all formats (for example, only scalars and - compressed histograms support CSV). - """ - JSON = 'json' - CSV = 'csv' - - def standard_tensorboard_wsgi( logdir, purge_orphaned_data, @@ -159,8 +146,6 @@ class TensorBoardWSGIApp(object): reload_multiplexer(self._multiplexer, path_to_run) self.data_applications = { - DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: - self._serve_compressed_histograms, DATA_PREFIX + GRAPH_ROUTE: self._serve_graph, DATA_PREFIX + LOGDIR_ROUTE: @@ -278,35 +263,6 @@ class TensorBoardWSGIApp(object): return http_util.Respond( request, str(run_metadata), 'text/x-protobuf') # pbtxt - @wrappers.Request.application - def _serve_compressed_histograms(self, request): - """Given a tag and single run, return an array of compressed histograms.""" - tag = request.args.get('tag') - run = request.args.get('run') - compressed_histograms = self._multiplexer.CompressedHistograms(run, tag) - if request.args.get('format') == _OutputFormat.CSV: - string_io = StringIO() - writer = csv.writer(string_io) - - # Build the headers; we have two columns for timing and two columns for - # each compressed histogram bucket. - headers = ['Wall time', 'Step'] - if compressed_histograms: - bucket_count = len(compressed_histograms[0].compressed_histogram_values) - for i in xrange(bucket_count): - headers += ['Edge %d basis points' % i, 'Edge %d value' % i] - writer.writerow(headers) - - for compressed_histogram in compressed_histograms: - row = [compressed_histogram.wall_time, compressed_histogram.step] - for value in compressed_histogram.compressed_histogram_values: - row += [value.rank_in_bps, value.value] - writer.writerow(row) - return http_util.Respond(request, string_io.getvalue(), 'text/csv') - else: - return http_util.Respond( - request, compressed_histograms, 'application/json') - @wrappers.Request.application def _serve_plugins_listing(self, request): """Serves an object mapping plugin name to whether it is enabled. diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py index 9729e6395a0..f05c9352466 100644 --- a/tensorflow/tensorboard/backend/application_test.py +++ b/tensorflow/tensorboard/backend/application_test.py @@ -167,7 +167,6 @@ class TensorboardServerTest(tf.test.TestCase): run_json, { 'run1': { - 'compressedHistograms': ['histogram'], # if only_use_meta_graph, the graph is from the metagraph 'graph': True, 'meta_graph': self._only_use_meta_graph, @@ -265,13 +264,8 @@ class TensorboardServerTest(tf.test.TestCase): """Generates the test data directory. The test data has a single run named run1 which contains: - - a histogram [1] - a graph definition - [1]: Histograms no longer appear in `/runs`, but compressed - histograms do, and they use the same test data. Thus, histograms are - still here for now. - Returns: temp_dir: The directory the test data is generated under. """ @@ -281,14 +275,6 @@ class TensorboardServerTest(tf.test.TestCase): os.makedirs(run1_path) writer = tf.summary.FileWriter(run1_path) - histogram_value = tf.HistogramProto( - min=0, - max=2, - num=3, - sum=6, - sum_squares=5, - bucket_limit=[0, 1, 2], - bucket=[1, 1, 1]) # Add a simple graph event. graph_def = tf.GraphDef() node1 = graph_def.node.add() @@ -309,13 +295,6 @@ class TensorboardServerTest(tf.test.TestCase): device_stats = run_metadata.step_stats.dev_stats.add() device_stats.device = 'test device' writer.add_run_metadata(run_metadata, 'test run') - writer.add_event( - tf.Event( - wall_time=0, - step=0, - summary=tf.Summary(value=[ - tf.Summary.Value(tag='histogram', histo=histogram_value), - ]))) writer.flush() writer.close() diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py index 1669c060844..1562f0f8339 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py @@ -72,7 +72,7 @@ SUMMARY_TYPES = { ## The tagTypes below are just arbitrary strings chosen to pass the type ## information of the tag from the backend to the frontend -COMPRESSED_HISTOGRAMS = 'compressedHistograms' +COMPRESSED_HISTOGRAMS = 'distributions' HISTOGRAMS = 'histograms' IMAGES = 'images' AUDIO = 'audio' diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index a7a222beaaf..2db8ddc23d2 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -193,8 +193,9 @@ export class Backend { * Return a promise showing the Run-to-Tag mapping for compressedHistogram * data. */ - public compressedHistogramRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'compressedHistograms')); + public compressedHistogramTags(): Promise { + return this.requestManager.request( + this.router.pluginRoute('distributions', '/tags')); } /** @@ -343,7 +344,8 @@ export class Backend { */ private compressedHistogram(tag: string, run: string): Promise> { - const url = this.router.compressedHistograms(tag, run); + const url = (this.router.pluginRunTagRoute( + 'distributions', '/distributions')(tag, run)); let p: Promise[]>; p = this.requestManager.request(url); return p.then(map(detupler((x) => x))); diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts index b31f9b366ea..115634be125 100644 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ b/tensorflow/tensorboard/components/tf_backend/router.ts @@ -21,7 +21,6 @@ export interface Router { logdir: () => string; runs: () => string; isDemoMode: () => boolean; - compressedHistograms: RunTagUrlFn; graph: (run: string, limit_attr_size?: number, large_attrs_key?: string) => string; @@ -88,7 +87,6 @@ export function router(dataDir = 'data', demoMode = false): Router { runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), isDemoMode: () => demoMode, graph: graphUrl, - compressedHistograms: standardRoute('compressedHistograms'), runMetadata: standardRoute('run_metadata', '.pbtxt'), healthPills: () => dataDir + '/plugin/debugger/health_pills', textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md index 541394cbe00..b015539cf52 100644 --- a/tensorflow/tensorboard/http_api.md +++ b/tensorflow/tensorboard/http_api.md @@ -55,13 +55,11 @@ all of the data available from the TensorBoard server. Here is an example: { "train_run": { - "compressedHistograms": ["foo_histogram", "bar_histogram"], "graph": true, "firstEventTimestamp": 123456.789 "run_metadata": ["forward prop", "inference"] }, "eval": { - "compressedHistograms": ["foo_histogram", "bar_histogram"], "graph": false, "run_metadata": [] } @@ -81,6 +79,7 @@ and will not appear in the output from this route: - `audio` - `images` - `scalars` + - `compressedHistograms`, moved to `distributions` - `histograms` ## `/data/plugin/scalars/tags` @@ -160,7 +159,21 @@ Annotated Example: (note - real data is higher precision) ] ] -## '/data/compressedHistograms?run=foo&tag=bar' +## `/data/plugin/distributions/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all +distribution tags present in the corresponding run. Here is an example: + + { + "train_run": ["foo_histogram", "bar_histogram"], + "eval": ["foo_histogram", "bar_histogram"] + } + +Note that runs without any distribution tags are included as keys with +value the empty array. + +## `/data/plugin/distributions/distributions?run=foo&tag=bar` Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, step, CompressedHistogramValues]) for the given run and tag. @@ -180,8 +193,8 @@ Annotated Example: (note - real data is higher precision) [ 1441154832.580509, # wall_time 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile + [ [0, -3.67], # CompressedHistogramValue for 0th percentile + [2500, -4.19], # CompressedHistogramValue for 25th percentile [5000, 6.29], [7500, 1.64], [10000, 3.67] diff --git a/tensorflow/tensorboard/plugins/distributions/BUILD b/tensorflow/tensorboard/plugins/distributions/BUILD new file mode 100644 index 00000000000..de1f73143c6 --- /dev/null +++ b/tensorflow/tensorboard/plugins/distributions/BUILD @@ -0,0 +1,50 @@ +# Description: +# TensorBoard plugin for distributions + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +## Distributions Plugin ## +py_library( + name = "distributions_plugin", + srcs = ["distributions_plugin.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/tensorboard/backend:http_util", + "//tensorflow/tensorboard/backend/event_processing:event_accumulator", + "//tensorflow/tensorboard/plugins:base_plugin", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +py_test( + name = "distributions_plugin_test", + size = "small", + srcs = ["distributions_plugin_test.py"], + main = "distributions_plugin_test.py", + srcs_version = "PY2AND3", + deps = [ + ":distributions_plugin", + "//tensorflow:tensorflow_py", + "//tensorflow/tensorboard/backend:application", + "//tensorflow/tensorboard/backend/event_processing:event_accumulator", + "//tensorflow/tensorboard/backend/event_processing:event_multiplexer", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + visibility = ["//tensorflow:__pkg__"], +) diff --git a/tensorflow/tensorboard/plugins/distributions/distributions_plugin.py b/tensorflow/tensorboard/plugins/distributions/distributions_plugin.py new file mode 100644 index 00000000000..4bb9dfaf545 --- /dev/null +++ b/tensorflow/tensorboard/plugins/distributions/distributions_plugin.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The TensorBoard Distributions (a.k.a. compressed histograms) plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from werkzeug import wrappers + +from tensorflow.tensorboard.backend import http_util +from tensorflow.tensorboard.backend.event_processing import event_accumulator +from tensorflow.tensorboard.plugins import base_plugin + +_PLUGIN_PREFIX_ROUTE = event_accumulator.COMPRESSED_HISTOGRAMS + + +class DistributionsPlugin(base_plugin.TBPlugin): + """Distributions Plugin for TensorBoard.""" + + plugin_name = _PLUGIN_PREFIX_ROUTE + + def get_plugin_apps(self, multiplexer, unused_logdir): + self._multiplexer = multiplexer + return { + '/distributions': self.distributions_route, + '/tags': self.tags_route, + } + + def is_active(self): + """This plugin is active iff any run has at least one relevant tag.""" + return any(self.index_impl().values()) + + def index_impl(self): + return { + run_name: run_data[event_accumulator.COMPRESSED_HISTOGRAMS] + for (run_name, run_data) in self._multiplexer.Runs().items() + if event_accumulator.COMPRESSED_HISTOGRAMS in run_data + } + + def distributions_impl(self, tag, run): + """Result of the form `(body, mime_type)`.""" + values = self._multiplexer.CompressedHistograms(run, tag) + return (values, 'application/json') + + @wrappers.Request.application + def tags_route(self, request): + index = self.index_impl() + return http_util.Respond(request, index, 'application/json') + + @wrappers.Request.application + def distributions_route(self, request): + """Given a tag and single run, return array of compressed histograms.""" + tag = request.args.get('tag') + run = request.args.get('run') + (body, mime_type) = self.distributions_impl(tag, run) + return http_util.Respond(request, body, mime_type) diff --git a/tensorflow/tensorboard/plugins/distributions/distributions_plugin_test.py b/tensorflow/tensorboard/plugins/distributions/distributions_plugin_test.py new file mode 100644 index 00000000000..b5aae6dea79 --- /dev/null +++ b/tensorflow/tensorboard/plugins/distributions/distributions_plugin_test.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration tests for the Distributions Plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.tensorboard.backend.event_processing import event_accumulator +from tensorflow.tensorboard.backend.event_processing import event_multiplexer +from tensorflow.tensorboard.plugins.distributions import distributions_plugin + + +class DistributionsPluginTest(tf.test.TestCase): + + _STEPS = 99 + + _DISTRIBUTION_TAG = 'my-favorite-distribution' + _SCALAR_TAG = 'my-boring-scalars' + + _RUN_WITH_DISTRIBUTION = '_RUN_WITH_DISTRIBUTION' + _RUN_WITH_SCALARS = '_RUN_WITH_SCALARS' + + def set_up_with_runs(self, run_names): + self.logdir = self.get_temp_dir() + for run_name in run_names: + self.generate_run(run_name) + multiplexer = event_multiplexer.EventMultiplexer(size_guidance={ + # don't truncate my test data, please + event_accumulator.COMPRESSED_HISTOGRAMS: + self._STEPS, + }) + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + self.plugin = distributions_plugin.DistributionsPlugin() + self.apps = self.plugin.get_plugin_apps(multiplexer, None) + + def generate_run(self, run_name): + if run_name == self._RUN_WITH_DISTRIBUTION: + (use_distributions, use_scalars) = (True, False) + elif run_name == self._RUN_WITH_SCALARS: + (use_distributions, use_scalars) = (False, True) + else: + assert False, 'Invalid run name: %r' % run_name + tf.reset_default_graph() + sess = tf.Session() + placeholder = tf.placeholder(tf.float32, shape=[3]) + if use_distributions: + tf.summary.histogram(self._DISTRIBUTION_TAG, placeholder) + if use_scalars: + tf.summary.scalar(self._SCALAR_TAG, tf.reduce_mean(placeholder)) + summ = tf.summary.merge_all() + + subdir = os.path.join(self.logdir, run_name) + writer = tf.summary.FileWriter(subdir) + writer.add_graph(sess.graph) + for step in xrange(self._STEPS): + feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} + s = sess.run(summ, feed_dict=feed_dict) + writer.add_summary(s, global_step=step) + writer.close() + + def test_index(self): + self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION, + self._RUN_WITH_SCALARS]) + self.assertEqual({ + self._RUN_WITH_DISTRIBUTION: [self._DISTRIBUTION_TAG], + self._RUN_WITH_SCALARS: [], + }, self.plugin.index_impl()) + + def _test_distributions_json(self, run_name, should_have_distributions): + self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION, + self._RUN_WITH_SCALARS]) + if should_have_distributions: + (data, mime_type) = self.plugin.distributions_impl( + self._DISTRIBUTION_TAG, run_name) + self.assertEqual('application/json', mime_type) + self.assertEqual(len(data), self._STEPS) + for i in xrange(self._STEPS): + self.assertEqual(i, data[i].step) + else: + with self.assertRaises(KeyError): + self.plugin.distributions_impl( + self._DISTRIBUTION_TAG, run_name) + + def test_distributions_json_with_scalars(self): + self._test_distributions_json(self._RUN_WITH_DISTRIBUTION, True) + + def test_distributions_json_with_histogram(self): + self._test_distributions_json(self._RUN_WITH_SCALARS, False) + + def test_active_with_distribution(self): + self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION]) + self.assertTrue(self.plugin.is_active()) + + def test_active_with_scalars(self): + self.set_up_with_runs([self._RUN_WITH_SCALARS]) + self.assertFalse(self.plugin.is_active()) + + def test_active_with_both(self): + self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION, + self._RUN_WITH_SCALARS]) + self.assertTrue(self.plugin.is_active()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py index bce5dd259dd..70830b9a8c8 100644 --- a/tensorflow/tensorboard/tensorboard.py +++ b/tensorflow/tensorboard/tensorboard.py @@ -33,6 +33,7 @@ from werkzeug import serving from tensorflow.tensorboard.backend import application from tensorflow.tensorboard.backend.event_processing import event_file_inspector as efi from tensorflow.tensorboard.plugins.audio import audio_plugin +from tensorflow.tensorboard.plugins.distributions import distributions_plugin from tensorflow.tensorboard.plugins.histograms import histograms_plugin from tensorflow.tensorboard.plugins.images import images_plugin from tensorflow.tensorboard.plugins.projector import projector_plugin @@ -204,10 +205,11 @@ def main(unused_argv=None): return 0 else: plugins = [ - audio_plugin.AudioPlugin(), - histograms_plugin.HistogramsPlugin(), - images_plugin.ImagesPlugin(), scalars_plugin.ScalarsPlugin(), + images_plugin.ImagesPlugin(), + audio_plugin.AudioPlugin(), + distributions_plugin.DistributionsPlugin(), + histograms_plugin.HistogramsPlugin(), projector_plugin.ProjectorPlugin(), text_plugin.TextPlugin(), ] From 69075f3546dfc29dbef8b7c5d990f3af094cbd5f Mon Sep 17 00:00:00 2001 From: Yangzihao Wang Date: Thu, 1 Jun 2017 17:50:43 -0700 Subject: [PATCH 38/72] Add functional support for cudnnConvolutionBiasActivationForward(). PiperOrigin-RevId: 157788425 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 135 ++++++++++++++-- tensorflow/stream_executor/cuda/cuda_dnn.h | 42 ++++- tensorflow/stream_executor/dnn.h | 70 ++++++++- tensorflow/stream_executor/stream.cc | 166 ++++++++++++++++++-- tensorflow/stream_executor/stream.h | 56 +++++++ 5 files changed, 434 insertions(+), 35 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index ec6919f9784..cd8994f73a0 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -239,6 +239,17 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_R5 #endif +// APIs in R6 +// clang-format off +#if CUDNN_VERSION >= 6000 +#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \ + __macro(cudnnConvolutionBiasActivationForward) + +// clang-format on +CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_R6 +#endif + #undef CUDNN_DNN_ROUTINE_EACH } // namespace wrap @@ -1791,6 +1802,7 @@ bool CudnnSupport::DoConvolveImpl( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, @@ -1917,6 +1929,26 @@ bool CudnnSupport::DoConvolveImpl( } } + const bool has_biases = (biases != nullptr); + const bool supported_activation_mode = + (activation_mode == dnn::ActivationMode::kRelu6 || + activation_mode == dnn::ActivationMode::kReluX || + activation_mode == dnn::ActivationMode::kRelu); + + if (has_biases && !supported_activation_mode) { + LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only " + "support relu activation."; + return false; + } + + if (has_biases && activation_mode != dnn::ActivationMode::kNone) { + LOG(ERROR) << "To use cudnnConvolutionBiasActivationForward() " + "with a valid biases tensor, need to also provide " + "a valid activation mode (currently only supports " + "kRelu6, kReluX, and kRelu)."; + return false; + } + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); @@ -1931,14 +1963,45 @@ bool CudnnSupport::DoConvolveImpl( return false; } } - status = wrap::cudnnConvolutionForward( - parent_, ToHandle(dnn_handle_), - /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(), - /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), - /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, /*workSpace=*/scratch.opaque(), - /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + if (has_biases) { + CHECK(supported_activation_mode); +#if CUDNN_VERSION < 6000 + LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only " + "supported for cuDNN version >= 6."; + return false; +#else + BatchDescriptor bias_dimensions; + bias_dimensions.set_count(1) + .set_feature_map_count(output_descriptor.feature_map_count()) + .set_height(1) + .set_width(1) + .set_layout(dnn::DataLayout::kBatchYXDepth); + ScopedTensorDescriptor bias_descriptor{ + parent_, bias_dimensions, static_cast(cudnn_type)}; + ScopedActivationDescriptor activation_desc{parent_, activation_mode, + output_descriptor.value_max()}; + status = wrap::cudnnConvolutionBiasActivationForward( + parent_, ToHandle(dnn_handle_), + /*alpha1=*/&alpha, /*srcDesc=*/input_nd.handle(), + /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), + /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&beta, + /*zDesc=*/output_nd.handle(), /*z=*/nullptr, + /*biasDesc=*/bias_descriptor.handle(), + /*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(), + /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); +#endif // CUDNN_VERSION < 6000 + } else { + status = wrap::cudnnConvolutionForward( + parent_, ToHandle(dnn_handle_), + /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(), + /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), + /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, + /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + } if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { timer->Destroy(); @@ -2211,16 +2274,48 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, + filter_data, convolution_descriptor, biases, activation_mode, + output_descriptor, output_data, scratch_allocator, algorithm_config, + output_profile_result); +} + +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& output_descriptor, DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveImpl( + stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor, + filter_data, convolution_descriptor, /*biases=*/nullptr, + dnn::ActivationMode::kNone, output_descriptor, output_data, scratch_allocator, algorithm_config, output_profile_result); } +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const BatchDescriptor& output_descriptor, + DeviceMemory* output_data) { + LOG(ERROR) << "double-based DNN not yet implemented"; + return false; +} + bool CudnnSupport::DoConvolve( Stream* stream, const BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, @@ -2239,13 +2334,33 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, + filter_data, convolution_descriptor, biases, activation_mode, + output_descriptor, output_data, scratch_allocator, algorithm_config, + output_profile_result); +} + +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveImpl( + stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor, + filter_data, convolution_descriptor, /*biases=*/nullptr, + dnn::ActivationMode::kNone, output_descriptor, output_data, scratch_allocator, algorithm_config, output_profile_result); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 2c8ed9a3353..7824885e1b3 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -137,7 +137,43 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) override; - bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor, + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -156,7 +192,7 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data) override; - bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor, + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -477,6 +513,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 8e56933ba38..e8b5bbf5b1a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -796,6 +796,7 @@ class NormalizeDescriptor { // Describes a kind of non-linearity (threshold-like mathematical function). enum class ActivationMode { + kNone, kSigmoid, // Rectified linear activation: f(x) = x < 0 ? 0 : x kRelu, @@ -910,9 +911,11 @@ class DnnSupport { // input_data: un-owned device memory region which contains the // convolution input. // filter_descriptor: dimensions of the convolution filter. - // weights: coefficients for the convolution filter, these are multiplied - // against values in the input that the filter convolves over. // convolution_descriptor: stride of the convolution filter. + // biases: un-owned device memory region containing biases to add to the + // input. This can be DeviceMemory pointing to NULL only when activation_mode + // is kNone. + // activation_mode: Type of activation to perform. // output_descriptor: dimensions of the output layer. // output_data: un-owned device memory region in which to place the // convolution result. @@ -939,6 +942,55 @@ class DnnSupport { // that if the inverse of the filter is applied to the output in VALID mode // the result is the same size as the input - this requires even more // padding of the input. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return false; + } + + // Enqueues a double-precision fused convolution, bias add, and activation + // operation onto the stream. See DoConvolve above for argument details. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data) { + return false; + } + + // Enqueues a half-precision fused convolution, bias add, and activation + // operation onto the stream. See DoConvolve above for argument details. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return false; + } + + // Enqueues a single-precision convolution operation (without bias add + // or activation) onto the stream. + // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, @@ -950,11 +1002,8 @@ class DnnSupport { const dnn::AlgorithmConfig& algorithm_config, ProfileResult* output_profile_result) = 0; - // Return a list of algorithms supported by the forward convolution pass. - virtual bool GetConvolveAlgorithms( - bool with_winograd_nonfused, std::vector* out_algorithms); - - // Enqueues a double-precision convolution operation onto the stream. + // Enqueues a double-precision convolution operation (without bias add + // or activation) onto the stream. // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& batch_descriptor, @@ -965,7 +1014,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data) = 0; - // Enqueues a half-precision convolution operation onto the stream. + // Enqueues a half-precision convolution operation (without bias add + // or activation) onto the stream. // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& batch_descriptor, @@ -979,6 +1029,10 @@ class DnnSupport { const dnn::AlgorithmConfig& algorithm_config, ProfileResult* output_profile_result) = 0; + // Return a list of algorithms supported by the forward convolution pass. + virtual bool GetConvolveAlgorithms( + bool with_winograd_nonfused, std::vector* out_algorithms); + // Version of DoConvolve that uses pre-quantized 8 bit coefficients. // coefficient_scales specifies the scaling of each column of coefficients: // original float coefficient[row * num_columns + column] = diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index a393b077034..bb586c58485 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -350,9 +350,65 @@ Stream &Stream::ThenConvolveWithScratch( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory &filter_data, const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output, + DeviceMemory *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(filter_descriptor), PARAM(filter_data), PARAM(convolution_descriptor), PARAM(output_descriptor), @@ -362,9 +418,9 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -389,9 +445,74 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } } else { SetErrorAndLogNoDnnSupport(); } @@ -461,6 +582,21 @@ Stream &Stream::ThenConvolveWithAlgorithm( return *this; } +Stream &Stream::ThenConvolve( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output) { + return ThenConvolveWithScratch( + input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, /*scratch_allocator=*/nullptr); +} + Stream &Stream::ThenConvolve( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -582,7 +718,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -676,7 +812,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -718,7 +854,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -779,7 +915,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -3868,7 +4004,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3900,7 +4036,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3934,7 +4070,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3973,7 +4109,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 5b46b86f54a..bc1d05cc08c 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -240,6 +240,16 @@ class Stream { DeviceMemory *offset_backprop); // TODO(leary) add double-precision version of this interface. + Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output); + Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, const dnn::FilterDescriptor &filter_descriptor, @@ -268,6 +278,27 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output_data); + Stream &ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator); + + Stream &ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveWithScratch( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -286,6 +317,31 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + + Stream &ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + Stream &ThenConvolveWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, From 9ae941c4a8c2d6e5a87c7a200ebde5bd0b07e5b2 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 1 Jun 2017 18:40:21 -0700 Subject: [PATCH 39/72] Turn reductions along an empty set of dimensions into identity nodes. PiperOrigin-RevId: 157792209 --- tensorflow/core/grappler/op_types.cc | 14 ++-- tensorflow/core/grappler/op_types.h | 1 + .../grappler/optimizers/constant_folding.cc | 68 +++++++++++++++++-- .../grappler/optimizers/constant_folding.h | 5 +- .../optimizers/constant_folding_test.cc | 41 +++++++++++ tensorflow/core/grappler/utils.cc | 4 ++ tensorflow/core/grappler/utils.h | 4 ++ 7 files changed, 127 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 7a239aeffec..ebe380070de 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -29,10 +29,10 @@ bool IsConstant(const NodeDef& node) { } bool IsDequeueOp(const NodeDef& node) { - static const std::set dequeue_ops = { - "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", - "QueueDequeue", "QueueDequeueUpToV2", "QueueDequeueUpTo"}; - return dequeue_ops.count(node.op()) > 0; + const auto& op = node.op(); + return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" || + op == "QueueDequeueV2" || op == "QueueDequeue" || + op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo"; } bool IsMerge(const NodeDef& node) { @@ -46,6 +46,12 @@ bool IsPlaceholder(const NodeDef& node) { op == "PlaceholderWithDefault"; } +bool IsReduction(const NodeDef& node) { + const auto& op = node.op(); + return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || + op == "Mean" || op == "Any" || op == "All"; +} + bool IsTranspose(const NodeDef& node) { const auto op = node.op(); return op == "Transpose"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 04bb78149f7..d32487c1286 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -26,6 +26,7 @@ bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsMerge(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsReduction(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsVariable(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index c9169d63f4b..ea5bfe164b3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -232,7 +232,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { } for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { continue; } bool is_const = IsConstant(*node_map_->GetNode(input)); @@ -267,7 +267,7 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, Status ConstantFolding::EvaluateNode(const NodeDef& node, const TensorVector& inputs, - TensorVector* output) { + TensorVector* output) const { Status status; auto op_kernel = CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node, @@ -299,7 +299,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, std::vector* outputs) { TensorVector inputs; for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { break; } TensorVector output; @@ -337,12 +337,12 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) { node_map_->AddNode(added_node->name(), added_node); for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { *added_node->add_input() = input; } else { NodeDef* input_node = node_map_->GetNode(input); for (const auto& fanin_of_input : input_node->input()) { - if (fanin_of_input[0] == '^') { + if (IsControlInput(fanin_of_input)) { *added_node->add_input() = fanin_of_input; } } @@ -396,6 +396,60 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { return Status::OK(); } +// Returns true iff this reduction can be reduced to an identity (i.e if the set +// of dimensions to reduce along is empty). This happens often in the gradient +// graphs. +bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const { + if (IsReduction(node)) { + CHECK_LE(2, node.input_size()); + const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); + if (IsConstant(*reductions_indices)) { + TensorVector output; + Status s = EvaluateNode(*reductions_indices, TensorVector(), &output); + if (!s.ok()) { + return false; + } + CHECK_EQ(1, output.size()); + int output_size = output[0]->NumElements(); + delete output[0].tensor; + if (output_size == 0) { + return true; + } + } + } + return false; +} + +Status ConstantFolding::SimplifyGraph(GraphDef* output) { + for (auto& node : *output->mutable_node()) { + if (IsSimplifiableReduction(node)) { + // Replace the reduction node with an identity node, that can be further + // optimized by the model pruner. + const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); + DataType output_type; + if (node.attr().count("T") > 0) { + output_type = node.attr().at("T").type(); + } else { + // This is an 'any' or 'all' reduction. The output is always boolean. + output_type = DT_BOOL; + } + node.set_op("Identity"); + node.clear_attr(); + (*node.mutable_attr())["T"].set_type(output_type); + if (node.input_size() > 2) { + node.mutable_input()->SwapElements(1, node.input_size() - 1); + } + node.mutable_input()->RemoveLast(); + for (const auto& input : reductions_indices->input()) { + if (IsControlInput(input)) { + *node.add_input() = input; + } + } + } + } + return Status::OK(); +} + Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { graph_ = item.graph; @@ -404,10 +458,14 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, for (const auto& node : item.fetch) { nodes_to_preserve_.insert(NodeName(node)); } + for (const auto& node : item.feed) { + nodes_to_preserve_.insert(NodeName(node.first)); + } device_.reset(new DeviceSimple()); *output = GraphDef(); TF_RETURN_IF_ERROR(MaterializeShapes(item)); TF_RETURN_IF_ERROR(FoldGraph(output)); + TF_RETURN_IF_ERROR(SimplifyGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index fd77fc945e3..9689e97a123 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -50,7 +50,7 @@ class ConstantFolding : public GraphOptimizer { Status EvaluateNode(const NodeDef& node, const gtl::InlinedVector& inputs, - gtl::InlinedVector* output); + gtl::InlinedVector* output) const; Status EvaluateOneFoldable(const NodeDef& node, std::vector* outputs); @@ -59,6 +59,9 @@ class ConstantFolding : public GraphOptimizer { Status FoldGraph(GraphDef* output); + bool IsSimplifiableReduction(const NodeDef& node) const; + Status SimplifyGraph(GraphDef* output); + std::unique_ptr device_; GraphDef graph_; std::unique_ptr node_map_; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 58bbb817d0b..87e42c72e24 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -245,6 +245,47 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) { EXPECT_EQ(3, found); } +TEST_F(ConstantFoldingTest, NoOpReduction) { + // Build a simple graph with a reduction that can be reduced to the identity. + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output d = ops::Const(scope.WithOpName("d"), 3.14f, {3, 5, 7}); + Output v = ops::PlaceholderWithDefault(scope.WithOpName("v"), d, {3, 5, 7}); + Output c = ops::Const(scope.WithOpName("c"), 0, {0}); + Output i = ops::Identity(scope.WithOpName("i"), c); + Output p = ops::Prod(scope.WithOpName("p"), v, i); + Output s = ops::Square(scope.WithOpName("s"), p); + + GrapplerItem item; + item.fetch.push_back("s"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + ASSERT_EQ("c", item.graph.node(2).name()); + (*item.graph.mutable_node(2)->add_input()) = "^v"; + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + auto expected = EvaluateNodes(item.graph, {"s"}); + auto optimized = EvaluateNodes(output, {"s"}); + EXPECT_EQ(1, expected.size()); + EXPECT_EQ(1, optimized.size()); + test::ExpectTensorEqual(expected[0], optimized[0]); + + bool found = false; + for (const auto& node : output.node()) { + if (node.name() == "p") { + found = true; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v", node.input(0)); + EXPECT_EQ("^v", node.input(1)); + } + } + EXPECT_TRUE(found); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 06ef61a9613..b7a04f4423d 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -83,6 +83,10 @@ string ParseNodeName(const string& name, int* position) { } } +bool IsControlInput(const string& name) { + return !name.empty() && name[0] == '^'; +} + string NodeName(const string& name) { int position; return ParseNodeName(name, &position); diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 17b980c5b8c..fd2fb60c9b9 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -46,6 +46,10 @@ class NodeMap { std::unordered_map> outputs_; }; +// True iff 'name' refers to a control inputs, i.e. a node name prefixed with +// the ^ character. +bool IsControlInput(const string& name); + // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. string NodeName(const string& name); From 0503ce09c74b2249a38bd4a6254b0b5964836ec3 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 1 Jun 2017 18:49:33 -0700 Subject: [PATCH 40/72] Wipe out previous shape inference result when importing a grappler item Run graph optimizations last: since they can be expensive it's best to filter invalid items first. PiperOrigin-RevId: 157792834 --- .../core/grappler/grappler_item_builder.cc | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 8f7333f1dbf..1ad3cbb4cb9 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -134,14 +134,6 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( new_item->id = id; new_item->graph = meta_graph.graph_def(); - // Optimize the graph (function inlining, l1 optimizations, etc). - Status optimize_status = - OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg); - if (!optimize_status.ok()) { - LOG(ERROR) << "Function optimization failed: " << optimize_status; - return nullptr; - } - // Attempt to detect the fetch node(s). if (meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); @@ -250,6 +242,10 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto; } + // Erase the recorded result of any previous shape inference to start again + // from scratch. + node.mutable_attr()->erase("_output_shapes"); + // Delete user specified placement if requested. if (cfg.ignore_user_placement) { node.clear_device(); @@ -329,6 +325,14 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } } + // Optimize the graph (function inlining, l1 optimizations, etc). + Status optimize_status = + OptimizeGraph(new_item->graph, &new_item->graph, cfg); + if (!optimize_status.ok()) { + LOG(ERROR) << "Function optimization failed: " << optimize_status; + return nullptr; + } + return new_item; } From b659bc39f27e81b3249f73710671059589c5daa1 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 1 Jun 2017 19:24:34 -0700 Subject: [PATCH 41/72] Simplify TensorBoard build - Remove tensorboard_typescript_genrule - Remove tensorboard_typescript_bundle - Introduce ts_web_library Skylark rule which supports seamless TypeScript compilation. - Use Closure Compiler in semi-advanced mode to compile JavaScript. This is done in a way that preserves diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD index 1e599cb710f..18009043d23 100644 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_audio_dashboard", srcs = [ "tf-audio-dashboard.html", @@ -17,7 +17,7 @@ web_library( "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_icon_button", "@org_polymer_paper_slider", "@org_polymer_paper_spinner", @@ -25,7 +25,7 @@ web_library( ], ) -web_library( +ts_web_library( name = "index", srcs = [ "demo/index.html", @@ -35,11 +35,11 @@ web_library( deps = [ ":tf_audio_dashboard", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/demo:demo_data", "@org_polymer_iron_component_page", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html index 177bc85db0d..dc8cd91d439 100644 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html @@ -42,14 +42,16 @@ limitations under the License. + + + diff --git a/tensorflow/tensorboard/components/tf_backend/tf-backend.html b/tensorflow/tensorboard/components/tf_backend/tf-backend.html index 5bf26633628..4cfed247a5e 100644 --- a/tensorflow/tensorboard/components/tf_backend/tf-backend.html +++ b/tensorflow/tensorboard/components/tf_backend/tf-backend.html @@ -20,4 +20,8 @@ limitations under the License. - + + + + + diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD index 3ec3d26051f..afe98ec5b5e 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/BUILD +++ b/tensorflow/tensorboard/components/tf_color_scale/BUILD @@ -1,57 +1,37 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_color_scale", srcs = [ - "bundle.js", + "colorScale.ts", + "palettes.ts", "tf-color-scale.html", ], path = "/tf-color-scale", deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-color-scale", deps = [ ":tf_color_scale", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_button", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF": [ - "palettes.ts", - "colorScale.ts", - ]}, -) - filegroup( name = "all_files", srcs = glob(["**"]), diff --git a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts index ff90d46aa24..6916e3bb2dd 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts +++ b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts @@ -19,9 +19,8 @@ limitations under the License. // ccs.domain(runs); // ccs.getColor("train"); // ccs.getColor("test1"); -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import {palettes} from './palettes' +import {palettes} from './palettes' export class ColorScale { private palette: string[]; @@ -29,8 +28,8 @@ export class ColorScale { /** * Creates a color scale with optional custom palette. - * @param {string[]} [palette=palettes.googleColorBlind] - The color - * palette you want as an Array of hex strings. + * @param {Array} [palette=palettes.googleColorBlind] - The color + * palette you want as an Array of hex strings. */ constructor(palette: string[] = palettes.googleColorBlindAssist) { this.palette = palette; @@ -38,8 +37,8 @@ export class ColorScale { /** * Set the domain of strings. - * @param {string[]} strings - An array of possible strings to use as the - * domain for your scale. + * @param {Array} strings - An array of possible strings to use as the + * domain for your scale. */ public domain(strings: string[]): this { this.identifiers = d3.map(); diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD index 6071b20886e..dab2779dc3c 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD +++ b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD @@ -3,43 +3,25 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "colorScaleTests.ts", "tests.html", ], path = "/tf-color-scale/test", deps = [ "//tensorflow/tensorboard/components/tf_color_scale", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "//tensorflow/tensorboard/components/tf_color_scale:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF": ["colorScaleTests.ts"]}, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html index eccc32cdec5..59c802d02bf 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html +++ b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html @@ -21,4 +21,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html index 3dedfaf1a1c..a325f0a04cd 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html @@ -26,5 +26,6 @@ a set of colors. @element tf-color-scale --> - + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD index f9a990e3799..b504fe79f99 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD +++ b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD @@ -1,33 +1,33 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_dashboard_common", srcs = [ + "dashboard-behavior.ts", "dashboard-style.html", + "reload-behavior.ts", "run-color-style.html", "scrollbar-style.html", "tensorboard-color.html", "tf-categorizer.html", - "tf-categorizer-bundle.js", + "tf-categorizer.ts", "tf-chart-scaffold.html", "tf-collapsable-pane.html", "tf-dashboard.html", - "tf-dashboard.js", "tf-dashboard-layout.html", "tf-downloader.html", "tf-multi-checkbox.html", - "tf-multi-checkbox-bundle.js", + "tf-multi-checkbox.ts", "tf-no-data-warning.html", "tf-option-selector.html", "tf-panes-helper.html", "tf-regex-group.html", - "tf-regex-group-bundle.js", + "tf-regex-group.ts", "tf-run-selector.html", "tf-sidebar-helper.html", ], @@ -35,9 +35,9 @@ web_library( deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_storage", "//tensorflow/tensorboard/components/vz_sorting", - "@org_polymer", "@org_polymer_iron_ajax", "@org_polymer_iron_collapse", "@org_polymer_iron_icons", @@ -56,7 +56,7 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = [ "tf-categorizer-demo.html", @@ -73,91 +73,9 @@ web_library( ], ) -tensorboard_typescript_bundle( - name = "tf_categorizer_bundle", - out = "tf-categorizer-bundle.ts", - namespace_srcs = {"TF.Dashboard.Categorizer": ["tf-categorizer.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.Categorizer": {"compareTagNames": "VZ.Sorting.compareTagNames"}}, -) - -tensorboard_typescript_genrule( - name = "tf_categorizer_ts", - srcs = ["tf-categorizer-bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/vz_sorting:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_regex_group_bundle", - out = "tf-regex-group-bundle.ts", - namespace_srcs = {"TF.Dashboard.RegexGroup": ["tf-regex-group.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.RegexGroup": {"storage": "TF.URIStorage"}}, -) - -tensorboard_typescript_genrule( - name = "tf_regex_group_ts", - srcs = ["tf-regex-group-bundle.ts"], - typings = [ - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_storage:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_multi_checkbox_bundle", - out = "tf-multi-checkbox-bundle.ts", - namespace_srcs = {"TF.Dashboard.MultiCheckbox": ["tf-multi-checkbox.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.MultiCheckbox": {"storage": "TF.URIStorage"}}, -) - -tensorboard_typescript_genrule( - name = "tf_multi_checkbox_ts", - srcs = ["tf-multi-checkbox-bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_storage:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_dashboard_bundle", - out = "tf-dashboard.ts", - namespace_srcs = { - "TF.Dashboard": [ - "dashboard-behavior.ts", - "reload-behavior.ts", - ], - }, -) - -tensorboard_typescript_genrule( - name = "tf_dashboard_ts", - srcs = ["tf-dashboard.ts"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - tensorboard_webcomponent_library( name = "legacy", - srcs = glob(["*.html"]) + [":legacy_ts"], + srcs = [":tf_dashboard_common"], destdir = "tf-dashboard-common", deps = [ "//tensorflow/tensorboard/components/tf_imports_google:lib", @@ -182,19 +100,8 @@ tensorboard_webcomponent_library( ], ) -tensorboard_ts_library( - name = "legacy_ts", - srcs = [ - "dashboard-behavior.ts", - "reload-behavior.ts", - "tf-categorizer.ts", - ], - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//tensorflow/tensorboard/components/vz_sorting:legacy_ts", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/polymer:polymer_without_externs", - ], +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], ) diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts index 3e40da14528..aa063c74220 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts @@ -16,6 +16,8 @@ limitations under the License. /** * A behavior that TensorBoard dashboards must implement. This behavior serves * the purpose of an interface. + * + * @polymerBehavior */ export function DashboardBehavior(dashboardName) { return { diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts index 8b5ca120d60..61fe0c07812 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts @@ -20,6 +20,8 @@ limitations under the License. * and call a `reload` method on that child. * May later extend it so it has more sophisticated logic, e.g. reloading * only tags that are in view. + * + * @polymerBehavior */ export function ReloadBehavior(tagName) { return { diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD index e82c4bd63cd..3cad646b967 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD +++ b/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD @@ -3,44 +3,25 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", "tests.html", + "tf-categorizer-tests.ts", ], path = "/tf-dashboard-common/test", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "//tensorflow/tensorboard/components/tf_dashboard_common:tf-categorizer-bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF.Dashboard": ["tf-categorizer-tests.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard": {"cat": "TF.Dashboard.Categorizer"}}, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html b/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html index cd33cee4742..c9ad14730f0 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html @@ -21,4 +21,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html index 6388ab5e7d4..f09eb03582d 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html @@ -59,5 +59,5 @@ categories are exclusive. } - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts index ebece842461..0eaf852ff13 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash'; - -import {compareTagNames} from '../vz_sorting/sorting'; +import {compareTagNames} from '../vz-sorting/sorting'; /** * This module contains methods that allow sorting tags into 'categories'. diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html index 475c2cef3bd..9e2f6b9589b 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html @@ -22,4 +22,5 @@ limitations under the License. - + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html index 8a56616f820..fad4642963f 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html @@ -156,5 +156,5 @@ handle these situations gracefully. } - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts index 44a14a21cfe..4b38d82b14e 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as _ from 'lodash'; import * as storage from '../tf-storage/storage'; Polymer({ diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html index e68b306ee33..c1d3cf06aea 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html @@ -95,5 +95,5 @@ more regexes). - + diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD index dcd5047bf49..fe089b80b42 100644 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_distribution_dashboard", srcs = ["tf-distribution-dashboard.html"], path = "/tf-distribution-dashboard", @@ -13,24 +13,24 @@ web_library( "//tensorflow/tensorboard/components/tf_color_scale", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/vz_distribution_chart", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_paper_icon_button", "@org_polymer_paper_styles", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-distribution-dashboard", deps = [ ":tf_distribution_dashboard", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html b/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html index 5e825f13f5c..2c300446480 100644 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html +++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html @@ -43,14 +43,17 @@ limitations under the License. diff --git a/tensorflow/tensorboard/components/tf_globals/BUILD b/tensorflow/tensorboard/components/tf_globals/BUILD index ca59c2fb93a..0ffefd79682 100644 --- a/tensorflow/tensorboard/components/tf_globals/BUILD +++ b/tensorflow/tensorboard/components/tf_globals/BUILD @@ -1,29 +1,23 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_globals", srcs = [ - "bundle.js", + "globals.ts", "tf-globals.html", ], path = "/tf-globals", ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF.Globals": ["globals.ts"]}, +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_globals"], + destdir = "tf-globals", ) filegroup( @@ -31,25 +25,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-globals.html", - ":legacy_ts", - ], - destdir = "tf-globals", -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = ["globals.ts"], - deps_mgmt = "off", - runtime = "nodejs", -) diff --git a/tensorflow/tensorboard/components/tf_globals/globals.ts b/tensorflow/tensorboard/components/tf_globals/globals.ts index 7d4229dccb0..fb6bb83b97f 100644 --- a/tensorflow/tensorboard/components/tf_globals/globals.ts +++ b/tensorflow/tensorboard/components/tf_globals/globals.ts @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - - // The names of TensorBoard tabs. export const TABS = [ 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', diff --git a/tensorflow/tensorboard/components/tf_globals/tf-globals.html b/tensorflow/tensorboard/components/tf_globals/tf-globals.html index b0fd74d4f20..efb8e92e080 100644 --- a/tensorflow/tensorboard/components/tf_globals/tf-globals.html +++ b/tensorflow/tensorboard/components/tf_globals/tf-globals.html @@ -15,5 +15,5 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/components/tf_graph/BUILD b/tensorflow/tensorboard/components/tf_graph/BUILD index 115964a59bd..92d2e8a42a1 100644 --- a/tensorflow/tensorboard/components/tf_graph/BUILD +++ b/tensorflow/tensorboard/components/tf_graph/BUILD @@ -1,10 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph", srcs = [ "tf-graph.html", @@ -15,7 +16,7 @@ web_library( deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_graph_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_iron_flex_layout", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -28,26 +29,28 @@ web_library( ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph"], + destdir = "tf-graph", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//third_party/javascript/polymer/v1/iron-flex-layout:lib", + "//third_party/javascript/polymer/v1/iron-icons:lib", + "//third_party/javascript/polymer/v1/paper-button:lib", + "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", + "//third_party/javascript/polymer/v1/paper-input:lib", + "//third_party/javascript/polymer/v1/paper-menu:lib", + "//third_party/javascript/polymer/v1/paper-radio-group:lib", + "//third_party/javascript/polymer/v1/paper-toggle-button:lib", + "//third_party/javascript/polymer/v1/paper-tooltip:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph.html", - "tf-graph-minimap.html", - "tf-graph-scene.html", - ], - destdir = "tf-graph", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/BUILD b/tensorflow/tensorboard/components/tf_graph/demo/BUILD index 524d0ff7679..b578a51798b 100644 --- a/tensorflow/tensorboard/components/tf_graph/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph/demo", @@ -13,9 +13,9 @@ web_library( "//tensorflow/tensorboard/components/tf_graph", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_loader", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html index 10a65f54d52..ccf8ecc697a 100644 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html @@ -941,7 +941,7 @@ Polymer({ delete this._nodeGroupIndex[n]; }, addEdgeGroup: function(n, selection) { - this._edgeGroupIndex[e] = selection; + this._edgeGroupIndex[n] = selection; }, getEdgeGroup: function(e) { return this._edgeGroupIndex[e]; diff --git a/tensorflow/tensorboard/components/tf_graph_app/BUILD b/tensorflow/tensorboard/components/tf_graph_app/BUILD index 415b20598ec..af568893821 100644 --- a/tensorflow/tensorboard/components/tf_graph_app/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_app/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_app", srcs = [ "index.html", @@ -16,9 +16,23 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_controls", "//tensorflow/tensorboard/components/tf_graph_loader", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_component_page", - "@org_polymer_webcomponentsjs", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_app"], + destdir = "tf-graph-app", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_board:legacy", + "//tensorflow/tensorboard/components/tf_graph_controls:legacy", + "//tensorflow/tensorboard/components/tf_graph_loader:legacy", + "//third_party/javascript/polymer/v1/iron-component-page:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + "//third_party/javascript/polymer/v1/webcomponentsjs:lib", ], ) @@ -27,23 +41,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "index.html", - "tf-graph-app.html", - ], - destdir = "tf-graph-app", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - "//third_party/javascript/polymer/v1/iron-list:lib", - "//third_party/javascript/polymer/v1/paper-radio-group:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD index 147cb0947c4..0f984664ce2 100644 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_app/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-app/demo", diff --git a/tensorflow/tensorboard/components/tf_graph_board/BUILD b/tensorflow/tensorboard/components/tf_graph_board/BUILD index f1c1ed1fc0f..14a66166582 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_board/BUILD @@ -1,44 +1,38 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_board", - srcs = [ - "tf-graph-board.html", - ], + srcs = ["tf-graph-board.html"], path = "/tf-graph-board", deps = [ "//tensorflow/tensorboard/components/tf_graph", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_info", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_progress", ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_board"], + destdir = "tf-graph-board", + deps = [ + "//tensorflow/tensorboard/components/tf_graph:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_info:legacy", + "//third_party/javascript/polymer/v1/paper-progress:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-board.html", - ], - destdir = "tf-graph-board", - deps = [ - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_info:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD index 2d668769e62..4bf52c5a567 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_board/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-board/demo", @@ -13,9 +13,9 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_loader", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html b/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html index 0ee694e1e66..79409ce2a0c 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html +++ b/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html @@ -180,10 +180,9 @@ Polymer({ graph: Object, stats: Object, /** - * @type {value: number, msg: string} - * * A number between 0 and 100 denoting the % of progress * for the progress bar and the displayed message. + * @type {{value: number, msg: string}} */ progress: Object, colorBy: String, diff --git a/tensorflow/tensorboard/components/tf_graph_common/BUILD b/tensorflow/tensorboard/components/tf_graph_common/BUILD index a372ab8279b..25e0403aa34 100644 --- a/tensorflow/tensorboard/components/tf_graph_common/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_common/BUILD @@ -1,15 +1,31 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_common", srcs = [ + "annotation.ts", + "colors.ts", + "common.ts", + "contextmenu.ts", + "edge.ts", + "externs.ts", + "graph.ts", + "hierarchy.ts", + "layout.ts", + "minimap.ts", + "node.ts", + "parser.ts", + "proto.ts", + "render.ts", + "scene.ts", + "template.ts", "tf-graph-common.html", - ":ts", + "util.ts", ], path = "/tf-graph-common", deps = [ @@ -17,18 +33,17 @@ web_library( "//tensorflow/tensorboard/components/tf_imports:dagre", "//tensorflow/tensorboard/components/tf_imports:graphlib", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = glob(["*.ts"]), - typings = [ - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_common"], + destdir = "tf-graph-common", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_google:lib", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -37,36 +52,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-common.html", - ":legacy_ts", - ], - destdir = "tf-graph-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = glob(["*.ts"]), - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/webcomponents_js", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/BUILD index 65cafa9570a..7004b7145a3 100644 --- a/tensorflow/tensorboard/components/tf_graph_controls/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_controls/BUILD @@ -1,19 +1,18 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_controls", - srcs = [ - "tf-graph-controls.html", - ], + srcs = ["tf-graph-controls.html"], path = "/tf-graph-controls", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_graph_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_button", "@org_polymer_paper_dropdown_menu", "@org_polymer_paper_menu", @@ -23,25 +22,25 @@ web_library( ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_controls"], + destdir = "tf-graph-controls", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//third_party/javascript/polymer/v1/paper-button:lib", + "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", + "//third_party/javascript/polymer/v1/paper-menu:lib", + "//third_party/javascript/polymer/v1/paper-radio-group:lib", + "//third_party/javascript/polymer/v1/paper-toggle-button:lib", + "//third_party/javascript/polymer/v1/paper-tooltip:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-controls.html", - ], - destdir = "tf-graph-controls", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD index c47cb90a03e..cd86ac7320a 100644 --- a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD @@ -1,19 +1,19 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_controls/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-graph-controls/demo", deps = [ "//tensorflow/tensorboard/components/tf_graph_controls", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD index d1866b5d807..20f9d3990b5 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD @@ -1,14 +1,13 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_dashboard", - srcs = [ - "tf-graph-dashboard.html", - ], + srcs = ["tf-graph-dashboard.html"], path = "/tf-graph-dashboard", deps = [ "//tensorflow/tensorboard/components/tf_backend", @@ -17,7 +16,24 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_controls", "//tensorflow/tensorboard/components/tf_graph_loader", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/vz_sorting", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_dashboard"], + destdir = "tf-graph-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend:legacy", + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph:legacy", + "//tensorflow/tensorboard/components/tf_graph_board:legacy", + "//tensorflow/tensorboard/components/tf_graph_controls:legacy", + "//tensorflow/tensorboard/components/tf_graph_loader:legacy", + "//tensorflow/tensorboard/components/vz_sorting:legacy", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -26,23 +42,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-dashboard.html", - ], - destdir = "tf-graph-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD index 3658f45b153..58cd2854c57 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD @@ -1,19 +1,19 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_dashboard/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-dashboard/demo", deps = [ "//tensorflow/tensorboard/components/tf_graph_dashboard", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html index 67756cc1298..2035e87898a 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html @@ -37,14 +37,17 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/tf_imports/dagre.html b/tensorflow/tensorboard/components/tf_imports/dagre.html index 1e2f6ef9af6..cb57b9a5cd8 100644 --- a/tensorflow/tensorboard/components/tf_imports/dagre.html +++ b/tensorflow/tensorboard/components/tf_imports/dagre.html @@ -42,4 +42,4 @@ THE SOFTWARE. - + diff --git a/tensorflow/tensorboard/components/tf_imports/graphlib.html b/tensorflow/tensorboard/components/tf_imports/graphlib.html index 783e33be0a6..05942123ab0 100644 --- a/tensorflow/tensorboard/components/tf_imports/graphlib.html +++ b/tensorflow/tensorboard/components/tf_imports/graphlib.html @@ -17,4 +17,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_imports/lodash.html b/tensorflow/tensorboard/components/tf_imports/lodash.html index cbe35f10505..65ff6a4b032 100644 --- a/tensorflow/tensorboard/components/tf_imports/lodash.html +++ b/tensorflow/tensorboard/components/tf_imports/lodash.html @@ -15,4 +15,4 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/numericjs.html b/tensorflow/tensorboard/components/tf_imports/numericjs.html index 7559054aaba..81fa9491688 100644 --- a/tensorflow/tensorboard/components/tf_imports/numericjs.html +++ b/tensorflow/tensorboard/components/tf_imports/numericjs.html @@ -40,4 +40,4 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/plottable.html b/tensorflow/tensorboard/components/tf_imports/plottable.html index 2c3e10a7c44..77ad544d5a0 100644 --- a/tensorflow/tensorboard/components/tf_imports/plottable.html +++ b/tensorflow/tensorboard/components/tf_imports/plottable.html @@ -40,5 +40,5 @@ THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/threejs.html b/tensorflow/tensorboard/components/tf_imports/threejs.html index d6adad43b03..7f4233b5713 100644 --- a/tensorflow/tensorboard/components/tf_imports/threejs.html +++ b/tensorflow/tensorboard/components/tf_imports/threejs.html @@ -39,5 +39,5 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - - + + diff --git a/tensorflow/tensorboard/components/tf_imports/weblas.html b/tensorflow/tensorboard/components/tf_imports/weblas.html index 054d04ea85e..c07020598fc 100644 --- a/tensorflow/tensorboard/components/tf_imports/weblas.html +++ b/tensorflow/tensorboard/components/tf_imports/weblas.html @@ -39,4 +39,4 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_option_selector/BUILD b/tensorflow/tensorboard/components/tf_option_selector/BUILD index 6f79ac536ab..cd0150529e7 100644 --- a/tensorflow/tensorboard/components/tf_option_selector/BUILD +++ b/tensorflow/tensorboard/components/tf_option_selector/BUILD @@ -1,16 +1,16 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_option_selector", srcs = ["tf-option-selector.html"], path = "/tf-option-selector", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD index f2a491a2b25..2de11a231e6 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_scalar_dashboard", srcs = [ "tf-scalar-dashboard.html", @@ -16,8 +16,8 @@ web_library( "//tensorflow/tensorboard/components/tf_color_scale", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/vz_line_chart", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_paper_checkbox", "@org_polymer_paper_dropdown_menu", diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD index 3b135d68afc..497767363ec 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD @@ -1,22 +1,22 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-scalar-dashboard/demo", deps = [ "//tensorflow/tensorboard/components/tf_backend", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/components/tf_scalar_dashboard", "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html index 7429c87b873..10cf83b2e9a 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html @@ -45,14 +45,17 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/tf_storage/tf-storage.html b/tensorflow/tensorboard/components/tf_storage/tf-storage.html index 91b8976519d..ff3f7b0ad4a 100644 --- a/tensorflow/tensorboard/components/tf_storage/tf-storage.html +++ b/tensorflow/tensorboard/components/tf_storage/tf-storage.html @@ -18,4 +18,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_tensorboard/BUILD b/tensorflow/tensorboard/components/tf_tensorboard/BUILD index b649bb53f2a..72f9a0852ae 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/BUILD +++ b/tensorflow/tensorboard/components/tf_tensorboard/BUILD @@ -1,16 +1,16 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") +load("//tensorflow/tensorboard:vulcanize.bzl", "tensorboard_html_binary") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_tensorboard", srcs = [ + "autoReloadBehavior.ts", "style.html", "tf-tensorboard.html", - ":ts", ], path = "/tf-tensorboard", visibility = ["//visibility:public"], @@ -23,11 +23,11 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_dashboard", "//tensorflow/tensorboard/components/tf_histogram_dashboard", "//tensorflow/tensorboard/components/tf_image_dashboard", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_scalar_dashboard", "//tensorflow/tensorboard/components/tf_storage", "//tensorflow/tensorboard/components/tf_text_dashboard", "//tensorflow/tensorboard/components/vz_projector", - "@org_polymer", "@org_polymer_font_roboto", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -40,20 +40,22 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = ["demo.html"], path = "/tf-tensorboard", deps = [ ":tf_tensorboard", "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["autoReloadBehavior.ts"], +tensorboard_html_binary( + name = "devserver", + testonly = 1, + input_path = "/tf-tensorboard/demo.html", + output_path = "/index.html", + deps = [":demo"], ) filegroup( diff --git a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts b/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts index 1f6b4cf6419..54df16f5b5d 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts @@ -12,49 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.TensorBoard { - export var AUTORELOAD_LOCALSTORAGE_KEY = 'TF.TensorBoard.autoReloadEnabled'; - var getAutoReloadFromLocalStorage: () => boolean = () => { - var val = window.localStorage.getItem(AUTORELOAD_LOCALSTORAGE_KEY); - return val === 'true' || val == null; // defaults to true - }; +export var AUTORELOAD_LOCALSTORAGE_KEY = 'TF.TensorBoard.autoReloadEnabled'; - export var AutoReloadBehavior = { - properties: { - autoReloadEnabled: { - type: Boolean, - observer: '_autoReloadObserver', - value: getAutoReloadFromLocalStorage, - }, - _autoReloadId: { - type: Number, - }, - autoReloadIntervalSecs: { - type: Number, - value: 30, - }, +var getAutoReloadFromLocalStorage: () => boolean = () => { + var val = window.localStorage.getItem(AUTORELOAD_LOCALSTORAGE_KEY); + return val === 'true' || val == null; // defaults to true +}; + +/** + * @polymerBehavior + */ +export var AutoReloadBehavior = { + properties: { + autoReloadEnabled: { + type: Boolean, + observer: '_autoReloadObserver', + value: getAutoReloadFromLocalStorage, }, - detached: function() { - window.clearTimeout(this._autoReloadId); + _autoReloadId: { + type: Number, }, - _autoReloadObserver: function(autoReload) { - window.localStorage.setItem(AUTORELOAD_LOCALSTORAGE_KEY, autoReload); - if (autoReload) { - var _this = this; - this._autoReloadId = window.setTimeout( - this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); - } else { - window.clearTimeout(this._autoReloadId); - } + autoReloadIntervalSecs: { + type: Number, + value: 30, }, - _doAutoReload: function() { - if (this.reload == null) { - throw new Error('AutoReloadBehavior requires a reload method'); - } - this.reload(); + }, + detached: function() { + window.clearTimeout(this._autoReloadId); + }, + _autoReloadObserver: function(autoReload) { + window.localStorage.setItem(AUTORELOAD_LOCALSTORAGE_KEY, autoReload); + if (autoReload) { + var _this = this; this._autoReloadId = window.setTimeout( this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); + } else { + window.clearTimeout(this._autoReloadId); } - }; -} + }, + _doAutoReload: function() { + if (this.reload == null) { + throw new Error('AutoReloadBehavior requires a reload method'); + } + this.reload(); + this._autoReloadId = window.setTimeout( + this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); + } +}; diff --git a/tensorflow/tensorboard/components/tf_tensorboard/demo.html b/tensorflow/tensorboard/components/tf_tensorboard/demo.html index c8a9238aef0..f691f6211bc 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/demo.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/demo.html @@ -18,7 +18,6 @@ limitations under the License. TensorBoard Demo - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts index 0f049d40ab6..b68fd8c9438 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts @@ -12,19 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +import {AUTORELOAD_LOCALSTORAGE_KEY, AutoReloadBehavior} from '../autoReloadBehavior'; + declare function fixture(id: string): void; + window.HTMLImports.whenReady(() => { Polymer({ is: 'autoreload-test-element', - behaviors: [TF.TensorBoard.AutoReloadBehavior], + behaviors: [AutoReloadBehavior], }); describe('autoReload-behavior', function() { - var testElement; - var ls = window.localStorage; - var key = TF.TensorBoard.AUTORELOAD_LOCALSTORAGE_KEY; - var clock; - var callCount: number; + let testElement; + const ls = window.localStorage; + const key = AUTORELOAD_LOCALSTORAGE_KEY; + let clock; + let callCount: number; beforeEach(function() { ls.setItem(key, 'false'); // start it turned off so we can mutate fns diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts index 2308298ced9..a00027963be 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {TABS} from '../../tf-globals/globals'; + describe('end-to-end test', () => { window.HTMLImports.whenReady(() => { let tb = d3.select('tf-tensorboard'); var tabs = (tb.node()).$.tabs; function testTab(tabIndex: number) { - it(`selecting ${TF.Globals.TABS[tabIndex]} tab`, done => { + it(`selecting ${TABS[tabIndex]} tab`, done => { // Every dashboard emits a rendered event when it is done rendering. tb.on('rendered', () => done()); tabs.set('selected', tabIndex); @@ -32,7 +34,7 @@ describe('end-to-end test', () => { // have failed. Re-selecting the default tab and listening for // "rendered" event won't work since the content is not re-stamped. let selected = +tabs.get('selected'); - for (let i = 0; i < TF.Globals.TABS.length; i++) { + for (let i = 0; i < TABS.length; i++) { if (i !== selected) { testTab(i); } diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts index 4dd62a0c382..905ed4ee4aa 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {TABS} from '../../tf-globals/globals'; + describe('fast tab switch', () => { window.HTMLImports.whenReady(() => { let tb = d3.select('tf-tensorboard'); + // tslint:disable-next-line:no-any be quiet tsc var tabs = (tb.node()).$.tabs; // This test will select the events tab. Once the events tab @@ -23,9 +26,9 @@ describe('fast tab switch', () => { // the images tab wihout waiting for the graph tab to finish // rendering. Finally, it finishes when the images tab // has rendered and no errors were thrown. - let eventsTabIndex = TF.Globals.TABS.indexOf('events'); - let imagesTabIndex = TF.Globals.TABS.indexOf('images'); - let graphTabIndex = TF.Globals.TABS.indexOf('graphs'); + const eventsTabIndex = TABS.indexOf('events'); + const imagesTabIndex = TABS.indexOf('images'); + const graphTabIndex = TABS.indexOf('graphs'); // Listen for when the events tab rendered. tb.on('rendered', () => { diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts index 3c7fe2c9e72..33e11e3094d 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +import * as backend_router from '../../tf-backend/router'; +import {TABS} from '../../tf-globals/globals'; + describe('tf-tensorboard tests', () => { window.HTMLImports.whenReady(() => { let tensorboard: any; @@ -25,16 +29,16 @@ describe('tf-tensorboard tests', () => { setTimeout(function() { let tabs = tensorboard.$.tabs.getElementsByTagName('paper-tab'); let tabMode = Array.prototype.map.call(tabs, (x) => x.dataMode); - chai.assert.deepEqual(tabMode, TF.Globals.TABS, 'mode is correct'); + chai.assert.deepEqual(tabMode, TABS, 'mode is correct'); let tabText = Array.prototype.map.call(tabs, (x) => x.innerText.toLowerCase()); - chai.assert.deepEqual(tabText, TF.Globals.TABS, 'text is correct'); + chai.assert.deepEqual(tabText, TABS, 'text is correct'); done(); }); }); it('respects router manually provided', function() { - let router = TF.Backend.router('data', true); + const router = backend_router.router('data', true); tensorboard.router = router; tensorboard.demoDir = null; chai.assert.equal(tensorboard._backend.router, router); @@ -46,7 +50,7 @@ describe('tf-tensorboard tests', () => { }); describe('reloading the selected dashboard', function() { - TF.Globals.TABS.forEach((name, tabIndex) => { + TABS.forEach((name, tabIndex) => { // These tabs do not support reload mode. if (name === 'graphs' || name === 'projections') { return; @@ -70,7 +74,7 @@ describe('tf-tensorboard tests', () => { }); it('reload is disabled for graph dashboard', function(done) { - let idx = TF.Globals.TABS.indexOf('graphs'); + const idx = TABS.indexOf('graphs'); chai.assert.notEqual(idx, -1, 'graphs was found'); tensorboard.$.tabs.set('selected', idx); setTimeout( diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html index ac3132fadaf..00a30686f69 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html @@ -44,7 +44,6 @@ tf-tensorboard is the frontend entry point for TensorBoard. It implements a toolbar (via paper-header-panel and paper-toolbar) that allows the user to toggle between various dashboards. --> - + - diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD index a1a97778280..b6dfdbefb4c 100644 --- a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_text_dashboard", srcs = [ "tf-text-dashboard.html", @@ -17,7 +17,7 @@ web_library( "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_dialog", "@org_polymer_paper_icon_button", "@org_polymer_paper_material", @@ -26,15 +26,15 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-text-dashboard", deps = [ ":tf_text_dashboard", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/index.html b/tensorflow/tensorboard/components/tf_text_dashboard/index.html index 77d19b948c9..d01f4777ed3 100644 --- a/tensorflow/tensorboard/components/tf_text_dashboard/index.html +++ b/tensorflow/tensorboard/components/tf_text_dashboard/index.html @@ -44,6 +44,9 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts index 17e35978249..f3911d301d9 100644 --- a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts +++ b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts @@ -12,13 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* tslint:disable:no-namespace variable-name */ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash' -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable -import {Dataset} from 'Plottable/plottable'; -import * as ChartHelpers from '../vz_line_chart/vz-chart-helpers'; +import * as ChartHelpers from '../vz-line-chart/vz-chart-helpers'; export class DistributionChart { private run2datasets: {[run: string]: Plottable.Dataset}; diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD index 005090b8e06..6f21df0c865 100644 --- a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD +++ b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD @@ -1,29 +1,41 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_histogram_timeseries", srcs = ["vz-histogram-timeseries.html"], path = "/vz-histogram-timeseries", deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/vz-histogram-timeseries", deps = [ ":vz_histogram_timeseries", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_button", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":vz_histogram_timeseries"], + visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], + destdir = "vz-histogram-timeseries", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_google:lib", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -32,22 +44,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "index.html", - "vz-histogram-timeseries.html", - ], - visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], - destdir = "vz-histogram-timeseries", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) diff --git a/tensorflow/tensorboard/components/vz_line_chart/BUILD b/tensorflow/tensorboard/components/vz_line_chart/BUILD index c641587158b..7d8d0d60749 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/BUILD +++ b/tensorflow/tensorboard/components/vz_line_chart/BUILD @@ -1,16 +1,17 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_line_chart", srcs = [ - "bundle.js", + "dragZoomInteraction.ts", + "vz-chart-helpers.ts", "vz-line-chart.html", + "vz-line-chart.ts", ], path = "/vz-line-chart", visibility = ["//visibility:public"], @@ -18,11 +19,11 @@ web_library( "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", "//tensorflow/tensorboard/components/tf_imports:plottable", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/vz-line-chart", @@ -30,60 +31,12 @@ web_library( ":vz_line_chart", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/tf_imports:plottable.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.ChartHelpers": [ - "vz-chart-helpers.ts", - ], - "VZ": [ - "vz-line-chart.ts", - "dragZoomInteraction.ts", - ], - }, - namespace_symbol_aliases = { - "VZ.ChartHelpers": { - "Dataset": "Plottable.Dataset", - }, - }, -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - tensorboard_webcomponent_library( name = "legacy", - srcs = [ - "index.html", - "vz-line-chart.html", - ":legacy_ts", - ], + srcs = [":vz_line_chart"], visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], destdir = "vz-line-chart", deps = [ @@ -93,24 +46,8 @@ tensorboard_webcomponent_library( ], ) -tensorboard_ts_library( - name = "legacy_ts", - srcs = [ - "dragZoomInteraction.ts", - "vz-chart-helpers.ts", - "vz-line-chart.ts", - ], - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/plottable:bundle", - "//third_party/javascript/typings/chai", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/mocha", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/sinon", - "//third_party/javascript/typings/webcomponents_js", - ], +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], ) diff --git a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts b/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts index 2c1f4989c4c..c7f1f30e76b 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts @@ -13,11 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable - - export class DragZoomLayer extends Plottable.Components.SelectionBoxLayer { private _dragInteraction: Plottable.Interactions.Drag; private _doubleClickInteraction: Plottable.Interactions.Click; diff --git a/tensorflow/tensorboard/components/vz_line_chart/index.html b/tensorflow/tensorboard/components/vz_line_chart/index.html index fb571a51837..856ab7d1efe 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/index.html +++ b/tensorflow/tensorboard/components/vz_line_chart/index.html @@ -21,7 +21,6 @@ limitations under the License. vz-line-chart demo - diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts index cd8f1376172..fa89e06ada1 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts @@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* tslint:disable:no-namespace variable-name */ - - -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable -import {Dataset} from 'Plottable/plottable'; export interface Datum { wall_time: Date; @@ -123,6 +117,7 @@ export function computeDomain(values: number[], ignoreOutliers: boolean) { } export function accessorize(key: string): Plottable.IAccessor { + // tslint:disable-next-line:no-any be quiet tsc return (d: any, index: number, dataset: Plottable.Dataset) => d[key]; } @@ -157,19 +152,21 @@ export function wallX(): XComponents { accessor: (d: Datum) => d.wall_time, }; } -export let relativeAccessor = (d: any, index: number, dataset: Dataset) => { - // We may be rendering the final-point datum for scatterplot. - // If so, we will have already provided the 'relative' property - if (d.relative != null) { - return d.relative; - } - let data = dataset.data(); - // I can't imagine how this function would be called when the data is - // empty (after all, it iterates over the data), but lets guard just - // to be safe. - let first = data.length > 0 ? +data[0].wall_time : 0; - return (+d.wall_time - first) / (60 * 60 * 1000); // ms to hours -}; +export let relativeAccessor = + // tslint:disable-next-line:no-any be quiet tsc + (d: any, index: number, dataset: Plottable.Dataset) => { + // We may be rendering the final-point datum for scatterplot. + // If so, we will have already provided the 'relative' property + if (d.relative != null) { + return d.relative; + } + let data = dataset.data(); + // I can't imagine how this function would be called when the data is + // empty (after all, it iterates over the data), but lets guard just + // to be safe. + let first = data.length > 0 ? +data[0].wall_time : 0; + return (+d.wall_time - first) / (60 * 60 * 1000); // ms to hours + }; export let relativeFormatter = (n: number) => { // we will always show 2 units of precision, e.g days and hours, or diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html index 85e24ae4be0..38e0d7cb8d8 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html @@ -125,5 +125,7 @@ such as different X scales (linear and temporal), tooltips and smoothing. - + + + diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts index d50a7834f5f..5da6190ea24 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts @@ -14,10 +14,6 @@ limitations under the License. ==============================================================================*/ /* tslint:disable:no-namespace variable-name */ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash' -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable - import {DragZoomLayer} from './dragZoomInteraction' import * as ChartHelpers from './vz-chart-helpers' @@ -142,7 +138,7 @@ Polymer({ * Sets the series that the chart displays. Series with other names will * not be displayed. * - * @param {String[]} names Array with the names of the series to + * @param {Array} names Array with the names of the series to * display. */ setVisibleSeries: function(names) { @@ -157,8 +153,8 @@ Polymer({ * Sets the data of one of the series. Note that to display this series * its name must be in the setVisibleSeries() array. * - * @param {String} name Name of the series. - * @param {VZ.ChartHelpers.ScalarDatum[]} data Data of the series. This is + * @param {string} name Name of the series. + * @param {Array} data Data of the series. This is * an array of objects with at least the following properties: * - step: (Number) - index of the datum. * - wall_time: (Date) - Date object with the datum's time. diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD b/tensorflow/tensorboard/components/vz_projector/BUILD index c1adeabbf53..6d22554efa5 100644 --- a/tensorflow/tensorboard/components/vz_projector/BUILD +++ b/tensorflow/tensorboard/components/vz_projector/BUILD @@ -1,38 +1,69 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_projector", srcs = [ + "analyticsLogger.ts", "bundle.html", - "bundle.js", + "data.ts", + "data-provider.ts", + "data-provider-demo.ts", + "data-provider-proto.ts", + "data-provider-server.ts", + "external.d.ts", + "knn.ts", + "label.ts", + "logging.ts", + "projectorEventContext.ts", + "projectorScatterPlotAdapter.ts", + "renderContext.ts", + "scatterPlot.ts", + "scatterPlotRectangleSelector.ts", + "scatterPlotVisualizer.ts", + "scatterPlotVisualizer3DLabels.ts", + "scatterPlotVisualizerCanvasLabels.ts", + "scatterPlotVisualizerPolylines.ts", + "scatterPlotVisualizerSprites.ts", "styles.html", + "util.ts", + "vector.ts", "vz-projector.html", + "vz-projector.ts", "vz-projector-app.html", "vz-projector-bookmark-panel.html", + "vz-projector-bookmark-panel.ts", "vz-projector-colab.html", "vz-projector-dashboard.html", "vz-projector-data-panel.html", + "vz-projector-data-panel.ts", "vz-projector-input.html", + "vz-projector-input.ts", "vz-projector-inspector-panel.html", + "vz-projector-inspector-panel.ts", "vz-projector-legend.html", + "vz-projector-legend.ts", "vz-projector-metadata-card.html", + "vz-projector-metadata-card.ts", "vz-projector-projections-panel.html", + "vz-projector-projections-panel.ts", + "vz-projector-util.ts", ], path = "/vz-projector", visibility = ["//visibility:public"], deps = [ + ":bh_tsne", + ":heap", + ":sptree", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:numericjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_imports:threejs", "//tensorflow/tensorboard/components/tf_imports:weblas", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -53,298 +84,23 @@ web_library( ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "external.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:three.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - ], +ts_web_library( + name = "heap", + srcs = ["heap.ts"], + path = "/vz-projector", ) -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.Projector.Heap": ["heap.ts"], - "VZ.Projector.Label": ["label.ts"], - "VZ.Projector.SPTree": ["sptree.ts"], - "VZ.Projector.BhTsne": ["bh_tsne.ts"], - "VZ.Projector.Logging": ["logging.ts"], - "VZ.Projector.RenderContext": ["renderContext.ts"], - "VZ.Projector.ScatterPlotRectangleSelector": ["scatterPlotRectangleSelector.ts"], - "VZ.Projector.AnalyticsLogger": ["analyticsLogger.ts"], - "VZ.Projector.Util": ["util.ts"], - "VZ.Projector.Vector": ["vector.ts"], - "VZ.Projector.Knn": ["knn.ts"], - "VZ.Projector.Data": ["data.ts"], - "VZ.Projector.DataProvider": ["data-provider.ts"], - "VZ.Projector.DataProviderDemo": ["data-provider-demo.ts"], - "VZ.Projector.DataProviderProto": ["data-provider-proto.ts"], - "VZ.Projector.DataProviderServer": ["data-provider-server.ts"], - "VZ.Projector.ProjectorEventContext": ["projectorEventContext.ts"], - "VZ.Projector.ScatterPlot": ["scatterPlot.ts"], - "VZ.Projector.ScatterPlotVisualizer3DLabels": ["scatterPlotVisualizer3DLabels.ts"], - "VZ.Projector.ScatterPlotVisualizerCanvasLabels": ["scatterPlotVisualizerCanvasLabels.ts"], - "VZ.Projector.ScatterPlotVisualizerPolylines": ["scatterPlotVisualizerPolylines.ts"], - "VZ.Projector.ScatterPlotVisualizerSprites": ["scatterPlotVisualizerSprites.ts"], - "VZ.Projector.ScatterPlotVisualizer": ["scatterPlotVisualizer.ts"], - "VZ.Projector.ProjectorScatterPlotAdapter": ["projectorScatterPlotAdapter.ts"], - "VZ.Projector.ProjectorUtil": ["vz-projector-util.ts"], - "VZ.Projector.ProjectorBookmarkPanel": ["vz-projector-bookmark-panel.ts"], - "VZ.Projector.ProjectorDataPanel": ["vz-projector-data-panel.ts"], - "VZ.Projector.ProjectorInput": ["vz-projector-input.ts"], - "VZ.Projector.ProjectorInspectorPanel": ["vz-projector-inspector-panel.ts"], - "VZ.Projector.ProjectorLegend": ["vz-projector-legend.ts"], - "VZ.Projector.ProjectorMetadataCard": ["vz-projector-metadata-card.ts"], - "VZ.Projector.ProjectorProjectionsPanel": ["vz-projector-projections-panel.ts"], - "VZ.Projector": ["vz-projector.ts"], - }, - namespace_symbol_aliases = { - "VZ.Projector.AnalyticsLogger": { - "ProjectionType": "VZ.Projector.Data.ProjectionType", - }, - "VZ.Projector.BhTsne": { - "SPNode": "VZ.Projector.SPTree.SPNode", - "SPTree": "VZ.Projector.SPTree.SPTree", - }, - "VZ.Projector.DataProviderDemo": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "TENSORS_MSG_ID": "VZ.Projector.DataProvider.TENSORS_MSG_ID", - "dataProvider": "VZ.Projector.DataProvider", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.DataProviderProto": { - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataProto": "VZ.Projector.Data.DataProto", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "analyzeMetadata": "VZ.Projector.DataProvider.analyzeMetadata", - }, - "VZ.Projector.DataProviderServer": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "dataProvider": "VZ.Projector.DataProvider", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.DataProvider": { - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataSet": "VZ.Projector.Data.DataSet", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "logging": "VZ.Projector.Logging", - "runAsyncTask": "VZ.Projector.Util.runAsyncTask", - }, - "VZ.Projector.Data": { - "SpriteMetadata": "VZ.Projector.DataProvider.SpriteMetadata", - "TSNE": "VZ.Projector.BhTsne.TSNE", - "knn": "VZ.Projector.Knn", - "logging": "VZ.Projector.Logging", - "scatterPlot": "VZ.Projector.ScatterPlot", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.Knn": { - "KMin": "VZ.Projector.Heap.KMin", - "Vector": "VZ.Projector.Vector.Vector", - "logging": "VZ.Projector.Logging", - "runAsyncTask": "VZ.Projector.Util.runAsyncTask", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ProjectorEventContext": { - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "NearestEntry": "VZ.Projector.Knn.NearestEntry", - "Projection": "VZ.Projector.Data.Projection", - }, - "VZ.Projector.ProjectorScatterPlotAdapter": { - "DataSet": "VZ.Projector.Data.DataSet", - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "LabelRenderParams": "VZ.Projector.RenderContext.LabelRenderParams", - "NearestEntry": "VZ.Projector.Knn.NearestEntry", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionComponents3D": "VZ.Projector.Data.ProjectionComponents3D", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ScatterPlot": "VZ.Projector.ScatterPlot.ScatterPlot", - "ScatterPlotVisualizer3DLabels": "VZ.Projector.ScatterPlotVisualizer3DLabels.ScatterPlotVisualizer3DLabels", - "ScatterPlotVisualizerCanvasLabels": "VZ.Projector.ScatterPlotVisualizerCanvasLabels.ScatterPlotVisualizerCanvasLabels", - "ScatterPlotVisualizerPolylines": "VZ.Projector.ScatterPlotVisualizerPolylines.ScatterPlotVisualizerPolylines", - "ScatterPlotVisualizerSprites": "VZ.Projector.ScatterPlotVisualizerSprites.ScatterPlotVisualizerSprites", - "State": "VZ.Projector.Data.State", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ScatterPlot": { - "BoundingBox": "VZ.Projector.ScatterPlotRectangleSelector.BoundingBox", - "CameraType": "VZ.Projector.RenderContext.CameraType", - "LabelRenderParams": "VZ.Projector.RenderContext.LabelRenderParams", - "Point2D": "VZ.Projector.Vector.Point2D", - "Point3D": "VZ.Projector.Vector.Point3D", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotRectangleSelector": "VZ.Projector.ScatterPlotRectangleSelector.ScatterPlotRectangleSelector", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizer3DLabels": { - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerCanvasLabels": { - "BoundingBox": "VZ.Projector.Label.BoundingBox", - "CameraType": "VZ.Projector.RenderContext.CameraType", - "CollisionGrid": "VZ.Projector.Label.CollisionGrid", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerPolylines": { - "DataSet": "VZ.Projector.Data.DataSet", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerSprites": { - "CameraType": "VZ.Projector.RenderContext.CameraType", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizer": { - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - }, - "VZ.Projector.Util": { - "DataPoint": "VZ.Projector.Data.DataPoint", - "Point2D": "VZ.Projector.Vector.Point2D", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.Vector": { - "assert": "VZ.Projector.Util.assert", - }, - "VZ.Projector.ProjectorBookmarkPanel": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "State": "VZ.Projector.Data.State", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.ProjectorDataPanel": { - "ColorLegendRenderInfo": "VZ.Projector.ProjectorLegend.ColorLegendRenderInfo", - "ColorLegendThreshold": "VZ.Projector.ProjectorLegend.ColorLegendThreshold", - "ColorOption": "VZ.Projector.Data.ColorOption", - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "parseRawMetadata": "VZ.Projector.DataProvider.parseRawMetadata", - "parseRawTensors": "VZ.Projector.DataProvider.parseRawTensors", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ProjectorInput": { - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorInspectorPanel": { - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ProjectorInput": "VZ.Projector.ProjectorInput.ProjectorInput", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "adapter": "VZ.Projector.ProjectorScatterPlotAdapter", - "knn": "VZ.Projector.Knn", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ProjectorLegend": { - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorMetadataCard": { - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorProjectionsPanel": { - "DataSet": "VZ.Projector.Data.DataSet", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionType": "VZ.Projector.Data.ProjectionType", - "Projector": "VZ.Projector.Projector", - "ProjectorInput": "VZ.Projector.ProjectorInput.ProjectorInput", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "Vector": "VZ.Projector.Vector.Vector", - "data": "VZ.Projector.Data", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector": { - "AnalyticsLogger": "VZ.Projector.AnalyticsLogger.AnalyticsLogger", - "BookmarkPanel": "VZ.Projector.ProjectorBookmarkPanel.BookmarkPanel", - "ColorOption": "VZ.Projector.Data.ColorOption", - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataPanel": "VZ.Projector.ProjectorDataPanel.DataPanel", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataProto": "VZ.Projector.Data.DataProto", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "DemoDataProvider": "VZ.Projector.DataProviderDemo.DemoDataProvider", - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "DistanceMetricChangedListener": "VZ.Projector.ProjectorEventContext.DistanceMetricChangedListener", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "HoverListener": "VZ.Projector.ProjectorEventContext.HoverListener", - "InspectorPanel": "VZ.Projector.ProjectorInspectorPanel.InspectorPanel", - "MetadataCard": "VZ.Projector.ProjectorMetadataCard.MetadataCard", - "MouseMode": "VZ.Projector.ScatterPlot.MouseMode", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionChangedListener": "VZ.Projector.ProjectorEventContext.ProjectionChangedListener", - "ProjectionsPanel": "VZ.Projector.ProjectorProjectionsPanel.ProjectionsPanel", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ProjectorScatterPlotAdapter": "VZ.Projector.ProjectorScatterPlotAdapter.ProjectorScatterPlotAdapter", - "ProtoDataProvider": "VZ.Projector.DataProviderProto.ProtoDataProvider", - "SelectionChangedListener": "VZ.Projector.ProjectorEventContext.SelectionChangedListener", - "ServerDataProvider": "VZ.Projector.DataProviderServer.ServerDataProvider", - "ServingMode": "VZ.Projector.DataProvider.ServingMode", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "data": "VZ.Projector.Data", - "knn": "VZ.Projector.Knn", - "logging": "VZ.Projector.Logging", - "stateGetAccessorDimensions": "VZ.Projector.Data.stateGetAccessorDimensions", - "util": "VZ.Projector.Util", - }, - }, +ts_web_library( + name = "sptree", + srcs = ["sptree.ts"], + path = "/vz-projector", +) + +ts_web_library( + name = "bh_tsne", + srcs = ["bh_tsne.ts"], + path = "/vz-projector", + deps = [":sptree"], ) filegroup( @@ -352,97 +108,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -#### Legacy for other consumers -load( - "//tensorflow/tensorboard:defs.bzl", - "tensorboard_webcomponent_library", - "tensorboard_ts_library", - "tensorboard_ts_declaration", -) - -# Standalone embedding projector demos should depend on this target. We -# exclude the HTML file for the dashboard itself. Demos do not need that -# HTML file. This was introduced because standalone demos as of today -# have an additional Closure pass that uses a compilation configuration -# stricter than that of TensorBoard. - -_PROJECTOR_LIB_TS_LIB_DEPS = [ - ":ts_lib", - ":tsne_ts_lib", -] - -_PROJECTOR_DESTDIR = "vz-projector" - -_PROJECTOR_LIB_DEPS = [ - "//third_party/javascript/polymer/v1/iron-collapse:lib", - "//third_party/javascript/polymer/v1/iron-icons:lib", - "//third_party/javascript/polymer/v1/paper-button:lib", - "//third_party/javascript/polymer/v1/paper-checkbox:lib", - "//third_party/javascript/polymer/v1/paper-dialog:lib", - "//third_party/javascript/polymer/v1/paper-dialog-scrollable:lib", - "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", - "//third_party/javascript/polymer/v1/paper-icon-button:lib", - "//third_party/javascript/polymer/v1/paper-input:lib", - "//third_party/javascript/polymer/v1/paper-item:lib", - "//third_party/javascript/polymer/v1/paper-listbox:lib", - "//third_party/javascript/polymer/v1/paper-slider:lib", - "//third_party/javascript/polymer/v1/paper-spinner:lib", - "//third_party/javascript/polymer/v1/paper-toast:lib", - "//third_party/javascript/polymer/v1/paper-toggle-button:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - "//third_party/javascript/polymer/v1/polymer:lib", -] - -tensorboard_ts_library( - name = "tsne_ts_lib", - srcs = [ - "bh_tsne.ts", - "sptree.ts", - ], -) - -tensorboard_ts_declaration( - name = "external", - srcs = ["external.d.ts"], -) - -tensorboard_ts_library( - name = "ts_lib", - srcs = glob( - ["*.ts"], - exclude = [ - "*.d.ts", - "*_test.ts", - "bh_tsne.ts", - "sptree.ts", - ], - ), - runtime_deps = [ - "//third_party/javascript/d3/v4:d3", - "//third_party/javascript/numericjs", - "//third_party/javascript/threejs/r77:threejs", - "//third_party/javascript/threejs/r77/examples/js/controls:orbitcontrols", - "//third_party/javascript/weblas", - ], - deps = [ - ":external", - ":tsne_ts_lib", - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/threejs:three", - "//third_party/javascript/typings/webcomponents_js", - ], -) - -tensorboard_webcomponent_library( - name = "lib", - srcs = glob( - ["*.html"], - exclude = ["vz-projector-dashboard.html"], - ), - ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS, - destdir = _PROJECTOR_DESTDIR, - deps = _PROJECTOR_LIB_DEPS, -) diff --git a/tensorflow/tensorboard/components/vz_projector/bundle.html b/tensorflow/tensorboard/components/vz_projector/bundle.html index 2837fed8708..de87763673b 100644 --- a/tensorflow/tensorboard/components/vz_projector/bundle.html +++ b/tensorflow/tensorboard/components/vz_projector/bundle.html @@ -21,4 +21,36 @@ limitations under the License. - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts index 9d6df953d65..c0da9526598 100644 --- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts +++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 - import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data'; import {NearestEntry} from './knn'; import {ProjectorEventContext} from './projectorEventContext'; diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts index ece4d84ef28..2f3146d213c 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {BoundingBox, CollisionGrid} from './label'; import {CameraType, RenderContext} from './renderContext'; import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; diff --git a/tensorflow/tensorboard/components/vz_projector/test/BUILD b/tensorflow/tensorboard/components/vz_projector/test/BUILD index 7629272c350..a73c50dcd6d 100644 --- a/tensorflow/tensorboard/components/vz_projector/test/BUILD +++ b/tensorflow/tensorboard/components/vz_projector/test/BUILD @@ -3,76 +3,31 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "assert.ts", + "data-provider_test.ts", + "data_test.ts", + "sptree_test.ts", "tests.html", + "util_test.ts", + # "scatterPlotRectangleSelector_test.ts", + # "vz-projector-projections-panel_test.ts", ], path = "/vz-projector/test", deps = [ + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/components/vz_projector", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:three.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/tf_imports:plottable.d.ts", - "//tensorflow/tensorboard/components/vz_projector:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.Projector.Test": [ - "assert.ts", - "sptree_test.ts", - "data_test.ts", - "data-provider_test.ts", - "util_test.ts", - - # TODO(smilkov): Migrate these away from jasmine. - # "scatterPlotRectangleSelector_test.ts", - # "vz-projector-projections-panel_test.ts", - ], - }, - namespace_symbol_aliases = { - "VZ.Projector.Test": { - "BoundingBox": "VZ.Projector.ScatterPlotRectangleSelector.BoundingBox", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataSet": "VZ.Projector.Data.DataSet", - "ProjectionsPanel": "VZ.Projector.ProjectorProjectionsPanel.ProjectionsPanel", - "SPTree": "VZ.Projector.SPTree.SPTree", - "ScatterPlotRectangleSelector": "VZ.Projector.ScatterPlotRectangleSelector.ScatterPlotRectangleSelector", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "State": "VZ.Projector.Data.State", - "data_provider": "VZ.Projector.DataProvider", - "stateGetAccessorDimensions": "VZ.Projector.Data.stateGetAccessorDimensions", - "util": "VZ.Projector.Util", - }, - }, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/vz_projector/test/tests.html b/tensorflow/tensorboard/components/vz_projector/test/tests.html index dd43079bde1..a6843d0d6b8 100644 --- a/tensorflow/tensorboard/components/vz_projector/test/tests.html +++ b/tensorflow/tensorboard/components/vz_projector/test/tests.html @@ -21,4 +21,11 @@ limitations under the License. - + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector/vector.ts b/tensorflow/tensorboard/components/vz_projector/vector.ts index 0de78ad85df..cab30483138 100644 --- a/tensorflow/tensorboard/components/vz_projector/vector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vector.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {assert} from './util'; /** diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html index 55c15da5ed7..8223c503ecd 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html @@ -37,10 +37,9 @@ limitations under the License. diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts index a6847ed3c87..a9b6f6c5a06 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {ColorOption, ColumnStats, SpriteAndMetadataInfo} from './data'; import {DataProvider, EmbeddingInfo, parseRawMetadata, parseRawTensors, ProjectorConfig} from './data-provider'; import * as util from './util'; diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD index 96e270ce21f..fc309ce4a5d 100644 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/BUILD @@ -1,30 +1,24 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_sorting", srcs = [ - "bundle.js", + "sorting.ts", "vz-sorting.html", ], path = "/vz-sorting", visibility = ["//visibility:public"], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"VZ.Sorting": ["sorting.ts"]}, +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":vz_sorting"], + destdir = "vz-sorting", ) filegroup( @@ -32,25 +26,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "vz-sorting.html", - ":legacy_ts", - ], - destdir = "vz-sorting", -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = ["sorting.ts"], - deps_mgmt = "off", - runtime = "nodejs", -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD index 07913e3cbde..23e575945b5 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/test/BUILD @@ -3,41 +3,23 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "sortingTests.ts", "tests.html", ], path = "/vz-sorting/test", deps = [ + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", "//tensorflow/tensorboard/components/vz_sorting", - "@org_npmjs_registry_web_component_tester", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:mocha.d.ts", - "@org_definitelytyped//:chai.d.ts", - "//tensorflow/tensorboard/components/vz_sorting:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"VZ.Sorting": ["sortingTests.ts"]}, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html index d1b4a1db31c..c408690603f 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/tests.html +++ b/tensorflow/tensorboard/components/vz_sorting/test/tests.html @@ -18,6 +18,6 @@ limitations under the License. + - - + diff --git a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html index 9f925951cb2..5ff6f311589 100644 --- a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html +++ b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html @@ -15,4 +15,4 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl index 827a74b173f..b3712a8156d 100644 --- a/tensorflow/tensorboard/defs.bzl +++ b/tensorflow/tensorboard/defs.bzl @@ -12,83 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -_DEFAULT_TYPINGS = [ - "@com_microsoft_typescript//:lib.es6.d.ts", -] - -def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs): - """Filegroup of compiled TypeScript sources. - - This is a very unsophisticated TypeScript rule where the user is responsible - for passing all typings and sources via srcs. It's meant as a stopgap because - TypeScript rules currently don't exist for Bazel. The definition of this rule - will need to evolve as more ts_library rules are migrated. - """ - for src in srcs: - if (src.startswith("/") or - src.endswith(".d.ts") or - not src.endswith(".ts")): - fail("srcs must be typescript sources in same package") - typings_out = [src[:-3] + ".d.ts" for src in srcs] - inputs = _DEFAULT_TYPINGS + typings + srcs - # These inputs are meant to work around a sandbox bug in Bazel. If we list - # @com_microsoft_typescript//:tsc.sh under tools, then its - # data attribute won't be considered when --genrule_strategy=sandboxed. See - # https://github.com/bazelbuild/bazel/issues/1147 and its linked issues. - data = [ - "@org_nodejs", - "@com_microsoft_typescript", - ] - native.genrule( - name = name, - srcs = inputs + data, - outs = [src[:-3] + ".js" for src in srcs] + typings_out, - cmd = "$(location @com_microsoft_typescript//:tsc.sh)" + - " --inlineSourceMap" + - " --inlineSources" + - # Do not follow triple slash references within typings. - " --noResolve" + - " --declaration" + - " --module es6" + - " --outDir $(@D) " + - " ".join(["$(locations %s)" % i for i in inputs]), - tools = ["@com_microsoft_typescript//:tsc.sh"], - **kwargs - ) - native.filegroup( - name = name + "_typings", - srcs = typings_out, - **kwargs - ) - -def tensorboard_karma_web_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_config(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_declaration(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_development_sources(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_devserver(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_library(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - def tensorboard_webcomponent_library(**kwargs): """Rules referencing this will be deleted from the codebase soon.""" pass - -def tensorboard_wct_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD index 447dff55a3f..f2ea14503a0 100644 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD @@ -5,6 +5,11 @@ licenses(["notice"]) # Apache 2.0 java_binary( name = "Vulcanize", srcs = ["Vulcanize.java"], + jvm_flags = [ + "-Xss20m", # JSCompiler needs big stacks for recursive parsing + "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive + "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam + ], visibility = ["//visibility:public"], deps = [ "@com_google_guava", @@ -29,6 +34,21 @@ java_binary( ], ) +# These JS files are always taken into consideration by the Closure Compiler +# when vulcanizing, per vulcanize.bzl. +filegroup( + name = "jslibs", + srcs = [ + # Ordering probably matters + "@com_google_javascript_closure_compiler_externs", + "@com_google_javascript_closure_compiler_externs_polymer", + "externs.js", + "@com_google_javascript_closure_library//:closure/goog/base.js", + "@com_google_javascript_closure_library//:closure/goog/deps.js", + ], + visibility = ["//visibility:public"], +) + filegroup( name = "all_files", srcs = glob(["**"]), diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java index e572415856c..8ef0f31d1e2 100644 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java @@ -15,23 +15,33 @@ package org.tensorflow.tensorboard.vulcanize; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.CharMatcher; import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Splitter; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import com.google.javascript.jscomp.BasicErrorManager; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; import com.google.javascript.jscomp.CheckLevel; +import com.google.javascript.jscomp.CompilationLevel; import com.google.javascript.jscomp.Compiler; import com.google.javascript.jscomp.CompilerOptions; -import com.google.javascript.jscomp.CompilerOptions.LanguageMode; -import com.google.javascript.jscomp.CompilerOptions.Reach; +import com.google.javascript.jscomp.DiagnosticGroup; +import com.google.javascript.jscomp.DiagnosticGroups; +import com.google.javascript.jscomp.DiagnosticType; import com.google.javascript.jscomp.JSError; import com.google.javascript.jscomp.PropertyRenamingPolicy; +import com.google.javascript.jscomp.Result; import com.google.javascript.jscomp.SourceFile; -import com.google.javascript.jscomp.VariableRenamingPolicy; +import com.google.javascript.jscomp.WarningsGuard; import com.google.protobuf.TextFormat; import io.bazel.rules.closure.Webpath; import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; @@ -44,12 +54,17 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.jsoup.Jsoup; +import org.jsoup.nodes.Attribute; import org.jsoup.nodes.Comment; import org.jsoup.nodes.DataNode; import org.jsoup.nodes.Document; @@ -63,21 +78,45 @@ import org.jsoup.parser.Tag; /** Simple one-off solution for TensorBoard vulcanization. */ public final class Vulcanize { + private static final Pattern IGNORE_PATHS_PATTERN = + Pattern.compile("/(?:polymer|marked-element)/.*"); + + private static final ImmutableSet EXTRA_JSDOC_TAGS = + ImmutableSet.of("attribute", "hero", "group", "required"); + + private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); + private static final Parser parser = Parser.htmlParser(); private static final Map webfiles = new HashMap<>(); private static final Set alreadyInlined = new HashSet<>(); private static final Set legalese = new HashSet<>(); private static final List licenses = new ArrayList<>(); private static final List stack = new ArrayList<>(); + private static final List sourcesFromJsLibraries = new ArrayList<>(); + private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); + private static final Map sourceTags = new LinkedHashMap<>(); + private static final Multimap suppressions = HashMultimap.create(); + private static CompilationLevel compilationLevel; private static Webpath outputPath; + private static Node firstCompiledScript; private static Node licenseComment; - private static boolean nominify; + private static int insideDemoSnippet; + private static boolean testOnly; public static void main(String[] args) throws IOException { - Webpath inputPath = Webpath.get(args[0]); - outputPath = Webpath.get(args[1]); - Path output = Paths.get(args[2]); - for (int i = 3; i < args.length; i++) { + compilationLevel = CompilationLevel.fromString(args[0]); + testOnly = args[1].equals("true"); + Webpath inputPath = Webpath.get(args[2]); + outputPath = Webpath.get(args[3]); + Path output = Paths.get(args[4]); + for (int i = 5; i < args.length; i++) { + if (args[i].endsWith(".js")) { + sourcesFromJsLibraries.add(SourceFile.fromFile(args[i])); + continue; + } + if (!args[i].endsWith(".pbtxt")) { + continue; + } Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); for (WebfilesSource src : manifest.getSrcList()) { webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); @@ -86,6 +125,7 @@ public final class Vulcanize { stack.add(inputPath); Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); transform(document); + compile(); if (licenseComment != null) { licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); } @@ -134,72 +174,30 @@ public final class Vulcanize { } private static Node enterNode(Node node) throws IOException { - Node newNode = node; + if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet++; + } + if (insideDemoSnippet > 0) { + return node; + } if (node instanceof Element) { if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { // Inline HTML. - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - if (alreadyInlined.add(href)) { - newNode = - parse(Files.readAllBytes(checkNotNull(webfiles.get(href), "%s in %s", href, me()))); - stack.add(href); - node.replaceWith(newNode); - } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); - } - } else if (node.nodeName().equals("script")) { - nominify = node.hasAttr("nominify"); - node.removeAttr("nominify"); - Webpath src; - String script; - if (node.attr("src").isEmpty()) { - // Minify JavaScript. - StringBuilder sb = new StringBuilder(); - for (Node child : node.childNodes()) { - if (child instanceof DataNode) { - sb.append(((DataNode) child).getWholeData()); - } - } - src = me(); - script = sb.toString(); - } else { - // Inline JavaScript. - src = me().lookup(Webpath.get(node.attr("src"))); - Path other = webfiles.get(src); - if (other != null) { - script = new String(Files.readAllBytes(other), UTF_8); - node.removeAttr("src"); - } else { - src = me(); - script = ""; - } - } - script = minify(src, script); - newNode = - new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) - .appendChild(new DataNode(script, node.baseUri())); - node.replaceWith(newNode); + node = visitHtmlImport(node); + } else if (node.nodeName().equals("script") + && !shouldIgnoreUri(node.attr("src")) + && !node.hasAttr("jscomp-ignore")) { + node = visitScript(node); } else if (node.nodeName().equals("link") && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty()) { - // Inline CSS. - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - Path other = webfiles.get(href); - if (other != null) { - newNode = - new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) - .appendChild( - new DataNode(new String(Files.readAllBytes(other), UTF_8), node.baseUri())); - newNode.removeAttr("rel"); - newNode.removeAttr("href"); - node.replaceWith(newNode); - } + && !node.attr("href").isEmpty() + && !shouldIgnoreUri(node.attr("href"))) { + node = visitStylesheet(node); } - rootifyAttribute(newNode, "href"); - rootifyAttribute(newNode, "src"); - rootifyAttribute(newNode, "action"); - rootifyAttribute(newNode, "assetpath"); + rootifyAttribute(node, "href"); + rootifyAttribute(node, "src"); + rootifyAttribute(node, "action"); + rootifyAttribute(node, "assetpath"); } else if (node instanceof Comment) { String text = ((Comment) node).getData(); if (text.contains("@license")) { @@ -207,53 +205,230 @@ public final class Vulcanize { if (licenseComment == null) { licenseComment = node; } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); + node = replaceNode(node, new TextNode("", node.baseUri())); } } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); + node = replaceNode(node, new TextNode("", node.baseUri())); } } + return node; + } + + private static Node leaveNode(Node node) { + if (node instanceof Document) { + stack.remove(stack.size() - 1); + } else if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet--; + } + return node; + } + + private static Node visitHtmlImport(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + if (alreadyInlined.add(href)) { + stack.add(href); + Document subdocument = parse(Files.readAllBytes(getWebfile(href))); + for (Attribute attr : node.attributes()) { + subdocument.attr(attr.getKey(), attr.getValue()); + } + return replaceNode(node, subdocument); + } else { + return replaceNode(node, new TextNode("", node.baseUri())); + } + } + + private static Node visitScript(Node node) throws IOException { + Webpath path; + String script; + if (node.attr("src").isEmpty()) { + path = makeSyntheticName(".js"); + script = getInlineScriptFromNode(node); + } else { + path = me().lookup(Webpath.get(node.attr("src"))); + script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); + } + if (node.attr("src").endsWith(".min.js") + || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { + Node newScript = + new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) + .appendChild(new DataNode(script, node.baseUri())) + .removeAttr("src") + .removeAttr("jscomp-nocompile"); + if (firstCompiledScript != null) { + firstCompiledScript.before(newScript); + return replaceNode(node, new TextNode("", node.baseUri())); + } else { + return replaceNode(node, newScript); + } + } else { + if (firstCompiledScript == null) { + firstCompiledScript = node; + } + sourcesFromScriptTags.put(path, script); + sourceTags.put(path, node); + Optional suppress = getAttrTransitive(node, "jscomp-suppress"); + if (suppress.isPresent()) { + if (suppress.get().isEmpty()) { + suppressions.put(path, "*"); + } else { + suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); + } + } + return node; + } + } + + private static Node visitStylesheet(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + return replaceNode( + node, + new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) + .appendChild( + new DataNode( + new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) + .removeAttr("rel") + .removeAttr("href")); + } + + private static Optional getAttrTransitive(Node node, String attr) { + while (node != null) { + if (node.hasAttr(attr)) { + return Optional.of(node.attr(attr)); + } + node = node.parent(); + } + return Optional.absent(); + } + + private static Node replaceNode(Node oldNode, Node newNode) { + oldNode.replaceWith(newNode); return newNode; } - private static String minify(Webpath src, String script) { - if (nominify) { - return script; + private static Path getWebfile(Webpath path) { + return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); + } + + private static void compile() { + if (sourcesFromScriptTags.isEmpty()) { + return; } - Compiler compiler = new Compiler(new JsPrintlessErrorManager()); + CompilerOptions options = new CompilerOptions(); - options.skipAllCompilerPasses(); // too lazy to get externs - options.setLanguageIn(LanguageMode.ECMASCRIPT_2016); - options.setLanguageOut(LanguageMode.ECMASCRIPT5); + compilationLevel.setOptionsForCompilationLevel(options); + + // Nice options. + options.setColorizeErrorOutput(true); options.setContinueAfterErrors(true); - options.setManageClosureDependencies(false); - options.setRenamingPolicy(VariableRenamingPolicy.LOCAL, PropertyRenamingPolicy.OFF); - options.setShadowVariables(true); - options.setInlineVariables(Reach.LOCAL_ONLY); - options.setFlowSensitiveInlineVariables(true); - options.setInlineFunctions(Reach.LOCAL_ONLY); - options.setAssumeClosuresOnlyCaptureReferences(false); + options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); + options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); + options.setGenerateExports(true); + options.setStrictModeInput(false); + options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); + + // So we can chop JS binary back up into the original script tags. + options.setPrintInputDelimiter(true); + options.setInputDelimiter("//~~WEBPATH~~%name%"); + + // Optimizations that are too advanced for us right now. + options.setPropertyRenaming(PropertyRenamingPolicy.OFF); options.setCheckGlobalThisLevel(CheckLevel.OFF); - options.setFoldConstants(true); - options.setCoalesceVariableNames(true); - options.setDeadAssignmentElimination(true); - options.setCollapseVariableDeclarations(true); - options.setConvertToDottedProperties(true); - options.setLabelRenaming(true); - options.setRemoveDeadCode(true); - options.setOptimizeArgumentsArray(true); - options.setRemoveUnusedVariables(Reach.LOCAL_ONLY); - options.setCollapseObjectLiterals(true); - options.setProtectHiddenSideEffects(true); - //options.setPrettyPrint(true); + options.setRemoveUnusedPrototypeProperties(false); + options.setRemoveUnusedPrototypePropertiesInExterns(false); + options.setRemoveUnusedClassProperties(false); + + // Closure pass. + options.setClosurePass(true); + options.setManageClosureDependencies(true); + options.getDependencyOptions().setDependencyPruning(true); + options.getDependencyOptions().setDependencySorting(false); + options.getDependencyOptions().setMoocherDropping(false); + + // Polymer pass. + options.setPolymerVersion(1); + + // Debug flags. + if (testOnly) { + options.setPrettyPrint(true); + options.setGeneratePseudoNames(true); + options.setExportTestFunctions(true); + } + + // Don't print warnings from + diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java index 8ef0f31d1e2..2635f9b12f1 100644 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java @@ -181,18 +181,20 @@ public final class Vulcanize { return node; } if (node instanceof Element) { - if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { - // Inline HTML. - node = visitHtmlImport(node); - } else if (node.nodeName().equals("script") - && !shouldIgnoreUri(node.attr("src")) - && !node.hasAttr("jscomp-ignore")) { - node = visitScript(node); - } else if (node.nodeName().equals("link") - && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty() - && !shouldIgnoreUri(node.attr("href"))) { - node = visitStylesheet(node); + if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { + if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { + // Inline HTML. + node = visitHtmlImport(node); + } else if (node.nodeName().equals("script") + && !shouldIgnoreUri(node.attr("src")) + && !node.hasAttr("jscomp-ignore")) { + node = visitScript(node); + } else if (node.nodeName().equals("link") + && node.attr("rel").equals("stylesheet") + && !node.attr("href").isEmpty() + && !shouldIgnoreUri(node.attr("href"))) { + node = visitStylesheet(node); + } } rootifyAttribute(node, "href"); rootifyAttribute(node, "src"); diff --git a/tensorflow/tensorboard/vulcanize.bzl b/tensorflow/tensorboard/vulcanize.bzl index e444bbe9dfc..c82b8cafdb2 100644 --- a/tensorflow/tensorboard/vulcanize.bzl +++ b/tensorflow/tensorboard/vulcanize.bzl @@ -17,10 +17,10 @@ load("//tensorflow/tensorboard:web.bzl", "web_aspect") def _tensorboard_html_binary(ctx): deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set() - files = set() - jslibs = set(ctx.files._jslibs) - webpaths = set() + manifests = depset(order="topological") + files = depset() + jslibs = depset(ctx.files._jslibs) + webpaths = depset() for dep in deps: manifests += dep.webfiles.manifests webpaths += dep.webfiles.webpaths @@ -75,12 +75,12 @@ def _tensorboard_html_binary(ctx): ctx.executable._WebfilesServer.short_path, long_path(ctx, params_file))) - transitive_runfiles = set() + transitive_runfiles = depset() transitive_runfiles += ctx.attr._WebfilesServer.data_runfiles.files for dep in deps: transitive_runfiles += dep.data_runfiles.files return struct( - files=set([ctx.outputs.html]), + files=depset([ctx.outputs.html]), webfiles=struct( manifest=manifest, manifests=manifests, @@ -101,8 +101,7 @@ tensorboard_html_binary = rule( "input_path": attr.string(mandatory=True), "output_path": attr.string(mandatory=True), "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list( - aspects=[web_aspect], providers=["webfiles"], mandatory=True), + "deps": attr.label_list(aspects=[web_aspect], mandatory=True), "external_assets": attr.string_dict(default={"/_/runfiles": "."}), "_jslibs": attr.label( default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:jslibs"), diff --git a/tensorflow/tensorboard/web.bzl b/tensorflow/tensorboard/web.bzl index 0b2ed66a57d..d3c585a1fd6 100644 --- a/tensorflow/tensorboard/web.bzl +++ b/tensorflow/tensorboard/web.bzl @@ -362,7 +362,7 @@ ts_web_library = rule( attrs={ "path": attr.string(), "srcs": attr.label_list(allow_files=True), - "deps": attr.label_list(aspects=[web_aspect], providers=["webfiles"]), + "deps": attr.label_list(aspects=[web_aspect]), "exports": attr.label_list(), "data": attr.label_list(cfg="data", allow_files=True), "suppress": attr.string_list(), From a56d59a84bcd90ad9126b669dea5e6d8e38952f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 09:51:27 -0700 Subject: [PATCH 49/72] Set flow to a value during TensorArray creation, Re-enable tensor_array_ops_test in msan. PiperOrigin-RevId: 157841785 --- tensorflow/core/kernels/tensor_array_ops.cc | 12 +++++++++++- tensorflow/python/kernel_tests/BUILD | 1 - 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index b46b405ffbf..075bacb432b 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -101,7 +102,7 @@ Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { class TensorArrayCreationOp : public OpKernel { public: explicit TensorArrayCreationOp(OpKernelConstruction* context) - : OpKernel(context) {} + : OpKernel(context), device_type_(context->device_type()) {} void Compute(OpKernelContext* ctx) override { Tensor tensor_array_output_handle; @@ -133,6 +134,12 @@ class TensorArrayCreationOp : public OpKernel { // Create the flow output. Tensor* flow; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow)); + if (device_type_ == DEVICE_CPU) { + // Value doesn't matter, but this makes msan not complaint about + // copying an uninitialized value. To do this on GPU would require + // a kernel launch or a host->device memcpy, so we avoid that. + flow->flat()(0) = 0; + } } } @@ -140,6 +147,9 @@ class TensorArrayCreationOp : public OpKernel { virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, Tensor* tensor_array_output_handle, TensorArray** output_tensor_array) = 0; + + private: + const DeviceType device_type_; }; // A per-run local tensor array. The tensor array uses a "per-step" resource diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 8504091279e..c363ad6fe59 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1939,7 +1939,6 @@ cuda_py_test( "//tensorflow/python:variables", ], flaky = 1, # create_local_cluster sometimes times out. - tags = ["nomsan"], # b/38390993 ) cuda_py_test( From f661128dbf1b591202772a73878da6fff75c9432 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Fri, 2 Jun 2017 10:05:28 -0700 Subject: [PATCH 50/72] Remove unused overloads of SummarizeGraphDef and EqualGraphDef PiperOrigin-RevId: 157843404 --- tensorflow/core/framework/graph_def_util.cc | 8 ------ tensorflow/core/framework/graph_def_util.h | 1 - tensorflow/core/util/equal_graph_def.cc | 32 +++++++-------------- tensorflow/core/util/equal_graph_def.h | 2 -- 4 files changed, 10 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index d731003366a..b76ab40b683 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -39,14 +39,6 @@ string SummarizeGraphDef(const GraphDef& graph_def) { return ret; } -string SummarizeGraphDef(gtl::ArraySlice node_defs) { - string ret; - for (const NodeDef& node : node_defs) { - strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); - } - return ret; -} - Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { for (const NodeDef& node : graph_def.node()) { TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index 27e3de581ad..56355eaf367 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -27,7 +27,6 @@ namespace tensorflow { // Produce a human-readable version of a GraphDef that is more concise // than a text-format proto. string SummarizeGraphDef(const GraphDef& graph_def); -string SummarizeGraphDef(gtl::ArraySlice node_defs); // Validates the syntax of a GraphDef provided externally. // diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index 8ad91e5adb2..2db026da56c 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -24,10 +24,16 @@ limitations under the License. namespace tensorflow { -template -static bool EqualNodeDefsHelper( - const NodeDefs& actual, const protobuf::RepeatedPtrField& expected, - string* diff, const EqualGraphDefOptions& options) { +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff, const EqualGraphDefOptions& options) { + // Intentionally do not check that versions match so that this routine can + // be used for less brittle golden file tests. + return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options); +} + +bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, + const protobuf::RepeatedPtrField& expected, + string* diff, const EqualGraphDefOptions& options) { std::unordered_map actual_index; for (const NodeDef& node : actual) { actual_index[node.name()] = &node; @@ -62,24 +68,6 @@ static bool EqualNodeDefsHelper( return true; } -bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, - string* diff, const EqualGraphDefOptions& options) { - // Intentionally do not check that versions match so that this routine can - // be used for less brittle golden file tests. - return EqualNodeDefsHelper(actual.node(), expected.node(), diff, options); -} - -bool EqualGraphDef(gtl::ArraySlice actual, const GraphDef& expected, - string* diff, const EqualGraphDefOptions& options) { - return EqualNodeDefsHelper(actual, expected.node(), diff, options); -} - -bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, - const protobuf::RepeatedPtrField& expected, - string* diff, const EqualGraphDefOptions& options) { - return EqualNodeDefsHelper(actual, expected, diff, options); -} - namespace { string JoinStringField(const protobuf::RepeatedPtrField& f) { diff --git a/tensorflow/core/util/equal_graph_def.h b/tensorflow/core/util/equal_graph_def.h index 29d0385493f..1ce6181c2e7 100644 --- a/tensorflow/core/util/equal_graph_def.h +++ b/tensorflow/core/util/equal_graph_def.h @@ -36,8 +36,6 @@ struct EqualGraphDefOptions { // nodes must be consistent. bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, string* diff, const EqualGraphDefOptions& options = {}); -bool EqualGraphDef(gtl::ArraySlice actual, const GraphDef& expected, - string* diff, const EqualGraphDefOptions& options = {}); // Determines if actual and expected are equal, ignoring: ordering of // attrs, internal attributes (if set in `options`), and control inputs. From d5421cf58e4b84832974e51ebc2c3a11ad86efb7 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 2 Jun 2017 10:11:36 -0700 Subject: [PATCH 51/72] Add additional concat test. PiperOrigin-RevId: 157844113 --- tensorflow/compiler/xla/tests/concat_test.cc | 33 ++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 63bfac441d3..fcdbe130d0b 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { + ComputationBuilder builder(client_, TestName()); + + Array3D arr0(9, 17, 1); + arr0.Fill(1); + + Array3D arr1(9, 17, 256); + arr1.Fill(2); + + Array3D expected(9, 17, arr0.n3() + arr1.n3()); + for (int64 i = 0; i < expected.n1(); ++i) { + for (int64 j = 0; j < expected.n2(); ++j) { + int64 kk = 0; + for (const Array3D& arr : {arr0, arr1}) { + for (int64 k = 0; k < arr.n3(); ++k, ++kk) { + expected(i, j, kk) = arr(i, j, k); + } + } + } + } + + ComputationDataHandle h0; + auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", + &builder, &h0); + ComputationDataHandle h1; + auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", + &builder, &h1); + + auto concatenated = builder.ConcatInDim({h0, h1}, 2); + + ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; From 0f2db739163809782049b2c956355506c88c77e5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 2 Jun 2017 11:04:35 -0700 Subject: [PATCH 52/72] [TF:XLA] Split union-find implementation in mark_for_compilation_pass.cc into a separate library, make it more generic. PiperOrigin-RevId: 157850985 --- tensorflow/compiler/jit/BUILD | 6 ++ .../compiler/jit/mark_for_compilation_pass.cc | 80 +++--------------- tensorflow/compiler/jit/union_find.h | 81 +++++++++++++++++++ 3 files changed, 99 insertions(+), 68 deletions(-) create mode 100644 tensorflow/compiler/jit/union_find.h diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 749139cb1f3..9103fcf363b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -202,6 +202,7 @@ cc_library( deps = [ ":common", ":graph_to_functiondef", + ":union_find", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/kernels:xla_local_launch_op", @@ -221,6 +222,11 @@ cc_library( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + cc_test( name = "compilation_passes_test", size = "small", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ed9d9ad70e4..f1fef85f994 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -206,70 +207,12 @@ Status FindCompilationCandidates( return Status::OK(); } -// Union-Find data structure used to compute clusters. We use our own -// implementation because we want one key feature: when merging clusters, we -// need to know which value becomes the representative of the merged clusters. -// We use the representatives to name nodes in a cycle detection graph, and we -// need to control which node is named. -// TODO(phawkins): consider merging this code with union-find implementations -// in Tensorflow, e.g., in SimplePlacer. -class Cluster { - public: - Cluster(); - - int Size() { return FindRoot()->size_; } - - // Merges this cluster with 'other'. This cluster's representative becomes - // the representative of the merged cluster; the representative of 'other' - // is ignored. - void Merge(Cluster* other); - - // Each cluster has an associated integer 'representative', initialized to -1 - // by default. - int GetRepresentative() { return FindRoot()->representative_; } - void SetRepresentative(int representative) { - FindRoot()->representative_ = representative; - } - - private: - // Finds the root element of the cluster. Performs path compression. - Cluster* FindRoot(); - - int representative_; - int rank_; - int size_; // Size of the cluster. - Cluster* parent_; +struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; }; -Cluster::Cluster() - : representative_(-1), rank_(0), size_(1), parent_(nullptr) {} - -void Cluster::Merge(Cluster* other) { - Cluster* a = FindRoot(); - Cluster* b = other->FindRoot(); - if (a == b) return; - if (a->rank_ > b->rank_) { - b->parent_ = a; - a->size_ += b->size_; - return; - } - - a->parent_ = b; - if (a->rank_ == b->rank_) { - b->rank_++; - } - b->representative_ = a->representative_; - b->size_ += a->size_; -} - -Cluster* Cluster::FindRoot() { - if (!parent_) return this; - // Path compression: update intermediate nodes to point to the root of the - // equivalence class. - parent_ = parent_->FindRoot(); - return parent_; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl( // Each compilation candidate belongs to a cluster. The cluster's // representative // names the node in the 'cycles' graph that represents the cluster. - std::vector clusters(graph->num_node_ids()); - std::deque worklist; + std::vector> clusters(graph->num_node_ids()); + std::deque*> worklist; for (Node* node : compilation_candidates) { - clusters[node->id()].SetRepresentative(node->id()); + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); worklist.push_back(&clusters[node->id()]); } @@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. while (!worklist.empty()) { - int from = worklist.front()->GetRepresentative(); + int from = worklist.front()->Get().representative; worklist.pop_front(); Node* node_from = graph->FindNodeId(from); @@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl( // Count the number of elements in each cluster. std::vector cluster_sizes(graph->num_node_ids()); for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; cluster_sizes[cluster]++; } @@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl( // if compilation is enabled, otherwise there will be no such candidates). const int min_cluster_size = flags->tf_xla_min_cluster_size; for (Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/jit/union_find.h new file mode 100644 index 00000000000..a1a7a6a4d0d --- /dev/null +++ b/tensorflow/compiler/jit/union_find.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ + +namespace tensorflow { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template +class UnionFind { + public: + UnionFind() : rank_(0), size_(1), parent_(nullptr) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& Get() { return FindRoot()->value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int rank_; + int size_; // Size of the cluster. + UnionFind* parent_; + T value_; +}; + +template +void UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return; + if (a->rank_ > b->rank_) { + b->parent_ = a; + a->size_ += b->size_; + return; + } + + a->parent_ = b; + if (a->rank_ == b->rank_) { + b->rank_++; + } + b->value_ = a->value_; + b->size_ += a->size_; +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ From cd6c02985e482d344adf1f2bdf5a4980e53f726e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 11:08:48 -0700 Subject: [PATCH 53/72] Add 'streaming_curve_points' metric which returns curve [ROC, PR] approximation at specified number of points. PiperOrigin-RevId: 157851535 --- tensorflow/contrib/metrics/__init__.py | 2 + .../contrib/metrics/python/ops/metric_ops.py | 97 +++++++++++++++++++ .../metrics/python/ops/metric_ops_test.py | 93 ++++++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 85eecfc6cd7..4c16fb50407 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide. @@streaming_precision @@streaming_precision_at_thresholds @@streaming_auc +@@streaming_curve_points @@streaming_recall_at_k @@streaming_mean_absolute_error @@streaming_mean_iou @@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 727cdd9597a..c2211961dfb 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds( return values['tn'], update_ops['tn'] +def streaming_curve_points(labels=None, + predictions=None, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): + """Computes curve (ROC or PR) values for a prespecified number of points. + + The `streaming_curve_points` function creates four local variables, + `true_positives`, `true_negatives`, `false_positives` and `false_negatives` + that are used to compute the curve values. To discretize the curve, a linearly + spaced set of thresholds is used to compute pairs of recall and precision + values. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `auc` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + curve: Specifies the name of the curve to be computed, 'ROC' [default] or + 'PR' for the Precision-Recall-curve. + name: An optional variable_scope name. + + Returns: + points: A `Tensor` with shape [num_thresholds, 2] that contains points of + the curve. + update_op: An operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, + weights)): + if curve != 'ROC' and curve != 'PR': + raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + values, update_ops = _streaming_confusion_matrix_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights) + + # Add epsilons to avoid dividing by 0. + epsilon = 1.0e-6 + + def compute_points(tp, fn, tn, fp): + """Computes the roc-auc or pr-auc based on confusion counts.""" + rec = math_ops.div(tp + epsilon, tp + fn + epsilon) + if curve == 'ROC': + fp_rate = math_ops.div(fp, fp + tn + epsilon) + return fp_rate, rec + else: # curve == 'PR'. + prec = math_ops.div(tp + epsilon, tp + fp + epsilon) + return rec, prec + + xs, ys = compute_points(values['tp'], values['fn'], values['tn'], + values['fp']) + points = array_ops.stack([xs, ys], axis=1) + update_op = control_flow_ops.group(*update_ops.values()) + + if metrics_collections: + ops.add_to_collections(metrics_collections, points) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return points, update_op + + def streaming_auc(predictions, labels, weights=None, num_thresholds=200, metrics_collections=None, updates_collections=None, curve='ROC', name=None): @@ -2372,6 +2468,7 @@ __all__ = [ 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', + 'streaming_curve_points', 'streaming_false_negatives', 'streaming_false_negatives_at_thresholds', 'streaming_false_positives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f42e974e238..54994ec617c 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase): self.assertEqual(0, recall.eval()) +class StreamingCurvePointsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metric_ops.streaming_curve_points( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables( + self, + ('curve_points/true_positives:0', 'curve_points/false_negatives:0', + 'curve_points/false_positives:0', 'curve_points/true_negatives:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + points, _ = metric_ops.streaming_curve_points( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [points]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metric_ops.streaming_curve_points( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def _testValueTensorIsIdempotent(self, curve): + predictions = constant_op.constant( + np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32) + labels = constant_op.constant( + np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32) + + points, update_op = metric_ops.streaming_curve_points( + labels, predictions=predictions, curve=curve) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + sess.run(update_op) + initial_points = points.eval() + + sess.run(update_op) + self.assertAllClose(initial_points, points.eval()) + + def testValueTensorIsIdempotentROC(self): + self._testValueTensorIsIdempotent(curve='ROC') + + def testValueTensorIsIdempotentPR(self): + self._testValueTensorIsIdempotent(curve='PR') + + def _testCase(self, labels, predictions, curve, expected_points): + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32) + points, update_op = metric_ops.streaming_curve_points( + labels=labels_tensor, + predictions=predictions_tensor, + num_thresholds=3, + curve=curve) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self.assertAllClose(expected_points, points.eval()) + + def testEdgeCasesROC(self): + self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]]) + self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]]) + self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]]) + self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]]) + + def testManyValuesROC(self): + self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]], + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC', + [[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]]) + + def testEdgeCasesPR(self): + self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]]) + self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]]) + self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]]) + self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]]) + + def testManyValuesPR(self): + self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]], + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR', + [[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]]) + + class StreamingAUCTest(test.TestCase): def setUp(self): From 54595f0f38ce7157fdf2cb11a8acef72b9209d2f Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Fri, 2 Jun 2017 11:36:41 -0700 Subject: [PATCH 54/72] Adds the training test for LinearClassifier with n_classes=2. PiperOrigin-RevId: 157855473 --- .../python/estimator/canned/linear_test.py | 272 +++++++++++++++++- 1 file changed, 271 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py index 7fda9f0e540..1e10d5b1e42 100644 --- a/tensorflow/python/estimator/canned/linear_test.py +++ b/tensorflow/python/estimator/canned/linear_test.py @@ -624,7 +624,7 @@ def _assert_close(expected, actual, rtol=1e-04, name='assert_close'): with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: expected = ops.convert_to_tensor(expected, name='expected') actual = ops.convert_to_tensor(actual, name='actual') - rdiff = math_ops.abs(expected - actual, 'diff') / expected + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) rtol = ops.convert_to_tensor(rtol, name='rtol') return check_ops.assert_less( rdiff, @@ -845,5 +845,275 @@ class LinearRegressorTrainingTest(test.TestCase): expected_age_weight=age_weight, expected_bias=bias) + +class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _mockOptimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % _AGE_WEIGHT_NAME, + '%s/part_0:0' % _BIAS_NAME + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = _assert_close( + math_ops.to_float(expected_loss, name='expected'), loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assertCheckpoint( + self, expected_global_step, expected_age_weight=None, expected_bias=None): + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([1, 1], shapes[_AGE_WEIGHT_NAME]) + if expected_age_weight is not None: + self.assertEqual( + expected_age_weight, + checkpoint_utils.load_variable(self._model_dir, _AGE_WEIGHT_NAME)) + + self.assertEqual([1], shapes[_BIAS_NAME]) + if expected_bias is not None: + self.assertEqual( + expected_bias, + checkpoint_utils.load_variable(self._model_dir, _BIAS_NAME)) + + def testFromScratchWithDefaultOptimizer(self): + n_classes = 2 + label = 0 + age = 17 + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self._assertCheckpoint(num_steps) + + def testTrainWithTwoDimsLabel(self): + n_classes = 2 + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_2, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + + def testTrainWithOneDimLabel(self): + n_classes = 2 + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + + def testTrainWithTwoDimsWeight(self): + n_classes = 2 + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + + def testTrainWithOneDimWeight(self): + n_classes = 2 + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + + def testFromScratch(self): + n_classes = 2 + label = 1 + age = 17 + # loss = sigmoid_cross_entropy(logits, label) where logits = 0 (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 + mock_optimizer = self._mockOptimizer(expected_loss=0.69315) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assertCheckpoint( + expected_global_step=num_steps, + expected_age_weight=0., + expected_bias=0.) + + def testFromCheckpoint(self): + # Create initial checkpoint. + n_classes = 2 + label = 1 + age = 17 + age_weight = 2.0 + bias = -35.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME) + variables.Variable([bias], name=_BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + _save_variables_to_ckpt(self._model_dir) + + # logits = age * age_weight + bias = 17 * 2. - 35. = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 + mock_optimizer = self._mockOptimizer(expected_loss=1.3133) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assertCheckpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + + def testFromCheckpointMultiBatch(self): + # Create initial checkpoint. + n_classes = 2 + label = [1, 0] + age = [17, 18.5] + age_weight = 2.0 + bias = -35.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME) + variables.Variable([bias], name=_BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + _save_variables_to_ckpt(self._model_dir) + + # logits = age * age_weight + bias + # logits[0] = 17 * 2. - 35. = -1. + # logits[1] = 18.5 * 2. - 35. = 2. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 + # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 + mock_optimizer = self._mockOptimizer(expected_loss=1.3133 + 2.1269) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': (age)}, (label)), + steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assertCheckpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + if __name__ == '__main__': test.main() From 79099d67761b3e56d1c3764cd34f97401571a211 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 11:43:58 -0700 Subject: [PATCH 55/72] Removes default thresholds from BinaryLogisticHead and adds predict and evaluate tests for DNNClassifier. PiperOrigin-RevId: 157856471 --- .../python/estimator/canned/dnn_test.py | 159 ++++++++++++++++++ tensorflow/python/estimator/canned/head.py | 5 +- .../python/estimator/canned/head_test.py | 48 ++++-- .../python/estimator/canned/metric_keys.py | 2 + 4 files changed, 202 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py index 1a037957d8e..1838d03a94f 100644 --- a/tensorflow/python/estimator/canned/dnn_test.py +++ b/tensorflow/python/estimator/canned/dnn_test.py @@ -459,6 +459,86 @@ class DNNRegressorEvaluateTest(test.TestCase): }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1)) +class DNNClassifierEvaluateTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def test_one_dim(self): + """Asserts evaluation metrics for one-dimensional input and logits.""" + global_step = 100 + _create_checkpoint(( + ([[.6, .5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1.], [1.]], [.3]), + ), global_step, self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=[feature_column.numeric_column('age')], + model_dir=self._model_dir) + def _input_fn(): + # batch_size = 2, one false label, and one true. + return {'age': [[10.], [10.]]}, [[1], [0]] + # Uses identical numbers as DNNModelTest.test_one_dim_logits. + # See that test for calculation of logits. + # logits = [[-2.08], [-2.08]] => + # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]] + # loss = -1. * log(0.111) -1. * log(0.889) = 2.31544200 + expected_loss = 2.31544200 + self.assertAllClose({ + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2., + metric_keys.MetricKeys.ACCURACY: 0.5, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.11105597, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + # There is no good way to calculate AUC for only two data points. But + # that is what the algorithm returns. + metric_keys.MetricKeys.AUC: 0.5, + metric_keys.MetricKeys.AUC_PR: 0.75, + ops.GraphKeys.GLOBAL_STEP: global_step + }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) + + def test_multi_dim(self): + """Asserts evaluation metrics for multi-dimensional input and logits.""" + global_step = 100 + _create_checkpoint(( + ([[.6, .5], [-.6, -.5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]), + ), global_step, self._model_dir) + n_classes = 3 + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=[feature_column.numeric_column('age', shape=[2])], + n_classes=n_classes, + model_dir=self._model_dir) + def _input_fn(): + # batch_size = 2, one false label, and one true. + return {'age': [[10., 8.], [10., 8.]]}, [[1], [0]] + # Uses identical numbers as + # DNNModelFnTest.test_multi_dim_input_multi_dim_logits. + # See that test for calculation of logits. + # logits = [[-0.48, 0.48, 0.39], [-0.48, 0.48, 0.39]] + # probabilities = exp(logits)/sum(exp(logits)) + # = [[0.16670536, 0.43538380, 0.39791084], + # [0.16670536, 0.43538380, 0.39791084]] + # loss = -log(0.43538380) - log(0.16670536) + expected_loss = 2.62305466 + self.assertAllClose({ + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0.5, + ops.GraphKeys.GLOBAL_STEP: global_step + }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) + + class DNNRegressorPredictTest(test.TestCase): def setUp(self): @@ -520,6 +600,85 @@ class DNNRegressorPredictTest(test.TestCase): }, next(dnn_regressor.predict(input_fn=input_fn))) +class DNNClassifierPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def test_one_dim(self): + """Asserts predictions for one-dimensional input and logits.""" + _create_checkpoint(( + ([[.6, .5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1.], [1.]], [.3]), + ), global_step=0, model_dir=self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=(feature_column.numeric_column('x'),), + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[10.]])}, batch_size=1, shuffle=False) + # Uses identical numbers as DNNModelTest.test_one_dim_logits. + # See that test for calculation of logits. + # logits = [-2.08] => + # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597 + # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597] + # class_ids = argmax(probabilities) = [0] + self.assertAllClose({ + prediction_keys.PredictionKeys.LOGITS: [-2.08], + prediction_keys.PredictionKeys.LOGISTIC: [0.11105597], + prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597], + prediction_keys.PredictionKeys.CLASS_IDS: [0], + }, next(dnn_classifier.predict(input_fn=input_fn))) + + def test_multi_dim(self): + """Asserts predictions for multi-dimensional input and logits.""" + _create_checkpoint(( + ([[.6, .5], [-.6, -.5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]), + ), global_step=0, model_dir=self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=(feature_column.numeric_column('x', shape=(2,)),), + n_classes=3, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + # Inputs shape is (batch_size, num_inputs). + x={'x': np.array([[10., 8.]])}, + batch_size=1, + shuffle=False) + # Uses identical numbers as + # DNNModelFnTest.test_multi_dim_input_multi_dim_logits. + # See that test for calculation of logits. + # logits = [-0.48, 0.48, 0.39] => + # probabilities[i] = exp(logits[i]) / sum_j exp(logits[j]) => + # probabilities = [0.16670536, 0.43538380, 0.39791084] + # class_ids = argmax(probabilities) = [1] + predictions = next(dnn_classifier.predict(input_fn=input_fn)) + self.assertItemsEqual( + [prediction_keys.PredictionKeys.LOGITS, + prediction_keys.PredictionKeys.PROBABILITIES, + prediction_keys.PredictionKeys.CLASS_IDS, + prediction_keys.PredictionKeys.CLASSES], + six.iterkeys(predictions)) + self.assertAllClose( + [-0.48, 0.48, 0.39], predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + [0.16670536, 0.43538380, 0.39791084], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllEqual( + [1], predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertAllEqual( + [b'1'], predictions[prediction_keys.PredictionKeys.CLASSES]) + + def _queue_parsed_features(feature_map): tensors_to_enqueue = [] keys = [] diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index b06940ae611..631ddfc5dfc 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -459,7 +459,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( - weight_feature_key=None, thresholds=(0.5,)): + weight_feature_key=None, thresholds=None): """Creates a `Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -482,6 +482,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( Raises: ValueError: if `thresholds` contains a value outside of `(0, 1)`. """ + thresholds = tuple(thresholds) if thresholds else tuple() for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) @@ -494,7 +495,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): def __init__(self, weight_feature_key=None, thresholds=None): self._weight_feature_key = weight_feature_key - self._thresholds = tuple(thresholds) + self._thresholds = thresholds @property def logits_dimension(self): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 34c1eb6c828..0efafac87ab 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -845,7 +845,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): prediction_keys.PredictionKeys.CLASS_IDS: np.array(((1,), (0,)), dtype=np.int64), } - default_threshold = .5 keys = metric_keys.MetricKeys expected_metrics = { # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 @@ -857,9 +856,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., keys.AUC_PR: 1., - keys.ACCURACY_AT_THRESHOLD % default_threshold: 1./2, - keys.PRECISION_AT_THRESHOLD % default_threshold: 2./2, - keys.RECALL_AT_THRESHOLD % default_threshold: 1./2, } # Assert spec contains expected tensors. @@ -888,6 +884,44 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertAllClose( expected_metrics, {k: value_ops[k].eval() for k in value_ops}) + def test_eval_with_thresholds(self): + thresholds = [0.25, 0.5, 0.75] + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + thresholds=thresholds) + + # Create estimator spec. + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.float32)}, + mode=model_fn.ModeKeys.EVAL, + logits=np.array(((-1,), (1,),), dtype=np.float32), + labels=np.array(((1,), (1,),), dtype=np.int32)) + + # probabilities[i] = 1/(1 + exp(-logits[i])) => + # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731] + # loss = -sum(ln(probabilities[label[i]])) = -ln(0.269) -ln(0.731) + # = 1.62652338 + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: 1.62652338 / 2., + keys.ACCURACY: 1./2, + keys.PREDICTION_MEAN: 1./2, + keys.LABEL_MEAN: 2./2, + keys.ACCURACY_BASELINE: 2./2, + keys.AUC: 0., + keys.AUC_PR: 1., + keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1., + keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1., + keys.RECALL_AT_THRESHOLD % thresholds[0]: 1., + keys.ACCURACY_AT_THRESHOLD % thresholds[1]: .5, + keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1., + keys.RECALL_AT_THRESHOLD % thresholds[1]: .5, + keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 0., + keys.PRECISION_AT_THRESHOLD % thresholds[2]: 0., + keys.RECALL_AT_THRESHOLD % thresholds[2]: 0., + } + + self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) + def test_train(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() @@ -1000,7 +1034,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=np.array(((1,), (1,), (0,)), dtype=np.int32)) - default_threshold = .5 # label_mean = (1*1 + .1*1 + 1.5*0)/(1 + .1 + 1.5) = 1.1/2.6 # = .42307692307 expected_label_mean = .42307692307 @@ -1021,11 +1054,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.ACCURACY_BASELINE: 1 - expected_label_mean, keys.AUC: .45454565, keys.AUC_PR: .6737757325172424, - keys.ACCURACY_AT_THRESHOLD % default_threshold: .38461538461, - # precision = (1*1 + 1.5*0)/(1 + 1.5) = 1/2.5 = .4 - keys.PRECISION_AT_THRESHOLD % default_threshold: .4, - # recall = (1*1 + .1*0)/(1 + .1) = 1/1.1 = .90909090909 - keys.RECALL_AT_THRESHOLD % default_threshold: .90909090909, } # Assert spec contains expected tensors. diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py index 1261d1dcfb1..91e3bf1d83a 100644 --- a/tensorflow/python/estimator/canned/metric_keys.py +++ b/tensorflow/python/estimator/canned/metric_keys.py @@ -29,6 +29,8 @@ class MetricKeys(object): LOSS_MEAN = model_fn.MetricKeys.AVERAGE_LOSS ACCURACY = 'accuracy' + # This is the best the model could do by always predicting one class. + # Should be < ACCURACY in a trained model. ACCURACY_BASELINE = 'accuracy_baseline' AUC = 'auc' AUC_PR = 'auc_precision_recall' From 55f6b6ff1365476fa460e7948413876bf1572484 Mon Sep 17 00:00:00 2001 From: David Soergel Date: Fri, 2 Jun 2017 12:18:35 -0700 Subject: [PATCH 56/72] Add explicit SparseTensor support to SignatureDef. PiperOrigin-RevId: 157860466 --- tensorflow/core/protobuf/meta_graph.proto | 28 +++++- .../tensorflow.-tensor-info.-coo-sparse.pbtxt | 88 +++++++++++++++++++ .../api/golden/tensorflow.-tensor-info.pbtxt | 8 ++ 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto index 5b2022321e5..47ec2aa1efe 100644 --- a/tensorflow/core/protobuf/meta_graph.proto +++ b/tensorflow/core/protobuf/meta_graph.proto @@ -202,8 +202,34 @@ message CollectionDef { // Information about a Tensor necessary for feeding or retrieval. message TensorInfo { - string name = 1; + // For sparse tensors, The COO encoding stores a triple of values, indices, + // and shape. + message CooSparse { + // The shape of the values Tensor is [?]. Its dtype must be the dtype of + // the SparseTensor as a whole, given in the enclosing TensorInfo. + string values_tensor_name = 1; + + // The indices Tensor must have dtype int64 and shape [?, ?]. + string indices_tensor_name = 2; + + // The dynamic logical shape represented by the SparseTensor is recorded in + // the Tensor referenced here. It must have dtype int64 and shape [?]. + string dense_shape_tensor_name = 3; + } + + oneof encoding { + // For dense `Tensor`s, the name of the tensor in the graph. + string name = 1; + // There are many possible encodings of sparse matrices + // (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow + // uses only the COO encoding. This is supported and documented in the + // SparseTensor Python class. + CooSparse coo_sparse = 4; + } DataType dtype = 2; + // The static shape should be recorded here, to the extent that it can + // be known in advance. In the case of a SparseTensor, this field describes + // the logical shape of the represented tensor (aka dense_shape). TensorShapeProto tensor_shape = 3; } diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt new file mode 100644 index 00000000000..425c35e0674 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.TensorInfo.CooSparse" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DENSE_SHAPE_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "INDICES_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUES_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt index 87632fb7b9e..41ea393be51 100644 --- a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt @@ -2,6 +2,14 @@ path: "tensorflow.TensorInfo" tf_class { is_instance: "" is_instance: "" + member { + name: "COO_SPARSE_FIELD_NUMBER" + mtype: "" + } + member { + name: "CooSparse" + mtype: "" + } member { name: "DESCRIPTOR" mtype: "" From 4905c0eae45fc509d21dd0b18b15d406e2d94bf8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 12:24:13 -0700 Subject: [PATCH 57/72] Remove TODO - the new tolerance is okay to keep. PiperOrigin-RevId: 157861020 --- .../kernel_tests/distributions/dirichlet_multinomial_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index 59af0b8cf85..2f8f85866df 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -252,7 +252,6 @@ class DirichletMultinomialTest(test.TestCase): ]) self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05) - # TODO(cwhipkey): reduce tolerance (b/62216354) self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.05) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) From f4b8d21b8e41636b6e61f0a1de753430108d2ee7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 13:22:15 -0700 Subject: [PATCH 58/72] Change function parameters to references to avoid copying, or otherwise move from function parameters when moving reduces the amount of copying. PiperOrigin-RevId: 157867333 --- tensorflow/cc/framework/testutil.cc | 4 +++- .../compiler/jit/encapsulate_subgraphs_pass_test.cc | 8 +++++--- tensorflow/compiler/xla/client/global_data.cc | 3 ++- tensorflow/compiler/xla/reference_util.cc | 4 +++- tensorflow/compiler/xla/service/execution_tracker.cc | 2 +- tensorflow/compiler/xla/shape_util.cc | 2 +- .../contrib/factorization/kernels/clustering_ops.cc | 4 ++-- tensorflow/contrib/session_bundle/session_bundle_test.cc | 2 +- tensorflow/core/framework/op_kernel_test.cc | 3 ++- tensorflow/core/graph/graph_def_builder.cc | 6 ++++-- tensorflow/core/graph/graph_partition.cc | 6 ++++-- tensorflow/core/graph/graph_partition_test.cc | 3 ++- tensorflow/core/graph/optimizer_cse_test.cc | 2 +- tensorflow/core/util/equal_graph_def_test.cc | 4 +++- 14 files changed, 34 insertions(+), 19 deletions(-) diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc index b0746913a16..ca78f31db51 100644 --- a/tensorflow/cc/framework/testutil.cc +++ b/tensorflow/cc/framework/testutil.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/cc/framework/testutil.h" +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/default_device.h" @@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors, void GetTensor(const Scope& scope, Output tensor, Tensor* out) { std::vector outputs; - GetTensors(scope, {tensor}, &outputs); + GetTensors(scope, {std::move(tensor)}, &outputs); *out = outputs[0]; } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index a8869c8e2a7..4a1dbaf05dc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" @@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) { } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { - return ops::UnaryOp("UnaryTest", a, opts); + return ops::UnaryOp("UnaryTest", std::move(a), opts); } Node* Binary(ops::NodeOut a, ops::NodeOut b, const GraphDefBuilder::Options& opts) { - return ops::BinaryOp("BinaryTest", a, b, opts); + return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } Node* AddNLike(const std::vector& inputs, @@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", opts.op_registry()); - node_builder.Input(a).Attr("index", index); + node_builder.Input(std::move(a)).Attr("index", index); return opts.FinalizeBuilder(&node_builder); } diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index be706f7d232..40f59eaa68e 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -23,7 +24,7 @@ limitations under the License. namespace xla { GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) - : handle_(handle), parent_(parent) {} + : handle_(std::move(handle)), parent_(parent) {} GlobalData::~GlobalData() { UnregisterRequest request; diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 32c3c3ae206..e8de559a5ef 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include +#include #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions( std::pair kernel_stride, Padding padding, ConvolutionDimensionNumbers dimension_numbers) { return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, - {1, 1}, {1, 1}, dimension_numbers); + {1, 1}, {1, 1}, + std::move(dimension_numbers)); } /* static */ std::unique_ptr> diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 8d79d07f942..c225e62e3e1 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend, : backend_(CHECK_NOTNULL(backend)), streams_(std::move(streams)), profile_(profile), - result_(result) { + result_(std::move(result)) { for (const auto& stream : streams_) { CHECK(stream != nullptr); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ccc1dc63e78..8d04935a0bc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { for (const auto& shape : parameters) { *program_shape.add_parameters() = shape; } - *program_shape.mutable_result() = result; + *program_shape.mutable_result() = std::move(result); return program_shape; } diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index 3a964311820..a2136c08bbc 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel { const Eigen::Ref& points_half_squared_norm, const Eigen::Ref& centers, const Eigen::Ref& centers_half_squared_norm, - Eigen::Ref nearest_center_indices, - Eigen::Ref nearest_center_distances) { + const Eigen::Ref& nearest_center_indices, + const Eigen::Ref& nearest_center_distances) { CHECK_LE(k, centers.rows()); if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) { FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers, diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index ad6264d5c8a..eb36d79e0f4 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test { // MetaGraphDef. // Returns the path of the export. // ** Should only be called once per test ** - string SetupExport(MetaGraphDefTwiddler twiddler) { + string SetupExport(const MetaGraphDefTwiddler& twiddler) { return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename); } // SetupExport that allows for the variables and meta_graph_def filenames diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index e8e931b52e4..f87b7178449 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -456,7 +456,8 @@ class OpKernelBuilderTest : public ::testing::Test { } } - string GetKernelClassName(const string& op_type, DeviceType device_type, + string GetKernelClassName(const string& op_type, + const DeviceType& device_type, const std::vector& attrs, DataTypeSlice input_types = {}) { NodeDef def = CreateNodeDef(op_type, attrs); diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index ec1c1b6cea2..33d2021f381 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" +#include + #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" @@ -119,7 +121,7 @@ Node* UnaryOp(const string& op_name, NodeOut input, if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, opts.op_registry()); - node_builder.Input(input); + node_builder.Input(std::move(input)); return opts.FinalizeBuilder(&node_builder); } @@ -128,7 +130,7 @@ Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, opts.op_registry()); - node_builder.Input(a).Input(b); + node_builder.Input(std::move(a)).Input(std::move(b)); return opts.FinalizeBuilder(&node_builder); } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 60363175594..f8c6895dfa1 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/core/framework/memory_types.h" @@ -392,7 +393,8 @@ Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, const string& device_name, const GraphDefBuilder::Options& bopts) { - Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts); + Node* res_node = + ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; @@ -401,7 +403,7 @@ Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, // A next_iteration node for control flow. Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, const GraphDefBuilder::Options& bopts) { - Node* res_node = ops::UnaryOp("NextIteration", input, bopts); + Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index ee545dbfbfa..ca49ea0ac49 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include +#include #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" @@ -159,7 +160,7 @@ Output BoolInput(const Scope& scope) { } Output Combine(const Scope& scope, Input a, Input b) { - return ConstructOp(scope, "Combine", {a, b}); + return ConstructOp(scope, "Combine", {std::move(a), std::move(b)}); } class GraphPartitionTest : public ::testing::Test { diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index 94250240eb7..21a63662cf2 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -86,7 +86,7 @@ class OptimizerCSETest : public ::testing::Test { str_util::Join(edges, ";")); } - string DoCSE(std::function consider_fn = nullptr) { + string DoCSE(const std::function& consider_fn = nullptr) { string before = CanonicalGraphString(&graph_); LOG(ERROR) << "Before rewrites: " << before; diff --git a/tensorflow/core/util/equal_graph_def_test.cc b/tensorflow/core/util/equal_graph_def_test.cc index af870c5c607..054cc92c169 100644 --- a/tensorflow/core/util/equal_graph_def_test.cc +++ b/tensorflow/core/util/equal_graph_def_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/framework/node_def_util.h" @@ -40,7 +42,7 @@ Node* Alternate(const GraphDefBuilder::Options& opts) { Node* Combine(ops::NodeOut a, ops::NodeOut b, const GraphDefBuilder::Options& opts) { - return ops::BinaryOp("Combine", a, b, opts); + return ops::BinaryOp("Combine", std::move(a), std::move(b), opts); } class EqualGraphDefTest : public ::testing::Test { From 366990d92dfff7861ff11a4d0c9bb4a9f74f9077 Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Fri, 2 Jun 2017 14:02:25 -0700 Subject: [PATCH 59/72] [XLA] Fix a subtle issue in copy_insertion due the interaction between copy overriding logic and RecordIndicesToColocatingBuffers: - When building instructions ShapeTree to be copy overriden, it is possible that we create a single kCopy for two identical instructions. An example can be: %tuple.19 = tuple(%constant.4, %constant.1793, %constant.1793) where it is used in a while.init operand, and constant.1793 is read-only within the loop and also used by another while loop. The copy overriding pass will then create the following (logical, not finalized) tuple: %tuple.19 = tuple(%constant.4, %copy.5, %copy.5) - In the subsequent pass RecordAmbiguousOrNonDistinctIndices, to add copies to ensure point_to set is distinct, the duplicate %copy.5 are ignored because they are not yet finalized, and these indices (1 and 2 in the example) are still marked as to-be copied. Therefore distinctiveness is lost. This fix applies to the override building stage, to explicitly avoid creating shared copies for non-distinct buffers. PiperOrigin-RevId: 157872231 --- .../compiler/xla/service/copy_insertion.cc | 66 +++++---- .../xla/service/copy_insertion_test.cc | 131 +++++++++++++++++- 2 files changed, 167 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 907b0307d4b..989b73faa80 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( // Mapping from LogicalBuffer to index (used to detect non-distinct indices). FlatMap> buffer_to_source_indices; - TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); - } - } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (const LogicalBuffer* buffer : buffers) { - buffer_to_source_indices[buffer].push_back(index); - } - return Status::OK(); - })); + TF_RETURN_IF_ERROR(points_to.ForEachElement( + [this, &buffer_to_source_indices]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); + } + } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (const LogicalBuffer* buffer : buffers) { + buffer_to_source_indices[buffer].push_back(index); + } + return Status::OK(); + })); // Record all non-distinct indices detected in 'buffer_to_source_indices'. for (const auto& buff_to_src : buffer_to_source_indices) { @@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( FlatMap* shared_copies) { const HloInstruction* init_hlo = while_hlo->operand(0); const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatSet buffer_set; + ShapeTree copy_overrides(init_hlo->shape()); TF_RETURN_IF_ERROR(points_to.ForEachElement( - [init_hlo, read_only_indices, shared_copies, ©_overrides]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { + [init_hlo, read_only_indices, shared_copies, &buffer_set, + ©_overrides](const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { // Look for read-only entry parameters. if (!read_only_indices->element(index)) { return Status::OK(); @@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( if (!is_entry_parameter && !is_constant) { continue; } + // We have found an entry parameter or constant that is read-only in // the while body. These buffers are managed by the caller, and cannot // be aliased with non-parameter buffers. Revert this read-only index, @@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( // Optimization to allow multiple while loops that share the same // read-only entry parameters (or constants) to share a single copy. - // Only unambiguous array-shaped buffers are allowed, to reduce code - // complexity. The shape of the entry parameter must be identical to - // the shape of the init_hlo at this index, to ensure there were no - // intervening bitcast or GTE instructions, which are also hard to - // handle. + // Only unambiguous and distinct array-shaped buffers are allowed, to + // reduce code complexity. The shape of the entry parameter must be + // identical to the shape of the init_hlo at this index, to ensure + // there were no intervening bitcast or GTE instructions, which are + // also hard to handle. const Shape& pointee_shape = pointee->shape(); const Shape& init_shape = ShapeUtil::GetSubshape(init_hlo->shape(), index); if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && - ShapeUtil::Equal(pointee_shape, init_shape)) { + ShapeUtil::Equal(pointee_shape, init_shape) && + buffer_set.count(buffer) < 1) { HloInstruction** copy = &(*shared_copies)[pointee]; if (*copy == nullptr) { *copy = @@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( *copy_overrides.mutable_element(index) = *copy; } + // Tracks whether this current buffer is distinct. + buffer_set.insert(buffer); + // We've already reverted the read-only index and handled the // single-copy optimization above, so there's nothing more to do. break; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4a14fc5397b..cb9682392ea 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase { EXPECT_IS_OK(copy_insertion.Run(module).status()); // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters. + // insertion contains no constants or parameters, and is distinct and + // non-ambiguous. auto points_to_analysis = TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); + const auto& points_to = points_to_analysis->GetPointsToSet( + module->entry_computation()->root_instruction()); + EXPECT_TRUE(points_to.IsDistinct()); + EXPECT_TRUE(!points_to.IsAmbiguous()); + tensorflow::gtl::FlatSet maybe_live_out_buffers = points_to_analysis ->GetPointsToSet(module->entry_computation()->root_instruction()) .CreateFlattenedSet(); + for (const LogicalBuffer* buffer : maybe_live_out_buffers) { EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); @@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest { return builder.Build(); } + // Builds a While body computation with two output tuple elements dependent on + // both input tuple elements. + // + // EX: Body({in0, in1, in2}) + // out0 = Add(in0, 1) + // out1 = in1 + // out2 = in2 + // Tuple(out0, out1, out2) + std::unique_ptr BuildDependentBodyComputation2() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + // add0 = Add(in0, 1) + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // data1 = GTE(1). + HloInstruction* data1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + + // data2 = GTE(2). + HloInstruction* data2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); + + // Create output Tuple. + builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); + + return builder.Build(); + } + // Builds a While body computation with read-only tuple element 0. // EX: // Body({in0, in1}) @@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Update data GTE(1). auto data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); @@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Create param instruction to access loop state. const Shape& loop_state_shape = nested ? nested_loop_state_shape_ : loop_state_shape_; + auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); // Update the induction variable GTE(0). @@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { op::Copy(old_init->operand(1)->operand(0))))); } -// Tests while init instruction buffer which interfers with while result buffer. +// Tests while init instruction buffer which interferes with while result +// buffer. // // init_data = Broadcast(...) // add_unrelated = Add(init_data) // takes a reference to cause interference @@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { op::Copy(old_init->operand(1)))); } +// Tests while init instruction buffer which has a non-distinct points-to set: +// +// init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, +// Parameter(F32, {8}))) +// +// where the second and third parameters are identical *and* the tuple shared +// by another while instruction.. +// +// Verifies that the resulting point-to set is distinct in the resulting Tuple +// (non-identical Copys). In other words, verifies that copy sharing does not +// insert identical copies to the resulting tuple. +TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { + auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation()); + // Loop body that outputs tuple comprises two elements dependent on the init + // tuple. + auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2()); + auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + + // Loop init tuple contains two identical parameter buffers. + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_param, data_param, data_param})); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + // Two while loops shares the same loop init tuple. + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition2, body2, loop_init)); + + module_.AddEntryComputation(builder.Build()); + + auto points_to_analysis = + TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie(); + + // Asserts that the init tuples before copy insertion is non-distinct. + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + + auto old_init1 = while_hlo1->operand(0); + auto old_init2 = while_hlo2->operand(0); + + InsertCopies(&module_); + + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Copy(old_init1->operand(0)), + op::Copy(old_init1->operand(1)), + op::Copy(old_init1->operand(2)))); + + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Copy(old_init2->operand(0)), + op::Copy(old_init2->operand(1)), + op::Copy(old_init2->operand(2)))); + + // Verifies the init tuples after copy insertion is distinct. + points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie(); + const auto& points_to1 = + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); + EXPECT_TRUE(points_to1.IsDistinct()); + + const auto& points_to2 = + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); + EXPECT_TRUE(points_to2.IsDistinct()); +} + } // namespace } // namespace xla From 7cdcd0cca2a97c45c634f45b0ace0771ce5a5498 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 2 Jun 2017 14:23:14 -0700 Subject: [PATCH 60/72] Filter more op types that don't benefit from constant folding. PiperOrigin-RevId: 157875168 --- .../grappler/optimizers/constant_folding.cc | 35 +++++++++++-------- .../grappler/optimizers/constant_folding.h | 12 ++----- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ea5bfe164b3..c2df76e4315 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -101,6 +101,11 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) { } } // namespace +ConstantFolding::ConstantFolding() { + ops_to_preserve_ = + std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader"); +} + Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically()); @@ -184,28 +189,19 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { } bool ConstantFolding::IsFoldable(const NodeDef& node) const { - DeviceTypeVector device_types; - auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, - &device_types); - if (!status.ok()) { - return false; - } - // Only fold ops with a CPU implementation available. - if (device_types[0] != DeviceType(DEVICE_CPU)) { - return false; - } - + // Skips nodes that must be preserved, and op_types that don't benefit from + // folding if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; } - - if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) { + std::cmatch match; + if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) { return false; } // Don't fold stateful ops such as TruncatedNormal. const OpDef* op_def = nullptr; - status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } @@ -217,6 +213,17 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + DeviceTypeVector device_types; + status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, + &device_types); + if (!status.ok()) { + return false; + } + // Only fold ops with a CPU implementation available. + if (device_types[0] != DeviceType(DEVICE_CPU)) { + return false; + } + // Folding not applicable to ops with no inputs. if (node.input().empty()) { return false; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9689e97a123..cb9729ef1ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ +#include #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" @@ -29,7 +30,7 @@ const char kConstantFoldingConst[] = "ConstantFolding"; // Contant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: - ConstantFolding() {} + ConstantFolding(); ~ConstantFolding() override {} @@ -66,14 +67,7 @@ class ConstantFolding : public GraphOptimizer { GraphDef graph_; std::unique_ptr node_map_; std::set nodes_to_preserve_; - std::set ops_to_preserve_ = {"Save", - "SaveV2", - "SaveSlices", - "Restore", - "RestoreV2", - "RestoreSlice", - "PlaceholderWithDefault", - "Const"}; + std::regex ops_to_preserve_; }; } // end namespace grappler From 9e25c68ad1a0216a16cfd5e87e5e189b42745508 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 14:29:17 -0700 Subject: [PATCH 61/72] Add loss_only_head to hold additional loss terms for multi_head setup PiperOrigin-RevId: 157875934 --- .../learn/python/learn/estimators/__init__.py | 1 + .../learn/python/learn/estimators/head.py | 98 ++++++++++++++++++- .../python/learn/estimators/head_test.py | 20 +++- 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index a40cbc04490..bba479a00ee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head from tensorflow.contrib.learn.python.learn.estimators.head import Head +from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index e4ef6996d8d..22e89de4c2b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -429,6 +429,23 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +def loss_only_head(loss_fn, head_name=None): + """Creates a Head that contains only loss terms. + + Loss only head holds additional loss terms to be added to other heads and + usually represents additional regularization terms in the objective function. + + Args: + loss_fn: a function that takes no argument and returns a list of + scalar tensors. + head_name: a name for for the head. + + Returns: + An instance of `Head` to hold the additional losses. + """ + return _LossOnlyHead(loss_fn, head_name=head_name) + + def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead): return metrics +class _LossOnlyHead(Head): + """`Head` implementation for additional loss terms. + + This class only holds loss terms unrelated to any other heads (labels), + e.g. regularization. + + Common usage: + This is oftem combine with other heads in a multi head setup. + ```python + head = multi_head([ + head1, head2, loss_only_head('regularizer', regularizer)]) + ``` + """ + + def __init__(self, loss_fn, head_name=None): + self._loss_fn = loss_fn + self.head_name = head_name or "loss_only_head" + + @property + def logits_dimension(self): + return 0 + + def create_model_fn_ops(self, + features, + mode, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + scope=None): + """See `_Head.create_model_fn_ops`. + + Args: + features: Not been used. + mode: Estimator's `ModeKeys`. + labels: Labels `Tensor`, or `dict` of same. + train_op_fn: Function that takes a scalar loss and returns an op to + optimize with the loss. + logits: Not been used. + logits_input: Not been used. + scope: Optional scope for variable_scope. If provided, will be passed to + all heads. Most users will want to set this to `None`, so each head + constructs a separate variable_scope according to its `head_name`. + + Returns: + A `ModelFnOps` object. + + Raises: + ValueError: if `mode` is not recognition. + """ + _check_mode_valid(mode) + loss = None + train_op = None + if mode != model_fn.ModeKeys.INFER: + with variable_scope.variable_scope(scope, default_name=self.head_name): + loss = self._loss_fn() + if isinstance(loss, list): + loss = math_ops.add_n(loss) + logging_ops.scalar_summary( + _summary_key(self.head_name, mkey.LOSS), loss) + if mode == model_fn.ModeKeys.TRAIN: + if train_op_fn is None: + raise ValueError("train_op_fn can not be None in TRAIN mode") + with ops.name_scope(None, "train_op", (loss,)): + train_op = train_op_fn(loss) + + return model_fn.ModelFnOps( + mode=mode, + loss=loss, + train_op=train_op, + predictions={}, + eval_metric_ops={}) + + class _MultiHead(Head): """`Head` implementation for multi objective learning. @@ -1525,7 +1616,10 @@ class _MultiHead(Head): if isinstance(logits, dict): head_logits_pairs = [] for head in self._heads: - head_logits_pairs.append((head, logits[head.head_name])) + if isinstance(head, _LossOnlyHead): + head_logits_pairs.append((head, None)) + else: + head_logits_pairs.append((head, logits[head.head_name])) else: # Split logits for each head. head_logits_pairs = zip(self._heads, self._split_logits(logits)) @@ -1606,6 +1700,8 @@ class _MultiHead(Head): predictions = {} output_alternatives = {} for head, m in zip(self._heads, all_model_fn_ops): + if isinstance(head, _LossOnlyHead): + continue head_name = head.head_name output_alternatives[head_name] = m.output_alternatives[head_name] for k, v in m.predictions.items(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 012b919d631..25a66748587 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) +class LossOnlyHead(test.TestCase): + + def testNoPredictionsAndNoMetrics(self): + head = head_lib.loss_only_head(lambda: 1, head_name="const") + model_fn_ops = head.create_model_fn_ops( + features={}, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn) + self.assertDictEqual(model_fn_ops.predictions, {}) + self.assertDictEqual(model_fn_ops.eval_metric_ops, {}) + self.assertIsNotNone(model_fn_ops.loss) + with session.Session() as sess: + self.assertEqual(1, sess.run(model_fn_ops.loss)) + + class MultiHeadTest(test.TestCase): def testInvalidHeads(self): @@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase): n_classes=3, label_name="label1", head_name="head1") head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib.multi_head((head1, head2)) + head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const") + head = head_lib.multi_head((head1, head2, head3)) labels = { "label1": (1,), "label2": (1,) @@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase): self.assertIsNone(model_fn_ops.output_alternatives) with session.Session() as sess: - self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) + self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3) def testTrain_withHeadWeights(self): head1 = head_lib.multi_class_head( From 8939b8562027189b24e7609b77e17122dc3a21d4 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 2 Jun 2017 14:50:27 -0700 Subject: [PATCH 62/72] [tf.contrib.data] Re-implement IteratorGetNext as an AsyncOpKernel. This prevents the op from consuming an inter-op thread pool thread when blocked, and fixes a potential deadlock when many IteratorGetNext ops are blocked. Fixes #10369. PiperOrigin-RevId: 157878885 --- .../kernel_tests/map_dataset_op_test.py | 3 +- tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/iterator_ops.cc | 63 ++++++++++++------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 2f7f8ebbae8..68cd3623c00 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase): results.append(sess.run(get_next)) except errors.OutOfRangeError: return - threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] + threads = [self.checkedThread(target=iterator_thread) + for _ in range(64)] for t in threads: t.start() for t in threads: diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b00a4e534d6..e8bb7b6688c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5387,6 +5387,7 @@ tf_kernel_library( srcs = ["iterator_ops.cc"], deps = [ ":dataset", + ":ops_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index 68d1d292252..b91aae6b077 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -18,7 +18,10 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -282,38 +285,54 @@ class OneShotIteratorOp : public OpKernel { IteratorResource* iterator_resource_ = nullptr; }; -class IteratorGetNextOp : public OpKernel { +class IteratorGetNextOp : public AsyncOpKernel { public: - explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit IteratorGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("iterator_get_next_thread_", + SanitizeThreadSuffix(def().name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} - // TODO(mrry): Convert this to an async op, because - // `iterator->GetNext()` could trigger long-running operations - // (e.g. a QueueDequeue or a remote read). - void Compute(OpKernelContext* ctx) override { + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { IteratorResource* iterator; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); - core::ScopedUnref unref_iterator(iterator); - std::vector components; - bool end_of_sequence; + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, iterator, done]() { + core::ScopedUnref unref_iterator(iterator); - IteratorContext::Params params; - params.env = ctx->env(); - params.step_id = ctx->step_id(); - params.resource_manager = ctx->resource_manager(); - params.runner = *(ctx->runner()); - IteratorContext iter_ctx(std::move(params)); + std::vector components; + bool end_of_sequence; - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); + IteratorContext::Params params; + params.env = ctx->env(); + params.step_id = ctx->step_id(); + params.resource_manager = ctx->resource_manager(); + params.runner = *(ctx->runner()); + IteratorContext iter_ctx(std::move(params)); - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + OP_REQUIRES_ASYNC(ctx, !end_of_sequence, + errors::OutOfRange("End of sequence"), done); + + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + + done(); + }); } + + private: + std::unique_ptr thread_pool_; }; class IteratorDisposeOp : public OpKernel { From a4caeb2ea4ba4229ea8444e8eda32b7dba57658c Mon Sep 17 00:00:00 2001 From: William Chargin Date: Fri, 2 Jun 2017 15:39:53 -0700 Subject: [PATCH 63/72] Extract the graphs dashboard to a plugin This completes the great plugin migration! The graphs plugin is somewhat different from the plugins considered so far. First, it exposes two kinds of data: graph data and run metadata. We elect to put both sources of data under the domain of the graphs plugin for now, because it's not clear that the run metadata would be useful for anything else. Second, the graph data really has no use for "tags": a run either has an associated graph or it does not. Thus, we expose an endpoint /data/plugin/graphs/runs that is different in format from the /tags routes exposed by other plugins (it returns just a list instead of a run-to-tag mapping). This change removes a bunch of tests from application_test.py. The tests cover the compresion behavior of the graph endpoint, but the graph endpoint doesn't have any special logic in the way of compression. Thus, the tests are, apparently, testing that werkzeug (or whatever is relevant here) provides good compression defaults. This isn't necessarily a bad idea, but it shouldn't be coupled to the graph tests. To get test data that includes run metadata, you can run this script: https://raw.githubusercontent.com/tensorflow/tensorflow/326942394e69074d50d5889218a24c9371eff259/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py PiperOrigin-RevId: 157884714 --- tensorflow/BUILD | 1 + tensorflow/contrib/cmake/tf_python.cmake | 1 + tensorflow/tensorboard/BUILD | 1 + tensorflow/tensorboard/backend/BUILD | 1 - tensorflow/tensorboard/backend/application.py | 60 +------- .../tensorboard/backend/application_test.py | 72 +-------- .../components/tf_backend/backend.ts | 43 ++++-- .../components/tf_backend/router.ts | 26 ---- .../tf_backend/test/backendTests.ts | 2 +- .../tf-graph-dashboard.html | 6 +- tensorflow/tensorboard/http_api.md | 27 ++-- tensorflow/tensorboard/plugins/graphs/BUILD | 51 +++++++ .../plugins/graphs/graphs_plugin.py | 140 +++++++++++++++++ .../plugins/graphs/graphs_plugin_test.py | 142 ++++++++++++++++++ tensorflow/tensorboard/tensorboard.py | 2 + 15 files changed, 396 insertions(+), 179 deletions(-) create mode 100644 tensorflow/tensorboard/plugins/graphs/BUILD create mode 100644 tensorflow/tensorboard/plugins/graphs/graphs_plugin.py create mode 100644 tensorflow/tensorboard/plugins/graphs/graphs_plugin_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ce1387ba43c..0eea54a6efc 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -381,6 +381,7 @@ filegroup( "//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins/audio:all_files", "//tensorflow/tensorboard/plugins/distributions:all_files", + "//tensorflow/tensorboard/plugins/graphs:all_files", "//tensorflow/tensorboard/plugins/histograms:all_files", "//tensorflow/tensorboard/plugins/images:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 95dbefc37ab..74716cd900d 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -231,6 +231,7 @@ add_python_module("tensorflow/tensorboard/backend/event_processing") add_python_module("tensorflow/tensorboard/plugins") add_python_module("tensorflow/tensorboard/plugins/audio") add_python_module("tensorflow/tensorboard/plugins/distributions") +add_python_module("tensorflow/tensorboard/plugins/graphs") add_python_module("tensorflow/tensorboard/plugins/histograms") add_python_module("tensorflow/tensorboard/plugins/images") add_python_module("tensorflow/tensorboard/plugins/projector") diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index a8a4fb16614..1eb5b124157 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -15,6 +15,7 @@ py_binary( "//tensorflow/tensorboard/backend/event_processing:event_file_inspector", "//tensorflow/tensorboard/plugins/audio:audio_plugin", "//tensorflow/tensorboard/plugins/distributions:distributions_plugin", + "//tensorflow/tensorboard/plugins/graphs:graphs_plugin", "//tensorflow/tensorboard/plugins/histograms:histograms_plugin", "//tensorflow/tensorboard/plugins/images:images_plugin", "//tensorflow/tensorboard/plugins/projector:projector_plugin", diff --git a/tensorflow/tensorboard/backend/BUILD b/tensorflow/tensorboard/backend/BUILD index 3b5ce4c6e3e..c7f22b1b6ab 100644 --- a/tensorflow/tensorboard/backend/BUILD +++ b/tensorflow/tensorboard/backend/BUILD @@ -63,7 +63,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":http_util", - ":process_graph", "//tensorflow:tensorflow_py", "//tensorflow/tensorboard/backend/event_processing:event_accumulator", "//tensorflow/tensorboard/backend/event_processing:event_multiplexer", diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py index 46f081a67c9..9c492e7dd39 100644 --- a/tensorflow/tensorboard/backend/application.py +++ b/tensorflow/tensorboard/backend/application.py @@ -33,7 +33,6 @@ import tensorflow as tf from werkzeug import wrappers from tensorflow.tensorboard.backend import http_util -from tensorflow.tensorboard.backend import process_graph from tensorflow.tensorboard.backend.event_processing import event_accumulator from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -57,8 +56,10 @@ DEFAULT_SIZE_GUIDANCE = { _MIGRATED_DATA_KEYS = frozenset(( 'audio', 'distributions', + 'graph', 'histograms', 'images', + 'run_metadata', 'scalars', )) @@ -67,8 +68,6 @@ LOGDIR_ROUTE = '/logdir' RUNS_ROUTE = '/runs' PLUGIN_PREFIX = '/plugin' PLUGINS_LISTING_ROUTE = '/plugins_listing' -GRAPH_ROUTE = '/' + event_accumulator.GRAPH -RUN_METADATA_ROUTE = '/' + event_accumulator.RUN_METADATA TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] # Slashes in a plugin name could throw the router for a loop. An empty @@ -146,16 +145,12 @@ class TensorBoardWSGIApp(object): reload_multiplexer(self._multiplexer, path_to_run) self.data_applications = { - DATA_PREFIX + GRAPH_ROUTE: - self._serve_graph, DATA_PREFIX + LOGDIR_ROUTE: self._serve_logdir, # TODO(chizeng): Delete this RPC once we have skylark rules that obviate # the need for the frontend to determine which plugins are active. DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing, - DATA_PREFIX + RUN_METADATA_ROUTE: - self._serve_run_metadata, DATA_PREFIX + RUNS_ROUTE: self._serve_runs, } @@ -212,57 +207,6 @@ class TensorBoardWSGIApp(object): return http_util.Respond( request, {'logdir': self._logdir}, 'application/json') - @wrappers.Request.application - def _serve_graph(self, request): - """Given a single run, return the graph definition in json format.""" - run = request.args.get('run', None) - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - - try: - graph = self._multiplexer.Graph(run) - except ValueError: - return http_util.Respond( - request, '404 Not Found', 'text/plain; charset=UTF-8', code=404) - - limit_attr_size = request.args.get('limit_attr_size', None) - if limit_attr_size is not None: - try: - limit_attr_size = int(limit_attr_size) - except ValueError: - return http_util.Respond( - request, 'query parameter `limit_attr_size` must be integer', - 'text/plain', 400) - - large_attrs_key = request.args.get('large_attrs_key', None) - try: - process_graph.prepare_graph_for_ui(graph, limit_attr_size, - large_attrs_key) - except ValueError as e: - return http_util.Respond(request, e.message, 'text/plain', 400) - - return http_util.Respond(request, str(graph), 'text/x-protobuf') # pbtxt - - @wrappers.Request.application - def _serve_run_metadata(self, request): - """Given a tag and a TensorFlow run, return the session.run() metadata.""" - tag = request.args.get('tag', None) - run = request.args.get('run', None) - if tag is None: - return http_util.Respond( - request, 'query parameter "tag" is required', 'text/plain', 400) - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - try: - run_metadata = self._multiplexer.RunMetadata(run, tag) - except ValueError: - return http_util.Respond( - request, '404 Not Found', 'text/plain; charset=UTF-8', code=404) - return http_util.Respond( - request, str(run_metadata), 'text/x-protobuf') # pbtxt - @wrappers.Request.application def _serve_plugins_listing(self, request): """Serves an object mapping plugin name to whether it is enabled. diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py index f05c9352466..87cfdbc1d8d 100644 --- a/tensorflow/tensorboard/backend/application_test.py +++ b/tensorflow/tensorboard/backend/application_test.py @@ -35,7 +35,6 @@ from six.moves import http_client import tensorflow as tf from werkzeug import serving -from google.protobuf import text_format from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.tensorboard import tensorboard @@ -168,9 +167,7 @@ class TensorboardServerTest(tf.test.TestCase): { 'run1': { # if only_use_meta_graph, the graph is from the metagraph - 'graph': True, 'meta_graph': self._only_use_meta_graph, - 'run_metadata': ['test run'], 'tensors': [], } }) @@ -191,8 +188,7 @@ class TensorboardServerTest(tf.test.TestCase): def testDataPaths_disableAllCaching(self): """Test the format of the /data/runs endpoint.""" - for path in ('/data/runs', '/data/logdir', - '/data/run_metadata?run=run1&tag=test%20run'): + for path in ('/data/runs', '/data/logdir'): connection = http_client.HTTPConnection('localhost', self._server.server_address[1]) connection.request('GET', path) @@ -202,69 +198,11 @@ class TensorboardServerTest(tf.test.TestCase): response.read() connection.close() - def testGraph(self): - """Test retrieving the graph definition.""" - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs') - self.assertEqual(response.status, 200) - graph_pbtxt = response.read() - # Parse the graph from pbtxt into a graph message. - graph = tf.GraphDef() - graph = text_format.Parse(graph_pbtxt, graph) - self.assertEqual(len(graph.node), 2) - self.assertEqual(graph.node[0].name, 'a') - self.assertEqual(graph.node[1].name, 'b') - # Make sure the second node has an attribute that was filtered out because - # it was too large and was added to the "too large" attributes list. - self.assertEqual(list(graph.node[1].attr.keys()), ['_very_large_attrs']) - self.assertEqual(graph.node[1].attr['_very_large_attrs'].list.s, - [b'very_large_attr']) - - def testAcceptGzip_compressesResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': 'gzip'}) - self.assertEqual(response.status, 200) - self.assertEqual(response.getheader('Content-Encoding'), 'gzip') - pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read() - graph = text_format.Parse(pbtxt, tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testAcceptAnyEncoding_compressesResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': '*'}) - self.assertEqual(response.status, 200) - self.assertEqual(response.getheader('Content-Encoding'), 'gzip') - pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read() - graph = text_format.Parse(pbtxt, tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testAcceptDoodleEncoding_doesNotCompressResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': 'doodle'}) - self.assertEqual(response.status, 200) - self.assertIsNone(response.getheader('Content-Encoding')) - graph = text_format.Parse(response.read(), tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testRunMetadata(self): - """Test retrieving the run metadata information.""" - response = self._get('/data/run_metadata?run=run1&tag=test%20run') - self.assertEqual(response.status, 200) - run_metadata_pbtxt = response.read() - # Parse from pbtxt into a message. - run_metadata = tf.RunMetadata() - text_format.Parse(run_metadata_pbtxt, run_metadata) - self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) - self.assertEqual(run_metadata.step_stats.dev_stats[0].device, 'test device') - def _GenerateTestData(self): """Generates the test data directory. The test data has a single run named run1 which contains: - - a graph definition + - a graph definition and metagraph definition Returns: temp_dir: The directory the test data is generated under. @@ -290,12 +228,6 @@ class TensorboardServerTest(tf.test.TestCase): else: writer.add_graph(graph_def) - # Add a simple run metadata event. - run_metadata = tf.RunMetadata() - device_stats = run_metadata.step_stats.dev_stats.add() - device_stats.device = 'test device' - writer.add_run_metadata(run_metadata, 'test run') - writer.flush() writer.close() diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index 5d048311021..93ea811eda7 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -16,7 +16,7 @@ limitations under the License. import {compareTagNames} from '../vz-sorting/sorting'; import {RequestManager} from './requestManager'; import {Router} from './router'; -import {demoify} from './urlPathHelpers'; +import {demoify, queryEncoder} from './urlPathHelpers'; export interface RunEnumeration { histograms: string[]; @@ -199,16 +199,16 @@ export class Backend { * Return a promise showing list of runs that contain graphs. */ public graphRuns(): Promise { - return this.runs().then((x) => { - return _.keys(x).filter((k) => x[k].graph); - }); + return this.requestManager.request( + this.router.pluginRoute('graphs', '/runs')); } /** * Return a promise showing the Run-to-Tag mapping for run_metadata objects. */ - public runMetadataRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'run_metadata')); + public runMetadataTags(): Promise { + return this.requestManager.request( + this.router.pluginRoute('graphs', '/run_metadata_tags')); } @@ -233,11 +233,25 @@ export class Backend { } /** - * Return a promise of a graph string from the backend. + * Return a URL to fetch a graph (cf. method 'graph'). */ - public graph(tag: string, limitAttrSize?: number, largeAttrKeys?: string): + public graphUrl(run: string, limitAttrSize?: number, largeAttrsKey?: string): + string { + const demoMode = this.router.isDemoMode(); + const base = this.router.pluginRoute('graphs', '/graph'); + const optional = (p) => (p != null && !demoMode || undefined) && p; + const parameters = { + 'run': run, + 'limit_attr_size': optional(limitAttrSize), + 'large_attrs_key': optional(largeAttrsKey), + }; + const extension = demoMode ? '.pbtxt' : ''; + return base + queryEncoder(parameters) + extension; + } + + public graph(run: string, limitAttrSize?: number, largeAttrsKey?: string): Promise { - const url = this.router.graph(tag, limitAttrSize, largeAttrKeys); + const url = this.graphUrl(run, limitAttrSize, largeAttrsKey); return this.requestManager.request(url); } @@ -288,7 +302,7 @@ export class Backend { Promise> { let p: Promise[]>; const url = - (this.router.pluginRunTagRoute('histograms', '/histograms')(tag, run)); + this.router.pluginRunTagRoute('histograms', '/histograms')(tag, run); p = this.requestManager.request(url); return p.then(map(detupler(createHistogram))).then(function(histos) { // Get the minimum and maximum values across all histograms so that the @@ -326,11 +340,18 @@ export class Backend { return p.then(map(this.createAudio.bind(this))); } + /** + * Returns the url for the RunMetadata for the given run/tag. + */ + public runMetadataUrl(tag: string, run: string): string { + return this.router.pluginRunTagRoute('graphs', '/run_metadata')(tag, run); + } + /** * Returns a promise to load the string RunMetadata for given run/tag. */ public runMetadata(tag: string, run: string): Promise { - const url = this.router.runMetadata(tag, run); + const url = this.runMetadataUrl(tag, run); return this.requestManager.request(url); } diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts index 115634be125..ad8bface57f 100644 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ b/tensorflow/tensorboard/components/tf_backend/router.ts @@ -21,10 +21,6 @@ export interface Router { logdir: () => string; runs: () => string; isDemoMode: () => boolean; - graph: - (run: string, limit_attr_size?: number, - large_attrs_key?: string) => string; - runMetadata: RunTagUrlFn; textRuns: () => string; text: RunTagUrlFn; healthPills: () => string; @@ -54,26 +50,6 @@ export function router(dataDir = 'data', demoMode = false): Router { return url; }; } - function graphUrl( - run: string, limit_attr_size?: number, large_attrs_key?: string) { - let query_params = [['run', clean(run)]]; - if (limit_attr_size != null && !demoMode) { - query_params.push(['limit_attr_size', String(limit_attr_size)]); - } - if (large_attrs_key != null && !demoMode) { - query_params.push(['large_attrs_key', large_attrs_key]); - } - let query = query_params - .map(param => { - return param[0] + '=' + encodeURIComponent(param[1]); - }) - .join('&'); - var url = dataDir + '/graph' + clean('?' + query); - if (demoMode) { - url += '.pbtxt'; - } - return url; - } function pluginRoute(pluginName: string, route: string): string { return `${dataDir}/plugin/${pluginName}${route}`; } @@ -86,8 +62,6 @@ export function router(dataDir = 'data', demoMode = false): Router { logdir: () => dataDir + '/logdir', runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), isDemoMode: () => demoMode, - graph: graphUrl, - runMetadata: standardRoute('run_metadata', '.pbtxt'), healthPills: () => dataDir + '/plugin/debugger/health_pills', textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), text: standardRoute('plugin/text/text'), diff --git a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts index 0ef58157aef..2ca04e07864 100644 --- a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts +++ b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts @@ -139,7 +139,7 @@ describe('backend tests', () => { chai.assert.deepEqual(x, audio); next(); }); - backend.runMetadataRuns().then((x) => { + backend.runMetadataTags().then((x) => { chai.assert.deepEqual(x, runMetadata); next(); }); diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html index 0ce1e338675..7e0ce2647bd 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html @@ -189,20 +189,20 @@ Polymer({ } // Set this to true so we only initialize once. this._initialized = true; - Promise.all([backend.graphRuns(), backend.runMetadataRuns()]) + Promise.all([backend.graphRuns(), backend.runMetadataTags()]) .then(function(result) { var runsWithGraph = result[0].sort(compareTagNames); var runToMetadata = result[1]; var datasets = _.map(runsWithGraph, function(runName) { return { name: runName, - path: backend.router.graph( + path: backend.graphUrl( runName, tf.graph.LIMIT_ATTR_SIZE, tf.graph.LARGE_ATTRS_KEY), runMetadata: runToMetadata[runName] ? _.map( runToMetadata[runName].sort(compareTagNames), function(tag) { return { tag: tag, - path: backend.router.runMetadata(tag, runName) + path: backend.runMetadataUrl(tag, runName) }; }, this) : [] }; diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md index b015539cf52..c62de0376d2 100644 --- a/tensorflow/tensorboard/http_api.md +++ b/tensorflow/tensorboard/http_api.md @@ -55,13 +55,9 @@ all of the data available from the TensorBoard server. Here is an example: { "train_run": { - "graph": true, "firstEventTimestamp": 123456.789 - "run_metadata": ["forward prop", "inference"] }, "eval": { - "graph": false, - "run_metadata": [] } } @@ -81,6 +77,8 @@ and will not appear in the output from this route: - `scalars` - `compressedHistograms`, moved to `distributions` - `histograms` + - `graph`, as `/data/plugin/graphs/runs` + - `run_metadata`, as `/data/plugin/graphs/run_metadata_tags` ## `/data/plugin/scalars/tags` @@ -296,11 +294,19 @@ tags present in the corresponding run. Here is an example: Note that runs without any audio tags are included as keys with value the empty array. -## `/data/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` +## `/data/plugin/graphs/runs` -Returns the graph definition for the given run in gzipped pbtxt format. The -graph is composed of a list of nodes, where each node is a specific TensorFlow -operation which takes as inputs other nodes (operations). +Returns a list of runs that have associated graphs. + +For example: + + ["train"] + +## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` + +Returns the graph definition for the given run in pbtxt format. The +graph is composed of a list of nodes, where each node is a specific +TensorFlow operation which takes as inputs other nodes (operations). The query parameters `limit_attr_size` and `large_attrs_key` are optional. @@ -313,7 +319,10 @@ attributes that are too large. The value of this key (list of strings) should be used by the client in order to determine which attributes have been filtered. Must be specified if `limit_attr_size` is specified. -For the query `/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large`, +For the query + + /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, + here is an example pbtxt response of a graph with 3 nodes, where the second node had two large attributes "a" and "b" that were filtered out (size > 1024): diff --git a/tensorflow/tensorboard/plugins/graphs/BUILD b/tensorflow/tensorboard/plugins/graphs/BUILD new file mode 100644 index 00000000000..a6feb08c814 --- /dev/null +++ b/tensorflow/tensorboard/plugins/graphs/BUILD @@ -0,0 +1,51 @@ +# Description: +# TensorBoard plugin for graphs + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +## Graphs Plugin ## +py_library( + name = "graphs_plugin", + srcs = ["graphs_plugin.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/tensorboard/backend:http_util", + "//tensorflow/tensorboard/backend:process_graph", + "//tensorflow/tensorboard/backend/event_processing:event_accumulator", + "//tensorflow/tensorboard/plugins:base_plugin", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +py_test( + name = "graphs_plugin_test", + size = "small", + srcs = ["graphs_plugin_test.py"], + main = "graphs_plugin_test.py", + srcs_version = "PY2AND3", + deps = [ + ":graphs_plugin", + "//tensorflow:tensorflow_py", + "//tensorflow/tensorboard/backend:application", + "//tensorflow/tensorboard/backend/event_processing:event_multiplexer", + "@org_pocoo_werkzeug//:werkzeug", + "@six_archive//:six", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + visibility = ["//tensorflow:__pkg__"], +) diff --git a/tensorflow/tensorboard/plugins/graphs/graphs_plugin.py b/tensorflow/tensorboard/plugins/graphs/graphs_plugin.py new file mode 100644 index 00000000000..7fdbf9903db --- /dev/null +++ b/tensorflow/tensorboard/plugins/graphs/graphs_plugin.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The TensorBoard Graphs plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from werkzeug import wrappers + +from tensorflow.tensorboard.backend import http_util +from tensorflow.tensorboard.backend import process_graph +from tensorflow.tensorboard.backend.event_processing import event_accumulator +from tensorflow.tensorboard.plugins import base_plugin + +_PLUGIN_PREFIX_ROUTE = 'graphs' + + +class GraphsPlugin(base_plugin.TBPlugin): + """Graphs Plugin for TensorBoard.""" + + plugin_name = _PLUGIN_PREFIX_ROUTE + + def get_plugin_apps(self, multiplexer, unused_logdir): + self._multiplexer = multiplexer + return { + '/graph': self.graph_route, + '/runs': self.runs_route, + '/run_metadata': self.run_metadata_route, + '/run_metadata_tags': self.run_metadata_tags_route, + } + + def is_active(self): + """The graphs plugin is active iff any run has a graph.""" + return bool(self.index_impl()) + + def index_impl(self): + """Returns a list of all runs that have a graph.""" + return [run_name + for (run_name, run_data) in self._multiplexer.Runs().items() + if run_data.get(event_accumulator.GRAPH)] + + def run_metadata_index_impl(self): + """Returns a run-to-tag mapping for metadata.""" + return { + run_name: run_data[event_accumulator.RUN_METADATA] + for (run_name, run_data) in self._multiplexer.Runs().items() + if event_accumulator.RUN_METADATA in run_data + } + + def graph_impl(self, run, limit_attr_size=None, large_attrs_key=None): + """Result of the form `(body, mime_type)`, or `None` if no graph exists.""" + try: + graph = self._multiplexer.Graph(run) + except ValueError: + return None + # This next line might raise a ValueError if the limit parameters + # are invalid (size is negative, size present but key absent, etc.). + process_graph.prepare_graph_for_ui(graph, limit_attr_size, large_attrs_key) + return (str(graph), 'text/x-protobuf') # pbtxt + + def run_metadata_impl(self, run, tag): + """Result of the form `(body, mime_type)`, or `None` if no data exists.""" + try: + run_metadata = self._multiplexer.RunMetadata(run, tag) + except ValueError: + return None + return (str(run_metadata), 'text/x-protobuf') # pbtxt + + @wrappers.Request.application + def runs_route(self, request): + index = self.index_impl() + return http_util.Respond(request, index, 'application/json') + + @wrappers.Request.application + def run_metadata_tags_route(self, request): + index = self.run_metadata_index_impl() + return http_util.Respond(request, index, 'application/json') + + @wrappers.Request.application + def graph_route(self, request): + """Given a single run, return the graph definition in protobuf format.""" + run = request.args.get('run') + if run is None: + return http_util.Respond( + request, 'query parameter "run" is required', 'text/plain', 400) + + limit_attr_size = request.args.get('limit_attr_size', None) + if limit_attr_size is not None: + try: + limit_attr_size = int(limit_attr_size) + except ValueError: + return http_util.Respond( + request, 'query parameter `limit_attr_size` must be an integer', + 'text/plain', 400) + + large_attrs_key = request.args.get('large_attrs_key', None) + + try: + result = self.graph_impl(run, limit_attr_size, large_attrs_key) + except ValueError as e: + return http_util.Respond(request, e.message, 'text/plain', code=400) + else: + if result is not None: + (body, mime_type) = result # pylint: disable=unpacking-non-sequence + return http_util.Respond(request, body, mime_type) + else: + return http_util.Respond(request, '404 Not Found', 'text/plain', + code=404) + + @wrappers.Request.application + def run_metadata_route(self, request): + """Given a tag and a run, return the session.run() metadata.""" + tag = request.args.get('tag') + run = request.args.get('run') + if tag is None: + return http_util.Respond( + request, 'query parameter "tag" is required', 'text/plain', 400) + if run is None: + return http_util.Respond( + request, 'query parameter "run" is required', 'text/plain', 400) + result = self.run_metadata_impl(run, tag) + if result is not None: + (body, mime_type) = result # pylint: disable=unpacking-non-sequence + return http_util.Respond(request, body, mime_type) + else: + return http_util.Respond(request, '404 Not Found', 'text/plain', + code=404) diff --git a/tensorflow/tensorboard/plugins/graphs/graphs_plugin_test.py b/tensorflow/tensorboard/plugins/graphs/graphs_plugin_test.py new file mode 100644 index 00000000000..db4d0cb1b3c --- /dev/null +++ b/tensorflow/tensorboard/plugins/graphs/graphs_plugin_test.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration tests for the Graphs Plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import os.path + +import tensorflow as tf + +from google.protobuf import text_format +from tensorflow.tensorboard.backend.event_processing import event_multiplexer +from tensorflow.tensorboard.plugins.graphs import graphs_plugin + + +class GraphsPluginTest(tf.test.TestCase): + + _RUN_WITH_GRAPH = '_RUN_WITH_GRAPH' + _RUN_WITHOUT_GRAPH = '_RUN_WITHOUT_GRAPH' + + _METADATA_TAG = 'secret-stats' + _MESSAGE_PREFIX_LENGTH_LOWER_BOUND = 1024 + + def generate_run(self, run_name, include_graph): + """Create a run with a text summary, metadata, and optionally a graph.""" + tf.reset_default_graph() + k1 = tf.constant(math.pi, name='k1') + k2 = tf.constant(math.e, name='k2') + result = (k1 ** k2) - k1 + expected = tf.constant(20.0, name='expected') + error = tf.abs(result - expected, name='error') + message_prefix_value = 'error ' * 1000 + true_length = len(message_prefix_value) + assert true_length > self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, true_length + message_prefix = tf.constant(message_prefix_value, name='message_prefix') + error_message = tf.string_join([message_prefix, + tf.as_string(error, name='error_string')], + name='error_message') + summary_message = tf.summary.text('summary_message', error_message) + + sess = tf.Session() + writer = tf.summary.FileWriter(os.path.join(self.logdir, run_name)) + if include_graph: + writer.add_graph(sess.graph) + options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + s = sess.run(summary_message, options=options, run_metadata=run_metadata) + writer.add_summary(s) + writer.add_run_metadata(run_metadata, self._METADATA_TAG) + writer.close() + + def set_up_with_runs(self, with_graph=True, without_graph=True): + self.logdir = self.get_temp_dir() + if with_graph: + self.generate_run(self._RUN_WITH_GRAPH, include_graph=True) + if without_graph: + self.generate_run(self._RUN_WITHOUT_GRAPH, include_graph=False) + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + self.plugin = graphs_plugin.GraphsPlugin() + self.plugin.get_plugin_apps(multiplexer, None) + + def test_index(self): + self.set_up_with_runs() + self.assertItemsEqual([self._RUN_WITH_GRAPH], self.plugin.index_impl()) + + def test_run_metadata_index(self): + self.set_up_with_runs() + self.assertDictEqual({ + self._RUN_WITH_GRAPH: [self._METADATA_TAG], + self._RUN_WITHOUT_GRAPH: [self._METADATA_TAG], + }, self.plugin.run_metadata_index_impl()) + + def _get_graph(self, *args, **kwargs): + """Set up runs, then fetch and return the graph as a proto.""" + self.set_up_with_runs() + (graph_pbtxt, mime_type) = self.plugin.graph_impl( + self._RUN_WITH_GRAPH, *args, **kwargs) + self.assertEqual(mime_type, 'text/x-protobuf') + return text_format.Parse(graph_pbtxt, tf.GraphDef()) + + def test_graph_simple(self): + graph = self._get_graph() + node_names = set(node.name for node in graph.node) + self.assertEqual({'k1', 'k2', 'pow', 'sub', 'expected', 'sub_1', 'error', + 'message_prefix', 'error_string', 'error_message', + 'summary_message'}, + node_names) + + def test_graph_large_attrs(self): + key = 'o---;;-;' + graph = self._get_graph( + limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, + large_attrs_key=key) + large_attrs = { + node.name: list(node.attr[key].list.s) + for node in graph.node + if key in node.attr + } + self.assertEqual({'message_prefix': [b'value']}, + large_attrs) + + def test_run_metadata(self): + self.set_up_with_runs() + (metadata_pbtxt, mime_type) = self.plugin.run_metadata_impl( + self._RUN_WITH_GRAPH, self._METADATA_TAG) + self.assertEqual(mime_type, 'text/x-protobuf') + text_format.Parse(metadata_pbtxt, tf.RunMetadata()) + # If it parses, we're happy. + + def test_is_active_with_graph(self): + self.set_up_with_runs(with_graph=True, without_graph=False) + self.assertTrue(self.plugin.is_active()) + + def test_is_active_without_graph(self): + self.set_up_with_runs(with_graph=False, without_graph=True) + self.assertFalse(self.plugin.is_active()) + + def test_is_active_with_both(self): + self.set_up_with_runs(with_graph=True, without_graph=True) + self.assertTrue(self.plugin.is_active()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py index 70830b9a8c8..3665d02ff55 100644 --- a/tensorflow/tensorboard/tensorboard.py +++ b/tensorflow/tensorboard/tensorboard.py @@ -34,6 +34,7 @@ from tensorflow.tensorboard.backend import application from tensorflow.tensorboard.backend.event_processing import event_file_inspector as efi from tensorflow.tensorboard.plugins.audio import audio_plugin from tensorflow.tensorboard.plugins.distributions import distributions_plugin +from tensorflow.tensorboard.plugins.graphs import graphs_plugin from tensorflow.tensorboard.plugins.histograms import histograms_plugin from tensorflow.tensorboard.plugins.images import images_plugin from tensorflow.tensorboard.plugins.projector import projector_plugin @@ -208,6 +209,7 @@ def main(unused_argv=None): scalars_plugin.ScalarsPlugin(), images_plugin.ImagesPlugin(), audio_plugin.AudioPlugin(), + graphs_plugin.GraphsPlugin(), distributions_plugin.DistributionsPlugin(), histograms_plugin.HistogramsPlugin(), projector_plugin.ProjectorPlugin(), From 0c92dada6a0790d4c0cbd54ce4c801b1940dc4ed Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 16:06:51 -0700 Subject: [PATCH 64/72] Use inplace Cholesky factorization and solves to speed up and reduce memory usage in matrix_solve_ls. Check succes before copying outputs in cholesky_op. PiperOrigin-RevId: 157887564 --- tensorflow/core/kernels/cholesky_op.cc | 6 +++--- tensorflow/core/kernels/matrix_solve_ls_op.cc | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index 5c7102f6f67..755ce7c43bd 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -64,11 +64,11 @@ class CholeskyOp : public LinearAlgebraOp { Eigen::Matrix> llt_decomposition(input); - // Output the lower triangular in a dense form. - outputs->at(0) = llt_decomposition.matrixL(); - OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success, errors::InvalidArgument(kErrMsg)); + + // Output the lower triangular in a dense form. + outputs->at(0) = llt_decomposition.matrixL(); } }; diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc index 11e7c94faf3..381a5ec7b9d 100644 --- a/tensorflow/core/kernels/matrix_solve_ls_op.cc +++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc @@ -105,18 +105,19 @@ class MatrixSolveLsOp : public LinearAlgebraOp { // using Cholesky decomposition. Matrix gramian(cols, cols); gramian.template triangularView() = - matrix.transpose() * matrix; + matrix.adjoint() * matrix; if (l2_regularizer > 0) { gramian += (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal(); } - const Eigen::LLT llt(gramian); + const Eigen::LLT, Eigen::Lower> llt(gramian); OP_REQUIRES( context, llt.info() == Eigen::Success, errors::InvalidArgument("Input matrix was rank deficient or " "ill-conditioned. Try setting fast=False " "or provide a larger l2_regularizer > 0.")); - outputs->at(0) = llt.solve(matrix.transpose() * rhs); + outputs->at(0).noalias() = matrix.adjoint() * rhs; + llt.solveInPlace(outputs->at(0)); } else { // Underdetermined case (rows < cols): Solves the minimum-norm problem // min ||X||_F^2 s.t. A*X = RHS @@ -125,18 +126,18 @@ class MatrixSolveLsOp : public LinearAlgebraOp { // using Cholesky decomposition. Matrix gramian(rows, rows); gramian.template triangularView() = - matrix * matrix.transpose(); + matrix * matrix.adjoint(); if (l2_regularizer > 0) { gramian += (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal(); } - const Eigen::LLT llt(gramian); + const Eigen::LLT, Eigen::Lower> llt(gramian); OP_REQUIRES( context, llt.info() == Eigen::Success, errors::InvalidArgument("Input matrix was rank deficient or " "ill-conditioned. Try setting fast=False " "or provide an l2_regularizer > 0.")); - outputs->at(0) = matrix.transpose() * llt.solve(rhs); + outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs); } } else { // Use complete orthogonal decomposition which is backwards stable and From 9b8f6113b7894acd07720e55f1cc6a33a1dc4b53 Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Fri, 2 Jun 2017 16:28:10 -0700 Subject: [PATCH 65/72] tensor_bundle: fix that the read path forgets to cache file handles. In a case where a reader is geographically far from the file, this change achieves a speedup of end-to-end checkpoint restore by 5.8x. PiperOrigin-RevId: 157889659 --- .../core/util/tensor_bundle/tensor_bundle.cc | 20 +++++++++++++------ .../core/util/tensor_bundle/tensor_bundle.h | 1 + 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 334444a4a22..b495bc31b1f 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -640,6 +640,12 @@ BundleReader::~BundleReader() { delete metadata_; delete iter_; delete table_; + // InputBuffer does not own the underlying RandomAccessFile. + for (auto pair : data_) { + if (pair.second->file() != nullptr) { + delete pair.second->file(); + } + } gtl::STLDeleteValues(&data_); gtl::STLDeleteValues(&tensor_slices_); } @@ -694,14 +700,16 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { } } - // Open the data file if not opened it. - std::unique_ptr file = nullptr; - std::unique_ptr buffered_file(data_[entry.shard_id()]); + // Open the data file if it has not been opened. + io::InputBuffer* buffered_file = data_[entry.shard_id()]; if (buffered_file == nullptr) { + std::unique_ptr file = nullptr; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); - buffered_file.reset( - new io::InputBuffer(file.get(), 256 << 10 /* 256KB buffer */)); + buffered_file = + new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */); + // The InputBuffer and RandomAccessFile objects are both released in dtor. + data_[entry.shard_id()] = buffered_file; } CHECK(buffered_file != nullptr); @@ -720,7 +728,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { // Relies on io::InputBuffer's buffering, because we issue many neighboring // reads for a single string tensor. TF_RETURN_IF_ERROR(ReadStringTensor( - buffered_file.get(), ret->NumElements(), entry.offset(), entry.size(), + buffered_file, ret->NumElements(), entry.offset(), entry.size(), GetStringBackingBuffer(*ret), &actual_crc32c)); } if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index b0bddf7e423..8562f89d920 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -273,6 +273,7 @@ class BundleReader { RandomAccessFile* metadata_; // Owned. table::Table* table_; table::Iterator* iter_; + // Owned the InputBuffer objects and their underlying RandomAccessFile's. std::unordered_map data_; // Maps each partitioned tensor's key to its stored slices (represented in a From f37d0ea47b0b4cf0e66676702f48800e03d18655 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Jun 2017 16:53:15 -0700 Subject: [PATCH 66/72] Internal change -- first draft docs PiperOrigin-RevId: 157891937 --- .../docs_src/programmers_guide/embedding.md | 352 ++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 tensorflow/docs_src/programmers_guide/embedding.md diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md new file mode 100644 index 00000000000..975850349f0 --- /dev/null +++ b/tensorflow/docs_src/programmers_guide/embedding.md @@ -0,0 +1,352 @@ +# Embeddings + +[TOC] + +## Introduction + +An embedding is a mapping from discrete objects, such as words, to vectors of +real numbers. For example, a 300-dimensional embedding for English words could +include: + +``` +blue: (0.01359, 0.00075997, 0.24608, ..., -0.2524, 1.0048, 0.06259) +blues: (0.01396, 0.11887, -0.48963, ..., 0.033483, -0.10007, 0.1158) +orange: (-0.24776, -0.12359, 0.20986, ..., 0.079717, 0.23865, -0.014213) +oranges: (-0.35609, 0.21854, 0.080944, ..., -0.35413, 0.38511, -0.070976) +``` + +Embeddings let you apply machine learning to discrete inputs. Classifiers, and +neural networks more generally, are designed to work with dense continuous +vectors, where all values contribute to define what an object is. If discrete +objects are naively encoded as discrete atoms, e.g., unique id numbers, they +hinder learning and generalization. One way to think of embeddings is as a way +to transform non-vector objects into useful inputs for machine learning. + +Embeddings are also useful as outputs of machine learning. Because embeddings +map objects to vectors, applications can use similarity in vector space (e.g., +Euclidean distance or the angle between vectors) as a robust and flexible +measure of object similarity. One common use is to find nearest neighbors. +Using the same word embeddings above, for instance, here are the three nearest +neighbors for each word and the corresponding angles (in degrees): + +``` +blue: (red, 47.6°), (yellow, 51.9°), (purple, 52.4°) +blues: (jazz, 53.3°), (folk, 59.1°), (bluegrass, 60.6°) +orange: (yellow, 53.5°), (colored, 58.0°), (bright, 59.9°) +oranges: (apples, 45.3°), (lemons, 48.3°), (mangoes, 50.4°) +``` + +This would tell an application that apples and oranges are in some way more +similar (45.3° apart) than lemons and oranges (48.3° apart). + +## Training an Embedding + +To train word embeddings in TensorFlow, we first need to split the text into +words and assign an integer to every word in the vocabulary. Let us assume that +this has already been done, and that `word_ids` is a vector of these integers. +For example, the sentence “I have a cat.” could be split into +`[“I”, “have”, “a”, “cat”, “.”]` and then the corresponding `word_ids` tensor +would have shape `[5]` and consist of 5 integers. To get these word ids +embedded, we need to create the embedding variable and use the `tf.gather` +function as follows: + +``` +word_embeddings = tf.get_variable(“word_embeddings”, + [vocabulary_size, embedding_size]) +embedded_word_ids = tf.gather(word_embeddings, word_ids) +``` + +After this, the tensor `embedded_word_ids` will have shape `[5, embedding_size]` +in our example and contain the embeddings (dense vectors) for each of the 5 +words. The variable `word_embeddings` will be learned and at the end of the +training it will contain the embeddings for all words in the vocabulary. +The embeddings can be trained in many ways, depending on the data available. +For example, one could use a recurrent neural network to predict the next word +from the previous one given a large corpus of sentences, or one could train +two networks to do multi-lingual translation. These methods are described in +[Vector Representations of Words](../tutorials/word2vec.md) tutorial, but in +all cases there is an embedding variable like above and words are embedded +using `tf.gather`, as shown. + +## Visualizing Embeddings + +TensorBoard has a built-in visualizer, called the Embedding Projector, +for interactive visualization of embeddings. The embedding projector will read +the embeddings from your checkpoint file and project them into 3 dimensions using +[principal component analysis](https://en.wikipedia.org/wiki/Principal_component_analysis). +For a visual explanation of PCA, see +[this article](http://setosa.io/ev/principal-component-analysis/). Another +very useful projection you can use is +[t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). + +If you are working with an embedding, you'll probably want to attach +labels/images to the data points. You can do this by generating a +[metadata file](#metadata) containing the labels for each point and configuring +the projector either by using our Python API, or manually constructing and +saving a +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) +in the same directory as your checkpoint file. + +### Setup + +For in depth information on how to run TensorBoard and make sure you are +logging all the necessary information, see +[TensorBoard: Visualizing Learning](../get_started/summaries_and_tensorboard.md). + +To visualize your embeddings, there are 3 things you need to do: + +1) Setup a 2D tensor that holds your embedding(s). + +```python +embedding_var = tf.get_variable(....) +``` + +2) Periodically save your model variables in a checkpoint in +LOG_DIR. + +```python +saver = tf.train.Saver() +saver.save(session, os.path.join(LOG_DIR, "model.ckpt"), step) +``` + +3) (Optional) Associate metadata with your embedding. + +If you have any metadata (labels, images) associated with your embedding, you +can tell TensorBoard about it either by directly storing a +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) +in the LOG_DIR, or use our python API. + +For instance, the following projector_config.ptxt associates the +word_embedding tensor with metadata stored in $LOG_DIR/metadata.tsv: + +``` +embeddings { + tensor_name: 'word_embedding' + metadata_path: '$LOG_DIR/metadata.tsv' +} +``` + +The same config can be produced programmatically using the following code snippet: + +```python +from tensorflow.contrib.tensorboard.plugins import projector + +# Create randomly initialized embedding weights which will be trained. +vocabulary_size = 10000 +embedding_size = 200 +embedding_var = tf.get_variable('word_embedding', [vocabulary_size, embedding_size]) + +# Format: tensorflow/tensorboard/plugins/projector/projector_config.proto +config = projector.ProjectorConfig() + +# You can add multiple embeddings. Here we add only one. +embedding = config.embeddings.add() +embedding.tensor_name = embedding_var.name +# Link this tensor to its metadata file (e.g. labels). +embedding.metadata_path = os.path.join(LOG_DIR, 'metadata.tsv') + +# Use the same LOG_DIR where you stored your checkpoint. +summary_writer = tf.summary.FileWriter(LOG_DIR) + +# The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will +# read this file during startup. +projector.visualize_embeddings(summary_writer, config) +``` + +After running your model and training your embeddings, run TensorBoard and point +it to the LOG_DIR of the job. + +```python +tensorboard --logdir=LOG_DIR +``` + +Then click on the *Embeddings* tab on the top pane +and select the appropriate run (if there are more than one run). + + +### Metadata +Usually embeddings have metadata associated with it (e.g. labels, images). The +metadata should be stored in a separate file outside of the model checkpoint +since the metadata is not a trainable parameter of the model. The format should +be a [TSV file](https://en.wikipedia.org/wiki/Tab-separated_values) +(tab characters shown in red) with the first line containing column headers +(shown in bold) and subsequent lines contain the metadata values: + + +Word\tFrequency
+ Airplane\t345
+ Car\t241
+ ... +
+ +There is no explicit key shared with the main data file; instead, the order in +the metadata file is assumed to match the order in the embedding tensor. In +other words, the first line is the header information and the (i+1)-th line in +the metadata file corresponds to the i-th row of the embedding tensor stored in +the checkpoint. + +Note: If the TSV metadata file has only a single column, then we don’t expect a +header row, and assume each row is the label of the embedding. We include this +exception because it matches the commonly-used "vocab file" format. + +### Images +If you have images associated with your embeddings, you will need to +produce a single image consisting of small thumbnails of each data point. +This is known as the +[sprite image](https://www.google.com/webhp#q=what+is+a+sprite+image). +The sprite should have the same number of rows and columns with thumbnails +stored in row-first order: the first data point placed in the top left and the +last data point in the bottom right: + + + + + + + + + + + + + + + + + +
012
345
67
+ +Note in the example above that the last row doesn't have to be filled. For a +concrete example of a sprite, see +[this sprite image](https://www.tensorflow.org/images/mnist_10k_sprite.png) of 10,000 MNIST digits +(100x100). + +Note: We currently support sprites up to 8192px X 8192px. + +After constructing the sprite, you need to tell the Embedding Projector where +to find it: + + +```python +embedding.sprite.image_path = PATH_TO_SPRITE_IMAGE +# Specify the width and height of a single thumbnail. +embedding.sprite.single_image_dim.extend([w, h]) +``` + +### Interaction + +The Embedding Projector has three panels: + +1. *Data panel* on the top left, where you can choose the run, the embedding + tensor and data columns to color and label points by. +2. *Projections panel* on the bottom left, where you choose the type of + projection (e.g. PCA, t-SNE). +3. *Inspector panel* on the right side, where you can search for particular + points and see a list of nearest neighbors. + +### Projections +The Embedding Projector has three methods of reducing the dimensionality of a +data set: two linear and one nonlinear. Each method can be used to create either +a two- or three-dimensional view. + +**Principal Component Analysis** A straightforward technique for reducing +dimensions is Principal Component Analysis (PCA). The Embedding Projector +computes the top 10 principal components. The menu lets you project those +components onto any combination of two or three. PCA is a linear projection, +often effective at examining global geometry. + +**t-SNE** A popular non-linear dimensionality reduction technique is t-SNE. +The Embedding Projector offers both two- and three-dimensional t-SNE views. +Layout is performed client-side animating every step of the algorithm. Because +t-SNE often preserves some local structure, it is useful for exploring local +neighborhoods and finding clusters. Although extremely useful for visualizing +high-dimensional data, t-SNE plots can sometimes be mysterious or misleading. +See this [great article](http://distill.pub/2016/misread-tsne/) for how to use +t-SNE effectively. + +**Custom** You can also construct specialized linear projections based on text +searches for finding meaningful directions in space. To define a projection +axis, enter two search strings or regular expressions. The program computes the +centroids of the sets of points whose labels match these searches, and uses the +difference vector between centroids as a projection axis. + +### Navigation + +To explore a data set, you can navigate the views in either a 2D or a 3D mode, +zooming, rotating, and panning using natural click-and-drag gestures. +Clicking on a point causes the right pane to show an explicit textual list of +nearest neighbors, along with distances to the current point. The +nearest-neighbor points themselves are highlighted on the projection. + +Zooming into the cluster gives some information, but it is sometimes more +helpful to restrict the view to a subset of points and perform projections only +on those points. To do so, you can select points in multiple ways: + +1. After clicking on a point, its nearest neighbors are also selected. +2. After a search, the points matching the query are selected. +3. Enabling selection, clicking on a point and dragging defines a selection + sphere. + +After selecting a set of points, you can isolate those points for +further analysis on their own with the "Isolate Points" button in the Inspector +pane on the right hand side. + + +![Selection of nearest neighbors](https://www.tensorflow.org/images/embedding-nearest-points.png "Selection of nearest neighbors") +*Selection of the nearest neighbors of “important” in a word embedding dataset.* + +The combination of filtering with custom projection can be powerful. Below, we filtered +the 100 nearest neighbors of “politics” and projected them onto the +“best” - “worst” vector as an x axis. The y axis is random. + +You can see that on the right side we have “ideas”, “science”, “perspective”, +“journalism” while on the left we have “crisis”, “violence” and “conflict”. + + + + + + + + + + +
+ Custom controls panel + + Custom projection +
+ Custom projection controls. + + Custom projection of neighbors of "politics" onto "best" - "worst" vector. +
+ +### Collaborative Features + +To share your findings, you can use the bookmark panel in the bottom right +corner and save the current state (including computed coordinates of any +projection) as a small file. The Projector can then be pointed to a set of one +or more of these files, producing the panel below. Other users can then walk +through a sequence of bookmarks. + +Bookmark panel + + +## Mini-FAQ + +**Is "embedding" an action or a thing?** +Both. People talk about embedding words in a vector space (action) and about +producing word embeddings (things). Common to both is the notion of embedding +as a mapping from discrete objects to vectors. Creating or applying that +mapping is an action, but the mapping itself is a thing. + +**Are embeddings high-dimensional or low-dimensional?** +It depends. A 300-dimensional vector space of words and phrases, for instance, +is often called low-dimensional (and dense) when compared to the millions of +words and phrases it can contain. But mathematically it is high-dimensional, +displaying many properties that are dramatically different from what our human +intuition has learned about 2- and 3-dimensional spaces. + +**Is an embedding the same as an embedding layer?** +No; an embedding layer is a part of neural network, but an embedding is a more +general concept. From 675d36be0d405c3680b49bb6e924f0ed0e233df9 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Fri, 2 Jun 2017 17:18:21 -0700 Subject: [PATCH 67/72] Add fused batch norm to tf.layers. PiperOrigin-RevId: 157893874 --- tensorflow/python/BUILD | 7 +- tensorflow/python/layers/normalization.py | 88 +++++++++++++++++-- .../python/layers/normalization_test.py | 81 +++++++++++++++++ .../tools/api/golden/tensorflow.layers.pbtxt | 2 +- 4 files changed, 168 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 93606ce4ce4..dcce808e97d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3555,13 +3555,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "layers_normalization_test", size = "small", srcs = ["layers/normalization_test.py"], - main = "layers/normalization_test.py", - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":array_ops", ":client_testlib", ":framework_for_generated_wrappers", @@ -3571,6 +3569,7 @@ py_test( ":variables", "//third_party/py/numpy", ], + main = "layers/normalization_test.py", ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index ea6f55281ed..780d1c2b8e0 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -66,9 +66,6 @@ class BatchNormalization(base.Layer): moving_variance_initializer: Initializer for the moving variance. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: A string, the name of the layer. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. @@ -82,6 +79,11 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: A string, the name of the layer. """ def __init__(self, @@ -99,6 +101,7 @@ class BatchNormalization(base.Layer): renorm=False, renorm_clipping=None, renorm_momentum=0.99, + fused=False, trainable=True, name=None, **kwargs): @@ -116,6 +119,10 @@ class BatchNormalization(base.Layer): self.beta_regularizer = beta_regularizer self.gamma_regularizer = gamma_regularizer self.renorm = renorm + self.fused = fused + if self.fused and renorm: + raise ValueError( + 'Batch renorm is currently not supported with fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -130,6 +137,13 @@ class BatchNormalization(base.Layer): if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndim = len(input_shape) + # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the + # output back to its original shape accordingly. + if self.fused and ndim != 4: + raise ValueError( + 'Only 4D inputs are currently supported with fused batch norm. ' + 'Consider reshaping the input to 4D and reshape the output back ' + 'to its original shape. Got input rank: ', ndim) if self.axis < 0: axis = ndim + self.axis else: @@ -137,6 +151,20 @@ class BatchNormalization(base.Layer): if axis < 0 or axis >= ndim: raise ValueError('Value of `axis` argument ' + str(self.axis) + ' is out of range for input with rank ' + str(ndim)) + + if self.fused is None: + self.fused = not self.renorm and ndim == 4 and axis in [1, 3] + + if self.fused: + if axis == 1: + self._data_format = 'NCHW' + elif axis == 3: + self._data_format = 'NHWC' + else: + raise ValueError( + 'Only axis 1 and 3 are currently supported dimensions for ' + 'fused batch norm. Got `axis` dimension: ', axis) + param_dim = input_shape[axis] if not param_dim.value: raise ValueError('Input has undefined `axis` dimension. Input shape: ', @@ -152,6 +180,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.beta = None + if self.fused: + self._beta_const = array_ops.constant(0.0, shape=(param_dim,)) if self.scale: self.gamma = self.add_variable(name='gamma', shape=(param_dim,), @@ -160,6 +190,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.gamma = None + if self.fused: + self._gamma_const = array_ops.constant(1.0, shape=(param_dim,)) # Disable variable partitioning when creating the moving mean and variance partitioner = self._scope.partitioner @@ -205,6 +237,45 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(partitioner) self.built = True + def _fused_batch_norm(self, inputs, training): + """Returns the output of fused batch norm.""" + beta = self.beta if self.center else self._beta_const + gamma = self.gamma if self.scale else self._gamma_const + + def _fused_batch_norm_training(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + epsilon=self.epsilon, + data_format=self._data_format) + + def _fused_batch_norm_inference(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + mean=self.moving_mean, + variance=self.moving_variance, + epsilon=self.epsilon, + is_training=False, + data_format=self._data_format) + + output, mean, variance = utils.smart_cond( + training, _fused_batch_norm_training, _fused_batch_norm_inference) + + training_value = utils.constant_value(training) + if training_value is not False: + decay = _smart_select(training, lambda: self.momentum, lambda: 1.) + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, decay, zero_debias=False) + self.add_update(mean_update, inputs=inputs) + self.add_update(variance_update, inputs=inputs) + + return output + def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) @@ -265,6 +336,9 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + if self.fused: + return self._fused_batch_norm(inputs, training=training) + # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() @@ -353,7 +427,8 @@ def batch_normalization(inputs, reuse=None, renorm=False, renorm_clipping=None, - renorm_momentum=0.99): + renorm_momentum=0.99, + fused=False): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -415,6 +490,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. Returns: Output tensor. @@ -431,10 +508,11 @@ def batch_normalization(inputs, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, - trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_momentum, + fused=fused, + trainable=trainable, name=name, _reuse=reuse, _scope=name) diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 933f196e011..fa6c9c4a5db 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -262,6 +262,87 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def test4DInputAxis3Fused(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=3, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis1Fused(self): + if test.is_gpu_available(cuda_only=True): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=1, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def testNegativeAxis(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt index 78b10c44a23..418ca3ea466 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "batch_normalization" - argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'\', \'\', \'\', \'\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], " + argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'\', \'\', \'\', \'\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], " } member_method { name: "conv1d" From 1c70fb6869d9099a0b52592cc6c3c47a1b6819aa Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Fri, 2 Jun 2017 17:45:41 -0700 Subject: [PATCH 68/72] Add training test for multi classes (n>2) linear classifier. PiperOrigin-RevId: 157896002 --- .../python/estimator/canned/linear_test.py | 190 ++++++++++++------ 1 file changed, 131 insertions(+), 59 deletions(-) diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py index 1e10d5b1e42..403e6b4f2a3 100644 --- a/tensorflow/python/estimator/canned/linear_test.py +++ b/tensorflow/python/estimator/canned/linear_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import os import shutil import tempfile @@ -648,7 +649,7 @@ class LinearRegressorTrainingTest(test.TestCase): if self._model_dir: shutil.rmtree(self._model_dir) - def _mockOptimizer(self, expected_loss=None): + def _mock_optimizer(self, expected_loss=None): expected_var_names = [ '%s/part_0:0' % _AGE_WEIGHT_NAME, '%s/part_0:0' % _BIAS_NAME @@ -680,7 +681,7 @@ class LinearRegressorTrainingTest(test.TestCase): mock_optimizer.__deepcopy__ = lambda _: mock_optimizer return mock_optimizer - def _assertCheckpoint( + def _assert_checkpoint( self, expected_global_step, expected_age_weight=None, expected_bias=None): shapes = { name: shape for (name, shape) in @@ -717,7 +718,7 @@ class LinearRegressorTrainingTest(test.TestCase): num_steps = 10 linear_regressor.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self._assertCheckpoint(num_steps) + self._assert_checkpoint(num_steps) def testTrainWithOneDimLabel(self): label_dimension = 1 @@ -736,7 +737,7 @@ class LinearRegressorTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testTrainWithOneDimWeight(self): label_dimension = 1 @@ -757,14 +758,14 @@ class LinearRegressorTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testFromScratch(self): # Create LinearRegressor. label = 5. age = 17 # loss = (logits - label)^2 = (0 - 5.)^2 = 25. - mock_optimizer = self._mockOptimizer(expected_loss=25.) + mock_optimizer = self._mock_optimizer(expected_loss=25.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -775,7 +776,7 @@ class LinearRegressorTrainingTest(test.TestCase): linear_regressor.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=num_steps, expected_age_weight=0., expected_bias=0.) @@ -795,7 +796,7 @@ class LinearRegressorTrainingTest(test.TestCase): # logits = age * age_weight + bias = 17 * 10. + 5. = 175 # loss = (logits - label)^2 = (175 - 5)^2 = 28900 - mock_optimizer = self._mockOptimizer(expected_loss=28900.) + mock_optimizer = self._mock_optimizer(expected_loss=28900.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -806,7 +807,7 @@ class LinearRegressorTrainingTest(test.TestCase): linear_regressor.train( input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) @@ -828,7 +829,7 @@ class LinearRegressorTrainingTest(test.TestCase): # logits[0] = 17 * 10. + 5. = 175 # logits[1] = 15 * 10. + 5. = 155 # loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004 - mock_optimizer = self._mockOptimizer(expected_loss=52004.) + mock_optimizer = self._mock_optimizer(expected_loss=52004.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -840,13 +841,18 @@ class LinearRegressorTrainingTest(test.TestCase): input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) -class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): +class _BaseLinearClassiferTrainingTest(object): + + def __init__(self, n_classes): + self._n_classes = n_classes + self._logits_dimensions = ( + self._n_classes if self._n_classes > 2 else 1) def setUp(self): self._model_dir = tempfile.mkdtemp() @@ -855,7 +861,7 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): if self._model_dir: shutil.rmtree(self._model_dir) - def _mockOptimizer(self, expected_loss=None): + def _mock_optimizer(self, expected_loss=None): expected_var_names = [ '%s/part_0:0' % _AGE_WEIGHT_NAME, '%s/part_0:0' % _BIAS_NAME @@ -887,8 +893,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): mock_optimizer.__deepcopy__ = lambda _: mock_optimizer return mock_optimizer - def _assertCheckpoint( + def _assert_checkpoint( self, expected_global_step, expected_age_weight=None, expected_bias=None): + logits_dimension = self._logits_dimensions + shapes = { name: shape for (name, shape) in checkpoint_utils.list_variables(self._model_dir) @@ -900,20 +908,20 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): checkpoint_utils.load_variable( self._model_dir, ops.GraphKeys.GLOBAL_STEP)) - self.assertEqual([1, 1], shapes[_AGE_WEIGHT_NAME]) + self.assertEqual([1, logits_dimension], shapes[_AGE_WEIGHT_NAME]) if expected_age_weight is not None: - self.assertEqual( + self.assertAllEqual( expected_age_weight, checkpoint_utils.load_variable(self._model_dir, _AGE_WEIGHT_NAME)) - self.assertEqual([1], shapes[_BIAS_NAME]) + self.assertEqual([logits_dimension], shapes[_BIAS_NAME]) if expected_bias is not None: - self.assertEqual( + self.assertAllEqual( expected_bias, checkpoint_utils.load_variable(self._model_dir, _BIAS_NAME)) def testFromScratchWithDefaultOptimizer(self): - n_classes = 2 + n_classes = self._n_classes label = 0 age = 17 est = linear.LinearClassifier( @@ -925,10 +933,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): num_steps = 10 est.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self._assertCheckpoint(num_steps) + self._assert_checkpoint(num_steps) def testTrainWithTwoDimsLabel(self): - n_classes = 2 + n_classes = self._n_classes batch_size = 20 est = linear.LinearClassifier( @@ -947,10 +955,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testTrainWithOneDimLabel(self): - n_classes = 2 + n_classes = self._n_classes batch_size = 20 est = linear.LinearClassifier( @@ -967,10 +975,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testTrainWithTwoDimsWeight(self): - n_classes = 2 + n_classes = self._n_classes batch_size = 20 est = linear.LinearClassifier( @@ -988,10 +996,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testTrainWithOneDimWeight(self): - n_classes = 2 + n_classes = self._n_classes batch_size = 20 est = linear.LinearClassifier( @@ -1007,16 +1015,24 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testFromScratch(self): - n_classes = 2 + n_classes = self._n_classes label = 1 age = 17 - # loss = sigmoid_cross_entropy(logits, label) where logits = 0 (weights are - # all zero initially) and label = 1 so, - # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 - mock_optimizer = self._mockOptimizer(expected_loss=0.69315) + # For binary classifer: + # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 + # For multi class classifer: + # loss = cross_entropy(logits, label) where logits are all 0s (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( 1.0 / n_classes ) + # For this particular test case, as logits are same, the formular + # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases. + mock_optimizer = self._mock_optimizer( + expected_loss=-1 * math.log(1.0/n_classes)) est = linear.LinearClassifier( feature_columns=(feature_column_lib.numeric_column('age'),), @@ -1030,31 +1046,49 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): est.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=num_steps, - expected_age_weight=0., - expected_bias=0.) + expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes], + expected_bias=[0.] if n_classes == 2 else [.0] * n_classes) def testFromCheckpoint(self): # Create initial checkpoint. - n_classes = 2 + n_classes = self._n_classes label = 1 age = 17 - age_weight = 2.0 - bias = -35.0 + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes initial_global_step = 100 with ops.Graph().as_default(): - variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME) - variables.Variable([bias], name=_BIAS_NAME) + variables.Variable(age_weight, name=_AGE_WEIGHT_NAME) + variables.Variable(bias, name=_BIAS_NAME) variables.Variable( initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) _save_variables_to_ckpt(self._model_dir) - # logits = age * age_weight + bias = 17 * 2. - 35. = -1. - # loss = sigmoid_cross_entropy(logits, label) - # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 - mock_optimizer = self._mockOptimizer(expected_loss=1.3133) + # For binary classifer: + # logits = age * age_weight + bias = 17 * 2. - 35. = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = 17 * age_weight + bias and label = 1 + # so, loss = 1 * -log ( soft_max(logits)[1] ) + if n_classes == 2: + expected_loss = 1.3133 + else: + logits = age_weight * age + bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[0, label]) + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) est = linear.LinearClassifier( feature_columns=(feature_column_lib.numeric_column('age'),), @@ -1068,34 +1102,55 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): est.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) def testFromCheckpointMultiBatch(self): # Create initial checkpoint. - n_classes = 2 + n_classes = self._n_classes label = [1, 0] age = [17, 18.5] - age_weight = 2.0 - bias = -35.0 + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes initial_global_step = 100 with ops.Graph().as_default(): - variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME) - variables.Variable([bias], name=_BIAS_NAME) + variables.Variable(age_weight, name=_AGE_WEIGHT_NAME) + variables.Variable(bias, name=_BIAS_NAME) variables.Variable( initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) _save_variables_to_ckpt(self._model_dir) - # logits = age * age_weight + bias - # logits[0] = 17 * 2. - 35. = -1. - # logits[1] = 18.5 * 2. - 35. = 2. - # loss = sigmoid_cross_entropy(logits, label) - # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 - # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 - mock_optimizer = self._mockOptimizer(expected_loss=1.3133 + 2.1269) + # For binary classifer: + # logits = age * age_weight + bias + # logits[0] = 17 * 2. - 35. = -1. + # logits[1] = 18.5 * 2. - 35. = 2. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 + # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = [17, 18.5] * age_weight + bias and label = [1, 0] + # so, loss = 1 * -log ( soft_max(logits)[label] ) + if n_classes == 2: + expected_loss = (1.3133 + 2.1269) + else: + logits = age_weight * np.reshape(age, (2, 1)) + bias + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) est = linear.LinearClassifier( feature_columns=(feature_column_lib.numeric_column('age'),), @@ -1110,10 +1165,27 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase): input_fn=lambda: ({'age': (age)}, (label)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) + +class LinearClassiferWithBinaryClassesTrainingTest( + _BaseLinearClassiferTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): + test.TestCase.__init__(self, methodName) + _BaseLinearClassiferTrainingTest.__init__(self, n_classes=2) + + +class LinearClassiferWithMultiClassesTrainingTest( + _BaseLinearClassiferTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): + test.TestCase.__init__(self, methodName) + _BaseLinearClassiferTrainingTest.__init__(self, n_classes=4) + + if __name__ == '__main__': test.main() From 6f4204c3d30b2fded99627fd7839acf9e2ece279 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 2 Jun 2017 18:15:18 -0700 Subject: [PATCH 69/72] Fix TensorBoard SHA256 in cmake PiperOrigin-RevId: 157897958 --- tensorflow/contrib/cmake/tf_python.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 74716cd900d..80522a18383 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -897,7 +897,7 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD # Copy resources for TensorBoard. file(DOWNLOAD http://mirror.bazel.build/tensorboard/index.html ${DOWNLOAD_LOCATION}/tensorboard/index.html - EXPECTED_HASH SHA256=60f185c68ff3f906000df9670bf9f46588056b197da7e7b10074411a0c048dae) + EXPECTED_HASH SHA256=25554e708552ad8587152f7a444db3f4ca753f9ed72d9f8105203c1d1806d521) add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tensorboard/components/) From bb7a8d8e728f7598dc9fb03a2ff71d4eaa51d714 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 2 Jun 2017 18:34:22 -0700 Subject: [PATCH 70/72] Don't use the _output_shape attribute in the op_level_cost_estimator since there is no guaranty that it will be present or accurate. PiperOrigin-RevId: 157898989 --- tensorflow/core/grappler/costs/BUILD | 15 +++ .../costs/analytical_cost_estimator.cc | 2 +- .../costs/analytical_cost_estimator_test.cc | 110 ++++++++++++++++++ .../grappler/costs/op_level_cost_estimator.cc | 81 +++++-------- .../grappler/costs/op_performance_data.proto | 5 +- .../core/grappler/costs/virtual_scheduler.cc | 6 +- .../core/grappler/costs/virtual_scheduler.h | 1 - .../grappler/costs/virtual_scheduler_test.cc | 6 +- tensorflow/python/BUILD | 7 ++ .../python/grappler/cost_analyzer_test.py | 41 +++++++ 10 files changed, 216 insertions(+), 58 deletions(-) create mode 100644 tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 206fac1decc..d40e66cd168 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -267,3 +267,18 @@ cc_library( "//tensorflow/core/grappler:grappler_item", ], ) + +cc_test( + name = "analytical_cost_estimator_test", + srcs = ["analytical_cost_estimator_test.cc"], + deps = [ + ":analytical_cost_estimator", + ":virtual_scheduler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/clusters:virtual_cluster", + ], +) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 7a1e7fcacef..651c77ad9a1 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -97,7 +97,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, node_costs.compute_time.asMicroSeconds().count()); cost_node->set_memory_time( node_costs.memory_time.asMicroSeconds().count()); - for (const auto& output : node_info.outputs) { + for (const auto& output : node_info.op_info.outputs()) { auto output_info = cost_node->add_output_info(); output_info->set_dtype(output.dtype()); auto shape = output_info->mutable_shape(); diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc new file mode 100644 index 00000000000..9e3dd38b09f --- /dev/null +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class AnalyticalCostEstimatorTest : public ::testing::Test { + protected: + void SetUp() override { + // Initializes cluster_ and placer_. + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + cpu_device.set_num_cores(4); + cpu_device.set_frequency(2600); + cpu_device.set_bandwidth(24 * 1024 * 1024); + devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + gpu_device.set_num_cores(12); + gpu_device.set_frequency(1100); + gpu_device.set_bandwidth(180 * 1024 * 1024); + (*gpu_device.mutable_environment())["architecture"] = "6"; + devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; + + cluster_.reset(new VirtualCluster(devices)); + } + + GrapplerItem CreateMiniGraph() { + const int batch = 1; + const int width = 28; + const int height = 28; + const int num_channels = 1; + const int num_labels = 10; + const int kernel_size = 3; + const int conv_filters = 32; + + Scope s = Scope::NewRootScope(); + auto images = ops::RandomUniform( + s.WithOpName("image"), {batch, width, height, num_channels}, DT_FLOAT); + auto labels = ops::RandomUniform(s.WithOpName("label"), {batch, num_labels}, + DT_FLOAT); + auto w = ops::Variable( + s.WithOpName("W"), + {kernel_size, kernel_size, num_channels, conv_filters}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("B"), {conv_filters}, DT_FLOAT); + auto conv = + ops::Conv2D(s.WithOpName("conv"), images, w, {1, 1, 1, 1}, "SAME"); + auto bias = ops::Add(s.WithOpName("bias"), conv, b); + auto relu = ops::Relu(s.WithOpName("relu"), bias); + auto flat_shape = ops::Const(s.WithOpName("flat_shape"), + {batch, width * height * conv_filters}); + auto flat = ops::Reshape(s.WithOpName("flat"), relu, flat_shape); + + auto w2 = + ops::Variable(s.WithOpName("W2"), + {width * height * conv_filters, num_labels}, DT_FLOAT); + auto b2 = ops::Variable(s.WithOpName("B2"), {num_labels}, DT_FLOAT); + auto matmul = ops::MatMul(s.WithOpName("matmul"), flat, w2); + auto logits = ops::Add(s.WithOpName("logits"), matmul, b2); + auto softmax = ops::Softmax(s.WithOpName("softmax"), logits); + auto lsm = ops::Log(s.WithOpName("lsm"), softmax); + + GrapplerItem item; + item.fetch.push_back("lsm"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + return item; + } + + std::unique_ptr cluster_; +}; + +TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { + GrapplerItem item = CreateMiniGraph(); + + AnalyticalCostEstimator estimator(cluster_.get(), true); + TF_ASSERT_OK(estimator.Initialize(item)); + + CostGraphDef cost_graph; + Costs summary; + TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary)); + + EXPECT_EQ(Costs::NanoSeconds(9108), summary.execution_time); + EXPECT_FALSE(summary.inaccurate); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 22c0c803e85..11a57921e56 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -76,32 +76,21 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; double bandwidth = -1; - if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; - } if (device.type() == "CPU") { - DeviceProperties local_cpu; - if (device.num_cores() <= 0 || device.frequency() <= 0) { - local_cpu = GetLocalCPUInfo(); - } else { - local_cpu = device; - } - // Check if vector instructions are available, and refine performance // prediction based on this. // Frequencies are stored in MHz in the DeviceProperties. - gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3; + gflops = device.num_cores() * device.frequency() * 1e-3; if (bandwidth < 0) { - if (local_cpu.bandwidth() > 0) { - bandwidth = local_cpu.bandwidth() / 1e6; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; } else { bandwidth = 32; } } } else if (device.type() == "GPU") { - const DeviceProperties local_gpu = GetLocalGPUInfo(0); - const string architecture = local_gpu.environment().at("architecture"); + const string architecture = device.environment().at("architecture"); int cores_per_multiprocessor; if (architecture < "3") { // Fermi @@ -110,17 +99,18 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( // Kepler cores_per_multiprocessor = 192; } else if (architecture < "6") { - // Maxwell + // Maxwell cores_per_multiprocessor = 128; } else { - // Pascal. + // Pascal cores_per_multiprocessor = 64; } - gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 * + gflops = device.num_cores() * device.frequency() * 1e-3 * cores_per_multiprocessor * kOpsPerMac; - if (bandwidth < 0) { - CHECK(local_gpu.bandwidth() > 0); - bandwidth = local_gpu.bandwidth() / 1e6; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; + } else { + bandwidth = 100; } } @@ -507,14 +497,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( return ops; } - if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + if (op_features.outputs_size() != 1) { // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure."; + LOG(ERROR) << "No output shape in Conv2DBackPropInput op."; return ops; } - const auto& input_shape = - op_features.attr().at("_output_shapes").list().shape(0); + const auto& input_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( input_shape, op_features.inputs(1).shape(), op_features, found_unknown_shapes); @@ -542,14 +531,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( return ops; } - if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { - // Need _output_shapes for filter shape. - LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure."; + if (op_features.outputs_size() != 1) { + // Need _output_shapes for input shape. + LOG(ERROR) << "No output shape in Conv2DBackPropFilter op."; return ops; } - const auto& filter_shape = - op_features.attr().at("_output_shapes").list().shape(0); + const auto& filter_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), filter_shape, op_features, found_unknown_shapes); @@ -598,28 +586,19 @@ int64 OpLevelCostEstimator::CalculateOutputSize( const OpInfo& op_features, bool* found_unknown_shapes) const { int64 total_output_size = 0; // use float as default for calculations - DataType dt = DT_FLOAT; - for (const auto& item : op_features.attr()) { - VLOG(1) << "Key:" << item.first - << " Value:" << SummarizeAttrValue(item.second); - if (item.first == "_output_shapes") { - for (const auto& original_output_shape : item.second.list().shape()) { - int64 output_size = 1; - int num_dims = std::max(1, original_output_shape.dim_size()); - auto output_shape = MaybeGetMinimumShape( - original_output_shape, num_dims, found_unknown_shapes); - for (const auto& dim : output_shape.dim()) { - output_size *= dim.size(); - } - output_size *= DataTypeSize(dt); - total_output_size += output_size; - VLOG(1) << "Output Size: " << output_size - << " Total Output Size:" << total_output_size; - } - } - if (item.first == "T") { - dt = item.second.type(); + for (const auto& output : op_features.outputs()) { + DataType dt = output.dtype(); + const auto& original_output_shape = output.shape(); + int64 output_size = DataTypeSize(dt); + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims, + found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); } + total_output_size += output_size; + VLOG(1) << "Output Size: " << output_size + << " Total Output Size:" << total_output_size; } return total_output_size; } diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto index 887a714c0f7..0d6b337d5a3 100644 --- a/tensorflow/core/grappler/costs/op_performance_data.proto +++ b/tensorflow/core/grappler/costs/op_performance_data.proto @@ -33,7 +33,7 @@ message OpInfo { // Custom parameters impacting the behavior of the op. map attr = 2; - // Input types, shapes and values if known. + // Input data types, shapes and values if known. message TensorProperties { DataType dtype = 1; TensorShapeProto shape = 2; @@ -41,6 +41,9 @@ message OpInfo { }; repeated TensorProperties inputs = 3; + // Optional description of the op outputs + repeated TensorProperties outputs = 5; + // Device on which the operation is run. DeviceProperties device = 4; } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 32b4b3c8bc0..8d8d246078c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -316,13 +316,17 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { NodeInfo node_info; node_info.name = node->name(); node_info.device_name = graph_properties_.GetDeviceName(node->name()); - node_info.outputs = graph_properties_.GetOutputProperties(node->name()); + std::vector outputs = + graph_properties_.GetOutputProperties(node->name()); auto& op_info = node_info.op_info; op_info.set_op(node->op()); *op_info.mutable_attr() = node->attr(); for (auto& input : inputs) { op_info.add_inputs()->Swap(&input); } + for (auto& output : outputs) { + op_info.add_outputs()->Swap(&output); + } op_info.mutable_device()->Swap(&device); // add some more to the node_info. return node_info; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 310f6cca09c..7764bdc478a 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -95,7 +95,6 @@ struct NodeInfo { OpInfo op_info; string name; string device_name; - std::vector outputs; }; // The virtual scheduler emulates execution of nodes in a graph, considering diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index cc4a63e5ff0..dad2104b754 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -126,9 +126,9 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { EXPECT_EQ(ops_executed.count("c2"), 0); // Check input / output properties. - EXPECT_EQ(1, ops_executed["x"].outputs.size()); - EXPECT_EQ(1, ops_executed["y"].outputs.size()); - EXPECT_EQ(1, ops_executed["f"].outputs.size()); + EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size()); + EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size()); + EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size()); EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index dcce808e97d..0de15b47242 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3797,10 +3797,17 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ + ":array_ops", ":client_testlib", ":cost_analyzer", ":framework_for_generated_wrappers", ":math_ops", + ":nn", + ":nn_grad", + ":random_ops", + ":state_ops", + ":training", + ":variables", "//tensorflow/core:protos_all_py", "//third_party/py/numpy", ], diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py index 19d3c9695bf..726db29f3c1 100644 --- a/tensorflow/python/grappler/cost_analyzer_test.py +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -19,11 +19,18 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import adam class PyWrapOptimizeGraphTest(test.TestCase): @@ -51,6 +58,40 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Also print the report to make it easier to debug print("{}".format(report)) + def testSmallNetwork(self): + image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1]) + label = array_ops.placeholder(dtypes.float32, shape=[1, 10]) + w = variables.Variable( + random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1)) + b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1)) + conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME") + h_conv = nn_ops.relu(conv + b) + h_conv_flat = array_ops.reshape(h_conv, [1, -1]) + + w_fc = variables.Variable( + random_ops.truncated_normal([25088, 10], stddev=0.1)) + b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1)) + y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc) + + cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum( + label * math_ops.log(y_conv), reduction_indices=[1])) + _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy) + + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + report = cost_analyzer.GenerateCostReport(mg) + + self.assertTrue(b"MatMul" in report) + self.assertTrue(b"ApplyAdam" in report) + self.assertTrue(b"Conv2D" in report) + self.assertTrue(b"Conv2DBackpropInput" in report) + self.assertTrue(b"Conv2DBackpropFilter" in report) + self.assertTrue(b"Softmax" in report) + + # Also print the report to make it easier to debug + print("{}".format(report)) + + +# print("{}".format(mg.graph_def)) if __name__ == "__main__": test.main() From 1234e2dda6851eb921bfa00355b80f2c9c661a7e Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 2 Jun 2017 18:56:51 -0700 Subject: [PATCH 71/72] Fix Plottable definition On Mac OS the build directory in the Node package conflicts with BUILD. PiperOrigin-RevId: 157899970 --- .../tensorboard/components/tf_imports/BUILD | 200 ++++++++++-------- tensorflow/workspace.bzl | 1 - 2 files changed, 106 insertions(+), 95 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_imports/BUILD b/tensorflow/tensorboard/components/tf_imports/BUILD index b067a6380b1..7014643b03d 100644 --- a/tensorflow/tensorboard/components/tf_imports/BUILD +++ b/tensorflow/tensorboard/components/tf_imports/BUILD @@ -97,11 +97,23 @@ ts_web_library( srcs = [ "plottable.d.ts", "plottable.html", - "@com_palantir_plottable//:plottable.css", - "@com_palantir_plottable//:plottable.js", ], path = "/tf-imports", - deps = [":d3"], + deps = [ + ":d3", + ":plottable_js_css", + ], +) + +ts_web_library( + name = "plottable_js_css", + srcs = [ + "@com_palantir_plottable//:package/plottable.css", + "@com_palantir_plottable//:package/plottable.js", + ], + path = "/tf-imports", + strip_prefix = "package", + visibility = ["//visibility:private"], ) ts_web_library( @@ -187,139 +199,139 @@ tensorboard_typescript_bundle( out = "plottable.d.ts", namespace_srcs = { "Plottable": [ - "@com_palantir_plottable//:build/src/core/dataset.d.ts", - "@com_palantir_plottable//:build/src/core/interfaces.d.ts", - "@com_palantir_plottable//:build/src/core/version.d.ts", + "@com_palantir_plottable//:package/build/src/core/dataset.d.ts", + "@com_palantir_plottable//:package/build/src/core/interfaces.d.ts", + "@com_palantir_plottable//:package/build/src/core/version.d.ts", ], "Plottable.Animators": [ - "@com_palantir_plottable//:build/src/animators/animator.d.ts", - "@com_palantir_plottable//:build/src/animators/easingAnimator.d.ts", - "@com_palantir_plottable//:build/src/animators/nullAnimator.d.ts", + "@com_palantir_plottable//:package/build/src/animators/animator.d.ts", + "@com_palantir_plottable//:package/build/src/animators/easingAnimator.d.ts", + "@com_palantir_plottable//:package/build/src/animators/nullAnimator.d.ts", ], "Plottable.Axes": [ - "@com_palantir_plottable//:build/src/axes/axis.d.ts", - "@com_palantir_plottable//:build/src/axes/categoryAxis.d.ts", - "@com_palantir_plottable//:build/src/axes/numericAxis.d.ts", - "@com_palantir_plottable//:build/src/axes/timeAxis.d.ts", + "@com_palantir_plottable//:package/build/src/axes/axis.d.ts", + "@com_palantir_plottable//:package/build/src/axes/categoryAxis.d.ts", + "@com_palantir_plottable//:package/build/src/axes/numericAxis.d.ts", + "@com_palantir_plottable//:package/build/src/axes/timeAxis.d.ts", ], "Plottable.Components": [ - "@com_palantir_plottable//:build/src/components/component.d.ts", - "@com_palantir_plottable//:build/src/components/componentContainer.d.ts", - "@com_palantir_plottable//:build/src/components/dragBoxLayer.d.ts", - "@com_palantir_plottable//:build/src/components/dragLineLayer.d.ts", - "@com_palantir_plottable//:build/src/components/gridlines.d.ts", - "@com_palantir_plottable//:build/src/components/group.d.ts", - "@com_palantir_plottable//:build/src/components/guideLineLayer.d.ts", - "@com_palantir_plottable//:build/src/components/interpolatedColorLegend.d.ts", - "@com_palantir_plottable//:build/src/components/label.d.ts", - "@com_palantir_plottable//:build/src/components/legend.d.ts", - "@com_palantir_plottable//:build/src/components/plotGroup.d.ts", - "@com_palantir_plottable//:build/src/components/selectionBoxLayer.d.ts", - "@com_palantir_plottable//:build/src/components/table.d.ts", - "@com_palantir_plottable//:build/src/components/xDragBoxLayer.d.ts", - "@com_palantir_plottable//:build/src/components/yDragBoxLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/component.d.ts", + "@com_palantir_plottable//:package/build/src/components/componentContainer.d.ts", + "@com_palantir_plottable//:package/build/src/components/dragBoxLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/dragLineLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/gridlines.d.ts", + "@com_palantir_plottable//:package/build/src/components/group.d.ts", + "@com_palantir_plottable//:package/build/src/components/guideLineLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/interpolatedColorLegend.d.ts", + "@com_palantir_plottable//:package/build/src/components/label.d.ts", + "@com_palantir_plottable//:package/build/src/components/legend.d.ts", + "@com_palantir_plottable//:package/build/src/components/plotGroup.d.ts", + "@com_palantir_plottable//:package/build/src/components/selectionBoxLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/table.d.ts", + "@com_palantir_plottable//:package/build/src/components/xDragBoxLayer.d.ts", + "@com_palantir_plottable//:package/build/src/components/yDragBoxLayer.d.ts", ], "Plottable.Configs": [ - "@com_palantir_plottable//:build/src/core/config.d.ts", + "@com_palantir_plottable//:package/build/src/core/config.d.ts", ], "Plottable.Formatters": [ - "@com_palantir_plottable//:build/src/core/formatters.d.ts", + "@com_palantir_plottable//:package/build/src/core/formatters.d.ts", ], "Plottable.RenderController": [ - "@com_palantir_plottable//:build/src/core/renderController.d.ts", + "@com_palantir_plottable//:package/build/src/core/renderController.d.ts", ], "Plottable.RenderPolicies": [ - "@com_palantir_plottable//:build/src/core/renderPolicy.d.ts", + "@com_palantir_plottable//:package/build/src/core/renderPolicy.d.ts", ], "Plottable.SymbolFactories": [ - "@com_palantir_plottable//:build/src/core/symbolFactories.d.ts", + "@com_palantir_plottable//:package/build/src/core/symbolFactories.d.ts", ], "Plottable.Dispatchers": [ - "@com_palantir_plottable//:build/src/dispatchers/dispatcher.d.ts", - "@com_palantir_plottable//:build/src/dispatchers/keyDispatcher.d.ts", - "@com_palantir_plottable//:build/src/dispatchers/mouseDispatcher.d.ts", - "@com_palantir_plottable//:build/src/dispatchers/touchDispatcher.d.ts", + "@com_palantir_plottable//:package/build/src/dispatchers/dispatcher.d.ts", + "@com_palantir_plottable//:package/build/src/dispatchers/keyDispatcher.d.ts", + "@com_palantir_plottable//:package/build/src/dispatchers/mouseDispatcher.d.ts", + "@com_palantir_plottable//:package/build/src/dispatchers/touchDispatcher.d.ts", ], "Plottable.Drawers": [ - "@com_palantir_plottable//:build/src/drawers/arcDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/arcOutlineDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/areaDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/canvasBuffer.d.ts", - "@com_palantir_plottable//:build/src/drawers/canvasDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/drawStep.d.ts", - "@com_palantir_plottable//:build/src/drawers/drawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/lineDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/rectangleDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/segmentDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/svgDrawer.d.ts", - "@com_palantir_plottable//:build/src/drawers/symbolDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/arcDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/arcOutlineDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/areaDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/canvasBuffer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/canvasDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/drawStep.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/drawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/lineDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/rectangleDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/segmentDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/svgDrawer.d.ts", + "@com_palantir_plottable//:package/build/src/drawers/symbolDrawer.d.ts", ], "Plottable.Interactions": [ - "@com_palantir_plottable//:build/src/interactions/clickInteraction.d.ts", - "@com_palantir_plottable//:build/src/interactions/dragInteraction.d.ts", - "@com_palantir_plottable//:build/src/interactions/interaction.d.ts", - "@com_palantir_plottable//:build/src/interactions/keyInteraction.d.ts", - "@com_palantir_plottable//:build/src/interactions/panZoomInteraction.d.ts", - "@com_palantir_plottable//:build/src/interactions/pointerInteraction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/clickInteraction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/dragInteraction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/interaction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/keyInteraction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/panZoomInteraction.d.ts", + "@com_palantir_plottable//:package/build/src/interactions/pointerInteraction.d.ts", ], "Plottable.Plots": [ - "@com_palantir_plottable//:build/src/plots/areaPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/barPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/clusteredBarPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/commons.d.ts", - "@com_palantir_plottable//:build/src/plots/linePlot.d.ts", - "@com_palantir_plottable//:build/src/plots/piePlot.d.ts", - "@com_palantir_plottable//:build/src/plots/plot.d.ts", - "@com_palantir_plottable//:build/src/plots/rectanglePlot.d.ts", - "@com_palantir_plottable//:build/src/plots/scatterPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/segmentPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/stackedAreaPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/stackedBarPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/waterfallPlot.d.ts", - "@com_palantir_plottable//:build/src/plots/xyPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/areaPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/barPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/clusteredBarPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/commons.d.ts", + "@com_palantir_plottable//:package/build/src/plots/linePlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/piePlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/plot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/rectanglePlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/scatterPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/segmentPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/stackedAreaPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/stackedBarPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/waterfallPlot.d.ts", + "@com_palantir_plottable//:package/build/src/plots/xyPlot.d.ts", ], "Plottable.Scales": [ - "@com_palantir_plottable//:build/src/scales/index.d.ts", - "@com_palantir_plottable//:build/src/scales/categoryScale.d.ts", - "@com_palantir_plottable//:build/src/scales/colorScale.d.ts", - "@com_palantir_plottable//:build/src/scales/interpolatedColorScale.d.ts", - "@com_palantir_plottable//:build/src/scales/linearScale.d.ts", - "@com_palantir_plottable//:build/src/scales/modifiedLogScale.d.ts", - "@com_palantir_plottable//:build/src/scales/quantitativeScale.d.ts", - "@com_palantir_plottable//:build/src/scales/scale.d.ts", - "@com_palantir_plottable//:build/src/scales/timeScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/index.d.ts", + "@com_palantir_plottable//:package/build/src/scales/categoryScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/colorScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/interpolatedColorScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/linearScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/modifiedLogScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/quantitativeScale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/scale.d.ts", + "@com_palantir_plottable//:package/build/src/scales/timeScale.d.ts", ], "Plottable.Scales.TickGenerators": [ - "@com_palantir_plottable//:build/src/scales/tickGenerators.d.ts", + "@com_palantir_plottable//:package/build/src/scales/tickGenerators.d.ts", ], "Plottable.Utils": [ - "@com_palantir_plottable//:build/src/utils/addD3SelectionMulti.d.ts", - "@com_palantir_plottable//:build/src/utils/bucket.d.ts", - "@com_palantir_plottable//:build/src/utils/callbackSet.d.ts", - "@com_palantir_plottable//:build/src/utils/coerceD3.d.ts", - "@com_palantir_plottable//:build/src/utils/entityStore.d.ts", - "@com_palantir_plottable//:build/src/utils/makeEnum.d.ts", - "@com_palantir_plottable//:build/src/utils/map.d.ts", - "@com_palantir_plottable//:build/src/utils/set.d.ts", - "@com_palantir_plottable//:build/src/utils/transformAwareTranslator.d.ts", + "@com_palantir_plottable//:package/build/src/utils/addD3SelectionMulti.d.ts", + "@com_palantir_plottable//:package/build/src/utils/bucket.d.ts", + "@com_palantir_plottable//:package/build/src/utils/callbackSet.d.ts", + "@com_palantir_plottable//:package/build/src/utils/coerceD3.d.ts", + "@com_palantir_plottable//:package/build/src/utils/entityStore.d.ts", + "@com_palantir_plottable//:package/build/src/utils/makeEnum.d.ts", + "@com_palantir_plottable//:package/build/src/utils/map.d.ts", + "@com_palantir_plottable//:package/build/src/utils/set.d.ts", + "@com_palantir_plottable//:package/build/src/utils/transformAwareTranslator.d.ts", ], "Plottable.Utils.Array": [ - "@com_palantir_plottable//:build/src/utils/arrayUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/arrayUtils.d.ts", ], "Plottable.Utils.Color": [ - "@com_palantir_plottable//:build/src/utils/colorUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/colorUtils.d.ts", ], "Plottable.Utils.DOM": [ - "@com_palantir_plottable//:build/src/utils/domUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/domUtils.d.ts", ], "Plottable.Utils.Math": [ - "@com_palantir_plottable//:build/src/utils/mathUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/mathUtils.d.ts", ], "Plottable.Utils.Stacking": [ - "@com_palantir_plottable//:build/src/utils/stackingUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/stackingUtils.d.ts", ], "Plottable.Utils.Window": [ - "@com_palantir_plottable//:build/src/utils/windowUtils.d.ts", + "@com_palantir_plottable//:package/build/src/utils/windowUtils.d.ts", ], }, namespace_symbol_aliases = { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c529f4d78c7..18bc9a82770 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -812,7 +812,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", ], }, - strip_prefix = {"plottable-3.1.0.tgz": "package"}, ) filegroup_external( From 563f05ff67b0c2c3b52da71337c4de1c43d09f1a Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Fri, 2 Jun 2017 20:22:41 -0700 Subject: [PATCH 72/72] [tf contrib seq2seq] Expand tile_batch to handle nested structures. This allows it to properly tile the initial wrapper state when using BeamSearchDecoder with AttentionWrapper. Unit tests updated to show this use. PiperOrigin-RevId: 157903115 --- .../kernel_tests/beam_search_decoder_test.py | 10 +++- .../seq2seq/python/ops/beam_search_decoder.py | 51 +++++++++++-------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index aeafe7c3e59..3d0627467aa 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase): class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention): - encoder_sequence_length = [3, 2, 3, 1, 1] - decoder_sequence_length = [2, 0, 1, 2, 3] + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 @@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase): batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, @@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase): num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, @@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase): alignment_history=False) cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone( + cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index f1d0ab07711..1d1babda163 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput( pass -def tile_batch(t, multiplier, name=None): - """Tile the batch dimension of tensor t. +def _tile_batch(t, multiplier): + """Core single-tensor implementation of tile_batch.""" + t = ops.convert_to_tensor(t, name="t") + shape_t = array_ops.shape(t) + if t.shape.ndims is None or t.shape.ndims < 1: + raise ValueError("t must have statically known rank") + tiling = [1] * (t.shape.ndims + 1) + tiling[1] = multiplier + tiled_static_batch_size = ( + t.shape[0].value * multiplier if t.shape[0].value is not None else None) + tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) + tiled = array_ops.reshape( + tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled.set_shape( + tensor_shape.TensorShape( + [tiled_static_batch_size]).concatenate(t.shape[1:])) + return tiled - This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of + +def tile_batch(t, multiplier, name=None): + """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. + + For each tensor t in a (possibly nested structure) of tensors, + this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated @@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None): name: Name scope for any created operations. Returns: - A `Tensor` shaped `[batch_size * multiplier, ...]`. + A (possibly nested structure of) `Tensor` shaped + `[batch_size * multiplier, ...]`. Raises: - ValueError: if `t` does not have a statically known rank or it's < 1. + ValueError: if tensor(s) `t` do not have a statically known rank or + the rank is < 1. """ - with ops.name_scope(name, "tile_batch", [t, multiplier]): - t = ops.convert_to_tensor(t, name="t") - shape_t = array_ops.shape(t) - if t.shape.ndims is None or t.shape.ndims < 1: - raise ValueError("t must have statically known rank") - tiling = [1] * (t.shape.ndims + 1) - tiling[1] = multiplier - tiled_static_batch_size = ( - t.shape[0].value * multiplier if t.shape[0].value is not None else None) - tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) - tiled = array_ops.reshape( - tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) - tiled.set_shape( - tensor_shape.TensorShape( - [tiled_static_batch_size]).concatenate(t.shape[1:])) - return tiled + flat_t = nest.flatten(t) + with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): + return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) def _check_maybe(t):