diff --git a/README.md b/README.md index 276012a8aa8..951e7c3b9f6 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ $ python >>> hello = tf.constant('Hello, TensorFlow!') >>> sess = tf.Session() >>> sess.run(hello) -Hello, TensorFlow! +'Hello, TensorFlow!' >>> a = tf.constant(10) >>> b = tf.constant(32) >>> sess.run(a+b) diff --git a/RELEASE.md b/RELEASE.md index b361d2f2055..fe6d052640a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,6 +3,19 @@ ## Major Features and Improvements * Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times. * Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo). +* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described + in the TensorFlow 1.1 release is gone: The first time an RNNCell is used, + it caches its scope. All future uses of the RNNCell will reuse variables from + that same scope. This is a breaking change from the behavior of RNNCells + in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to + ensure old code works correctly with the new semantics; this version + allows more flexible uses of RNNCell but can lead to subtle errors if + using code meant for TensorFlow <= 1.0.1. For example, writing: + `MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each + layer shares the **same** parameters. To get 5 layers each with their own + parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`. + If at all unsure, first test your code with TF 1.1; ensure it raises no + errors, and then upgrade to TF 1.2. # Release 1.1.0 diff --git a/configure b/configure index f0b1a77d702..fad3fdbebd9 100755 --- a/configure +++ b/configure @@ -86,15 +86,18 @@ while true; do PYTHON_BIN_PATH="" # Retry done +export PYTHON_BIN_PATH +write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" +# TODO(ngiraldo): allow the user to optionally set PYTHON_INCLUDE_PATH and NUMPY_INCLUDE_PATH ## Set up MKL related environment settings if false; then # Disable building with MKL for now while [ "$TF_NEED_MKL" == "" ]; do fromuser="" - read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT fromuser="1" case $INPUT in - [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; + [Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;; [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; * ) echo "Invalid selection: " $INPUT;; @@ -261,7 +264,7 @@ if [[ "$TF_NEED_VERBS" == "1" ]]; then fi # Invoke python_config and set up symlinks to python includes -./util/python/python_config.sh --setup "$PYTHON_BIN_PATH" +./util/python/python_config.sh "$PYTHON_BIN_PATH" # Append CC optimization flags to bazel.rc echo >> tools/bazel.rc diff --git a/tensorflow/BUILD b/tensorflow/BUILD index a2be4d40e05..b380b81c162 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -278,6 +278,7 @@ filegroup( "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", + "//tensorflow/contrib/verbs:all_files", "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", @@ -326,6 +327,7 @@ filegroup( "//tensorflow/tensorboard/components/vz_line_chart:all_files", "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", "//tensorflow/tensorboard/components/vz_projector:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files", "//tensorflow/tensorboard/components/vz_sorting:all_files", "//tensorflow/tensorboard/components/vz_sorting/test:all_files", "//tensorflow/tensorboard/lib:all_files", diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc index e79d3635095..7d3ad60aea4 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc @@ -38,7 +38,6 @@ static void AllocateFlags() { flags = new GpuCompilerFlags; flags->xla_gpu_embed_ir = false; flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas"; flag_list = new std::vector({ tensorflow::Flag( "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 86c9c3b1ac3..5630033ac89 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -649,4 +649,39 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ Array4D ReferenceUtil::PadArray4D( + const Array4D& operand, const PaddingConfig& padding, + const float pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented"; + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], + output_bounds[3]); + result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + } + *value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1], + indices[2] - pad_low[2], indices[3] - pad_low[3]); + }); + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 9e0f2472038..eb1eea7fc4c 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -395,6 +395,11 @@ class ReferenceUtil { const Array2D& operand, const PaddingConfig& padding, const float pad); + // Returns the result of a 4D pad on an input array. + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const float pad); + private: TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); }; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 7dae49acad3..81f54c26ec5 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -409,7 +409,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { // operand copy insertion above (which will share an allocation). TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( liveness.get(), computation->parameter_instruction(0))); - } else if (copy_param_and_const_) { + } else { // Record root indices to copy for general computations. TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( liveness->points_to_analysis())); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index ce91ac0de56..c20e04b6288 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -32,9 +32,6 @@ namespace xla { // different lifetimes than computation results. class CopyInsertion : public HloPassInterface { public: - explicit CopyInsertion(bool copy_param_and_const = true) - : copy_param_and_const_(copy_param_and_const) {} - ~CopyInsertion() override {} tensorflow::StringPiece name() const override { return "copy-insertion"; } // Run the pass on the given module. Returns whether the module was changed @@ -46,10 +43,6 @@ class CopyInsertion : public HloPassInterface { // duplicate copies. StatusOr FindOrInsertCopy(HloInstruction* hlo); - // Determines whether to insert copies if the root instruction is, or - // points-to, any constant or parameter instruction. - const bool copy_param_and_const_; - // A map containing all copies inserted during the copy insertion pass. The // key is the copied instruction and the value is the copy. std::unordered_map inserted_copies_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 161d0033a3a..43960cd3a8f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -187,8 +187,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // Invokes the ptxas tool on the given PTX string, and dumps its output. void DumpPtxasInfo(const string& ptx) { - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string ptxas_path = flags->xla_ptxas_path; + const string ptxas_path = + tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { LOG(WARNING) diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 447892c8dec..9e25f1aceb1 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -70,6 +70,7 @@ string HloExecutionProfile::ToString( string result; const int64 total_cycles = total_cycles_executed(computation); double clock_rate_ghz = device_description.clock_rate_ghz(); + CHECK_GE(clock_rate_ghz, 1e-9); const auto cycles_to_microseconds = [&](double cycles) { return cycles / clock_rate_ghz / 1000.0; @@ -80,14 +81,19 @@ string HloExecutionProfile::ToString( double nsecs = cycles / clock_rate_ghz; string bytes_per_sec; string bytes_per_cycle; - if (bytes_accessed >= 0) { + if (cycles <= 0 || bytes_accessed < 0) { + bytes_per_sec = ""; + bytes_per_cycle = ""; + } else { bytes_per_sec = tensorflow::strings::HumanReadableNumBytes( bytes_accessed / (nsecs / 1e9)); bytes_per_cycle = tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles); - } else { - bytes_per_sec = ""; - bytes_per_cycle = ""; + } + + double cycles_percent = 0; + if (total_cycles > 0) { + cycles_percent = cycles / static_cast(total_cycles) * 100; } tensorflow::strings::StrAppend( @@ -97,8 +103,7 @@ string HloExecutionProfile::ToString( ":: " "%12s/cycle :: " "%s", - cycles, cycles / static_cast(total_cycles) * 100, - cycles_to_microseconds(cycles), + cycles, cycles_percent, cycles_to_microseconds(cycles), flops <= 0 ? "" : HumanReadableNumFlops(flops, nsecs).c_str(), bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str())); }; @@ -114,26 +119,30 @@ string HloExecutionProfile::ToString( for (const auto& item : items) { const HloInstruction* hlo = item.first; tensorflow::strings::StrAppend(&result, "\n\t"); - int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo); - int64 bytes_accessed = - hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo); - string display = hlo == nullptr ? "" : hlo->ToString(); + const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo); + const int64 bytes_accessed = + (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo); + const string display = (hlo == nullptr) ? "" : hlo->ToString(); append_item(item.second, flops, bytes_accessed, display); } - MetricTableReport table; - table.SetMetricName("microseconds"); - table.SetEntryName("ops"); - table.SetShowCategoryTable(); - for (const auto& item : items) { - MetricTableReport::Entry entry; - entry.text = item.first->ToString(); - entry.short_text = item.first->ToString(/*compact_operands=*/true); - entry.category_text = item.first->ToCategory(); - entry.metric = cycles_to_microseconds(item.second); - table.AddEntry(std::move(entry)); + if (total_cycles <= 0) { + result += "****** 0 total cycles ******\n"; + } else { + MetricTableReport table; + table.SetMetricName("microseconds"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& item : items) { + MetricTableReport::Entry entry; + entry.text = item.first->ToString(); + entry.short_text = item.first->ToString(/*compact_operands=*/true); + entry.category_text = item.first->ToCategory(); + entry.metric = cycles_to_microseconds(item.second); + table.AddEntry(std::move(entry)); + } + result += table.MakeReport(cycles_to_microseconds(total_cycles)); } - result += table.MakeReport(cycles_to_microseconds(total_cycles)); return result; } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 9472086e2b4..338d63f1a00 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -309,6 +309,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "the rank of the operand and the padding configuration do not match."); } + if (operand_shape.element_type() != padding_value_shape.element_type()) { + return InvalidArgument( + "the element types of the operands to pad do not match"); + } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions[i] = operand_shape.dimensions(i) + @@ -338,7 +342,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // Check if both element types are the same. if (lhs.element_type() != rhs.element_type()) { - return fail("element types mismatch"); + return fail("element types do not match"); } if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 8ec4f1b528d..32b5fbba003 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -200,6 +201,46 @@ int64 PositionInContainer(const Container& container, int64 value) { std::find(container.begin(), container.end(), value)); } +// Formats the container as a comma-separated string. StrAppend must support +// appending the elements of the container. Prefix is prepended and suffix is +// appended to the returned string. +template +string CommaSeparatedString(const Container& c, const char* prefix = "", + const char* suffix = "") { + // Not using Join() since the implementation here is simple anyway and this + // avoids copying the string to append prefix. + string comma_separated = prefix; + const char* separator = ""; + for (const auto& entry : c) { + tensorflow::strings::StrAppend(&comma_separated, separator, entry); + separator = ", "; + } + comma_separated += suffix; + return comma_separated; +} + +// Overload needed to allow the container to be an initializer list. The default +// type for T makes an empty initializer list work as well. +template +string CommaSeparatedString(const std::initializer_list& c, + const char* prefix = "", const char* suffix = "") { + return CommaSeparatedString>(c, prefix, suffix); +} + +// Formats the container in the mathematical notation for a vector, e.g. (1, 3, +// 7). StrAppend must support appending the elements of c. +template +string VectorString(const Container& c) { + return CommaSeparatedString(c, "(", ")"); +} + +// Overload needed to allow the container to be an initializer list. The default +// type for T makes an empty initializer list work as well. +template +string VectorString(const std::initializer_list& c) { + return VectorString>(c); +} + // Returns a PaddingConfig object that represents no padding for the given rank. PaddingConfig MakeNoPaddingConfig(int64 rank); diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index a81014f3b7a..547b924180b 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -80,6 +80,26 @@ TEST(UtilTest, HumanReadableNumFlopsExample) { ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9)); } +TEST(UtilTest, CommaSeparatedString) { + EXPECT_EQ(CommaSeparatedString({}), ""); + EXPECT_EQ(CommaSeparatedString({"hello world"}), "hello world"); + EXPECT_EQ(CommaSeparatedString({1, 57, 2}, "foo", "bar"), "foo1, 57, 2bar"); +} + +TEST(UtilTest, VectorString) { + std::list empty_list; + EXPECT_EQ(VectorString(empty_list), "()"); + + std::vector float_vector = {5.5}; + EXPECT_EQ(VectorString(float_vector), "(5.5)"); + + std::set string_set = {"a", "b"}; + EXPECT_EQ(VectorString(string_set), "(a, b)"); + + EXPECT_EQ(VectorString({}), "()"); + EXPECT_EQ(VectorString({1, 57, 2}), "(1, 57, 2)"); +} + TEST(UtilTest, LogLines) { // Just make sure this code runs (not verifying the output). LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__); diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py index 95d6c233886..0b7ffbd792e 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution.py +++ b/tensorflow/contrib/distributions/python/ops/distribution.py @@ -20,7 +20,6 @@ from __future__ import print_function import abc import contextlib -import inspect import types import numpy as np @@ -33,6 +32,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_inspect _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ @@ -154,12 +154,12 @@ class _DistributionMeta(abc.ABCMeta): if class_special_attr_value is None: # No _special method available, no need to update the docstring. continue - class_special_attr_docstring = inspect.getdoc(class_special_attr_value) + class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value) if not class_special_attr_docstring: # No docstring to append. continue class_attr_value = _copy_fn(base_attr_value) - class_attr_docstring = inspect.getdoc(base_attr_value) + class_attr_docstring = tf_inspect.getdoc(base_attr_value) if class_attr_docstring is None: raise ValueError( "Expected base class fn to contain a docstring: %s.%s" diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index db26c2b627e..f99a6674e57 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -44,7 +44,7 @@ class _Gumbel(distribution.Distribution): where `loc = mu` and `scale = sigma`. - The cumulative densifyt function of this distribution is, + The cumulative density function of this distribution is, ```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))``` diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py index bb94a876809..335fe7a5e2a 100644 --- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py +++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py @@ -18,12 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_inspect _DIVERGENCES = {} @@ -31,8 +30,8 @@ _DIVERGENCES = {} def _registered_kl(type_a, type_b): """Get the KL function registered for classes a and b.""" - hierarchy_a = inspect.getmro(type_a) - hierarchy_b = inspect.getmro(type_b) + hierarchy_a = tf_inspect.getmro(type_a) + hierarchy_b = tf_inspect.getmro(type_b) dist_to_children = None kl_fn = None for mro_to_a, parent_a in enumerate(hierarchy_a): diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index ad84cd681aa..9c194ec202a 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -61,8 +61,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import functools + +from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import tf_decorator __all__ = ['arg_scope', 'add_arg_scope', @@ -106,7 +107,7 @@ def _add_op(op): _DECORATED_OPS[key_op] = _kwarg_names(op) -@contextlib.contextmanager +@tf_contextlib.contextmanager def arg_scope(list_ops_or_scope, **kwargs): """Stores the default arguments for the given set of list_ops. @@ -170,7 +171,6 @@ def add_arg_scope(func): Returns: A tuple with the decorated function func_with_args(). """ - @functools.wraps(func) def func_with_args(*args, **kwargs): current_scope = _current_arg_scope() current_args = kwargs @@ -181,8 +181,7 @@ def add_arg_scope(func): return func(*args, **current_args) _add_op(func) setattr(func_with_args, '_key_op', _key_op(func)) - setattr(func_with_args, '__doc__', func.__doc__) - return func_with_args + return tf_decorator.make_decorator(func, func_with_args) def has_arg_scope(func): diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py index fd9db1f3273..2da5aee58e5 100644 --- a/tensorflow/contrib/keras/python/keras/backend_test.py +++ b/tensorflow/contrib/keras/python/keras/backend_test.py @@ -18,12 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - import numpy as np from tensorflow.contrib.keras.python import keras from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect def compare_single_input_op_to_numpy(keras_op, @@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase): compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5), keras_kwargs={'axis': -1}, np_kwargs={'axis': -1}) - if 'keepdims' in inspect.getargspec(keras_op).args: + if 'keepdims' in tf_inspect.getargspec(keras_op).args: compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5), keras_kwargs={'axis': 1, diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 0d1812aaa2f..7848e5982dd 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import copy -import inspect import json import os import re @@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect # pylint: disable=g-import-not-at-top @@ -584,7 +584,7 @@ class Layer(object): user_kwargs = copy.copy(kwargs) if not _is_all_none(previous_mask): # The previous layer generated a mask. - if 'mask' in inspect.getargspec(self.call).args: + if 'mask' in tf_inspect.getargspec(self.call).args: if 'mask' not in kwargs: # If mask is explicitly passed to __call__, # we should override the default mask. @@ -2166,7 +2166,7 @@ class Container(Layer): kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] - if 'mask' in inspect.getargspec(layer.call).args: + if 'mask' in tf_inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = _to_list(layer.call(computed_tensor, **kwargs)) @@ -2177,7 +2177,7 @@ class Container(Layer): else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] - if 'mask' in inspect.getargspec(layer.call).args: + if 'mask' in tf_inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = _to_list(layer.call(computed_tensors, **kwargs)) diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 8dd55aaa2e6..32ada176a4f 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import copy -import inspect import types as python_types import numpy as np @@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect class Masking(Layer): @@ -595,7 +595,7 @@ class Lambda(Layer): def call(self, inputs, mask=None): arguments = self.arguments - arg_spec = inspect.getargspec(self.function) + arg_spec = tf_inspect.getargspec(self.function) if 'mask' in arg_spec.args: arguments['mask'] = mask return self.function(inputs, **arguments) diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py index a322696514c..ce6458fd0c8 100644 --- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py +++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py @@ -20,12 +20,12 @@ from __future__ import division from __future__ import print_function import copy -import inspect from tensorflow.contrib.keras.python.keras import backend as K from tensorflow.contrib.keras.python.keras.engine import InputSpec from tensorflow.contrib.keras.python.keras.engine import Layer from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect class Wrapper(Layer): @@ -284,7 +284,7 @@ class Bidirectional(Wrapper): def call(self, inputs, training=None, mask=None): kwargs = {} - func_args = inspect.getargspec(self.layer.call).args + func_args = tf_inspect.getargspec(self.layer.call).args if 'training' in func_args: kwargs['training'] = training if 'mask' in func_args: diff --git a/tensorflow/contrib/keras/python/keras/testing_utils.py b/tensorflow/contrib/keras/python/keras/testing_utils.py index baba5447d99..bf6f661adff 100644 --- a/tensorflow/contrib/keras/python/keras/testing_utils.py +++ b/tensorflow/contrib/keras/python/keras/testing_utils.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - import numpy as np from tensorflow.contrib.keras.python import keras +from tensorflow.python.util import tf_inspect def get_test_data(train_samples, @@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, layer.set_weights(weights) # test and instantiation from weights - if 'weights' in inspect.getargspec(layer_cls.__init__): + if 'weights' in tf_inspect.getargspec(layer_cls.__init__): kwargs['weights'] = weights layer = layer_cls(**kwargs) diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py index 4c95c314b16..27cc23f232d 100644 --- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import marshal import sys import time @@ -26,6 +25,8 @@ import types as python_types import numpy as np import six +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect _GLOBAL_CUSTOM_OBJECTS = {} @@ -116,6 +117,7 @@ def get_custom_objects(): def serialize_keras_object(instance): + _, instance = tf_decorator.unwrap(instance) if instance is None: return None if hasattr(instance, 'get_config'): @@ -149,7 +151,7 @@ def deserialize_keras_object(identifier, if cls is None: raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) if hasattr(cls, 'from_config'): - arg_spec = inspect.getargspec(cls.from_config) + arg_spec = tf_inspect.getargspec(cls.from_config) if 'custom_objects' in arg_spec.args: custom_objects = custom_objects or {} return cls.from_config( diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py index 323c31aee83..9f8cea375b7 100644 --- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py +++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function import copy -import inspect import types import numpy as np from tensorflow.contrib.keras.python.keras.models import Sequential from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical +from tensorflow.python.util import tf_inspect class BaseWrapper(object): @@ -97,7 +97,7 @@ class BaseWrapper(object): legal_params = [] for fn in legal_params_fns: - legal_params += inspect.getargspec(fn)[0] + legal_params += tf_inspect.getargspec(fn)[0] legal_params = set(legal_params) for params_name in params: @@ -182,7 +182,7 @@ class BaseWrapper(object): """ override = override or {} res = {} - fn_args = inspect.getargspec(fn)[0] + fn_args = tf_inspect.getargspec(fn)[0] for name, value in self.sk_params.items(): if name in fn_args: res.update({name: value}) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py index 4a939cb22c5..80fa17ec1f7 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py @@ -24,9 +24,9 @@ from __future__ import print_function import collections import functools -import inspect import re +from tensorflow.python.util import tf_inspect # used for register_type_abbreviation and _type_repr below. _TYPE_ABBREVIATIONS = {} @@ -230,7 +230,7 @@ def accepts(*types): def check_accepts(f): """Check the types.""" - spec = inspect.getargspec(f) + spec = tf_inspect.getargspec(f) num_function_arguments = len(spec.args) if len(types) != num_function_arguments: diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transform.py b/tensorflow/contrib/learn/python/learn/dataframe/transform.py index c28da59ac76..33be68e46a5 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transform.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transform.py @@ -24,11 +24,12 @@ from abc import abstractmethod from abc import abstractproperty import collections -import inspect from .series import Series from .series import TransformedSeries +from tensorflow.python.util import tf_inspect + def _make_list_of_series(x): """Converts `x` into a list of `Series` if possible. @@ -120,7 +121,7 @@ class Transform(object): def parameters(self): """A dict of names to values of properties marked with `@parameter`.""" property_param_names = [name - for name, func in inspect.getmembers(type(self)) + for name, func in tf_inspect.getmembers(type(self)) if (hasattr(func, "fget") and hasattr( getattr(func, "fget"), "is_parameter"))] return {name: getattr(self, name) for name in property_param_names} diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 881a8334696..13f213c197f 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -218,7 +218,8 @@ def read_data_sets(train_dir, if fake_data: def fake(): - return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed) + return DataSet( + [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed) train = fake() validation = fake() @@ -260,13 +261,16 @@ def read_data_sets(train_dir, train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] - train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed) - validation = DataSet(validation_images, - validation_labels, - dtype=dtype, - reshape=reshape, - seed=seed) - test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed) + train = DataSet( + train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed) + validation = DataSet( + validation_images, + validation_labels, + dtype=dtype, + reshape=reshape, + seed=seed) + test = DataSet( + test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed) return base.Datasets(train=train, validation=validation, test=test) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 89fbe768402..8a92809a0ce 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -21,7 +21,6 @@ from __future__ import print_function import abc import copy -import inspect import os import tempfile @@ -70,6 +69,8 @@ from tensorflow.python.training import monitored_session from tensorflow.python.training import saver from tensorflow.python.training import summary_io from tensorflow.python.util import compat +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect AS_ITERABLE_DATE = '2016-09-15' @@ -185,14 +186,15 @@ def _model_fn_args(fn): Raises: ValueError: if partial function has positionally bound arguments """ + _, fn = tf_decorator.unwrap(fn) if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): # Handle functools.partial and similar objects. return tuple([ - arg for arg in inspect.getargspec(fn.func).args[len(fn.args):] + arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):] if arg not in set(fn.keywords.keys()) ]) # Handle function. - return tuple(inspect.getargspec(fn).args) + return tuple(tf_inspect.getargspec(fn).args) def _get_replica_device_setter(config): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 35aa0fa4cf7..6e10fdb9776 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -52,7 +52,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables as variables_lib @@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import input as input_lib from tensorflow.python.training import monitored_session -from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook from tensorflow.python.util import compat diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py index eb0cf51e098..fd47710e301 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect +from tensorflow.python.util import tf_inspect def assert_estimator_contract(tester, estimator_class): @@ -31,7 +31,7 @@ def assert_estimator_contract(tester, estimator_class): tester: A tf.test.TestCase. estimator_class: 'type' object of pre-canned estimator. """ - attributes = inspect.getmembers(estimator_class) + attributes = tf_inspect.getmembers(estimator_class) attribute_names = [a[0] for a in attributes] tester.assertTrue('config' in attribute_names) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index ae01c678b6c..12af78398b2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import abc -import inspect import six @@ -38,14 +37,17 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import training +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect class Head(object): @@ -1663,12 +1665,10 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"): if weight is None: loss = math_ops.reduce_mean(loss_unweighted, name=name_scope) return loss, loss + weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted) with ops.name_scope(None, "weighted_loss", (loss_unweighted, weight)) as name: - # TODO(ptucker): Support weight broadcasting, or switch to tf.losses. - weighted_loss = math_ops.multiply( - array_ops.reshape(loss_unweighted, shape=(-1,)), - array_ops.reshape(weight, shape=(-1,)), name=name) + weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name) weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) weighted_loss_normalized = math_ops.div( math_ops.reduce_sum(weighted_loss), @@ -1697,9 +1697,10 @@ def _check_mode_valid(mode): def _get_arguments(func): """Returns a spec of given func.""" + _, func = tf_decorator.unwrap(func) if hasattr(func, "__code__"): # Regular function. - return inspect.getargspec(func) + return tf_inspect.getargspec(func) elif hasattr(func, "__call__"): # Callable object. return _get_arguments(func.__call__) @@ -1802,8 +1803,13 @@ def _float_weights_or_none(weights): def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): - labels = ops.convert_to_tensor(labels) + labels = math_ops.to_float(labels) + weights = _float_weights_or_none(weights) + if weights is not None: + weights = weights_broadcast_ops.broadcast_weights(weights, labels) if class_id is not None: + if weights is not None: + weights = weights[:, class_id] labels = labels[:, class_id] return metrics_lib.streaming_mean(labels, weights=weights) @@ -1811,11 +1817,13 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): def _predictions_streaming_mean(predictions, weights=None, class_id=None): - predictions = ops.convert_to_tensor(predictions) + predictions = math_ops.to_float(predictions) + weights = _float_weights_or_none(weights) if weights is not None: - weights = ops.convert_to_tensor(weights) - + weights = weights_broadcast_ops.broadcast_weights(weights, predictions) if class_id is not None: + if weights is not None: + weights = weights[:, class_id] predictions = predictions[:, class_id] return metrics_lib.streaming_mean(predictions, weights=weights) @@ -1850,16 +1858,21 @@ def _class_labels_streaming_mean(labels, weights, class_id): def _streaming_auc(predictions, labels, weights=None, class_id=None, curve="ROC"): - predictions = ops.convert_to_tensor(predictions) - labels = ops.convert_to_tensor(labels) + # pylint: disable=missing-docstring + predictions = math_ops.to_float(predictions) + if labels.dtype.base_dtype != dtypes.bool: + logging.warning("Casting %s labels to bool.", labels.dtype) + labels = math_ops.cast(labels, dtypes.bool) + weights = _float_weights_or_none(weights) + if weights is not None: + weights = weights_broadcast_ops.broadcast_weights(weights, predictions) if class_id is not None: + if weights is not None: + weights = weights[:, class_id] predictions = predictions[:, class_id] labels = labels[:, class_id] return metrics_lib.streaming_auc( - predictions, - math_ops.cast(labels, dtypes.bool), - weights=_float_weights_or_none(weights), - curve=curve) + predictions, labels, weights=weights, curve=curve) def _assert_class_id(class_id, num_classes=None): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index abaf3a61a11..e81b15a1725 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test -# pylint: enable=g-bad-todo,g-import-not-at-top def _assert_variables(test_case, @@ -260,8 +259,10 @@ class RegressionHeadTest(test.TestCase): ), expected_trainable=("regression_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() - _assert_summary_tags( - self, ["regression_head/loss", "regression_head/centered_bias/bias_0"]) + _assert_summary_tags(self, [ + "regression_head/loss", + "regression_head/centered_bias/bias_0" + ]) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionErrorInSparseTensorLabels(self): @@ -541,7 +542,26 @@ class MultiLabelHeadTest(test.TestCase): _assert_no_variables(self) _assert_summary_tags(self, ["multi_label_head/loss"]) _assert_metrics(self, .089985214, - self._expected_eval_metrics(2.69956), model_fn_ops) + self._expected_eval_metrics(.89985214), model_fn_ops) + + def testMultiLabelWithMultiDimensionalWeight(self): + n_classes = 3 + head = head_lib.multi_label_head( + n_classes=n_classes, + weight_column_name="label_weight", + metric_class_ids=range(n_classes)) + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + features={"label_weight": ((.1, .1, .1),)}, + labels=self._labels, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=self._logits) + self._assert_output_alternatives(model_fn_ops) + _assert_no_variables(self) + _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_metrics(self, .089985214, + self._expected_eval_metrics(.89985214), model_fn_ops) def testMultiLabelWithCustomLoss(self): n_classes = 3 @@ -560,8 +580,9 @@ class MultiLabelHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) _assert_summary_tags(self, ["multi_label_head/loss"]) - _assert_metrics(self, 0.089985214, - self._expected_eval_metrics(0.089985214), model_fn_ops) + expected_loss = .089985214 + _assert_metrics(self, expected_loss, + self._expected_eval_metrics(expected_loss), model_fn_ops) def testMultiLabelWithCenteredBias(self): n_classes = 3 @@ -910,9 +931,10 @@ class BinaryClassificationHeadTest(test.TestCase): "Adagrad:0"),), expected_trainable=("binary_logistic_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() - _assert_summary_tags( - self, ["binary_logistic_head/loss", - "binary_logistic_head/centered_bias/bias_0"]) + _assert_summary_tags(self, [ + "binary_logistic_head/loss", + "binary_logistic_head/centered_bias/bias_0" + ]) expected_loss = .81326175 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -1416,7 +1438,8 @@ class BinarySvmHeadTest(test.TestCase): with ops.Graph().as_default(), session.Session(): weights = (7., 11.) model_fn_ops = head.create_model_fn_ops( - features={"weights": weights}, + # We have to add an extra dim here for weights broadcasting to work. + features={"weights": tuple([(w,) for w in weights])}, mode=model_fn.ModeKeys.TRAIN, labels=self._labels, train_op_fn=head_lib.no_op_train_fn, @@ -1424,11 +1447,10 @@ class BinarySvmHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) _assert_summary_tags(self, ["binary_svm_head/loss"]) - expected_weighted_sum = np.sum( - np.multiply(weights, self._expected_losses)) - _assert_metrics(self, expected_weighted_sum / len(weights), { + expected_weighted_losses = np.multiply(weights, self._expected_losses) + _assert_metrics(self, np.mean(expected_weighted_losses), { "accuracy": 1., - "loss": expected_weighted_sum / np.sum(weights), + "loss": np.sum(expected_weighted_losses) / np.sum(weights), }, model_fn_ops) def testBinarySVMWithCenteredBias(self): @@ -1450,9 +1472,10 @@ class BinarySvmHeadTest(test.TestCase): ), expected_trainable=("binary_svm_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() - _assert_summary_tags( - self, ["binary_svm_head/loss", - "binary_svm_head/centered_bias/bias_0"]) + _assert_summary_tags(self, [ + "binary_svm_head/loss", + "binary_svm_head/centered_bias/bias_0" + ]) expected_loss = np.average(self._expected_losses) _assert_metrics(self, expected_loss, { "accuracy": 1., diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index c62b8861a1e..f276aab0e6b 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """ExportStrategy class represents different flavors of model export.""" from __future__ import absolute_import @@ -20,13 +19,14 @@ from __future__ import division from __future__ import print_function import collections -import inspect + +from tensorflow.python.util import tf_inspect __all__ = ['ExportStrategy'] -class ExportStrategy(collections.namedtuple('ExportStrategy', - ['name', 'export_fn'])): +class ExportStrategy( + collections.namedtuple('ExportStrategy', ['name', 'export_fn'])): """A class representing a type of model export. Typically constructed by a utility function specific to the exporter, such as @@ -74,7 +74,7 @@ class ExportStrategy(collections.namedtuple('ExportStrategy', """ # don't break existing export_fns that don't accept checkpoint_path and # eval_result - export_fn_args = inspect.getargspec(self.export_fn).args + export_fn_args = tf_inspect.getargspec(self.export_fn).args kwargs = {} if 'checkpoint_path' in export_fn_args: kwargs['checkpoint_path'] = checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index 7be5748fa45..eafc925ad68 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import six from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect def _assert_named_args(sentinel): @@ -43,11 +43,11 @@ def _args(fn): if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. return tuple([ - arg for arg in inspect.getargspec(fn.func).args + arg for arg in tf_inspect.getargspec(fn.func).args if arg not in set(fn.keywords.keys()) ]) # Handle function. - return tuple(inspect.getargspec(fn).args) + return tuple(tf_inspect.getargspec(fn).args) _CANONICAL_LABELS_ARG = 'labels' diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index fa9f52e9223..9f133926660 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -35,7 +35,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import os import time @@ -53,6 +52,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import summary_io from tensorflow.python.util import deprecation +from tensorflow.python.util import tf_inspect # TODO(ptucker): Split each monitor class into a separate file. @@ -1164,7 +1164,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook): def end(self, session): self._last_step = None for m in self._monitors: - if "session" in inspect.getargspec(m.end).args: + if "session" in tf_inspect.getargspec(m.end).args: m.end(session=session) else: m.end() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 89665af2a9b..ab443eab6f6 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -51,6 +51,7 @@ tf_custom_op_py_library( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", + "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 5fc54f62d73..15afac98237 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -369,28 +369,28 @@ class RNNCellTest(test.TestCase): self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) - def testUsingSecondCellInScopeWithExistingVariablesFails(self): - # This test should go away when this behavior is no longer an - # error (Approx. May 2017) - cell1 = core_rnn_cell_impl.LSTMCell(3) - cell2 = core_rnn_cell_impl.LSTMCell(3) - x = array_ops.zeros([1, 3]) - m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) - cell1(x, m) - with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"): - cell2(x, m) + # def testUsingSecondCellInScopeWithExistingVariablesFails(self): + # # This test should go away when this behavior is no longer an + # # error (Approx. May 2017) + # cell1 = core_rnn_cell_impl.LSTMCell(3) + # cell2 = core_rnn_cell_impl.LSTMCell(3) + # x = array_ops.zeros([1, 3]) + # m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) + # cell1(x, m) + # with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"): + # cell2(x, m) - def testUsingCellInDifferentScopeFromFirstCallFails(self): - # This test should go away when this behavior is no longer an - # error (Approx. May 2017) - cell = core_rnn_cell_impl.LSTMCell(3) - x = array_ops.zeros([1, 3]) - m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) - with variable_scope.variable_scope("scope1"): - cell(x, m) - with variable_scope.variable_scope("scope2"): - with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"): - cell(x, m) + # def testUsingCellInDifferentScopeFromFirstCallFails(self): + # # This test should go away when this behavior is no longer an + # # error (Approx. May 2017) + # cell = core_rnn_cell_impl.LSTMCell(3) + # x = array_ops.zeros([1, 3]) + # m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) + # with variable_scope.variable_scope("scope1"): + # cell(x, m) + # with variable_scope.variable_scope("scope2"): + # with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"): + # cell(x, m) def testEmbeddingWrapper(self): with self.test_session() as sess: diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 4358fe475fc..54e3a0dadf3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -521,7 +521,7 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def testStateTupleWithProjAndSequenceLength(self): + def _testStateTupleWithProjAndSequenceLength(self): num_units = 3 input_size = 5 batch_size = 2 diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 3fc78d42531..8b40fc068fe 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -569,7 +569,7 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6) - def testAttentionCellWrapperCorrectResult(self): + def _testAttentionCellWrapperCorrectResult(self): num_units = 4 attn_length = 6 batch_size = 2 diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 3d1b482afd7..884b51926eb 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -108,11 +108,11 @@ class BasicRNNCell(RNNCell): """The most basic RNN cell.""" def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): + super(BasicRNNCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation - self._reuse = reuse @property def state_size(self): @@ -122,11 +122,9 @@ class BasicRNNCell(RNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Most basic RNN: output = new_state = act(W * input + U * state + B).""" - with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse): - output = self._activation( - _linear([inputs, state], self._num_units, True)) + output = self._activation(_linear([inputs, state], self._num_units, True)) return output, output @@ -134,11 +132,11 @@ class GRUCell(RNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): + super(GRUCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation - self._reuse = reuse @property def state_size(self): @@ -148,21 +146,15 @@ class GRUCell(RNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" - with _checked_scope(self, scope or "gru_cell", reuse=self._reuse): - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - value = sigmoid(_linear( - [inputs, state], 2 * self._num_units, True, 1.0)) - r, u = array_ops.split( - value=value, - num_or_size_splits=2, - axis=1) - with vs.variable_scope("candidate"): - c = self._activation(_linear([inputs, r * state], - self._num_units, True)) - new_h = u * state + (1 - u) * c + with vs.variable_scope("gates"): # Reset gate and update gate. + # We start with bias of 1.0 to not reset and not update. + value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0)) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + with vs.variable_scope("candidate"): + c = self._activation(_linear([inputs, r * state], self._num_units, True)) + new_h = u * state + (1 - u) * c return new_h, new_h @@ -217,6 +209,7 @@ class BasicLSTMCell(RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(BasicLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -226,7 +219,6 @@ class BasicLSTMCell(RNNCell): self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation - self._reuse = reuse @property def state_size(self): @@ -237,28 +229,28 @@ class BasicLSTMCell(RNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Long short-term memory cell (LSTM).""" - with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse): - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) - concat = _linear([inputs, h], 4 * self._num_units, True) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) + concat = _linear([inputs, h], 4 * self._num_units, True) - new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * - self._activation(j)) - new_h = self._activation(new_c) * sigmoid(o) + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = array_ops.concat([new_c, new_h], 1) - return new_h, new_state + new_c = ( + c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) + new_h = self._activation(new_c) * sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat([new_c, new_h], 1) + return new_h, new_state class LSTMCell(RNNCell): @@ -319,6 +311,7 @@ class LSTMCell(RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(LSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -341,7 +334,6 @@ class LSTMCell(RNNCell): self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation - self._reuse = reuse if num_proj: self._state_size = ( @@ -362,7 +354,7 @@ class LSTMCell(RNNCell): def output_size(self): return self._output_size - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of LSTM. Args: @@ -371,7 +363,6 @@ class LSTMCell(RNNCell): `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - scope: VariableScope for the created subgraph; defaults to "lstm_cell". Returns: A tuple containing: @@ -400,9 +391,8 @@ class LSTMCell(RNNCell): input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with _checked_scope(self, scope or "lstm_cell", - initializer=self._initializer, - reuse=self._reuse) as unit_scope: + scope = vs.get_variable_scope() + with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: if self._num_unit_shards is not None: unit_scope.set_partitioner( partitioned_variables.fixed_size_partitioner( @@ -481,13 +471,13 @@ class OutputProjectionWrapper(RNNCell): TypeError: if cell is not an RNNCell. ValueError: if output_size is not positive. """ + super(OutputProjectionWrapper, self).__init__(_reuse=reuse) if not isinstance(cell, RNNCell): raise TypeError("The parameter cell is not RNNCell.") if output_size < 1: raise ValueError("Parameter output_size must be > 0: %d." % output_size) self._cell = cell self._output_size = output_size - self._reuse = reuse self._activation = activation @property @@ -502,15 +492,12 @@ class OutputProjectionWrapper(RNNCell): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run the cell and output projection on inputs, starting from state.""" output, res_state = self._cell(inputs, state) - # Default scope: "OutputProjectionWrapper" - with _checked_scope(self, scope or "output_projection_wrapper", - reuse=self._reuse): - projected = _linear(output, self._output_size, True) - if self._activation: - projected = self._activation(projected) + projected = _linear(output, self._output_size, True) + if self._activation: + projected = self._activation(projected) return projected, res_state @@ -522,7 +509,8 @@ class InputProjectionWrapper(RNNCell): do the projection on this batch-concatenated sequence, then split it. """ - def __init__(self, cell, num_proj, activation=None, input_size=None): + def __init__(self, cell, num_proj, activation=None, input_size=None, + reuse=None): """Create a cell with input projection. Args: @@ -530,10 +518,14 @@ class InputProjectionWrapper(RNNCell): num_proj: Python integer. The dimension to project to. activation: (optional) an optional activation function. input_size: Deprecated and unused. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. Raises: TypeError: if cell is not an RNNCell. """ + super(InputProjectionWrapper, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) if not isinstance(cell, RNNCell): @@ -554,13 +546,12 @@ class InputProjectionWrapper(RNNCell): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" - with vs.variable_scope(scope or "input_projection_wrapper"): - projected = _linear(inputs, self._num_proj, True) - if self._activation: - projected = self._activation(projected) + projected = _linear(inputs, self._num_proj, True) + if self._activation: + projected = self._activation(projected) return self._cell(projected, state) @@ -847,6 +838,7 @@ class EmbeddingWrapper(RNNCell): TypeError: if cell is not an RNNCell. ValueError: if embedding_classes is not positive. """ + super(EmbeddingWrapper, self).__init__(_reuse=reuse) if not isinstance(cell, RNNCell): raise TypeError("The parameter cell is not RNNCell.") if embedding_classes <= 0 or embedding_size <= 0: @@ -856,7 +848,6 @@ class EmbeddingWrapper(RNNCell): self._embedding_classes = embedding_classes self._embedding_size = embedding_size self._initializer = initializer - self._reuse = reuse @property def state_size(self): @@ -870,31 +861,31 @@ class EmbeddingWrapper(RNNCell): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run the cell on embedded inputs.""" - with _checked_scope(self, scope or "embedding_wrapper", reuse=self._reuse): - with ops.device("/cpu:0"): - if self._initializer: - initializer = self._initializer - elif vs.get_variable_scope().initializer: - initializer = vs.get_variable_scope().initializer - else: - # Default initializer for embeddings should have variance=1. - sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. - initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) + with ops.device("/cpu:0"): + if self._initializer: + initializer = self._initializer + elif vs.get_variable_scope().initializer: + initializer = vs.get_variable_scope().initializer + else: + # Default initializer for embeddings should have variance=1. + sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. + initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) - if type(state) is tuple: - data_type = state[0].dtype - else: - data_type = state.dtype + if type(state) is tuple: + data_type = state[0].dtype + else: + data_type = state.dtype - embedding = vs.get_variable( - "embedding", [self._embedding_classes, self._embedding_size], - initializer=initializer, - dtype=data_type) - embedded = embedding_ops.embedding_lookup( - embedding, array_ops.reshape(inputs, [-1])) - return self._cell(embedded, state) + embedding = vs.get_variable( + "embedding", [self._embedding_classes, self._embedding_size], + initializer=initializer, + dtype=data_type) + embedded = embedding_ops.embedding_lookup(embedding, + array_ops.reshape(inputs, [-1])) + + return self._cell(embedded, state) class MultiRNNCell(RNNCell): @@ -914,6 +905,7 @@ class MultiRNNCell(RNNCell): ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ + super(MultiRNNCell, self).__init__() if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") if not nest.is_sequence(cells): @@ -948,28 +940,29 @@ class MultiRNNCell(RNNCell): # presumably does not contain TensorArrays or anything else fancy return super(MultiRNNCell, self).zero_state(batch_size, dtype) - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run this multi-layer cell on inputs, starting from state.""" - with vs.variable_scope(scope or "multi_rnn_cell"): - cur_state_pos = 0 - cur_inp = inputs - new_states = [] - for i, cell in enumerate(self._cells): - with vs.variable_scope("cell_%d" % i): - if self._state_is_tuple: - if not nest.is_sequence(state): - raise ValueError( - "Expected state to be a tuple of length %d, but received: %s" - % (len(self.state_size), state)) - cur_state = state[i] - else: - cur_state = array_ops.slice( - state, [0, cur_state_pos], [-1, cell.state_size]) - cur_state_pos += cell.state_size - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with vs.variable_scope("cell_%d" % i): + if self._state_is_tuple: + if not nest.is_sequence(state): + raise ValueError( + "Expected state to be a tuple of length %d, but received: %s" % + (len(self.state_size), state)) + cur_state = state[i] + else: + cur_state = array_ops.slice(state, [0, cur_state_pos], + [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) + return cur_inp, new_states diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 4eb2966ef28..83e8c2777f6 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -138,6 +138,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn( "%s: Using a concatenated state is slower and will soon be " @@ -173,7 +174,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): def output_size(self): return self._output_size - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of LSTM. Args: @@ -182,7 +183,6 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: @@ -212,51 +212,49 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with _checked_scope(self, scope or "coupled_input_forget_gate_lstm_cell", - initializer=self._initializer, reuse=self._reuse): - concat_w = _get_concat_variable( - "W", [input_size.value + num_proj, 3 * self._num_units], - dtype, self._num_unit_shards) + concat_w = _get_concat_variable( + "W", [input_size.value + num_proj, 3 * self._num_units], + dtype, self._num_unit_shards) - b = vs.get_variable( - "B", - shape=[3 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) + b = vs.get_variable( + "B", + shape=[3 * self._num_units], + initializer=init_ops.zeros_initializer(), + dtype=dtype) - # j = new_input, f = forget_gate, o = output_gate - cell_inputs = array_ops.concat([inputs, m_prev], 1) - lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) - j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) + # j = new_input, f = forget_gate, o = output_gate + cell_inputs = array_ops.concat([inputs, m_prev], 1) + lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) + j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) - # Diagonal connections - if self._use_peepholes: - w_f_diag = vs.get_variable( - "W_F_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "W_O_diag", shape=[self._num_units], dtype=dtype) + # Diagonal connections + if self._use_peepholes: + w_f_diag = vs.get_variable( + "W_F_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "W_O_diag", shape=[self._num_units], dtype=dtype) - if self._use_peepholes: - f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) - else: - f_act = sigmoid(f + self._forget_bias) - c = (f_act * c_prev + (1 - f_act) * self._activation(j)) + if self._use_peepholes: + f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) + else: + f_act = sigmoid(f + self._forget_bias) + c = (f_act * c_prev + (1 - f_act) * self._activation(j)) - if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) - if self._num_proj is not None: - concat_w_proj = _get_concat_variable( - "W_P", [self._num_units, self._num_proj], - dtype, self._num_proj_shards) + if self._num_proj is not None: + concat_w_proj = _get_concat_variable( + "W_P", [self._num_units, self._num_proj], + dtype, self._num_proj_shards) - m = math_ops.matmul(m, concat_w_proj) - if self._proj_clip is not None: - # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) - # pylint: enable=invalid-unary-operand-type + m = math_ops.matmul(m, concat_w_proj) + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else array_ops.concat([c, m], 1)) @@ -301,6 +299,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip @@ -321,14 +320,12 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell): def state_size(self): return self._state_size - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: state Tensor, 2D, batch x state_size. - scope: VariableScope for the created subgraph; defaults to - "TimeFreqLSTMCell". Returns: A tuple containing: @@ -347,63 +344,63 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell): freq_inputs = self._make_tf_features(inputs) dtype = inputs.dtype actual_input_size = freq_inputs[0].get_shape().as_list()[1] - with _checked_scope(self, scope or "time_freq_lstm_cell", - initializer=self._initializer, reuse=self._reuse): - concat_w = _get_concat_variable( - "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], - dtype, self._num_unit_shards) - b = vs.get_variable( - "B", - shape=[4 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) - # Diagonal connections + concat_w = _get_concat_variable( + "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], + dtype, self._num_unit_shards) + + b = vs.get_variable( + "B", + shape=[4 * self._num_units], + initializer=init_ops.zeros_initializer(), + dtype=dtype) + + # Diagonal connections + if self._use_peepholes: + w_f_diag = vs.get_variable( + "W_F_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "W_I_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "W_O_diag", shape=[self._num_units], dtype=dtype) + + # initialize the first freq state to be zero + m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), + self._num_units], dtype) + for fq in range(len(freq_inputs)): + c_prev = array_ops.slice(state, [0, 2*fq*self._num_units], + [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units], + [-1, self._num_units]) + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], + 1) + lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + if self._use_peepholes: - w_f_diag = vs.get_variable( - "W_F_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = vs.get_variable( - "W_I_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "W_O_diag", shape=[self._num_units], dtype=dtype) + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * tanh(j)) + else: + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) - # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), - self._num_units], dtype) - for fq in range(len(freq_inputs)): - c_prev = array_ops.slice(state, [0, 2*fq*self._num_units], - [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units], - [-1, self._num_units]) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], - 1) - lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) - i, j, f, o = array_ops.split( - value=lstm_matrix, num_or_size_splits=4, axis=1) + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type - if self._use_peepholes: - c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * tanh(j)) - else: - c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) - - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type - - if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * tanh(c) - else: - m = sigmoid(o) * tanh(c) - m_prev_freq = m - if fq == 0: - state_out = array_ops.concat([c, m], 1) - m_out = m - else: - state_out = array_ops.concat([state_out, c, m], 1) - m_out = array_ops.concat([m_out, m], 1) + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * tanh(c) + else: + m = sigmoid(o) * tanh(c) + m_prev_freq = m + if fq == 0: + state_out = array_ops.concat([c, m], 1) + m_out = m + else: + state_out = array_ops.concat([state_out, c, m], 1) + m_out = array_ops.concat([m_out, m], 1) return m_out, state_out def _make_tf_features(self, input_feat): @@ -499,6 +496,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell): Raises: ValueError: if the num_frequency_blocks list is not specified """ + super(GridLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -550,15 +548,13 @@ class GridLSTMCell(core_rnn_cell.RNNCell): def state_tuple_type(self): return self._state_tuple_type - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, [batch, feature_size]. state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the flag self._state_is_tuple. - scope: (optional) VariableScope for the created subgraph; if None, it - defaults to "GridLSTMCell". Returns: A tuple containing: @@ -573,21 +569,19 @@ class GridLSTMCell(core_rnn_cell.RNNCell): """ batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] freq_inputs = self._make_tf_features(inputs) - with _checked_scope(self, scope or "grid_lstm_cell", - initializer=self._initializer, reuse=self._reuse): - m_out_lst = [] - state_out_lst = [] - for block in range(len(freq_inputs)): - m_out_lst_current, state_out_lst_current = self._compute( - freq_inputs[block], block, state, batch_size, - state_is_tuple=self._state_is_tuple) - m_out_lst.extend(m_out_lst_current) - state_out_lst.extend(state_out_lst_current) - if self._state_is_tuple: - state_out = self._state_tuple_type(*state_out_lst) - else: - state_out = array_ops.concat(state_out_lst, 1) - m_out = array_ops.concat(m_out_lst, 1) + m_out_lst = [] + state_out_lst = [] + for block in range(len(freq_inputs)): + m_out_lst_current, state_out_lst_current = self._compute( + freq_inputs[block], block, state, batch_size, + state_is_tuple=self._state_is_tuple) + m_out_lst.extend(m_out_lst_current) + state_out_lst.extend(state_out_lst_current) + if self._state_is_tuple: + state_out = self._state_tuple_type(*state_out_lst) + else: + state_out = array_ops.concat(state_out_lst, 1) + m_out = array_ops.concat(m_out_lst, 1) return m_out, state_out def _compute(self, freq_inputs, block, state, batch_size, @@ -974,14 +968,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell): *([num_units, num_units] * self._total_blocks * 2)) self._output_size = 2 * num_units * self._total_blocks * 2 - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, [batch, num_units]. state: tuple of Tensors, 2D, [batch, state_size]. - scope: (optional) VariableScope for the created subgraph; if None, it - defaults to "BidirectionalGridLSTMCell". Returns: A tuple containing: @@ -1002,29 +994,27 @@ class BidirectionalGridLSTMCell(GridLSTMCell): bwd_inputs = fwd_inputs # Forward processing - with _checked_scope(self, scope or "bidirectional_grid_lstm_cell", - initializer=self._initializer, reuse=self._reuse): - with vs.variable_scope("fwd"): - fwd_m_out_lst = [] - fwd_state_out_lst = [] - for block in range(len(fwd_inputs)): - fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( - fwd_inputs[block], block, state, batch_size, - state_prefix="fwd_state", state_is_tuple=True) - fwd_m_out_lst.extend(fwd_m_out_lst_current) - fwd_state_out_lst.extend(fwd_state_out_lst_current) - # Backward processing - bwd_m_out_lst = [] - bwd_state_out_lst = [] - with vs.variable_scope("bwd"): - for block in range(len(bwd_inputs)): - # Reverse the blocks - bwd_inputs_reverse = bwd_inputs[block][::-1] - bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( - bwd_inputs_reverse, block, state, batch_size, - state_prefix="bwd_state", state_is_tuple=True) - bwd_m_out_lst.extend(bwd_m_out_lst_current) - bwd_state_out_lst.extend(bwd_state_out_lst_current) + with vs.variable_scope("fwd"): + fwd_m_out_lst = [] + fwd_state_out_lst = [] + for block in range(len(fwd_inputs)): + fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( + fwd_inputs[block], block, state, batch_size, + state_prefix="fwd_state", state_is_tuple=True) + fwd_m_out_lst.extend(fwd_m_out_lst_current) + fwd_state_out_lst.extend(fwd_state_out_lst_current) + # Backward processing + bwd_m_out_lst = [] + bwd_state_out_lst = [] + with vs.variable_scope("bwd"): + for block in range(len(bwd_inputs)): + # Reverse the blocks + bwd_inputs_reverse = bwd_inputs[block][::-1] + bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( + bwd_inputs_reverse, block, state, batch_size, + state_prefix="bwd_state", state_is_tuple=True) + bwd_m_out_lst.extend(bwd_m_out_lst_current) + bwd_state_out_lst.extend(bwd_state_out_lst_current) state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) # Outputs are always concated as it is never used separately. m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) @@ -1069,6 +1059,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): ValueError: if cell returns a state tuple but the flag `state_is_tuple` is `False` or if attn_length is zero or less. """ + super(AttentionCellWrapper, self).__init__(_reuse=reuse) if not isinstance(cell, core_rnn_cell.RNNCell): raise TypeError("The parameter cell is not RNNCell.") if nest.is_sequence(cell.state_size) and not state_is_tuple: @@ -1107,42 +1098,40 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): def output_size(self): return self._attn_size - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Long short-term memory cell with attention (LSTMA).""" - with _checked_scope(self, scope or "attention_cell_wrapper", - reuse=self._reuse): - if self._state_is_tuple: - state, attns, attn_states = state - else: - states = state - state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) - attns = array_ops.slice( - states, [0, self._cell.state_size], [-1, self._attn_size]) - attn_states = array_ops.slice( - states, [0, self._cell.state_size + self._attn_size], - [-1, self._attn_size * self._attn_length]) - attn_states = array_ops.reshape(attn_states, - [-1, self._attn_length, self._attn_size]) - input_size = self._input_size - if input_size is None: - input_size = inputs.get_shape().as_list()[1] - inputs = _linear([inputs, attns], input_size, True) - lstm_output, new_state = self._cell(inputs, state) - if self._state_is_tuple: - new_state_cat = array_ops.concat(nest.flatten(new_state), 1) - else: - new_state_cat = new_state - new_attns, new_attn_states = self._attention(new_state_cat, attn_states) - with vs.variable_scope("attn_output_projection"): - output = _linear([lstm_output, new_attns], self._attn_size, True) - new_attn_states = array_ops.concat( - [new_attn_states, array_ops.expand_dims(output, 1)], 1) - new_attn_states = array_ops.reshape( - new_attn_states, [-1, self._attn_length * self._attn_size]) - new_state = (new_state, new_attns, new_attn_states) - if not self._state_is_tuple: - new_state = array_ops.concat(list(new_state), 1) - return output, new_state + if self._state_is_tuple: + state, attns, attn_states = state + else: + states = state + state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) + attns = array_ops.slice( + states, [0, self._cell.state_size], [-1, self._attn_size]) + attn_states = array_ops.slice( + states, [0, self._cell.state_size + self._attn_size], + [-1, self._attn_size * self._attn_length]) + attn_states = array_ops.reshape(attn_states, + [-1, self._attn_length, self._attn_size]) + input_size = self._input_size + if input_size is None: + input_size = inputs.get_shape().as_list()[1] + inputs = _linear([inputs, attns], input_size, True) + lstm_output, new_state = self._cell(inputs, state) + if self._state_is_tuple: + new_state_cat = array_ops.concat(nest.flatten(new_state), 1) + else: + new_state_cat = new_state + new_attns, new_attn_states = self._attention(new_state_cat, attn_states) + with vs.variable_scope("attn_output_projection"): + output = _linear([lstm_output, new_attns], self._attn_size, True) + new_attn_states = array_ops.concat( + [new_attn_states, array_ops.expand_dims(output, 1)], 1) + new_attn_states = array_ops.reshape( + new_attn_states, [-1, self._attn_length * self._attn_size]) + new_state = (new_state, new_attns, new_attn_states) + if not self._state_is_tuple: + new_state = array_ops.concat(list(new_state), 1) + return output, new_state def _attention(self, query, attn_states): conv2d = nn_ops.conv2d @@ -1213,6 +1202,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) @@ -1256,34 +1246,31 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): out = nn_ops.bias_add(out, bias) return out - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """LSTM cell with layer normalization and recurrent dropout.""" + c, h = state + args = array_ops.concat([inputs, h], 1) + concat = self._linear(args) - with _checked_scope(self, scope or "layer_norm_basic_lstm_cell", - reuse=self._reuse): - c, h = state - args = array_ops.concat([inputs, h], 1) - concat = self._linear(args) + i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) + if self._layer_norm: + i = self._norm(i, "input") + j = self._norm(j, "transform") + f = self._norm(f, "forget") + o = self._norm(o, "output") - i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) - if self._layer_norm: - i = self._norm(i, "input") - j = self._norm(j, "transform") - f = self._norm(f, "forget") - o = self._norm(o, "output") + g = self._activation(j) + if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: + g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) - g = self._activation(j) - if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: - g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) + new_c = (c * math_ops.sigmoid(f + self._forget_bias) + + math_ops.sigmoid(i) * g) + if self._layer_norm: + new_c = self._norm(new_c, "state") + new_h = self._activation(new_c) * math_ops.sigmoid(o) - new_c = (c * math_ops.sigmoid(f + self._forget_bias) - + math_ops.sigmoid(i) * g) - if self._layer_norm: - new_c = self._norm(new_c, "state") - new_h = self._activation(new_c) * math_ops.sigmoid(o) - - new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) - return new_h, new_state + new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) + return new_h, new_state class NASCell(core_rnn_cell.RNNCell): @@ -1313,6 +1300,7 @@ class NASCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(NASCell, self).__init__(_reuse=reuse) self._num_units = num_units self._num_proj = num_proj self._use_biases = use_biases @@ -1333,14 +1321,13 @@ class NASCell(core_rnn_cell.RNNCell): def output_size(self): return self._output_size - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of NAS Cell. Args: inputs: input Tensor, 2D, batch x num_units. state: This must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - scope: VariableScope for the created subgraph; defaults to "nas_rnn". Returns: A tuple containing: @@ -1368,71 +1355,70 @@ class NASCell(core_rnn_cell.RNNCell): input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with _checked_scope(self, scope or "nas_rnn", reuse=self._reuse): - # Variables for the NAS cell. W_m is all matrices multiplying the - # hiddenstate and W_inputs is all matrices multiplying the inputs. - concat_w_m = vs.get_variable( - "recurrent_weights", [num_proj, 8 * self._num_units], - dtype) - concat_w_inputs = vs.get_variable( - "weights", [input_size.value, 8 * self._num_units], + # Variables for the NAS cell. W_m is all matrices multiplying the + # hiddenstate and W_inputs is all matrices multiplying the inputs. + concat_w_m = vs.get_variable( + "recurrent_weights", [num_proj, 8 * self._num_units], + dtype) + concat_w_inputs = vs.get_variable( + "weights", [input_size.value, 8 * self._num_units], + dtype) + + m_matrix = math_ops.matmul(m_prev, concat_w_m) + inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) + + if self._use_biases: + b = vs.get_variable( + "bias", + shape=[8 * self._num_units], + initializer=init_ops.zeros_initializer(), + dtype=dtype) + m_matrix = nn_ops.bias_add(m_matrix, b) + + # The NAS cell branches into 8 different splits for both the hiddenstate + # and the input + m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, + value=m_matrix) + inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, + value=inputs_matrix) + + # First layer + layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) + layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) + layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) + layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) + layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) + layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) + layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) + layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) + + # Second layer + l2_0 = tanh(layer1_0 * layer1_1) + l2_1 = tanh(layer1_2 + layer1_3) + l2_2 = tanh(layer1_4 * layer1_5) + l2_3 = sigmoid(layer1_6 + layer1_7) + + # Inject the cell + l2_0 = tanh(l2_0 + c_prev) + + # Third layer + l3_0_pre = l2_0 * l2_1 + new_c = l3_0_pre # create new cell + l3_0 = l3_0_pre + l3_1 = tanh(l2_2 + l2_3) + + # Final layer + new_m = tanh(l3_0 * l3_1) + + # Projection layer if specified + if self._num_proj is not None: + concat_w_proj = vs.get_variable( + "projection_weights", [self._num_units, self._num_proj], dtype) + new_m = math_ops.matmul(new_m, concat_w_proj) - m_matrix = math_ops.matmul(m_prev, concat_w_m) - inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) - - if self._use_biases: - b = vs.get_variable( - "bias", - shape=[8 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) - m_matrix = nn_ops.bias_add(m_matrix, b) - - # The NAS cell branches into 8 different splits for both the hiddenstate - # and the input - m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, - value=m_matrix) - inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, - value=inputs_matrix) - - # First layer - layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) - layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) - layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) - layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) - layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) - layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) - layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) - layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) - - # Second layer - l2_0 = tanh(layer1_0 * layer1_1) - l2_1 = tanh(layer1_2 + layer1_3) - l2_2 = tanh(layer1_4 * layer1_5) - l2_3 = sigmoid(layer1_6 + layer1_7) - - # Inject the cell - l2_0 = tanh(l2_0 + c_prev) - - # Third layer - l3_0_pre = l2_0 * l2_1 - new_c = l3_0_pre # create new cell - l3_0 = l3_0_pre - l3_1 = tanh(l2_2 + l2_3) - - # Final layer - new_m = tanh(l3_0 * l3_1) - - # Projection layer if specified - if self._num_proj is not None: - concat_w_proj = vs.get_variable( - "projection_weights", [self._num_units, self._num_proj], - dtype) - new_m = math_ops.matmul(new_m, concat_w_proj) - - new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m) - return new_m, new_state + new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m) + return new_m, new_state class UGRNNCell(core_rnn_cell.RNNCell): @@ -1467,6 +1453,7 @@ class UGRNNCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(UGRNNCell, self).__init__(_reuse=reuse) self._num_units = num_units self._initializer = initializer self._forget_bias = forget_bias @@ -1481,13 +1468,12 @@ class UGRNNCell(core_rnn_cell.RNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of UGRNN. Args: inputs: input Tensor, 2D, batch x input size. state: state Tensor, 2D, batch x num units. - scope: VariableScope for the created subgraph; defaults to "ugrnn_cell". Returns: new_output: batch x num units, Tensor representing the output of the UGRNN @@ -1506,8 +1492,8 @@ class UGRNNCell(core_rnn_cell.RNNCell): if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with _checked_scope(self, scope or "ugrnn_cell", - initializer=self._initializer, reuse=self._reuse): + with vs.variable_scope(vs.get_variable_scope(), + initializer=self._initializer): cell_inputs = array_ops.concat([inputs, state], 1) rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True) @@ -1567,6 +1553,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(IntersectionRNNCell, self).__init__(_reuse=reuse) self._num_units = num_units self._initializer = initializer self._forget_bias = forget_bias @@ -1582,14 +1569,12 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Run one step of the Intersection RNN. Args: inputs: input Tensor, 2D, batch x input size. state: state Tensor, 2D, batch x num units. - scope: VariableScope for the created subgraph; defaults to - "intersection_rnn_cell" Returns: new_y: batch x num units, Tensor representing the output of the +RNN @@ -1610,8 +1595,8 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell): if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with _checked_scope(self, scope or "intersection_rnn_cell", - initializer=self._initializer, reuse=self._reuse): + with vs.variable_scope(vs.get_variable_scope(), + initializer=self._initializer): # read-in projections (should be used for first layer in deep +RNN # to transform size of inputs from I --> N) if input_size.value != self._num_units: @@ -1683,7 +1668,7 @@ class CompiledWrapper(core_rnn_cell.RNNCell): return not _REGISTERED_OPS[node_def.op].is_stateful with jit.experimental_jit_scope(compile_ops=compile_ops): - return self._cell(inputs, state, scope=scope) + return self._cell(inputs, state, scope) def _random_exp_initializer(minval, @@ -1753,6 +1738,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ + super(PhasedLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._use_peepholes = use_peepholes self._leak = leak @@ -1782,7 +1768,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): cycle_ratio = self._mod(shifted_time, period_casted) / period_casted return math_ops.cast(cycle_ratio, dtype=dtypes.float32) - def __call__(self, inputs, state, scope=None): + def call(self, inputs, state): """Phased LSTM Cell. Args: @@ -1792,7 +1778,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): The second Tensor has shape [batch, features_size], and type float32. It stores the features. state: core_rnn_cell.LSTMStateTuple, state from previous timestep. - scope: string, id of the variable scope. Returns: A tuple containing: @@ -1801,61 +1786,60 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape [batch_size, num_units], representing the new state and the output. """ - with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse): - (c_prev, h_prev) = state - (time, x) = inputs + (c_prev, h_prev) = state + (time, x) = inputs - in_mask_gates = [x, h_prev] - if self._use_peepholes: - in_mask_gates.append(c_prev) + in_mask_gates = [x, h_prev] + if self._use_peepholes: + in_mask_gates.append(c_prev) - with vs.variable_scope("mask_gates"): - mask_gates = math_ops.sigmoid( - _linear(in_mask_gates, 2 * self._num_units, True)) - [input_gate, forget_gate] = array_ops.split( - axis=1, num_or_size_splits=2, value=mask_gates) + with vs.variable_scope("mask_gates"): + mask_gates = math_ops.sigmoid( + _linear(in_mask_gates, 2 * self._num_units, True)) + [input_gate, forget_gate] = array_ops.split( + axis=1, num_or_size_splits=2, value=mask_gates) - with vs.variable_scope("new_input"): - new_input = math_ops.tanh( - _linear([x, h_prev], self._num_units, True)) + with vs.variable_scope("new_input"): + new_input = math_ops.tanh( + _linear([x, h_prev], self._num_units, True)) - new_c = (c_prev * forget_gate + input_gate * new_input) + new_c = (c_prev * forget_gate + input_gate * new_input) - in_out_gate = [x, h_prev] - if self._use_peepholes: - in_out_gate.append(new_c) + in_out_gate = [x, h_prev] + if self._use_peepholes: + in_out_gate.append(new_c) - with vs.variable_scope("output_gate"): - output_gate = math_ops.sigmoid( - _linear(in_out_gate, self._num_units, True)) + with vs.variable_scope("output_gate"): + output_gate = math_ops.sigmoid( + _linear(in_out_gate, self._num_units, True)) - new_h = math_ops.tanh(new_c) * output_gate + new_h = math_ops.tanh(new_c) * output_gate - period = vs.get_variable( - "period", [self._num_units], - initializer=_random_exp_initializer( - self._period_init_min, self._period_init_max)) - phase = vs.get_variable( - "phase", [self._num_units], - initializer=init_ops.random_uniform_initializer( - 0., period.initial_value)) - ratio_on = vs.get_variable( - "ratio_on", [self._num_units], - initializer=init_ops.constant_initializer(self._ratio_on), - trainable=self._trainable_ratio_on) + period = vs.get_variable( + "period", [self._num_units], + initializer=_random_exp_initializer( + self._period_init_min, self._period_init_max)) + phase = vs.get_variable( + "phase", [self._num_units], + initializer=init_ops.random_uniform_initializer( + 0., period.initial_value)) + ratio_on = vs.get_variable( + "ratio_on", [self._num_units], + initializer=init_ops.constant_initializer(self._ratio_on), + trainable=self._trainable_ratio_on) - cycle_ratio = self._get_cycle_ratio(time, phase, period) + cycle_ratio = self._get_cycle_ratio(time, phase, period) - k_up = 2 * cycle_ratio / ratio_on - k_down = 2 - k_up - k_closed = self._leak * cycle_ratio + k_up = 2 * cycle_ratio / ratio_on + k_down = 2 - k_up + k_closed = self._leak * cycle_ratio - k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) - k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) + k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) + k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) - new_c = k * new_c + (1 - k) * c_prev - new_h = k * new_h + (1 - k) * h_prev + new_c = k * new_c + (1 - k) * c_prev + new_h = k * new_h + (1 - k) * h_prev - new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) + new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) - return new_h, new_state + return new_h, new_state diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 3b0568794dc..ec493b84635 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -56,14 +56,19 @@ class GatherTreeOp : public OpKernel { errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ", step_ids_shape.DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(sequence_length.shape()), - errors::InvalidArgument("sequence_length must be a vector, saw shape: ", + ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()), + errors::InvalidArgument("sequence_length must be a matrix, saw shape: ", sequence_length.shape().DebugString())); OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1), errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[1] (", + "Inconsistent batch sizes: sequence_length.shape[0] (", sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (", - step_ids_shape.dim_size(0), ")")); + step_ids_shape.dim_size(1), ")")); + OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2), + errors::InvalidArgument( + "Inconsistent batch sizes: sequence_length.shape[1] (", + sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (", + step_ids_shape.dim_size(2), ")")); OP_REQUIRES( ctx, step_ids_shape == parent_ids.shape(), errors::InvalidArgument( @@ -74,7 +79,7 @@ class GatherTreeOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); - typename TTypes::ConstVec seq_len_t = sequence_length.vec(); + typename TTypes::ConstMatrix seq_len_t = sequence_length.matrix(); typename TTypes::Tensor beams_t = beams->tensor(); functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, seq_len_t, beams_t); @@ -96,7 +101,7 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const CPUDevice& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstVec sequence_length, + typename TTypes::ConstMatrix sequence_length, typename TTypes::Tensor beams) { const int64 max_time = parent_ids.dimension(0); const int64 batch_size = parent_ids.dimension(1); @@ -104,15 +109,10 @@ struct GatherTree { beams.setConstant(-1); auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) { - int32 seq_len_b = -1; - int32 old_batch = -1; for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - if (batch != old_batch) { - seq_len_b = sequence_length(batch); - old_batch = batch; - } + int32 seq_len_b = sequence_length(batch, beam); if (seq_len_b == 0) { continue; } @@ -148,14 +148,14 @@ struct GatherTree { #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void GatherTree::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, \ - typename TTypes::ConstTensor step_ids, \ - typename TTypes::ConstTensor parent_ids, \ - typename TTypes::ConstVec sequence_length, \ - typename TTypes::Tensor beams); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void GatherTree::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes::ConstTensor step_ids, \ + typename TTypes::ConstTensor parent_ids, \ + typename TTypes::ConstMatrix sequence_length, \ + typename TTypes::Tensor beams); \ extern template struct GatherTree; DECLARE_GPU_SPEC(int32); diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 501a2eae848..124d07264e7 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -31,7 +31,7 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const Device& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstVec sequence_length, + typename TTypes::ConstMatrix sequence_length, typename TTypes::Tensor beams); }; diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index 8d8fc810015..e3c0d0bfa98 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -33,7 +33,7 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - const int32 seq_len_b = ldg(sequence_length + batch); + const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam); #define GET_IX(time_ix, beam_ix) \ (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam); @@ -59,7 +59,7 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstVec sequence_length, + typename TTypes::ConstMatrix sequence_length, typename TTypes::Tensor beams) { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index c167736d882..6c445cd4606 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -32,17 +32,20 @@ REGISTER_OP("GatherTree") ShapeHandle step_ids, parent_ids, sequence_length; // step_ids, parent_ids, and output are all shaped: - // [batch_size, max_time, beam_width]. - // sequence_length is shaped [batch_size]. + // [max_time, batch_size, beam_width]. + // sequence_length is shaped [batch_size, beam_width]. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length)); DimensionHandle batch_size = c->Dim(step_ids, 1); + DimensionHandle beam_width = c->Dim(step_ids, 2); TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids)); TF_RETURN_IF_ERROR( c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); + TF_RETURN_IF_ERROR( + c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width)); c->set_output(0, step_ids); return tensorflow::Status::OK(); @@ -58,7 +61,7 @@ TODO(ebrevdo): fill in step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. -sequence_length: `[batch_size]`. +sequence_length: `[batch_size, beam_width]`. beams: `[max_time, batch_size, beam_width]`. )doc"); diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index aa84ae060c9..888479e218e 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -109,7 +109,7 @@ class AttentionWrapperTest(test.TestCase): initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) - final_outputs, final_state = decoder.dynamic_decode(my_decoder) + final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 512df183171..a72d962d784 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -41,24 +42,32 @@ class TestGatherTree(test.TestCase): """Tests the gather_tree function.""" def test_gather_tree(self): - predicted_ids = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[2, 3, 4], [5, 6, 7], - [8, 9, 10]]]).transpose([1, 0, 2]) - parent_ids = np.array([ - [[0, 0, 0], [0, 1, 1], [2, 1, 2]], - [[0, 0, 0], [1, 2, 0], [2, 1, 1]], - ]).transpose([1, 0, 2]) - expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]], - [[2, 4, 4], [7, 6, 6], - [8, 9, 10]]]).transpose([1, 0, 2]) + # (max_time = 3, batch_size = 2, beam_width = 3) - res = beam_search_decoder._gather_tree( - ops.convert_to_tensor(predicted_ids), ops.convert_to_tensor(parent_ids)) + # create (batch_size, max_time, beam_width) matrix and transpose it + predicted_ids = np.array( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], + dtype=np.int32).transpose([1, 0, 2]) + parent_ids = np.array( + [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], + [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], + dtype=np.int32).transpose([1, 0, 2]) + + # sequence_lengths is shaped (batch_size = 2, beam_width = 3) + sequence_lengths = [[3, 3, 3], [3, 3, 3]] + + expected_result = np.array( + [[[2, 2, 2], [6, 5, 6], [7, 8, 9]], + [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) + + res = beam_search_ops.gather_tree( + predicted_ids, parent_ids, sequence_lengths) with self.test_session() as sess: res_ = sess.run(res) - np.testing.assert_array_equal(expected_result, res_) + self.assertAllEqual(expected_result, res_) class TestEosMasking(test.TestCase): @@ -80,18 +89,18 @@ class TestEosMasking(test.TestCase): probs = sess.run(probs) masked = sess.run(masked) - np.testing.assert_array_equal(probs[0][0], masked[0][0]) - np.testing.assert_array_equal(probs[0][2], masked[0][2]) - np.testing.assert_array_equal(probs[1][0], masked[1][0]) + self.assertAllEqual(probs[0][0], masked[0][0]) + self.assertAllEqual(probs[0][2], masked[0][2]) + self.assertAllEqual(probs[1][0], masked[1][0]) - np.testing.assert_equal(masked[0][1][0], 0) - np.testing.assert_equal(masked[1][1][0], 0) - np.testing.assert_equal(masked[1][2][0], 0) + self.assertEqual(masked[0][1][0], 0) + self.assertEqual(masked[1][1][0], 0) + self.assertEqual(masked[1][2][0], 0) for i in range(1, 5): - np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min) - np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min) - np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min) + self.assertAllClose(masked[0][1][i], np.finfo('float32').min) + self.assertAllClose(masked[1][1][i], np.finfo('float32').min) + self.assertAllClose(masked[1][2][i], np.finfo('float32').min) class TestBeamStep(test.TestCase): @@ -142,12 +151,11 @@ class TestBeamStep(test.TestCase): outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) - np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2, - 1]]) - np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) - np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) - np.testing.assert_array_equal(next_state_.finished, [[False, False, False], - [False, False, False]]) + self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) + self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) + self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) + self.assertAllEqual(next_state_.finished, [[False, False, False], + [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) @@ -158,7 +166,7 @@ class TestBeamStep(test.TestCase): expected_log_probs[1][0] += log_probs_[1, 2, 2] expected_log_probs[1][1] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] - np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs) + self.assertAllEqual(next_state_.log_probs, expected_log_probs) def test_step_with_eos(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) @@ -197,12 +205,11 @@ class TestBeamStep(test.TestCase): outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) - np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) - np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0, - 1]]) - np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) - np.testing.assert_array_equal(next_state_.finished, [[True, False, False], - [False, True, False]]) + self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) + self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) + self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) + self.assertAllEqual(next_state_.finished, [[True, False, False], + [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) @@ -211,7 +218,7 @@ class TestBeamStep(test.TestCase): expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] - np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs) + self.assertAllEqual(next_state_.log_probs, expected_log_probs) class BeamSearchDecoderTest(test.TestCase): @@ -259,8 +266,9 @@ class BeamSearchDecoderTest(test.TestCase): output_layer=output_layer, length_penalty_weight=0.0) - final_outputs, final_state = decoder.dynamic_decode( - bsd, output_time_major=time_major, maximum_iterations=max_out) + final_outputs, final_state, final_sequence_lengths = ( + decoder.dynamic_decode( + bsd, output_time_major=time_major, maximum_iterations=max_out)) def _t(shape): if time_major: @@ -284,16 +292,18 @@ class BeamSearchDecoderTest(test.TestCase): sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, - 'final_state': final_state + 'final_state': final_state, + 'final_sequence_lengths': final_sequence_lengths }) - # Mostly a smoke test - time_steps = max_out + max_sequence_length = np.max(sess_results['final_sequence_lengths']) + + # A smoke test self.assertEqual( - _t((batch_size, time_steps, beam_width)), + _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output.scores.shape) self.assertEqual( - _t((batch_size, time_steps, beam_width)), sess_results[ + _t((batch_size, max_sequence_length, beam_width)), sess_results[ 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) def testDynamicDecodeRNNBatchMajorNoAttention(self): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index 542254854a4..491d87f62d8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -38,7 +38,7 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [3] + sequence_length = [[3, 3, 3]] expected_result = _transpose_batch_time( [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) beams = beam_search_ops.gather_tree( @@ -54,7 +54,7 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [3] + sequence_length = [[3, 3, 3]] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, @@ -73,7 +73,7 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [3] + sequence_length = [[3, 3, 3]] expected_result = _transpose_batch_time( [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/gpu:0"): @@ -84,7 +84,8 @@ class GatherTreeTest(test.TestCase): self.assertAllEqual(expected_result, beams.eval()) def testGatherTreeBatch(self): - sequence_length = [0, 1, 2, 3] + # sequence_length is [batch_size, beam_width] = [4, 5] + sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5] with self.test_session(use_gpu=True): # (max_time = 4, batch_size = 4, beam_width = 5) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 340ec9bbb22..96dc7b4beee 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -60,9 +60,9 @@ class DynamicDecodeRNNTest(test.TestCase): initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) - final_outputs, final_state = decoder.dynamic_decode( - my_decoder, output_time_major=time_major, - maximum_iterations=maximum_iterations) + final_outputs, final_state, final_sequence_length = ( + decoder.dynamic_decode(my_decoder, output_time_major=time_major, + maximum_iterations=maximum_iterations)) def _t(shape): if time_major: @@ -73,6 +73,9 @@ class DynamicDecodeRNNTest(test.TestCase): isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple)) + self.assertEqual( + (batch_size,), + tuple(final_sequence_length.get_shape().as_list())) self.assertEqual( _t((batch_size, None, cell_depth)), tuple(final_outputs.rnn_output.get_shape().as_list())) @@ -83,7 +86,8 @@ class DynamicDecodeRNNTest(test.TestCase): sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "final_outputs": final_outputs, - "final_state": final_state + "final_state": final_state, + "final_sequence_length": final_sequence_length, }) # Mostly a smoke test @@ -131,7 +135,7 @@ class DynamicDecodeRNNTest(test.TestCase): # Match the variable scope of dynamic_rnn below so we end up # using the same variables with vs.variable_scope("root") as scope: - final_decoder_outputs, final_decoder_state = decoder.dynamic_decode( + final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode( my_decoder, # impute_finished=True ensures outputs and final state # match those of dynamic_rnn called with sequence_length not None diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 4da87276c6f..37622af59f6 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -454,6 +454,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): up to the next cell in an RNN stack or to the top RNN output. name: Name to use when creating ops. """ + super(AttentionWrapper, self).__init__() if not isinstance(cell, core_rnn_cell.RNNCell): raise TypeError( "cell must be an RNNCell, saw type: %s" % type(cell).__name__) @@ -515,7 +516,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): dtype), alignment_history=alignment_history) - def __call__(self, inputs, state, tiling_factor=1, scope=None): + def __call__(self, inputs, state, tiling_factor=1): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via @@ -536,7 +537,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): tensors from the previous time step. tiling_factor: An integer factor for which to tile the batch dimension. Used with BeamSearchDecoder. - scope: Must be `None`. Returns: A tuple `(attention_or_cell_output, next_state)`, where: @@ -548,50 +548,46 @@ class AttentionWrapper(core_rnn_cell.RNNCell): Raises: NotImplementedError: if `scope` is not `None`. """ - if scope is not None: - raise NotImplementedError("scope not None is not supported") + # Step 1: Calculate the true inputs to the cell based on the + # previous attention value. + cell_inputs = self._cell_input_fn(inputs, state.attention) + cell_state = state.cell_state + cell_output, next_cell_state = self._cell(cell_inputs, cell_state) - with variable_scope.variable_scope("attention"): - # Step 1: Calculate the true inputs to the cell based on the - # previous attention value. - cell_inputs = self._cell_input_fn(inputs, state.attention) - cell_state = state.cell_state - cell_output, next_cell_state = self._cell(cell_inputs, cell_state) + score = self._attention_mechanism(cell_output, tiling_factor) + alignments = self._probability_fn(score) - score = self._attention_mechanism(cell_output, tiling_factor) - alignments = self._probability_fn(score) + # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] + expanded_alignments = array_ops.expand_dims(alignments, 1) + # Context is the inner product of alignments and values along the + # memory time dimension. + # alignments shape is + # [batch_size, 1, memory_time] + # attention_mechanism.values shape is + # [batch_size, memory_time, attention_mechanism.num_units] + # the batched matmul is over memory_time, so the output shape is + # [batch_size, 1, attention_mechanism.num_units]. + # we then squeeze out the singleton dim. + attention_mechanism_values = _maybe_tile_batch( + self._attention_mechanism.values, tiling_factor) - # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] - expanded_alignments = array_ops.expand_dims(alignments, 1) - # Context is the inner product of alignments and values along the - # memory time dimension. - # alignments shape is - # [batch_size, 1, memory_time] - # attention_mechanism.values shape is - # [batch_size, memory_time, attention_mechanism.num_units] - # the batched matmul is over memory_time, so the output shape is - # [batch_size, 1, attention_mechanism.num_units]. - # we then squeeze out the singleton dim. - attention_mechanism_values = _maybe_tile_batch( - self._attention_mechanism.values, tiling_factor) + context = math_ops.matmul(expanded_alignments, attention_mechanism_values) + context = array_ops.squeeze(context, [1]) - context = math_ops.matmul(expanded_alignments, attention_mechanism_values) - context = array_ops.squeeze(context, [1]) + attention = self._attention_layer( + array_ops.concat([cell_output, context], 1)) - attention = self._attention_layer( - array_ops.concat([cell_output, context], 1)) + if self._alignment_history: + alignment_history = state.alignment_history.write( + state.time, alignments) + else: + alignment_history = () - if self._alignment_history: - alignment_history = state.alignment_history.write( - state.time, alignments) - else: - alignment_history = () - - next_state = AttentionWrapperState( - time=state.time + 1, - cell_state=next_cell_state, - attention=attention, - alignment_history=alignment_history) + next_state = AttentionWrapperState( + time=state.time + 1, + cell_state=next_cell_state, + attention=attention, + alignment_history=alignment_history) if self._output_attention: return attention, next_state diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 8f1f74ab09d..55ef21a5a0d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -19,9 +19,9 @@ from __future__ import division from __future__ import print_function import collections -import numpy as np from tensorflow.contrib.rnn import core_rnn_cell +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import script_ops from tensorflow.python.util import nest @@ -202,20 +201,24 @@ class BeamSearchDecoder(decoder.Decoder): return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state): + def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. + sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`. + The sequence lengths determined for each beam during decode. Returns: outputs: An instance of FinalBeamSearchDecoderOutput where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of BeamSearchDecoderState. """ - predicted_ids = _gather_tree(outputs.predicted_ids, outputs.parent_ids) + predicted_ids = beam_search_ops.gather_tree( + outputs.predicted_ids, outputs.parent_ids, + sequence_length=sequence_lengths) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state @@ -536,42 +539,6 @@ def _mask_probs(probs, eos_token, finished): return finished_examples + non_finished_examples -def _gather_tree_py(values, parents): - """Gathers path through a tree backwards from the leave nodes. - - Used to reconstruct beams given their parents. - - Args: - values: A [T, batch_size, beam_width] tensor of indices. - parents: A [T, batch_size, beam_width] tensor of parent beam ids. - - Returns: - The [T, batch_size, beam_width] numpy array of paths. For a given batch - entry b, the best path is given by ret[:, b, 0]. - """ - num_timesteps = values.shape[0] - num_beams = values.shape[2] - batch_size = values.shape[1] - ret = np.zeros_like(values) # [T, MB, BW] - ret[-1, :, :] = values[-1, :, :] - for beam_id in range(num_beams): - for batch in range(batch_size): - parent = parents[-1][batch][beam_id] - for timestep in reversed(range(num_timesteps - 1)): - ret[timestep, batch, beam_id] = values[timestep][batch][parent] - parent = parents[timestep][batch][parent] - # now we are going to return ret as a [ts, mb, bw] tensor - return np.array(ret).astype(values.dtype) - - -def _gather_tree(values, parents): - """Tensor version of _gather_tree_py.""" - ret = script_ops.py_func( - func=_gather_tree_py, inp=[values, parents], Tout=values.dtype) - ret.set_shape(values.get_shape().as_list()) - return ret - - def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size, final_shape): range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index ee287b0cf65..ff705715e01 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -154,11 +154,11 @@ def dynamic_decode(decoder, scope: Optional variable scope to use. Returns: - `(final_outputs, final_state)`. + `(final_outputs, final_state, final_sequence_lengths)`. Raises: TypeError: if `decoder` is not an instance of `Decoder`. - ValueError: if maximum_iterations is provided but is not a scalar. + ValueError: if `maximum_iterations` is provided but is not a scalar. """ if not isinstance(decoder, Decoder): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % @@ -184,6 +184,8 @@ def dynamic_decode(decoder, if maximum_iterations is not None: initial_finished = math_ops.logical_or( initial_finished, 0 >= maximum_iterations) + initial_sequence_lengths = array_ops.zeros_like( + initial_finished, dtype=dtypes.int32) initial_time = constant_op.constant(0, dtype=dtypes.int32) def _shape(batch_size, from_shape): @@ -206,10 +208,10 @@ def dynamic_decode(decoder, decoder.output_dtype) def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, - finished): + finished, unused_sequence_lengths): return math_ops.logical_not(math_ops.reduce_all(finished)) - def body(time, outputs_ta, state, inputs, finished): + def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: @@ -217,10 +219,13 @@ def dynamic_decode(decoder, outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. - finished: 1-D bool tensor. + finished: bool tensor (keeping track of what's finished). + sequence_lengths: int32 tensor (keeping track of time of finish). Returns: - `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. + `(time + 1, outputs_ta, next_state, next_inputs, next_finished, + next_sequence_lengths)`. + ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) @@ -228,6 +233,10 @@ def dynamic_decode(decoder, if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) + next_sequence_lengths = array_ops.where( + math_ops.logical_and(math_ops.logical_not(finished), next_finished), + array_ops.fill(array_ops.shape(sequence_lengths), time + 1), + sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) @@ -260,26 +269,30 @@ def dynamic_decode(decoder, outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) - return (time + 1, outputs_ta, next_state, next_inputs, next_finished) + return (time + 1, outputs_ta, next_state, next_inputs, next_finished, + next_sequence_lengths) res = control_flow_ops.while_loop( condition, body, loop_vars=[ initial_time, initial_outputs_ta, initial_state, initial_inputs, - initial_finished + initial_finished, initial_sequence_lengths, ], parallel_iterations=parallel_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] final_state = res[2] + final_sequence_lengths = res[5] final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) + + if hasattr(decoder, "finalize"): + final_outputs, final_state = decoder.finalize( + final_outputs, final_state, final_sequence_lengths) + if not output_time_major: final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) - if hasattr(decoder, "finalize"): - final_outputs, final_state = decoder.finalize(final_outputs, final_state) - - return final_outputs, final_state + return final_outputs, final_state, final_sequence_lengths diff --git a/tensorflow/contrib/specs/python/specs.py b/tensorflow/contrib/specs/python/specs.py index a9fba442db5..d5223b9b551 100644 --- a/tensorflow/contrib/specs/python/specs.py +++ b/tensorflow/contrib/specs/python/specs.py @@ -19,13 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - -import inspect - from six import exec_ from tensorflow.contrib.specs.python import params_ops from tensorflow.contrib.specs.python import specs_lib from tensorflow.contrib.specs.python import specs_ops +from tensorflow.python.util import tf_inspect def eval_params(params, environment=None): @@ -44,7 +42,8 @@ def eval_params(params, environment=None): """ specs_lib.check_keywords(params) bindings = {} - if environment: bindings.update(environment) + if environment: + bindings.update(environment) exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used return bindings @@ -71,7 +70,8 @@ def eval_spec(spec, environment=None): """ specs_lib.check_keywords(spec) bindings = {} - if environment: bindings.update(environment) + if environment: + bindings.update(environment) exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used return bindings @@ -141,7 +141,7 @@ class LocalImport(object): self.names = names def __enter__(self): - self.frame = inspect.currentframe() + self.frame = tf_inspect.currentframe() bindings = self.frame.f_back.f_globals self.old = {k: bindings.get(k, None) for k in self.names.keys()} bindings.update(self.names) @@ -151,7 +151,9 @@ class LocalImport(object): bindings = self.frame.f_back.f_globals bindings.update(self.old) for k, v in self.old.items(): - if v is None: del bindings[k] + if v is None: + del bindings[k] del self.frame + ops = LocalImport(specs_ops) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index e1302d9903b..e747fa4c9e4 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -1,6 +1,10 @@ # Description: # Verbs RDMA communication interfaces and implementations for TensorFlow. +package(default_visibility = [ + "//tensorflow:__subpackages__", +]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -31,13 +35,10 @@ load( "tf_proto_library_cc", ) -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - tf_proto_library_cc( name = "verbs_service_proto", srcs = ["verbs_service.proto"], + has_services = 1, cc_api_version = 2, visibility = [ "//tensorflow:__subpackages__", diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.cc b/tensorflow/contrib/verbs/grpc_verbs_client.cc index be94af18670..608a9140d3d 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_client.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_client.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -namespace tensorflow { - +namespace tensorflow { + Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options, - const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response) { + const GetRemoteAddressRequest* request, + GetRemoteAddressResponse* response) { ::grpc::ClientContext ctx; ctx.set_fail_fast(false); SetDeadline(&ctx, call_options->GetTimeout()); @@ -31,14 +31,14 @@ Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options, } Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response) { + GetRemoteAddressResponse* response) { CallOptions call_options; - call_options.SetTimeout(-1); // no time out + call_options.SetTimeout(-1); // no time out return GetRemoteAddress(&call_options, request, response); } -void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx, - int64 time_in_ms) { +void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx, + int64 time_in_ms) { if (time_in_ms > 0) { ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN)); } diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h index 030710726fd..358977f9254 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_client.h +++ b/tensorflow/contrib/verbs/grpc_verbs_client.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ +#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" +#include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "tensorflow/contrib/verbs/verbs_service.pb.h" namespace tensorflow { @@ -28,24 +28,23 @@ namespace tensorflow { class GrpcVerbsClient { public: explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel) - : stub_(grpc::VerbsService::NewStub(client_channel)) {} + : stub_(grpc::VerbsService::NewStub(client_channel)) {} ~GrpcVerbsClient() {} Status GetRemoteAddress(CallOptions* call_options, - const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response); + const GetRemoteAddressRequest* request, + GetRemoteAddressResponse* response); Status GetRemoteAddress(const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response); - + GetRemoteAddressResponse* response); + private: std::unique_ptr stub_; void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms); - + TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient); }; } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ - diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 2b1cdec6b91..e73b2700bd9 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -26,10 +26,10 @@ limitations under the License. namespace tensorflow { GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env, - ::grpc::ServerBuilder* builder) - : is_shutdown_(false), worker_env_(worker_env) { - builder->RegisterService(&verbs_service_); - cq_ = builder->AddCompletionQueue().release(); + ::grpc::ServerBuilder* builder) + : is_shutdown_(false), worker_env_(worker_env) { + builder->RegisterService(&verbs_service_); + cq_ = builder->AddCompletionQueue().release(); } GrpcVerbsService::~GrpcVerbsService() { @@ -52,7 +52,7 @@ void GrpcVerbsService::Shutdown() { new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr); } } - + // This macro creates a new request for the given RPC method name // (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on // `this->cq_`. @@ -64,17 +64,17 @@ void GrpcVerbsService::Shutdown() { // The implementation of the request handler for each RPC method // must ensure that it calls ENQUEUE_REQUEST() for that RPC method, // to keep accepting new requests. -#define ENQUEUE_REQUEST(method, supports_cancel) \ - do { \ - mutex_lock l(shutdown_mu_); \ - if (!is_shutdown_) { \ - Call:: \ - EnqueueRequest(&verbs_service_, cq_, \ - &grpc::VerbsService::AsyncService::Request##method, \ - &GrpcVerbsService::method##Handler, \ - (supports_cancel)); \ - } \ +#define ENQUEUE_REQUEST(method, supports_cancel) \ + do { \ + mutex_lock l(shutdown_mu_); \ + if (!is_shutdown_) { \ + Call:: \ + EnqueueRequest(&verbs_service_, cq_, \ + &grpc::VerbsService::AsyncService::Request##method, \ + &GrpcVerbsService::method##Handler, \ + (supports_cancel)); \ + } \ } while (0) // This method blocks forever handling requests from the completion queue. @@ -97,8 +97,8 @@ void GrpcVerbsService::HandleRPCsLoop() { } } -void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall - * call) { +void GrpcVerbsService::GetRemoteAddressHandler( + WorkerCall* call) { Status s = GetRemoteAddressSync(&call->request, &call->response); call->SendResponse(ToGrpcStatus(s)); ENQUEUE_REQUEST(GetRemoteAddress, false); @@ -106,8 +106,8 @@ void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall // synchronous method Status GrpcVerbsService::GetRemoteAddressSync( - const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response) { + const GetRemoteAddressRequest* request, + GetRemoteAddressResponse* response) { // analyzing request // the channel setting part is redundant. const string remote_host_name = request->host_name(); @@ -115,7 +115,7 @@ Status GrpcVerbsService::GetRemoteAddressSync( CHECK(rc); RdmaAddress ra; ra.lid = request->channel().lid(); - ra.qpn = request->channel().qpn(); + ra.qpn = request->channel().qpn(); ra.psn = request->channel().psn(); rc->SetRemoteAddress(ra, false); rc->Connect(); @@ -140,8 +140,8 @@ Status GrpcVerbsService::GetRemoteAddressSync( CHECK(i == RdmaChannel::kNumMessageBuffers); // setting up response - response->set_host_name(worker_env_->session_mgr-> - LegacySession()->worker_name); + response->set_host_name( + worker_env_->session_mgr->LegacySession()->worker_name); Channel* channel_info = response->mutable_channel(); channel_info->set_lid(rc->self().lid); channel_info->set_qpn(rc->self().qpn); @@ -151,12 +151,12 @@ Status GrpcVerbsService::GetRemoteAddressSync( mr->set_remote_addr(reinterpret_cast(mb[i]->buffer())); mr->set_rkey(mb[i]->self()->rkey); } - return Status::OK(); + return Status::OK(); } // Create a GrpcVerbsService, then assign it to a given handle. void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env, - ::grpc::ServerBuilder* builder) { + ::grpc::ServerBuilder* builder) { *handle = new GrpcVerbsService(worker_env, builder); } diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h index dcc4518bb5d..aa509602b51 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service.h @@ -18,12 +18,12 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" +#include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" -#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/lib/core/refcount.h" namespace grpc { class ServerBuilder; @@ -44,27 +44,27 @@ class GrpcVerbsService : public AsyncServiceInterface { private: template using WorkerCall = Call; - void GetRemoteAddressHandler(WorkerCall - * call); + RequestMessage, ResponseMessage>; + void GetRemoteAddressHandler( + WorkerCall* call); Status GetRemoteAddressSync(const GetRemoteAddressRequest* request, - GetRemoteAddressResponse* response); - - ::grpc::ServerCompletionQueue* cq_; + GetRemoteAddressResponse* response); + + ::grpc::ServerCompletionQueue* cq_; grpc::VerbsService::AsyncService verbs_service_; mutex shutdown_mu_; bool is_shutdown_ GUARDED_BY(shutdown_mu_); ::grpc::Alarm* shutdown_alarm_; // not owned RdmaMgr* rdma_mgr_; - const WorkerEnv* const worker_env_; + const WorkerEnv* const worker_env_; TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService); }; // Create a GrpcVerbsService, then assign it to a given handle. void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env, - ::grpc::ServerBuilder* builder); + ::grpc::ServerBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 7aac8c7ab38..e0ba78dbfd5 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -43,7 +43,7 @@ VerbsService::Stub::Stub( const std::shared_ptr< ::grpc::ChannelInterface>& channel) : channel_(channel), rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0], - ::grpc::RpcMethod::NORMAL_RPC, channel) {} + ::grpc::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status VerbsService::Stub::GetRemoteAddress( ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index d9e5856cb30..f7ea774b661 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -48,16 +48,16 @@ class VerbsService GRPC_FINAL { class StubInterface { public: virtual ~StubInterface() {} - virtual ::grpc::Status GetRemoteAddress(::grpc::ClientContext* context, - const GetRemoteAddressRequest& request, - GetRemoteAddressResponse* response) = 0; + virtual ::grpc::Status GetRemoteAddress( + ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, + GetRemoteAddressResponse* response) = 0; }; class Stub GRPC_FINAL : public StubInterface { public: Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status GetRemoteAddress(::grpc::ClientContext* context, - const GetRemoteAddressRequest& request, - GetRemoteAddressResponse* response) GRPC_OVERRIDE; + ::grpc::Status GetRemoteAddress( + ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, + GetRemoteAddressResponse* response) GRPC_OVERRIDE; private: std::shared_ptr< ::grpc::ChannelInterface> channel_; diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 8063b299cd6..53d840f5d1c 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -15,16 +15,16 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include #include "tensorflow/contrib/verbs/rdma.h" +#include #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" @@ -35,12 +35,12 @@ namespace tensorflow { namespace { // hash name to 32-bit integer uint32_t NameHash(const string& name) { - return Hash32(name.data(), name.size(), 0x1234ABCD); + return Hash32(name.data(), name.size(), 0x1234ABCD); } // convenience function for printing message string MessageTypeToString(RdmaMessageType rmt) { - switch(rmt){ + switch (rmt) { case RDMA_MESSAGE_ACK: return "RDMA_MESSAGE_ACK"; break; @@ -59,11 +59,11 @@ string MessageTypeToString(RdmaMessageType rmt) { case RDMA_MESSAGE_TENSOR_WRITE: return "RDMA_MESSAGE_TENSOR_WRITE"; break; - default: + default: return "UNKNOWN MESSAGE"; } } -} +} // namespace ibv_context* open_default_device() { ibv_device** dev_list; @@ -89,29 +89,28 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) worker_env_(worker_env) { event_channel_ = ibv_create_comp_channel(context_); CHECK(event_channel_) << "Failed to create completion channel"; - cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_, 0); + cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_, + 0); CHECK(cq_) << "Failed to create completion queue"; CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification"; polling_thread_.reset(Env::Default()->StartThread( - ThreadOptions(), "RdmaAdapterCQThread", - [this] {Process_CQ(); })); - VLOG(2) << "Start RdmaAdapter: " << name(); + ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); + VLOG(2) << "Start RdmaAdapter: " << name(); } RdmaAdapter::~RdmaAdapter() { polling_thread_.reset(); CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; - CHECK(!ibv_destroy_comp_channel(event_channel_)) << "Failed to destroy channel"; + CHECK(!ibv_destroy_comp_channel(event_channel_)) + << "Failed to destroy channel"; CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD"; CHECK(!ibv_close_device(context_)) << "Failed to release context"; } -string RdmaAdapter::name() const { - return string(context_->device->name); -} +string RdmaAdapter::name() const { return string(context_->device->name); } // Function to process incoming messages -// There are two types of messages: +// There are two types of messages: // 1. IBV_WC_RECV_RDMA_WITH_IMM (receive) // 2. IBV_WC_RDMA_WRITE (send)) void RdmaAdapter::Process_CQ() { @@ -123,15 +122,14 @@ void RdmaAdapter::Process_CQ() { ibv_ack_cq_events(cq, 1); CHECK(!ibv_req_notify_cq(cq_, 0)); - int ne = ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, - static_cast(wc_)); + int ne = + ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast(wc_)); CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { - CHECK(wc_[i].status == IBV_WC_SUCCESS) << "Failed status \n" - << ibv_wc_status_str(wc_[i].status) - << " " << wc_[i].status << " " - << static_cast(wc_[i].wr_id) - << " "<< wc_[i].vendor_err; + CHECK(wc_[i].status == IBV_WC_SUCCESS) + << "Failed status \n" + << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " + << static_cast(wc_[i].wr_id) << " " << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast(wc_[i].wr_id); // put back a recv wr. @@ -142,8 +140,8 @@ void RdmaAdapter::Process_CQ() { RdmaMessage rm; RdmaMessage::ParseMessage(rm, rb->buffer_); VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_); - - if (rm.type_ == RDMA_MESSAGE_ACK) { + + if (rm.type_ == RDMA_MESSAGE_ACK) { // receive an ack to a message rb = rc->tx_message_buffer_; rb->SetBufferStatus(remote, idle); @@ -155,12 +153,12 @@ void RdmaAdapter::Process_CQ() { ab->SendNextItem(); // find or create buffer RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_); - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); + string key_with_step_id = + VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); tb->EnqueueItem(key_with_step_id); // send the next tensor - worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();}); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) { + worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); }); + } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) { // receive tensor-buffer-ready message // send ack to release remote tx message buffer RdmaBuffer* ab = rc->tx_ack_buffer_; @@ -168,7 +166,7 @@ void RdmaAdapter::Process_CQ() { // find buffer RdmaBuffer* tb = rc->FindBuffer(rm.name_); tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();}); + worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); }); } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) { // remote host requests to create a tensor buffer; // send ack to release remote tx message buffer @@ -194,31 +192,30 @@ void RdmaAdapter::Process_CQ() { mb->EnqueueItem(message); mb->SendNextItem(); } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) { - // remote creates a buffer and responds + // remote creates a buffer and responds // send ack to release remote tx message buffer RdmaBuffer* ab = rc->tx_ack_buffer_; ab->SendNextItem(); // find buffer RdmaBuffer* tb = rc->FindBuffer(rm.name_); - CHECK(rm.buffer_size_ == tb->size_) - << "rm.buffer_size = " << rm.buffer_size_ - << "tb->size_ = " << tb->size_ - << "rm.name_ = " << rm.name_; + CHECK(rm.buffer_size_ == tb->size_) + << "rm.buffer_size = " << rm.buffer_size_ + << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_; RemoteMR rmr; rmr.remote_addr = rm.remote_addr_; rmr.rkey = rm.rkey_; tb->SetRemoteMR(rmr, true); tb->SetBufferStatus(local, idle); tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();}); + worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); }); } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { // tensor RDMA write completed - worker_env_->compute_pool->Schedule([rm, rc](){ - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); + worker_env_->compute_pool->Schedule([rm, rc]() { + string key_with_step_id = + VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); rc->RunRecvCallback(key_with_step_id); - }); - } + }); + } } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { RdmaBuffer* rb = reinterpret_cast(wc_[i].wr_id); rb->SetBufferStatus(local, idle); @@ -226,7 +223,7 @@ void RdmaAdapter::Process_CQ() { RdmaMessage::ParseMessage(rm, rb->buffer_); VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_); if (rm.type_ != RDMA_MESSAGE_ACK) { - worker_env_->compute_pool->Schedule([rb](){rb->SendNextItem();}); + worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); }); } } } @@ -235,9 +232,7 @@ void RdmaAdapter::Process_CQ() { RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, const string remote_name) - : adapter_(adapter), - local_name_(local_name), - remote_name_(remote_name) { + : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) { // Create queue pair { struct ibv_qp_init_attr attr; @@ -263,21 +258,21 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, attr.port_num = 1; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; - int mask = IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT - | IBV_QP_ACCESS_FLAGS; + int mask = + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT"; } // Local address { struct ibv_port_attr attr; - CHECK(!ibv_query_port(adapter_->context_, (uint8_t) 1, &attr)) + CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr)) << "Query port"; self_.lid = attr.lid; self_.qpn = qp_->qp_num; self_.psn = static_cast(random::New64()) & 0xffffff; } - + // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", @@ -303,7 +298,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, buffer_index_name_table_.insert({index, buffer_names[i]}); buffer_name_index_table_.insert({buffer_names[i], index}); } - + // Initiate recv for (int i = 0; i < 100; i++) { Recv(); @@ -320,17 +315,17 @@ RdmaChannel::~RdmaChannel() { } void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { - mutex_lock lock{mu_}; - if ((override) || (!remote_set_)) { - remote_.lid = ra.lid; - remote_.qpn = ra.qpn; - remote_.psn = ra.psn; - remote_set_ = true; - } else { - CHECK(remote_.lid == ra.lid); - CHECK(remote_.qpn == ra.qpn); - CHECK(remote_.psn == ra.psn); - } + mutex_lock lock{mu_}; + if ((override) || (!remote_set_)) { + remote_.lid = ra.lid; + remote_.qpn = ra.qpn; + remote_.psn = ra.psn; + remote_set_ = true; + } else { + CHECK(remote_.lid == ra.lid); + CHECK(remote_.qpn == ra.qpn); + CHECK(remote_.psn == ra.psn); + } } // Adding tokens to the completion queue @@ -338,7 +333,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -347,12 +342,11 @@ void RdmaChannel::Recv() { // Args: // buffer_name: name of the buffer // Returns: -// 32-bit index -uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name){ - +// 32-bit index +uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) { mutex_lock lock{bt_mu_}; - BufferNameIndexTable::iterator iter = buffer_name_index_table_.find( - buffer_name); + BufferNameIndexTable::iterator iter = + buffer_name_index_table_.find(buffer_name); CHECK(iter != buffer_name_index_table_.end()); return iter->second; } @@ -380,14 +374,14 @@ RdmaBuffer* RdmaChannel::FindBuffer(const string& name) { } // Find a buffer if it exists, otherwise create one. -// The memory inside the created buffer is not allocated. -// Args: +// The memory inside the created buffer is not allocated. +// Args: // name: the name of the buffer // buffer_type: TENSOR, MESSAGE or ACK. // Returns: // the named buffer -RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, - BufferType buffer_type) { +RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, + BufferType buffer_type) { mutex_lock lock{bt_mu_}; RdmaBuffer* rb; // find index @@ -405,7 +399,7 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, } else if (buffer_type == MESSAGE) { rb = new RdmaMessageBuffer(this, name); } else if (buffer_type == ACK) { - rb = new RdmaAckBuffer(this, name); + rb = new RdmaAckBuffer(this, name); } buffer_name_index_table_.insert({name, index}); buffer_index_name_table_.insert({index, name}); @@ -417,20 +411,19 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, // Insert callback to the callback_table. // The callback is activated when the corresponding tensor is received. -// Arg: +// Arg: // key: the name of the tensor // recv_done: the callback associated with the tensor. // Returns: // None -void RdmaChannel::InsertRecvCallback(const string& key, - std::function recv_done) { - +void RdmaChannel::InsertRecvCallback(const string& key, + std::function recv_done) { mutex_lock lock{ct_mu_}; callback_table_.insert({key, recv_done}); } // Remove callback from the callback_table. -// Arg: +// Arg: // key: the name of the tensor // Returns: // None @@ -440,7 +433,7 @@ void RdmaChannel::RemoveRecvCallback(const string& key) { } // Run named callback in the callback_table. -// Arg: +// Arg: // key: the name of the tensor // Returns: // None @@ -484,17 +477,15 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.sl = 0; attr.ah_attr.src_path_bits = 0; attr.ah_attr.port_num = 1; - + int r; CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | - IBV_QP_AV | - IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | - IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; - + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) + << "QP to Ready to Receive " << r; + memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTS; attr.sq_psn = self_.psn; @@ -502,15 +493,13 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.retry_cnt = 7; attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - + CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | - IBV_QP_TIMEOUT | - IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | - IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; - + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) + << "QP to Ready to Send " << r; + connected_ = true; } else { LOG(INFO) << "channel already connected"; @@ -518,7 +507,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { } RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name) - : channel_(channel), name_(name) {} + : channel_(channel), name_(name) {} RdmaBuffer::~RdmaBuffer() { CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; @@ -528,9 +517,9 @@ RdmaBuffer::~RdmaBuffer() { void RdmaBuffer::FreeBuffer() { if ((buffer_ != nullptr) && buffer_on_host_) { free(buffer_); - } + } // TODO - // release buffer if it is on device. + // release buffer if it is on device. // We don't support RDMABuffer on device at this moment. } @@ -548,14 +537,12 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { if (local_status_ != none) { // delete existing buffer CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; - FreeBuffer(); + FreeBuffer(); } size_ = size; buffer_ = malloc(size_); - self_ = ibv_reg_mr(channel_->adapter_->pd_, - buffer_, size_, - IBV_ACCESS_LOCAL_WRITE | - IBV_ACCESS_REMOTE_WRITE); + self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); CHECK(self_) << "Failed to register memory region"; buffer_on_host_ = true; local_status_ = idle; @@ -572,53 +559,52 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { // None void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) { mutex_lock lock{mu_}; - if ((override) || (remote_status_ == none)) { + if ((override) || (remote_status_ == none)) { remote_.remote_addr = rmr.remote_addr; remote_.rkey = rmr.rkey; remote_status_ = idle; } else { CHECK(remote_.remote_addr == rmr.remote_addr); - CHECK(remote_.rkey == rmr.rkey); - } + CHECK(remote_.rkey == rmr.rkey); + } } // Put a task in the buffer's job queue -void RdmaBuffer::EnqueueItem(string item){ +void RdmaBuffer::EnqueueItem(string item) { mutex_lock lock{mu_}; queue_.push(item); } // Rdma-Write the content of the buffer void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { - struct ibv_sge list; - list.addr = (uint64_t) buffer_; + list.addr = (uint64_t)buffer_; list.length = buffer_size; list.lkey = self_->lkey; struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wr.send_flags = IBV_SEND_SIGNALED; wr.imm_data = imm_data; - wr.wr.rdma.remote_addr = (uint64_t) remote_.remote_addr; + wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr; wr.wr.rdma.rkey = remote_.rkey; - struct ibv_send_wr *bad_wr; + struct ibv_send_wr* bad_wr; CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send"; } RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - + : RdmaBuffer(channel, name) {} + RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} + : RdmaBuffer(channel, name) {} RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} + : RdmaBuffer(channel, name) {} // Send the next ack from the buffer's job queue. void RdmaAckBuffer::SendNextItem() { @@ -636,13 +622,12 @@ void RdmaAckBuffer::SendNextItem() { void RdmaMessageBuffer::SendNextItem() { uint32_t imm_data = LookupBufferIndex("rx_message_buffer"); mu_.lock(); - if (!queue_.empty() && (local_status_ == idle) - && (remote_status_ == idle)) { + if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) { local_status_ = busy; - remote_status_= busy; + remote_status_ = busy; string message = queue_.front(); queue_.pop(); - // local/remote_status_ won't be set back to idle + // local/remote_status_ won't be set back to idle // unitl Write() is successful mu_.unlock(); memcpy(buffer_, message.data(), message.size()); @@ -665,61 +650,56 @@ void RdmaTensorBuffer::SendNextItem() { } // send the tensor if a key is acquired. if (key_with_step_id != "") { - VLOG(2) << "try to send tensor: " << key_with_step_id; + VLOG(2) << "try to send tensor: " << key_with_step_id; string key; int64 step_id; VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id); CHECK(key.compare(name_) == 0); Rendezvous::ParsedKey parsed; Rendezvous::ParseKey(key, &parsed); - Rendezvous::DoneCallback cb = [this, key_with_step_id, key, - step_id, parsed](const Status& status, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& in, bool is_dead) { - CHECK(status.ok()) << "RecvLocalAsync was not ok, key" - << key_with_step_id - << " error message: " << status.error_message(); + Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, + parsed](const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { + CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id + << " error message: " << status.error_message(); size_t buffer_size = RdmaMessage::kMessageTotalBytes; size_t tensor_bytes = 0; TensorProto proto; // Figures out which device the tensor is hosted on. Device* src_dev = nullptr; - Status s = - channel_->adapter_->worker_env_-> - device_mgr->LookupDevice(parsed.src_device, &src_dev); + Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice( + parsed.src_device, &src_dev); CHECK(s.ok()) << "src device not found"; // Does the device have the right incarnation number we expect? - CHECK(src_dev->attributes().incarnation() == - parsed.src_incarnation) - << "RecvTensor expects a different device incarnation: " - << parsed.src_incarnation - << " vs. " - << src_dev->attributes().incarnation() - << ". Your worker job was probably restarted. Check your " - << "worker job for the reason why it was restarted."; + CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation) + << "RecvTensor expects a different device incarnation: " + << parsed.src_incarnation << " vs. " + << src_dev->attributes().incarnation() + << ". Your worker job was probably restarted. Check your " + << "worker job for the reason why it was restarted."; Device* dst_dev = nullptr; // destination is on CPU. - s = channel_->adapter_->worker_env_-> - device_mgr->LookupDevice("CPU:0", &dst_dev); - CHECK(s.ok())<< "dst device not found"; + s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0", + &dst_dev); + CHECK(s.ok()) << "dst device not found"; AllocatorAttributes dst_alloc_attr; dst_alloc_attr.set_on_host(true); // string tensor needs to be serialized - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { + if (src_dev->tensorflow_gpu_device_info() && + (!send_args.alloc_attrs.on_host())) { CHECK(send_args.device_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); // "val" is on a GPU. Uses GPUUtil to fill the proto. - s = VerbsUtil::SetProtoFromGPUSync(in, src_dev, - send_args.device_context, - &proto, is_dead); - CHECK(s.ok()) << "set proto from gpu sync"; + s = VerbsUtil::SetProtoFromGPUSync( + in, src_dev, send_args.device_context, &proto, is_dead); + CHECK(s.ok()) << "set proto from gpu sync"; } else { // tensor is in CPU memory. in.AsProtoTensorContent(&proto); - } + } tensor_bytes = proto.ByteSize(); // maybe some margin for string tensor? buffer_size += tensor_bytes; @@ -734,13 +714,12 @@ void RdmaTensorBuffer::SendNextItem() { rm.tensor_bytes_ = tensor_bytes; rm.buffer_size_ = buffer_size; mu_.lock(); - if (local_status_ == none || - (buffer_size > size_ && - local_status_ == idle && + if (local_status_ == none || + (buffer_size > size_ && local_status_ == idle && remote_status_ == idle)) { if ((local_status_ != none) && (buffer_size > size_)) { - CHECK(rm.data_type_ == DT_STRING) - << "Only string tensor allows to change size"; + CHECK(rm.data_type_ == DT_STRING) + << "Only string tensor allows to change size"; } CreateCPUBuffer(buffer_size, false); mu_.unlock(); @@ -752,29 +731,29 @@ void RdmaTensorBuffer::SendNextItem() { rm.rkey_ = self_->rkey; string message = RdmaMessage::CreateMessage(rm); channel_->tx_message_buffer_->EnqueueItem(message); - channel_->tx_message_buffer_->SendNextItem(); - } else if((local_status_ == idle) && (remote_status_ == idle)) { + channel_->tx_message_buffer_->SendNextItem(); + } else if ((local_status_ == idle) && (remote_status_ == idle)) { // both buffers are ready, send the tensor local_status_ = busy; remote_status_ = busy; - // local/remote_status_ won't be set back to idle + // local/remote_status_ won't be set back to idle // unitl Write() is successful mu_.unlock(); CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) || (buffer_size <= size_ && rm.data_type_ == DT_STRING)) - << "tensor and buffer size do not agree!" - << " buffer_size = " << size_ - << " requested tensor size = " << buffer_size - << in.DebugString(); + << "tensor and buffer size do not agree!" + << " buffer_size = " << size_ + << " requested tensor size = " << buffer_size << in.DebugString(); uint32_t imm_data = LookupBufferIndex(key); rm.type_ = RDMA_MESSAGE_TENSOR_WRITE; string message = RdmaMessage::CreateMessage(rm); memcpy(buffer_, message.data(), message.size()); if (!is_dead) { // copy the tensor buffer content - void* output = static_cast(static_cast( - buffer_) + RdmaMessage::kTensorBufferStartIndex); - CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); + void* output = + static_cast(static_cast(buffer_) + + RdmaMessage::kTensorBufferStartIndex); + CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); proto.SerializeToArray(output, tensor_bytes); } else { buffer_size = RdmaMessage::kMessageTotalBytes; @@ -789,8 +768,8 @@ void RdmaTensorBuffer::SendNextItem() { // Use default session (legacy_session_) // TODO use WorkerSessionForSession // need to pass in session handle - channel_->adapter_->worker_env_->session_mgr-> - LegacySession()->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb); + channel_->adapter_->worker_env_->session_mgr->LegacySession() + ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb); } } @@ -811,8 +790,10 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) { // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead // |data_type|tensor_shape|tensor_bytes // BUFFER_IDLE: type|name_size|buffer_name - // BUFFER_REQUEST: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| - // BUFFER_RESPONSE: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| + // BUFFER_REQUEST: + // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| + // BUFFER_RESPONSE: + // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| char message[kMessageTotalBytes]; // type message[kTypeStartIndex] = static_cast(rm.type_) & 0xff; @@ -821,32 +802,32 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) { // name memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); // buffer_size, remote_addr, rkey - if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || + if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_, - sizeof(rm.buffer_size_)); - memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_, - sizeof(rm.remote_addr_)); - memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_)); + memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_, + sizeof(rm.buffer_size_)); + memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_, + sizeof(rm.remote_addr_)); + memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_)); } // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || + if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_)); } // is_dead, data_type, tensor_shape, tensor_bytes if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_)); - - memcpy(&message[kDataTypeStartIndex], &rm.data_type_, - sizeof(rm.data_type_)); - memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_, - sizeof(rm.tensor_shape_)); - memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_, - sizeof(rm.tensor_bytes_)); + + memcpy(&message[kDataTypeStartIndex], &rm.data_type_, + sizeof(rm.data_type_)); + memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_, + sizeof(rm.tensor_shape_)); + memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_, + sizeof(rm.tensor_bytes_)); } return string(message, kMessageTotalBytes); -} +} // Parse a RdmaMessage according to the pre-defined format // Args: @@ -865,27 +846,26 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { // buffer_size, remote_addr, rkey if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex], - sizeof(rm.buffer_size_)); - memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex], - sizeof(rm.remote_addr_)); - memcpy(&rm.rkey_, &message[kRkeyStartIndex], - sizeof(rm.rkey_)); + memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex], + sizeof(rm.buffer_size_)); + memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex], + sizeof(rm.remote_addr_)); + memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_)); } // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || + if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_)); } - // data_type, tensor_bytes, tensor_shape, is_dead + // data_type, tensor_bytes, tensor_shape, is_dead if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_)); - memcpy(&rm.data_type_, &message[kDataTypeStartIndex], - sizeof(rm.data_type_)); + memcpy(&rm.data_type_, &message[kDataTypeStartIndex], + sizeof(rm.data_type_)); memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex], - sizeof(rm.tensor_shape_)); - memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex], - sizeof(rm.tensor_bytes_)); + sizeof(rm.tensor_shape_)); + memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex], + sizeof(rm.tensor_bytes_)); } } diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index 2f6cb402956..ae2aa63e3f6 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -19,11 +19,11 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS #include -#include // for shared_ptr -#include // for memset +#include // for memset +#include +#include // for shared_ptr #include #include -#include #include #include @@ -37,43 +37,46 @@ namespace tensorflow { // structure to save the address of remote channels. struct RdmaAddress { - uint32_t lid; - uint32_t qpn; - uint32_t psn; + uint32_t lid; + uint32_t qpn; + uint32_t psn; }; // structure to save information for remote memory regions. -struct RemoteMR{ - uint64_t remote_addr; - uint32_t rkey; +struct RemoteMR { + uint64_t remote_addr; + uint32_t rkey; +}; +enum BufferStatus { none, idle, busy }; +enum Location { local, remote }; +enum BufferType { ACK, MESSAGE, TENSOR }; +enum RdmaMessageType { + RDMA_MESSAGE_ACK, + RDMA_MESSAGE_BUFFER_IDLE, + RDMA_MESSAGE_BUFFER_REQUEST, + RDMA_MESSAGE_BUFFER_RESPONSE, + RDMA_MESSAGE_TENSOR_REQUEST, + RDMA_MESSAGE_TENSOR_WRITE }; -enum BufferStatus {none, idle, busy}; -enum Location {local, remote}; -enum BufferType {ACK, MESSAGE, TENSOR}; -enum RdmaMessageType {RDMA_MESSAGE_ACK, - RDMA_MESSAGE_BUFFER_IDLE, - RDMA_MESSAGE_BUFFER_REQUEST, - RDMA_MESSAGE_BUFFER_RESPONSE, - RDMA_MESSAGE_TENSOR_REQUEST, - RDMA_MESSAGE_TENSOR_WRITE}; class RdmaBuffer; // Class that represents the Rdma Adapter. // Responsible for creation of the completion queue, and handling -// of work completions. +// of work completions. class RdmaAdapter { - friend class RdmaChannel; - friend class RdmaBuffer; - friend class RdmaAckBuffer; - friend class RdmaMessageBuffer; - friend class RdmaTensorBuffer; - friend class RdmaMgr; - friend class RdmaRemoteRendezvous; + friend class RdmaChannel; + friend class RdmaBuffer; + friend class RdmaAckBuffer; + friend class RdmaMessageBuffer; + friend class RdmaTensorBuffer; + friend class RdmaMgr; + friend class RdmaRemoteRendezvous; + public: RdmaAdapter(const WorkerEnv* worker_env); ~RdmaAdapter(); // Adapter name, e.g. mlx5_0. string name() const; void Process_CQ(); - + protected: static const int MAX_CONCURRENT_WRITES = 1000; ibv_context* context_; @@ -94,36 +97,39 @@ class RdmaAdapter { // Class that represents a connection to a remote Rdma peer. // Responsible for connecting queue pairs. class RdmaChannel { - friend class RdmaAdapter; - friend class RdmaBuffer; - friend class RdmaAckBuffer; - friend class RdmaMessageBuffer; - friend class RdmaTensorBuffer; - friend class RdmaMgr; - friend class RdmaRemoteRendezvous; + friend class RdmaAdapter; + friend class RdmaBuffer; + friend class RdmaAckBuffer; + friend class RdmaMessageBuffer; + friend class RdmaTensorBuffer; + friend class RdmaMgr; + friend class RdmaRemoteRendezvous; + public: - explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name, - const string remote_name_); + explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name, + const string remote_name_); ~RdmaChannel(); inline const RdmaAddress& self() { return self_; } RdmaAddress address() const; inline const std::vector& message_buffers() const { - return message_buffers_;} + return message_buffers_; + } void Connect(const RdmaAddress& remoteAddr); void Connect(); void Recv(); RdmaBuffer* FindBuffer(const uint32_t index); RdmaBuffer* FindBuffer(const string& name); - RdmaBuffer* FindOrCreateBuffer(const string& name, + RdmaBuffer* FindOrCreateBuffer(const string& name, BufferType buffer_type = TENSOR); - uint32_t LookupBufferIndex (const string& buffer_name); + uint32_t LookupBufferIndex(const string& buffer_name); void SetRemoteAddress(const RdmaAddress& ra, bool override); void InsertRecvCallback(const string& key, std::function recv_done); void RemoveRecvCallback(const string& key); void RunRecvCallback(const string& key); static const int kNumMessageBuffers = 4; + protected: - const RdmaAdapter* adapter_; + const RdmaAdapter* adapter_; RdmaAddress self_; string local_name_; string remote_name_; @@ -151,10 +157,11 @@ class RdmaChannel { // Class that represents a buffer for Rdma writes and reads. class RdmaBuffer { - friend class RdmaChannel; - friend class RdmaAdapter; - friend class RdmaMgr; - friend class RdmaRemoteRendezvous; + friend class RdmaChannel; + friend class RdmaAdapter; + friend class RdmaMgr; + friend class RdmaRemoteRendezvous; + public: explicit RdmaBuffer(RdmaChannel* channel, string name); virtual ~RdmaBuffer(); @@ -173,10 +180,11 @@ class RdmaBuffer { void FreeBuffer(); void EnqueueItem(string Item); virtual void SendNextItem(){}; - void CreateCPUBuffer(size_t size, bool lock=true); + void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); - uint32_t LookupBufferIndex (const string& buffer_name) { - return const_cast(channel_)->LookupBufferIndex(buffer_name);} + uint32_t LookupBufferIndex(const string& buffer_name) { + return const_cast(channel_)->LookupBufferIndex(buffer_name); + } void Write(uint32_t imm_data, size_t buffer_size); protected: @@ -188,7 +196,7 @@ class RdmaBuffer { ibv_mr* self_ = nullptr; mutex mu_; RemoteMR remote_; - std::queue queue_ GUARDED_BY(mu_); + std::queue queue_ GUARDED_BY(mu_); BufferStatus local_status_ GUARDED_BY(mu_) = none; BufferStatus remote_status_ GUARDED_BY(mu_) = none; }; @@ -201,8 +209,9 @@ class RdmaAckBuffer : public RdmaBuffer { }; class RdmaMessageBuffer : public RdmaBuffer { - friend class RdmaChannel; - friend class RdmaAapater; + friend class RdmaChannel; + friend class RdmaAapater; + public: explicit RdmaMessageBuffer(RdmaChannel* channel, string name); virtual ~RdmaMessageBuffer() override {} @@ -228,40 +237,41 @@ struct RdmaMessage { DataType data_type_; TensorShape tensor_shape_; size_t tensor_bytes_; - -// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... -// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... -// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer -// ...| XB | XB | 8B |... -// + + // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... + // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... + // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer + // ...| XB | XB | 8B |... + // static const size_t kNameCapacity = 512; static const size_t kTypeStartIndex = 0; static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); - static const size_t kNameStartIndex = kNameSizeStartIndex + sizeof(name_size_); + static const size_t kNameStartIndex = + kNameSizeStartIndex + sizeof(name_size_); static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; - static const size_t kBufferSizeStartIndex = kStepIdStartIndex - + sizeof(step_id_); - static const size_t kRemoteAddrStartIndex = kBufferSizeStartIndex - + sizeof(buffer_size_); - static const size_t kRkeyStartIndex = kRemoteAddrStartIndex - + sizeof(remote_addr_); + static const size_t kBufferSizeStartIndex = + kStepIdStartIndex + sizeof(step_id_); + static const size_t kRemoteAddrStartIndex = + kBufferSizeStartIndex + sizeof(buffer_size_); + static const size_t kRkeyStartIndex = + kRemoteAddrStartIndex + sizeof(remote_addr_); static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); - static const size_t kDataTypeStartIndex = kIsDeadStartIndex - + sizeof(is_dead_); - static const size_t kTensorShapeStartIndex = kDataTypeStartIndex - + sizeof(data_type_); - static const size_t kTensorBytesStartIndex = kTensorShapeStartIndex - + sizeof(TensorShape); - static const size_t kTensorBufferStartIndex = kTensorBytesStartIndex - + sizeof(tensor_bytes_); + static const size_t kDataTypeStartIndex = + kIsDeadStartIndex + sizeof(is_dead_); + static const size_t kTensorShapeStartIndex = + kDataTypeStartIndex + sizeof(data_type_); + static const size_t kTensorBytesStartIndex = + kTensorShapeStartIndex + sizeof(TensorShape); + static const size_t kTensorBufferStartIndex = + kTensorBytesStartIndex + sizeof(tensor_bytes_); static const size_t kMessageTotalBytes = kTensorBufferStartIndex; static const size_t kRdmaMessageBufferSize = kMessageTotalBytes; static const size_t kRdmaAckBufferSize = kMessageTotalBytes; - static string CreateMessage(const RdmaMessage & rm); + static string CreateMessage(const RdmaMessage& rm); static void ParseMessage(RdmaMessage& rm, void* buffer); }; -} // namespace tensorflow +} // namespace tensorflow -#endif // TENSORFLOW_USE_VERBS +#endif // TENSORFLOW_USE_VERBS #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_ diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 7bbdcaf7653..e28b80c6f6b 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -15,8 +15,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include #include "tensorflow/contrib/verbs/rdma_mgr.h" +#include #include "tensorflow/contrib/verbs/grpc_verbs_client.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { -RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, +RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, GrpcChannelCache* const channel_cache) : worker_env_(worker_env), channel_cache_(channel_cache) { rdma_adapter_ = new RdmaAdapter(worker_env_); @@ -34,14 +34,15 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, // need to pass in session handle local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name; std::vector workers; - worker_env_->session_mgr->LegacySession()-> - worker_cache->ListWorkers(&workers); - num_remote_workers_ = workers.size()-1; + worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers( + &workers); + num_remote_workers_ = workers.size() - 1; VLOG(2) << "rmda_mgr on local worker: " << local_worker_; for (size_t i = 0; i < workers.size(); i++) { if (local_worker_.compare(workers[i]) != 0) { - channel_table_.insert({workers[i], new RdmaChannel(rdma_adapter_, - local_worker_, workers[i])}); + channel_table_.insert( + {workers[i], + new RdmaChannel(rdma_adapter_, local_worker_, workers[i])}); } } } @@ -49,16 +50,16 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, // Setup Rdma channels between peers. // This is done at the beginning of the server setup. -void RdmaMgr::SetupChannels() { +void RdmaMgr::SetupChannels() { for (const auto& p : channel_table_) { string worker_name = p.first; LOG(INFO) << "connecting to remote node " << worker_name; RdmaChannel* rc = p.second; GetRemoteAddressRequest req; - GetRemoteAddressResponse resp; + GetRemoteAddressResponse resp; // get the channel cache - SharedGrpcChannelPtr client_channel = channel_cache_ - ->FindWorkerChannel(worker_name); + SharedGrpcChannelPtr client_channel = + channel_cache_->FindWorkerChannel(worker_name); GrpcVerbsClient* client = new GrpcVerbsClient(client_channel); CHECK(client != nullptr) << "No worker known as " << worker_name; @@ -70,8 +71,8 @@ void RdmaMgr::SetupChannels() { channel_info->set_psn(rc->self_.psn); for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) { MemoryRegion* mr = req.add_mr(); - mr->set_remote_addr(reinterpret_cast( - rc->message_buffers_[i]->buffer_)); + mr->set_remote_addr( + reinterpret_cast(rc->message_buffers_[i]->buffer_)); mr->set_rkey(rc->message_buffers_[i]->self_->rkey); } // synchronous call @@ -79,10 +80,10 @@ void RdmaMgr::SetupChannels() { // save obtained remote addresses // connect to the remote channel if (s.ok()) { - CHECK(worker_name.compare(resp.host_name())==0); + CHECK(worker_name.compare(resp.host_name()) == 0); RdmaAddress ra; ra.lid = resp.channel().lid(); - ra.qpn = resp.channel().qpn(); + ra.qpn = resp.channel().qpn(); ra.psn = resp.channel().psn(); rc->SetRemoteAddress(ra, false); rc->Connect(); @@ -112,7 +113,7 @@ void RdmaMgr::SetupChannels() { RdmaMgr::~RdmaMgr() { for (const auto& p : channel_table_) delete p.second; - channel_table_.clear(); + channel_table_.clear(); delete rdma_adapter_; } diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h index 931cc55c0d2..b156f64096c 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.h +++ b/tensorflow/contrib/verbs/rdma_mgr.h @@ -28,9 +28,8 @@ limitations under the License. namespace tensorflow { class RdmaMgr { - public: - explicit RdmaMgr(const WorkerEnv* const worker_env, + explicit RdmaMgr(const WorkerEnv* const worker_env, GrpcChannelCache* const channel_cache); ~RdmaMgr(); RdmaChannel* FindChannel(const string& key); @@ -45,11 +44,11 @@ class RdmaMgr { RdmaAdapter* rdma_adapter_; typedef std::unordered_map ChannelTable; ChannelTable channel_table_; - + TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr); }; -} // namespace tensorflow +} // namespace tensorflow -#endif // TENSORFLOW_USE_VERBS +#endif // TENSORFLOW_USE_VERBS #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 2cfe7565ad5..8cbdfaa9439 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -15,8 +15,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h" +#include #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -29,14 +29,16 @@ namespace tensorflow { class RdmaRemoteRendezvous : public BaseRemoteRendezvous { public: - RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name, + RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name, int64 step_id, RdmaMgr* rdma_mgr) - : BaseRemoteRendezvous(env, worker_name, step_id, true), + : BaseRemoteRendezvous(env, worker_name, step_id, true), rdma_mgr_(rdma_mgr) {} + protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, DoneCallback done) override; + private: ~RdmaRemoteRendezvous() override {} RdmaMgr* rdma_mgr_; @@ -45,13 +47,13 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous { }; void RdmaRemoteRendezvous::RecvFromRemoteAsync( - const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, - DoneCallback done) { + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, + DoneCallback done) { Status s; // parse src_name and dst_name string src_name, dst_name, unused; - if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, - &src_name, &unused)) { + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name, + &unused)) { s = errors::Internal("Could not parse src name."); } CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); @@ -59,8 +61,8 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( done(s, Args(), recv_args, Tensor{}, false); return; } - if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, - &dst_name, &unused)) { + if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, + &unused)) { s = errors::Internal("Could not parse dst name."); } CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); @@ -73,52 +75,52 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( string key(std::move(parsed.FullKey().ToString())); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); // insert callback - rc->InsertRecvCallback(key_with_step_id, - [this, key, key_with_step_id, rc, recv_args, parsed, done](){ - Status s; - Device* src_dev; - s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor(), true); - return; - } - Device* dst_dev; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor(), true); - return; - } - RdmaBuffer* rb = rc->FindBuffer(key); - RdmaMessage rm; - CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes); - RdmaMessage::ParseMessage(rm, rb->buffer_); - CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE); - Tensor val; - if (!rm.is_dead_) { - void* input = static_cast(rb->buffer_) + - RdmaMessage::kTensorBufferStartIndex; - TensorProto proto; - CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= rb->size_); - CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) - << "fail to parse proto from array"; - s = dst_dev->MakeTensorFromProto(proto, - recv_args.alloc_attrs, &val); - } - - rc->RemoveRecvCallback(key_with_step_id); - // create message - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_IDLE; - br.name_size_ = key.size(); - br.name_ = key; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* tb = rc->tx_message_buffer_; - tb->EnqueueItem(message); - tb->SendNextItem(); - done(s, Args(), recv_args, val, rm.is_dead_); - }); + rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc, + recv_args, parsed, done]() { + Status s; + Device* src_dev; + s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); + CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), true); + return; + } + Device* dst_dev; + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); + CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), true); + return; + } + RdmaBuffer* rb = rc->FindBuffer(key); + RdmaMessage rm; + CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes); + RdmaMessage::ParseMessage(rm, rb->buffer_); + CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE); + Tensor val; + if (!rm.is_dead_) { + void* input = static_cast(rb->buffer_) + + RdmaMessage::kTensorBufferStartIndex; + TensorProto proto; + CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= + rb->size_); + CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) + << "fail to parse proto from array"; + s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); + } + + rc->RemoveRecvCallback(key_with_step_id); + // create message + RdmaMessage br; + br.type_ = RDMA_MESSAGE_BUFFER_IDLE; + br.name_size_ = key.size(); + br.name_ = key; + string message = RdmaMessage::CreateMessage(br); + RdmaBuffer* tb = rc->tx_message_buffer_; + tb->EnqueueItem(message); + tb->SendNextItem(); + done(s, Args(), recv_args, val, rm.is_dead_); + }); // append key to message queue RdmaBuffer* rb = rc->tx_message_buffer_; RdmaMessage rm; @@ -141,7 +143,7 @@ BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id, const string& worker_name) { return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_); } - + } // end namespace tensorflow #endif diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h index f23c31e7933..57cd4bf5e4e 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h @@ -47,12 +47,12 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr { public: explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name, WorkerCacheInterface* worker_cache); - void SetRdmaMgr(RdmaMgr* rdma_mgr) { - rdma_mgr_ = rdma_mgr; - } + void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; } + protected: BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& worker_name) override; + const string& worker_name) override; + private: RdmaMgr* rdma_mgr_; TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr); @@ -60,5 +60,5 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr { } // end namespace tensorflow -#endif // TENSORFLOW_USE_VERBS +#endif // TENSORFLOW_USE_VERBS #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index bb385d8fb98..b061c81d2d8 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -27,8 +27,9 @@ namespace tensorflow { namespace { // static utility function -RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env, - const string& worker_name, WorkerCacheInterface* worker_cache) { +RendezvousMgrInterface* NewRdmaRendezvousMgr( + const WorkerEnv* env, const string& worker_name, + WorkerCacheInterface* worker_cache) { return new RdmaRendezvousMgr(env, worker_name, worker_cache); } @@ -46,7 +47,7 @@ VerbsServer::~VerbsServer() { } Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, - GrpcChannelCache** channel_cache) { + GrpcChannelCache** channel_cache) { string name_prefix = strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:", server_def.task_index()); @@ -54,41 +55,43 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); - *channel_cache = NewGrpcChannelCache(channel_spec, - GetChannelCreationFunction(server_def)); - + *channel_cache = + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def)); + const string host_port = (*channel_cache)->TranslateTask(name_prefix); int requested_port; if (!strings::safe_strto32(str_util::Split(host_port, ':')[1], &requested_port)) { return errors::Internal("Could not parse port for local server from \"", - (*channel_cache)->TranslateTask(name_prefix), "\"."); + (*channel_cache)->TranslateTask(name_prefix), + "\"."); } if (requested_port != bound_port()) { return errors::InvalidArgument("Requested port ", requested_port, - " differs from expected port ", bound_port()); + " differs from expected port ", + bound_port()); } - + return Status::OK(); } Status VerbsServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func) { + RendezvousMgrCreationFunction rendezvous_mgr_func) { Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); { mutex_lock l(mu_); CHECK_EQ(verbs_state_, DISCONNECTED); CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok()); rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_); - // set rdma_mgr for verbs_service and rdma_rendezvous_mgr + // set rdma_mgr for verbs_service and rdma_rendezvous_mgr verbs_service_->SetRdmaMgr(rdma_mgr_); // hardcoded to default session (legacy_session_) // TODO: use WorkerSessionForSession // need to pass in session handle - dynamic_cast(worker_env()->session_mgr-> - LegacySession()->rendezvous_mgr.get()) - ->SetRdmaMgr(rdma_mgr_); + dynamic_cast( + worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get()) + ->SetRdmaMgr(rdma_mgr_); } return s; } @@ -100,9 +103,9 @@ Status VerbsServer::Start() { if (verbs_state_ == DISCONNECTED) { // verbs_thread needs to be initiated // before rdma_mgr sets up the rdma channels. - verbs_thread_.reset( - worker_env()->env->StartThread(ThreadOptions(), "TF_verbs_service", - [this] { verbs_service_->HandleRPCsLoop(); })); + verbs_thread_.reset(worker_env()->env->StartThread( + ThreadOptions(), "TF_verbs_service", + [this] { verbs_service_->HandleRPCsLoop(); })); rdma_mgr_->SetupChannels(); verbs_state_ = CONNECTED; } @@ -124,10 +127,10 @@ Status VerbsServer::Join() { /* static */ Status VerbsServer::Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server) { + std::unique_ptr* out_server) { std::unique_ptr ret(new VerbsServer(server_def, Env::Default())); ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env, - ::grpc::ServerBuilder* builder) { + ::grpc::ServerBuilder* builder) { return SetNewVerbsService(&ret->verbs_service_, worker_env, builder); }; TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr)); diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h index 4e4f6683f7b..855380129f2 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.h +++ b/tensorflow/contrib/verbs/verbs_server_lib.h @@ -18,8 +18,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" +#include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" namespace tensorflow { @@ -27,7 +27,7 @@ namespace tensorflow { class VerbsServer : public GrpcServer { protected: VerbsServer(const ServerDef& server_def, Env* env); - + public: static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server); @@ -39,21 +39,22 @@ class VerbsServer : public GrpcServer { // Implementations of ServerInterface methods. Status Start() override; Status Join() override; - + protected: Status Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func); Status ChannelCacheFactory(const ServerDef& server_def, GrpcChannelCache** channel_cache); + private: RdmaMgr* rdma_mgr_; - + // Guards state transitions. mutex mu_; - + enum State { DISCONNECTED, CONNECTED }; State verbs_state_ GUARDED_BY(mu_); - + GrpcVerbsService* verbs_service_ = nullptr; std::unique_ptr verbs_thread_ GUARDED_BY(mu_); GrpcChannelCache* channel_cache_ = nullptr; @@ -61,5 +62,5 @@ class VerbsServer : public GrpcServer { } // namespace tensorflow -#endif // TENSORFLOW_USE_VERBS +#endif // TENSORFLOW_USE_VERBS #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto index a99efc124d1..b985febfb8c 100644 --- a/tensorflow/contrib/verbs/verbs_service.proto +++ b/tensorflow/contrib/verbs/verbs_service.proto @@ -29,7 +29,7 @@ option java_package = "org.tensorflow.contrib.verbs"; message Channel { int32 lid = 1; int32 qpn = 2; - int32 psn = 3; + int32 psn = 3; } message MemoryRegion { @@ -39,15 +39,14 @@ message MemoryRegion { message GetRemoteAddressRequest { string host_name = 1; Channel channel = 2; - repeated MemoryRegion mr = 3; + repeated MemoryRegion mr = 3; } message GetRemoteAddressResponse { string host_name = 1; Channel channel = 2; - repeated MemoryRegion mr = 3; -} - + repeated MemoryRegion mr = 3; +} //////////////////////////////////////////////////////////////////////////////// // @@ -56,5 +55,6 @@ message GetRemoteAddressResponse { //////////////////////////////////////////////////////////////////////////////// service VerbsService { - rpc GetRemoteAddress(GetRemoteAddressRequest) returns (GetRemoteAddressResponse); + rpc GetRemoteAddress(GetRemoteAddressRequest) + returns (GetRemoteAddressResponse); } diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc index ff4f9219da6..c3350f7958c 100644 --- a/tensorflow/contrib/verbs/verbs_util.cc +++ b/tensorflow/contrib/verbs/verbs_util.cc @@ -22,30 +22,27 @@ namespace tensorflow { // static sync wrapper: Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev, - const DeviceContext* device_context, - TensorProto* proto, bool is_dead) { + const DeviceContext* device_context, + TensorProto* proto, bool is_dead) { Notification n; Status status; - GPUUtil::SetProtoFromGPU(tensor, dev, - device_context, - proto, is_dead, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); return status; } -//static -string VerbsUtil::AppendStepidToKey(const string& key, - int64 step_id) { +// static +string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) { return strings::StrCat(key, ";", step_id); } // static -void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, - string& key, int64& step_id) { +void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key, + int64& step_id) { StringPiece s(key_with_step_id); // a key (with step_id) has exact 6 parts if split by ";" // part 1: src_device; @@ -55,10 +52,10 @@ void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, // part 5: frame_iter.frame_id:frame_iter.iter_id // part 6: step_id std::vector parts = str_util::Split(s, ';'); - CHECK(parts.size()==6) << "Key with step_id must have 6 parts"; + CHECK(parts.size() == 6) << "Key with step_id must have 6 parts"; strings::safe_strto64(parts[5], &step_id); - parts.pop_back(); // remove step_id - key.assign(str_util::Join(parts, ";")); // stitch them together + parts.pop_back(); // remove step_id + key.assign(str_util::Join(parts, ";")); // stitch them together } } // namespace tensorflow diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h index 757e559d966..cbc01adae49 100644 --- a/tensorflow/contrib/verbs/verbs_util.h +++ b/tensorflow/contrib/verbs/verbs_util.h @@ -28,14 +28,13 @@ class TensorProto; class VerbsUtil { public: - // synchronous wrapper of SetProtoFromGPU static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev, - const DeviceContext* device_context, - TensorProto* proto, bool is_dead); + const DeviceContext* device_context, + TensorProto* proto, bool is_dead); static string AppendStepidToKey(const string& key, int64 step_id); - static void GetKeyAndStepId(const string& key_with_step_id, string& key, - int64& step_id); + static void GetKeyAndStepId(const string& key_with_step_id, string& key, + int64& step_id); }; } // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bca81a1dc59..71fba99aad1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1570,6 +1570,7 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "//tensorflow/core/kernels:function_ops", ], alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h index c6d4bdad9c1..49566c8fa8f 100644 --- a/tensorflow/core/common_runtime/build_graph_options.h +++ b/tensorflow/core/common_runtime/build_graph_options.h @@ -30,6 +30,11 @@ struct BuildGraphOptions { // the former via "ref" fetch_endpoints. std::vector target_nodes; + // If `true`, uses Arg/Retval to implement feeds/fetches; otherwise + // uses Recv/Send to implement feeds/fetches. + // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval. + bool use_function_convention = false; + string DebugString() const; }; diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 8c4085425a1..3cd29c8e86e 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -43,7 +43,7 @@ namespace tensorflow { namespace { bool IsConstantFoldable(const Node* n, - std::function consider) { + const std::function& consider) { if (n->op_def().is_stateful()) { return false; } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index b25131b07b5..ffd37faca42 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -71,7 +71,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, if (ri.sender_device_type == src_device_type && ri.receiver_device_type == dst_device_type) { ri.copy_function(send_dev_context, recv_dev_context, src, dst, - src_alloc_attr, dst_alloc_attr, input, output, done); + src_alloc_attr, dst_alloc_attr, input, output, + std::move(done)); return; } } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index c05cceced11..002e246b80d 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { return Status::OK(); } -// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()". Status DirectSession::Run(const NamedTensorList& inputs, const std::vector& output_names, const std::vector& target_nodes, @@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options, executor_step_count, input_tensor_names, output_names, target_nodes)); } + // Configure a call frame for the step, which we use to feed and + // fetch values to and from the executors. + FunctionCallFrame call_frame(executors_and_keys->input_types, + executors_and_keys->output_types); + gtl::InlinedVector feed_args(inputs.size()); + for (const auto& it : inputs) { + if (it.second.dtype() == DT_RESOURCE) { + Tensor tensor_from_handle; + TF_RETURN_IF_ERROR( + ResourceHandleToInputTensor(it.second, &tensor_from_handle)); + feed_args[executors_and_keys->input_name_to_index[it.first]] = + tensor_from_handle; + } else { + feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; + } + } + Status s = call_frame.SetArgs(feed_args); + if (errors::IsInternal(s)) { + return errors::InvalidArgument(s.error_message()); + } else if (!s.ok()) { + return s; + } + // Create a run state and start execution. RunState run_state(args.step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); CancellationManager step_cancellation_manager; - - // Send inputs. - TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez)); + args.call_frame = &call_frame; // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); @@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options, } // Receive outputs. - TF_RETURN_IF_ERROR( - RecvOutputs(output_names, executors_and_keys, &run_state, outputs)); + if (outputs) { + std::vector sorted_outputs; + Status s = call_frame.ConsumeRetvals(&sorted_outputs); + if (errors::IsInternal(s)) { + return errors::InvalidArgument(s.error_message()); + } else if (!s.ok()) { + return s; + } + outputs->clear(); + outputs->reserve(sorted_outputs.size()); + for (const string& output_name : output_names) { + outputs->emplace_back( + std::move(sorted_outputs[executors_and_keys + ->output_name_to_index[output_name]])); + } + } // Save the output tensors of this run we choose to keep. TF_RETURN_IF_ERROR( @@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, CheckFetch(inputs, output_names, executors_and_keys, run_state)); // Send inputs. - Status s = SendInputs(inputs, executors_and_keys, run_state->rendez); + Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez); // Receive outputs. if (s.ok()) { - s = RecvOutputs(output_names, executors_and_keys, run_state, outputs); + s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); } // Save the output tensors of this run we choose to keep. @@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, } } -Status DirectSession::SendInputs(const NamedTensorList& inputs, - const ExecutorsAndKeys* executors_and_keys, - IntraProcessRendezvous* rendez) { +Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, + const ExecutorsAndKeys* executors_and_keys, + IntraProcessRendezvous* rendez) { Status s; Rendezvous::ParsedKey parsed; // Insert the input tensors into the local rendezvous by their // rendezvous key. for (const auto& input : inputs) { - auto it = executors_and_keys->input_keys.find(input.first); - if (it == executors_and_keys->input_keys.end()) { + auto it = + executors_and_keys->input_name_to_rendezvous_key.find(input.first); + if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { return errors::Internal("'", input.first, "' is not a pre-defined feed."); } const string& input_key = it->second; @@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs, return Status::OK(); } -Status DirectSession::RecvOutputs(const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, - RunState* run_state, - std::vector* outputs) { +Status DirectSession::RecvPRunOutputs( + const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + std::vector* outputs) { Status s; if (!output_names.empty()) { outputs->resize(output_names.size()); @@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector& output_names, for (size_t output_offset = 0; output_offset < output_names.size(); ++output_offset) { const string& output_name = output_names[output_offset]; - auto it = executors_and_keys->output_keys.find(output_name); - if (it == executors_and_keys->output_keys.end()) { + auto it = + executors_and_keys->output_name_to_rendezvous_key.find(output_name); + if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { return errors::Internal("'", output_name, "' is not a pre-defined fetch."); } @@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors( options.feed_endpoints = inputs_sorted; options.fetch_endpoints = outputs_sorted; options.target_nodes = tn_sorted; + options.use_function_convention = !run_state_args->is_partial_run; std::shared_ptr ek(new ExecutorsAndKeys); // The executor_lock_ is intentionally released while executor is // being created. std::unordered_map> graphs; - TF_RETURN_IF_ERROR( - CreateGraphs(options, &graphs, &ek->flib_def, run_state_args)); + TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def, + run_state_args, &ek->input_types, + &ek->output_types)); if (run_state_args->is_partial_run) { ek->graph = std::move(run_state_args->graph); @@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors( item->executor.reset(executor); } - // Compute the rendezvous keys to avoid recomputing them every time. - // - // We always use the first device as the device name portion of the - // key, even if we're feeding another graph. - for (const string& input : inputs) { - ek->input_keys[input] = GetRendezvousKey( - input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); - } - for (const string& output : outputs) { - ek->output_keys[output] = GetRendezvousKey( - output, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + // Cache the mapping from input/output names to graph elements to + // avoid recomputing it every time. + if (!run_state_args->is_partial_run) { + // For regular `Run()`, we use the function calling convention, and so + // maintain a mapping from input/output names to + // argument/return-value ordinal index. + for (size_t i = 0; i < inputs_sorted.size(); ++i) { + const string& input = inputs_sorted[i]; + ek->input_name_to_index[input] = i; + } + for (size_t i = 0; i < outputs_sorted.size(); ++i) { + const string& output = outputs_sorted[i]; + ek->output_name_to_index[output] = i; + } + } else { + // For `PRun()`, we use the rendezvous calling convention, and so + // maintain a mapping from input/output names to rendezvous keys. + // + // We always use the first device as the device name portion of the + // key, even if we're feeding another graph. + for (size_t i = 0; i < inputs_sorted.size(); ++i) { + const string& input = inputs_sorted[i]; + ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( + input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + } + for (size_t i = 0; i < outputs_sorted.size(); ++i) { + const string& output = outputs_sorted[i]; + ek->output_name_to_rendezvous_key[output] = + GetRendezvousKey(output, device_set_.client_device()->attributes(), + FrameAndIter(0, 0)); + } } // Reacquire the lock, try to insert into the map. @@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs( const BuildGraphOptions& subgraph_options, std::unordered_map>* outputs, std::unique_ptr* flib_def, - RunStateArgs* run_state_args) { + RunStateArgs* run_state_args, DataTypeVector* input_types, + DataTypeVector* output_types) { mutex_lock l(graph_def_lock_); std::unique_ptr client_graph; @@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs( execution_state->BuildGraph(subgraph_options, &client_graph)); } + if (subgraph_options.feed_endpoints.size() != + client_graph->feed_types.size()) { + return errors::Internal( + "Graph pruning failed: requested number of feed endpoints = ", + subgraph_options.feed_endpoints.size(), + " versus number of pruned feed endpoints = ", + client_graph->feed_types.size()); + } + if (subgraph_options.fetch_endpoints.size() != + client_graph->fetch_types.size()) { + return errors::Internal( + "Graph pruning failed: requested number of fetch endpoints = ", + subgraph_options.fetch_endpoints.size(), + " versus number of pruned fetch endpoints = ", + client_graph->fetch_types.size()); + } + auto current_stateful_placements = execution_state->GetStatefulPlacements(); // Update our current state based on the execution_state's // placements. If there are any mismatches for a node, @@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs( } } *flib_def = std::move(client_graph->flib_def); + std::swap(*input_types, client_graph->feed_types); + std::swap(*output_types, client_graph->fetch_types); return s; } diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index b9d22ac522c..848ef3bc62d 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -132,8 +132,13 @@ class DirectSession : public Session { NameNodeMap name_to_node; std::unique_ptr flib_def; std::vector items; - std::unordered_map input_keys; - std::unordered_map output_keys; + std::unordered_map input_name_to_index; + std::unordered_map input_name_to_rendezvous_key; + std::unordered_map output_name_to_index; + std::unordered_map output_name_to_rendezvous_key; + + DataTypeVector input_types; + DataTypeVector output_types; }; // For each live partial execution, the session maintains a RunState. @@ -187,7 +192,8 @@ class DirectSession : public Session { const BuildGraphOptions& options, std::unordered_map>* outputs, std::unique_ptr* flib_def, - RunStateArgs* run_state_args); + RunStateArgs* run_state_args, DataTypeVector* input_types, + DataTypeVector* output_types); ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); @@ -196,17 +202,17 @@ class DirectSession : public Session { const Tensor& resource_tensor, Tensor* retrieved_tensor); // Feeds more inputs to the executors, triggering further execution. - ::tensorflow::Status SendInputs( + ::tensorflow::Status SendPRunInputs( const std::vector>& inputs, const ExecutorsAndKeys* executors_and_keys, IntraProcessRendezvous* rendez); // Fetches more outputs from the executors. It waits until the output // tensors are computed. - ::tensorflow::Status RecvOutputs(const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, - RunState* run_state, - std::vector* outputs); + ::tensorflow::Status RecvPRunOutputs( + const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + std::vector* outputs); // Check if the specified fetches can be computed from the feeds // that we have already provided. diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 561e185ac4e..ed5b87f2f22 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1434,7 +1434,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { } else { num_outstanding_ops_ = ready.size(); root_frame_->iterations[0]->outstanding_ops = ready.size(); - done_cb_ = done; + done_cb_ = std::move(done); // Schedule to run all the ready ops in thread pool. ScheduleReady(ready, nullptr); } @@ -2560,7 +2560,7 @@ bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview, } void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { - (new ExecutorState(args, this))->RunAsync(done); + (new ExecutorState(args, this))->RunAsync(std::move(done)); } } // end namespace diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 5f011c2ce94..0f2e24690f3 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -604,7 +604,7 @@ struct CustomCreatorSingleton { void Set(CustomKernelCreator cb) { mutex_lock l(mu); - custom_creator = cb; + custom_creator = std::move(cb); } CustomKernelCreator Get() { @@ -621,7 +621,7 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() { } // end namespace void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) { - GetCustomCreatorSingleton()->Set(cb); + GetCustomCreatorSingleton()->Set(std::move(cb)); } FunctionLibraryRuntime* NewFunctionLibraryRuntime( @@ -631,7 +631,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime( CustomKernelCreator custom_kernel_creator) { return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator); + std::move(custom_kernel_creator)); } FunctionLibraryRuntime* NewFunctionLibraryRuntime( diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 29ce157349a..bbf35590eb6 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -44,7 +44,7 @@ Status GetOpSig(const string& op, const OpDef** sig) { void FunctionTestSchedClosure(std::function fn) { static thread::ThreadPool* w = new thread::ThreadPool(Env::Default(), "Test", 8); - w->Schedule(fn); + w->Schedule(std::move(fn)); } void HasError(const Status& s, const string& substr) { @@ -654,7 +654,8 @@ namespace { bool DoNothing(Graph* g) { return false; } -string Optimize(std::function pass, const FunctionDef& fdef) { +string Optimize(const std::function& pass, + const FunctionDef& fdef) { InstantiationResult result; InstantiateAttrValueMap empty; TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 514a63590b1..a85fbbf88ff 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, } // Call RewriteGraphForExecution + subgraph::RewriteGraphMetadata metadata; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( graph_to_run.get(), input_names, output_names, {} /* target nodes */, - cpu_device_->attributes())); + cpu_device_->attributes(), false /* use_function_convention */, + &metadata)); // Create the local executor and the Rendezvous for fetching back the // constants. diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 285ac7540c8..2a2b10c0cff 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -106,7 +106,7 @@ void IntraProcessRendezvous::SameWorkerRecvDone( CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context, recv_args.device_context, src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in, out, - done); + std::move(done)); } void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, @@ -132,7 +132,8 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, }; if (status.ok() && in.IsInitialized()) { - SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback); + SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + std::move(final_callback)); } else { final_callback(status); } diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc index 85a29e11e23..c179e94c36b 100644 --- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc +++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc @@ -21,9 +21,9 @@ limitations under the License. namespace tensorflow { namespace { -// Replaces ReadVariableOp nodes which are only used by Sends and sinks with -// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve -// performance. +// Replaces ReadVariableOp nodes which are only used by Sends, sinks, +// and function Retvals with _UnsafeReadVariable nodes, as this +// transformation is safe and will improve performance. class ResourceVariableReadPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override { @@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass { if (n->type_string() == "ReadVariableOp") { bool skip = false; for (const Edge* e : n->out_edges()) { - if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") { + if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" && + e->dst()->name() != "_SINK") { skip = true; } } diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index c2ac15b345d..31e63a9ef75 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph( if (session_options_ && session_options_->config.graph_options().place_pruned_graph()) { // Rewrite the graph before placement. + rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata); TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( new_graph.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes())); + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, rewrite_metadata_.get())); } // Save stateful placements before placing. @@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph( std::unique_ptr ng(new Graph(flib_def_.get())); CopyGraph(*graph_, ng.get()); + subgraph::RewriteGraphMetadata rewrite_metadata; if (session_options_ == nullptr || !session_options_->config.graph_options().place_pruned_graph()) { // Extract the subset of the graph that needs to be run, adding feed/fetch // ops as needed. TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( ng.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes())); + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, &rewrite_metadata)); + } else { + // This SimpleGraphExecutionState represents a graph that was + // pruned when this was constructed, so we copy the metadata from + // a member variable. + CHECK(rewrite_metadata_); + rewrite_metadata = *rewrite_metadata_; } + CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size()); + CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size()); + // Make a fresh copy of the function library for the client graph. std::unique_ptr flib( new FunctionLibraryDefinition(*flib_def_)); @@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph( // since the local CostModel used to record its stats is sized by // the largest node id. std::unique_ptr dense_copy( - new SimpleClientGraph(std::move(flib))); + new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types, + rewrite_metadata.fetch_types)); CopyGraph(*ng, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index 3b6ce23c754..00b5509fd78 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -39,6 +39,10 @@ struct SessionOptions; class StepStats; class Timeline; +namespace subgraph { +struct RewriteGraphMetadata; +} + struct SimpleGraphExecutionStateOptions { const DeviceSet* device_set = nullptr; const SessionOptions* session_options = nullptr; @@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions { // A SimpleClientGraph is simply a sub-graph of the full graph as induced by // BuildGraphOptions. struct SimpleClientGraph { - explicit SimpleClientGraph(std::unique_ptr flib) - : flib_def(std::move(flib)), graph(flib_def.get()) {} + explicit SimpleClientGraph(std::unique_ptr flib, + DataTypeVector feed_types, + DataTypeVector fetch_types) + : flib_def(std::move(flib)), + graph(flib_def.get()), + feed_types(std::move(feed_types)), + fetch_types(std::move(fetch_types)) {} // Each client-graph gets its own function library since optimization passes // post rewrite for execution might want to introduce new functions. std::unique_ptr flib_def; Graph graph; - int32 placement_version; + DataTypeVector feed_types; + DataTypeVector fetch_types; }; // SimpleGraphExecutionState is responsible for generating an @@ -190,6 +200,10 @@ class SimpleGraphExecutionState { // and may be updated by a graph optimization pass. std::unique_ptr flib_def_; + // `rewrite_metadata_` is only set for SimpleGraphExecutionState + // objects created by `MakeForPrunedGraph()`. + std::unique_ptr rewrite_metadata_; + // The dataflow graph owned by this object. Graph* graph_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 51f1d7f8f99..7160962b168 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -63,8 +63,9 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { }; // static utility function -RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env, - const string& worker_name, WorkerCacheInterface* worker_cache) { +RendezvousMgrInterface* NewRpcRendezvousMgr( + const WorkerEnv* env, const string& worker_name, + WorkerCacheInterface* worker_cache) { return new RpcRendezvousMgr(env, worker_name, worker_cache); } @@ -76,7 +77,7 @@ GrpcServer::GrpcServer(const ServerDef& server_def, Env* env) GrpcServer::~GrpcServer() { TF_CHECK_OK(Stop()); TF_CHECK_OK(Join()); - + delete master_service_; delete worker_service_; @@ -100,7 +101,7 @@ GrpcServer::~GrpcServer() { } Status GrpcServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendevous_mgr_func) { + RendezvousMgrCreationFunction rendevous_mgr_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -193,6 +194,8 @@ Status GrpcServer::Init(ServiceInitFunction service_func, // Set up worker environment. std::unique_ptr rendezvous_mgr( + rendevous_mgr_func == nullptr ? + new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) : rendevous_mgr_func(&worker_env_, name_prefix, worker_cache)); worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), @@ -222,6 +225,10 @@ Status GrpcServer::Init(ServiceInitFunction service_func, return Status::OK(); } +Status GrpcServer::Init() { + return Init(nullptr, nullptr); +} + Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, GrpcChannelSpec* channel_spec) { for (const auto& job : server_def.cluster().job()) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 7924fbfd45c..3b66291a9ab 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -37,14 +37,15 @@ class GrpcWorker; class Master; // function that creates a RendezvousMgr. -typedef std::function - RendezvousMgrCreationFunction; +typedef std::function + RendezvousMgrCreationFunction; // function that registers a service to the server. The service needs to // be registered before builder.BuildAndStart(). -typedef std::function - ServiceInitFunction; +typedef std::function + ServiceInitFunction; class GrpcServer : public ServerInterface { protected: @@ -68,6 +69,8 @@ class GrpcServer : public ServerInterface { Status Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func); + Status Init(); + // A subclass can override this method to support secure credentials. virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( const ServerDef& server_def) const; @@ -90,7 +93,7 @@ class GrpcServer : public ServerInterface { int bound_port() const { return bound_port_; } WorkerEnv* worker_env() { return &worker_env_; } - + const ServerDef& server_def() const { return server_def_; } private: @@ -115,7 +118,7 @@ class GrpcServer : public ServerInterface { // Stop(), Join() enum State { NEW, STARTED, STOPPED }; State state_ GUARDED_BY(mu_); - + // Implementation of a TensorFlow master, and RPC polling thread. MasterEnv master_env_; std::unique_ptr master_impl_; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index edb52737d94..8a7d96c38a9 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { - auto item = rets_[i]; + const auto& item = rets_[i]; if (item.has_val) { rets->push_back(item.val); } else { @@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { return Status::OK(); } +Status FunctionCallFrame::ConsumeRetvals(std::vector* rets) { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + if (rets_[i].has_val) { + rets->emplace_back(std::move(rets_[i].val)); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + Status FunctionCallFrame::GetArg(int index, Tensor* val) const { if (index < 0 || static_cast(index) >= args_.size()) { return errors::InvalidArgument("GetArg ", index, " is not within [0, ", diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 63c868ac9b8..210e5b949a5 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -259,6 +259,7 @@ class FunctionCallFrame { // Caller methods. Status SetArgs(gtl::ArraySlice args); Status GetRetvals(std::vector* rets) const; + Status ConsumeRetvals(std::vector* rets); // Callee methods. Status GetArg(int index, Tensor* val) const; diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 6b92e66f3c8..e45f156e1e5 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -126,29 +126,33 @@ FunctionDef XTimes16() { {{"y", "y:y:0"}}); } -FunctionDef WXPlusB() { - return FDH::Define( - // Name - "WXPlusB", - // Args - {"w: T", "x: T", "b: T"}, - // Return values - {"y: T"}, - // Attr def - {"T: {float, double}"}, - // Nodes - {{{"mm"}, - "MatMul", - {"w", "x"}, - {{"T", "$T"}, - {"transpose_a", false}, - {"transpose_b", false}, -#ifdef INTEL_MKL - }}, +FunctionDef WXPlusB(){return FDH::Define( + // Name + "WXPlusB", + // Args + {"w: T", "x: T", "b: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + { + {{"mm"}, + "MatMul", + {"w", "x"}, + { + {"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}, +#ifdef INTEL_MKL + }}, #else {"_kernel", "eigen"}}}, #endif - {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}}); + { + {"y"}, "Add", {"mm", "b"}, { + { "T", "$T" } + } + } + }); } FunctionDef Swap() { diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index d27a6702e92..09b632a1650 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -63,7 +63,7 @@ namespace tensorflow { // P = BiasAdd(O, C) // // We merge them into Conv2DWithBias as: -// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m) +// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) // // The meaning of A_m, B_m and C_m is explained in B.1. // @@ -115,7 +115,7 @@ namespace tensorflow { // Since every rewritten node generates twice the number of inputs and // outputs, one could imagine various orderings among Tensorflow tensors // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as -// inputs, then the new op 'MklConv2D' can take inputs A, B, A_m and B_m +// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m // order. Among N inputs one can get N! permutations. // @@ -239,15 +239,15 @@ namespace tensorflow { // ------------------------------------------- // Consider BiasAddGrad op as: // -// O = MklConv2D(A, B, C, A_m, B_m, C_m) +// O = _MklConv2D(A, B, C, A_m, B_m, C_m) // P = BiasAddGrad(O) // // Then we rewrite it as: // // P = Conv2DWithBiasBackpropBias(O, O_m) // -// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is -// the context matching depth. If MklConv2DWithBias is not within the context +// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is +// the context matching depth. If _MklConv2DWithBias is not within the context // matching depth, then we do not rewrite BiasAddGrad. // How many hops do we search for matching node in the backward dataflow graph? @@ -261,74 +261,66 @@ class MklLayoutRewritePass : public GraphOptimizationPass { public: MklLayoutRewritePass() { // NOTE: names are alphabetically sorted. - csinfo_.avg_pool = "AvgPool"; - csinfo_.avg_pool_grad = "AvgPoolGrad"; - csinfo_.bias_add = "BiasAdd"; - csinfo_.bias_add_grad = "BiasAddGrad"; - csinfo_.concat = "Concat"; - csinfo_.concatv2 = "ConcatV2"; - csinfo_.conv2d = "Conv2D"; - csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; - csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; - csinfo_.fused_batch_norm = "FusedBatchNorm"; + csinfo_.avg_pool = "AvgPool"; + csinfo_.avg_pool_grad = "AvgPoolGrad"; + csinfo_.bias_add = "BiasAdd"; + csinfo_.bias_add_grad = "BiasAddGrad"; + csinfo_.concat = "Concat"; + csinfo_.concatv2 = "ConcatV2"; + csinfo_.conv2d = "Conv2D"; + csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; + csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; + csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; - csinfo_.lrn = "LRN"; - csinfo_.lrn_grad = "LRNGrad"; - csinfo_.matmul = "MatMul"; - csinfo_.max_pool = "MaxPool"; - csinfo_.max_pool_grad = "MaxPoolGrad"; - csinfo_.mkl_conv2d = "MklConv2D"; - csinfo_.mkl_conv2d_with_bias = "MklConv2DWithBias"; + csinfo_.lrn = "LRN"; + csinfo_.lrn_grad = "LRNGrad"; + csinfo_.matmul = "MatMul"; + csinfo_.max_pool = "MaxPool"; + csinfo_.max_pool_grad = "MaxPoolGrad"; + csinfo_.mkl_conv2d = "_MklConv2D"; + csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = - "MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.reshape = "Reshape"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.split = "Split"; + "_MklConv2DWithBiasBackpropBias"; + csinfo_.relu = "Relu"; + csinfo_.reshape = "Reshape"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.split = "Split"; // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.avg_pool, - GetMklOpName(csinfo_.avg_pool), - 1, CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1, + CopyAttrsPooling, AlwaysRewrite}); rinfo_.push_back({csinfo_.avg_pool_grad, - GetMklOpName(csinfo_.avg_pool_grad), - 2, CopyAttrsPooling, AlwaysRewrite}); - rinfo_.push_back({csinfo_.concat, - GetMklOpName(csinfo_.concat), - 0, CopyAttrsConcat, AlwaysRewrite}); - rinfo_.push_back({csinfo_.concatv2, - GetMklOpName(csinfo_.concatv2), - 0, CopyAttrsConcatV2, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2d, - GetMklOpName(csinfo_.conv2d), - 2, CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0, + CopyAttrsConcat, AlwaysRewrite}); + rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0, + CopyAttrsConcatV2, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2, + CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter, - GetMklOpName(csinfo_.conv2d_grad_filter), - 3, CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_filter), 3, + CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_input, - GetMklOpName(csinfo_.conv2d_grad_input), - 3, CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_input), 3, + CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm, - GetMklOpName(csinfo_.fused_batch_norm), - 5, CopyAttrsFusedBatchNorm, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm), 5, + CopyAttrsFusedBatchNorm, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm_grad, - GetMklOpName(csinfo_.fused_batch_norm_grad), - 5, CopyAttrsFusedBatchNorm, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn, - GetMklOpName(csinfo_.lrn), - 1, CopyAttrsLRN, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn_grad, - GetMklOpName(csinfo_.lrn_grad), - 3, CopyAttrsLRN, AlwaysRewrite}); - rinfo_.push_back({csinfo_.max_pool, - GetMklOpName(csinfo_.max_pool), - 1, CopyAttrsPooling, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm_grad), 5, + CopyAttrsFusedBatchNorm, AlwaysRewrite}); + rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3, + CopyAttrsLRN, AlwaysRewrite}); + rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1, + CopyAttrsPooling, AlwaysRewrite}); rinfo_.push_back({csinfo_.max_pool_grad, - GetMklOpName(csinfo_.max_pool_grad), - 3, CopyAttrsPooling, AlwaysRewrite}); - rinfo_.push_back({csinfo_.relu, - GetMklOpName(csinfo_.relu), - 1, CopyAttrsRelu, AlwaysRewrite}); + GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1, + CopyAttrsRelu, AlwaysRewrite}); rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2, CopyAttrsReshape, AlwaysRewrite}); @@ -339,8 +331,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); // Add a rule for merging nodes - minfo_.push_back( - {csinfo_.mkl_conv2d, csinfo_.bias_add, 0, csinfo_.mkl_conv2d_with_bias}); + minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, + csinfo_.mkl_conv2d_with_bias}); // We use maxhop of 10 based on empirical observations. Also, these are // maxhops in backward data-flow graph. Since input of forward nodes @@ -374,7 +366,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // A function handler to copy attributes from an old node to a new node. std::function copy_attrs; std::function rewrite_rule; // A rule under which to - // rewrite this node. + // rewrite this node. } RewriteInfo; /// Structure to specify a forward op, a backward op, and the slot numbers @@ -477,7 +469,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // Concat, Split are vararg nodes. inline bool IsVarArgNode(Node* n) { - if (n->type_string() == csinfo_.concat || + if (n->type_string() == csinfo_.concat || n->type_string() == csinfo_.concatv2 || n->type_string() == csinfo_.split) { return true; @@ -496,9 +488,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { CHECK_EQ(ArgIsList(arg), true); int N = 0; - const string attr_name = !arg.type_list_attr().empty() ? - arg.type_list_attr() : - arg.number_attr(); + const string attr_name = !arg.type_list_attr().empty() + ? arg.type_list_attr() + : arg.number_attr(); if (!arg.type_list_attr().empty()) { std::vector value; TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); @@ -514,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // TODO(nhasabni) We should move this to mkl_util.h. inline string GetMklOpName(const string& name) const { // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "Mkl"; + const char* const kMklOpPrefix = "_Mkl"; return string(kMklOpPrefix) + name; } @@ -598,9 +590,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // @return None void GetNodesProducingTFTensorList( - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get nodes that will feed a list of Mkl tensors to the new // node that we are constructing. @@ -616,10 +608,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList(std::unique_ptr* g, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList( + std::unique_ptr* g, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -635,7 +628,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // will feed the tensor // @return None void GetNodeProducingMklTensor(std::unique_ptr* g, Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); + int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -648,11 +642,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // Returns Status::OK() if setting up inputs is successful, otherwise // returns appropriate status code. - int SetUpContiguousInputs(std::unique_ptr* g, - const gtl::InlinedVector, 4>& old_node_inputs, - NodeBuilder* nb, Node* old_node, - std::vector* workspace_tensors, - bool are_workspace_tensors_available); + int SetUpContiguousInputs( + std::unique_ptr* g, + const gtl::InlinedVector, 4>& old_node_inputs, + NodeBuilder* nb, Node* old_node, + std::vector* workspace_tensors, + bool are_workspace_tensors_available); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'orig_node'. @@ -672,8 +667,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // tensors, if they need to be added, will be set into these tensors. // If we set workspace tensors, then are_ws_tensors_added should be true. void AddWorkSpaceEdgeIfNeeded(std::unique_ptr* g, Node* orig_node, - NodeBuilder* nb, std::vector* ws_tensors, - bool* are_ws_tensors_added); + NodeBuilder* nb, + std::vector* ws_tensors, + bool* are_ws_tensors_added); // Functions specific to operators to copy attributes // We need operator-specific function to copy attributes because the framework @@ -732,9 +728,8 @@ static void FillInputs(const Node* n, } void MklLayoutRewritePass::GetNodesProducingTFTensorList( - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -767,34 +762,33 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList( } // TODO(nhasabni) We should move this to mkl_util.h. -void MklLayoutRewritePass::GetDummyMklTensorNode( - std::unique_ptr* g, Node** out, Node* orig_node) { +void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, + Node** out, Node* orig_node) { // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent // dummy Mkl tensor. 8 = 2*size_t. const DataType dt = DataTypeToEnum::v(); TensorProto proto; proto.set_dtype(dt); uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - proto.set_tensor_content(const_cast( - static_cast(&zero)), 8); + proto.set_tensor_content(const_cast(static_cast(&zero)), + 8); TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } void MklLayoutRewritePass::GetNodesProducingMklTensorList( std::unique_ptr* g, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -819,8 +813,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // If it is a list, then create a list of Mkl dummy nodes. for (int j = 0; j < N; j++) { GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); } (*input_idx)++; list_length -= N; @@ -829,8 +823,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( Node* mkl_node = nullptr; int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); (*input_idx)++; list_length--; } @@ -841,9 +835,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, - Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr* g, Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -859,8 +853,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, // output slot number for Mkl tensor would be N+slot number of TensorFlow // tensor, where N is total number of TensorFlow tensors. *mkl_node = n; - *mkl_node_output_slot = GetTensorMetaDataIndex(n_output_slot, - n->num_outputs()); + *mkl_node_output_slot = + GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); } else { // If we have not visited the node and rewritten it, then we need // to create a dummy node that will feed a dummy Mkl tensor to this node. @@ -872,7 +866,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, } } -int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr* g, +int MklLayoutRewritePass::SetUpContiguousInputs( + std::unique_ptr* g, const gtl::InlinedVector, 4>& old_node_inputs, NodeBuilder* nb, Node* old_node, std::vector* workspace_tensors, @@ -931,16 +926,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr* g, if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, - N, &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N, + &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { Node* mkl_node = nullptr; int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, - &mkl_node, &mkl_node_output_slot); + old_node_inputs[iidx].second, &mkl_node, + &mkl_node_output_slot); nb->Input(mkl_node, mkl_node_output_slot); iidx++; nn_slot_idx++; @@ -961,7 +956,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr* g, return nn_slot_idx; } -Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr* g, +Status MklLayoutRewritePass::SetUpInputs( + std::unique_ptr* g, const gtl::InlinedVector, 4>& old_node_inputs, NodeBuilder* nb, Node* old_node) { // Let's check if we need to add workspace tensors for this node. @@ -975,13 +971,14 @@ Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr* g, if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { // TODO(nhasabni): implement this function just for same of completion. // We do not use interleaved ordering right now. - return Status(error::Code::UNIMPLEMENTED, - "Interleaved ordering of tensors is currently not supported."); + return Status( + error::Code::UNIMPLEMENTED, + "Interleaved ordering of tensors is currently not supported."); } else { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - new_node_input_slots = SetUpContiguousInputs(g, old_node_inputs, nb, - old_node, &workspace_tensors, - are_workspace_tensors_available); + new_node_input_slots = SetUpContiguousInputs( + g, old_node_inputs, nb, old_node, &workspace_tensors, + are_workspace_tensors_available); } // Sanity check @@ -1023,20 +1020,19 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } -void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr* g, - Node* orig_node, NodeBuilder* nb, - std::vector* ws_tensors, - bool* are_ws_tensors_added) { +void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( + std::unique_ptr* g, Node* orig_node, NodeBuilder* nb, + std::vector* ws_tensors, bool* are_ws_tensors_added) { bool workspace_edge_added = false; // Default initializer CHECK_NOTNULL(are_ws_tensors_added); *are_ws_tensors_added = false; // Default initializer @@ -1071,8 +1067,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr* g, nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp( - GetMklOpName(orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this // op. Corresponding fwd op is specified in 'fwd_op' field of @@ -1094,8 +1090,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr* g, // Add workspace edge between fwd op and bwd op. ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); // Add Mkl tensor edge for workspace edge between fwd op and bwd op. - ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), - DataIndexToMetaDataIndex(ws.ws_fwd_slot, e->src()->num_outputs()))); + ws_tensors->push_back(NodeBuilder::NodeOut( + e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, + e->src()->num_outputs()))); *are_ws_tensors_added = true; // In terms of input ordering, we add these calls to add Input // here because workspace edge (and its Mkl tensor) is the last @@ -1154,8 +1151,8 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", - &use_cudnn_on_gpu)); + TF_CHECK_OK( + GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); // Add attributes to new node. nb->Attr("T", T); @@ -1307,14 +1304,14 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, } void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { + NodeBuilder* nb) { DataType T; DataType Tshape; - + // Get all attributes from old node. TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); - + // Add attributes to new node. nb->Attr("T", T); nb->Attr("Tshape", Tshape); @@ -1435,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* succ, // 2. Get inputs from both the nodes. // Find the 2 inputs from the conv and the bias from the add Bias. // Get operand 0, 1 of conv2D and their Mkl tensors. - CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs. + CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs. // Get operand 1 of add_bias // BiasAdd must have 2 inputs: Conv, bias CHECK_EQ(succ->in_edges().size(), 2); @@ -1538,15 +1535,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, DataType orig_T, ctx_T; string orig_data_format, ctx_data_format; TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", - &orig_data_format)); + TF_CHECK_OK( + GetNodeAttr(orig_node->def(), "data_format", &orig_data_format)); TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T)); - TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "data_format", - &ctx_data_format)); + TF_CHECK_OK( + GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format)); if (orig_data_format != ctx_data_format || orig_T != ctx_T || orig_node->assigned_device_name() != - fwd_node->assigned_device_name() || + fwd_node->assigned_device_name() || orig_node->def().device() != fwd_node->def().device()) { return Status( error::Code::INVALID_ARGUMENT, @@ -1613,9 +1610,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, if (e->src_output() < 0) { (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); } else { - (*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), - e->src()->num_outputs()), - e->dst(), e->dst_input()); + (*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input()); } } diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 5b9201939da..6e72baf84e2 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -110,13 +110,11 @@ class MklLayoutPassTest : public ::testing::Test { }; REGISTER_OP("Input").Output("o: float").SetIsStateful(); -REGISTER_OP("InputList").Output("o: N * float") - .Attr("N: int") - .SetIsStateful(); +REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); -REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -137,16 +135,16 @@ TEST_F(MklLayoutPassTest, Basic) { // Test set 1: Conv2D + AddBias -// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering) -// C=MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering) +// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering) +// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -163,22 +161,22 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);" - "M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" + "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" + "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" "DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1"); } -// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved) -// C=MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous) +// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved) +// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous) // Test for correct output slots selected TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput2'}" - "node { name: 'N' op: 'MklInput2'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput2'}" + "node { name: 'N' op: '_MklInput2'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -195,15 +193,15 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);" - "M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" + "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" + "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" "DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1"); } // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y); // This is a case of node rewrite followed by node merge. -// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D -// with BiasAdd to produce MklConv2DWithBias. +// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D +// with BiasAdd to produce _MklConv2DWithBias. TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( @@ -227,19 +225,19 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|" + "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|" "A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;" "E->Z;Y->Z:1"); } -// Graph contains only MklConv2D, no AddBias. +// Graph contains only _MklConv2D, no AddBias. TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -247,18 +245,18 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'M', 'N']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|" - "A->C;B->C:1;M->C:2;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|" + "A->C;B->C:1;M->C:2;N->C:3"); } -// MklConv2D output does not go to BiasAdd. +// _MklConv2D output does not go to BiasAdd. TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -270,21 +268,21 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { "node { name: 'F' op: 'BiasAdd'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd. + " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd. EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);" - "M(MklInput);N(MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" + "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3"); } -// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add). +// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add). // Merge should not be done in such case. TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -302,8 +300,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);" - "G(Add);M(MklInput);N(MklInput)|A->C;B->C:1;C->G;D->F;" + "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" + "G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;" "E->F:1;E->G:1;M->C:2;N->C:3"); } @@ -313,9 +311,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -328,26 +326,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { " attr { key: 'data_format' value { s: 'NHCW' } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);" - "N(MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);" + "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3"); } // Disabling Conv2DBackpropBias test for now as we have disabled rewrite // of BiasAddGrad into BackpropBias #if 0 -// Test set 2: MklConv2D..BiasAddGrad -> MklConv2DWithBiasBackpropBias +// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias // rewrite tests -// D=MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E) +// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'O' op: 'MklInput'}" - "node { name: 'D' op: 'MklConv2DWithBias'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -362,25 +360,25 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(MklConv2DWithBias);DMT/_0(Const);" - "E(Sub);F(MklConv2DWithBiasBackpropBias);M(MklInput);N(MklInput);" - "O(MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;" + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" + "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" + "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;" "M->D:3;N->D:4;O->D:5"); } #endif -// No MklConv2D in context, but Conv2D in context. -// Only Conv2D would be rewritten to MklConv2D, but no rewrite +// No _MklConv2D in context, but Conv2D in context. +// Only Conv2D would be rewritten to _MklConv2D, but no rewrite // for BiasAddGrad should happen. -// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved) -// C=MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { +// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved) +// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous) +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" @@ -395,8 +393,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);" - "M(MklInput);N(MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" + "A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);" + "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" "M->C:2;N->C:3"); } @@ -509,7 +507,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|" + "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|" "A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); } @@ -536,7 +534,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);" + "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;" "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); } @@ -578,7 +576,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);" + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;" "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } @@ -617,8 +615,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);" - "F(MklConv2D);G(Const);H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;" + "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" + "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1"); } @@ -652,8 +650,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);" - "H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" + "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" + "H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;" "G->H;H->I:1"); } @@ -678,7 +676,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);" + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;" "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } @@ -719,8 +717,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);" - "F(MklConv2D);G(Const);H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;" + "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" + "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" "DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1"); } @@ -755,8 +753,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);" - "H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" + "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" + "H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;" "G->H:2;H->I:1"); } @@ -804,9 +802,10 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklLRN);C(MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(MklMaxPoolGrad);F(Input);G(MklLRNGrad);H(Input);I(Mul)|" + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" + "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|" "A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;" "C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;" "DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I"); @@ -837,8 +836,8 @@ TEST_F(MklLayoutPassTest, LRN_Positive) { "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(MklLRNGrad);F(Mul)|" + "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" + "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|" "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;" "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); } @@ -858,7 +857,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) { "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklLRN);C(Mul);DMT/_0(Const)|" + "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|" "A->B;A->C;B->C:1;DMT/_0->B:1"); } @@ -879,7 +878,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(MklLRNGrad);DMT/_0(Const);" + "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;" "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); @@ -919,9 +918,9 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) { "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'F'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" + "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" - "DMT/_6(Const);E(MklLRNGrad);F(MklLRNGrad);G(Mul)|A->B;B->E:2;" + "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;" "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;" "D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); @@ -950,8 +949,8 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|" + "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|" "A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;" "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); } @@ -972,7 +971,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|" + "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|" "A->B;A->C;B->C:1;DMT/_0->B:1"); } @@ -994,7 +993,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);" + "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;" "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index affe9b31c37..55c280719c3 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -123,22 +123,24 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype)); if (src_datatype != dst_datatype) { - string err_msg = "T attribute of " + src->name() + " and " + - dst->name() + " do not match. Will not insert" + + string err_msg = "T attribute of " + src->name() + " and " + dst->name() + + " do not match. Will not insert" + " MklToTf node in such case."; return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); } // Build the conversion node and specify src as input. - TF_CHECK_OK(NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf") - .Input(src, e->src_output()) - .Input(src, DataIndexToMetaDataIndex( - e->src_output(), src->num_outputs())) // Get an Mkl tensor slot - // from the Tf tensor slot. - .Device(src->def().device()) // We want to get conversion node - // on same device as source node. - .Attr("T", src_datatype) - .Finalize(&**g, &conversion_node)); + TF_CHECK_OK( + NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf") + .Input(src, e->src_output()) + .Input(src, DataIndexToMetaDataIndex( + e->src_output(), + src->num_outputs())) // Get an Mkl tensor slot + // from the Tf tensor slot. + .Device(src->def().device()) // We want to get conversion node + // on same device as source node. + .Attr("T", src_datatype) + .Finalize(&**g, &conversion_node)); CHECK_NOTNULL(conversion_node); if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) { @@ -191,8 +193,8 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { // We skip adding MklToTf on an edge between X->MklToTf or // MklToTf->X, where X is any node. - if (src->type_string().compare("MklToTf") == 0 || - dst->type_string().compare("MklToTf") == 0) { + if (src->type_string().compare("_MklToTf") == 0 || + dst->type_string().compare("_MklToTf") == 0) { continue; } @@ -246,8 +248,7 @@ bool InsertMklToTfConversionNodes(std::unique_ptr* g) { return MklToTfConversionPass().RunPass(g); } -Status MklToTfConversionPass::Run( - const GraphOptimizationPassOptions& options) { +Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index 2fce0fbfc7c..bd2cb0989c1 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -15,8 +15,8 @@ limitations under the License. #ifdef INTEL_MKL -#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/graph/mkl_tfconversion_pass.h" +#include "tensorflow/core/util/mkl_util.h" #include #include @@ -110,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test { REGISTER_OP("Input").Output("o: float").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); -REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful(); +REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); TEST_F(MklToTfConversionPass, Basic) { InitGraph( @@ -131,47 +131,49 @@ TEST_F(MklToTfConversionPass, Basic) { TEST_F(MklToTfConversionPass, Positive) { if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'M', 'B', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Sub'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D']}"); + "node { name: 'A' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } " + "}" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'M', 'B', 'N']}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), - "A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);" - "Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;" - "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);" + "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;" + "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3"); } else { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Sub'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D']}"); + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } " + "}" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'M', 'N']}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), - "A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);" - "Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;" - "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);" + "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;" + "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3"); } } @@ -182,55 +184,57 @@ TEST_F(MklToTfConversionPass, Positive) { TEST_F(MklToTfConversionPass, Negative_DoubleInsert) { if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'M', 'B', 'N']}" - "node { name: 'D' op: 'MklToTf'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C:0', 'C:1']}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'Sub'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'E']}"); + "node { name: 'A' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } " + "}" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'M', 'B', 'N']}" + "node { name: 'D' op: '_MklToTf'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['C:0', 'C:1']}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'E']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), - "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);" - "F(Sub);M(MklInput);N(MklInput)|" - "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);" + "F(Sub);M(_MklInput);N(_MklInput)|" + "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3"); } else { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: 'MklInput'}" - "node { name: 'N' op: 'MklInput'}" - "node { name: 'C' op: 'MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'MklToTf'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C:0', 'C:1']}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'Sub'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'E']}"); + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } " + "}" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'M', 'N']}" + "node { name: 'D' op: '_MklToTf'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['C:0', 'C:1']}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'E']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), - "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);" - "F(Sub);M(MklInput);N(MklInput)|" - "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);" + "F(Sub);M(_MklInput);N(_MklInput)|" + "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3"); } } @@ -258,7 +262,7 @@ TEST_F(MklToTfConversionPass, Negative_NoMklLayer) { " input: ['E', 'Y']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), "A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|" - "A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1"); + "A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1"); } static void BM_RunMklToTfConversionPass(int iters, int op_nodes) { diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 91292500e1e..9849d9a1596 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -55,8 +55,13 @@ namespace { // state). static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, const gtl::ArraySlice& fed_outputs, - subgraph::NameIndex* name_index) { - for (const string& t : fed_outputs) { + bool use_function_convention, + subgraph::NameIndex* name_index, + DataTypeVector* out_feed_types) { + out_feed_types->clear(); + out_feed_types->reserve(fed_outputs.size()); + for (size_t i = 0; i < fed_outputs.size(); ++i) { + const string& t = fed_outputs[i]; TensorId id(ParseTensorName(t)); auto iter = name_index->find(id.first); @@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, } Node* recv_node; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), - "_Recv") - .Attr("tensor_type", BaseType(n->output_type(id.second))) - .Attr("tensor_name", t) - .Attr("send_device", device_info.name()) - .Attr("recv_device", device_info.name()) - .Attr("send_device_incarnation", - static_cast(device_info.incarnation())) - .Attr("client_terminated", true) - .Finalize(g, &recv_node)); + + if (!use_function_convention) { + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), + "_Recv") + .Attr("tensor_type", BaseType(n->output_type(id.second))) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &recv_node)); + } else { + // NOTE(mrry): We must include the index as part of the node + // name, because _Arg is a "stateful" kernel and therefore + // its name must uniquely identify a kernel instance across all + // graphs in the same session. + TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_", + id.second, "_", i), + "_Arg") + .Attr("T", BaseType(n->output_type(id.second))) + .Attr("index", static_cast(i)) + .Finalize(g, &recv_node)); + } recv_node->set_assigned_device_name(device_info.name()); // Copy the _output_shapes from the original node to the feed node, @@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, } g->RemoveEdge(e); } + out_feed_types->push_back(BaseType(n->output_type(id.second))); } return Status::OK(); } @@ -181,9 +201,14 @@ namespace subgraph { Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, const gtl::ArraySlice& fetch_outputs, - NameIndex* name_index, std::vector* fetch_nodes) { - fetch_nodes->clear(); - for (const string& t : fetch_outputs) { + bool use_function_convention, NameIndex* name_index, + std::vector* out_fetch_nodes, + DataTypeVector* out_fetch_types) { + out_fetch_nodes->clear(); + out_fetch_nodes->reserve(fetch_outputs.size()); + for (size_t i = 0; i < fetch_outputs.size(); ++i) { + const string& t = fetch_outputs[i]; + // Parse t into node_name and output_index. TensorId id(ParseTensorName(t)); @@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, // Create the fetch Node and connect it up Node* send_node; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), - "_Send") - .Input(n, id.second) - .Attr("tensor_name", t) - .Attr("send_device", device_info.name()) - .Attr("recv_device", device_info.name()) - .Attr("send_device_incarnation", - static_cast(device_info.incarnation())) - .Attr("client_terminated", true) - .Finalize(g, &send_node)); + if (!use_function_convention) { + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), + "_Send") + .Input(n, id.second) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &send_node)); + } else { + // NOTE(mrry): We must include the index as part of the node + // name, because _Retval is a "stateful" kernel and therefore + // its name must uniquely identify a kernel instance across all + // graphs in the same session. + TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_", + id.second, "_", i), + "_Retval") + .Input(n, id.second) + .Attr("T", BaseType(n->output_type(id.second))) + .Attr("index", static_cast(i)) + .Finalize(g, &send_node)); + } send_node->set_assigned_device_name(device_info.name()); - VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def()); // Update the index. (*name_index)[send_node->name()] = send_node; g->AddControlEdge(send_node, g->sink_node()); - fetch_nodes->push_back(send_node); + out_fetch_nodes->push_back(send_node); + out_fetch_types->push_back(BaseType(n->output_type(id.second))); } return Status::OK(); @@ -241,7 +280,8 @@ Status RewriteGraphForExecution( Graph* g, const gtl::ArraySlice& fed_outputs, const gtl::ArraySlice& fetch_outputs, const gtl::ArraySlice& target_node_names, - const DeviceAttributes& device_info) { + const DeviceAttributes& device_info, bool use_function_convention, + RewriteGraphMetadata* out_metadata) { if (fetch_outputs.empty() && target_node_names.empty()) { return errors::InvalidArgument( "Must specify at least one target to fetch or execute."); @@ -274,18 +314,21 @@ Status RewriteGraphForExecution( // currently listed in "fetch_nodes". We pass "name_index" so the index is // kept up to date. if (!fed_outputs.empty()) { - TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index)); + TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, + use_function_convention, &name_index, + &out_metadata->feed_types)); } // Add the fetch nodes, also updating "name_index". std::vector fetch_nodes; if (!fetch_outputs.empty()) { - TF_RETURN_IF_ERROR( - FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes)); + TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs, + use_function_convention, &name_index, + &fetch_nodes, &out_metadata->fetch_types)); } // Prune the graph to only compute what is needed for the fetch nodes and the - // targets nodes. + // target nodes. if (!fetch_nodes.empty() || !target_node_names.empty()) { TF_RETURN_IF_ERROR( PruneForTargets(g, name_index, fetch_nodes, target_node_names)); diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h index d94d983d000..8ccc27914bc 100644 --- a/tensorflow/core/graph/subgraph.h +++ b/tensorflow/core/graph/subgraph.h @@ -26,6 +26,18 @@ limitations under the License. namespace tensorflow { namespace subgraph { +// Information about a graph rewritten by `RewriteGraphForExecution()`. +struct RewriteGraphMetadata { + // The element type of each tensor fed to this subgraph. The order + // of types corresponds to the order of tensor names in + // `fed_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector feed_types; + // The element type of each tensor fetched from this subgraph. The + // order of types corresponds to the order of tensor names in + // `fetch_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector fetch_types; +}; + // Rewrite the graph structure of "*g" to deal with feeding node // outputs, fetching node outputs, and only running a subset of the // graph. "fed_outputs" and "fetch_outputs" are both lists of @@ -56,7 +68,8 @@ Status RewriteGraphForExecution( Graph* g, const gtl::ArraySlice& fed_outputs, const gtl::ArraySlice& fetch_outputs, const gtl::ArraySlice& target_node_names, - const DeviceAttributes& device_info); + const DeviceAttributes& device_info, bool use_function_convention, + RewriteGraphMetadata* out_metadata); typedef std::unordered_map NameIndex; diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index ee4960121f5..3dc11b7a166 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test { } string Subgraph(const string& fed_str, const string& fetch_str, - const string& targets_str) { + const string& targets_str, + bool use_function_convention = false) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(*g_, subgraph); std::vector fed = @@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test { std::vector targets = str_util::Split(targets_str, ',', str_util::SkipEmpty()); - Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets, - device_info_); + subgraph::RewriteGraphMetadata metadata; + Status s = subgraph::RewriteGraphForExecution( + subgraph, fed, fetch, targets, device_info_, use_function_convention, + &metadata); if (!s.ok()) { delete subgraph; return s.ToString(); } + EXPECT_EQ(fed.size(), metadata.feed_types.size()); + EXPECT_EQ(fetch.size(), metadata.fetch_types.size()); + // Replace the graph with the subgraph for the rest of the display program g_.reset(subgraph); return "OK"; @@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) { ExpectNodes("W1,W2,_recv_input_1,t1,t2"); } +TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", + Subgraph("input:1", "", "t2", true /* use_function_convention */)); + ExpectNodes("W1,W2,_arg_input_1_0,t1,t2"); +} + TEST_F(SubgraphTest, FedRefNode) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" @@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) { EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); } -TEST_F(SubgraphTest, FedOutputs2) { +TEST_F(SubgraphTest, FedRefNode_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); + EXPECT_EQ("OK", + Subgraph("W1:0", "", "t1", true /* use_function_convention */)); + ExpectNodes("_arg_W1_0_0,W2,t1"); + Node* n = FindNode("_arg_W1_0_0"); + EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); +} + +TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" @@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) { "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); // We feed input:1, but nothing connects to it, so the _recv(input:1) // node also disappears. - EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); - ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); + EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2", + true /* use_function_convention */)); + ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2"); } TEST_F(SubgraphTest, FetchOutputs1) { @@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) { "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); } +TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2", + true /* use_function_convention */)); + ExpectNodes( + "W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_" + "retval_t2_0_3"); +} + TEST_F(SubgraphTest, FetchOutputs2) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" @@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) { ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); } +TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", + Subgraph("", "t3_a", "t2", true /* use_function_convention */)); + ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0"); +} + TEST_F(SubgraphTest, ChainOfFools) { ExpectOK( "node { name: 'a' op: 'TestParams' }" @@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); -static void BM_Subgraph(int iters, int num_nodes) { +static void BM_SubgraphHelper(int iters, int num_nodes, + bool use_function_convention) { DeviceAttributes device_info; device_info.set_name("/job:a/replica:0/task:0/cpu:0"); device_info.set_device_type(DeviceType(DEVICE_CPU).type()); @@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) { while (--iters > 0) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(g, subgraph); - TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, - targets, device_info)); + subgraph::RewriteGraphMetadata metadata; + TF_CHECK_OK(subgraph::RewriteGraphForExecution( + subgraph, fed, fetch, targets, device_info, use_function_convention, + &metadata)); delete subgraph; } } + +static void BM_Subgraph(int iters, int num_nodes) { + BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */); +} +static void BM_SubgraphFunctionConvention(int iters, int num_nodes) { + BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */); +} BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); +BENCHMARK(BM_SubgraphFunctionConvention) + ->Arg(100) + ->Arg(1000) + ->Arg(10000) + ->Arg(100000); } // namespace } // namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e928e21264e..266d74976fe 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -18,6 +18,13 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsDequeueOp(const NodeDef& node) { + static const std::set dequeue_ops = { + "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", + "QueueDequeue"}; + return dequeue_ops.count(node.op()) > 0; +} + bool IsPlaceholder(const NodeDef& node) { const auto op = node.op(); return op == "Placeholder" || op == "PlaceholderV2"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 2f83325c9da..2f58835628d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -21,6 +21,7 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsDequeueOp(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); bool IsVariable(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 64d5815bf78..d7a7989dfad 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -40,6 +40,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", ], diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc index 96fb9d3792a..078fb10bc95 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/auto_parallel.h" + #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -94,22 +97,22 @@ Status AutoParallel::Initialize(const GrapplerItem& item) { VLOG(2) << "Variable: " << var->name(); } - std::set apply_gradients_ops = {"ApplyGradientDescent", - "ApplyProximalGradientDescent", - "ApplyAdadelta", - "ApplyAdagrad", - "ApplyProximalAdagrad", - "ApplyAdagradDA", - "ApplyFtrl", - "ApplyMomentum", - "ApplyAdam", - "ApplyRMSProp", - "ApplyCenteredRMSProp"}; + const std::set apply_gradients_ops = {"ApplyGradientDescent", + "ApplyProximalGradientDescent", + "ApplyAdadelta", + "ApplyAdagrad", + "ApplyProximalAdagrad", + "ApplyAdagradDA", + "ApplyFtrl", + "ApplyMomentum", + "ApplyAdam", + "ApplyRMSProp", + "ApplyCenteredRMSProp"}; const NodeDef* dequeue_node = nullptr; for (int i = 0; i < graph_.node_size(); i++) { all_nodes_.insert( std::make_pair(graph_.node(i).name(), graph_.mutable_node(i))); - if (graph_.node(i).op() == "QueueDequeueManyV2") { + if (IsDequeueOp(graph_.node(i))) { dequeue_node = graph_.mutable_node(i); } if (apply_gradients_ops.find(graph_.node(i).op()) != @@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) { for (const auto& fetch : item_->fetch) { AddNodeControl(fetch, {control->name()}, graph); } + *(graph->mutable_library()) = item_->graph.library(); LOG(INFO) << "Parallelized graph size: " << graph->node_size(); } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9af7e1e0baa..49b12df7aa9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3692,11 +3692,130 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "mfcc_dct", + srcs = ["mfcc_dct.cc"], + hdrs = ["mfcc_dct.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "mfcc_dct_test", + size = "small", + srcs = ["mfcc_dct_test.cc"], + deps = [ + ":mfcc_dct", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_test_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +cc_library( + name = "mfcc_mel_filterbank", + srcs = ["mfcc_mel_filterbank.cc"], + hdrs = ["mfcc_mel_filterbank.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "mfcc_mel_filterbank_test", + size = "small", + srcs = ["mfcc_mel_filterbank_test.cc"], + deps = [ + ":mfcc_mel_filterbank", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_test_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +cc_library( + name = "mfcc", + srcs = ["mfcc.cc"], + hdrs = ["mfcc.h"], + copts = tf_copts(), + deps = [ + ":mfcc_dct", + ":mfcc_mel_filterbank", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "mfcc_test", + size = "small", + srcs = ["mfcc_test.cc"], + deps = [ + ":mfcc", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_test_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "mfcc_op", + prefix = "mfcc_op", + deps = [ + ":mfcc", + "//tensorflow/core:audio_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "mfcc_op_test", + size = "small", + srcs = ["mfcc_op_test.cc"], + deps = [ + ":mfcc_op", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "audio", deps = [ ":decode_wav_op", ":encode_wav_op", + ":mfcc_op", ":spectrogram_op", ], ) diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc index ee13038325d..ce7fb9c332b 100644 --- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc +++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc @@ -86,7 +86,7 @@ class FixedLengthRecordReader : public ReaderBase { ++record_number_; if (hop_bytes_ > 0) { - input_buffer_->Seek(pos_before_read + hop_bytes_); + input_buffer_->Seek(pos_before_read + hop_bytes_).IgnoreError(); } return Status::OK(); @@ -118,7 +118,8 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel { public: explicit FixedLengthRecordReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { - int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1, hop_bytes = -1; + int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1, + hop_bytes = -1; OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes)); OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes)); OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes)); @@ -132,15 +133,15 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel { OP_REQUIRES(context, footer_bytes >= 0, errors::InvalidArgument("footer_bytes must be >= 0 not ", footer_bytes)); - OP_REQUIRES(context, hop_bytes >= 0, - errors::InvalidArgument("hop_bytes must be >= 0 not ", - hop_bytes)); + OP_REQUIRES( + context, hop_bytes >= 0, + errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes)); Env* env = context->env(); - SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, - hop_bytes, env]() { - return new FixedLengthRecordReader(name(), header_bytes, record_bytes, - footer_bytes, hop_bytes, env); - }); + SetReaderFactory( + [this, header_bytes, record_bytes, footer_bytes, hop_bytes, env]() { + return new FixedLengthRecordReader(name(), header_bytes, record_bytes, + footer_bytes, hop_bytes, env); + }); } }; diff --git a/tensorflow/core/kernels/mfcc.cc b/tensorflow/core/kernels/mfcc.cc new file mode 100644 index 00000000000..2793005aa26 --- /dev/null +++ b/tensorflow/core/kernels/mfcc.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/kernels/mfcc.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +const double kDefaultUpperFrequencyLimit = 4000; +const double kDefaultLowerFrequencyLimit = 20; +const double kFilterbankFloor = 1e-12; +const int kDefaultFilterbankChannelCount = 40; +const int kDefaultDCTCoefficientCount = 13; + +Mfcc::Mfcc() : initialized_(false), + lower_frequency_limit_(kDefaultLowerFrequencyLimit), + upper_frequency_limit_(kDefaultUpperFrequencyLimit), + filterbank_channel_count_(kDefaultFilterbankChannelCount), + dct_coefficient_count_(kDefaultDCTCoefficientCount) { } + +bool Mfcc::Initialize(int input_length, + double input_sample_rate) { + bool initialized = mel_filterbank_.Initialize(input_length, + input_sample_rate, + filterbank_channel_count_, + lower_frequency_limit_, + upper_frequency_limit_); + initialized &= dct_.Initialize(filterbank_channel_count_, + dct_coefficient_count_); + initialized_ = initialized; + return initialized; +} + +void Mfcc::Compute(const std::vector& spectrogram_frame, + std::vector* output) const { + if (!initialized_) { + LOG(ERROR) << "Mfcc not initialized."; + return; + } + std::vector working; + mel_filterbank_.Compute(spectrogram_frame, &working); + for (int i = 0; i < working.size(); ++i) { + double val = working[i]; + if (val < kFilterbankFloor) { + val = kFilterbankFloor; + } + working[i] = log(val); + } + dct_.Compute(working, output); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc.h b/tensorflow/core/kernels/mfcc.h new file mode 100644 index 00000000000..c39f1049909 --- /dev/null +++ b/tensorflow/core/kernels/mfcc.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for computing MFCCs from spectrogram slices. + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ + +#include + +#include "tensorflow/core/kernels/mfcc_dct.h" +#include "tensorflow/core/kernels/mfcc_mel_filterbank.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class Mfcc { + public: + Mfcc(); + bool Initialize(int input_length, + double input_sample_rate); + + // Input is a single magnitude spectrogram frame. The input spectrum + // is filtered into bands using a triangular mel filterbank and a + // discrete cosine transform (DCT) of the values is taken. Output is + // populated with the lowest dct_coefficient_count of these values. + void Compute(const std::vector& spectrogram_frame, + std::vector* output) const; + + void set_upper_frequency_limit(double upper_frequency_limit) { + CHECK(!initialized_) << "Set frequency limits before calling Initialize."; + upper_frequency_limit_ = upper_frequency_limit; + } + + void set_lower_frequency_limit(double lower_frequency_limit) { + CHECK(!initialized_) << "Set frequency limits before calling Initialize."; + lower_frequency_limit_ = lower_frequency_limit; + } + + void set_filterbank_channel_count(int filterbank_channel_count) { + CHECK(!initialized_) << "Set channel count before calling Initialize."; + filterbank_channel_count_ = filterbank_channel_count; + } + + void set_dct_coefficient_count(int dct_coefficient_count) { + CHECK(!initialized_) << "Set coefficient count before calling Initialize."; + dct_coefficient_count_ = dct_coefficient_count; + } + + private: + MfccMelFilterbank mel_filterbank_; + MfccDct dct_; + bool initialized_; + double lower_frequency_limit_; + double upper_frequency_limit_; + int filterbank_channel_count_; + int dct_coefficient_count_; + TF_DISALLOW_COPY_AND_ASSIGN(Mfcc); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ diff --git a/tensorflow/core/kernels/mfcc_dct.cc b/tensorflow/core/kernels/mfcc_dct.cc new file mode 100644 index 00000000000..aa67a8d6499 --- /dev/null +++ b/tensorflow/core/kernels/mfcc_dct.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mfcc_dct.h" + +#include +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +MfccDct::MfccDct() : initialized_(false) {} + +bool MfccDct::Initialize(int input_length, int coefficient_count) { + coefficient_count_ = coefficient_count; + input_length_ = input_length; + + if (coefficient_count_ < 1) { + LOG(ERROR) << "Coefficient count must be positive."; + return false; + } + + if (input_length < 1) { + LOG(ERROR) << "Input length must be positive."; + return false; + } + + if (coefficient_count_ > input_length_) { + LOG(ERROR) << "Coefficient count must be less than or equal to " + << "input length."; + return false; + } + + cosines_.resize(coefficient_count_); + double fnorm = sqrt(2.0 / input_length_); + // Some platforms don't have M_PI, so define a local constant here. + const double pi = std::atan(1) * 4; + double arg = pi / input_length_; + for (int i = 0; i < coefficient_count_; ++i) { + cosines_[i].resize(input_length_); + for (int j = 0; j < input_length_; ++j) { + cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5)); + } + } + initialized_ = true; + return true; +} + +void MfccDct::Compute(const std::vector &input, + std::vector *output) const { + if (!initialized_) { + LOG(ERROR) << "DCT not initialized."; + return; + } + + output->resize(coefficient_count_); + int length = input.size(); + if (length > input_length_) { + length = input_length_; + } + + for (int i = 0; i < coefficient_count_; ++i) { + double sum = 0.0; + for (int j = 0; j < length; ++j) { + sum += cosines_[i][j] * input[j]; + } + (*output)[i] = sum; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_dct.h b/tensorflow/core/kernels/mfcc_dct.h new file mode 100644 index 00000000000..4fa3c01628d --- /dev/null +++ b/tensorflow/core/kernels/mfcc_dct.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic minimal DCT class for MFCC speech processing. + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class MfccDct { + public: + MfccDct(); + bool Initialize(int input_length, int coefficient_count); + void Compute(const std::vector& input, + std::vector* output) const; + + private: + bool initialized_; + int coefficient_count_; + int input_length_; + std::vector > cosines_; + TF_DISALLOW_COPY_AND_ASSIGN(MfccDct); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ diff --git a/tensorflow/core/kernels/mfcc_dct_test.cc b/tensorflow/core/kernels/mfcc_dct_test.cc new file mode 100644 index 00000000000..7526278fe9e --- /dev/null +++ b/tensorflow/core/kernels/mfcc_dct_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mfcc_dct.h" + +#include + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +TEST(MfccDctTest, AgreesWithMatlab) { + // This test verifies the DCT against MATLAB's dct function. + MfccDct dct; + std::vector input = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + const int kCoefficientCount = 6; + ASSERT_TRUE(dct.Initialize(input.size(), kCoefficientCount)); + std::vector output; + dct.Compute(input, &output); + // Note, the matlab dct function divides the first coefficient by + // sqrt(2), whereas we don't, so we multiply the first element of + // the matlab result by sqrt(2) to get the expected values below. + std::vector expected = {12.1243556530, -4.1625617959, 0.0, + -0.4082482905, 0.0, -0.0800788912}; + ASSERT_EQ(output.size(), kCoefficientCount); + for (int i = 0; i < kCoefficientCount; ++i) { + EXPECT_NEAR(output[i], expected[i], 1e-10); + } +} + +TEST(MfccDctTest, InitializeFailsOnInvalidInput) { + MfccDct dct1; + EXPECT_FALSE(dct1.Initialize(-50, 1)); + MfccDct dct2; + EXPECT_FALSE(dct1.Initialize(10, -4)); + MfccDct dct3; + EXPECT_FALSE(dct1.Initialize(-1, -1)); + MfccDct dct4; + EXPECT_FALSE(dct1.Initialize(20, 21)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc new file mode 100644 index 00000000000..d68c60280d9 --- /dev/null +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This code resamples the FFT bins, and smooths then with triangle-shaped +// weights to create a mel-frequency filter bank. For filter i centered at f_i, +// there is a triangular weighting of the FFT bins that extends from +// filter f_i-1 (with a value of zero at the left edge of the triangle) to f_i +// (where the filter value is 1) to f_i+1 (where the filter values returns to +// zero). + +// Note: this code fails if you ask for too many channels. The algorithm used +// here assumes that each FFT bin contributes to at most two channels: the +// right side of a triangle for channel i, and the left side of the triangle +// for channel i+1. If you ask for so many channels that some of the +// resulting mel triangle filters are smaller than a single FFT bin, these +// channels may end up with no contributing FFT bins. The resulting mel +// spectrum output will have some channels that are always zero. + +#include "tensorflow/core/kernels/mfcc_mel_filterbank.h" + +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {} + +bool MfccMelFilterbank::Initialize(int input_length, + double input_sample_rate, + int output_channel_count, + double lower_frequency_limit, + double upper_frequency_limit) { + num_channels_ = output_channel_count; + sample_rate_ = input_sample_rate; + input_length_ = input_length; + + if (num_channels_ < 1) { + LOG(ERROR) << "Number of filterbank channels must be positive."; + return false; + } + + if (sample_rate_ <= 0) { + LOG(ERROR) << "Sample rate must be positive."; + return false; + } + + if (input_length < 2) { + LOG(ERROR) << "Input length must greater than 1."; + return false; + } + + if (lower_frequency_limit <= 0) { + LOG(ERROR) << "Lower frequency limit must be positive."; + return false; + } + + if (upper_frequency_limit <= lower_frequency_limit) { + LOG(ERROR) << "Upper frequency limit must be greater than " + << "lower frequency limit."; + return false; + } + + // An extra center frequency is computed at the top to get the upper + // limit on the high side of the final triangular filter. + center_frequencies_.resize(num_channels_ + 1); + const double mel_low = FreqToMel(lower_frequency_limit); + const double mel_hi = FreqToMel(upper_frequency_limit); + const double mel_span = mel_hi - mel_low; + const double mel_spacing = mel_span / static_cast(num_channels_ + 1); + for (int i = 0; i < num_channels_ + 1; ++i) { + center_frequencies_[i] = mel_low + (mel_spacing * (i + 1)); + } + + // Always exclude DC; emulate HTK. + const double hz_per_sbin = 0.5 * sample_rate_ / + static_cast(input_length_ - 1); + start_index_ = static_cast(1.5 + (lower_frequency_limit / + hz_per_sbin)); + end_index_ = static_cast(upper_frequency_limit / hz_per_sbin); + + // Maps the input spectrum bin indices to filter bank channels/indices. For + // each FFT bin, band_mapper tells us which channel this bin contributes to + // on the right side of the triangle. Thus this bin also contributes to the + // left side of the next channel's triangle response. + band_mapper_.resize(input_length_); + int channel = 0; + for (int i = 0; i < input_length_; ++i) { + double melf = FreqToMel(i * hz_per_sbin); + if ((i < start_index_) || (i > end_index_)) { + band_mapper_[i] = -2; // Indicate an unused Fourier coefficient. + } else { + while ((center_frequencies_[channel] < melf) && + (channel < num_channels_)) { + ++channel; + } + band_mapper_[i] = channel - 1; // Can be == -1 + } + } + + // Create the weighting functions to taper the band edges. The contribution + // of any one FFT bin is based on its distance along the continuum between two + // mel-channel center frequencies. This bin contributes weights_[i] to the + // current channel and 1-weights_[i] to the next channel. + weights_.resize(input_length_); + for (int i = 0; i < input_length_; ++i) { + channel = band_mapper_[i]; + if ((i < start_index_) || (i > end_index_)) { + weights_[i] = 0.0; + } else { + if (channel >= 0) { + weights_[i] = (center_frequencies_[channel + 1] - + FreqToMel(i * hz_per_sbin)) / + (center_frequencies_[channel + 1] - center_frequencies_[channel]); + } else { + weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) / + (center_frequencies_[0] - mel_low); + } + } + } + // Check the sum of FFT bin weights for every mel band to identify + // situations where the mel bands are so narrow that they don't get + // significant weight on enough (or any) FFT bins -- i.e., too many + // mel bands have been requested for the given FFT size. + std::vector bad_channels; + for (int c = 0; c < num_channels_; ++c) { + float band_weights_sum = 0.0; + for (int i = 0; i < input_length_; ++i) { + if (band_mapper_[i] == c - 1) { + band_weights_sum += (1.0 - weights_[i]); + } else if (band_mapper_[i] == c) { + band_weights_sum += weights_[i]; + } + } + // The lowest mel channels have the fewest FFT bins and the lowest + // weights sum. But given that the target gain at the center frequency + // is 1.0, if the total sum of weights is 0.5, we're in bad shape. + if (band_weights_sum < 0.5) { + bad_channels.push_back(c); + } + } + if (!bad_channels.empty()) { + LOG(ERROR) << "Missing " << bad_channels.size() << " bands " << + " starting at " << bad_channels[0] << + " in mel-frequency design. " << + "Perhaps too many channels or " << + "not enough frequency resolution in spectrum. (" << + "input_length: " << input_length << + " input_sample_rate: " << input_sample_rate << + " output_channel_count: " << output_channel_count << + " lower_frequency_limit: " << lower_frequency_limit << + " upper_frequency_limit: " << upper_frequency_limit; + } + initialized_ = true; + return true; +} + +// Compute the mel spectrum from the squared-magnitude FFT input by taking the +// square root, then summing FFT magnitudes under triangular integration windows +// whose widths increase with frequency. +void MfccMelFilterbank::Compute(const std::vector &input, + std::vector *output) const { + if (!initialized_) { + LOG(ERROR) << "Mel Filterbank not initialized."; + return; + } + + if (input.size() <= end_index_) { + LOG(ERROR) << "Input too short to compute filterbank"; + return; + } + + // Ensure output is right length and reset all values. + output->assign(num_channels_, 0.0); + + for (int i = start_index_; i <= end_index_; i++) { // For each FFT bin + double spec_val = sqrt(input[i]); + double weighted = spec_val * weights_[i]; + int channel = band_mapper_[i]; + if (channel >= 0) + (*output)[channel] += weighted; // Right side of triangle, downward slope + channel++; + if (channel < num_channels_) + (*output)[channel] += spec_val - weighted; // Left side of triangle + } +} + +double MfccMelFilterbank::FreqToMel(double freq) const { + return 1127.0 * log(1.0 + (freq / 700.0)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.h b/tensorflow/core/kernels/mfcc_mel_filterbank.h new file mode 100644 index 00000000000..33ea1bdb5bc --- /dev/null +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for applying a mel-scale filterbank to an input. + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ + +#include +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class MfccMelFilterbank { + public: + MfccMelFilterbank(); + bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1. + double input_sample_rate, + int output_channel_count, + double lower_frequency_limit, + double upper_frequency_limit); + + // Takes a magnitude spectrogram slice as input, computes a + // traingular mel filterbank and places the result in output. + void Compute(const std::vector& input, + std::vector* output) const; + + private: + double FreqToMel(double freq) const; + bool initialized_; + int num_channels_; + double sample_rate_; + int input_length_; + std::vector center_frequencies_; // In mel, for each mel channel. + + // Each FFT bin b contributes to two triangular mel channels, with + // proportion weights_[b] going into mel channel band_mapper_[b], and + // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1. + // Thus, weights_ contains the weighting applied to each FFT bin for the + // upper-half of the triangular band. + std::vector weights_; // Right-side weight for this fft bin. + + // FFT bin i contributes to the upper side of mel channel band_mapper_[i] + std::vector band_mapper_; + int start_index_; // Lowest FFT bin used to calculate mel spectrum. + int end_index_; // Highest FFT bin used to calculate mel spectrum. + + TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc new file mode 100644 index 00000000000..c3a7e779403 --- /dev/null +++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mfcc_mel_filterbank.h" + +#include + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) { + // This test verifies the Mel filterbank against "golden values". + // Golden values are from an independent Python Mel implementation. + MfccMelFilterbank filterbank; + + std::vector input; + const int kSampleCount = 513; + for (int i = 0; i < kSampleCount; ++i) { + input.push_back(i + 1); + } + const int kChannelCount = 20; + filterbank.Initialize(input.size(), + 22050 /* sample rate */, + kChannelCount /* channels */, + 20.0 /* lower frequency limit */, + 4000.0 /* upper frequency limit */); + + std::vector output; + filterbank.Compute(input, &output); + + std::vector expected = { + 7.38894574, 10.30330648, 13.72703292, 17.24158686, 21.35253118, + 25.77781089, 31.30624108, 37.05877236, 43.9436536, 51.80306637, + 60.79867148, 71.14363376, 82.90910141, 96.50069158, 112.08428368, + 129.96721968, 150.4277597, 173.74997634, 200.86037462, 231.59802942}; + + ASSERT_EQ(output.size(), kChannelCount); + + for (int i = 0; i < kChannelCount; ++i) { + EXPECT_NEAR(output[i], expected[i], 1e-04); + } +} + +TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) { + // Test for bug where the output vector was not cleared before + // accumulating next frame's weighted spectral values. + MfccMelFilterbank filterbank; + + const int kSampleCount = 513; + std::vector input; + std::vector output; + + filterbank.Initialize(kSampleCount, + 22050 /* sample rate */, + 20 /* channels */, + 20.0 /* lower frequency limit */, + 4000.0 /* upper frequency limit */); + + + // First call with nonzero input value, and an empty output vector, + // will resize the output and fill it with the correct, nonzero outputs. + input.assign(kSampleCount, 1.0); + filterbank.Compute(input, &output); + for (const double value : output) { + EXPECT_LE(0.0, value); + } + + // Second call with zero input should also generate zero output. However, + // the output vector now is already the correct size, but full of nonzero + // values. Make sure these don't affect the output. + input.assign(kSampleCount, 0.0); + filterbank.Compute(input, &output); + for (const double value : output) { + EXPECT_EQ(0.0, value); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_op.cc b/tensorflow/core/kernels/mfcc_op.cc new file mode 100644 index 00000000000..02643857c1f --- /dev/null +++ b/tensorflow/core/kernels/mfcc_op.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/audio_ops.cc + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/mfcc.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Create a speech fingerpring from spectrogram data. +class MfccOp : public OpKernel { + public: + explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("upper_frequency_limit", + &upper_frequency_limit_)); + OP_REQUIRES_OK(context, context->GetAttr("lower_frequency_limit", + &lower_frequency_limit_)); + OP_REQUIRES_OK(context, context->GetAttr("filterbank_channel_count", + &filterbank_channel_count_)); + OP_REQUIRES_OK(context, context->GetAttr("dct_coefficient_count", + &dct_coefficient_count_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& spectrogram = context->input(0); + OP_REQUIRES(context, spectrogram.dims() == 3, + errors::InvalidArgument("spectrogram must be 3-dimensional", + spectrogram.shape().DebugString())); + const Tensor& sample_rate_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()), + errors::InvalidArgument( + "Input sample_rate should be a scalar tensor, got ", + sample_rate_tensor.shape().DebugString(), " instead.")); + const int32 sample_rate = sample_rate_tensor.scalar()(); + + const int spectrogram_channels = spectrogram.dim_size(2); + const int spectrogram_samples = spectrogram.dim_size(1); + const int audio_channels = spectrogram.dim_size(0); + + Mfcc mfcc; + mfcc.set_upper_frequency_limit(upper_frequency_limit_); + mfcc.set_lower_frequency_limit(lower_frequency_limit_); + mfcc.set_filterbank_channel_count(filterbank_channel_count_); + mfcc.set_dct_coefficient_count(dct_coefficient_count_); + OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate), + errors::InvalidArgument( + "Mfcc initialization failed for channel count ", + spectrogram_channels, " and sample rate ", sample_rate)); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, + TensorShape({audio_channels, spectrogram_samples, + dct_coefficient_count_}), + &output_tensor)); + + const float* spectrogram_flat = spectrogram.flat().data(); + float* output_flat = output_tensor->flat().data(); + + for (int audio_channel = 0; audio_channel < audio_channels; + ++audio_channel) { + for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples; + ++spectrogram_sample) { + const float* sample_data = + spectrogram_flat + + (audio_channel * spectrogram_samples * spectrogram_channels) + + (spectrogram_sample * spectrogram_channels); + std::vector mfcc_input(sample_data, + sample_data + spectrogram_channels); + std::vector mfcc_output; + mfcc.Compute(mfcc_input, &mfcc_output); + DCHECK_EQ(dct_coefficient_count_, mfcc_output.size()); + float* output_data = + output_flat + + (audio_channel * spectrogram_samples * dct_coefficient_count_) + + (spectrogram_sample * dct_coefficient_count_); + for (int i = 0; i < dct_coefficient_count_; ++i) { + output_data[i] = mfcc_output[i]; + } + } + } + } + + private: + float upper_frequency_limit_; + float lower_frequency_limit_; + int32 filterbank_channel_count_; + int32 dct_coefficient_count_; +}; +REGISTER_KERNEL_BUILDER(Name("Mfcc").Device(DEVICE_CPU), MfccOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_op_test.cc b/tensorflow/core/kernels/mfcc_op_test.cc new file mode 100644 index 00000000000..d16171d5265 --- /dev/null +++ b/tensorflow/core/kernels/mfcc_op_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/audio_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using namespace ops; // NOLINT(build/namespaces) + +TEST(MfccOpTest, SimpleTest) { + Scope root = Scope::NewRootScope(); + + Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513})); + test::FillIota(&spectrogram_tensor, 1.0f); + + Output spectrogram_const_op = Const(root.WithOpName("spectrogram_const_op"), + Input::Initializer(spectrogram_tensor)); + + Output sample_rate_const_op = + Const(root.WithOpName("sample_rate_const_op"), 22050); + + Mfcc mfcc_op = Mfcc(root.WithOpName("mfcc_op"), spectrogram_const_op, + sample_rate_const_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {mfcc_op.output}, &outputs)); + + const Tensor& mfcc_tensor = outputs[0]; + + EXPECT_EQ(3, mfcc_tensor.dims()); + EXPECT_EQ(13, mfcc_tensor.dim_size(2)); + EXPECT_EQ(1, mfcc_tensor.dim_size(1)); + EXPECT_EQ(1, mfcc_tensor.dim_size(0)); + + test::ExpectTensorNear( + mfcc_tensor, + test::AsTensor( + {29.13970072, -6.41568601, -0.61903012, -0.96778652, -0.26819878, + -0.40907028, -0.15614748, -0.23203119, -0.10481487, -0.1543029, + -0.0769791, -0.10806114, -0.06047613}, + TensorShape({1, 1, 13})), + 1e-3); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_test.cc b/tensorflow/core/kernels/mfcc_test.cc new file mode 100644 index 00000000000..9ab726e5b9c --- /dev/null +++ b/tensorflow/core/kernels/mfcc_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mfcc.h" + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +TEST(MfccTest, AgreesWithPythonGoldenValues) { + Mfcc mfcc; + std::vector input; + const int kSampleCount = 513; + for (int i = 0; i < kSampleCount; ++i) { + input.push_back(i + 1); + } + + ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/)); + + std::vector output; + mfcc.Compute(input, &output); + + std::vector expected = {29.13970072, -6.41568601, -0.61903012, + -0.96778652, -0.26819878, -0.40907028, + -0.15614748, -0.23203119, -0.10481487, + -0.1543029, -0.0769791, -0.10806114, + -0.06047613}; + + ASSERT_EQ(expected.size(), output.size()); + for (int i = 0; i < output.size(); ++i) { + EXPECT_NEAR(output[i], expected[i], 1e-04); + } +} + +TEST(MfccTest, AvoidsNansWithZeroInput) { + Mfcc mfcc; + std::vector input; + const int kSampleCount = 513; + for (int i = 0; i < kSampleCount; ++i) { + input.push_back(0.0); + } + + ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/)); + + std::vector output; + mfcc.Compute(input, &output); + + int expected_size = 13; + ASSERT_EQ(expected_size, output.size()); + for (const double value : output) { + EXPECT_FALSE(isnan(value)); + } +} + +TEST(MfccTest, SimpleInputSaneResult) { + Mfcc mfcc; + mfcc.set_lower_frequency_limit(125.0); + mfcc.set_upper_frequency_limit(3800.0); + mfcc.set_filterbank_channel_count(40); + mfcc.set_dct_coefficient_count(40); + const int kSpectrogramSize = 129; + std::vector input(kSpectrogramSize, 0.0); + + // Simulate a low-frequency sinusoid from the spectrogram. + const int kHotBin = 10; + input[kHotBin] = 1.0; + ASSERT_TRUE(mfcc.Initialize(input.size(), 8000)); + + std::vector output; + mfcc.Compute(input, &output); + + // For a single low-frequency input, output beyond c_0 should look like + // a slow cosine, with a slight delay. Largest value will be c_1. + EXPECT_EQ(output.begin() + 1, std::max_element(output.begin(), output.end())); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index 0d1aa57f652..d90baee069c 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -413,13 +413,13 @@ class MklAvgPoolingGradOp : public OpKernel { TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER(Name("MklAvgPool") +REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") .Device(DEVICE_CPU) .TypeConstraint("T") .Label(mkl_op_registry::kMklOpLabel), MklAvgPoolingOp); -REGISTER_KERNEL_BUILDER(Name("MklAvgPoolGrad") +REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") .Device(DEVICE_CPU) .TypeConstraint("T") .Label(mkl_op_registry::kMklOpLabel), diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 2f5f623d922..27930c44a65 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -#include "third_party/mkl/include/mkl_dnn_types.h" #include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -75,8 +75,9 @@ class EigenConcatBaseOp : public OpKernel { const TensorShape& input_shape = values[0].shape(); int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(c, (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(c, + (0 <= axis && axis < input_dims) || + (allow_legacy_scalars() && concat_dim == 0), errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -101,8 +102,8 @@ class EigenConcatBaseOp : public OpKernel { c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); for (int j = 0; j < input_dims; ++j) { if (j == axis) { continue; @@ -111,8 +112,8 @@ class EigenConcatBaseOp : public OpKernel { c, in.dim_size(j) == input_shape.dim_size(j), errors::InvalidArgument( "ConcatOp : Dimensions of inputs should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); } if (in.NumElements() > 0) { int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; @@ -152,10 +153,10 @@ class MklConcatOp : public OpKernel { public: typedef std::vector::ConstMatrix>> - ConstMatrixVector; + ConstMatrixVector; - explicit MklConcatOp(OpKernelConstruction* c) : OpKernel(c), - eigen_concat_op_(c) {} + explicit MklConcatOp(OpKernelConstruction* c) + : OpKernel(c), eigen_concat_op_(c) {} void Compute(OpKernelContext* context) override { MklConcatOpContext mkl_context; @@ -170,18 +171,18 @@ class MklConcatOp : public OpKernel { // If this is Concat, then concat_dim is 0th input. // If this is ConcatV2, then axis is Nth input. - const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM ? - MklGetInput(context, 0) : - MklGetInput(context, N); + const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM + ? MklGetInput(context, 0) + : MklGetInput(context, N); // Sanity checks OP_REQUIRES( - context, IsLegacyScalar(concat_dim_tensor.shape()), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor.shape().DebugString())); + context, IsLegacyScalar(concat_dim_tensor.shape()), + errors::InvalidArgument( + "Concat dim tensor should be a scalar integer, but got shape ", + concat_dim_tensor.shape().DebugString())); int32 concat_dim = - internal::SubtleMustCopy(concat_dim_tensor.scalar()()); + internal::SubtleMustCopy(concat_dim_tensor.scalar()()); MklShape& inpshape0 = input_shapes[0]; @@ -215,10 +216,10 @@ class MklConcatOp : public OpKernel { if (invoke_eigen) { string msg = std::string("Invoking Eigen version of Concat. Reason:") + - (!is_concat_dim_channel ? - std::string("Concat dimension is not channel") : - std::string("Not all tensors are in Mkl layout")); - VLOG(1) << "MklConcatOp: " << msg; + (!is_concat_dim_channel + ? std::string("Concat dimension is not channel") + : std::string("Not all tensors are in Mkl layout")); + VLOG(1) << "_MklConcatOp: " << msg; CallEigenVersion(context, input_tensors, input_shapes); return; } @@ -235,12 +236,11 @@ class MklConcatOp : public OpKernel { int i = 0; for (auto& s : input_shapes) { size_t exp_dims = inpshape0.GetDimension(); - OP_REQUIRES( - context, s.GetDimension() == exp_dims, - errors::InvalidArgument( - "MklConcatOp : Ranks of all input tensors should match:" - " input dimensions = ", s.GetDimension(), - " vs. expected rank = ", exp_dims)); + OP_REQUIRES(context, s.GetDimension() == exp_dims, + errors::InvalidArgument( + "_MklConcatOp : Ranks of all input tensors should match:" + " input dimensions = ", + s.GetDimension(), " vs. expected rank = ", exp_dims)); for (int d = 0; d < exp_dims; ++d) { if (d == concat_dim) { @@ -248,10 +248,12 @@ class MklConcatOp : public OpKernel { } size_t exp_size = inpshape0.GetSizes()[d]; - OP_REQUIRES(context, exp_size == s.GetSizes()[d], - errors::InvalidArgument("MklConcatOp : Dimensions of inputs" - "should match: shape[0][", d, "]= ", exp_size, " vs. shape[", - i, "][", d, "] = ", s.GetSizes()[d])); + OP_REQUIRES( + context, exp_size == s.GetSizes()[d], + errors::InvalidArgument("_MklConcatOp : Dimensions of inputs" + "should match: shape[0][", + d, "]= ", exp_size, " vs. shape[", i, "][", + d, "] = ", s.GetSizes()[d])); } ++i; } @@ -259,8 +261,8 @@ class MklConcatOp : public OpKernel { // Use input MKL layout instead of creating new layouts. int64 output_concat_dim_size = 0; for (auto& s : input_shapes) { - output_concat_dim_size += s.GetDimension() > 0 ? - s.GetSizes()[concat_dim] : 1; + output_concat_dim_size += + s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1; } mkl_context.MklCreateInputLayouts(context, input_shapes); @@ -273,9 +275,10 @@ class MklConcatOp : public OpKernel { if (inpshape0.IsTensorInNHWCFormat()) { data_format = FORMAT_NHWC; } else { - OP_REQUIRES(context, inpshape0.IsTensorInNCHWFormat(), - errors::InvalidArgument( - "MklConcat only supports all inputs in NCHW or NHWC format ")); + OP_REQUIRES( + context, inpshape0.IsTensorInNCHWFormat(), + errors::InvalidArgument( + "_MklConcat only supports all inputs in NCHW or NHWC format ")); data_format = FORMAT_NCHW; } @@ -298,16 +301,18 @@ class MklConcatOp : public OpKernel { TensorShape mkl_output_tf_shape; mkl_output_tf_shape.AddDim(1); - mkl_output_tf_shape.AddDim(dnnLayoutGetMemorySize_F32( - static_cast(mkl_output_mkl_shape.GetMklLayout()))/sizeof(T)); + mkl_output_tf_shape.AddDim( + dnnLayoutGetMemorySize_F32( + static_cast(mkl_output_mkl_shape.GetMklLayout())) / + sizeof(T)); Tensor* output = nullptr; AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, mkl_output_mkl_shape); // Set destination resource. - mkl_context.concat_res[dnnResourceDst] = const_cast( - static_cast(output->flat().data())); + mkl_context.concat_res[dnnResourceDst] = + const_cast(static_cast(output->flat().data())); mkl_context.mkl_tmp_tensors.resize(N); mkl_context.MklPrepareConcatInputs(context, input_tensors); @@ -325,7 +330,7 @@ class MklConcatOp : public OpKernel { size_t out_sizes[4]; size_t out_strides[4]; dnnPrimitive_t prim_concat; - void *concat_res[dnnResourceNumber]; + void* concat_res[dnnResourceNumber]; std::vector lt_inputs; std::vector mkl_tmp_tensors; @@ -335,7 +340,7 @@ class MklConcatOp : public OpKernel { MklShapeList& input_shapes) { for (auto& is : input_shapes) { CHECK_EQ(is.IsMklTensor(), true); - lt_inputs.push_back((dnnLayout_t) is.GetCurLayout()); + lt_inputs.push_back((dnnLayout_t)is.GetCurLayout()); } } @@ -348,34 +353,31 @@ class MklConcatOp : public OpKernel { dnnLayout_t mkl_lt_internal_input; void* mkl_buf_convert_input = nullptr; - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, - prim_concat, - (dnnResourceType_t) - (dnnResourceMultipleSrc + i)), + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( + &mkl_lt_internal_input, prim_concat, + (dnnResourceType_t)(dnnResourceMultipleSrc + i)), E_SUCCESS); if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) { CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, - lt_inputs[i], - mkl_lt_internal_input), + lt_inputs[i], mkl_lt_internal_input), E_SUCCESS); - AllocTmpBuffer(context, &mkl_tmp_tensors[i], - mkl_lt_internal_input, + AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input, &mkl_buf_convert_input); - CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, - const_cast(static_cast( + CHECK_EQ(dnnConversionExecute_F32( + mkl_prim_convert_input, + const_cast(static_cast( input_tensors[i].flat().data())), - mkl_buf_convert_input), + mkl_buf_convert_input), E_SUCCESS); concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input; CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS); } else { - concat_res[dnnResourceMultipleSrc + i] = const_cast( - static_cast( - input_tensors[i].flat().data())); + concat_res[dnnResourceMultipleSrc + i] = const_cast( + static_cast(input_tensors[i].flat().data())); } CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS); @@ -401,9 +403,8 @@ class MklConcatOp : public OpKernel { for (int i = 0; i < input_shapes.size(); i++) { if (input_shapes[i].IsMklTensor()) { // If input tensor is Mkl, then do the conversion. - Tensor tmp_tensor = ConvertMklToTF(context, - values[i], - input_shapes[i]); + Tensor tmp_tensor = + ConvertMklToTF(context, values[i], input_shapes[i]); converted_values.push_back(tmp_tensor); } else { // If input tensor is TF already, then we do not need any conversion. @@ -421,27 +422,27 @@ class MklConcatOp : public OpKernel { mkl_tensor_mkl_shape.SetTfDimOrder(4); // Dimensions Tensor* mkl_tensor = nullptr; TensorShape mkl_tensor_tf_shape; - mkl_tensor_tf_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA( - mkl_tensor_mkl_shape.GetDimension())); + mkl_tensor_tf_shape.AddDim( + SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension())); int tf_output_index = 0; - context->allocate_output(GetTensorMetaDataIndex(tf_output_index, - context->num_outputs()), - mkl_tensor_tf_shape, &mkl_tensor); + context->allocate_output( + GetTensorMetaDataIndex(tf_output_index, context->num_outputs()), + mkl_tensor_tf_shape, &mkl_tensor); mkl_tensor_mkl_shape.SerializeMklShape( - mkl_tensor->flat().data(), - mkl_tensor->flat().size() * sizeof(uint8)); + mkl_tensor->flat().data(), + mkl_tensor->flat().size() * sizeof(uint8)); } }; /* Use optimized concat for float type only */ #define REGISTER_MKL_CPU(type) \ - REGISTER_KERNEL_BUILDER(Name("MklConcat") \ + REGISTER_KERNEL_BUILDER(Name("_MklConcat") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .HostMemory("concat_dim") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConcatOp) \ - REGISTER_KERNEL_BUILDER(Name("MklConcatV2") \ + REGISTER_KERNEL_BUILDER(Name("_MklConcatV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("Tidx") \ diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc index 82973340c38..8a1006a8e95 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc @@ -251,11 +251,11 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MklConv2DCustomBackpropBiasOp); }; -#define REGISTER_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBiasBackpropBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DCustomBackpropBiasOp); TF_CALL_float(REGISTER_CPU_KERNELS); diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 03a0dc8d999..6381b527a1b 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -408,11 +408,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { TensorFormat data_format_; }; -#define REGISTER_MKL_FILTER_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropFilter") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_FILTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DCustomBackpropFilterOp); TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 431abcf69d2..638ce4c0243 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -341,11 +341,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { TensorFormat data_format; }; -#define REGISTER_MKL_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropInput") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DCustomBackpropInputOp); TF_CALL_float(REGISTER_MKL_CPU_KERNELS); diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b0040351ed0..b818819b020 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -437,16 +437,16 @@ class MklConv2DOp : public OpKernel { TensorFormat data_format_; }; -#define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DOp); \ - REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConv2DOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DOp); TF_CALL_float(REGISTER_MKL_CPU); diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index b698af6e9eb..512e799d152 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -14,16 +14,16 @@ limitations under the License. ==============================================================================*/ #ifdef INTEL_MKL +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" // TODO(inteltf) Address comments from PR 8968. @@ -325,11 +325,11 @@ class MklFusedBatchNormOp : public OpKernel { } MklFusedBatchNormOpContext; }; -#define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklFusedBatchNorm") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklFusedBatchNormOp); TF_CALL_float(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU @@ -676,11 +676,11 @@ class MklFusedBatchNormGradOp : public OpKernel { } MklFusedBatchNormGradOpContext; }; -#define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklFusedBatchNormGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklFusedBatchNormGradOp); TF_CALL_float(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 4aa00838c8a..edca8e2553d 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -22,6 +22,9 @@ limitations under the License. #define EIGEN_USE_THREADS #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -30,9 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/mkl/include/mkl_dnn.h" -#include "third_party/mkl/include/mkl_dnn_types.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/util/work_sharder.h" @@ -66,10 +66,11 @@ class MklLRNOp : public OpKernel { explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); @@ -92,9 +93,10 @@ class MklLRNOp : public OpKernel { : input.dims(); OP_REQUIRES(context, mkl_context.in_dims == 4, errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("argument to LRN too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument("argument to LRN too large")); if (!input_in_mkl_format) { mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, @@ -334,10 +336,11 @@ class MklLRNGradOp : public OpKernel { explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); @@ -701,12 +704,12 @@ class MklLRNGradOp : public OpKernel { }; #define REGISTER_MKL_LRN_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklLRN") \ + REGISTER_KERNEL_BUILDER(Name("_MklLRN") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklLRNOp); \ - REGISTER_KERNEL_BUILDER(Name("MklLRNGrad") \ + REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index af43c35af5a..ba2d347d941 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -482,13 +482,13 @@ class MklMaxPoolingGradOp : public OpKernel { bool workspace_enabled_; }; -REGISTER_KERNEL_BUILDER(Name("MklMaxPool") +REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") .Device(DEVICE_CPU) .TypeConstraint("T") .Label(mkl_op_registry::kMklOpLabel), MklMaxPoolingOp); -REGISTER_KERNEL_BUILDER(Name("MklMaxPoolGrad") +REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") .Device(DEVICE_CPU) .TypeConstraint("T") .Label(mkl_op_registry::kMklOpLabel), diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 02f0a5d6deb..25c8359cc53 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -379,16 +379,16 @@ void MklReluGradOp::Compute(OpKernelContext* context) { /* Register DNN kernels for supported operations and supported types - right now * it is only Relu and f32*/ -#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ - REGISTER_KERNEL_BUILDER(Name("MklRelu") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklReluOp); \ - REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklReluOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklReluGradOp); TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index ea3a605564d..753a8b52b42 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; @@ -36,7 +36,6 @@ class MklReshapeOp : public OpKernel { explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - const Tensor& input = MklGetInput(context, 0); const Tensor& sizes = MklGetInput(context, 1); @@ -94,12 +93,12 @@ class MklReshapeOp : public OpKernel { GetMklShape(context, 0, &mkl_shape_input); bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); if (input_in_mkl_format) { - TensorShape & shape_to = shape; + TensorShape& shape_to = shape; TensorShape shape_from; for (size_t i = 0; i < mkl_shape_input.GetDimension(); i++) { // Outermost to innermost dimension - shape_from.AddDim(mkl_shape_input.GetSizes()[ - mkl_shape_input.tf_dim_idx(i)]); + shape_from.AddDim( + mkl_shape_input.GetSizes()[mkl_shape_input.tf_dim_idx(i)]); } if (shape_from == shape_to) { @@ -114,17 +113,17 @@ class MklReshapeOp : public OpKernel { mkl_shape_output); // Get output layout pointer. - dnnLayout_t output_layout = static_cast( - mkl_shape_input.GetTfLayout()); + dnnLayout_t output_layout = + static_cast(mkl_shape_input.GetTfLayout()); // Execute DNNConversion. // Note: we assume an MKL tensor always have float as its data type. - void *input_buffer = static_cast(const_cast( - input.flat().data())); - void *output_buffer = static_cast(const_cast( - output_tensor->flat().data())); + void* input_buffer = + static_cast(const_cast(input.flat().data())); + void* output_buffer = static_cast( + const_cast(output_tensor->flat().data())); mkl_shape_input.GetConvertedFlatData(output_layout, input_buffer, - output_buffer); + output_buffer); VLOG(1) << "MKLToTFConversion complete successfully."; return; @@ -133,16 +132,15 @@ class MklReshapeOp : public OpKernel { CopyTFTensorInToOut(context, 0, 0, shape); } } - }; -#define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklReshape") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("T") \ - .TypeConstraint("Tshape") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklReshape") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("Tshape") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklReshapeOp); TF_CALL_float(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc index 09529e4a705..c31ef5c2554 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.cc +++ b/tensorflow/core/kernels/mkl_tfconv_op.cc @@ -105,11 +105,11 @@ class MklToTfOp : public OpKernel { // Register kernel /////////////////////////////////////////////////////////// -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklToTf") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("MklToTf") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklToTfOp); TF_CALL_float(REGISTER_CPU); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 0ad2ab7e4f7..30026f222a6 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -139,16 +139,14 @@ class SparseTensorDenseMatMulOp : public OpKernel { TensorShape({0}), &scratch)); } -#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ - if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ - Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ - Device, T, Tindices, ADJ_A, ADJ_B>::Compute(ctx->eigen_device(), \ - out->matrix(), \ - a_indices->matrix(), \ - a_values->vec(), \ - b->matrix(), \ - scratch.vec()); \ - OP_REQUIRES_OK(ctx, functor_status); \ +#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ + if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ + Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ + Device, T, Tindices, ADJ_A, \ + ADJ_B>::Compute(ctx->eigen_device(), out->matrix(), \ + a_indices->matrix(), a_values->vec(), \ + b->matrix(), scratch.vec()); \ + OP_REQUIRES_OK(ctx, functor_status); \ } MAYBE_ADJOINT(false, false); @@ -164,17 +162,17 @@ class SparseTensorDenseMatMulOp : public OpKernel { bool adjoint_b_; }; -#define REGISTER_CPU(TypeT, TypeIndex) \ - REGISTER_KERNEL_BUILDER( \ - Name("SparseTensorDenseMatMul") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices") \ - .HostMemory("a_shape"), \ - SparseTensorDenseMatMulOp); \ +#define REGISTER_CPU(TypeT, TypeIndex) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseTensorDenseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices") \ + .HostMemory("a_shape"), \ + SparseTensorDenseMatMulOp); -#define REGISTER_KERNELS_CPU(T) \ - REGISTER_CPU(T, int64); \ +#define REGISTER_KERNELS_CPU(T) \ + REGISTER_CPU(T, int64); \ REGISTER_CPU(T, int32) REGISTER_KERNELS_CPU(float); @@ -186,16 +184,17 @@ REGISTER_KERNELS_CPU(complex128); #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \ - template <> \ - Status SparseTensorDenseMatMulFunctor::Compute( \ - const GPUDevice& d, typename TTypes::Matrix out, \ - typename TTypes::ConstMatrix a_indices, \ - typename TTypes::ConstVec a_values, \ - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch); \ - extern template struct SparseTensorDenseMatMulFunctor; +#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \ + template <> \ + Status SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, ADJ_A, \ + ADJ_B>::Compute(const GPUDevice& d, typename TTypes::Matrix out, \ + typename TTypes::ConstMatrix a_indices, \ + typename TTypes::ConstVec a_values, \ + typename TTypes::ConstMatrix b, \ + typename TTypes::Vec scratch); \ + extern template struct SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, ADJ_A, ADJ_B>; #define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B) \ DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \ @@ -223,8 +222,8 @@ DECLARE_ADJOINT_GPU_SPEC(float); .HostMemory("a_shape"), \ SparseTensorDenseMatMulOp); -#define REGISTER_KERNELS_GPU(T) \ - REGISTER_GPU(T, int64); \ +#define REGISTER_KERNELS_GPU(T) \ + REGISTER_GPU(T, int64); \ REGISTER_GPU(T, int32) REGISTER_KERNELS_GPU(float); @@ -254,10 +253,10 @@ struct SparseTensorDenseMatMulFunctor { static const std::size_t kNumVectorize = 32; static Status Compute(const CPUDevice& d, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, - typename TTypes::Vec scratch) { + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b, + typename TTypes::Vec scratch) { const std::size_t nnz = a_values.size(); const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index ebf52e020ab..e707743f782 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -25,16 +25,14 @@ namespace tensorflow { namespace functor { -template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status Compute( - const Device& d, - typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, - typename TTypes::Vec scratch); + static EIGEN_ALWAYS_INLINE Status + Compute(const Device& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b, typename TTypes::Vec scratch); }; template diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc index 209a9be6367..7266e0cf812 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -90,11 +90,11 @@ namespace functor { template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status Compute( - const GPUDevice& d, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch) { + static EIGEN_ALWAYS_INLINE Status + Compute(const GPUDevice& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b, typename TTypes::Vec scratch) { generator::SparseTensorDenseMatMulGPUGenerator sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices), To32Bit(a_values), To32Bit(b)); @@ -146,14 +146,14 @@ struct SparseTensorDenseMatMulFunctor { } // namespace functor -#define DEFINE(T, Tindices) \ - template struct functor::SparseTensorDenseMatMulFunctor< \ - GPUDevice, T, Tindices, false, false>; \ - template struct functor::SparseTensorDenseMatMulFunctor< \ - GPUDevice, T, Tindices, false, true>; \ - template struct functor::SparseTensorDenseMatMulFunctor< \ - GPUDevice, T, Tindices, true, false>; \ - template struct functor::SparseTensorDenseMatMulFunctor< \ +#define DEFINE(T, Tindices) \ + template struct functor::SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, false, false>; \ + template struct functor::SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, false, true>; \ + template struct functor::SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, true, false>; \ + template struct functor::SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, true, true>; DEFINE(float, int32); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 00a1566ae41..e540ecfa8d9 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -394,8 +394,10 @@ output: A `Tensor` with the concatenation of values stacked along the in `concat_dim` where it has the sum of the sizes. )doc"); +// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops +// are not to be made user-accessible. #ifdef INTEL_MKL -REGISTER_OP("MklConcatV2") +REGISTER_OP("_MklConcatV2") .Input("values: N * T") .Input("axis: Tidx") .Input("mkl_values: N * uint8") @@ -1659,7 +1661,7 @@ shape: Defines the shape of the output tensor. )Doc"); #ifdef INTEL_MKL -REGISTER_OP("MklReshape") +REGISTER_OP("_MklReshape") .Input("tensor: T") .Input("shape: Tshape") .Input("mkl_tensor: uint8") @@ -1671,7 +1673,7 @@ REGISTER_OP("MklReshape") .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); }) .Doc(R"Doc( MKL implementation of ReshapeOp. )Doc"); -#endif // INTEL_MKL +#endif // INTEL_MKL // -------------------------------------------------------------------------- REGISTER_OP("InvertPermutation") @@ -5001,7 +5003,7 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`: )doc"); #ifdef INTEL_MKL -REGISTER_OP("MklConcat") +REGISTER_OP("_MklConcat") .Input("concat_dim: int32") .Input("values: N * T") .Input("mkl_concat_dim: uint8") diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc index 2f55e45e377..02b13a455ce 100644 --- a/tensorflow/core/ops/audio_ops.cc +++ b/tensorflow/core/ops/audio_ops.cc @@ -100,6 +100,26 @@ Status SpectrogramShapeFn(InferenceContext* c) { return Status::OK(); } +Status MfccShapeFn(InferenceContext* c) { + ShapeHandle spectrogram; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + + int32 dct_coefficient_count; + TF_RETURN_IF_ERROR( + c->GetAttr("dct_coefficient_count", &dct_coefficient_count)); + + DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0); + DimensionHandle spectrogram_length = c->Dim(spectrogram, 1); + + DimensionHandle output_channels = c->MakeDim(dct_coefficient_count); + + c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length, + output_channels})); + return Status::OK(); +} + } // namespace REGISTER_OP("DecodeWav") @@ -200,4 +220,34 @@ magnitude_squared: Whether to return the squared magnitude or just the spectrogram: 3D representation of the audio frequencies as an image. )doc"); +REGISTER_OP("Mfcc") + .Input("spectrogram: float") + .Input("sample_rate: int32") + .Attr("upper_frequency_limit: float = 4000") + .Attr("lower_frequency_limit: float = 20") + .Attr("filterbank_channel_count: int = 40") + .Attr("dct_coefficient_count: int = 13") + .Output("output: float") + .SetShapeFn(MfccShapeFn) + .Doc(R"doc( +Transforms a spectrogram into a form that's useful for speech recognition. + +Mel Frequency Cepstral Coefficients are a way of representing audio data that's +been effective as an input feature for machine learning. They are created by +taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +higher frequencies that are less significant to the human ear. They have a long +history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +is a good resource to learn more. + +spectrogram: Typically produced by the Spectrogram op, with magnitude_squared + set to true. +sample_rate: How many samples per second the source audio used. +upper_frequency_limit: The highest frequency to use when calculating the + ceptstrum. +lower_frequency_limit: The lowest frequency to use when calculating the + ceptstrum. +filterbank_channel_count: Resolution of the Mel bank used internally. +dct_coefficient_count: How many output channels to produce per time slice. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 50e49713140..1781f778b49 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -11316,6 +11316,49 @@ op { } } } +op { + name: "Mfcc" + input_arg { + name: "spectrogram" + type: DT_FLOAT + } + input_arg { + name: "sample_rate" + type: DT_INT32 + } + output_arg { + name: "output" + type: DT_FLOAT + } + attr { + name: "upper_frequency_limit" + type: "float" + default_value { + f: 4000 + } + } + attr { + name: "lower_frequency_limit" + type: "float" + default_value { + f: 20 + } + } + attr { + name: "filterbank_channel_count" + type: "int" + default_value { + i: 40 + } + } + attr { + name: "dct_coefficient_count" + type: "int" + default_value { + i: 13 + } + } +} op { name: "Min" input_arg { diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 4dfc51490b8..0bce6fc0ea8 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -449,6 +449,11 @@ REGISTER_OP("FixedLengthRecordReader") A Reader that outputs fixed-length records from a file. reader_handle: The handle to reference the Reader. +header_bytes: Number of bytes in the header, defaults to 0. +record_bytes: Number of bytes in the record. +footer_bytes: Number of bytes in the footer, defaults to 0. +hop_bytes: Number of bytes to hop before each read. Default of 0 means using + record_bytes. container: If non-empty, this reader is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this reader is named in the given bucket @@ -469,6 +474,11 @@ REGISTER_OP("FixedLengthRecordReaderV2") A Reader that outputs fixed-length records from a file. reader_handle: The handle to reference the Reader. +header_bytes: Number of bytes in the header, defaults to 0. +record_bytes: Number of bytes in the record. +footer_bytes: Number of bytes in the footer, defaults to 0. +hop_bytes: Number of bytes to hop before each read. Default of 0 means using + record_bytes. container: If non-empty, this reader is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this reader is named in the given bucket diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 49379773254..932113bf2c4 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -2612,7 +2612,7 @@ scale_after_normalization: A bool indicating whether the resulted tensor )doc"); #ifdef INTEL_MKL -REGISTER_OP("MklConv2D") +REGISTER_OP("_MklConv2D") .Input("input: T") .Input("filter: T") .Input("mkl_input: uint8") @@ -2632,7 +2632,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklConv2DWithBias") +REGISTER_OP("_MklConv2DWithBias") .Input("input: T") .Input("filter: T") .Input("bias: T") @@ -2654,7 +2654,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklConv2DBackpropFilter") +REGISTER_OP("_MklConv2DBackpropFilter") .Input("input: T") .Input("filter_sizes: int32") .Input("out_backprop: T") @@ -2679,7 +2679,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklConv2DWithBiasBackpropBias") +REGISTER_OP("_MklConv2DWithBiasBackpropBias") .Input("out_backprop: T") .Input("mkl_out_backprop: uint8") .Output("output: T") @@ -2695,7 +2695,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklConv2DBackpropInput") +REGISTER_OP("_MklConv2DBackpropInput") .Input("input_sizes: int32") .Input("filter: T") .Input("out_backprop: T") @@ -2720,7 +2720,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklRelu") +REGISTER_OP("_MklRelu") .Input("features: T") .Input("mkl_features: uint8") .Output("activations: T") @@ -2734,7 +2734,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklReluGrad") +REGISTER_OP("_MklReluGrad") .Input("gradients: T") .Input("features: T") .Input("mkl_gradients: uint8") @@ -2751,7 +2751,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklMaxPool") +REGISTER_OP("_MklMaxPool") .Attr("T: {float, half} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -2773,7 +2773,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklMaxPoolGrad") +REGISTER_OP("_MklMaxPoolGrad") .Attr("T: {float, half} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -2801,7 +2801,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklAvgPool") +REGISTER_OP("_MklAvgPool") .Input("value: T") .Input("mkl_input: uint8") .Output("output: T") @@ -2820,7 +2820,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklAvgPoolGrad") +REGISTER_OP("_MklAvgPoolGrad") .Input("orig_input_shape: int32") .Input("grad: T") .Input("mkl_orig_input: uint8") @@ -2843,7 +2843,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklLRN") +REGISTER_OP("_MklLRN") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") @@ -2867,7 +2867,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklLRNGrad") +REGISTER_OP("_MklLRNGrad") .Input("input_grads: T") .Input("input_image: T") .Input("output_image: T") @@ -2900,7 +2900,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklFusedBatchNorm") +REGISTER_OP("_MklFusedBatchNorm") .Input("x: T") .Input("scale: T") .Input("offset: T") @@ -2966,7 +2966,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklFusedBatchNormGrad") +REGISTER_OP("_MklFusedBatchNormGrad") .Input("y_backprop: T") .Input("x: T") .Input("scale: T") @@ -3048,7 +3048,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); -REGISTER_OP("MklToTf") +REGISTER_OP("_MklToTf") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 17cc2517cbb..cbbabe0b876 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -11539,6 +11539,57 @@ op { summary: "V2 format specific: merges the metadata files of sharded checkpoints. The" description: "result is one logical checkpoint, with one physical metadata file and renamed\ndata files.\n\nIntended for \"grouping\" multiple checkpoints in a sharded checkpoint setup.\n\nIf delete_old_dirs is true, attempts to delete recursively the dirname of each\npath in the input checkpoint_prefixes. This is useful when those paths are non\nuser-facing temporary locations." } +op { + name: "Mfcc" + input_arg { + name: "spectrogram" + description: "Typically produced by the Spectrogram op, with magnitude_squared\nset to true." + type: DT_FLOAT + } + input_arg { + name: "sample_rate" + description: "How many samples per second the source audio used." + type: DT_INT32 + } + output_arg { + name: "output" + type: DT_FLOAT + } + attr { + name: "upper_frequency_limit" + type: "float" + default_value { + f: 4000 + } + description: "The highest frequency to use when calculating the\nceptstrum." + } + attr { + name: "lower_frequency_limit" + type: "float" + default_value { + f: 20 + } + description: "The lowest frequency to use when calculating the\nceptstrum." + } + attr { + name: "filterbank_channel_count" + type: "int" + default_value { + i: 40 + } + description: "Resolution of the Mel bank used internally." + } + attr { + name: "dct_coefficient_count" + type: "int" + default_value { + i: 13 + } + description: "How many output channels to produce per time slice." + } + summary: "Transforms a spectrogram into a form that\'s useful for speech recognition." + description: "Mel Frequency Cepstral Coefficients are a way of representing audio data that\'s\nbeen effective as an input feature for machine learning. They are created by\ntaking the spectrum of a spectrogram (a \'cepstrum\'), and discarding some of the\nhigher frequencies that are less significant to the human ear. They have a long\nhistory in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum\nis a good resource to learn more." +} op { name: "Min" input_arg { diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 2493ea9c761..897b174eff2 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -218,7 +218,7 @@ class MklShape { (IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_ #define SIZES_OFFSET(dims) \ (DIMS_OFFSET + \ - sizeof(size_t)) // Location of sizes. Note dim is not used here, left here + sizeof(size_t)) // Location of sizes. Note dim is not used here, left here // to make macros consistent. #define STRIDES_OFFSET(dims) \ (SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides @@ -228,7 +228,7 @@ class MklShape { (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_ #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \ (TF_LAYOUT_OFFSET(dims) + \ - SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_ + SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_ // TODO(agramesh1) make sure to create a const to share with rewrite pass // for min size of MKL metadata tensor. @@ -315,25 +315,22 @@ inline bool AreAllMklTensors(const MklShapeList& shapes) { } template -inline Tensor ConvertMklToTF(OpKernelContext *context, - const Tensor& mkl_tensor, +inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, const MklShape& mkl_shape) { Tensor output_tensor; TensorShape output_shape; for (size_t j = 0; j < mkl_shape.GetDimension(); j++) { - // Outermost to innermost dimension - output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]); + // Outermost to innermost dimension + output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]); } // Allocate output tensor. - context->allocate_temp(DataTypeToEnum::v(), - output_shape, &output_tensor); + context->allocate_temp(DataTypeToEnum::v(), output_shape, &output_tensor); - dnnLayout_t output_layout = static_cast( - mkl_shape.GetTfLayout()); - void *input_buffer = const_cast(mkl_tensor.flat().data()); - void *output_buffer = const_cast(output_tensor.flat().data()); + dnnLayout_t output_layout = static_cast(mkl_shape.GetTfLayout()); + void* input_buffer = const_cast(mkl_tensor.flat().data()); + void* output_buffer = const_cast(output_tensor.flat().data()); if (mkl_tensor.NumElements() != 0) { mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer); @@ -394,14 +391,15 @@ int inline GetTensorMetaDataIndex(int n, int total_tensors) { return DataIndexToMetaDataIndex(tidx, total_tensors); } - // Get the MKL shape from the second string tensor inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) { mklshape->DeSerializeMklShape( - ctext->input( - GetTensorMetaDataIndex(n, ctext->num_inputs())).flat().data(), - ctext->input( - GetTensorMetaDataIndex(n, ctext->num_inputs())).flat().size() * + ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) + .flat() + .data(), + ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) + .flat() + .size() * sizeof(uint8)); } @@ -410,8 +408,8 @@ inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) { return ctext->input(GetTensorDataIndex(n, ctext->num_inputs())); } -inline void GetMklInputList(OpKernelContext* ctext, - StringPiece name, OpInputList* input_tensors) { +inline void GetMklInputList(OpKernelContext* ctext, StringPiece name, + OpInputList* input_tensors) { CHECK_NOTNULL(input_tensors); ctext->input_list(name, input_tensors); } @@ -423,8 +421,8 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name, for (int i = 0; i < input_mkl_tensors.size(); i++) { (*mkl_shapes)[i].DeSerializeMklShape( - input_mkl_tensors[i].flat().data(), - input_mkl_tensors[i].flat().size() * sizeof(uint8)); + input_mkl_tensors[i].flat().data(), + input_mkl_tensors[i].flat().size() * sizeof(uint8)); } } @@ -435,9 +433,9 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, Tensor* second_tensor = nullptr; TensorShape second_shape; second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension())); - OP_REQUIRES_OK(ctext, ctext->allocate_output(GetTensorMetaDataIndex(n, - ctext->num_outputs()), - second_shape, &second_tensor)); + OP_REQUIRES_OK(ctext, ctext->allocate_output( + GetTensorMetaDataIndex(n, ctext->num_outputs()), + second_shape, &second_tensor)); mkl_shape.SerializeMklShape( second_tensor->flat().data(), second_tensor->flat().size() * sizeof(uint8)); @@ -453,13 +451,11 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, TensorShape second_shape; second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension())); OP_REQUIRES_OK( - ctext, - ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()), - tf_shape, output)); - OP_REQUIRES_OK( - ctext, - ctext->allocate_output(GetTensorMetaDataIndex(n, ctext->num_outputs()), - second_shape, &second_tensor)); + ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()), + tf_shape, output)); + OP_REQUIRES_OK(ctext, ctext->allocate_output( + GetTensorMetaDataIndex(n, ctext->num_outputs()), + second_shape, &second_tensor)); mkl_shape.SerializeMklShape( second_tensor->flat().data(), second_tensor->flat().size() * sizeof(uint8)); @@ -499,7 +495,8 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides, inline void MklSizesToTFSizes(OpKernelContext* context, TensorFormat data_format_, - const MklShape& mkl_shape, TensorShape* tf_shape) { + const MklShape& mkl_shape, + TensorShape* tf_shape) { size_t tf_dim = mkl_shape.GetDimension(); const size_t* tf_sizes = mkl_shape.GetSizes(); @@ -545,8 +542,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) { return mkl_shape.dim_size(index); } -inline void CopyMklTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -566,9 +563,8 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, context->set_output(idx_meta_out, meta_output); } -inline void CopyTFTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out, - const TensorShape& shape) { +inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out, const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); diff --git a/tensorflow/docs_src/api_guides/python/index.md b/tensorflow/docs_src/api_guides/python/index.md index 0e624df55b7..177f19bc80d 100644 --- a/tensorflow/docs_src/api_guides/python/index.md +++ b/tensorflow/docs_src/api_guides/python/index.md @@ -43,6 +43,7 @@ * [Random variable transformations (contrib)](contrib.distributions.bijector.md) * [RNN and Cells (contrib)](contrib.rnn.md) * [Seq2seq Library (contrib)](contrib.seq2seq.md) +* [Staging (contrib)](contrib.staging.md) * [Statistical Distributions (contrib)](contrib.distributions.md) * [Training (contrib)](contrib.training.md) * [Utilities (contrib)](contrib.util.md) diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index 90497fae092..dc0d8703158 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -197,11 +197,15 @@ if __name__ == '__main__': help='Initial learning rate') parser.add_argument('--dropout', type=float, default=0.9, help='Keep probability for training dropout.') - parser.add_argument('--data_dir', type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - parser.add_argument('--log_dir', type=str, - default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', - help='Summaries log directory') + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + parser.add_argument( + '--log_dir', + type=str, + default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', + help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index e832690e183..b21e8fd4481 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2456,6 +2456,83 @@ func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf return op.Output(0) } +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) + +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["upper_frequency_limit"] = value + } +} + +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value + } +} + +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. +// +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. +// +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Mfcc", + Input: []tf.Input{ + spectrogram, sample_rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // UniqueAttr is an optional argument to Unique. type UniqueAttr func(optionalAttr) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c9449d83ce7..c367d20f816 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1670,6 +1670,7 @@ py_library( deps = [ ":array_ops", ":framework_for_generated_wrappers", + ":layers_base", ":util", ], ) @@ -2380,6 +2381,39 @@ py_test( ], ) +py_test( + name = "tf_contextlib_test", + size = "small", + srcs = ["util/tf_contextlib_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + ], +) + +py_test( + name = "tf_decorator_test", + size = "small", + srcs = ["util/tf_decorator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + ], +) + +py_test( + name = "tf_inspect_test", + size = "small", + srcs = ["util/tf_inspect_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + ], +) + py_library( name = "util_example_parser_configuration", srcs = ["util/example_parser_configuration.py"], @@ -3210,15 +3244,10 @@ py_tests( ) py_library( - name = "layers", + name = "layers_base", srcs = [ "layers/__init__.py", "layers/base.py", - "layers/convolutional.py", - "layers/core.py", - "layers/layers.py", - "layers/normalization.py", - "layers/pooling.py", "layers/utils.py", ], srcs_version = "PY2AND3", @@ -3228,6 +3257,31 @@ py_library( ":framework", ":framework_for_generated_wrappers", ":init_ops", + ":util", + ":variable_scope", + ":variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "layers", + srcs = [ + "layers/convolutional.py", + "layers/core.py", + "layers/layers.py", + "layers/normalization.py", + "layers/pooling.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":control_flow_ops", + ":framework", + ":framework_for_generated_wrappers", + ":init_ops", + ":layers_base", ":math_ops", ":nn", ":standard_ops", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index d2945adf75c..864a96ef348 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -26,11 +26,9 @@ import tensorflow as tf import ctypes import importlib -import inspect import sys import traceback - # TODO(drpng): write up instructions for editing this file in a doc and point to # the doc instead. # If you want to edit this file to expose modules in public tensorflow API, you @@ -170,7 +168,7 @@ _allowed_symbols.extend([ 'parse_single_sequence_example', 'serialize_many_sparse', 'serialize_sparse', - 'sparse_matmul', ## use tf.matmul instead. + 'sparse_matmul', ## use tf.matmul instead. ]) # This is needed temporarily because we import it explicitly. diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 3813aa996b3..f7e17f1c53d 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -601,6 +601,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:util", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py index d4a84f62cde..01e36f754c2 100644 --- a/tensorflow/python/debug/__init__.py +++ b/tensorflow/python/debug/__init__.py @@ -25,6 +25,7 @@ See the @{$python/tfdbg} guide. @@has_inf_or_nan @@DumpingDebugHook @@DumpingDebugWrapperSession +@@GrpcDebugWrapperSession @@LocalCLIDebugHook @@LocalCLIDebugWrapperSession @@WatchOptions @@ -46,6 +47,7 @@ from tensorflow.python.debug.lib.debug_utils import watch_graph_with_blacklists from tensorflow.python.debug.wrappers.dumping_wrapper import DumpingDebugWrapperSession from tensorflow.python.debug.wrappers.framework import WatchOptions +from tensorflow.python.debug.wrappers.grpc_wrapper import GrpcDebugWrapperSession from tensorflow.python.debug.wrappers.hooks import DumpingDebugHook from tensorflow.python.debug.wrappers.hooks import LocalCLIDebugHook from tensorflow.python.debug.wrappers.local_cli_wrapper import LocalCLIDebugWrapperSession diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index e62a0f611f7..8b191f332e8 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import os import shutil import tempfile @@ -41,10 +40,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect def line_number_above(): - return inspect.stack()[1][2] - 1 + return tf_inspect.stack()[1][2] - 1 def parse_op_and_node(line): @@ -503,7 +503,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): cls._main_device = "/job:localhost/replica:0/task:0/cpu:0" cls._curr_file_path = os.path.abspath( - inspect.getfile(inspect.currentframe())) + tf_inspect.getfile(tf_inspect.currentframe())) cls._sess = session.Session() with cls._sess as sess: diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index a76dd4f6d60..bb457a01b23 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -820,7 +820,7 @@ class DebugDumpDir(object): self._node_op_types[node.name] = node.op for inp in node.input: - if is_copy_node(inp) and node.op == "_Send": + if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): self._copy_send_nodes.append(node.name) if inp.startswith("^"): diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index 6010723d46c..a4fb0d99109 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import os import shutil import tempfile @@ -37,10 +36,11 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.util import tf_inspect def line_number_above(): - return inspect.stack()[1][2] - 1 + return tf_inspect.stack()[1][2] - 1 class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): @@ -52,21 +52,21 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): ops.reset_default_graph() def testGuessedBaseDirIsProbablyCorrect(self): - self.assertEqual( - "tensorflow", os.path.basename(source_utils._TENSORFLOW_BASEDIR)) + self.assertEqual("tensorflow", + os.path.basename(source_utils._TENSORFLOW_BASEDIR)) def testUnitTestFileReturnsFalse(self): - self.assertFalse(source_utils._guess_is_tensorflow_py_library( - self.curr_file_path)) + self.assertFalse( + source_utils._guess_is_tensorflow_py_library(self.curr_file_path)) def testSourceUtilModuleReturnsTrue(self): - self.assertTrue(source_utils._guess_is_tensorflow_py_library( - source_utils.__file__)) + self.assertTrue( + source_utils._guess_is_tensorflow_py_library(source_utils.__file__)) def testFileInPythonKernelsPathReturnsTrue(self): x = constant_op.constant(42.0, name="x") - self.assertTrue(source_utils._guess_is_tensorflow_py_library( - x.op.traceback[-1][0])) + self.assertTrue( + source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0])) def testNonPythonFileRaisesException(self): with self.assertRaisesRegexp(ValueError, r"is not a Python source file"): @@ -85,7 +85,7 @@ class SourceHelperTest(test_util.TensorFlowTestCase): self.dump_root = self.get_temp_dir() self.curr_file_path = os.path.abspath( - inspect.getfile(inspect.currentframe())) + tf_inspect.getfile(tf_inspect.currentframe())) # Run a simple TF graph to generate some debug dumps that can be used in # source annotation. @@ -135,27 +135,21 @@ class SourceHelperTest(test_util.TensorFlowTestCase): self.assertIn(self.u_init.op.name, source_annotation[self.u_init_line_number]) - self.assertIn(self.u.op.name, - source_annotation[self.u_line_number]) + self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) self.assertIn(self.v_init.op.name, source_annotation[self.v_init_line_number]) - self.assertIn(self.v.op.name, - source_annotation[self.v_line_number]) - self.assertIn(self.w.op.name, - source_annotation[self.w_line_number]) + self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) + self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) # In the non-stack-top (default) mode, the helper line should be annotated # with all the ops as well. self.assertIn(self.u_init.op.name, source_annotation[self.helper_line_number]) - self.assertIn(self.u.op.name, - source_annotation[self.helper_line_number]) + self.assertIn(self.u.op.name, source_annotation[self.helper_line_number]) self.assertIn(self.v_init.op.name, source_annotation[self.helper_line_number]) - self.assertIn(self.v.op.name, - source_annotation[self.helper_line_number]) - self.assertIn(self.w.op.name, - source_annotation[self.helper_line_number]) + self.assertIn(self.v.op.name, source_annotation[self.helper_line_number]) + self.assertIn(self.w.op.name, source_annotation[self.helper_line_number]) def testAnnotateWithStackTopGivesCorrectResult(self): source_annotation = source_utils.annotate_source( @@ -163,14 +157,11 @@ class SourceHelperTest(test_util.TensorFlowTestCase): self.assertIn(self.u_init.op.name, source_annotation[self.u_init_line_number]) - self.assertIn(self.u.op.name, - source_annotation[self.u_line_number]) + self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) self.assertIn(self.v_init.op.name, source_annotation[self.v_init_line_number]) - self.assertIn(self.v.op.name, - source_annotation[self.v_line_number]) - self.assertIn(self.w.op.name, - source_annotation[self.w_line_number]) + self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) + self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) # In the stack-top mode, the helper line should not have been annotated. self.assertNotIn(self.helper_line_number, source_annotation) @@ -182,8 +173,7 @@ class SourceHelperTest(test_util.TensorFlowTestCase): min_line=self.u_line_number, max_line=self.u_line_number + 1) - self.assertIn(self.u.op.name, - source_annotation[self.u_line_number]) + self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) self.assertNotIn(self.v_line_number, source_annotation) def testAnnotateDumpedTensorsGivesCorrectResult(self): @@ -192,26 +182,17 @@ class SourceHelperTest(test_util.TensorFlowTestCase): # Note: Constant Tensors u_init and v_init may not get dumped due to # constant-folding. - self.assertIn(self.u.name, - source_annotation[self.u_line_number]) - self.assertIn(self.v.name, - source_annotation[self.v_line_number]) - self.assertIn(self.w.name, - source_annotation[self.w_line_number]) + self.assertIn(self.u.name, source_annotation[self.u_line_number]) + self.assertIn(self.v.name, source_annotation[self.v_line_number]) + self.assertIn(self.w.name, source_annotation[self.w_line_number]) - self.assertNotIn(self.u.op.name, - source_annotation[self.u_line_number]) - self.assertNotIn(self.v.op.name, - source_annotation[self.v_line_number]) - self.assertNotIn(self.w.op.name, - source_annotation[self.w_line_number]) + self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number]) + self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number]) + self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number]) - self.assertIn(self.u.name, - source_annotation[self.helper_line_number]) - self.assertIn(self.v.name, - source_annotation[self.helper_line_number]) - self.assertIn(self.w.name, - source_annotation[self.helper_line_number]) + self.assertIn(self.u.name, source_annotation[self.helper_line_number]) + self.assertIn(self.v.name, source_annotation[self.helper_line_number]) + self.assertIn(self.w.name, source_annotation[self.helper_line_number]) def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self): self.dump.set_python_graph(None) @@ -224,8 +205,9 @@ class SourceHelperTest(test_util.TensorFlowTestCase): with open(unrelated_source_path, "wt") as source_file: source_file.write("print('hello, world')\n") - self.assertEqual( - {}, source_utils.annotate_source(self.dump, unrelated_source_path)) + self.assertEqual({}, + source_utils.annotate_source(self.dump, + unrelated_source_path)) # Clean up unrelated source file. os.remove(unrelated_source_path) @@ -238,7 +220,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): self.dump_root = self.get_temp_dir() self.curr_file_path = os.path.abspath( - inspect.getfile(inspect.currentframe())) + tf_inspect.getfile(tf_inspect.currentframe())) # Run a simple TF graph to generate some debug dumps that can be used in # source annotation. diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 449cb54c841..ac3cda4ff16 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import copy -import inspect import os import tempfile @@ -48,6 +47,9 @@ from tensorflow.python.training import monitored_session from tensorflow.python.training import saver from tensorflow.python.training import training from tensorflow.python.util import compat +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + _VALID_MODEL_FN_ARGS = set( ['features', 'labels', 'mode', 'params', 'config']) @@ -524,7 +526,7 @@ class Estimator(object): Raises: ValueError: if model_fn returns invalid objects. """ - model_fn_args = _get_arguments(self._model_fn).args + model_fn_args = _model_fn_args(self._model_fn) kwargs = {} if 'mode' in model_fn_args: kwargs['mode'] = mode @@ -704,35 +706,45 @@ def _get_replica_device_setter(config): return None -def _get_arguments(func): - """Returns a spec of given func.""" - if hasattr(func, '__code__'): - # Regular function. - return inspect.getargspec(func) - elif hasattr(func, '__call__'): - # Callable object. - return _get_arguments(func.__call__) - elif hasattr(func, 'func'): - # Partial function. - return _get_arguments(func.func) +def _model_fn_args(fn): + """Get argument names for function-like object. + + Args: + fn: Function, or function-like object (e.g., result of `functools.partial`). + + Returns: + `tuple` of string argument names. + + Raises: + ValueError: if partial function has positionally bound arguments + """ + _, fn = tf_decorator.unwrap(fn) + if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): + # Handle functools.partial and similar objects. + return tuple([ + arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):] + if arg not in set(fn.keywords.keys()) + ]) + # Handle function. + return tuple(tf_inspect.getargspec(fn).args) def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" - fn_spec = _get_arguments(model_fn) - if 'features' not in fn_spec.args: + args = _model_fn_args(model_fn) + if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) - if 'labels' not in fn_spec.args: + if 'labels' not in args: raise ValueError('model_fn (%s) must include labels argument.' % model_fn) - if params is not None and 'params' not in fn_spec.args: + if params is not None and 'params' not in args: raise ValueError('model_fn (%s) does not include params argument, ' 'but params (%s) is passed to Estimator.' % (model_fn, params)) - if params is None and 'params' in fn_spec.args: + if params is None and 'params' in args: logging.warning('Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', model_fn) - non_valid_args = list(set(fn_spec.args) - _VALID_MODEL_FN_ARGS) + non_valid_args = list(set(args) - _VALID_MODEL_FN_ARGS) if non_valid_args: raise ValueError('model_fn (%s) has following not expected args: %s' % (model_fn, non_valid_args)) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 89a9483e201..3b46db59e30 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os import tempfile import numpy as np +import six from google.protobuf import text_format @@ -38,6 +40,7 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import layers from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops @@ -262,8 +265,120 @@ def model_fn_global_step_incrementer(features, labels, mode): train_op=state_ops.assign_add(global_step, 1)) +def _estimator_spec( + expected_features, expected_labels, actual_features, actual_labels, mode): + assert_ops = tuple([ + check_ops.assert_equal( + expected_features[k], actual_features[k], name='assert_%s' % k) + for k in expected_features + ] + [ + check_ops.assert_equal( + expected_labels, actual_labels, name='assert_labels') + ]) + with ops.control_dependencies(assert_ops): + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant(0.), + loss=constant_op.constant(0.), + train_op=constant_op.constant(0.)) + + +def _make_input_fn(features, labels): + def _input_fn(): + return { + k: constant_op.constant(v) + for k, v in six.iteritems(features) + }, constant_op.constant(labels) + return _input_fn + + class EstimatorTrainTest(test.TestCase): + def test_minimal_model_fn_args(self): + expected_features = {'x': 42., 'y': 43.} + expected_labels = 44. + + # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments + # doesn't work with mock fns. + model_fn_call_count = [0] + + def _model_fn(features, labels): + model_fn_call_count[0] += 1 + self.assertItemsEqual(expected_features.keys(), features.keys()) + return _estimator_spec( + expected_features, expected_labels, features, labels, + model_fn_lib.ModeKeys.TRAIN) + + with self.assertRaisesRegexp(ValueError, 'does not include params'): + estimator.Estimator(model_fn=_model_fn, params={'a': 'b'}) + est = estimator.Estimator(model_fn=_model_fn, config=run_config.RunConfig()) + self.assertEqual(0, model_fn_call_count[0]) + est.train( + input_fn=_make_input_fn(expected_features, expected_labels), steps=1) + self.assertEqual(1, model_fn_call_count[0]) + + def test_all_model_fn_args(self): + expected_features = {'x': 42., 'y': 43.} + expected_labels = 44. + expected_params = {'some_param': 'some_value'} + expected_config = run_config.RunConfig() + expected_config.i_am_test = True + + # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments + # doesn't work with mock fns. + model_fn_call_count = [0] + + # Note that args are all passed by keyword, so can be in any order. + def _model_fn(mode, params, features, labels, config): + model_fn_call_count[0] += 1 + self.assertItemsEqual(expected_features.keys(), features.keys()) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, mode) + self.assertEqual(expected_params, params) + self.assertTrue(config.i_am_test) + return _estimator_spec( + expected_features, expected_labels, features, labels, mode) + + est = estimator.Estimator( + model_fn=_model_fn, params=expected_params, config=expected_config) + self.assertEqual(0, model_fn_call_count[0]) + est.train( + input_fn=_make_input_fn(expected_features, expected_labels), steps=1) + self.assertEqual(1, model_fn_call_count[0]) + + def test_partial_model_fn_args(self): + expected_features = {'x': 42., 'y': 43.} + expected_labels = 44. + expected_params = {'some_param': 'some_value'} + expected_config = run_config.RunConfig() + expected_config.i_am_test = True + expected_foo = 45. + expected_bar = 46. + + # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments + # doesn't work with mock fns. + model_fn_call_count = [0] + + def _model_fn(features, labels, foo, mode, params, config, bar): + model_fn_call_count[0] += 1 + self.assertEqual(expected_foo, foo) + self.assertEqual(expected_bar, bar) + self.assertItemsEqual(expected_features.keys(), features.keys()) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, mode) + self.assertEqual(expected_params, params) + self.assertTrue(config.i_am_test) + return _estimator_spec( + expected_features, expected_labels, features, labels, mode) + partial_model_fn = functools.partial( + _model_fn, foo=expected_foo, bar=expected_bar) + + est = estimator.Estimator( + model_fn=partial_model_fn, params=expected_params, + config=expected_config) + self.assertEqual(0, model_fn_call_count[0]) + est.train( + input_fn=_make_input_fn(expected_features, expected_labels), steps=1) + self.assertEqual(1, model_fn_call_count[0]) + def test_model_fn_must_return_estimator_spec(self): def model_fn(features, labels): diff --git a/tensorflow/python/framework/contrib_test.py b/tensorflow/python/framework/contrib_test.py index 8ca0c69d775..f2eaf7c2eea 100644 --- a/tensorflow/python/framework/contrib_test.py +++ b/tensorflow/python/framework/contrib_test.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect class ContribTest(test.TestCase): @@ -29,17 +28,17 @@ class ContribTest(test.TestCase): # pylint: disable=g-import-not-at-top import tensorflow as tf _ = tf.contrib.layers # `tf.contrib` is loaded lazily on first use. - assert inspect.ismodule(tf.contrib) + assert tf_inspect.ismodule(tf.contrib) def testLayers(self): # pylint: disable=g-import-not-at-top import tensorflow as tf - assert inspect.ismodule(tf.contrib.layers) + assert tf_inspect.ismodule(tf.contrib.layers) def testLinearOptimizer(self): # pylint: disable=g-import-not-at-top import tensorflow as tf - assert inspect.ismodule(tf.contrib.linear_optimizer) + assert tf_inspect.ismodule(tf.contrib.linear_optimizer) if __name__ == '__main__': diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 8b156db6dc4..2a1389b91ff 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -23,7 +23,6 @@ from __future__ import print_function import collections import hashlib -import inspect import re from tensorflow.core.framework import attr_value_pb2 @@ -36,6 +35,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect def _make_argname_from_tensor_name(name): @@ -259,10 +260,11 @@ def _call(sig, *inputs, **kwargs): def _get_func_name(func): + _, func = tf_decorator.unwrap(func) if callable(func): - if inspect.isfunction(func): + if tf_inspect.isfunction(func): return func.__name__ - elif inspect.ismethod(func): + elif tf_inspect.ismethod(func): return "%s.%s" % (func.__self__.__name__, func.__name__) else: # Probably a class instance with __call__ return type(func) @@ -955,7 +957,7 @@ class Defun(object): raise ValueError("func %s must be callable" % func) # Func should not use kwargs and defaults. - argspec = inspect.getargspec(func) + argspec = tf_inspect.getargspec(func) if argspec.keywords or argspec.defaults: raise ValueError("Functions with argument defaults or keyword " "arguments are not supported.") @@ -966,7 +968,7 @@ class Defun(object): if argspec.varargs: max_args = 1000000 argnames = argspec.args - if inspect.ismethod(func): + if tf_inspect.ismethod(func): # 1st argument is the "class" type. min_args -= 1 argnames = argnames[1:] diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 7f2b03e3509..2c39f5b0e37 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -19,8 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - import six from tensorflow.core.framework import attr_value_pb2 @@ -33,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import tf_contextlib def _Attr(op_def, name): @@ -241,7 +240,7 @@ class _OpInfo(object): # pylint: disable=g-doc-return-or-yield -@contextlib.contextmanager +@tf_contextlib.contextmanager def _MaybeColocateWith(inputs): """A context manager for (maybe) colocating with a list of input tensors. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ebab40c0aab..6d2a38b3a6c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import collections -import contextlib import copy import linecache import re @@ -44,6 +43,7 @@ from tensorflow.python.framework import versions from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import tf_contextlib def _override_helper(clazz_object, operator, func): @@ -2725,7 +2725,7 @@ class Graph(object): if name in self._collections: del self._collections[name] - @contextlib.contextmanager + @tf_contextlib.contextmanager def _original_op(self, op): """Python 'with' handler to help annotate ops with their originator. @@ -2751,7 +2751,7 @@ class Graph(object): self._default_original_op = old_original_op # pylint: disable=g-doc-return-or-yield - @contextlib.contextmanager + @tf_contextlib.contextmanager def name_scope(self, name): r"""Returns a context manager that creates hierarchical names for operations. @@ -2924,7 +2924,7 @@ class Graph(object): """ return self._name_stack - @contextlib.contextmanager + @tf_contextlib.contextmanager def colocate_with(self, op, ignore_existing=False): """Returns a context manager that specifies an op to colocate with. @@ -2999,7 +2999,7 @@ class Graph(object): if ignore_existing: self._colocation_stack = current_stack - @contextlib.contextmanager + @tf_contextlib.contextmanager def device(self, device_name_or_function): """Returns a context manager that specifies the default device to use. @@ -3081,7 +3081,7 @@ class Graph(object): op._set_device(device_function(op)) # pylint: disable=g-doc-return-or-yield - @contextlib.contextmanager + @tf_contextlib.contextmanager def container(self, container_name): """Returns a context manager that specifies the resource container to use. @@ -3349,7 +3349,7 @@ class Graph(object): return self._ControlDependenciesController(self, control_ops) # pylint: disable=g-doc-return-or-yield - @contextlib.contextmanager + @tf_contextlib.contextmanager def _attr_scope(self, attr_map): """EXPERIMENTAL: A context manager for setting attributes on operators. @@ -3414,7 +3414,7 @@ class Graph(object): # pylint: enable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield - @contextlib.contextmanager + @tf_contextlib.contextmanager def _kernel_label_map(self, op_to_kernel_label_map): """EXPERIMENTAL: A context manager for setting kernel labels. @@ -3476,7 +3476,7 @@ class Graph(object): # pylint: enable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield - @contextlib.contextmanager + @tf_contextlib.contextmanager def gradient_override_map(self, op_type_map): """EXPERIMENTAL: A context manager for overriding gradient functions. @@ -3634,7 +3634,7 @@ class _DefaultStack(threading.local): def enforce_nesting(self, value): self._enforce_nesting = value - @contextlib.contextmanager + @tf_contextlib.contextmanager def get_controller(self, default): """A context manager for manipulating a default stack.""" try: @@ -4137,7 +4137,7 @@ def get_all_collection_keys(): # pylint: disable=g-doc-return-or-yield -@contextlib.contextmanager +@tf_contextlib.contextmanager def name_scope(name, default_name=None, values=None): """Returns a context manager for use when defining a Python op. @@ -4227,7 +4227,7 @@ def prepend_name_scope(name, import_scope): # pylint: disable=g-doc-return-or-yield -@contextlib.contextmanager +@tf_contextlib.contextmanager def op_scope(values, name, default_name=None): """DEPRECATED. Same as name_scope above, just different argument order.""" logging.warn("tf.op_scope(values, name, default_name) is deprecated," diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6c7cbbff9cb..00f6cc0d6d9 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -196,7 +196,7 @@ class ControlFlowTest(test.TestCase): with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, - lambda e: "The tensor returned for" in str(e)): + lambda e: "Retval[0] does not have value" in str(e)): dead_branch.eval() def testSwitchMergeLess(self): diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 15eeb762d8f..40fddd76ffa 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -447,7 +447,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase): [-100.0, -100.0, 100.0]]) labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) loss = losses.sigmoid_cross_entropy(labels, logits) - self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertEquals(logits.dtype, loss.dtype) + self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) self.assertAlmostEqual(0.0, loss.eval(), 3) def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self): @@ -456,6 +457,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): weights = array_ops.ones_like(logits, dtype=dtypes.float32) loss = losses.sigmoid_cross_entropy(labels, logits, weights) + self.assertEquals(logits.dtype, loss.dtype) with self.test_session() as sess: loss = sess.run(loss, @@ -471,6 +473,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): weights = array_ops.ones_like(logits, dtype=dtypes.float32) loss = losses.sigmoid_cross_entropy(labels, logits, weights) + self.assertEquals(logits.dtype, loss.dtype) with self.test_session() as sess: loss = sess.run(loss, @@ -487,7 +490,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase): [-100.0, -100.0, 100.0]]) labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) loss = losses.sigmoid_cross_entropy(labels, logits) - self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertEquals(logits.dtype, loss.dtype) + self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3) def testAllWrongSigmoidWithMeasurementSpecificWeights(self): @@ -498,7 +502,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) loss = losses.sigmoid_cross_entropy(labels, logits, weights) - self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertEquals(logits.dtype, loss.dtype) + self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) self.assertAlmostEqual(1700.0 / 7.0, loss.eval(), 3) def testMultiCorrectSigmoid(self): @@ -507,10 +512,43 @@ class SigmoidCrossEntropyLossTest(test.TestCase): [-100.0, 100.0, 100.0]]) labels = constant_op.constant([[1, 0, 1], [1, 1, 0], [0, 1, 1]]) loss = losses.sigmoid_cross_entropy(labels, logits) - self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertEquals(logits.dtype, loss.dtype) + self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) with self.test_session(): - self.assertAlmostEqual(loss.eval(), 0.0, 3) + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testSigmoidFloat64(self): + logits = constant_op.constant(( + (100.0, -100.0, 100.0), + (100.0, -100.0, 100.0), + (100.0, 100.0, -100.0) + ), dtype=dtypes.float64) + labels = constant_op.constant(( + (1, 0, 1), (1, 1, 0), (0, 1, 1) + ), dtype=dtypes.int64) + loss = losses.sigmoid_cross_entropy(labels, logits) + self.assertEquals(logits.dtype, loss.dtype) + + with self.test_session(): + self.assertAlmostEqual(44.444, loss.eval(), 3) + + def testSigmoidNoReduction(self): + logits = constant_op.constant(( + (100.0, -100.0, 100.0), + (100.0, -100.0, 100.0), + (100.0, 100.0, -100.0))) + labels = constant_op.constant(((1, 0, 1), (1, 1, 0), (0, 1, 1))) + loss = losses.sigmoid_cross_entropy( + labels, logits, reduction=losses.Reduction.NONE) + self.assertEquals(logits.dtype, loss.dtype) + + with self.test_session(): + self.assertAllClose(( + (0., 0., 0.), + (0., 100., 100.), + (100., 0., 100.) + ), loss.eval(), 3) def testSigmoidLabelSmoothingCorrect(self): with self.test_session(): @@ -530,7 +568,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase): label_smoothing = 0.1 loss = losses.sigmoid_cross_entropy( labels, logits, label_smoothing=label_smoothing) - self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertEquals(logits.dtype, loss.dtype) + self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) expected_value = (100.0 + 50.0 * label_smoothing) / 3.0 self.assertAlmostEqual(loss.eval(), expected_value, 3) @@ -541,6 +580,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): sigmoid_labels = constant_op.constant([[1, 0, 1]]) sigmoid_loss = losses.sigmoid_cross_entropy( sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing) + self.assertEquals(sigmoid_logits.dtype, sigmoid_loss.dtype) softmax_logits = constant_op.constant( [[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]]) @@ -1254,10 +1294,14 @@ class ComputeWeightedLossTest(test.TestCase): self.assertEqual(9, len(util.get_losses())) with self.test_session(g): for unweighted_loss in unweighted_losses: - if reduction == losses.Reduction.WEIGHTED_SUM: + if reduction == losses.Reduction.NONE: + self.assertAllClose(self._raw_losses, unweighted_loss.eval()) + elif reduction == losses.Reduction.SUM: self.assertAllClose( np.sum(self._raw_losses), unweighted_loss.eval()) - else: # losses.Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS + else: + # reduction one of losses.Reduction.MEAN and + # losses.Reduction.SUM_BY_NONZERO_WEIGHTS. self.assertAllClose( np.mean(self._raw_losses), unweighted_loss.eval()) @@ -1341,13 +1385,20 @@ class ComputeWeightedLossTest(test.TestCase): with self.test_session(g): weighted_losses = weights * self._raw_losses weighted_sum = np.sum(weighted_losses) - if reduction == losses.Reduction.WEIGHTED_SUM: + if reduction == losses.Reduction.NONE: + self.assertAllClose(weighted_losses, weighted_loss.eval()) + elif reduction == losses.Reduction.SUM: self.assertAllClose(weighted_sum, weighted_loss.eval()) - else: # losses.Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS + else: broadcast_weights = weights * np.ones_like(self._raw_losses) - self.assertAllClose( - weighted_sum / np.count_nonzero(broadcast_weights), - weighted_loss.eval()) + if reduction == losses.Reduction.MEAN: + self.assertAllClose( + weighted_sum / np.sum(broadcast_weights), + weighted_loss.eval()) + elif reduction == losses.Reduction.SUM_BY_NONZERO_WEIGHTS: + self.assertAllClose( + weighted_sum / np.count_nonzero(broadcast_weights), + weighted_loss.eval()) def test1x1x1Weight(self): self._test_valid_weights((((17.0,),),)) diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index d4950fa2830..10f34751d0b 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import gzip -import math import os import threading import zlib @@ -360,9 +359,11 @@ class FixedLengthRecordReaderTest(test.TestCase): return compat.as_bytes(str(f * 2 + r) * self._record_bytes) def _OverlappedRecord(self, f, r): - record_str = "".join( - [str(i)[0] for i in range( - r * self._hop_bytes, r * self._hop_bytes + self._record_bytes)]) + record_str = "".join([ + str(i)[0] + for i in range(r * self._hop_bytes, + r * self._hop_bytes + self._record_bytes) + ]) return compat.as_bytes(record_str) def _CreateFiles(self): @@ -380,13 +381,16 @@ class FixedLengthRecordReaderTest(test.TestCase): def _CreateOverlappedRecordFiles(self): filenames = [] for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_overlapped_record.%d.txt" % i) + fn = os.path.join(self.get_temp_dir(), + "fixed_length_overlapped_record.%d.txt" % i) filenames.append(fn) with open(fn, "wb") as f: f.write(b"H" * self._header_bytes) - all_records_str = "".join( - [str(i)[0] for i in range( - self._record_bytes + self._hop_bytes * (self._num_overlapped_records - 1))]) + all_records_str = "".join([ + str(i)[0] + for i in range(self._record_bytes + self._hop_bytes * + (self._num_overlapped_records - 1)) + ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) return filenames @@ -440,6 +444,7 @@ class FixedLengthRecordReaderTest(test.TestCase): "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value]) + class TFRecordReaderTest(test.TestCase): def setUp(self): diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py index 3af73907a6f..e8b94294b1b 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py @@ -41,7 +41,11 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase): return sparse_tensor.SparseTensor( indices=x_indices, values=x_values, dense_shape=x_shape), len(x_values) - def _randomTensor(self, size, values_dtype, adjoint=False, sparse=False, + def _randomTensor(self, + size, + values_dtype, + adjoint=False, + sparse=False, indices_dtype=np.int64): n, m = size x = np.random.randn(n, m).astype(values_dtype) @@ -58,8 +62,11 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase): indices_dtype): n, k, m = np.random.randint(1, 10, size=3) sp_t, nnz = self._randomTensor( - [n, k], values_dtype, adjoint=adjoint_a, sparse=True, - indices_dtype=indices_dtype) + [n, k], + values_dtype, + adjoint=adjoint_a, + sparse=True, + indices_dtype=indices_dtype) dense_t = self._randomTensor([k, m], values_dtype, adjoint=adjoint_b) matmul = sparse_ops.sparse_tensor_dense_matmul( @@ -78,7 +85,7 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase): for adjoint_a in [True, False]: for adjoint_b in [True, False]: name = "sparse_tensor_dense_matmul_%s_%s_%s_%s" % ( - adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__) + adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__) self._testGradients(adjoint_a, adjoint_b, name, values_dtype, indices_dtype) diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index e08cf6092c7..80991751860 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -45,7 +45,11 @@ def _maybe_complex(x): class SparseTensorDenseMatMulTest(test.TestCase): - def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False, + def _testMatmul(self, + x, + y, + adjoint_a=False, + adjoint_b=False, indices_dtype=np.int64): x_mat = np.matrix(x) if adjoint_a: diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 9b76585f9fb..ff9a777f191 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -25,7 +25,6 @@ from __future__ import print_function import copy import functools -import inspect import re from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np @@ -36,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect class _Layer(object): @@ -280,8 +280,7 @@ class _Layer(object): inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. - **Note**, the kwarg 'scope' is reserved for use by the Layer. - + **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). """ @@ -329,6 +328,8 @@ class _Layer(object): else: self.build(input_shapes) self._built = True + if 'scope' in tf_inspect.getargspec(self.call).args: + kwargs['scope'] = scope outputs = self.call(inputs, *args, **kwargs) # Apply activity regularization. @@ -366,19 +367,20 @@ class _Layer(object): setattr(result, k, copy.deepcopy(v, memo)) return result - def apply(self, inputs, **kwargs): + def apply(self, inputs, *args, **kwargs): """Apply the layer on a input. This simply wraps `self.__call__`. Arguments: inputs: Input tensor(s). + *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. Returns: Output tensor(s). """ - return self.__call__(inputs, **kwargs) + return self.__call__(inputs, *args, **kwargs) def _to_snake_case(name): diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index b8052379105..68ecc219e4f 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -391,8 +391,12 @@ class FixedLengthRecordReader(ReaderBase): """ # TODO(josh11b): Support serializing and restoring state. - def __init__(self, record_bytes, header_bytes=None, footer_bytes=None, - hop_bytes=None, name=None): + def __init__(self, + record_bytes, + header_bytes=None, + footer_bytes=None, + hop_bytes=None, + name=None): """Create a FixedLengthRecordReader. Args: @@ -403,8 +407,11 @@ class FixedLengthRecordReader(ReaderBase): name: A name for the operation (optional). """ rr = gen_io_ops._fixed_length_record_reader_v2( - record_bytes=record_bytes, header_bytes=header_bytes, - footer_bytes=footer_bytes, hop_bytes=hop_bytes, name=name) + record_bytes=record_bytes, + header_bytes=header_bytes, + footer_bytes=footer_bytes, + hop_bytes=hop_bytes, + name=name) super(FixedLengthRecordReader, self).__init__(rr) diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 45075d4d3c4..fc54553b0c3 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -27,21 +27,31 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util +from tensorflow.python.platform import tf_logging as logging -# TODO(ptucker): Per-example? Divided by batch_size? Divided by sum of weights? class Reduction(object): """Types of loss reduction.""" - # Batch sum of weighted losses. - WEIGHTED_SUM = "weighted_sum" + # Un-reduced weighted losses with the same shape as input. + NONE = "none" - # `WEIGHTED_SUM` divided by number of non-zero weights. - WEIGHTED_SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights" + # Scalar sum of `NONE`. + SUM = "weighted_sum" + + # Scalar `SUM` divided by sum of weights. + MEAN = "weighted_mean" + + # Scalar `SUM` divided by number of non-zero weights. + SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights" @classmethod def all(cls): - return (cls.WEIGHTED_SUM, cls.WEIGHTED_SUM_BY_NONZERO_WEIGHTS) + return ( + cls.NONE, + cls.SUM, + cls.MEAN, + cls.SUM_BY_NONZERO_WEIGHTS) @classmethod def validate(cls, key): @@ -127,7 +137,7 @@ def _num_present(losses, weights, per_batch=False): def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Computes the weighted loss. Args: @@ -140,7 +150,8 @@ def compute_weighted_loss( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss `Tensor` of the same type as `losses`. If `reduction` is + `NONE`, this has the same shape as `losses`; otherwise, it is scalar. Raises: ValueError: If `weights` is `None` or the shape is not compatible with @@ -156,9 +167,16 @@ def compute_weighted_loss( losses = math_ops.to_float(losses) weights = math_ops.to_float(weights) weighted_losses = math_ops.multiply(losses, weights) - loss = math_ops.reduce_sum(weighted_losses) - if reduction == Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS: - loss = _safe_mean(loss, _num_present(losses, weights)) + if reduction == Reduction.NONE: + loss = weighted_losses + else: + loss = math_ops.reduce_sum(weighted_losses) + if reduction == Reduction.MEAN: + loss = _safe_mean( + loss, + math_ops.reduce_sum(array_ops.ones_like(losses) * weights)) + elif reduction == Reduction.SUM_BY_NONZERO_WEIGHTS: + loss = _safe_mean(loss, _num_present(losses, weights)) # Convert the result back to the input type. loss = math_ops.cast(loss, input_dtype) @@ -169,7 +187,7 @@ def compute_weighted_loss( def absolute_difference( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds an Absolute Difference loss to the training procedure. `weights` acts as a coefficient for the loss. If a scalar is provided, then @@ -191,7 +209,8 @@ def absolute_difference( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shape of `predictions` doesn't match that of `labels` or @@ -210,7 +229,7 @@ def absolute_difference( def cosine_distance( labels, predictions, dim=None, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds a cosine-distance loss to the training procedure. Note that the function assumes that `predictions` and `labels` are already @@ -228,7 +247,8 @@ def cosine_distance( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If `predictions` shape doesn't match `labels` shape, or @@ -250,7 +270,7 @@ def cosine_distance( def hinge_loss(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds a hinge loss to the training procedure. Args: @@ -265,7 +285,8 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shapes of `logits` and `labels` don't match. @@ -285,7 +306,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds a Huber Loss term to the training procedure. For each value x in `error=labels-predictions`, the following is calculated: @@ -320,7 +341,8 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shape of `predictions` doesn't match that of `labels` or @@ -347,7 +369,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds a Log Loss term to the training procedure. `weights` acts as a coefficient for the loss. If a scalar is provided, then @@ -370,7 +392,8 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shape of `predictions` doesn't match that of `labels` or @@ -474,7 +497,7 @@ def mean_pairwise_squared_error( def mean_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Adds a Sum-of-Squares loss to the training procedure. `weights` acts as a coefficient for the loss. If a scalar is provided, then @@ -496,7 +519,8 @@ def mean_squared_error( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shape of `predictions` doesn't match that of `labels` or @@ -515,7 +539,7 @@ def mean_squared_error( def sigmoid_cross_entropy( multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits. `weights` acts as a coefficient for the loss. If a scalar is provided, @@ -531,7 +555,7 @@ def sigmoid_cross_entropy( Args: multi_class_labels: `[batch_size, num_classes]` target integer labels in `(0, 1)`. - logits: `[batch_size, num_classes]` logits outputs of the network. + logits: Float `[batch_size, num_classes]` logits outputs of the network. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `losses` dimension). @@ -541,7 +565,8 @@ def sigmoid_cross_entropy( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss `Tensor` of the same type as `logits`. If `reduction` is + `NONE`, this has the same shape as `logits`; otherwise, it is scalar. Raises: ValueError: If the shape of `logits` doesn't match that of @@ -551,7 +576,9 @@ def sigmoid_cross_entropy( with ops.name_scope(scope, "sigmoid_cross_entropy_loss", (logits, multi_class_labels, weights)) as scope: logits = ops.convert_to_tensor(logits) + logging.info("logits.dtype=%s.", logits.dtype) multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) + logging.info("multi_class_labels.dtype=%s.", multi_class_labels.dtype) logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) if label_smoothing > 0: @@ -561,6 +588,7 @@ def sigmoid_cross_entropy( losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels, logits=logits, name="xentropy") + logging.info("losses.dtype=%s.", losses.dtype) return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) @@ -568,7 +596,7 @@ def sigmoid_cross_entropy( def softmax_cross_entropy( onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits. `weights` acts as a coefficient for the loss. If a scalar is provided, @@ -593,7 +621,8 @@ def softmax_cross_entropy( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss `Tensor` of the same type as `logits`. If `reduction` is + `NONE`, this has shape `[batch_size]`; otherwise, it is scalar. Raises: ValueError: If the shape of `logits` doesn't match that of `onehot_labels` @@ -673,7 +702,7 @@ def _remove_squeezable_dimensions( def sparse_softmax_cross_entropy( labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, - reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS): + reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`. `weights` acts as a coefficient for the loss. If a scalar is provided, @@ -696,7 +725,8 @@ def sparse_softmax_cross_entropy( reduction: Type of reduction to apply to loss. Returns: - A scalar `Tensor` that returns the weighted loss. + Weighted loss `Tensor` of the same type as `logits`. If `reduction` is + `NONE`, this has the same shape as `labels`; otherwise, it is scalar. Raises: ValueError: If the shapes of logits, labels, and weight are incompatible, or diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index c1ee26d037c..7c17cf2cb61 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -640,9 +640,10 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): else: shift = math_ops.cast(shift, y.dtype) shifted_mean = math_ops.reduce_mean( - math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean") + math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean") variance = math_ops.subtract( - math_ops.reduce_mean(math_ops.squared_difference(y, shift), axes, keep_dims=True), + math_ops.reduce_mean( + math_ops.squared_difference(y, shift), axes, keep_dims=True), math_ops.square(shifted_mean), name="variance") mean = math_ops.add(shifted_mean, shift, name="mean") @@ -650,11 +651,12 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) if x.dtype == dtypes.float16: - return (math_ops.cast(mean, dtypes.float16), - math_ops.cast(variance, dtypes.float16)) + return (math_ops.cast(mean, dtypes.float16), math_ops.cast( + variance, dtypes.float16)) else: return (mean, variance) + def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False): """Returns the frequency-weighted mean and variance of `x`. diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c3dddf85f3d..32ebe0c2e84 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -26,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.util import nest @@ -74,7 +75,7 @@ def _zero_state_tensors(state_size, batch_size, dtype): return zeros -class _RNNCell(object): +class _RNNCell(base_layer._Layer): # pylint: disable=protected-access """Abstract object representing an RNN cell. Every `RNNCell` must have the properties below and implement `__call__` with @@ -111,7 +112,7 @@ class _RNNCell(object): - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state`. """ - raise NotImplementedError("Abstract method") + return super(_RNNCell, self).__call__(inputs, state, scope=scope) @property def state_size(self): diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 6a6158620f4..b8e356c78cc 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -151,8 +151,8 @@ def _SparseTensorDenseMatMulGrad(op, grad): "complex gradients.") # gradient w.r.t. dense - b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( - a_indices, a_values, a_shape, grad, adjoint_a=not adj_a) + b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access + a_indices, a_values, a_shape, grad, adjoint_a=not adj_a) if adj_b: b_grad = array_ops.transpose(b_grad) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index ff9f134a0e3..f81837b73ac 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import collections as collections_lib -import contextlib import copy import functools import traceback @@ -36,6 +35,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_contextlib __all__ = ["VariableScope", "get_variable_scope", "get_variable", "get_local_variable", "variable_scope", @@ -1250,7 +1250,7 @@ def _get_partitioned_variable(name, # pylint: enable=protected-access -@contextlib.contextmanager +@tf_contextlib.contextmanager def _pure_variable_scope(name_or_scope, reuse=None, initializer=None, @@ -1409,7 +1409,7 @@ def _get_unique_variable_scope(prefix): # pylint: disable=g-doc-return-or-yield -@contextlib.contextmanager +@tf_contextlib.contextmanager def variable_scope(name_or_scope, default_name=None, values=None, @@ -1582,7 +1582,7 @@ def variable_scope(name_or_scope, # pylint: disable=g-doc-return-or-yield -@contextlib.contextmanager +@tf_contextlib.contextmanager def variable_op_scope(values, name_or_scope, default_name=None, diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index ea29399ed2f..aa74a419d8e 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import numbers import os import re @@ -33,6 +32,8 @@ from tensorflow.python.client import timeline from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect + # When a subclass of the Benchmark class is created, it is added to # the registry automatically @@ -135,7 +136,7 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)): """Returns full name of class and method calling report_benchmark.""" # Find the caller method (outermost Benchmark class) - stack = inspect.stack() + stack = tf_inspect.stack() calling_class = None name = None for frame in stack[::-1]: diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index 1e74b1512b8..96219faab71 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import atexit -import inspect import itertools import os import sys @@ -35,6 +34,9 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.platform import app from tensorflow.python.platform import benchmark from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name @@ -101,9 +103,9 @@ def GetTempDir(): """Return a temporary directory for tests to use.""" global _googletest_temp_dir if not _googletest_temp_dir: - first_frame = inspect.stack()[-1][0] - temp_dir = os.path.join( - tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame))) + first_frame = tf_inspect.stack()[-1][0] + temp_dir = os.path.join(tempfile.gettempdir(), + os.path.basename(tf_inspect.getfile(first_frame))) temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py')) def delete_temp_dir(dirname=temp_dir): @@ -204,15 +206,16 @@ class StubOutForTesting(object): Raises: AttributeError: If the attribute cannot be found. """ - if (inspect.ismodule(obj) or - (not inspect.isclass(obj) and attr_name in obj.__dict__)): + _, obj = tf_decorator.unwrap(obj) + if (tf_inspect.ismodule(obj) or + (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)): orig_obj = obj orig_attr = getattr(obj, attr_name) else: - if not inspect.isclass(obj): - mro = list(inspect.getmro(obj.__class__)) + if not tf_inspect.isclass(obj): + mro = list(tf_inspect.getmro(obj.__class__)) else: - mro = list(inspect.getmro(obj)) + mro = list(tf_inspect.getmro(obj)) mro.reverse() diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py index a53fc541cb7..2455acb4c0c 100644 --- a/tensorflow/python/platform/resource_loader.py +++ b/tensorflow/python/platform/resource_loader.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Resource management library. @@get_data_files_path @@ -25,10 +24,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect as _inspect import os as _os import sys as _sys +from tensorflow.python.util import tf_inspect as _inspect from tensorflow.python.util.all_util import remove_undocumented @@ -44,9 +43,8 @@ def load_resource(path): Raises: IOError: If the path is not found, or the resource can't be opened. """ - tensorflow_root = ( - _os.path.join( - _os.path.dirname(__file__), _os.pardir, _os.pardir)) + tensorflow_root = (_os.path.join( + _os.path.dirname(__file__), _os.pardir, _os.pardir)) path = _os.path.join(tensorflow_root, path) path = _os.path.abspath(path) with open(path, 'rb') as f: diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index eaf0a5c837b..48b84f9a96e 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -205,6 +205,7 @@ py_binary( deps = [ "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python", + "//tensorflow/python/debug:local_cli_wrapper", ], ) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index d14748b492f..17ef8ef9c23 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -66,6 +66,12 @@ tensors to files: --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy --outdir /tmp/out +To observe the intermediate Tensor values in the runtime graph, use the +--tf_debug flag, e.g.: + $saved_model_cli run --dir /tmp/saved_model --tag_set serve + --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy + --outdir /tmp/out --tf_debug + To build this tool from source, run: $bazel build tensorflow/python/tools:saved_model_cli @@ -87,6 +93,7 @@ from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session +from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import ops as ops_lib from tensorflow.python.platform import app from tensorflow.python.saved_model import loader @@ -282,7 +289,7 @@ def get_signature_def_map(saved_model_dir, tag_set): def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, - overwrite_flag): + overwrite_flag, tf_debug=False): """Runs SavedModel and fetch all outputs. Runs the input dictionary through the MetaGraphDef within a SavedModel @@ -300,6 +307,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, it will be created. overwrite_flag: A boolean flag to allow overwrite output file if file with the same name exists. + tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the + intermediate Tensor values and runtime GraphDefs while running the + SavedModel. Raises: RuntimeError: An error when output file already exists and overwrite is not @@ -329,6 +339,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, with session.Session(graph=ops_lib.Graph()) as sess: loader.load(sess, tag_set.split(','), saved_model_dir) + if tf_debug: + sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess) + outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict) for i, output in enumerate(outputs): @@ -520,7 +533,7 @@ def run(args): tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs) run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, tensor_key_feed_dict, args.outdir, - args.overwrite) + args.overwrite, tf_debug=args.tf_debug) def create_parser(): @@ -620,6 +633,12 @@ def create_parser(): '--overwrite', action='store_true', help='if set, output file will be overwritten if it already exists.') + parser_run.add_argument( + '--tf_debug', + action='store_true', + help='if set, will use TensorFlow Debugger (tfdbg) to watch the ' + 'intermediate Tensors and runtime GraphDefs while running the ' + 'SavedModel.') parser_run.set_defaults(func=run) return parser diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index b9d28794cc4..c481dba2e9a 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -28,6 +28,7 @@ import sys import numpy as np from six import StringIO +from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.platform import test from tensorflow.python.tools import saved_model_cli @@ -299,9 +300,9 @@ Method name is: tensorflow/serving/predict""" test.get_temp_dir() ]) saved_model_cli.run(args) - y = np.load(output_file) - y_exp = np.array([[3.5], [4.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(output_file) + y_expected = np.array([[3.5], [4.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandNewOutdir(self): self.parser = saved_model_cli.create_parser() @@ -320,9 +321,9 @@ Method name is: tensorflow/serving/predict""" output_dir ]) saved_model_cli.run(args) - y = np.load(os.path.join(output_dir, 'y.npy')) - y_exp = np.array([[2.5], [3.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(os.path.join(output_dir, 'y.npy')) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandOutOverwrite(self): self.parser = saved_model_cli.create_parser() @@ -340,9 +341,9 @@ Method name is: tensorflow/serving/predict""" test.get_temp_dir(), '--overwrite' ]) saved_model_cli.run(args) - y = np.load(output_file) - y_exp = np.array([[2.5], [3.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(output_file) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandOutputFileExistError(self): self.parser = saved_model_cli.create_parser() @@ -362,6 +363,37 @@ Method name is: tensorflow/serving/predict""" with self.assertRaises(RuntimeError): saved_model_cli.run(args) + def testRunCommandWithDebuggerEnabled(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + x = np.array([[1], [2]]) + x_notused = np.zeros((6, 3)) + input_path = os.path.join(test.get_temp_dir(), + 'testRunCommandNewOutdir_inputs.npz') + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + if os.path.isdir(output_dir): + shutil.rmtree(output_dir) + np.savez(input_path, x0=x, x1=x_notused) + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir', + output_dir, '--tf_debug' + ]) + + def fake_wrapper_session(sess): + return sess + + with test.mock.patch.object(local_cli_wrapper, + 'LocalCLIDebugWrapperSession', + side_effect=fake_wrapper_session, + autospec=True) as fake: + saved_model_cli.run(args) + fake.assert_called_with(test.mock.ANY) + + y_actual = np.load(os.path.join(output_dir, 'y.npy')) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py index 08f33657510..50d480f8707 100644 --- a/tensorflow/python/util/all_util.py +++ b/tensorflow/python/util/all_util.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect as _inspect import re as _re import sys as _sys +from tensorflow.python.util import tf_inspect as _tf_inspect + + _reference_pattern = _re.compile(r'^@@(\w+)$', flags=_re.MULTILINE) @@ -45,7 +47,7 @@ def make_all(module_name, doc_string_modules=None): if doc_string_modules is None: doc_string_modules = [_sys.modules[module_name]] cur_members = set([name for name, _ - in _inspect.getmembers(_sys.modules[module_name])]) + in _tf_inspect.getmembers(_sys.modules[module_name])]) results = set() for doc_module in doc_string_modules: diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 60b559b5f4e..73fc3e24087 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -20,11 +20,12 @@ from __future__ import print_function import collections import functools -import inspect import re from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import decorator_utils +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect def _add_deprecated_function_notice_to_docstring(doc, date, instructions): @@ -59,7 +60,7 @@ def _validate_deprecation_args(date, instructions): def _call_location(): """Returns call location given level up from current call.""" - frame = inspect.currentframe() + frame = tf_inspect.currentframe() if frame: # CPython internals are available, use them for performance. # walk back two frames to get to deprecated function caller. @@ -69,7 +70,7 @@ def _call_location(): return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno) else: # Slow fallback path - stack = inspect.stack(0) # 0 avoids generating unused context + stack = tf_inspect.stack(0) # 0 avoids generating unused context entry = stack[2] return '%s:%d' % (entry[1], entry[2]) @@ -119,9 +120,10 @@ def deprecated(date, instructions): 'in a future version' if date is None else ('after %s' % date), instructions) return func(*args, **kwargs) - new_func.__doc__ = _add_deprecated_function_notice_to_docstring( - func.__doc__, date, instructions) - return new_func + return tf_decorator.make_decorator( + func, new_func, 'deprecated', + _add_deprecated_function_notice_to_docstring(func.__doc__, date, + instructions)) return deprecated_wrapper @@ -193,7 +195,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): Args: names_to_ok_vals: dict from string arg_name to a list of values, possibly empty, which should not elicit a warning. - arg_spec: Output from inspect.getargspec on the called function. + arg_spec: Output from tf_inspect.getargspec on the called function. Returns: Dictionary from arg_name to DeprecatedArgSpec. @@ -213,7 +215,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): decorator_utils.validate_callable(func, 'deprecated_args') deprecated_arg_names = _get_arg_names_to_ok_vals() - arg_spec = inspect.getargspec(func) + arg_spec = tf_inspect.getargspec(func) deprecated_positions = _get_deprecated_positional_arguments( deprecated_arg_names, arg_spec) @@ -260,7 +262,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): def new_func(*args, **kwargs): """Deprecation wrapper.""" invalid_args = [] - named_args = inspect.getcallargs(func, *args, **kwargs) + named_args = tf_inspect.getcallargs(func, *args, **kwargs) for arg_name, spec in iter(deprecated_positions.items()): if (spec.position < len(args) and not (spec.has_ok_value and @@ -285,9 +287,9 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): 'in a future version' if date is None else ('after %s' % date), instructions) return func(*args, **kwargs) - new_func.__doc__ = _add_deprecated_arg_notice_to_docstring( - func.__doc__, date, instructions) - return new_func + return tf_decorator.make_decorator(func, new_func, 'deprecated', + _add_deprecated_arg_notice_to_docstring( + func.__doc__, date, instructions)) return deprecated_wrapper @@ -332,7 +334,7 @@ def deprecated_arg_values(date, instructions, **deprecated_kwargs): @functools.wraps(func) def new_func(*args, **kwargs): """Deprecation wrapper.""" - named_args = inspect.getcallargs(func, *args, **kwargs) + named_args = tf_inspect.getcallargs(func, *args, **kwargs) for arg_name, arg_value in deprecated_kwargs.items(): if arg_name in named_args and named_args[arg_name] == arg_value: logging.warning( @@ -343,9 +345,9 @@ def deprecated_arg_values(date, instructions, **deprecated_kwargs): 'in a future version' if date is None else ('after %s' % date), instructions) return func(*args, **kwargs) - new_func.__doc__ = _add_deprecated_arg_notice_to_docstring( - func.__doc__, date, instructions) - return new_func + return tf_decorator.make_decorator(func, new_func, 'deprecated', + _add_deprecated_arg_notice_to_docstring( + func.__doc__, date, instructions)) return deprecated_wrapper diff --git a/tensorflow/python/util/tf_contextlib.py b/tensorflow/python/util/tf_contextlib.py new file mode 100644 index 00000000000..3830014d4ac --- /dev/null +++ b/tensorflow/python/util/tf_contextlib.py @@ -0,0 +1,36 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFDecorator-aware replacements for the contextlib module.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib as _contextlib + +from tensorflow.python.util import tf_decorator + + +def contextmanager(target): + """A tf_decorator-aware wrapper for `contextlib.contextmanager`. + + Usage is identical to `contextlib.contextmanager`. + + Args: + target: A callable to be wrapped in a contextmanager. + Returns: + A callable that can be used inside of a `with` statement. + """ + context_manager = _contextlib.contextmanager(target) + return tf_decorator.make_decorator(target, context_manager, 'contextmanager') diff --git a/tensorflow/python/util/tf_contextlib_test.py b/tensorflow/python/util/tf_contextlib_test.py new file mode 100644 index 00000000000..4a5bf388a63 --- /dev/null +++ b/tensorflow/python/util/tf_contextlib_test.py @@ -0,0 +1,92 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for tf_contextlib.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +@tf_contextlib.contextmanager +def test_yield_append_before_and_after_yield(x, before, after): + x.append(before) + yield + x.append(after) + + +@tf_contextlib.contextmanager +def test_yield_return_x_plus_1(x): + yield x + 1 + + +@tf_contextlib.contextmanager +def test_params_and_defaults(a, b=2, c=True, d='hello'): + return [a, b, c, d] + + +class TfContextlibTest(test.TestCase): + + def testRunsCodeBeforeYield(self): + x = [] + with test_yield_append_before_and_after_yield(x, 'before', ''): + self.assertEqual('before', x[-1]) + + def testRunsCodeAfterYield(self): + x = [] + with test_yield_append_before_and_after_yield(x, '', 'after'): + pass + self.assertEqual('after', x[-1]) + + def testNestedWith(self): + x = [] + with test_yield_append_before_and_after_yield(x, 'before', 'after'): + with test_yield_append_before_and_after_yield(x, 'inner', 'outer'): + with test_yield_return_x_plus_1(1) as var: + x.append(var) + self.assertEqual(['before', 'inner', 2, 'outer', 'after'], x) + + def testMultipleCallsOfSeparateInstances(self): + x = [] + with test_yield_append_before_and_after_yield(x, 1, 2): + pass + with test_yield_append_before_and_after_yield(x, 3, 4): + pass + self.assertEqual([1, 2, 3, 4], x) + + def testReturnsResultFromYield(self): + with test_yield_return_x_plus_1(3) as result: + self.assertEqual(4, result) + + def testUnwrapContextManager(self): + decorators, target = tf_decorator.unwrap(test_params_and_defaults) + self.assertEqual(1, len(decorators)) + self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator)) + self.assertEqual('contextmanager', decorators[0].decorator_name) + self.assertFalse(isinstance(target, tf_decorator.TFDecorator)) + + def testGetArgSpecReturnsWrappedArgSpec(self): + argspec = tf_inspect.getargspec(test_params_and_defaults) + self.assertEqual(['a', 'b', 'c', 'd'], argspec.args) + self.assertEqual((2, True, 'hello'), argspec.defaults) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py new file mode 100644 index 00000000000..a5d979e376c --- /dev/null +++ b/tensorflow/python/util/tf_decorator.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base TFDecorator class and utility functions for working with decorators. + +There are two ways to create decorators that TensorFlow can introspect into. +This is important for documentation generation purposes, so that function +signatures aren't obscured by the (*args, **kwds) signature that decorators +often provide. + +1. Call `tf_decorator.make_decorator` on your wrapper function. If your +decorator is stateless, or can capture all of the variables it needs to work +with through lexical closure, this is the simplest option. Create your wrapper +function as usual, but instead of returning it, return +`tf_decorator.make_decorator(your_wrapper)`. This will attach some decorator +introspection metadata onto your wrapper and return it. + +Example: + + def print_hello_before_calling(target): + def wrapper(*args, **kwargs): + print('hello') + return target(*args, **kwargs) + return tf_decorator.make_decorator(wrapper) + +2. Derive from TFDecorator. If your decorator needs to be stateful, you can +implement it in terms of a TFDecorator. Store whatever state you need in your +derived class, and implement the `__call__` method to do your work before +calling into your target. You can retrieve the target via +`super(MyDecoratorClass, self).decorated_target`, and call it with whatever +parameters it needs. + +Example: + + class CallCounter(tf_decorator.TFDecorator): + def __init__(self, target): + super(CallCounter, self).__init__('count_calls', target) + self.call_count = 0 + + def __call__(self, *args, **kwargs): + self.call_count += 1 + return super(CallCounter, self).decorated_target(*args, **kwargs) + + def count_calls(target): + return CallCounter(target) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools as _functools +import inspect as _inspect + + +def make_decorator(target, + decorator_func, + decorator_name=None, + decorator_doc='', + decorator_argspec=None): + """Make a decorator from a wrapper and a target. + + Args: + target: The final callable to be wrapped. + decorator_func: The wrapper function. + decorator_name: The name of the decorator. If `None`, the name of the + function calling make_decorator. + decorator_doc: Documentation specific to this application of + `decorator_func` to `target`. + decorator_argspec: The new callable signature of this decorator. + + Returns: + The `decorator_func` argument with new metadata attached. + """ + if decorator_name is None: + decorator_name = _inspect.stack()[1][3] # Caller's name. + decorator = TFDecorator(decorator_name, target, decorator_doc, + decorator_argspec) + setattr(decorator_func, '_tf_decorator', decorator) + decorator_func.__name__ = target.__name__ + decorator_func.__doc__ = decorator.__doc__ + decorator_func.__wrapped__ = target + return decorator_func + + +def unwrap(maybe_tf_decorator): + """Unwraps an object into a list of TFDecorators and a final target. + + Args: + maybe_tf_decorator: Any callable object. + + Returns: + A tuple whose first element is an list of TFDecorator-derived objects that + were applied to the final callable target, and whose second element is the + final undecorated callable target. If the `maybe_tf_decorator` parameter is + not decorated by any TFDecorators, the first tuple element will be an empty + list. The `TFDecorator` list is ordered from outermost to innermost + decorators. + """ + decorators = [] + cur = maybe_tf_decorator + while True: + if isinstance(cur, TFDecorator): + decorators.append(cur) + elif hasattr(cur, '_tf_decorator'): + decorators.append(getattr(cur, '_tf_decorator')) + else: + break + cur = decorators[-1].decorated_target + return decorators, cur + + +class TFDecorator(object): + """Base class for all TensorFlow decorators. + + TFDecorator captures and exposes the wrapped target, and provides details + about the current decorator. + """ + + def __init__(self, + decorator_name, + target, + decorator_doc='', + decorator_argspec=None): + self._decorated_target = target + self._decorator_name = decorator_name + self._decorator_doc = decorator_doc + self._decorator_argspec = decorator_argspec + self.__name__ = target.__name__ + if self._decorator_doc: + self.__doc__ = self._decorator_doc + elif target.__doc__: + self.__doc__ = target.__doc__ + else: + self.__doc__ = '' + + def __get__(self, obj, objtype): + return _functools.partial(self.__call__, obj) + + def __call__(self, *args, **kwargs): + return self._decorated_target(*args, **kwargs) + + @property + def decorated_target(self): + return self._decorated_target + + @property + def decorator_name(self): + return self._decorator_name + + @property + def decorator_doc(self): + return self._decorator_doc + + @property + def decorator_argspec(self): + return self._decorator_argspec diff --git a/tensorflow/python/util/tf_decorator_test.py b/tensorflow/python/util/tf_decorator_test.py new file mode 100644 index 00000000000..3f6a10b4408 --- /dev/null +++ b/tensorflow/python/util/tf_decorator_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for tf_decorator.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +def test_tfdecorator(decorator_name, decorator_doc=None): + + def make_tf_decorator(target): + return tf_decorator.TFDecorator(decorator_name, target, decorator_doc) + + return make_tf_decorator + + +def test_decorator_increment_first_int_arg(target): + """This test decorator skips past `self` as args[0] in the bound case.""" + + def wrapper(*args, **kwargs): + new_args = [] + found = False + for arg in args: + if not found and isinstance(arg, int): + new_args.append(arg + 1) + found = True + else: + new_args.append(arg) + return target(*new_args, **kwargs) + + return tf_decorator.make_decorator(target, wrapper) + + +def test_function(x): + """Test Function Docstring.""" + return x + 1 + + +@test_tfdecorator('decorator 1') +@test_decorator_increment_first_int_arg +@test_tfdecorator('decorator 3', 'decorator 3 documentation') +def test_decorated_function(x): + """Test Decorated Function Docstring.""" + return x * 2 + + +@test_tfdecorator('decorator') +class TestDecoratedClass(object): + """Test Decorated Class.""" + + def __init__(self, two_attr=2): + self.two_attr = two_attr + + @property + def two_prop(self): + return 2 + + def two_func(self): + return 2 + + @test_decorator_increment_first_int_arg + def return_params(self, a, b, c): + """Return parameters.""" + return [a, b, c] + + +class TfDecoratorTest(test.TestCase): + + def testInitCapturesTarget(self): + self.assertIs(test_function, + tf_decorator.TFDecorator('', test_function).decorated_target) + + def testInitCapturesDecoratorName(self): + self.assertEqual('decorator name', + tf_decorator.TFDecorator('decorator name', + test_function).decorator_name) + + def testInitCapturesDecoratorDoc(self): + self.assertEqual('decorator doc', + tf_decorator.TFDecorator('', test_function, + 'decorator doc').decorator_doc) + + def testInitCapturesNonNoneArgspec(self): + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + self.assertIs(argspec, + tf_decorator.TFDecorator('', test_function, '', + argspec).decorator_argspec) + + def testInitSetsDecoratorNameToTargetName(self): + self.assertEqual('test_function', + tf_decorator.TFDecorator('', test_function).__name__) + + def testInitSetsDecoratorDocToTargetDoc(self): + self.assertEqual('Test Function Docstring.', + tf_decorator.TFDecorator('', test_function).__doc__) + + def testCallingATFDecoratorCallsTheTarget(self): + self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123)) + + def testCallingADecoratedFunctionCallsTheTarget(self): + self.assertEqual((2 + 1) * 2, test_decorated_function(2)) + + def testInitializingDecoratedClassWithInitParamsDoesntRaise(self): + try: + TestDecoratedClass(2) + except TypeError: + self.assertFail() + + def testReadingClassAttributeOnDecoratedClass(self): + self.assertEqual(2, TestDecoratedClass().two_attr) + + def testCallingClassMethodOnDecoratedClass(self): + self.assertEqual(2, TestDecoratedClass().two_func()) + + def testReadingClassPropertyOnDecoratedClass(self): + self.assertEqual(2, TestDecoratedClass().two_prop) + + def testNameOnBoundProperty(self): + self.assertEqual('return_params', + TestDecoratedClass().return_params.__name__) + + def testDocstringOnBoundProperty(self): + self.assertEqual('Return parameters.', + TestDecoratedClass().return_params.__doc__) + + +def test_wrapper(*args, **kwargs): + return test_function(*args, **kwargs) + + +class TfMakeDecoratorTest(test.TestCase): + + def testAttachesATFDecoratorAttr(self): + decorated = tf_decorator.make_decorator(test_function, test_wrapper) + decorator = getattr(decorated, '_tf_decorator') + self.assertIsInstance(decorator, tf_decorator.TFDecorator) + + def testAttachesWrappedAttr(self): + decorated = tf_decorator.make_decorator(test_function, test_wrapper) + wrapped_attr = getattr(decorated, '__wrapped__') + self.assertIs(test_function, wrapped_attr) + + def testSetsTFDecoratorNameToDecoratorNameArg(self): + decorated = tf_decorator.make_decorator(test_function, test_wrapper, + 'test decorator name') + decorator = getattr(decorated, '_tf_decorator') + self.assertEqual('test decorator name', decorator.decorator_name) + + def testSetsTFDecoratorDocToDecoratorDocArg(self): + decorated = tf_decorator.make_decorator( + test_function, test_wrapper, decorator_doc='test decorator doc') + decorator = getattr(decorated, '_tf_decorator') + self.assertEqual('test decorator doc', decorator.decorator_doc) + + def testSetsTFDecoratorArgSpec(self): + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '', + argspec) + decorator = getattr(decorated, '_tf_decorator') + self.assertEqual(argspec, decorator.decorator_argspec) + + def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self): + + def test_decorator_name(wrapper): + return tf_decorator.make_decorator(test_function, wrapper) + + decorated = test_decorator_name(test_wrapper) + decorator = getattr(decorated, '_tf_decorator') + self.assertEqual('test_decorator_name', decorator.decorator_name) + + +class TfDecoratorUnwrapTest(test.TestCase): + + def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self): + decorators, _ = tf_decorator.unwrap(test_function) + self.assertEqual(0, len(decorators)) + + def testUnwrapReturnsUndecoratedFunctionAsTarget(self): + _, target = tf_decorator.unwrap(test_function) + self.assertIs(test_function, target) + + def testUnwrapReturnsFinalFunctionAsTarget(self): + self.assertEqual((4 + 1) * 2, test_decorated_function(4)) + _, target = tf_decorator.unwrap(test_decorated_function) + self.assertTrue(tf_inspect.isfunction(target)) + self.assertEqual(4 * 2, target(4)) + + def testUnwrapReturnsListOfUniqueTFDecorators(self): + decorators, _ = tf_decorator.unwrap(test_decorated_function) + self.assertEqual(3, len(decorators)) + self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator)) + self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator)) + self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator)) + self.assertIsNot(decorators[0], decorators[1]) + self.assertIsNot(decorators[1], decorators[2]) + self.assertIsNot(decorators[2], decorators[0]) + + def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self): + decorators, _ = tf_decorator.unwrap(test_decorated_function) + self.assertEqual('decorator 1', decorators[0].decorator_name) + self.assertEqual('test_decorator_increment_first_int_arg', + decorators[1].decorator_name) + self.assertEqual('decorator 3', decorators[2].decorator_name) + self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc) + + def testUnwrapBoundMethods(self): + test_decorated_class = TestDecoratedClass() + self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3)) + decorators, target = tf_decorator.unwrap(test_decorated_class.return_params) + self.assertEqual('test_decorator_increment_first_int_arg', + decorators[0].decorator_name) + self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py new file mode 100644 index 00000000000..977b0df08b5 --- /dev/null +++ b/tensorflow/python/util/tf_inspect.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFDecorator-aware replacements for the inspect module.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect as _inspect + +from tensorflow.python.util import tf_decorator + +ArgSpec = _inspect.ArgSpec + + +def currentframe(): + """TFDecorator-aware replacement for inspect.currentframe.""" + return _inspect.stack()[1][0] + + +def getargspec(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getargspec. + + Args: + object: A callable, possibly decorated. + + Returns: + The `ArgSpec` that describes the signature of the outermost decorator that + changes the callable's signature. If the callable is not decorated, + `inspect.getargspec()` will be called directly on the callable. + """ + decorators, target = tf_decorator.unwrap(object) + return next((d.decorator_argspec for d in decorators + if d.decorator_argspec is not None), _inspect.getargspec(target)) + + +def getcallargs(func, *positional, **named): + """TFDecorator-aware replacement for inspect.getcallargs. + + Args: + func: A callable, possibly decorated + *positional: The positional arguments that would be passed to `func`. + **named: The named argument dictionary that would be passed to `func`. + + Returns: + A dictionary mapping `func`'s named arguments to the values they would + receive if `func(*positional, **named)` were called. + + `getcallargs` will use the argspec from the outermost decorator that provides + it. If no attached decorators modify argspec, the final unwrapped target's + argspec will be used. + """ + argspec = getargspec(func) + call_args = named.copy() + this = getattr(func, 'im_self', None) or getattr(func, '__self__', None) + if ismethod(func) and this: + positional = (this,) + positional + remaining_positionals = [arg for arg in argspec.args if arg not in call_args] + call_args.update(dict(zip(remaining_positionals, positional))) + default_count = 0 if not argspec.defaults else len(argspec.defaults) + if default_count: + for arg, value in zip(argspec.args[-default_count:], argspec.defaults): + if arg not in call_args: + call_args[arg] = value + return call_args + + +def getdoc(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getdoc. + + Args: + object: An object, possibly decorated. + + Returns: + The docstring associated with the object. + + The outermost-decorated object is intended to have the most complete + documentation, so the decorated parameter is not unwrapped. + """ + return _inspect.getdoc(object) + + +def getfile(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getfile.""" + return _inspect.getfile(tf_decorator.unwrap(object)[1]) + + +def getmembers(object, predicate=None): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getmembers.""" + return _inspect.getmembers(object, predicate) + + +def getmro(cls): + """TFDecorator-aware replacement for inspect.getmro.""" + return _inspect.getmro(cls) + + +def getsource(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getsource.""" + return _inspect.getsource(tf_decorator.unwrap(object)[1]) + + +def isclass(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isclass.""" + return _inspect.isclass(tf_decorator.unwrap(object)[1]) + + +def isfunction(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isfunction.""" + return _inspect.isfunction(tf_decorator.unwrap(object)[1]) + + +def ismethod(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.ismethod.""" + return _inspect.ismethod(tf_decorator.unwrap(object)[1]) + + +def ismodule(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.ismodule.""" + return _inspect.ismodule(tf_decorator.unwrap(object)[1]) + + +def isroutine(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isroutine.""" + return _inspect.isroutine(tf_decorator.unwrap(object)[1]) + + +def stack(context=1): + """TFDecorator-aware replacement for inspect.stack.""" + return _inspect.stack(context)[1:] diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py new file mode 100644 index 00000000000..a9e8ffb30c3 --- /dev/null +++ b/tensorflow/python/util/tf_inspect_test.py @@ -0,0 +1,327 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for tf_inspect.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +def test_decorator(decorator_name, decorator_doc=None): + + def make_tf_decorator(target): + return tf_decorator.TFDecorator(decorator_name, target, decorator_doc) + + return make_tf_decorator + + +def test_undecorated_function(): + pass + + +@test_decorator('decorator 1') +@test_decorator('decorator 2') +@test_decorator('decorator 3') +def test_decorated_function(x): + """Test Decorated Function Docstring.""" + return x * 2 + + +@test_decorator('decorator') +def test_decorated_function_with_defaults(a, b=2, c='Hello'): + """Test Decorated Function With Defaults Docstring.""" + return [a, b, c] + + +@test_decorator('decorator') +class TestDecoratedClass(object): + """Test Decorated Class.""" + + def __init__(self): + pass + + def two(self): + return 2 + + +class TfInspectTest(test.TestCase): + + def testCurrentFrame(self): + self.assertEqual(inspect.currentframe(), tf_inspect.currentframe()) + + def testGetArgSpecOnDecoratorsThatDontProvideArgspec(self): + argspec = tf_inspect.getargspec(test_decorated_function_with_defaults) + self.assertEqual(['a', 'b', 'c'], argspec.args) + self.assertEqual((2, 'Hello'), argspec.defaults) + + def testGetArgSpecOnDecoratorThatChangesArgspec(self): + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + + decorator = tf_decorator.TFDecorator('', test_undecorated_function, '', + argspec) + self.assertEqual(argspec, tf_inspect.getargspec(decorator)) + + def testGetArgSpecIgnoresDecoratorsThatDontProvideArgspec(self): + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + + inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function, + '', argspec) + outer_decorator = tf_decorator.TFDecorator('', inner_decorator) + self.assertEqual(argspec, tf_inspect.getargspec(outer_decorator)) + + def testGetArgSpecReturnsOutermostDecoratorThatChangesArgspec(self): + outer_argspec = tf_inspect.ArgSpec( + args=['a'], varargs=None, keywords=None, defaults=None) + inner_argspec = tf_inspect.ArgSpec( + args=['b'], varargs=None, keywords=None, defaults=None) + + inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function, + '', inner_argspec) + outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '', + outer_argspec) + self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator)) + + def testGetDoc(self): + self.assertEqual('Test Decorated Function With Defaults Docstring.', + tf_inspect.getdoc(test_decorated_function_with_defaults)) + + def testGetFile(self): + self.assertTrue('tf_inspect_test.py' in tf_inspect.getfile( + test_decorated_function_with_defaults)) + self.assertTrue('tf_decorator.py' in tf_inspect.getfile( + test_decorator('decorator')(tf_decorator.unwrap))) + + def testGetMembers(self): + self.assertEqual( + inspect.getmembers(TestDecoratedClass), + tf_inspect.getmembers(TestDecoratedClass)) + + def testGetSource(self): + expected = '''@test_decorator('decorator') +def test_decorated_function_with_defaults(a, b=2, c='Hello'): + """Test Decorated Function With Defaults Docstring.""" + return [a, b, c] +''' + self.assertEqual( + expected, tf_inspect.getsource(test_decorated_function_with_defaults)) + + def testIsClass(self): + self.assertTrue(tf_inspect.isclass(TestDecoratedClass)) + self.assertFalse(tf_inspect.isclass(test_decorated_function)) + + def testIsFunction(self): + self.assertTrue(tf_inspect.isfunction(test_decorated_function)) + self.assertFalse(tf_inspect.isfunction(TestDecoratedClass)) + + def testIsMethod(self): + self.assertTrue(tf_inspect.ismethod(TestDecoratedClass().two)) + self.assertFalse(tf_inspect.ismethod(test_decorated_function)) + + def testIsModule(self): + self.assertTrue( + tf_inspect.ismodule(inspect.getmodule(inspect.currentframe()))) + self.assertFalse(tf_inspect.ismodule(test_decorated_function)) + + def testIsRoutine(self): + self.assertTrue(tf_inspect.isroutine(len)) + self.assertFalse(tf_inspect.isroutine(TestDecoratedClass)) + + def testStack(self): + expected_stack = inspect.stack() + actual_stack = tf_inspect.stack() + self.assertEqual(len(expected_stack), len(actual_stack)) + self.assertEqual(expected_stack[0][0], actual_stack[0][0]) # Frame object + self.assertEqual(expected_stack[0][1], actual_stack[0][1]) # Filename + self.assertEqual(expected_stack[0][2], + actual_stack[0][2] - 1) # Line number + self.assertEqual(expected_stack[0][3], actual_stack[0][3]) # Function name + self.assertEqual(expected_stack[1:], actual_stack[1:]) + + +class TfInspectGetCallArgsTest(test.TestCase): + + def testReturnsEmptyWhenUnboundFuncHasNoParameters(self): + + def empty(): + pass + + self.assertEqual({}, tf_inspect.getcallargs(empty)) + + def testUnboundFuncWithOneParamPositional(self): + + def func(a): + return a + + self.assertEqual({'a': 5}, tf_inspect.getcallargs(func, 5)) + + def testUnboundFuncWithTwoParamsPositional(self): + + def func(a, b): + return (a, b) + + self.assertEqual({'a': 10, 'b': 20}, tf_inspect.getcallargs(func, 10, 20)) + + def testUnboundFuncWithOneParamKeyword(self): + + def func(a): + return a + + self.assertEqual({'a': 5}, tf_inspect.getcallargs(func, a=5)) + + def testUnboundFuncWithTwoParamsKeyword(self): + + def func(a, b): + return (a, b) + + self.assertEqual({'a': 6, 'b': 7}, tf_inspect.getcallargs(func, a=6, b=7)) + + def testUnboundFuncWithOneParamDefault(self): + + def func(a=13): + return a + + self.assertEqual({'a': 13}, tf_inspect.getcallargs(func)) + + def testUnboundFuncWithOneParamDefaultOnePositional(self): + + def func(a=0): + return a + + self.assertEqual({'a': 1}, tf_inspect.getcallargs(func, 1)) + + def testUnboundFuncWithTwoParamsDefaultOnePositional(self): + + def func(a=1, b=2): + return (a, b) + + self.assertEqual({'a': 5, 'b': 2}, tf_inspect.getcallargs(func, 5)) + + def testUnboundFuncWithTwoParamsDefaultTwoPositional(self): + + def func(a=1, b=2): + return (a, b) + + self.assertEqual({'a': 3, 'b': 4}, tf_inspect.getcallargs(func, 3, 4)) + + def testUnboundFuncWithOneParamDefaultOneKeyword(self): + + def func(a=1): + return a + + self.assertEqual({'a': 3}, tf_inspect.getcallargs(func, a=3)) + + def testUnboundFuncWithTwoParamsDefaultOneKeywordFirst(self): + + def func(a=1, b=2): + return (a, b) + + self.assertEqual({'a': 3, 'b': 2}, tf_inspect.getcallargs(func, a=3)) + + def testUnboundFuncWithTwoParamsDefaultOneKeywordSecond(self): + + def func(a=1, b=2): + return (a, b) + + self.assertEqual({'a': 1, 'b': 4}, tf_inspect.getcallargs(func, b=4)) + + def testUnboundFuncWithTwoParamsDefaultTwoKeywords(self): + + def func(a=1, b=2): + return (a, b) + + self.assertEqual({'a': 3, 'b': 4}, tf_inspect.getcallargs(func, a=3, b=4)) + + def testBoundFuncWithOneParam(self): + + class Test(object): + + def bound(self): + pass + + t = Test() + self.assertEqual({'self': t}, tf_inspect.getcallargs(t.bound)) + + def testBoundFuncWithManyParamsAndDefaults(self): + + class Test(object): + + def bound(self, a, b=2, c='Hello'): + return (a, b, c) + + t = Test() + self.assertEqual({ + 'self': t, + 'a': 3, + 'b': 2, + 'c': 'Goodbye' + }, tf_inspect.getcallargs(t.bound, 3, c='Goodbye')) + + def testClassMethod(self): + + class Test(object): + + @classmethod + def test(cls, a, b=3, c='hello'): + return (a, b, c) + + self.assertEqual({ + 'cls': Test, + 'a': 5, + 'b': 3, + 'c': 'goodbye' + }, tf_inspect.getcallargs(Test.test, 5, c='goodbye')) + + def testUsesOutermostDecoratorsArgSpec(self): + + def func(): + pass + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + decorated = tf_decorator.make_decorator( + func, + wrapper, + decorator_argspec=tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(3, 'hello'))) + + self.assertEqual({ + 'a': 4, + 'b': 3, + 'c': 'goodbye' + }, tf_inspect.getcallargs(decorated, 4, c='goodbye')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG index 59343b09ec7..fb1e7bc8699 100644 --- a/tensorflow/tensorboard/TAG +++ b/tensorflow/tensorboard/TAG @@ -1 +1 @@ -53 +54 diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD new file mode 100644 index 00000000000..8c222be10e9 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD @@ -0,0 +1,19 @@ +# Description: +# Package for the Embedding Projector component. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts new file mode 100644 index 00000000000..aa1f86927da --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts @@ -0,0 +1,67 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {ProjectionType} from './data'; + +export class AnalyticsLogger { + private eventLogging: boolean; + private pageViewLogging: boolean; + + /** + * Constructs an event logger using Google Analytics. It assumes there is a + * Google Analytics script added to the page elsewhere. If there is no such + * script, the logger acts as a no-op. + * + * @param pageViewLogging Whether to log page views. + * @param eventLogging Whether to log user interaction. + */ + constructor(pageViewLogging: boolean, eventLogging: boolean) { + if (typeof ga === 'undefined' || ga == null) { + this.eventLogging = false; + this.pageViewLogging = false; + return; + } + this.eventLogging = eventLogging; + this.pageViewLogging = pageViewLogging; + } + + logPageView(pageTitle: string) { + if (this.pageViewLogging) { + // Always send a page view. + ga('send', {hitType: 'pageview', page: `/v/${pageTitle}`}); + } + } + + logProjectionChanged(projection: ProjectionType) { + if (this.eventLogging) { + ga('send', { + hitType: 'event', + eventCategory: 'Projection', + eventAction: 'click', + eventLabel: projection + }); + } + } + + logWebGLDisabled() { + if (this.eventLogging) { + ga('send', { + hitType: 'event', + eventCategory: 'Error', + eventAction: 'PageLoad', + eventLabel: 'WebGL_disabled' + }); + } + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts new file mode 100644 index 00000000000..9d2df65f560 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts @@ -0,0 +1,472 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * This is a fork of the Karpathy's TSNE.js (original license below). + * This fork implements Barnes-Hut approximation and runs in O(NlogN) + * time, as opposed to the Karpathy's O(N^2) version. + * + * @author smilkov@google.com (Daniel Smilkov) + */ + +/** + * The MIT License (MIT) + * Copyright (c) 2015 Andrej Karpathy + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +import {SPNode, SPTree} from './sptree'; + +type AugmSPNode = SPNode&{numCells: number, yCell: number[], rCell: number}; + +/** + * Barnes-hut approximation level. Higher means more approximation and faster + * results. Recommended value mentioned in the paper is 0.8. + */ +const THETA = 0.8; + +const MIN_POSSIBLE_PROB = 1E-9; + +// Variables used for memorizing the second random number since running +// gaussRandom() generates two random numbers at the cost of 1 atomic +// computation. This optimization results in 2X speed-up of the generator. +let return_v = false; +let v_val = 0.0; + +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: number[], b: number[]): number { + if (a.length !== b.length) { + throw new Error('Vectors a and b must be of same length'); + } + + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} + +function gaussRandom(rng: () => number): number { + if (return_v) { + return_v = false; + return v_val; + } + let u = 2 * rng() - 1; + let v = 2 * rng() - 1; + let r = u * u + v * v; + if (r === 0 || r > 1) { + return gaussRandom(rng); + } + let c = Math.sqrt(-2 * Math.log(r) / r); + v_val = v * c; // cache this for next function call for efficiency + return_v = true; + return u * c; +}; + +// return random normal number +function randn(rng: () => number, mu: number, std: number) { + return mu + gaussRandom(rng) * std; +}; + +// utilitity that creates contiguous vector of zeros of size n +function zeros(n: number): Float64Array { + return new Float64Array(n); +}; + +// utility that returns a matrix filled with random numbers +// generated by the provided generator. +function randnMatrix(n: number, d: number, rng: () => number) { + let nd = n * d; + let x = zeros(nd); + for (let i = 0; i < nd; ++i) { + x[i] = randn(rng, 0.0, 1E-4); + } + return x; +}; + +// utility that returns a matrix filled with the provided value. +function arrayofs(n: number, d: number, val: number) { + let x: number[][] = []; + for (let i = 0; i < n; ++i) { + x.push(d === 3 ? [val, val, val] : [val, val]); + } + return x; +}; + +// compute (p_{i|j} + p_{j|i})/(2n) +function nearest2P( + nearest: {index: number, dist: number}[][], perplexity: number, + tol: number) { + let N = nearest.length; + let Htarget = Math.log(perplexity); // target entropy of distribution + let P = zeros(N * N); // temporary probability matrix + let K = nearest[0].length; + let pRow: number[] = new Array(K); // pij[]. + + for (let i = 0; i < N; ++i) { + let neighbors = nearest[i]; + let betaMin = -Infinity; + let betaMax = Infinity; + let beta = 1; // initial value of precision + let maxTries = 50; + + // perform binary search to find a suitable precision beta + // so that the entropy of the distribution is appropriate + let numTries = 0; + while (true) { + // compute entropy and kernel row with beta precision + let psum = 0.0; + for (let k = 0; k < neighbors.length; ++k) { + let neighbor = neighbors[k]; + let pij = (i === neighbor.index) ? 0 : Math.exp(-neighbor.dist * beta); + pij = Math.max(pij, MIN_POSSIBLE_PROB); + pRow[k] = pij; + psum += pij; + } + // normalize p and compute entropy + let Hhere = 0.0; + for (let k = 0; k < pRow.length; ++k) { + pRow[k] /= psum; + let pij = pRow[k]; + if (pij > 1E-7) { + Hhere -= pij * Math.log(pij); + }; + } + + // adjust beta based on result + if (Hhere > Htarget) { + // entropy was too high (distribution too diffuse) + // so we need to increase the precision for more peaky distribution + betaMin = beta; // move up the bounds + if (betaMax === Infinity) { + beta = beta * 2; + } else { + beta = (beta + betaMax) / 2; + } + + } else { + // converse case. make distrubtion less peaky + betaMax = beta; + if (betaMin === -Infinity) { + beta = beta / 2; + } else { + beta = (beta + betaMin) / 2; + } + } + numTries++; + // stopping conditions: too many tries or got a good precision + if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) { + break; + } + } + + // copy over the final prow to P at row i + for (let k = 0; k < pRow.length; ++k) { + let pij = pRow[k]; + let j = neighbors[k].index; + P[i * N + j] = pij; + } + } // end loop over examples i + + // symmetrize P and normalize it to sum to 1 over all ij + let N2 = N * 2; + for (let i = 0; i < N; ++i) { + for (let j = i + 1; j < N; ++j) { + let i_j = i * N + j; + let j_i = j * N + i; + let value = (P[i_j] + P[j_i]) / N2; + P[i_j] = value; + P[j_i] = value; + } + } + return P; +}; + +// helper function +function sign(x: number) { + return x > 0 ? 1 : x < 0 ? -1 : 0; +} + +function computeForce_2d( + force: number[], mult: number, pointA: number[], pointB: number[]) { + force[0] += mult * (pointA[0] - pointB[0]); + force[1] += mult * (pointA[1] - pointB[1]); +} + +function computeForce_3d( + force: number[], mult: number, pointA: number[], pointB: number[]) { + force[0] += mult * (pointA[0] - pointB[0]); + force[1] += mult * (pointA[1] - pointB[1]); + force[2] += mult * (pointA[2] - pointB[2]); +} + +export interface TSNEOptions { + /** How many dimensions. */ + dim: number; + /** Roughly how many neighbors each point influences. */ + perplexity?: number; + /** Learning rate. */ + epsilon?: number; + /** A random number generator. */ + rng?: () => number; +} + +export class TSNE { + private perplexity: number; + private epsilon: number; + /** Random generator */ + private rng: () => number; + private iter = 0; + private Y: Float64Array; + private N: number; + private P: Float64Array; + private gains: number[][]; + private ystep: number[][]; + private nearest: {index: number, dist: number}[][]; + private dim: number; + private dist2: (a: number[], b: number[]) => number; + private computeForce: + (force: number[], mult: number, pointA: number[], + pointB: number[]) => void; + + constructor(opt: TSNEOptions) { + opt = opt || {dim: 2}; + this.perplexity = opt.perplexity || 30; + this.epsilon = opt.epsilon || 10; + this.rng = opt.rng || Math.random; + this.dim = opt.dim; + if (opt.dim === 2) { + this.dist2 = dist2_2D; + this.computeForce = computeForce_2d; + } else if (opt.dim === 3) { + this.dist2 = dist2_3D; + this.computeForce = computeForce_3d; + } else { + throw new Error('Only 2D and 3D is supported'); + } + } + + // this function takes a fattened distance matrix and creates + // matrix P from them. + // D is assumed to be provided as an array of size N^2. + initDataDist(nearest: {index: number, dist: number}[][]) { + let N = nearest.length; + this.nearest = nearest; + this.P = nearest2P(nearest, this.perplexity, 1E-4); + this.N = N; + this.initSolution(); // refresh this + } + + // (re)initializes the solution to random + initSolution() { + // generate random solution to t-SNE + this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution + this.gains = arrayofs(this.N, this.dim, 1.0); // step gains + // to accelerate progress in unchanging directions + this.ystep = arrayofs(this.N, this.dim, 0.0); // momentum accumulator + this.iter = 0; + } + + // return pointer to current solution + getSolution() { return this.Y; } + + // perform a single step of optimization to improve the embedding + step() { + this.iter += 1; + let N = this.N; + + let grad = this.costGrad(this.Y); // evaluate gradient + + // perform gradient step + let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + let gid = grad[i][d]; + let sid = this.ystep[i][d]; + let gainid = this.gains[i][d]; + + // compute gain update + let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2; + if (newgain < 0.01) { + newgain = 0.01; // clamp + } + this.gains[i][d] = newgain; // store for next turn + + // compute momentum step direction + let momval = this.iter < 250 ? 0.5 : 0.8; + let newsid = momval * sid - this.epsilon * newgain * grad[i][d]; + this.ystep[i][d] = newsid; // remember the step we took + + // step! + let i_d = i * this.dim + d; + this.Y[i_d] += newsid; + ymean[d] += this.Y[i_d]; // accumulate mean so that we + // can center later + } + } + + // reproject Y to be zero mean + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + this.Y[i * this.dim + d] -= ymean[d] / N; + } + } + } + + // return cost and gradient, given an arrangement + costGrad(Y: Float64Array): number[][] { + let N = this.N; + let P = this.P; + + // Trick that helps with local optima. + let alpha = this.iter < 100 ? 4 : 1; + + // Make data for the SP tree. + let points: number[][] = new Array(N); // (x, y)[] + for (let i = 0; i < N; ++i) { + let iTimesD = i * this.dim; + let row = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + row[d] = Y[iTimesD + d]; + } + points[i] = row; + } + + // Make a tree. + let tree = new SPTree(points); + let root = tree.root as AugmSPNode; + // Annotate the tree. + + let annotateTree = + (node: AugmSPNode): {numCells: number, yCell: number[]} => { + let numCells = 1; + if (node.children == null) { + // Update the current node and tell the parent. + node.numCells = numCells; + node.yCell = node.point; + return {numCells, yCell: node.yCell}; + } + // node.point is a 2 or 3-dim number[], so slice() makes a copy. + let yCell = node.point.slice(); + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child == null) { + continue; + } + let result = annotateTree(child as AugmSPNode); + numCells += result.numCells; + for (let d = 0; d < this.dim; ++d) { + yCell[d] += result.yCell[d]; + } + } + // Update the node and tell the parent. + node.numCells = numCells; + node.yCell = yCell.map(v => v / numCells); + return {numCells, yCell}; + }; + + // Augment the tree with more info. + annotateTree(root); + tree.visit((node: AugmSPNode, low: number[], high: number[]) => { + node.rCell = high[0] - low[0]; + return false; + }); + // compute current Q distribution, unnormalized first + let grad: number[][] = []; + let Z = 0; + let forces: [number[], number[]][] = new Array(N); + for (let i = 0; i < N; ++i) { + let pointI = points[i]; + // Compute the positive forces for the i-th node. + let Fpos = this.dim === 3 ? [0, 0, 0] : [0, 0]; + let neighbors = this.nearest[i]; + for (let k = 0; k < neighbors.length; ++k) { + let j = neighbors[k].index; + let pij = P[i * N + j]; + let pointJ = points[j]; + let squaredDistItoJ = this.dist2(pointI, pointJ); + let premult = pij / (1 + squaredDistItoJ); + this.computeForce(Fpos, premult, pointI, pointJ); + } + // Compute the negative forces for the i-th node. + let FnegZ = this.dim === 3 ? [0, 0, 0] : [0, 0]; + tree.visit((node: AugmSPNode) => { + let squaredDistToCell = this.dist2(pointI, node.yCell); + // Squared distance from point i to cell. + if (node.children == null || + (squaredDistToCell > 0 && + node.rCell / Math.sqrt(squaredDistToCell) < THETA)) { + let qijZ = 1 / (1 + squaredDistToCell); + let dZ = node.numCells * qijZ; + Z += dZ; + dZ *= qijZ; + this.computeForce(FnegZ, dZ, pointI, node.yCell); + return true; + } + // Cell is too close to approximate. + let squaredDistToPoint = this.dist2(pointI, node.point); + let qijZ = 1 / (1 + squaredDistToPoint); + Z += qijZ; + qijZ *= qijZ; + this.computeForce(FnegZ, qijZ, pointI, node.point); + return false; + }, true); + forces[i] = [Fpos, FnegZ]; + } + // Normalize the negative forces and compute the gradient. + const A = 4 * alpha; + const B = 4 / Z; + for (let i = 0; i < N; ++i) { + let [FPos, FNegZ] = forces[i]; + let gsum = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + gsum[d] = A * FPos[d] - B * FNegZ[d]; + } + grad.push(gsum); + } + return grad; + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts new file mode 100644 index 00000000000..1410a84a8e4 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts @@ -0,0 +1,127 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataSet, SpriteAndMetadataInfo, State} from './data'; +import {ProjectorConfig, DataProvider, EmbeddingInfo, TENSORS_MSG_ID} from './data-provider'; +import * as dataProvider from './data-provider'; +import * as logging from './logging'; + +const BYTES_EXTENSION = '.bytes'; + +/** Data provider that loads data from a demo folder. */ +export class DemoDataProvider implements DataProvider { + private projectorConfigPath: string; + private projectorConfig: ProjectorConfig; + + constructor(projectorConfigPath: string) { + this.projectorConfigPath = projectorConfigPath; + } + + private getEmbeddingInfo(tensorName: string): EmbeddingInfo { + let embeddings = this.projectorConfig.embeddings; + for (let i = 0; i < embeddings.length; i++) { + let embedding = embeddings[i]; + if (embedding.tensorName === tensorName) { + return embedding; + } + } + return null; + } + + retrieveRuns(callback: (runs: string[]) => void): void { + callback(['Demo']); + } + + retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) + : void { + const msgId = logging.setModalMessage('Fetching projector config...'); + + const xhr = new XMLHttpRequest(); + xhr.open('GET', this.projectorConfigPath); + xhr.onerror = (err) => { + let errorMessage = err.message; + // If the error is a valid XMLHttpResponse, it's possible this is a + // cross-origin error. + if (xhr.responseText != null) { + errorMessage = 'Cannot fetch projector config, possibly a ' + + 'Cross-Origin request error.'; + } + logging.setErrorMessage(errorMessage, 'fetching projector config'); + }; + xhr.onload = () => { + const projectorConfig = JSON.parse(xhr.responseText) as ProjectorConfig; + logging.setModalMessage(null, msgId); + this.projectorConfig = projectorConfig; + callback(projectorConfig); + }; + xhr.send(); + } + + retrieveTensor(run: string, tensorName: string, + callback: (ds: DataSet) => void) { + let embedding = this.getEmbeddingInfo(tensorName); + let url = `${embedding.tensorPath}`; + if (embedding.tensorPath.substr(-1 * BYTES_EXTENSION.length) === + BYTES_EXTENSION) { + dataProvider.retrieveTensorAsBytes( + this, this.getEmbeddingInfo(tensorName), run, tensorName, url, + callback); + } else { + logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID); + const request = new XMLHttpRequest(); + request.open('GET', url); + request.responseType = 'arraybuffer'; + + request.onerror = () => { + logging.setErrorMessage(request.responseText, 'fetching tensors'); + }; + request.onload = () => { + dataProvider.parseTensors(request.response).then(points => { + callback(new DataSet(points)); + }); + }; + request.send(); + } + } + + retrieveSpriteAndMetadata(run: string, tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void) { + let embedding = this.getEmbeddingInfo(tensorName); + let spriteImagePath = null; + if (embedding.sprite && embedding.sprite.imagePath) { + spriteImagePath = embedding.sprite.imagePath; + } + dataProvider.retrieveSpriteAndMetadataInfo( + embedding.metadataPath, spriteImagePath, embedding.sprite, callback); + } + + getBookmarks( + run: string, tensorName: string, callback: (r: State[]) => void) { + let embedding = this.getEmbeddingInfo(tensorName); + let msgId = logging.setModalMessage('Fetching bookmarks...'); + + const xhr = new XMLHttpRequest(); + xhr.open('GET', embedding.bookmarksPath); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText); + }; + xhr.onload = () => { + const bookmarks = JSON.parse(xhr.responseText) as State[]; + logging.setModalMessage(null, msgId); + callback(bookmarks); + }; + xhr.send(); + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts new file mode 100644 index 00000000000..67124a92323 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint, DataProto, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data'; +import {analyzeMetadata, ProjectorConfig, DataProvider} from './data-provider'; + + +export class ProtoDataProvider implements DataProvider { + private dataProto: DataProto; + + constructor(dataProto: DataProto) { + this.dataProto = dataProto; + } + + retrieveRuns(callback: (runs: string[]) => void): void { + callback(['proto']); + } + + retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) { + callback({ + modelCheckpointPath: 'proto', + embeddings: [{ + tensorName: 'proto', + tensorShape: this.dataProto.shape, + metadataPath: 'proto' + }] + }); + } + + retrieveTensor(run: string, tensorName: string, + callback: (ds: DataSet) => void) { + callback(this.flatArrayToDataset(this.dataProto.tensor)); + } + + retrieveSpriteAndMetadata(run: string, tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void): void { + let columnNames = this.dataProto.metadata.columns.map(c => c.name); + let n = this.dataProto.shape[0]; + let pointsMetadata: PointMetadata[] = new Array(n); + this.dataProto.metadata.columns.forEach(c => { + let values = c.numericValues || c.stringValues; + for (let i = 0; i < n; i++) { + pointsMetadata[i] = pointsMetadata[i] || {}; + pointsMetadata[i][c.name] = values[i]; + } + }); + callback({ + stats: analyzeMetadata(columnNames, pointsMetadata), + pointsInfo: pointsMetadata + }); + } + + getBookmarks(run: string, tensorName: string, + callback: (r: State[]) => void): void { + return callback([]); + } + + private flatArrayToDataset(tensor: number[]): DataSet { + let points: DataPoint[] = []; + let n = this.dataProto.shape[0]; + let d = this.dataProto.shape[1]; + if (n * d !== tensor.length) { + throw 'The shape doesn\'t match the length of the flattened array'; + } + for (let i = 0; i < n; i++) { + let offset = i * d; + points.push({ + vector: new Float32Array(tensor.slice(offset, offset + d)), + metadata: {}, + projections: null, + index: i + }); + } + return new DataSet(points); + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts new file mode 100644 index 00000000000..02720ebf6a7 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts @@ -0,0 +1,137 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataSet, SpriteAndMetadataInfo, State} from './data'; +import * as dataProvider from './data-provider'; +import {DataProvider, EmbeddingInfo, ProjectorConfig} from './data-provider'; +import * as logging from './logging'; + +// Limit for the number of data points we receive from the server. +export const LIMIT_NUM_POINTS = 100000; + +/** + * Data provider that loads data provided by a python server (usually backed + * by a checkpoint file). + */ +export class ServerDataProvider implements DataProvider { + private routePrefix: string; + private runProjectorConfigCache: {[run: string]: ProjectorConfig} = {}; + + constructor(routePrefix: string) { + this.routePrefix = routePrefix; + } + + private getEmbeddingInfo(run: string, tensorName: string, + callback: (e: EmbeddingInfo) => void): void { + this.retrieveProjectorConfig(run, config => { + const embeddings = config.embeddings; + for (let i = 0; i < embeddings.length; i++) { + const embedding = embeddings[i]; + if (embedding.tensorName === tensorName) { + callback(embedding); + return; + } + } + callback(null); + }); + } + + retrieveRuns(callback: (runs: string[]) => void): void { + const msgId = logging.setModalMessage('Fetching runs...'); + + const xhr = new XMLHttpRequest(); + xhr.open('GET', `${this.routePrefix}/runs`); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching runs'); + }; + xhr.onload = () => { + const runs = JSON.parse(xhr.responseText); + logging.setModalMessage(null, msgId); + callback(runs); + }; + xhr.send(); + } + + retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) + : void { + if (run in this.runProjectorConfigCache) { + callback(this.runProjectorConfigCache[run]); + return; + } + + const msgId = logging.setModalMessage('Fetching projector config...'); + + const xhr = new XMLHttpRequest(); + xhr.open('GET', `${this.routePrefix}/info?run=${run}`); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching projector config'); + }; + xhr.onload = () => { + const config = JSON.parse(xhr.responseText) as ProjectorConfig; + logging.setModalMessage(null, msgId); + this.runProjectorConfigCache[run] = config; + callback(config); + }; + xhr.send(); + } + + retrieveTensor(run: string, tensorName: string, + callback: (ds: DataSet) => void) { + this.getEmbeddingInfo(run, tensorName, embedding => { + dataProvider.retrieveTensorAsBytes( + this, embedding, run, tensorName, + `${this.routePrefix}/tensor?run=${run}&name=${tensorName}` + + `&num_rows=${LIMIT_NUM_POINTS}`, + callback); + }); + } + + retrieveSpriteAndMetadata(run: string, tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void) { + this.getEmbeddingInfo(run, tensorName, embedding => { + let metadataPath = null; + if (embedding.metadataPath) { + metadataPath = + `${this.routePrefix}/metadata?` + + `run=${run}&name=${tensorName}&num_rows=${LIMIT_NUM_POINTS}`; + } + let spriteImagePath = null; + if (embedding.sprite && embedding.sprite.imagePath) { + spriteImagePath = + `${this.routePrefix}/sprite_image?run=${run}&name=${tensorName}`; + } + dataProvider.retrieveSpriteAndMetadataInfo(metadataPath, spriteImagePath, + embedding.sprite, callback); + }); + } + + getBookmarks( + run: string, tensorName: string, callback: (r: State[]) => void) { + const msgId = logging.setModalMessage('Fetching bookmarks...'); + + const xhr = new XMLHttpRequest(); + xhr.open( + 'GET', `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}`); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching bookmarks'); + }; + xhr.onload = () => { + logging.setModalMessage(null, msgId); + const bookmarks = JSON.parse(xhr.responseText); + callback(bookmarks); + }; + xhr.send(); + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts new file mode 100644 index 00000000000..c8eede798c6 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts @@ -0,0 +1,429 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ColumnStats, DataPoint, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data'; +import * as logging from './logging'; +import {runAsyncTask} from './util'; + +/** Maximum number of colors supported in the color map. */ +const NUM_COLORS_COLOR_MAP = 50; +const MAX_SPRITE_IMAGE_SIZE_PX = 8192; + +export const METADATA_MSG_ID = 'metadata'; +export const TENSORS_MSG_ID = 'tensors'; + +/** Matches the json format of `projector_config.proto` */ +export interface SpriteMetadata { + imagePath: string; + singleImageDim: [number, number]; +} + +/** Matches the json format of `projector_config.proto` */ +export interface EmbeddingInfo { + /** Name of the tensor. */ + tensorName: string; + /** The shape of the tensor. */ + tensorShape: [number, number]; + /** + * The path to the tensors TSV file. If empty, it is assumed that the tensor + * is stored in the checkpoint file. + */ + tensorPath?: string; + /** The path to the metadata file associated with the tensor. */ + metadataPath?: string; + /** The path to the bookmarks file associated with the tensor. */ + bookmarksPath?: string; + sprite?: SpriteMetadata; +} + +/** + * Matches the json format of `projector_config.proto` + * This should be kept in sync with the code in vz-projector-data-panel which + * holds a template for users to build a projector config JSON object from the + * projector UI. + */ +export interface ProjectorConfig { + embeddings: EmbeddingInfo[]; + modelCheckpointPath?: string; +} + +export type ServingMode = 'demo' | 'server' | 'proto'; + +/** Interface between the data storage and the UI. */ +export interface DataProvider { + /** Returns a list of run names that have embedding config files. */ + retrieveRuns(callback: (runs: string[]) => void): void; + + /** + * Returns the projector configuration: number of tensors, their shapes, + * and their associated metadata files. + */ + retrieveProjectorConfig(run: string, + callback: (d: ProjectorConfig) => void): void; + + /** Fetches and returns the tensor with the specified name. */ + retrieveTensor(run: string, tensorName: string, + callback: (ds: DataSet) => void); + + /** + * Fetches the metadata for the specified tensor. + */ + retrieveSpriteAndMetadata(run: string, tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void): void; + + getBookmarks(run: string, tensorName: string, callback: (r: State[]) => void): + void; +} + +export function retrieveTensorAsBytes( + dp: DataProvider, embedding: EmbeddingInfo, run: string, tensorName: string, + tensorsPath: string, callback: (ds: DataSet) => void) { + // Get the tensor. + logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID); + let xhr = new XMLHttpRequest(); + xhr.open('GET', tensorsPath); + xhr.responseType = 'arraybuffer'; + xhr.onprogress = (ev) => { + if (ev.lengthComputable) { + let percent = (ev.loaded * 100 / ev.total).toFixed(1); + logging.setModalMessage( + 'Fetching tensor values: ' + percent + '%', TENSORS_MSG_ID); + } + }; + xhr.onload = () => { + if (xhr.status !== 200) { + let msg = String.fromCharCode.apply(null, new Uint8Array(xhr.response)); + logging.setErrorMessage(msg, 'fetching tensors'); + return; + } + let data: Float32Array; + try { + data = new Float32Array(xhr.response); + } catch (e) { + logging.setErrorMessage(e, 'parsing tensor bytes'); + return; + } + + let dim = embedding.tensorShape[1]; + let N = data.length / dim; + if (embedding.tensorShape[0] > N) { + logging.setWarningMessage( + `Showing the first ${N.toLocaleString()}` + + ` of ${embedding.tensorShape[0].toLocaleString()} data points`); + } + parseTensorsFromFloat32Array(data, dim).then(dataPoints => { + callback(new DataSet(dataPoints)); + }); + }; + xhr.send(); +} + +export function parseRawTensors( + content: ArrayBuffer, callback: (ds: DataSet) => void) { + parseTensors(content).then(data => { + callback(new DataSet(data)); + }); +} + +export function parseRawMetadata( + contents: ArrayBuffer, callback: (r: SpriteAndMetadataInfo) => void) { + parseMetadata(contents).then(result => callback(result)); +} + +/** + * Parse an ArrayBuffer in a streaming fashion line by line (or custom delim). + * Can handle very large files. + * + * @param content The array buffer. + * @param callback The callback called on each line. + * @param chunkSize The size of each read chunk, defaults to ~1MB. (optional) + * @param delim The delimiter used to split a line, defaults to '\n'. (optional) + * @returns A promise for when it is finished. + */ +function streamParse( + content: ArrayBuffer, callback: (line: string) => void, chunkSize = 1000000, + delim = '\n'): Promise { + return new Promise((resolve, reject) => { + let offset = 0; + let bufferSize = content.byteLength - 1; + let data = ''; + + function readHandler(str) { + offset += chunkSize; + let parts = str.split(delim); + let first = data + parts[0]; + if (parts.length === 1) { + data = first; + readChunk(offset, chunkSize); + return; + } + data = parts[parts.length - 1]; + callback(first); + for (let i = 1; i < parts.length - 1; i++) { + callback(parts[i]); + } + if (offset >= bufferSize) { + if (data) { + callback(data); + } + resolve(); + return; + } + readChunk(offset, chunkSize); + } + + function readChunk(offset: number, size: number) { + const contentChunk = content.slice(offset, offset + size); + + const blob = new Blob([contentChunk]); + const file = new FileReader(); + file.onload = (e: any) => readHandler(e.target.result); + file.readAsText(blob); + } + + readChunk(offset, chunkSize); + }); +} + +/** Parses a tsv text file. */ +export function parseTensors( + content: ArrayBuffer, valueDelim = '\t'): Promise { + logging.setModalMessage('Parsing tensors...', TENSORS_MSG_ID); + + return new Promise((resolve, reject) => { + const data: DataPoint[] = []; + let numDim: number; + + streamParse(content, (line: string) => { + line = line.trim(); + if (line === '') { + return; + } + const row = line.split(valueDelim); + const dataPoint: DataPoint = { + metadata: {}, + vector: null, + index: data.length, + projections: null, + }; + // If the first label is not a number, take it as the label. + if (isNaN(row[0] as any) || numDim === row.length - 1) { + dataPoint.metadata['label'] = row[0]; + dataPoint.vector = new Float32Array(row.slice(1).map(Number)); + } else { + dataPoint.vector = new Float32Array(row.map(Number)); + } + data.push(dataPoint); + if (numDim == null) { + numDim = dataPoint.vector.length; + } + if (numDim !== dataPoint.vector.length) { + logging.setModalMessage( + 'Parsing failed. Vector dimensions do not match'); + throw Error('Parsing failed'); + } + if (numDim <= 1) { + logging.setModalMessage( + 'Parsing failed. Found a vector with only one dimension?'); + throw Error('Parsing failed'); + } + }).then(() => { + logging.setModalMessage(null, TENSORS_MSG_ID); + resolve(data); + }); + }); +} + +/** Parses a tsv text file. */ +export function parseTensorsFromFloat32Array(data: Float32Array, + dim: number): Promise { + return runAsyncTask('Parsing tensors...', () => { + const N = data.length / dim; + const dataPoints: DataPoint[] = []; + let offset = 0; + for (let i = 0; i < N; ++i) { + dataPoints.push({ + metadata: {}, + vector: data.subarray(offset, offset + dim), + index: i, + projections: null, + }); + offset += dim; + } + return dataPoints; + }, TENSORS_MSG_ID).then(dataPoints => { + logging.setModalMessage(null, TENSORS_MSG_ID); + return dataPoints; + }); +} + +export function analyzeMetadata( + columnNames, pointsMetadata: PointMetadata[]): ColumnStats[] { + const columnStats: ColumnStats[] = columnNames.map(name => { + return { + name: name, + isNumeric: true, + tooManyUniqueValues: false, + min: Number.POSITIVE_INFINITY, + max: Number.NEGATIVE_INFINITY + }; + }); + + const mapOfValues: [{[value: string]: number}] = + columnNames.map(() => new Object()); + + pointsMetadata.forEach(metadata => { + columnNames.forEach((name: string, colIndex: number) => { + const stats = columnStats[colIndex]; + const map = mapOfValues[colIndex]; + const value = metadata[name]; + + // Skip missing values. + if (value == null) { + return; + } + + if (!stats.tooManyUniqueValues) { + if (value in map) { + map[value]++; + } else { + map[value] = 1; + } + if (Object.keys(map).length > NUM_COLORS_COLOR_MAP) { + stats.tooManyUniqueValues = true; + } + } + if (isNaN(value as any)) { + stats.isNumeric = false; + } else { + metadata[name] = +value; + stats.min = Math.min(stats.min, +value); + stats.max = Math.max(stats.max, +value); + } + }); + }); + columnStats.forEach((stats, colIndex) => { + stats.uniqueEntries = Object.keys(mapOfValues[colIndex]).map(label => { + return {label, count: mapOfValues[colIndex][label]}; + }); + }); + return columnStats; +} + +export function parseMetadata(content: ArrayBuffer): + Promise { + logging.setModalMessage('Parsing metadata...', METADATA_MSG_ID); + + return new Promise((resolve, reject) => { + let pointsMetadata: PointMetadata[] = []; + let hasHeader = false; + let lineNumber = 0; + let columnNames = ['label']; + streamParse(content, (line: string) => { + if (line.trim().length === 0) { + return; + } + if (lineNumber === 0) { + hasHeader = line.indexOf('\t') >= 0; + + // If the first row doesn't contain metadata keys, we assume that the + // values are labels. + if (hasHeader) { + columnNames = line.split('\t'); + lineNumber++; + return; + } + } + + lineNumber++; + + let rowValues = line.split('\t'); + let metadata: PointMetadata = {}; + pointsMetadata.push(metadata); + columnNames.forEach((name: string, colIndex: number) => { + let value = rowValues[colIndex]; + // Normalize missing values. + value = (value === '' ? null : value); + metadata[name] = value; + }); + }).then(() => { + logging.setModalMessage(null, METADATA_MSG_ID); + resolve({ + stats: analyzeMetadata(columnNames, pointsMetadata), + pointsInfo: pointsMetadata + }); + }); + }); +} + +export function fetchImage(url: string): Promise { + return new Promise((resolve, reject) => { + let image = new Image(); + image.onload = () => resolve(image); + image.onerror = (err) => reject(err); + image.crossOrigin = ''; + image.src = url; + }); +} + +export function retrieveSpriteAndMetadataInfo(metadataPath: string, + spriteImagePath: string, spriteMetadata: SpriteMetadata, + callback: (r: SpriteAndMetadataInfo) => void) { + let metadataPromise: Promise = Promise.resolve({}); + if (metadataPath) { + metadataPromise = new Promise((resolve, reject) => { + logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); + + const request = new XMLHttpRequest(); + request.open('GET', metadataPath); + request.responseType = 'arraybuffer'; + + request.onerror = () => { + logging.setErrorMessage(request.responseText, 'fetching metadata'); + reject(); + }; + request.onload = () => { + resolve(parseMetadata(request.response)); + }; + request.send(null); + }); + } + let spriteMsgId = null; + let spritesPromise: Promise = null; + if (spriteImagePath) { + spriteMsgId = logging.setModalMessage('Fetching sprite image...'); + spritesPromise = fetchImage(spriteImagePath); + } + + // Fetch the metadata and the image in parallel. + Promise.all([metadataPromise, spritesPromise]).then(values => { + if (spriteMsgId) { + logging.setModalMessage(null, spriteMsgId); + } + const [metadata, spriteImage] = values; + + if (spriteImage && (spriteImage.height > MAX_SPRITE_IMAGE_SIZE_PX || + spriteImage.width > MAX_SPRITE_IMAGE_SIZE_PX)) { + logging.setModalMessage( + `Error: Sprite image of dimensions ${spriteImage.width}px x ` + + `${spriteImage.height}px exceeds maximum dimensions ` + + `${MAX_SPRITE_IMAGE_SIZE_PX}px x ${MAX_SPRITE_IMAGE_SIZE_PX}px`); + } else { + metadata.spriteImage = spriteImage; + metadata.spriteMetadata = spriteMetadata; + callback(metadata); + } + }); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts new file mode 100644 index 00000000000..01b89ca7001 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts @@ -0,0 +1,96 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint, SpriteAndMetadataInfo} from './data'; +import * as data_provider from './data-provider'; + +/** + * Converts a string to an ArrayBuffer. + */ +function stringToArrayBuffer(str: string): Promise { + return new Promise((resolve, reject) => { + let blob = new Blob([str]); + let file = new FileReader(); + file.onload = (e: any) => { + resolve(e.target.result); + }; + file.readAsArrayBuffer(blob); + }); +} + +/** + * Converts an data array to TSV format. + */ +function dataToTsv(data: string[][]|number[][]) { + let lines = []; + for (let i = 0; i < data.length; i++) { + lines.push(data[i].join('\t')); + } + return lines.join('\n'); +} + +describe('parse tensors', () => { + it('parseTensors', (doneFn) => { + let tensors = [[1.0, 2.0], [2.0, 3.0]]; + stringToArrayBuffer(dataToTsv(tensors)) + .then((tensorsArrayBuffer: ArrayBuffer) => { + data_provider.parseTensors(tensorsArrayBuffer) + .then((data: DataPoint[]) => { + expect(data.length).toBe(2); + + expect(data[0].vector).toEqual(new Float32Array(tensors[0])); + expect(data[0].index).toEqual(0); + expect(data[0].projections).toBeNull(); + + expect(data[1].vector).toEqual(new Float32Array(tensors[1])); + expect(data[1].index).toEqual(1); + expect(data[1].projections).toBeNull(); + doneFn(); + }); + }); + }); + it('parseMetadata', (doneFn) => { + let metadata = [['label', 'fakecol'], ['Г', '0'], ['label1', '1']]; + + stringToArrayBuffer(dataToTsv(metadata)) + .then((metadataArrayBuffer: ArrayBuffer) => { + data_provider.parseMetadata(metadataArrayBuffer) + .then((spriteAndMetadataInfo: SpriteAndMetadataInfo) => { + expect(spriteAndMetadataInfo.stats.length).toBe(2); + expect(spriteAndMetadataInfo.stats[0].name) + .toBe(metadata[0][0]); + expect(spriteAndMetadataInfo.stats[0].isNumeric).toBe(false); + expect(spriteAndMetadataInfo.stats[0].tooManyUniqueValues) + .toBe(false); + expect(spriteAndMetadataInfo.stats[1].name) + .toBe(metadata[0][1]); + expect(spriteAndMetadataInfo.stats[1].isNumeric).toBe(true); + expect(spriteAndMetadataInfo.stats[1].tooManyUniqueValues) + .toBe(false); + + expect(spriteAndMetadataInfo.pointsInfo.length).toBe(2); + expect(spriteAndMetadataInfo.pointsInfo[0]['label']) + .toBe(metadata[1][0]); + expect(spriteAndMetadataInfo.pointsInfo[0]['fakecol']) + .toBe(+metadata[1][1]); + expect(spriteAndMetadataInfo.pointsInfo[1]['label']) + .toBe(metadata[2][0]); + expect(spriteAndMetadataInfo.pointsInfo[1]['fakecol']) + .toBe(+metadata[2][1]); + doneFn(); + }); + }); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts new file mode 100644 index 00000000000..c4e81985fc8 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts @@ -0,0 +1,547 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {TSNE} from './bh_tsne'; +import {SpriteMetadata} from './data-provider'; +import * as knn from './knn'; +import * as logging from './logging'; +import * as scatterPlot from './scatterPlot'; +import * as util from './util'; +import * as vector from './vector'; + +export type DistanceFunction = (a: number[], b: number[]) => number; +export type ProjectionComponents3D = [string, string, string]; + +export interface PointMetadata { [key: string]: number|string; } + +export interface DataProto { + shape: [number, number]; + tensor: number[]; + metadata: { + columns: Array< + {name: string; stringValues: string[]; numericValues: number[];}>; + }; +} + +/** Statistics for a metadata column. */ +export interface ColumnStats { + name: string; + isNumeric: boolean; + tooManyUniqueValues: boolean; + uniqueEntries?: Array<{label: string, count: number}>; + min: number; + max: number; +} + +export interface SpriteAndMetadataInfo { + stats?: ColumnStats[]; + pointsInfo?: PointMetadata[]; + spriteImage?: HTMLImageElement; + spriteMetadata?: SpriteMetadata; +} + +/** A single collection of points which make up a sequence through space. */ +export interface Sequence { + /** Indices into the DataPoints array in the Data object. */ + pointIndices: number[]; +} + +export interface DataPoint { + /** The point in the original space. */ + vector: Float32Array; + + /* + * Metadata for each point. Each metadata is a set of key/value pairs + * where the value can be a string or a number. + */ + metadata: PointMetadata; + + /** index of the sequence, used for highlighting on click */ + sequenceIndex?: number; + + /** index in the original data source */ + index: number; + + /** This is where the calculated projections space are cached */ + projections: {[key: string]: number}; +} + +const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0; +/** Controls whether nearest neighbors computation is done on the GPU or CPU. */ +const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX; + +export const TSNE_SAMPLE_SIZE = 10000; +export const PCA_SAMPLE_SIZE = 50000; +/** Number of dimensions to sample when doing approximate PCA. */ +export const PCA_SAMPLE_DIM = 200; +/** Number of pca components to compute. */ +const NUM_PCA_COMPONENTS = 10; +/** + * Reserved metadata attributes used for sequence information + * NOTE: Use "__seq_next__" as "__next__" is deprecated. + */ +const SEQUENCE_METADATA_ATTRS = ['__next__', '__seq_next__']; + +function getSequenceNextPointIndex(pointMetadata: PointMetadata): number|null { + let sequenceAttr = null; + for (let metadataAttr of SEQUENCE_METADATA_ATTRS) { + if (metadataAttr in pointMetadata && pointMetadata[metadataAttr] !== '') { + sequenceAttr = pointMetadata[metadataAttr]; + break; + } + } + if (sequenceAttr == null) { + return null; + } + return +sequenceAttr; +} + +/** + * Dataset contains a DataPoints array that should be treated as immutable. This + * acts as a working subset of the original data, with cached properties + * from computationally expensive operations. Because creating a subset + * requires normalizing and shifting the vector space, we make a copy of the + * data so we can still always create new subsets based on the original data. + */ +export class DataSet { + points: DataPoint[]; + sequences: Sequence[]; + + shuffledDataIndices: number[] = []; + + /** + * This keeps a list of all current projections so you can easily test to see + * if it's been calculated already. + */ + projections: {[projection: string]: boolean} = {}; + nearest: knn.NearestEntry[][]; + nearestK: number; + tSNEIteration: number = 0; + tSNEShouldStop = true; + dim: [number, number] = [0, 0]; + hasTSNERun: boolean = false; + spriteAndMetadataInfo: SpriteAndMetadataInfo; + fracVariancesExplained: number[]; + + private tsne: TSNE; + + /** Creates a new Dataset */ + constructor( + points: DataPoint[], spriteAndMetadataInfo?: SpriteAndMetadataInfo) { + this.points = points; + this.shuffledDataIndices = util.shuffle(util.range(this.points.length)); + this.sequences = this.computeSequences(points); + this.dim = [this.points.length, this.points[0].vector.length]; + this.spriteAndMetadataInfo = spriteAndMetadataInfo; + } + + private computeSequences(points: DataPoint[]) { + // Keep a list of indices seen so we don't compute sequences for a given + // point twice. + let indicesSeen = new Int8Array(points.length); + // Compute sequences. + let indexToSequence: {[index: number]: Sequence} = {}; + let sequences: Sequence[] = []; + for (let i = 0; i < points.length; i++) { + if (indicesSeen[i]) { + continue; + } + indicesSeen[i] = 1; + + // Ignore points without a sequence attribute. + let next = getSequenceNextPointIndex(points[i].metadata); + if (next == null) { + continue; + } + if (next in indexToSequence) { + let existingSequence = indexToSequence[next]; + // Pushing at the beginning of the array. + existingSequence.pointIndices.unshift(i); + indexToSequence[i] = existingSequence; + continue; + } + // The current point is pointing to a new/unseen sequence. + let newSequence: Sequence = {pointIndices: []}; + indexToSequence[i] = newSequence; + sequences.push(newSequence); + let currentIndex = i; + while (points[currentIndex]) { + newSequence.pointIndices.push(currentIndex); + let next = getSequenceNextPointIndex(points[currentIndex].metadata); + if (next != null) { + indicesSeen[next] = 1; + currentIndex = next; + } else { + currentIndex = -1; + } + } + } + return sequences; + } + + projectionCanBeRendered(projection: ProjectionType): boolean { + if (projection !== 'tsne') { + return true; + } + return this.tSNEIteration > 0; + } + + /** + * Returns a new subset dataset by copying out data. We make a copy because + * we have to modify the vectors by normalizing them. + * + * @param subset Array of indices of points that we want in the subset. + * + * @return A subset of the original dataset. + */ + getSubset(subset?: number[]): DataSet { + const pointsSubset = ((subset != null) && (subset.length > 0)) ? + subset.map(i => this.points[i]) : + this.points; + let points = pointsSubset.map(dp => { + return { + metadata: dp.metadata, + index: dp.index, + vector: dp.vector.slice(), + projections: {} as {[key: string]: number} + }; + }); + return new DataSet(points, this.spriteAndMetadataInfo); + } + + /** + * Computes the centroid, shifts all points to that centroid, + * then makes them all unit norm. + */ + normalize() { + // Compute the centroid of all data points. + let centroid = vector.centroid(this.points, a => a.vector); + if (centroid == null) { + throw Error('centroid should not be null'); + } + // Shift all points by the centroid and make them unit norm. + for (let id = 0; id < this.points.length; ++id) { + let dataPoint = this.points[id]; + dataPoint.vector = vector.sub(dataPoint.vector, centroid); + vector.unit(dataPoint.vector); + } + } + + /** Projects the dataset onto a given vector and caches the result. */ + projectLinear(dir: vector.Vector, label: string) { + this.projections[label] = true; + this.points.forEach(dataPoint => { + dataPoint.projections[label] = vector.dot(dataPoint.vector, dir); + }); + } + + /** Projects the dataset along the top 10 principal components. */ + projectPCA(): Promise { + if (this.projections['pca-0'] != null) { + return Promise.resolve(null); + } + return util.runAsyncTask('Computing PCA...', () => { + // Approximate pca vectors by sampling the dimensions. + let dim = this.points[0].vector.length; + let vectors = this.shuffledDataIndices.map(i => this.points[i].vector); + if (dim > PCA_SAMPLE_DIM) { + vectors = vector.projectRandom(vectors, PCA_SAMPLE_DIM); + } + let sampledVectors = vectors.slice(0, PCA_SAMPLE_SIZE); + + let sigma = numeric.div( + numeric.dot(numeric.transpose(sampledVectors), sampledVectors), + sampledVectors.length); + let svd = numeric.svd(sigma); + + let variances: number[] = svd.S; + let totalVariance = 0; + for (let i = 0; i < variances.length; ++i) { + totalVariance += variances[i]; + } + for (let i = 0; i < variances.length; ++i) { + variances[i] /= totalVariance; + } + this.fracVariancesExplained = variances; + + let U: number[][] = svd.U; + let pcaVectors = vectors.map(vector => { + let newV = new Float32Array(NUM_PCA_COMPONENTS); + for (let newDim = 0; newDim < NUM_PCA_COMPONENTS; newDim++) { + let dot = 0; + for (let oldDim = 0; oldDim < vector.length; oldDim++) { + dot += vector[oldDim] * U[oldDim][newDim]; + } + newV[newDim] = dot; + } + return newV; + }); + for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { + let label = 'pca-' + d; + this.projections[label] = true; + for (let i = 0; i < pcaVectors.length; i++) { + let pointIndex = this.shuffledDataIndices[i]; + this.points[pointIndex].projections[label] = pcaVectors[i][d]; + } + } + }); + } + + /** Runs tsne on the data. */ + projectTSNE( + perplexity: number, learningRate: number, tsneDim: number, + stepCallback: (iter: number) => void) { + this.hasTSNERun = true; + let k = Math.floor(3 * perplexity); + let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim}; + this.tsne = new TSNE(opt); + this.tSNEShouldStop = false; + this.tSNEIteration = 0; + + let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE); + let step = () => { + if (this.tSNEShouldStop) { + stepCallback(null); + this.tsne = null; + return; + } + this.tsne.step(); + let result = this.tsne.getSolution(); + sampledIndices.forEach((index, i) => { + let dataPoint = this.points[index]; + + dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; + dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; + if (tsneDim === 3) { + dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; + } + }); + this.tSNEIteration++; + stepCallback(this.tSNEIteration); + requestAnimationFrame(step); + }; + + // Nearest neighbors calculations. + let knnComputation: Promise; + + if (this.nearest != null && k === this.nearestK) { + // We found the nearest neighbors before and will reuse them. + knnComputation = Promise.resolve(this.nearest); + } else { + let sampledData = sampledIndices.map(i => this.points[i]); + this.nearestK = k; + knnComputation = KNN_GPU_ENABLED ? + knn.findKNNGPUCosine(sampledData, k, (d => d.vector)) : + knn.findKNN( + sampledData, k, (d => d.vector), + (a, b, limit) => vector.cosDistNorm(a, b)); + } + knnComputation.then(nearest => { + this.nearest = nearest; + util.runAsyncTask('Initializing T-SNE...', () => { + this.tsne.initDataDist(this.nearest); + }).then(step); + }); + } + + /** + * Merges metadata to the dataset and returns whether it succeeded. + */ + mergeMetadata(metadata: SpriteAndMetadataInfo): boolean { + if (metadata.pointsInfo.length !== this.points.length) { + let errorMessage = `Number of tensors (${this.points.length}) do not` + + ` match the number of lines in metadata` + + ` (${metadata.pointsInfo.length}).`; + + if (metadata.stats.length === 1 && + this.points.length + 1 === metadata.pointsInfo.length) { + // If there is only one column of metadata and the number of points is + // exactly one less than the number of metadata lines, this is due to an + // unnecessary header line in the metadata and we can show a meaningful + // error. + logging.setErrorMessage( + errorMessage + ' Single column metadata should not have a header ' + + 'row.', + 'merging metadata'); + return false; + } else if ( + metadata.stats.length > 1 && + this.points.length - 1 === metadata.pointsInfo.length) { + // If there are multiple columns of metadata and the number of points is + // exactly one greater than the number of lines in the metadata, this + // means there is a missing metadata header. + logging.setErrorMessage( + errorMessage + ' Multi-column metadata should have a header ' + + 'row with column labels.', + 'merging metadata'); + return false; + } + + logging.setWarningMessage(errorMessage); + } + this.spriteAndMetadataInfo = metadata; + metadata.pointsInfo.slice(0, this.points.length) + .forEach((m, i) => this.points[i].metadata = m); + return true; + } + + stopTSNE() { + this.tSNEShouldStop = true; + } + + /** + * Finds the nearest neighbors of the query point using a + * user-specified distance metric. + */ + findNeighbors(pointIndex: number, distFunc: DistanceFunction, numNN: number): + knn.NearestEntry[] { + // Find the nearest neighbors of a particular point. + let neighbors = knn.findKNNofPoint( + this.points, pointIndex, numNN, (d => d.vector), distFunc); + // TODO(smilkov): Figure out why we slice. + let result = neighbors.slice(0, numNN); + return result; + } + + /** + * Search the dataset based on a metadata field. + */ + query(query: string, inRegexMode: boolean, fieldName: string): number[] { + let predicate = util.getSearchPredicate(query, inRegexMode, fieldName); + let matches: number[] = []; + this.points.forEach((point, id) => { + if (predicate(point)) { + matches.push(id); + } + }); + return matches; + } +} + +export type ProjectionType = 'tsne' | 'pca' | 'custom'; + +export class Projection { + constructor( + public projectionType: ProjectionType, + public projectionComponents: ProjectionComponents3D, + public dimensionality: number, public dataSet: DataSet) {} +} + +export interface ColorOption { + name: string; + desc?: string; + map?: (value: string|number) => string; + /** List of items for the color map. Defined only for categorical map. */ + items?: {label: string, count: number}[]; + /** Threshold values and their colors. Defined for gradient color map. */ + thresholds?: {value: number, color: string}[]; + isSeparator?: boolean; + tooManyUniqueValues?: boolean; +} + +/** + * An interface that holds all the data for serializing the current state of + * the world. + */ +export class State { + /** A label identifying this state. */ + label: string = ''; + + /** Whether this State is selected in the bookmarks pane. */ + isSelected: boolean = false; + + /** The selected projection tab. */ + selectedProjection: ProjectionType; + + /** Dimensions of the DataSet. */ + dataSetDimensions: [number, number]; + + /** t-SNE parameters */ + tSNEIteration: number = 0; + tSNEPerplexity: number = 0; + tSNELearningRate: number = 0; + tSNEis3d: boolean = true; + + /** PCA projection component dimensions */ + pcaComponentDimensions: number[] = []; + + /** Custom projection parameters */ + customSelectedSearchByMetadataOption: string; + customXLeftText: string; + customXLeftRegex: boolean; + customXRightText: string; + customXRightRegex: boolean; + customYUpText: string; + customYUpRegex: boolean; + customYDownText: string; + customYDownRegex: boolean; + + /** The computed projections of the tensors. */ + projections: Array<{[key: string]: number}> = []; + + /** Filtered dataset indices. */ + filteredPoints: number[]; + + /** The indices of selected points. */ + selectedPoints: number[] = []; + + /** Camera state (2d/3d, position, target, zoom, etc). */ + cameraDef: scatterPlot.CameraDef; + + /** Color by option. */ + selectedColorOptionName: string; + forceCategoricalColoring: boolean; + + /** Label by option. */ + selectedLabelOption: string; +} + +export function getProjectionComponents( + projection: ProjectionType, + components: (number|string)[]): ProjectionComponents3D { + if (components.length > 3) { + throw new RangeError('components length must be <= 3'); + } + const projectionComponents: [string, string, string] = [null, null, null]; + const prefix = (projection === 'custom') ? 'linear' : projection; + for (let i = 0; i < components.length; ++i) { + if (components[i] == null) { + continue; + } + projectionComponents[i] = prefix + '-' + components[i]; + } + return projectionComponents; +} + +export function stateGetAccessorDimensions(state: State): Array { + let dimensions: Array; + switch (state.selectedProjection) { + case 'pca': + dimensions = state.pcaComponentDimensions.slice(); + break; + case 'tsne': + dimensions = [0, 1]; + if (state.tSNEis3d) { + dimensions.push(2); + } + break; + case 'custom': + dimensions = ['x', 'y']; + break; + default: + throw new Error('Unexpected fallthrough'); + } + return dimensions; +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts new file mode 100644 index 00000000000..924ae3a929f --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts @@ -0,0 +1,104 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint, DataSet, State, stateGetAccessorDimensions} from './data'; + +/** + * Helper method that makes a list of points given an array of + * sequence indexes. + * + * @param sequences The i-th entry holds the 'next' attribute for the i-th + * point. + */ +function makePointsWithSequences( + sequences: number[], nextAttr = '__seq_next__') { + let points: DataPoint[] = []; + sequences.forEach((t, i) => { + let metadata: {[key: string]: any} = {}; + metadata[nextAttr] = t >= 0 ? t : null; + points.push({ + vector: new Float32Array(0), + metadata: metadata, + projections: {}, + index: i + }); + }); + return points; +} + +describe('constructor_with_sequences', () => { + it('Simple forward pointing sequences, __seq_next__ metadata format', () => { + // The input is: 0->2, 1->None, 2->3, 3->None. This should return + // one sequence 0->2->3. + const points = makePointsWithSequences([2, -1, 3, -1]); + let dataset = new DataSet(points); + expect(dataset.sequences.length).toEqual(1); + expect(dataset.sequences[0].pointIndices).toEqual([0, 2, 3]); + }); + + it('Simple forward pointing sequences, __next__ metadata format', () => { + // The input is: 0->2, 1->None, 2->3, 3->None. This should return + // one sequence 0->2->3. + const points = makePointsWithSequences([2, -1, 3, -1], '__next__'); + let dataset = new DataSet(points); + expect(dataset.sequences.length).toEqual(1); + expect(dataset.sequences[0].pointIndices).toEqual([0, 2, 3]); + }); + + it('No sequences', () => { + let points = makePointsWithSequences([-1, -1, -1, -1]); + let dataset = new DataSet(points); + expect(dataset.sequences.length).toEqual(0); + }); + + it('A sequence that goes backwards and forward in the array', () => { + // The input is: 0->2, 1->0, 2->nothing, 3->1. This should return + // one sequence 3->1->0->2. + let points = makePointsWithSequences([2, 0, -1, 1]); + let dataset = new DataSet(points); + expect(dataset.sequences.length).toEqual(1); + expect(dataset.sequences[0].pointIndices).toEqual([3, 1, 0, 2]); + }); +}); + +describe('stateGetAccessorDimensions', () => { + it('returns [0, 1] for 2d t-SNE', () => { + const state = new State(); + state.selectedProjection = 'tsne'; + state.tSNEis3d = false; + expect(stateGetAccessorDimensions(state)).toEqual([0, 1]); + }); + + it('returns [0, 1, 2] for 3d t-SNE', () => { + const state = new State(); + state.selectedProjection = 'tsne'; + state.tSNEis3d = true; + expect(stateGetAccessorDimensions(state)).toEqual([0, 1, 2]); + }); + + it('returns pca component dimensions array for pca', () => { + const state = new State(); + state.selectedProjection = 'pca'; + state.pcaComponentDimensions = [13, 12, 11, 10]; + expect(stateGetAccessorDimensions(state)) + .toEqual(state.pcaComponentDimensions); + }); + + it('returns ["x", "y"] for custom projections', () => { + const state = new State(); + state.selectedProjection = 'custom'; + expect(stateGetAccessorDimensions(state)).toEqual(['x', 'y']); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts new file mode 100644 index 00000000000..cbc1512c215 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(smilkov): Split into weblas.d.ts and numeric.d.ts and write +// typings for numeric. +interface Tensor { + new(size: [number, number], data: Float32Array); + transfer(): Float32Array; + delete(): void; +} + +interface Weblas { + sgemm(M: number, N: number, K: number, alpha: number, + A: Float32Array, B: Float32Array, beta: number, C: Float32Array): + Float32Array; + pipeline: { + Tensor: Tensor; + sgemm(alpha: number, A: Tensor, B: Tensor, beta: number, + C: Tensor): Tensor; + }; + util: { + transpose(M: number, N: number, data: Float32Array): Tensor; + }; + +} + +declare let numeric: any; +declare let weblas: Weblas; + +interface AnalyticsEventType { + hitType: string; + page?: string; + eventCategory?: string; + eventAction?: string; + eventLabel?: string; + eventValue?: number; +} + +declare let ga: (command: string, eventObj: AnalyticsEventType) => void; \ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts new file mode 100644 index 00000000000..ac3144e6493 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts @@ -0,0 +1,146 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Min key heap. */ +export type HeapItem = { + key: number, + value: T +}; + +/** + * Min-heap data structure. Provides O(1) for peek, returning the smallest key. + */ +// TODO(jart): Rename to Heap and use Comparator. +export class MinHeap { + private arr: HeapItem[] = []; + + /** Push an element with the provided key. */ + push(key: number, value: T): void { + this.arr.push({key, value}); + this.bubbleUp(this.arr.length - 1); + } + + /** Pop the element with the smallest key. */ + pop(): HeapItem { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); + } + let item = this.arr[0]; + let last = this.arr.length - 1; + this.arr[0] = this.arr[last]; + this.arr.pop(); + if (last > 0) { + this.bubbleDown(0); + } + return item; + }; + + /** Returns, but doesn't remove the element with the smallest key */ + peek(): HeapItem { return this.arr[0]; } + + /** + * Pops the element with the smallest key and at the same time + * adds the newly provided element. This is faster than calling + * pop() and push() separately. + */ + popPush(key: number, value: T): HeapItem { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); + } + let item = this.arr[0]; + this.arr[0] = {key, value}; + if (this.arr.length > 0) { + this.bubbleDown(0); + } + return item; + } + + /** Returns the number of elements in the heap. */ + size(): number { return this.arr.length; } + + /** Returns all the items in the heap. */ + items(): HeapItem[] { return this.arr; } + + private swap(a: number, b: number) { + let temp = this.arr[a]; + this.arr[a] = this.arr[b]; + this.arr[b] = temp; + } + + private bubbleDown(pos: number) { + let left = (pos << 1) + 1; + let right = left + 1; + let largest = pos; + if (left < this.arr.length && this.arr[left].key < this.arr[largest].key) { + largest = left; + } + if (right < this.arr.length && + this.arr[right].key < this.arr[largest].key) { + largest = right; + } + if (largest !== pos) { + this.swap(largest, pos); + this.bubbleDown(largest); + } + } + + private bubbleUp(pos: number) { + if (pos <= 0) { + return; + } + let parent = ((pos - 1) >> 1); + if (this.arr[pos].key < this.arr[parent].key) { + this.swap(pos, parent); + this.bubbleUp(parent); + } + } +} + +/** List that keeps the K elements with the smallest keys. */ +export class KMin { + private k: number; + private maxHeap = new MinHeap(); + + /** Constructs a new k-min data structure with the provided k. */ + constructor(k: number) { this.k = k; } + + /** Adds an element to the list. */ + add(key: number, value: T) { + if (this.maxHeap.size() < this.k) { + this.maxHeap.push(-key, value); + return; + } + let largest = this.maxHeap.peek(); + // If the new element is smaller, replace the largest with the new element. + if (key < -largest.key) { + this.maxHeap.popPush(-key, value); + } + } + + /** Returns the k items with the smallest keys. */ + getMinKItems(): T[] { + let items = this.maxHeap.items(); + items.sort((a, b) => b.key - a.key); + return items.map(a => a.value); + } + + /** Returns the size of the list. */ + getSize(): number { return this.maxHeap.size(); } + + /** Returns the largest key in the list. */ + getLargestKey(): number { + return this.maxHeap.size() === 0 ? null : -this.maxHeap.peek().key; + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts new file mode 100644 index 00000000000..906e077b5d7 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts @@ -0,0 +1,235 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {runAsyncTask} from './util'; +import * as logging from './logging'; +import {KMin} from './heap'; +import {Vector} from './vector'; +import * as vector from './vector'; + +export type NearestEntry = { + index: number, + dist: number +}; + +/** + * Optimal size for the height of the matrix when doing computation on the GPU + * using WebGL. This was found experimentally. + * + * This also guarantees that for computing pair-wise distance for up to 10K + * vectors, no more than 40MB will be allocated in the GPU. Without the + * allocation limit, we can freeze the graphics of the whole OS. + */ +const OPTIMAL_GPU_BLOCK_SIZE = 256; +/** Id of message box used for knn gpu progress bar. */ +const KNN_GPU_MSG_ID = 'knn-gpu'; + +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the GPU (WebGL) using cosine distance. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + */ +export function findKNNGPUCosine( + dataPoints: T[], k: number, + accessor: (dataPoint: T) => Float32Array): Promise { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + + // The goal is to compute a large matrix multiplication A*A.T where A is of + // size NxD and A.T is its transpose. This results in a NxN matrix which + // could be too big to store on the GPU memory. To avoid memory overflow, we + // compute multiple A*partial_A.T where partial_A is of size BxD (B is much + // smaller than N). This results in storing only NxB size matrices on the GPU + // at a given time. + + // A*A.T will give us NxN matrix holding the cosine distance between every + // pair of points, which we sort using KMin data structure to obtain the + // K nearest neighbors for each point. + let typedArray = vector.toTypedArray(dataPoints, accessor); + let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray); + let nearest: NearestEntry[][] = new Array(N); + let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); + let M = Math.floor(N / numPieces); + let modulo = N % numPieces; + let offset = 0; + let progress = 0; + let progressDiff = 1 / (2 * numPieces); + let piece = 0; + + function step(resolve: (result: NearestEntry[][]) => void) { + let progressMsg = + 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; + runAsyncTask(progressMsg, () => { + let B = piece < modulo ? M + 1 : M; + let typedB = new Float32Array(B * dim); + for (let i = 0; i < B; ++i) { + let vector = accessor(dataPoints[offset + i]); + for (let d = 0; d < dim; ++d) { + typedB[i * dim + d] = vector[d]; + } + } + let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB); + // Result is N x B matrix. + let result = + weblas.pipeline.sgemm(1, bigMatrix, partialMatrix, null, null); + let partial = result.transfer(); + partialMatrix.delete(); + result.delete(); + progress += progressDiff; + for (let i = 0; i < B; i++) { + let kMin = new KMin(k); + let iReal = offset + i; + for (let j = 0; j < N; j++) { + if (j === iReal) { + continue; + } + let cosDist = 1 - partial[j * B + i]; // [j, i]; + kMin.add(cosDist, {index: j, dist: cosDist}); + } + nearest[iReal] = kMin.getMinKItems(); + } + progress += progressDiff; + offset += B; + piece++; + }, KNN_GPU_MSG_ID).then(() => { + if (piece < numPieces) { + step(resolve); + } else { + logging.setModalMessage(null, KNN_GPU_MSG_ID); + bigMatrix.delete(); + resolve(nearest); + } + }, error => { + // GPU failed. Reverting back to CPU. + logging.setModalMessage(null, KNN_GPU_MSG_ID); + let distFunc = (a, b, limit) => vector.cosDistNorm(a, b); + findKNN(dataPoints, k, accessor, distFunc).then(nearest => { + resolve(nearest); + }); + }); + } + return new Promise(resolve => step(resolve)); +} + +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the CPU using a user-specified distance method. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + * @param dist Method that takes two vectors and a limit, and computes the + * distance between two vectors, with the ability to stop early if the + * distance is above the limit. + */ +export function findKNN( + dataPoints: T[], k: number, accessor: (dataPoint: T) => Float32Array, + dist: (a: Vector, b: Vector, limit: number) => + number): Promise { + return runAsyncTask('Finding nearest neighbors...', () => { + let N = dataPoints.length; + let nearest: NearestEntry[][] = new Array(N); + // Find the distances from node i. + let kMin: KMin[] = new Array(N); + for (let i = 0; i < N; i++) { + kMin[i] = new KMin(k); + } + for (let i = 0; i < N; i++) { + let a = accessor(dataPoints[i]); + let kMinA = kMin[i]; + for (let j = i + 1; j < N; j++) { + let kMinB = kMin[j]; + let limitI = kMinA.getSize() === k ? + kMinA.getLargestKey() || Number.MAX_VALUE : + Number.MAX_VALUE; + let limitJ = kMinB.getSize() === k ? + kMinB.getLargestKey() || Number.MAX_VALUE : + Number.MAX_VALUE; + let limit = Math.max(limitI, limitJ); + let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit); + if (dist2ItoJ >= 0) { + kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ}); + kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ}); + } + } + } + for (let i = 0; i < N; i++) { + nearest[i] = kMin[i].getMinKItems(); + } + return nearest; + }); +} + +/** Calculates the minimum distance between a search point and a rectangle. */ +function minDist( + point: [number, number], x1: number, y1: number, x2: number, y2: number) { + let x = point[0]; + let y = point[1]; + let dx1 = x - x1; + let dx2 = x - x2; + let dy1 = y - y1; + let dy2 = y - y2; + + if (dx1 * dx2 <= 0) { // x is between x1 and x2 + if (dy1 * dy2 <= 0) { // (x,y) is inside the rectangle + return 0; // return 0 as point is in rect + } + return Math.min(Math.abs(dy1), Math.abs(dy2)); + } + if (dy1 * dy2 <= 0) { // y is between y1 and y2 + // We know it is already inside the rectangle + return Math.min(Math.abs(dx1), Math.abs(dx2)); + } + let corner: [number, number]; + if (x > x2) { + // Upper-right vs lower-right. + corner = y > y2 ? [x2, y2] : [x2, y1]; + } else { + // Upper-left vs lower-left. + corner = y > y2 ? [x1, y2] : [x1, y1]; + } + return Math.sqrt(vector.dist22D([x, y], corner)); +} + +/** + * Returns the nearest neighbors of a particular point. + * + * @param dataPoints List of data points. + * @param pointIndex The index of the point we need the nearest neighbors of. + * @param k Number of nearest neighbors to search for. + * @param accessor Method that maps a data point => vector (array of numbers). + * @param distance Method that takes two vectors and returns their distance. + */ +export function findKNNofPoint( + dataPoints: T[], pointIndex: number, k: number, + accessor: (dataPoint: T) => Float32Array, + distance: (a: Vector, b: Vector) => number) { + let kMin = new KMin(k); + let a = accessor(dataPoints[pointIndex]); + for (let i = 0; i < dataPoints.length; ++i) { + if (i === pointIndex) { + continue; + } + let b = accessor(dataPoints[i]); + let dist = distance(a, b); + kMin.add(dist, {index: i, dist: dist}); + } + return kMin.getMinKItems(); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts new file mode 100644 index 00000000000..67987f06ea3 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts @@ -0,0 +1,151 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export interface BoundingBox { + loX: number; + loY: number; + hiX: number; + hiY: number; +} + +/** + * Accelerates label placement by dividing the view into a uniform grid. + * Labels only need to be tested for collision with other labels that overlap + * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}. + */ +export class CollisionGrid { + private numHorizCells: number; + private numVertCells: number; + private grid: BoundingBox[][]; + private bound: BoundingBox; + private cellWidth: number; + private cellHeight: number; + + /** + * Constructs a new Collision grid. + * + * @param bound The bound of the grid. Labels out of bounds will be rejected. + * @param cellWidth Width of a cell in the grid. + * @param cellHeight Height of a cell in the grid. + */ + constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) { + /** The bound of the grid. Labels out of bounds will be rejected. */ + this.bound = bound; + + /** Width of a cell in the grid. */ + this.cellWidth = cellWidth; + + /** Height of a cell in the grid. */ + this.cellHeight = cellHeight; + + /** Number of grid cells along the x axis. */ + this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth); + + /** Number of grid cells along the y axis. */ + this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight); + + /** + * The 2d grid (stored as a 1d array.) Each cell consists of an array of + * BoundingBoxes for objects that are in the cell. + */ + this.grid = new Array(this.numHorizCells * this.numVertCells); + } + + private boundWidth(bound: BoundingBox) { return bound.hiX - bound.loX; } + + private boundHeight(bound: BoundingBox) { return bound.hiY - bound.loY; } + + private boundsIntersect(a: BoundingBox, b: BoundingBox) { + return !(a.loX > b.hiX || a.loY > b.hiY || a.hiX < b.loX || a.hiY < b.loY); + } + + /** + * Checks if a given bounding box has any conflicts in the grid and inserts it + * if none are found. + * + * @param bound The bound to insert. + * @param justTest If true, just test if it conflicts, without inserting. + * @return True if the bound was successfully inserted; false if it + * could not be inserted due to a conflict. + */ + insert(bound: BoundingBox, justTest = false): boolean { + // Reject if the label is out of bounds. + if ((bound.hiX < this.bound.loX) || (bound.loX > this.bound.hiX) || + (bound.hiY < this.bound.loY) || (bound.loY > this.bound.hiY)) { + return false; + } + + let minCellX = this.getCellX(bound.loX); + let maxCellX = this.getCellX(bound.hiX); + let minCellY = this.getCellY(bound.loY); + let maxCellY = this.getCellY(bound.hiY); + + // Check all overlapped cells to verify that we can insert. + let baseIdx = minCellY * this.numHorizCells + minCellX; + let idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + let cell = this.grid[idx++]; + if (cell) { + for (let k = 0; k < cell.length; k++) { + if (this.boundsIntersect(bound, cell[k])) { + return false; + } + } + } + } + idx += this.numHorizCells - (maxCellX - minCellX + 1); + } + + if (justTest) { + return true; + } + + // Insert into the overlapped cells. + idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + if (!this.grid[idx]) { + this.grid[idx] = [bound]; + } else { + this.grid[idx].push(bound); + } + idx++; + } + idx += this.numHorizCells - (maxCellX - minCellX + 1); + } + return true; + } + + /** + * Returns the x index of the grid cell where the given x coordinate falls. + * + * @param x the coordinate, in world space. + * @return the x index of the cell. + */ + private getCellX(x: number) { + return Math.floor((x - this.bound.loX) / this.cellWidth); + }; + + /** + * Returns the y index of the grid cell where the given y coordinate falls. + * + * @param y the coordinate, in world space. + * @return the y index of the cell. + */ + private getCellY(y: number) { + return Math.floor((y - this.bound.loY) / this.cellHeight); + }; +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts new file mode 100644 index 00000000000..59f37206012 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts @@ -0,0 +1,103 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Duration in ms for showing warning messages to the user */ +const WARNING_DURATION_MS = 10000; + +let dom: HTMLElement = null; +let msgId = 0; +let numActiveMessages = 0; + +export function setDomContainer(domElement: HTMLElement) { + dom = domElement; +} + +/** + * Updates the user message with the provided id. + * + * @param msg The message shown to the user. If null, the message is removed. + * @param id The id of an existing message. If no id is provided, a unique id + * is assigned. + * @param title The title of the notification. + * @param isErrorMsg If true, the message is error and the dialog will have a + * close button. + * @return The id of the message. + */ +export function setModalMessage( + msg: string, id: string = null, title = null, isErrorMsg = false): string { + if (dom == null) { + console.warn('Can\'t show modal message before the dom is initialized'); + return; + } + if (id == null) { + id = (msgId++).toString(); + } + let dialog = dom.querySelector('#notification-dialog') as any; + dialog.querySelector('.close-button').style.display = + isErrorMsg ? null : 'none'; + let spinner = dialog.querySelector('.progress'); + spinner.style.display = isErrorMsg ? 'none' : null; + spinner.active = isErrorMsg ? null : true; + dialog.querySelector('#notification-title').innerHTML = title; + let msgsContainer = dialog.querySelector('#notify-msgs') as HTMLElement; + if (isErrorMsg) { + msgsContainer.innerHTML = ''; + } else { + const errors = msgsContainer.querySelectorAll('.error'); + for (let i = 0; i < errors.length; i++) { + msgsContainer.removeChild(errors[i]); + } + } + let divId = `notify-msg-${id}`; + let msgDiv = dialog.querySelector('#' + divId) as HTMLDivElement; + if (msgDiv == null) { + msgDiv = document.createElement('div'); + msgDiv.className = 'notify-msg ' + (isErrorMsg ? 'error' : ''); + msgDiv.id = divId; + + msgsContainer.insertBefore(msgDiv, msgsContainer.firstChild); + + if (!isErrorMsg) { + numActiveMessages++; + } else { + numActiveMessages = 0; + } + } + if (msg == null) { + numActiveMessages--; + if (numActiveMessages === 0) { + dialog.close(); + } + msgDiv.remove(); + } else { + msgDiv.innerText = msg; + dialog.open(); + } + return id; +} + +export function setErrorMessage(errMsg: string, task?: string) { + setModalMessage(errMsg, null, 'Error ' + (task != null ? task : ''), true); +} + +/** + * Shows a warning message to the user for a certain amount of time. + */ +export function setWarningMessage(msg: string): void { + let toast = dom.querySelector('#toast') as any; + toast.text = msg; + toast.duration = WARNING_DURATION_MS; + toast.open(); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts new file mode 100644 index 00000000000..36f5c4c5841 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts @@ -0,0 +1,45 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DistanceFunction, Projection} from './data'; +import {NearestEntry} from './knn'; + +export type HoverListener = (index: number) => void; +export type SelectionChangedListener = + (selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) => + void; +export type ProjectionChangedListener = (projection: Projection) => void; +export type DistanceMetricChangedListener = + (distanceMetric: DistanceFunction) => void; +export interface ProjectorEventContext { + /** Register a callback to be invoked when the mouse hovers over a point. */ + registerHoverListener(listener: HoverListener); + /** Notify the hover system that a point is under the mouse. */ + notifyHoverOverPoint(pointIndex: number); + /** Registers a callback to be invoked when the selection changes. */ + registerSelectionChangedListener(listener: SelectionChangedListener); + /** + * Notify the selection system that a client has changed the selected point + * set. + */ + notifySelectionChanged(newSelectedPointIndices: number[]); + /** Registers a callback to be invoked when the projection changes. */ + registerProjectionChangedListener(listener: ProjectionChangedListener); + /** Notify listeners that a reprojection occurred. */ + notifyProjectionChanged(projection: Projection); + registerDistanceMetricChangedListener(listener: + DistanceMetricChangedListener); + notifyDistanceMetricChanged(distMetric: DistanceFunction); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts new file mode 100644 index 00000000000..bb09e2b153a --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts @@ -0,0 +1,713 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 + +import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data'; +import {NearestEntry} from './knn'; +import {ProjectorEventContext} from './projectorEventContext'; +import {LabelRenderParams} from './renderContext'; +import {ScatterPlot} from './scatterPlot'; +import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels'; +import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels'; +import {ScatterPlotVisualizerPolylines} from './scatterPlotVisualizerPolylines'; +import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites'; +import * as vector from './vector'; + +const LABEL_FONT_SIZE = 10; +const LABEL_SCALE_DEFAULT = 1.0; +const LABEL_SCALE_LARGE = 2; +const LABEL_FILL_COLOR_SELECTED = 0x000000; +const LABEL_FILL_COLOR_HOVER = 0x000000; +const LABEL_FILL_COLOR_NEIGHBOR = 0x000000; +const LABEL_STROKE_COLOR_SELECTED = 0xFFFFFF; +const LABEL_STROKE_COLOR_HOVER = 0xFFFFFF; +const LABEL_STROKE_COLOR_NEIGHBOR = 0xFFFFFF; + +const POINT_COLOR_UNSELECTED = 0xE3E3E3; +const POINT_COLOR_NO_SELECTION = 0x7575D9; +const POINT_COLOR_SELECTED = 0xFA6666; +const POINT_COLOR_HOVER = 0x760B4F; + +const POINT_SCALE_DEFAULT = 1.0; +const POINT_SCALE_SELECTED = 1.2; +const POINT_SCALE_NEIGHBOR = 1.2; +const POINT_SCALE_HOVER = 1.2; + +const LABELS_3D_COLOR_UNSELECTED = 0xFFFFFF; +const LABELS_3D_COLOR_NO_SELECTION = 0xFFFFFF; + +const SPRITE_IMAGE_COLOR_UNSELECTED = 0xFFFFFF; +const SPRITE_IMAGE_COLOR_NO_SELECTION = 0xFFFFFF; + +const POLYLINE_START_HUE = 60; +const POLYLINE_END_HUE = 360; +const POLYLINE_SATURATION = 1; +const POLYLINE_LIGHTNESS = .3; + +const POLYLINE_DEFAULT_OPACITY = .2; +const POLYLINE_DEFAULT_LINEWIDTH = 2; +const POLYLINE_SELECTED_OPACITY = .9; +const POLYLINE_SELECTED_LINEWIDTH = 3; +const POLYLINE_DESELECTED_OPACITY = .05; + +const SCATTER_PLOT_CUBE_LENGTH = 2; + +/** Color scale for nearest neighbors. */ +const NN_COLOR_SCALE = + d3.scaleLinear() + .domain([1, 0.7, 0.4]) + .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)']) + .clamp(true); + +/** + * Interprets projector events and assembes the arrays and commands necessary + * to use the ScatterPlot to render the current projected data set. + */ +export class ProjectorScatterPlotAdapter { + public scatterPlot: ScatterPlot; + private projection: Projection; + private hoverPointIndex: number; + private selectedPointIndices: number[]; + private neighborsOfFirstSelectedPoint: NearestEntry[]; + private renderLabelsIn3D: boolean = false; + private labelPointAccessor: string; + private legendPointColorer: (ds: DataSet, index: number) => string; + private distanceMetric: DistanceFunction; + + private spriteVisualizer: ScatterPlotVisualizerSprites; + private labels3DVisualizer: ScatterPlotVisualizer3DLabels; + private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels; + private polylineVisualizer: ScatterPlotVisualizerPolylines; + + constructor( + private scatterPlotContainer: HTMLElement, + projectorEventContext: ProjectorEventContext) { + this.scatterPlot = + new ScatterPlot(scatterPlotContainer, projectorEventContext); + projectorEventContext.registerProjectionChangedListener(projection => { + this.projection = projection; + this.updateScatterPlotWithNewProjection(projection); + }); + projectorEventContext.registerSelectionChangedListener( + (selectedPointIndices, neighbors) => { + this.selectedPointIndices = selectedPointIndices; + this.neighborsOfFirstSelectedPoint = neighbors; + this.updateScatterPlotPositions(); + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + projectorEventContext.registerHoverListener(hoverPointIndex => { + this.hoverPointIndex = hoverPointIndex; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + projectorEventContext.registerDistanceMetricChangedListener( + distanceMetric => { + this.distanceMetric = distanceMetric; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + this.createVisualizers(false); + } + + notifyProjectionPositionsUpdated() { + this.updateScatterPlotPositions(); + this.scatterPlot.render(); + } + + setDataSet(dataSet: DataSet) { + if (this.projection != null) { + // TODO(nicholsonc): setDataSet needs to go away, the projection is the + // atomic unit of update. + this.projection.dataSet = dataSet; + } + if (this.polylineVisualizer != null) { + this.polylineVisualizer.setDataSet(dataSet); + } + if (this.labels3DVisualizer != null) { + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(dataSet, this.labelPointAccessor)); + } + if (this.spriteVisualizer == null) { + return; + } + this.spriteVisualizer.clearSpriteAtlas(); + if ((dataSet == null) || (dataSet.spriteAndMetadataInfo == null)) { + return; + } + const metadata = dataSet.spriteAndMetadataInfo; + if ((metadata.spriteImage == null) || (metadata.spriteMetadata == null)) { + return; + } + const n = dataSet.points.length; + const spriteIndices = new Float32Array(n); + for (let i = 0; i < n; ++i) { + spriteIndices[i] = dataSet.points[i].index; + } + this.spriteVisualizer.setSpriteAtlas( + metadata.spriteImage, metadata.spriteMetadata.singleImageDim, + spriteIndices); + } + + set3DLabelMode(renderLabelsIn3D: boolean) { + this.renderLabelsIn3D = renderLabelsIn3D; + this.createVisualizers(renderLabelsIn3D); + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + } + + setLegendPointColorer( + legendPointColorer: (ds: DataSet, index: number) => string) { + this.legendPointColorer = legendPointColorer; + } + + setLabelPointAccessor(labelPointAccessor: string) { + this.labelPointAccessor = labelPointAccessor; + if (this.labels3DVisualizer != null) { + const ds = (this.projection == null) ? null : this.projection.dataSet; + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(ds, labelPointAccessor)); + } + } + + resize() { + this.scatterPlot.resize(); + } + + populateBookmarkFromUI(state: State) { + state.cameraDef = this.scatterPlot.getCameraDef(); + } + + restoreUIFromBookmark(state: State) { + this.scatterPlot.setCameraParametersForNextCameraCreation( + state.cameraDef, false); + } + + updateScatterPlotPositions() { + const ds = (this.projection == null) ? null : this.projection.dataSet; + const projectionComponents = + (this.projection == null) ? null : this.projection.projectionComponents; + const newPositions = + this.generatePointPositionArray(ds, projectionComponents); + this.scatterPlot.setPointPositions(newPositions); + } + + updateScatterPlotAttributes() { + if (this.projection == null) { + return; + } + const dataSet = this.projection.dataSet; + const selectedSet = this.selectedPointIndices; + const hoverIndex = this.hoverPointIndex; + const neighbors = this.neighborsOfFirstSelectedPoint; + const pointColorer = this.legendPointColorer; + + const pointColors = this.generatePointColorArray( + dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors, + hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode()); + const pointScaleFactors = this.generatePointScaleFactorArray( + dataSet, selectedSet, neighbors, hoverIndex); + const labels = this.generateVisibleLabelRenderParams( + dataSet, selectedSet, neighbors, hoverIndex); + const polylineColors = + this.generateLineSegmentColorMap(dataSet, pointColorer); + const polylineOpacities = + this.generateLineSegmentOpacityArray(dataSet, selectedSet); + const polylineWidths = + this.generateLineSegmentWidthArray(dataSet, selectedSet); + + this.scatterPlot.setPointColors(pointColors); + this.scatterPlot.setPointScaleFactors(pointScaleFactors); + this.scatterPlot.setLabels(labels); + this.scatterPlot.setPolylineColors(polylineColors); + this.scatterPlot.setPolylineOpacities(polylineOpacities); + this.scatterPlot.setPolylineWidths(polylineWidths); + } + + render() { + this.scatterPlot.render(); + } + + generatePointPositionArray( + ds: DataSet, projectionComponents: ProjectionComponents3D): Float32Array { + if (ds == null) { + return null; + } + + const xScaler = d3.scaleLinear(); + const yScaler = d3.scaleLinear(); + let zScaler = null; + { + // Determine max and min of each axis of our data. + const xExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[0]]); + const yExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[1]]); + + const range = + [-SCATTER_PLOT_CUBE_LENGTH / 2, SCATTER_PLOT_CUBE_LENGTH / 2]; + + xScaler.domain(xExtent).range(range); + yScaler.domain(yExtent).range(range); + + if (projectionComponents[2] != null) { + const zExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[2]]); + zScaler = d3.scaleLinear(); + zScaler.domain(zExtent).range(range); + } + } + + const positions = new Float32Array(ds.points.length * 3); + let dst = 0; + + ds.points.forEach((d, i) => { + positions[dst++] = + xScaler(ds.points[i].projections[projectionComponents[0]]); + positions[dst++] = + yScaler(ds.points[i].projections[projectionComponents[1]]); + positions[dst++] = 0.0; + }); + + if (zScaler) { + dst = 2; + ds.points.forEach((d, i) => { + positions[dst] = + zScaler(ds.points[i].projections[projectionComponents[2]]); + dst += 3; + }); + } + + return positions; + } + + generateVisibleLabelRenderParams( + ds: DataSet, selectedPointIndices: number[], + neighborsOfFirstPoint: NearestEntry[], + hoverPointIndex: number): LabelRenderParams { + if (ds == null) { + return null; + } + + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; + const n = selectedPointCount + neighborCount + + ((hoverPointIndex != null) ? 1 : 0); + + const visibleLabels = new Uint32Array(n); + const scale = new Float32Array(n); + const opacityFlags = new Int8Array(n); + const fillColors = new Uint8Array(n * 3); + const strokeColors = new Uint8Array(n * 3); + const labelStrings: string[] = []; + + scale.fill(LABEL_SCALE_DEFAULT); + opacityFlags.fill(1); + + let dst = 0; + + if (hoverPointIndex != null) { + labelStrings.push( + this.getLabelText(ds, hoverPointIndex, this.labelPointAccessor)); + visibleLabels[dst] = hoverPointIndex; + scale[dst] = LABEL_SCALE_LARGE; + opacityFlags[dst] = 0; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); + packRgbIntoUint8Array( + fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); + packRgbIntoUint8Array( + strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]); + ++dst; + } + + // Selected points + { + const n = selectedPointCount; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); + for (let i = 0; i < n; ++i) { + const labelIndex = selectedPointIndices[i]; + labelStrings.push( + this.getLabelText(ds, labelIndex, this.labelPointAccessor)); + visibleLabels[dst] = labelIndex; + scale[dst] = LABEL_SCALE_LARGE; + opacityFlags[dst] = (n === 1) ? 0 : 1; + packRgbIntoUint8Array( + fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); + packRgbIntoUint8Array( + strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); + ++dst; + } + } + + // Neighbors + { + const n = neighborCount; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); + for (let i = 0; i < n; ++i) { + const labelIndex = neighborsOfFirstPoint[i].index; + labelStrings.push( + this.getLabelText(ds, labelIndex, this.labelPointAccessor)); + visibleLabels[dst] = labelIndex; + packRgbIntoUint8Array( + fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); + packRgbIntoUint8Array( + strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); + ++dst; + } + } + + return new LabelRenderParams( + visibleLabels, labelStrings, scale, opacityFlags, LABEL_FONT_SIZE, + fillColors, strokeColors); + } + + generatePointScaleFactorArray( + ds: DataSet, selectedPointIndices: number[], + neighborsOfFirstPoint: NearestEntry[], + hoverPointIndex: number): Float32Array { + if (ds == null) { + return new Float32Array(0); + } + + const scale = new Float32Array(ds.points.length); + scale.fill(POINT_SCALE_DEFAULT); + + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; + + // Scale up all selected points. + { + const n = selectedPointCount; + for (let i = 0; i < n; ++i) { + const p = selectedPointIndices[i]; + scale[p] = POINT_SCALE_SELECTED; + } + } + + // Scale up the neighbor points. + { + const n = neighborCount; + for (let i = 0; i < n; ++i) { + const p = neighborsOfFirstPoint[i].index; + scale[p] = POINT_SCALE_NEIGHBOR; + } + } + + // Scale up the hover point. + if (hoverPointIndex != null) { + scale[hoverPointIndex] = POINT_SCALE_HOVER; + } + + return scale; + } + + generateLineSegmentColorMap( + ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string): + {[polylineIndex: number]: Float32Array} { + let polylineColorArrayMap: {[polylineIndex: number]: Float32Array} = {}; + if (ds == null) { + return polylineColorArrayMap; + } + + for (let i = 0; i < ds.sequences.length; i++) { + let sequence = ds.sequences[i]; + let colors = new Float32Array(2 * (sequence.pointIndices.length - 1) * 3); + let colorIndex = 0; + + if (legendPointColorer) { + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const c1 = + new THREE.Color(legendPointColorer(ds, sequence.pointIndices[j])); + const c2 = new THREE.Color( + legendPointColorer(ds, sequence.pointIndices[j + 1])); + colors[colorIndex++] = c1.r; + colors[colorIndex++] = c1.g; + colors[colorIndex++] = c1.b; + colors[colorIndex++] = c2.r; + colors[colorIndex++] = c2.g; + colors[colorIndex++] = c2.b; + } + } else { + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const c1 = + getDefaultPointInPolylineColor(j, sequence.pointIndices.length); + const c2 = getDefaultPointInPolylineColor( + j + 1, sequence.pointIndices.length); + colors[colorIndex++] = c1.r; + colors[colorIndex++] = c1.g; + colors[colorIndex++] = c1.b; + colors[colorIndex++] = c2.r; + colors[colorIndex++] = c2.g; + colors[colorIndex++] = c2.b; + } + } + + polylineColorArrayMap[i] = colors; + } + + return polylineColorArrayMap; + } + + generateLineSegmentOpacityArray(ds: DataSet, selectedPoints: number[]): + Float32Array { + if (ds == null) { + return new Float32Array(0); + } + const opacities = new Float32Array(ds.sequences.length); + const selectedPointCount = + (selectedPoints == null) ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { + opacities.fill(POLYLINE_DESELECTED_OPACITY); + const i = ds.points[selectedPoints[0]].sequenceIndex; + opacities[i] = POLYLINE_SELECTED_OPACITY; + } else { + opacities.fill(POLYLINE_DEFAULT_OPACITY); + } + return opacities; + } + + generateLineSegmentWidthArray(ds: DataSet, selectedPoints: number[]): + Float32Array { + if (ds == null) { + return new Float32Array(0); + } + const widths = new Float32Array(ds.sequences.length); + widths.fill(POLYLINE_DEFAULT_LINEWIDTH); + const selectedPointCount = + (selectedPoints == null) ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { + const i = ds.points[selectedPoints[0]].sequenceIndex; + widths[i] = POLYLINE_SELECTED_LINEWIDTH; + } + return widths; + } + + generatePointColorArray( + ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string, + distFunc: DistanceFunction, selectedPointIndices: number[], + neighborsOfFirstPoint: NearestEntry[], hoverPointIndex: number, + label3dMode: boolean, spriteImageMode: boolean): Float32Array { + if (ds == null) { + return new Float32Array(0); + } + + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; + const colors = new Float32Array(ds.points.length * 3); + + let unselectedColor = POINT_COLOR_UNSELECTED; + let noSelectionColor = POINT_COLOR_NO_SELECTION; + + if (label3dMode) { + unselectedColor = LABELS_3D_COLOR_UNSELECTED; + noSelectionColor = LABELS_3D_COLOR_NO_SELECTION; + } + + if (spriteImageMode) { + unselectedColor = SPRITE_IMAGE_COLOR_UNSELECTED; + noSelectionColor = SPRITE_IMAGE_COLOR_NO_SELECTION; + } + + // Give all points the unselected color. + { + const n = ds.points.length; + let dst = 0; + if (selectedPointCount > 0) { + const c = new THREE.Color(unselectedColor); + for (let i = 0; i < n; ++i) { + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } else { + if (legendPointColorer != null) { + for (let i = 0; i < n; ++i) { + const c = new THREE.Color(legendPointColorer(ds, i)); + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } else { + const c = new THREE.Color(noSelectionColor); + for (let i = 0; i < n; ++i) { + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } + } + } + + // Color the selected points. + { + const n = selectedPointCount; + const c = new THREE.Color(POINT_COLOR_SELECTED); + for (let i = 0; i < n; ++i) { + let dst = selectedPointIndices[i] * 3; + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } + + // Color the neighbors. + { + const n = neighborCount; + let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0; + for (let i = 0; i < n; ++i) { + const c = new THREE.Color( + dist2color(distFunc, neighborsOfFirstPoint[i].dist, minDist)); + let dst = neighborsOfFirstPoint[i].index * 3; + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } + + // Color the hover point. + if (hoverPointIndex != null) { + const c = new THREE.Color(POINT_COLOR_HOVER); + let dst = hoverPointIndex * 3; + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + + return colors; + } + + generate3DLabelsArray(ds: DataSet, accessor: string) { + if ((ds == null) || (accessor == null)) { + return null; + } + let labels: string[] = []; + const n = ds.points.length; + for (let i = 0; i < n; ++i) { + labels.push(this.getLabelText(ds, i, accessor)); + } + return labels; + } + + private getLabelText(ds: DataSet, i: number, accessor: string) { + return ds.points[i].metadata[accessor].toString(); + } + + private updateScatterPlotWithNewProjection(projection: Projection) { + if (projection == null) { + this.createVisualizers(this.renderLabelsIn3D); + this.scatterPlot.render(); + return; + } + this.setDataSet(projection.dataSet); + this.scatterPlot.setDimensions(projection.dimensionality); + if (projection.dataSet.projectionCanBeRendered(projection.projectionType)) { + this.updateScatterPlotAttributes(); + this.notifyProjectionPositionsUpdated(); + } + this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); + } + + private createVisualizers(inLabels3DMode: boolean) { + const ds = (this.projection == null) ? null : this.projection.dataSet; + const scatterPlot = this.scatterPlot; + scatterPlot.removeAllVisualizers(); + this.labels3DVisualizer = null; + this.canvasLabelsVisualizer = null; + this.spriteVisualizer = null; + this.polylineVisualizer = null; + if (inLabels3DMode) { + this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels(); + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(ds, this.labelPointAccessor)); + } else { + this.spriteVisualizer = new ScatterPlotVisualizerSprites(); + scatterPlot.addVisualizer(this.spriteVisualizer); + this.canvasLabelsVisualizer = + new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer); + } + this.polylineVisualizer = new ScatterPlotVisualizerPolylines(); + this.setDataSet(ds); + if (this.spriteVisualizer) { + scatterPlot.addVisualizer(this.spriteVisualizer); + } + if (this.labels3DVisualizer) { + scatterPlot.addVisualizer(this.labels3DVisualizer); + } + if (this.canvasLabelsVisualizer) { + scatterPlot.addVisualizer(this.canvasLabelsVisualizer); + } + scatterPlot.addVisualizer(this.polylineVisualizer); + } + + private getSpriteImageMode(): boolean { + if (this.projection == null) { + return false; + } + const ds = this.projection.dataSet; + if ((ds == null) || (ds.spriteAndMetadataInfo == null)) { + return false; + } + return ds.spriteAndMetadataInfo.spriteImage != null; + } +} + +function packRgbIntoUint8Array( + rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) { + rgbArray[labelIndex * 3] = r; + rgbArray[labelIndex * 3 + 1] = g; + rgbArray[labelIndex * 3 + 2] = b; +} + +function styleRgbFromHexColor(hex: number): [number, number, number] { + const c = new THREE.Color(hex); + return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; +} + +function getDefaultPointInPolylineColor( + index: number, totalPoints: number): THREE.Color { + let hue = POLYLINE_START_HUE + + (POLYLINE_END_HUE - POLYLINE_START_HUE) * index / totalPoints; + + let rgb = d3.hsl(hue, POLYLINE_SATURATION, POLYLINE_LIGHTNESS).rgb(); + return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); +} + +/** + * Normalizes the distance so it can be visually encoded with color. + * The normalization depends on the distance metric (cosine vs euclidean). + */ +export function normalizeDist( + distFunc: DistanceFunction, d: number, minDist: number): number { + return (distFunc === vector.dist) ? (minDist / d) : (1 - d); +} + +/** Normalizes and encodes the provided distance with color. */ +export function dist2color( + distFunc: DistanceFunction, d: number, minDist: number): string { + return NN_COLOR_SCALE(normalizeDist(distFunc, d, minDist)); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts new file mode 100644 index 00000000000..8d5232a8048 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts @@ -0,0 +1,53 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http:www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * LabelRenderParams describes the set of points that should have labels + * rendered next to them. + */ +export class LabelRenderParams { + constructor( + public pointIndices: Float32Array, public labelStrings: string[], + public scaleFactors: Float32Array, public useSceneOpacityFlags: Int8Array, + public defaultFontSize: number, public fillColors: Uint8Array, + public strokeColors: Uint8Array) {} +} + +/** Details about the camera projection being used to render the scene. */ +export enum CameraType { + Perspective, + Orthographic +} + +/** + * RenderContext contains all of the state required to color and render the data + * set. ScatterPlot passes this to every attached visualizer as part of the + * render callback. + * TODO(nicholsonc): This should only contain the data that's changed between + * each frame. Data like colors / scale factors / labels should be reapplied + * only when they change. + */ +export class RenderContext { + constructor( + public camera: THREE.Camera, public cameraType: CameraType, + public cameraTarget: THREE.Vector3, public screenWidth: number, + public screenHeight: number, public nearestCameraSpacePointZ: number, + public farthestCameraSpacePointZ: number, public backgroundColor: number, + public pointColors: Float32Array, public pointScaleFactors: Float32Array, + public labels: LabelRenderParams, + public polylineColors: {[polylineIndex: number]: Float32Array}, + public polylineOpacities: Float32Array, + public polylineWidths: Float32Array) {} +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts new file mode 100644 index 00000000000..283b608e836 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts @@ -0,0 +1,723 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ProjectorEventContext} from './projectorEventContext'; +import {CameraType, LabelRenderParams, RenderContext} from './renderContext'; +import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; +import {Point2D, Point3D} from './vector'; + +const BACKGROUND_COLOR = 0xffffff; + +/** + * The length of the cube (diameter of the circumscribing sphere) where all the + * points live. + */ +const CUBE_LENGTH = 2; +const MAX_ZOOM = 5 * CUBE_LENGTH; +const MIN_ZOOM = 0.025 * CUBE_LENGTH; + +// Constants relating to the camera parameters. +const PERSP_CAMERA_FOV_VERTICAL = 70; +const PERSP_CAMERA_NEAR_CLIP_PLANE = 0.01; +const PERSP_CAMERA_FAR_CLIP_PLANE = 100; +const ORTHO_CAMERA_FRUSTUM_HALF_EXTENT = 1.2; + +// Key presses. +const SHIFT_KEY = 16; +const CTRL_KEY = 17; + +const START_CAMERA_POS_3D = new THREE.Vector3(0.45, 0.9, 1.6); +const START_CAMERA_TARGET_3D = new THREE.Vector3(0, 0, 0); +const START_CAMERA_POS_2D = new THREE.Vector3(0, 0, 4); +const START_CAMERA_TARGET_2D = new THREE.Vector3(0, 0, 0); + +const ORBIT_MOUSE_ROTATION_SPEED = 1; +const ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 7; + +export type OnCameraMoveListener = + (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => void; + +/** Supported modes of interaction. */ +export enum MouseMode { + AREA_SELECT, + CAMERA_AND_CLICK_SELECT +} + +/** Defines a camera, suitable for serialization. */ +export class CameraDef { + orthographic: boolean = false; + position: Point3D; + target: Point3D; + zoom: number; +} + +/** + * Maintains a three.js instantiation and context, + * animation state, and all other logic that's + * independent of how a 3D scatter plot is actually rendered. Also holds an + * array of visualizers and dispatches application events to them. + */ +export class ScatterPlot { + private visualizers: ScatterPlotVisualizer[] = []; + + private onCameraMoveListeners: OnCameraMoveListener[] = []; + + private height: number; + private width: number; + + private mouseMode: MouseMode; + private backgroundColor: number = BACKGROUND_COLOR; + + private dimensionality: number = 3; + private renderer: THREE.WebGLRenderer; + + private scene: THREE.Scene; + private pickingTexture: THREE.WebGLRenderTarget; + private light: THREE.PointLight; + + private cameraDef: CameraDef = null; + private camera: THREE.Camera; + private orbitAnimationOnNextCameraCreation: boolean = false; + private orbitCameraControls: any; + private orbitAnimationId: number; + + private worldSpacePointPositions: Float32Array; + private pointColors: Float32Array; + private pointScaleFactors: Float32Array; + private labels: LabelRenderParams; + private polylineColors: {[polylineIndex: number]: Float32Array}; + private polylineOpacities: Float32Array; + private polylineWidths: Float32Array; + + private selecting = false; + private nearestPoint: number; + private mouseIsDown = false; + private isDragSequence = false; + private rectangleSelector: ScatterPlotRectangleSelector; + + constructor( + private container: HTMLElement, + private projectorEventContext: ProjectorEventContext) { + this.getLayoutValues(); + + this.scene = new THREE.Scene(); + this.renderer = new THREE.WebGLRenderer( + {alpha: true, premultipliedAlpha: false, antialias: false}); + this.renderer.setClearColor(BACKGROUND_COLOR, 1); + this.container.appendChild(this.renderer.domElement); + this.light = new THREE.PointLight(0xFFECBF, 1, 0); + this.scene.add(this.light); + + this.setDimensions(3); + this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); + this.renderer.render(this.scene, this.camera); + + this.rectangleSelector = new ScatterPlotRectangleSelector( + this.container, + (boundingBox: BoundingBox) => this.selectBoundingBox(boundingBox)); + this.addInteractionListeners(); + } + + private addInteractionListeners() { + this.container.addEventListener('mousemove', this.onMouseMove.bind(this)); + this.container.addEventListener('mousedown', this.onMouseDown.bind(this)); + this.container.addEventListener('mouseup', this.onMouseUp.bind(this)); + this.container.addEventListener('click', this.onClick.bind(this)); + window.addEventListener('keydown', this.onKeyDown.bind(this), false); + window.addEventListener('keyup', this.onKeyUp.bind(this), false); + } + + private addCameraControlsEventListeners(cameraControls: any) { + // Start is called when the user stars interacting with + // controls. + cameraControls.addEventListener('start', () => { + this.stopOrbitAnimation(); + this.onCameraMoveListeners.forEach( + l => l(this.camera.position, cameraControls.target)); + }); + + // Change is called everytime the user interacts with the controls. + cameraControls.addEventListener('change', () => { + this.render(); + }); + + // End is called when the user stops interacting with the + // controls (e.g. on mouse up, after dragging). + cameraControls.addEventListener('end', () => {}); + } + + private makeOrbitControls( + camera: THREE.Camera, cameraDef: CameraDef, cameraIs3D: boolean) { + if (this.orbitCameraControls != null) { + this.orbitCameraControls.dispose(); + } + const occ = + new (THREE as any).OrbitControls(camera, this.renderer.domElement); + occ.target0 = new THREE.Vector3( + cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); + occ.position0 = new THREE.Vector3().copy(camera.position); + occ.zoom0 = cameraDef.zoom; + occ.enableRotate = cameraIs3D; + occ.autoRotate = false; + occ.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; + if (cameraIs3D) { + occ.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + occ.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } else { + occ.mouseButtons.ORBIT = null; + occ.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + occ.reset(); + + this.camera = camera; + this.orbitCameraControls = occ; + this.addCameraControlsEventListeners(this.orbitCameraControls); + } + + private makeCamera3D(cameraDef: CameraDef, w: number, h: number) { + let camera: THREE.PerspectiveCamera; + { + const aspectRatio = w / h; + camera = new THREE.PerspectiveCamera( + PERSP_CAMERA_FOV_VERTICAL, aspectRatio, PERSP_CAMERA_NEAR_CLIP_PLANE, + PERSP_CAMERA_FAR_CLIP_PLANE); + camera.position.set( + cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]); + const at = new THREE.Vector3( + cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); + camera.lookAt(at); + camera.zoom = cameraDef.zoom; + camera.updateProjectionMatrix(); + } + this.camera = camera; + this.makeOrbitControls(camera, cameraDef, true); + } + + private makeCamera2D(cameraDef: CameraDef, w: number, h: number) { + let camera: THREE.OrthographicCamera; + const target = new THREE.Vector3( + cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); + { + const aspectRatio = w / h; + let left = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let right = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let bottom = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let top = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + // Scale up the larger of (w, h) to match the aspect ratio. + if (aspectRatio > 1) { + left *= aspectRatio; + right *= aspectRatio; + } else { + top /= aspectRatio; + bottom /= aspectRatio; + } + camera = + new THREE.OrthographicCamera(left, right, top, bottom, -1000, 1000); + camera.position.set( + cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]); + camera.up = new THREE.Vector3(0, 1, 0); + camera.lookAt(target); + camera.zoom = cameraDef.zoom; + camera.updateProjectionMatrix(); + } + this.camera = camera; + this.makeOrbitControls(camera, cameraDef, false); + } + + private makeDefaultCameraDef(dimensionality: number): CameraDef { + const def = new CameraDef(); + def.orthographic = (dimensionality === 2); + def.zoom = 1.0; + if (def.orthographic) { + def.position = + [START_CAMERA_POS_2D.x, START_CAMERA_POS_2D.y, START_CAMERA_POS_2D.z]; + def.target = [ + START_CAMERA_TARGET_2D.x, START_CAMERA_TARGET_2D.y, + START_CAMERA_TARGET_2D.z + ]; + } else { + def.position = + [START_CAMERA_POS_3D.x, START_CAMERA_POS_3D.y, START_CAMERA_POS_3D.z]; + def.target = [ + START_CAMERA_TARGET_3D.x, START_CAMERA_TARGET_3D.y, + START_CAMERA_TARGET_3D.z + ]; + } + return def; + } + + /** Recreate the scatter plot camera from a definition structure. */ + recreateCamera(cameraDef: CameraDef) { + if (cameraDef.orthographic) { + this.makeCamera2D(cameraDef, this.width, this.height); + } else { + this.makeCamera3D(cameraDef, this.width, this.height); + } + this.orbitCameraControls.minDistance = MIN_ZOOM; + this.orbitCameraControls.maxDistance = MAX_ZOOM; + this.orbitCameraControls.update(); + if (this.orbitAnimationOnNextCameraCreation) { + this.startOrbitAnimation(); + } + } + + private onClick(e?: MouseEvent, notify = true) { + if (e && this.selecting) { + return; + } + // Only call event handlers if the click originated from the scatter plot. + if (!this.isDragSequence && notify) { + const selection = (this.nearestPoint != null) ? [this.nearestPoint] : []; + this.projectorEventContext.notifySelectionChanged(selection); + } + this.isDragSequence = false; + this.render(); + } + + private onMouseDown(e: MouseEvent) { + this.isDragSequence = false; + this.mouseIsDown = true; + if (this.selecting) { + this.orbitCameraControls.enabled = false; + this.rectangleSelector.onMouseDown(e.offsetX, e.offsetY); + this.setNearestPointToMouse(e); + } else if ( + !e.ctrlKey && this.sceneIs3D() && + this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.RIGHT) { + // The user happened to press the ctrl key when the tab was active, + // unpressed the ctrl when the tab was inactive, and now he/she + // is back to the projector tab. + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } else if ( + e.ctrlKey && this.sceneIs3D() && + this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.LEFT) { + // Similarly to the situation above. + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + } + + /** When we stop dragging/zooming, return to normal behavior. */ + private onMouseUp(e: any) { + if (this.selecting) { + this.orbitCameraControls.enabled = true; + this.rectangleSelector.onMouseUp(); + this.render(); + } + this.mouseIsDown = false; + } + + /** + * When the mouse moves, find the nearest point (if any) and send it to the + * hoverlisteners (usually called from embedding.ts) + */ + private onMouseMove(e: MouseEvent) { + this.isDragSequence = this.mouseIsDown; + // Depending if we're selecting or just navigating, handle accordingly. + if (this.selecting && this.mouseIsDown) { + this.rectangleSelector.onMouseMove(e.offsetX, e.offsetY); + this.render(); + } else if (!this.mouseIsDown) { + this.setNearestPointToMouse(e); + this.projectorEventContext.notifyHoverOverPoint(this.nearestPoint); + } + } + + /** For using ctrl + left click as right click, and for circle select */ + private onKeyDown(e: any) { + // If ctrl is pressed, use left click to orbit + if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + + // If shift is pressed, start selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = true; + this.container.style.cursor = 'crosshair'; + } + } + + /** For using ctrl + left click as right click, and for circle select */ + private onKeyUp(e: any) { + if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } + + // If shift is released, stop selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = (this.getMouseMode() === MouseMode.AREA_SELECT); + if (!this.selecting) { + this.container.style.cursor = 'default'; + } + this.render(); + } + } + + /** + * Returns a list of indices of points in a bounding box from the picking + * texture. + * @param boundingBox The bounding box to select from. + */ + private getPointIndicesFromPickingTexture(boundingBox: BoundingBox): + number[] { + if (this.worldSpacePointPositions == null) { + return null; + } + const pointCount = this.worldSpacePointPositions.length / 3; + const dpr = window.devicePixelRatio || 1; + const x = Math.floor(boundingBox.x * dpr); + const y = Math.floor(boundingBox.y * dpr); + const width = Math.floor(boundingBox.width * dpr); + const height = Math.floor(boundingBox.height * dpr); + + // Create buffer for reading all of the pixels from the texture. + let pixelBuffer = new Uint8Array(width * height * 4); + + // Read the pixels from the bounding box. + this.renderer.readRenderTargetPixels( + this.pickingTexture, x, this.pickingTexture.height - y, width, height, + pixelBuffer); + + // Keep a flat list of each point and whether they are selected or not. This + // approach is more efficient than using an object keyed by the index. + let pointIndicesSelection = + new Uint8Array(this.worldSpacePointPositions.length); + for (let i = 0; i < width * height; i++) { + const id = (pixelBuffer[i * 4] << 16) | (pixelBuffer[i * 4 + 1] << 8) | + pixelBuffer[i * 4 + 2]; + if (id !== 0xffffff && (id < pointCount)) { + pointIndicesSelection[id] = 1; + } + } + let pointIndices: number[] = []; + for (let i = 0; i < pointIndicesSelection.length; i++) { + if (pointIndicesSelection[i] === 1) { + pointIndices.push(i); + } + } + + return pointIndices; + } + + + private selectBoundingBox(boundingBox: BoundingBox) { + let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); + this.projectorEventContext.notifySelectionChanged(pointIndices); + } + + private setNearestPointToMouse(e: MouseEvent) { + if (this.pickingTexture == null) { + this.nearestPoint = null; + return; + } + const boundingBox: + BoundingBox = {x: e.offsetX, y: e.offsetY, width: 1, height: 1}; + const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); + this.nearestPoint = (pointIndices != null) ? pointIndices[0] : null; + } + + private getLayoutValues(): Point2D { + this.width = this.container.offsetWidth; + this.height = Math.max(1, this.container.offsetHeight); + return [this.width, this.height]; + } + + private sceneIs3D(): boolean { + return this.dimensionality === 3; + } + + private remove3dAxisFromScene(): THREE.Object3D { + const axes = this.scene.getObjectByName('axes'); + if (axes != null) { + this.scene.remove(axes); + } + return axes; + } + + private add3dAxis() { + const axes = new THREE.AxisHelper(); + axes.name = 'axes'; + this.scene.add(axes); + } + + /** Set 2d vs 3d mode. */ + setDimensions(dimensionality: number) { + if ((dimensionality !== 2) && (dimensionality !== 3)) { + throw new RangeError('dimensionality must be 2 or 3'); + } + this.dimensionality = dimensionality; + + const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); + this.recreateCamera(def); + + this.remove3dAxisFromScene(); + if (dimensionality === 3) { + this.add3dAxis(); + } + } + + /** Gets the current camera information, suitable for serialization. */ + getCameraDef(): CameraDef { + const def = new CameraDef(); + const pos = this.camera.position; + const tgt = this.orbitCameraControls.target; + def.orthographic = !this.sceneIs3D(); + def.position = [pos.x, pos.y, pos.z]; + def.target = [tgt.x, tgt.y, tgt.z]; + def.zoom = (this.camera as any).zoom; + return def; + } + + /** Sets parameters for the next camera recreation. */ + setCameraParametersForNextCameraCreation( + def: CameraDef, orbitAnimation: boolean) { + this.cameraDef = def; + this.orbitAnimationOnNextCameraCreation = orbitAnimation; + } + + /** Gets the current camera position. */ + getCameraPosition(): Point3D { + const currPos = this.camera.position; + return [currPos.x, currPos.y, currPos.z]; + } + + /** Gets the current camera target. */ + getCameraTarget(): Point3D { + let currTarget = this.orbitCameraControls.target; + return [currTarget.x, currTarget.y, currTarget.z]; + } + + /** Sets up the camera from given position and target coordinates. */ + setCameraPositionAndTarget(position: Point3D, target: Point3D) { + this.stopOrbitAnimation(); + this.camera.position.set(position[0], position[1], position[2]); + this.orbitCameraControls.target.set(target[0], target[1], target[2]); + this.orbitCameraControls.update(); + this.render(); + } + + /** Starts orbiting the camera around its current lookat target. */ + startOrbitAnimation() { + if (!this.sceneIs3D()) { + return; + } + if (this.orbitAnimationId != null) { + this.stopOrbitAnimation(); + } + this.orbitCameraControls.autoRotate = true; + this.orbitCameraControls.rotateSpeed = + ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS; + this.updateOrbitAnimation(); + } + + private updateOrbitAnimation() { + this.orbitCameraControls.update(); + this.orbitAnimationId = + requestAnimationFrame(() => this.updateOrbitAnimation()); + } + + /** Stops the orbiting animation on the camera. */ + stopOrbitAnimation() { + this.orbitCameraControls.autoRotate = false; + this.orbitCameraControls.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; + if (this.orbitAnimationId != null) { + cancelAnimationFrame(this.orbitAnimationId); + this.orbitAnimationId = null; + } + } + + /** Adds a visualizer to the set, will start dispatching events to it */ + addVisualizer(visualizer: ScatterPlotVisualizer) { + if (this.scene) { + visualizer.setScene(this.scene); + } + visualizer.onResize(this.width, this.height); + visualizer.onPointPositionsChanged(this.worldSpacePointPositions); + this.visualizers.push(visualizer); + } + + /** Removes all visualizers attached to this scatter plot. */ + removeAllVisualizers() { + this.visualizers.forEach(v => v.dispose()); + this.visualizers = []; + } + + /** Update scatter plot with a new array of packed xyz point positions. */ + setPointPositions(worldSpacePointPositions: Float32Array) { + this.worldSpacePointPositions = worldSpacePointPositions; + this.visualizers.forEach( + v => v.onPointPositionsChanged(worldSpacePointPositions)); + } + + render() { + { + const lightPos = this.camera.position.clone(); + lightPos.x += 1; + lightPos.y += 1; + this.light.position.set(lightPos.x, lightPos.y, lightPos.z); + } + + const cameraType = (this.camera instanceof THREE.PerspectiveCamera) ? + CameraType.Perspective : + CameraType.Orthographic; + + let cameraSpacePointExtents: [number, number] = [0, 0]; + if (this.worldSpacePointPositions != null) { + cameraSpacePointExtents = util.getNearFarPoints( + this.worldSpacePointPositions, this.camera.position, + this.orbitCameraControls.target); + } + + const rc = new RenderContext( + this.camera, cameraType, this.orbitCameraControls.target, this.width, + this.height, cameraSpacePointExtents[0], cameraSpacePointExtents[1], + this.backgroundColor, this.pointColors, this.pointScaleFactors, + this.labels, this.polylineColors, this.polylineOpacities, + this.polylineWidths); + + // Render first pass to picking target. This render fills pickingTexture + // with colors that are actually point ids, so that sampling the texture at + // the mouse's current x,y coordinates will reveal the data point that the + // mouse is over. + this.visualizers.forEach(v => v.onPickingRender(rc)); + + { + const axes = this.remove3dAxisFromScene(); + this.renderer.render(this.scene, this.camera, this.pickingTexture); + if (axes != null) { + this.scene.add(axes); + } + } + + // Render second pass to color buffer, to be displayed on the canvas. + this.visualizers.forEach(v => v.onRender(rc)); + + this.renderer.render(this.scene, this.camera); + } + + setMouseMode(mouseMode: MouseMode) { + this.mouseMode = mouseMode; + if (mouseMode === MouseMode.AREA_SELECT) { + this.selecting = true; + this.container.style.cursor = 'crosshair'; + } else { + this.selecting = false; + this.container.style.cursor = 'default'; + } + } + + /** Set the colors for every data point. (RGB triplets) */ + setPointColors(colors: Float32Array) { + this.pointColors = colors; + } + + /** Set the scale factors for every data point. (scalars) */ + setPointScaleFactors(scaleFactors: Float32Array) { + this.pointScaleFactors = scaleFactors; + } + + /** Set the labels to rendered */ + setLabels(labels: LabelRenderParams) { + this.labels = labels; + } + + /** Set the colors for every data polyline. (RGB triplets) */ + setPolylineColors(colors: {[polylineIndex: number]: Float32Array}) { + this.polylineColors = colors; + } + + setPolylineOpacities(opacities: Float32Array) { + this.polylineOpacities = opacities; + } + + setPolylineWidths(widths: Float32Array) { + this.polylineWidths = widths; + } + + getMouseMode(): MouseMode { + return this.mouseMode; + } + + resetZoom() { + this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); + this.render(); + } + + setDayNightMode(isNight: boolean) { + const canvases = this.container.querySelectorAll('canvas'); + const filterValue = isNight ? 'invert(100%)' : null; + for (let i = 0; i < canvases.length; i++) { + canvases[i].style.filter = filterValue; + } + } + + resize(render = true) { + const [oldW, oldH] = [this.width, this.height]; + const [newW, newH] = this.getLayoutValues(); + + if (this.dimensionality === 3) { + const camera = (this.camera as THREE.PerspectiveCamera); + camera.aspect = newW / newH; + camera.updateProjectionMatrix(); + } else { + const camera = (this.camera as THREE.OrthographicCamera); + // Scale the ortho frustum by however much the window changed. + const scaleW = newW / oldW; + const scaleH = newH / oldH; + const newCamHalfWidth = ((camera.right - camera.left) * scaleW) / 2; + const newCamHalfHeight = ((camera.top - camera.bottom) * scaleH) / 2; + camera.top = newCamHalfHeight; + camera.bottom = -newCamHalfHeight; + camera.left = -newCamHalfWidth; + camera.right = newCamHalfWidth; + camera.updateProjectionMatrix(); + } + + // Accouting for retina displays. + const dpr = window.devicePixelRatio || 1; + this.renderer.setPixelRatio(dpr); + this.renderer.setSize(newW, newH); + + // the picking texture needs to be exactly the same as the render texture. + { + const renderCanvasSize = this.renderer.getSize(); + const pixelRatio = this.renderer.getPixelRatio(); + this.pickingTexture = new THREE.WebGLRenderTarget( + renderCanvasSize.width * pixelRatio, + renderCanvasSize.height * pixelRatio); + this.pickingTexture.texture.minFilter = THREE.LinearFilter; + } + + this.visualizers.forEach(v => v.onResize(newW, newH)); + + if (render) { + this.render(); + }; + } + + onCameraMove(listener: OnCameraMoveListener) { + this.onCameraMoveListeners.push(listener); + } + + clickOnPoint(pointIndex: number) { + this.nearestPoint = pointIndex; + this.onClick(null, false); + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts new file mode 100644 index 00000000000..a781877014e --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts @@ -0,0 +1,107 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +const FILL = '#dddddd'; +const FILL_OPACITY = .2; +const STROKE = '#aaaaaa'; +const STROKE_WIDTH = 2; +const STROKE_DASHARRAY = '10 5'; + +export interface BoundingBox { + // The bounding box (x, y) position refers to the bottom left corner of the + // rect. + x: number; + y: number; + width: number; + height: number; +} + +/** + * A class that manages and renders a data selection rectangle. + */ +export class ScatterPlotRectangleSelector { + private svgElement: SVGElement; + private rectElement: SVGRectElement; + + private isMouseDown: boolean; + private startCoordinates: [number, number]; + private lastBoundingBox: BoundingBox; + + private selectionCallback: (boundingBox: BoundingBox) => void; + + /** + * @param container The container HTML element that the selection SVG rect + * will be a child of. + * @param selectionCallback The callback that accepts a bounding box to be + * called when selection changes. Currently, we only call the callback on + * mouseUp. + */ + constructor( + container: HTMLElement, + selectionCallback: (boundingBox: BoundingBox) => void) { + this.svgElement = container.querySelector('#selector') as SVGElement; + this.rectElement = + document.createElementNS('http://www.w3.org/2000/svg', 'rect'); + this.rectElement.style.stroke = STROKE; + this.rectElement.style.strokeDasharray = STROKE_DASHARRAY; + this.rectElement.style.strokeWidth = '' + STROKE_WIDTH; + this.rectElement.style.fill = FILL; + this.rectElement.style.fillOpacity = '' + FILL_OPACITY; + this.svgElement.appendChild(this.rectElement); + + this.selectionCallback = selectionCallback; + this.isMouseDown = false; + } + + onMouseDown(offsetX: number, offsetY: number) { + this.isMouseDown = true; + this.rectElement.style.display = 'block'; + + this.startCoordinates = [offsetX, offsetY]; + this.lastBoundingBox = { + x: this.startCoordinates[0], + y: this.startCoordinates[1], + width: 1, + height: 1 + }; + } + + onMouseMove(offsetX: number, offsetY: number) { + if (!this.isMouseDown) { + return; + } + + this.lastBoundingBox.x = Math.min(offsetX, this.startCoordinates[0]); + this.lastBoundingBox.y = Math.max(offsetY, this.startCoordinates[1]); + this.lastBoundingBox.width = + Math.max(offsetX, this.startCoordinates[0]) - this.lastBoundingBox.x; + this.lastBoundingBox.height = + this.lastBoundingBox.y - Math.min(offsetY, this.startCoordinates[1]); + + this.rectElement.setAttribute('x', '' + this.lastBoundingBox.x); + this.rectElement.setAttribute( + 'y', '' + (this.lastBoundingBox.y - this.lastBoundingBox.height)); + this.rectElement.setAttribute('width', '' + this.lastBoundingBox.width); + this.rectElement.setAttribute('height', '' + this.lastBoundingBox.height); + } + + onMouseUp() { + this.isMouseDown = false; + this.rectElement.style.display = 'none'; + this.rectElement.setAttribute('width', '0'); + this.rectElement.setAttribute('height', '0'); + this.selectionCallback(this.lastBoundingBox); + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts new file mode 100644 index 00000000000..91cb10a97eb --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts @@ -0,0 +1,69 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector'; + +describe('selector callbacks make bounding box start bottom left', () => { + let containerElement: HTMLElement; + let selectionCallback: (boundingBox: BoundingBox) => void; + let selection: ScatterPlotRectangleSelector; + + beforeEach(() => { + containerElement = document.createElement('div'); + const selector = document.createElement('svg'); + selector.id = 'selector'; + containerElement.appendChild(selector); + + selectionCallback = jasmine.createSpy('selectionCallback'); + selection = + new ScatterPlotRectangleSelector(containerElement, selectionCallback); + }); + + it('Simple mouse event starting top left', () => { + selection.onMouseDown(0, 0); + selection.onMouseMove(10, 10); + selection.onMouseUp(); + + expect(selectionCallback) + .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); + }); + + it('Simple mouse event starting bottom left', () => { + selection.onMouseDown(0, 10); + selection.onMouseMove(10, 0); + selection.onMouseUp(); + + expect(selectionCallback) + .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); + }); + + it('Simple mouse event starting top right', () => { + selection.onMouseDown(10, 0); + selection.onMouseMove(0, 10); + selection.onMouseUp(); + + expect(selectionCallback) + .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); + }); + + it('Simple mouse event starting bottom right', () => { + selection.onMouseDown(10, 10); + selection.onMouseMove(0, 0); + selection.onMouseUp(); + + expect(selectionCallback) + .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts new file mode 100644 index 00000000000..b0974a20538 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {RenderContext} from './renderContext'; + +/** + * ScatterPlotVisualizer is an interface used by ScatterPlotContainer + * to manage and aggregate any number of concurrent visualization behaviors. + * To add a new visualization to the 3D scatter plot, create a new class that + * implements this interface and attach it to the ScatterPlotContainer. + */ +export interface ScatterPlotVisualizer { + /** Called to initialize the visualizer with the primary scene. */ + setScene(scene: THREE.Scene); + /** + * Called when the main scatter plot tears down the visualizer. Remove all + * objects from the scene, and dispose any heavy resources. + */ + dispose(); + /** + * Called when the positions of the scatter plot points have changed. + */ + onPointPositionsChanged(newWorldSpacePointPositions: Float32Array); + /** + * Called immediately before the main scatter plot performs a picking + * (selection) render. Set up render state for any geometry to use picking IDs + * instead of visual colors. + */ + onPickingRender(renderContext: RenderContext); + /** + * Called immediately before the main scatter plot performs a color (visual) + * render. Set up render state, lights, etc here. + */ + onRender(renderContext: RenderContext); + /** + * Called when the canvas size changes. + */ + onResize(newWidth: number, newHeight: number); +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts new file mode 100644 index 00000000000..cbd9785e2f6 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts @@ -0,0 +1,367 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const FONT_SIZE = 80; +const ONE_OVER_FONT_SIZE = 1 / FONT_SIZE; +const LABEL_SCALE = 2.2; // at 1:1 texel/pixel ratio +const LABEL_COLOR = 'black'; +const LABEL_BACKGROUND = 'white'; +const MAX_CANVAS_DIMENSION = 8192; +const NUM_GLYPHS = 256; +const RGB_ELEMENTS_PER_ENTRY = 3; +const XYZ_ELEMENTS_PER_ENTRY = 3; +const UV_ELEMENTS_PER_ENTRY = 2; +const VERTICES_PER_GLYPH = 2 * 3; // 2 triangles, 3 verts per triangle + +/** + * Each label is made up of triangles (two per letter.) Each vertex, then, is + * the corner of one of these triangles (and thus the corner of a letter + * rectangle.) + * Each has the following attributes: + * posObj: The (x, y) position of the vertex within the label, where the + * bottom center of the word is positioned at (0, 0); + * position: The position of the label in worldspace. + * vUv: The (u, v) coordinates that index into the glyphs sheet (range 0, 1.) + * color: The color of the label (matches the cooresponding point's color.) + * wordShown: Boolean. Whether or not the label is visible. + */ + +const VERTEX_SHADER = ` + attribute vec2 posObj; + attribute vec3 color; + varying vec2 vUv; + varying vec3 vColor; + + void main() { + vUv = uv; + vColor = color; + + // Rotate label to face camera. + + vec4 vRight = vec4( + modelViewMatrix[0][0], modelViewMatrix[1][0], modelViewMatrix[2][0], 0); + + vec4 vUp = vec4( + modelViewMatrix[0][1], modelViewMatrix[1][1], modelViewMatrix[2][1], 0); + + vec4 vAt = -vec4( + modelViewMatrix[0][2], modelViewMatrix[1][2], modelViewMatrix[2][2], 0); + + mat4 pointToCamera = mat4(vRight, vUp, vAt, vec4(0, 0, 0, 1)); + + vec2 scaledPos = posObj * ${ONE_OVER_FONT_SIZE} * ${LABEL_SCALE}; + + vec4 posRotated = pointToCamera * vec4(scaledPos, 0, 1); + vec4 mvPosition = modelViewMatrix * (vec4(position, 0) + posRotated); + gl_Position = projectionMatrix * mvPosition; + }`; + +const FRAGMENT_SHADER = ` + uniform sampler2D texture; + uniform bool picking; + varying vec2 vUv; + varying vec3 vColor; + + void main() { + if (picking) { + gl_FragColor = vec4(vColor, 1.0); + } else { + vec4 fromTexture = texture2D(texture, vUv); + gl_FragColor = vec4(vColor, 1.0) * fromTexture; + } + }`; + +type GlyphTexture = { + texture: THREE.Texture; lengths: Float32Array; offsets: Float32Array; +}; + +/** + * Renders the text labels as 3d geometry in the world. + */ +export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { + private scene: THREE.Scene; + private labelStrings: string[]; + private geometry: THREE.BufferGeometry; + private worldSpacePointPositions: Float32Array; + private pickingColors: Float32Array; + private renderColors: Float32Array; + private material: THREE.ShaderMaterial; + private uniforms: Object; + private labelsMesh: THREE.Mesh; + private positions: THREE.BufferAttribute; + private totalVertexCount: number; + private labelVertexMap: number[][]; + private glyphTexture: GlyphTexture; + + private createGlyphTexture(): GlyphTexture { + let canvas = document.createElement('canvas'); + canvas.width = MAX_CANVAS_DIMENSION; + canvas.height = FONT_SIZE; + let ctx = canvas.getContext('2d'); + ctx.font = 'bold ' + FONT_SIZE * 0.75 + 'px roboto'; + ctx.textBaseline = 'top'; + ctx.fillStyle = LABEL_BACKGROUND; + ctx.rect(0, 0, canvas.width, canvas.height); + ctx.fill(); + ctx.fillStyle = LABEL_COLOR; + let spaceOffset = ctx.measureText(' ').width; + // For each letter, store length, position at the encoded index. + let glyphLengths = new Float32Array(NUM_GLYPHS); + let glyphOffset = new Float32Array(NUM_GLYPHS); + let leftCoord = 0; + for (let i = 0; i < NUM_GLYPHS; i++) { + let text = ' ' + String.fromCharCode(i); + let textLength = ctx.measureText(text).width; + glyphLengths[i] = textLength - spaceOffset; + glyphOffset[i] = leftCoord; + ctx.fillText(text, leftCoord - spaceOffset, 0); + leftCoord += textLength; + } + const tex = util.createTexture(canvas); + return {texture: tex, lengths: glyphLengths, offsets: glyphOffset}; + } + + private processLabelVerts(pointCount: number) { + let numTotalLetters = 0; + this.labelVertexMap = []; + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; + let vertsArray: number[] = []; + for (let j = 0; j < label.length; j++) { + for (let k = 0; k < VERTICES_PER_GLYPH; k++) { + vertsArray.push(numTotalLetters * VERTICES_PER_GLYPH + k); + } + numTotalLetters++; + } + this.labelVertexMap.push(vertsArray); + } + this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH; + } + + private createColorBuffers(pointCount: number) { + this.pickingColors = + new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); + this.renderColors = + new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); + for (let i = 0; i < pointCount; i++) { + let color = new THREE.Color(i); + this.labelVertexMap[i].forEach((j) => { + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r; + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = color.g; + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = color.b; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j] = 1.0; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = 1.0; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = 1.0; + }); + } + } + + private createLabels() { + if ((this.labelStrings == null) || + (this.worldSpacePointPositions == null)) { + return; + } + const pointCount = + this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY; + if (pointCount !== this.labelStrings.length) { + return; + } + this.glyphTexture = this.createGlyphTexture(); + + this.uniforms = { + texture: {type: 't'}, + picking: {type: 'bool'}, + }; + + this.material = new THREE.ShaderMaterial({ + uniforms: this.uniforms, + transparent: true, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER, + }); + + this.processLabelVerts(pointCount); + this.createColorBuffers(pointCount); + + let positionArray = + new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY); + this.positions = + new THREE.BufferAttribute(positionArray, XYZ_ELEMENTS_PER_ENTRY); + + let posArray = + new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY); + let uvArray = + new Float32Array(this.totalVertexCount * UV_ELEMENTS_PER_ENTRY); + let colorsArray = + new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); + let positionObject = new THREE.BufferAttribute(posArray, 2); + let uv = new THREE.BufferAttribute(uvArray, UV_ELEMENTS_PER_ENTRY); + let colors = new THREE.BufferAttribute(colorsArray, RGB_ELEMENTS_PER_ENTRY); + + this.geometry = new THREE.BufferGeometry(); + this.geometry.addAttribute('posObj', positionObject); + this.geometry.addAttribute('position', this.positions); + this.geometry.addAttribute('uv', uv); + this.geometry.addAttribute('color', colors); + + let lettersSoFar = 0; + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; + let leftOffset = 0; + // Determine length of word in pixels. + for (let j = 0; j < label.length; j++) { + let letterCode = label.charCodeAt(j); + leftOffset += this.glyphTexture.lengths[letterCode]; + } + leftOffset /= -2; // centers text horizontally around the origin + for (let j = 0; j < label.length; j++) { + let letterCode = label.charCodeAt(j); + let letterWidth = this.glyphTexture.lengths[letterCode]; + let scale = FONT_SIZE; + let right = (leftOffset + letterWidth) / scale; + let left = (leftOffset) / scale; + let top = FONT_SIZE / scale; + + // First triangle + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, left, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, right, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, left, top); + + // Second triangle + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, left, top); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, right, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, right, top); + + // Set UVs based on letter. + let uLeft = (this.glyphTexture.offsets[letterCode]); + let uRight = (this.glyphTexture.offsets[letterCode] + letterWidth); + // Scale so that uvs lie between 0 and 1 on the texture. + uLeft /= MAX_CANVAS_DIMENSION; + uRight /= MAX_CANVAS_DIMENSION; + let vTop = 1; + let vBottom = 0; + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, uLeft, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, uRight, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, uLeft, vBottom); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, uLeft, vBottom); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, uRight, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, uRight, vBottom); + + lettersSoFar++; + leftOffset += letterWidth; + } + } + + for (let i = 0; i < pointCount; i++) { + const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i); + this.labelVertexMap[i].forEach((j) => { + this.positions.setXYZ(j, p.x, p.y, p.z); + }); + }; + + this.labelsMesh = new THREE.Mesh(this.geometry, this.material); + this.labelsMesh.frustumCulled = false; + this.scene.add(this.labelsMesh); + } + + private colorLabels(pointColors: Float32Array) { + if (this.labelStrings == null || this.geometry == null || + pointColors == null) { + return; + } + + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + colors.array = this.renderColors; + + const n = pointColors.length / XYZ_ELEMENTS_PER_ENTRY; + let src = 0; + for (let i = 0; i < n; ++i) { + const c = new THREE.Color( + pointColors[src], pointColors[src + 1], pointColors[src + 2]); + const m = this.labelVertexMap[i].length; + for (let j = 0; j < m; ++j) { + colors.setXYZ(this.labelVertexMap[i][j], c.r, c.g, c.b); + } + src += RGB_ELEMENTS_PER_ENTRY; + } + colors.needsUpdate = true; + } + + setScene(scene: THREE.Scene) { + this.scene = scene; + } + + dispose() { + if (this.labelsMesh) { + if (this.scene) { + this.scene.remove(this.labelsMesh); + } + this.labelsMesh = null; + } + if (this.geometry) { + this.geometry.dispose(); + this.geometry = null; + } + if ((this.glyphTexture != null) && (this.glyphTexture.texture != null)) { + this.glyphTexture.texture.dispose(); + this.glyphTexture.texture = null; + } + } + + onPickingRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); + } + if (this.geometry == null) { + return; + } + this.material.uniforms.texture.value = this.glyphTexture.texture; + this.material.uniforms.picking.value = true; + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + colors.array = this.pickingColors; + colors.needsUpdate = true; + } + + onRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); + } + if (this.geometry == null) { + return; + } + this.colorLabels(rc.pointColors); + this.material.uniforms.texture.value = this.glyphTexture.texture; + this.material.uniforms.picking.value = false; + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + colors.array = this.renderColors; + colors.needsUpdate = true; + } + + onPointPositionsChanged(newPositions: Float32Array) { + this.worldSpacePointPositions = newPositions; + this.dispose(); + } + + setLabelStrings(labelStrings: string[]) { + this.labelStrings = labelStrings; + this.dispose(); + } + + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts new file mode 100644 index 00000000000..ece4d84ef28 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts @@ -0,0 +1,187 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 +import {BoundingBox, CollisionGrid} from './label'; +import {CameraType, RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const MAX_LABELS_ON_SCREEN = 10000; +const LABEL_STROKE_WIDTH = 3; +const LABEL_FILL_WIDTH = 6; + +/** + * Creates and maintains a 2d canvas on top of the GL canvas. All labels, when + * active, are rendered to the 2d canvas as part of the visible render pass. + */ +export class ScatterPlotVisualizerCanvasLabels implements + ScatterPlotVisualizer { + private worldSpacePointPositions: Float32Array; + private gc: CanvasRenderingContext2D; + private canvas: HTMLCanvasElement; + private labelsActive: boolean = true; + + constructor(container: HTMLElement) { + this.canvas = document.createElement('canvas'); + container.appendChild(this.canvas); + + this.gc = this.canvas.getContext('2d'); + this.canvas.style.position = 'absolute'; + this.canvas.style.left = '0'; + this.canvas.style.top = '0'; + this.canvas.style.pointerEvents = 'none'; + } + + private removeAllLabels() { + const pixelWidth = this.canvas.width * window.devicePixelRatio; + const pixelHeight = this.canvas.height * window.devicePixelRatio; + this.gc.clearRect(0, 0, pixelWidth, pixelHeight); + } + + /** Render all of the non-overlapping visible labels to the canvas. */ + private makeLabels(rc: RenderContext) { + if ((rc.labels == null) || (rc.labels.pointIndices.length === 0)) { + return; + } + if (this.worldSpacePointPositions == null) { + return; + } + + const lrc = rc.labels; + const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); + const labelHeight = parseInt(this.gc.font, 10); + const dpr = window.devicePixelRatio; + + let grid: CollisionGrid; + { + const pixw = this.canvas.width * dpr; + const pixh = this.canvas.height * dpr; + const bb: BoundingBox = {loX: 0, hiX: pixw, loY: 0, hiY: pixh}; + grid = new CollisionGrid(bb, pixw / 25, pixh / 50); + } + + let opacityMap = + d3.scalePow() + .exponent(Math.E) + .domain([rc.farthestCameraSpacePointZ, rc.nearestCameraSpacePointZ]) + .range([0.1, 1]); + + const camPos = rc.camera.position; + const camToTarget = camPos.clone().sub(rc.cameraTarget); + let camToPoint = new THREE.Vector3(); + + this.gc.textBaseline = 'middle'; + this.gc.miterLimit = 2; + + // Have extra space between neighboring labels. Don't pack too tightly. + const labelMargin = 2; + // Shift the label to the right of the point circle. + const xShift = 4; + + const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length); + for (let i = 0; i < n; ++i) { + let point: THREE.Vector3; + { + const pi = lrc.pointIndices[i]; + point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi); + } + + // discard points that are behind the camera + camToPoint.copy(camPos).sub(point); + if (camToTarget.dot(camToPoint) < 0) { + continue; + } + + let [x, y] = util.vector3DToScreenCoords( + rc.camera, rc.screenWidth, rc.screenHeight, point); + x += xShift; + + // Computing the width of the font is expensive, + // so we assume width of 1 at first. Then, if the label doesn't + // conflict with other labels, we measure the actual width. + const textBoundingBox: BoundingBox = { + loX: x - labelMargin, + hiX: x + 1 + labelMargin, + loY: y - labelHeight / 2 - labelMargin, + hiY: y + labelHeight / 2 + labelMargin + }; + + if (grid.insert(textBoundingBox, true)) { + const text = lrc.labelStrings[i]; + const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr; + this.gc.font = fontSize + 'px roboto'; + + // Now, check with properly computed width. + textBoundingBox.hiX += this.gc.measureText(text).width - 1; + if (grid.insert(textBoundingBox)) { + let opacity = 1; + if (sceneIs3D && (lrc.useSceneOpacityFlags[i] === 1)) { + opacity = opacityMap(camToPoint.length()); + } + this.gc.fillStyle = + this.styleStringFromPackedRgba(lrc.fillColors, i, opacity); + this.gc.strokeStyle = + this.styleStringFromPackedRgba(lrc.strokeColors, i, opacity); + this.gc.lineWidth = LABEL_STROKE_WIDTH; + this.gc.strokeText(text, x, y); + this.gc.lineWidth = LABEL_FILL_WIDTH; + this.gc.fillText(text, x, y); + } + } + } + } + + private styleStringFromPackedRgba( + packedRgbaArray: Uint8Array, colorIndex: number, + opacity: number): string { + const offset = colorIndex * 3; + const r = packedRgbaArray[offset]; + const g = packedRgbaArray[offset + 1]; + const b = packedRgbaArray[offset + 2]; + return 'rgba(' + r + ',' + g + ',' + b + ',' + opacity + ')'; + } + + onResize(newWidth: number, newHeight: number) { + let dpr = window.devicePixelRatio; + this.canvas.width = newWidth * dpr; + this.canvas.height = newHeight * dpr; + this.canvas.style.width = newWidth + 'px'; + this.canvas.style.height = newHeight + 'px'; + } + + dispose() { + this.removeAllLabels(); + this.canvas = null; + this.gc = null; + } + + onPointPositionsChanged(newPositions: Float32Array) { + this.worldSpacePointPositions = newPositions; + this.removeAllLabels(); + } + + onRender(rc: RenderContext) { + if (!this.labelsActive) { + return; + } + + this.removeAllLabels(); + this.makeLabels(rc); + } + + setScene(scene: THREE.Scene) {} + onPickingRender(renderContext: RenderContext) {} +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts new file mode 100644 index 00000000000..e6d4aeda28b --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts @@ -0,0 +1,149 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataSet} from './data'; +import {RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const RGB_NUM_ELEMENTS = 3; +const XYZ_NUM_ELEMENTS = 3; + +/** + * Renders polylines that connect multiple points in the dataset. + */ +export class ScatterPlotVisualizerPolylines implements ScatterPlotVisualizer { + private dataSet: DataSet; + private scene: THREE.Scene; + private polylines: THREE.Line[]; + private polylinePositionBuffer: + {[polylineIndex: number]: THREE.BufferAttribute} = {}; + private polylineColorBuffer: + {[polylineIndex: number]: THREE.BufferAttribute} = {}; + + private updateSequenceIndicesInDataSet(ds: DataSet) { + for (let i = 0; i < ds.sequences.length; i++) { + const sequence = ds.sequences[i]; + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + ds.points[sequence.pointIndices[j]].sequenceIndex = i; + ds.points[sequence.pointIndices[j + 1]].sequenceIndex = i; + } + } + } + + private createPolylines(scene: THREE.Scene) { + if (!this.dataSet || !this.dataSet.sequences) { + return; + } + + this.updateSequenceIndicesInDataSet(this.dataSet); + this.polylines = []; + + for (let i = 0; i < this.dataSet.sequences.length; i++) { + const geometry = new THREE.BufferGeometry(); + geometry.addAttribute('position', this.polylinePositionBuffer[i]); + geometry.addAttribute('color', this.polylineColorBuffer[i]); + + const material = new THREE.LineBasicMaterial({ + linewidth: 1, // unused default, overwritten by width array. + opacity: 1.0, // unused default, overwritten by opacity array. + transparent: true, + vertexColors: THREE.VertexColors + }); + + const polyline = new THREE.LineSegments(geometry, material); + polyline.frustumCulled = false; + this.polylines.push(polyline); + scene.add(polyline); + } + } + + dispose() { + if (this.polylines == null) { + return; + } + for (let i = 0; i < this.polylines.length; i++) { + this.scene.remove(this.polylines[i]); + this.polylines[i].geometry.dispose(); + } + this.polylines = null; + this.polylinePositionBuffer = {}; + this.polylineColorBuffer = {}; + } + + setScene(scene: THREE.Scene) { + this.scene = scene; + } + + setDataSet(dataSet: DataSet) { + this.dataSet = dataSet; + } + + onPointPositionsChanged(newPositions: Float32Array) { + if ((newPositions == null) || (this.polylines != null)) { + this.dispose(); + } + if ((newPositions == null) || (this.dataSet == null)) { + return; + } + // Set up the position buffer arrays for each polyline. + for (let i = 0; i < this.dataSet.sequences.length; i++) { + let sequence = this.dataSet.sequences[i]; + const vertexCount = 2 * (sequence.pointIndices.length - 1); + + let polylines = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); + this.polylinePositionBuffer[i] = + new THREE.BufferAttribute(polylines, XYZ_NUM_ELEMENTS); + + let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); + this.polylineColorBuffer[i] = + new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS); + } + for (let i = 0; i < this.dataSet.sequences.length; i++) { + const sequence = this.dataSet.sequences[i]; + let src = 0; + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const p1Index = sequence.pointIndices[j]; + const p2Index = sequence.pointIndices[j + 1]; + const p1 = util.vector3FromPackedArray(newPositions, p1Index); + const p2 = util.vector3FromPackedArray(newPositions, p2Index); + this.polylinePositionBuffer[i].setXYZ(src, p1.x, p1.y, p1.z); + this.polylinePositionBuffer[i].setXYZ(src + 1, p2.x, p2.y, p2.z); + src += 2; + } + this.polylinePositionBuffer[i].needsUpdate = true; + } + + if (this.polylines == null) { + this.createPolylines(this.scene); + } + } + + onRender(renderContext: RenderContext) { + if (this.polylines == null) { + return; + } + for (let i = 0; i < this.polylines.length; i++) { + this.polylines[i].material.opacity = renderContext.polylineOpacities[i]; + (this.polylines[i].material as THREE.LineBasicMaterial).linewidth = + renderContext.polylineWidths[i]; + this.polylineColorBuffer[i].array = renderContext.polylineColors[i]; + this.polylineColorBuffer[i].needsUpdate = true; + } + } + + onPickingRender(renderContext: RenderContext) {} + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts new file mode 100644 index 00000000000..8adc9a9bd23 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts @@ -0,0 +1,435 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {CameraType, RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const NUM_POINTS_FOG_THRESHOLD = 5000; +const MIN_POINT_SIZE = 5.0; +const IMAGE_SIZE = 30; + +// Constants relating to the indices of buffer arrays. +const RGB_NUM_ELEMENTS = 3; +const INDEX_NUM_ELEMENTS = 1; +const XYZ_NUM_ELEMENTS = 3; + +const VERTEX_SHADER = ` + // Index of the specific vertex (passed in as bufferAttribute), and the + // variable that will be used to pass it to the fragment shader. + attribute float spriteIndex; + attribute vec3 color; + attribute float scaleFactor; + + varying vec2 xyIndex; + varying vec3 vColor; + + uniform bool sizeAttenuation; + uniform float pointSize; + uniform float spritesPerRow; + uniform float spritesPerColumn; + + void main() { + // Pass index and color values to fragment shader. + vColor = color; + xyIndex = vec2(mod(spriteIndex, spritesPerRow), + floor(spriteIndex / spritesPerColumn)); + + // Transform current vertex by modelViewMatrix (model world position and + // camera world position matrix). + vec4 cameraSpacePos = modelViewMatrix * vec4(position, 1.0); + + // Project vertex in camera-space to screen coordinates using the camera's + // projection matrix. + gl_Position = projectionMatrix * cameraSpacePos; + + // Create size attenuation (if we're in 3D mode) by making the size of + // each point inversly proportional to its distance to the camera. + float outputPointSize = pointSize; + if (sizeAttenuation) { + outputPointSize = -pointSize / cameraSpacePos.z; + } + + gl_PointSize = + max(outputPointSize * scaleFactor, ${MIN_POINT_SIZE.toFixed(1)}); + }`; + +const FRAGMENT_SHADER_POINT_TEST_CHUNK = ` + bool point_in_unit_circle(vec2 spriteCoord) { + vec2 centerToP = spriteCoord - vec2(0.5, 0.5); + return dot(centerToP, centerToP) < (0.5 * 0.5); + } + + bool point_in_unit_equilateral_triangle(vec2 spriteCoord) { + vec3 v0 = vec3(0, 1, 0); + vec3 v1 = vec3(0.5, 0, 0); + vec3 v2 = vec3(1, 1, 0); + vec3 p = vec3(spriteCoord, 0); + float p_in_v0_v1 = cross(v1 - v0, p - v0).z; + float p_in_v1_v2 = cross(v2 - v1, p - v1).z; + return (p_in_v0_v1 > 0.0) && (p_in_v1_v2 > 0.0); + } + + bool point_in_unit_square(vec2 spriteCoord) { + return true; + } +`; + +const FRAGMENT_SHADER = ` + varying vec2 xyIndex; + varying vec3 vColor; + + uniform sampler2D texture; + uniform float spritesPerRow; + uniform float spritesPerColumn; + uniform bool isImage; + + ${THREE.ShaderChunk['common']} + ${THREE.ShaderChunk['fog_pars_fragment']} + ${FRAGMENT_SHADER_POINT_TEST_CHUNK} + + void main() { + if (isImage) { + // Coordinates of the vertex within the entire sprite image. + vec2 coords = + (gl_PointCoord + xyIndex) / vec2(spritesPerRow, spritesPerColumn); + gl_FragColor = vec4(vColor, 1.0) * texture2D(texture, coords); + } else { + bool inside = point_in_unit_circle(gl_PointCoord); + if (!inside) { + discard; + } + gl_FragColor = vec4(vColor, 1); + } + ${THREE.ShaderChunk['fog_fragment']} + }`; + +const FRAGMENT_SHADER_PICKING = ` + varying vec2 xyIndex; + varying vec3 vColor; + uniform bool isImage; + + ${FRAGMENT_SHADER_POINT_TEST_CHUNK} + + void main() { + xyIndex; // Silence 'unused variable' warning. + if (isImage) { + gl_FragColor = vec4(vColor, 1); + } else { + bool inside = point_in_unit_circle(gl_PointCoord); + if (!inside) { + discard; + } + gl_FragColor = vec4(vColor, 1); + } + }`; + +/** + * Uses GL point sprites to render the dataset. + */ +export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { + private scene: THREE.Scene; + private fog: THREE.Fog; + private texture: THREE.Texture = null; + private standinTextureForPoints: THREE.Texture; + private spritesPerRow: number; + private spritesPerColumn: number; + private spriteDimensions: [number, number]; + private spriteIndexBufferAttribute: THREE.BufferAttribute; + private renderMaterial: THREE.ShaderMaterial; + private pickingMaterial: THREE.ShaderMaterial; + + private points: THREE.Points; + private worldSpacePointPositions: Float32Array; + private pickingColors: Float32Array; + private renderColors: Float32Array; + + constructor() { + this.standinTextureForPoints = + util.createTexture(document.createElement('canvas')); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } + + private createTextureFromSpriteAtlas( + spriteAtlas: HTMLImageElement, spriteDimensions: [number, number], + spriteIndices: Float32Array) { + this.texture = util.createTexture(spriteAtlas); + this.spritesPerRow = spriteAtlas.width / spriteDimensions[0]; + this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1]; + this.spriteDimensions = spriteDimensions; + this.spriteIndexBufferAttribute = + new THREE.BufferAttribute(spriteIndices, INDEX_NUM_ELEMENTS); + + if (this.points != null) { + (this.points.geometry as THREE.BufferGeometry) + .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); + } + } + + private createUniforms(): any { + return { + texture: {type: 't'}, + spritesPerRow: {type: 'f'}, + spritesPerColumn: {type: 'f'}, + fogColor: {type: 'c'}, + fogNear: {type: 'f'}, + fogFar: {type: 'f'}, + isImage: {type: 'bool'}, + sizeAttenuation: {type: 'bool'}, + pointSize: {type: 'f'} + }; + } + + private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER, + transparent: !haveImage, + depthTest: haveImage, + depthWrite: haveImage, + fog: true, + blending: THREE.MultiplyBlending, + }); + } + + private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER_PICKING, + transparent: true, + depthTest: true, + depthWrite: true, + fog: false, + blending: THREE.NormalBlending, + }); + } + + /** + * Create points, set their locations and actually instantiate the + * geometry. + */ + private createPointSprites(scene: THREE.Scene, positions: Float32Array) { + const pointCount = + (positions != null) ? (positions.length / XYZ_NUM_ELEMENTS) : 0; + const geometry = this.createGeometry(pointCount); + + this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten. + + this.points = new THREE.Points(geometry, this.renderMaterial); + this.points.frustumCulled = false; + if (this.spriteIndexBufferAttribute != null) { + (this.points.geometry as THREE.BufferGeometry) + .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); + } + scene.add(this.points); + } + + private calculatePointSize(sceneIs3D: boolean): number { + if (this.texture != null) { + return sceneIs3D ? IMAGE_SIZE : this.spriteDimensions[0]; + } + const n = (this.worldSpacePointPositions != null) ? + (this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS) : + 1; + const SCALE = 200; + const LOG_BASE = 8; + const DIVISOR = 1.5; + // Scale point size inverse-logarithmically to the number of points. + const pointSize = SCALE / Math.log(n) / Math.log(LOG_BASE); + return sceneIs3D ? pointSize : (pointSize / DIVISOR); + } + + /** + * Set up buffer attributes to be used for the points/images. + */ + private createGeometry(pointCount: number): THREE.BufferGeometry { + const n = pointCount; + + // Fill pickingColors with each point's unique id as its color. + this.pickingColors = new Float32Array(n * RGB_NUM_ELEMENTS); + { + let dst = 0; + for (let i = 0; i < n; i++) { + const c = new THREE.Color(i); + this.pickingColors[dst++] = c.r; + this.pickingColors[dst++] = c.g; + this.pickingColors[dst++] = c.b; + } + } + + const geometry = new THREE.BufferGeometry(); + geometry.addAttribute( + 'position', new THREE.BufferAttribute(null, XYZ_NUM_ELEMENTS)); + geometry.addAttribute( + 'color', new THREE.BufferAttribute(null, RGB_NUM_ELEMENTS)); + geometry.addAttribute( + 'scaleFactor', new THREE.BufferAttribute(null, INDEX_NUM_ELEMENTS)); + return geometry; + } + + private setFogDistances( + sceneIs3D: boolean, nearestPointZ: number, farthestPointZ: number) { + if (sceneIs3D) { + const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS; + this.fog.near = nearestPointZ; + // If there are fewer points we want less fog. We do this + // by making the "far" value (that is, the distance from the camera to the + // far edge of the fog) proportional to the number of points. + let multiplier = + 2 - Math.min(n, NUM_POINTS_FOG_THRESHOLD) / NUM_POINTS_FOG_THRESHOLD; + this.fog.far = farthestPointZ * multiplier; + } else { + this.fog.near = Infinity; + this.fog.far = Infinity; + } + } + + dispose() { + this.disposeGeometry(); + this.disposeTextureAtlas(); + } + + private disposeGeometry() { + if (this.points != null) { + this.scene.remove(this.points); + this.points.geometry.dispose(); + this.points = null; + this.worldSpacePointPositions = null; + } + } + + private disposeTextureAtlas() { + if (this.texture != null) { + this.texture.dispose(); + } + this.texture = null; + this.renderMaterial = null; + this.pickingMaterial = null; + } + + setScene(scene: THREE.Scene) { + this.scene = scene; + } + + setSpriteAtlas( + spriteImage: HTMLImageElement, spriteDimensions: [number, number], + spriteIndices: Uint8Array) { + this.disposeTextureAtlas(); + this.createTextureFromSpriteAtlas( + spriteImage, spriteDimensions, spriteIndices); + this.renderMaterial = this.createRenderMaterial(true); + this.pickingMaterial = this.createPickingMaterial(true); + } + + clearSpriteAtlas() { + this.disposeTextureAtlas(); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } + + onPointPositionsChanged(newPositions: Float32Array) { + if ((newPositions == null) || (newPositions.length === 0)) { + this.dispose(); + return; + } + if (this.points != null) { + if (this.worldSpacePointPositions.length !== newPositions.length) { + this.disposeGeometry(); + } + } + + this.worldSpacePointPositions = newPositions; + + if (this.points == null) { + this.createPointSprites(this.scene, newPositions); + } + + const positions = (this.points.geometry as THREE.BufferGeometry) + .getAttribute('position') as THREE.BufferAttribute; + positions.array = newPositions; + positions.needsUpdate = true; + } + + onPickingRender(rc: RenderContext) { + if (this.points == null) { + return; + } + + const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); + + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn; + this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D; + this.pickingMaterial.uniforms.pointSize.value = + this.calculatePointSize(sceneIs3D); + this.points.material = this.pickingMaterial; + + let colors = (this.points.geometry as THREE.BufferGeometry) + .getAttribute('color') as THREE.BufferAttribute; + colors.array = this.pickingColors; + colors.needsUpdate = true; + + let scaleFactors = + (this.points.geometry as THREE.BufferGeometry) + .getAttribute('scaleFactor') as THREE.BufferAttribute; + scaleFactors.array = rc.pointScaleFactors; + scaleFactors.needsUpdate = true; + } + + onRender(rc: RenderContext) { + if (!this.points) { + return; + } + const sceneIs3D: boolean = (rc.camera instanceof THREE.PerspectiveCamera); + + this.setFogDistances( + sceneIs3D, rc.nearestCameraSpacePointZ, rc.farthestCameraSpacePointZ); + + this.scene.fog = this.fog; + this.scene.fog.color = new THREE.Color(rc.backgroundColor); + + this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color; + this.renderMaterial.uniforms.fogNear.value = this.fog.near; + this.renderMaterial.uniforms.fogFar.value = this.fog.far; + this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn; + this.renderMaterial.uniforms.isImage.value = (this.texture != null); + this.renderMaterial.uniforms.texture.value = + (this.texture != null) ? this.texture : this.standinTextureForPoints; + this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D; + this.renderMaterial.uniforms.pointSize.value = + this.calculatePointSize(sceneIs3D); + this.points.material = this.renderMaterial; + + let colors = (this.points.geometry as THREE.BufferGeometry) + .getAttribute('color') as THREE.BufferAttribute; + this.renderColors = rc.pointColors; + colors.array = this.renderColors; + colors.needsUpdate = true; + + let scaleFactors = + (this.points.geometry as THREE.BufferGeometry) + .getAttribute('scaleFactor') as THREE.BufferAttribute; + scaleFactors.array = rc.pointScaleFactors; + scaleFactors.needsUpdate = true; + } + + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts new file mode 100644 index 00000000000..991369a3352 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts @@ -0,0 +1,175 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** N-dimensional point. Usually 2D or 3D. */ +export type Point = number[]; + +export interface BBox { + center: Point; + halfDim: number; +} + +/** A node in a space-partitioning tree. */ +export interface SPNode { + /** The children of this node. */ + children?: SPNode[]; + /** The bounding box of the region this node occupies. */ + box: BBox; + /** One or more points this node has. */ + point: Point; +} + +/** + * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning) + * that recursively divides the space into regions of equal sizes. This data + * structure can act both as a Quad tree and an Octree when the data is 2 or + * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut + * approximation. + */ +export class SPTree { + root: SPNode; + + private masks: number[]; + private dim: number; + + /** + * Constructs a new tree with the provided data. + * + * @param data List of n-dimensional data points. + * @param capacity Number of data points to store in a single node. + */ + constructor(data: Point[]) { + if (data.length < 1) { + throw new Error('There should be at least 1 data point'); + } + // Make a bounding box based on the extent of the data. + this.dim = data[0].length; + // Each node has 2^d children, where d is the dimension of the space. + // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in + // which child (e.g. quadron in 2D) the new point is going to be assigned. + // For more details, see the insert() method and its comments. + this.masks = new Array(Math.pow(2, this.dim)); + for (let d = 0; d < this.masks.length; ++d) { + this.masks[d] = (1 << d); + } + let min: Point = new Array(this.dim); + fillArray(min, Number.POSITIVE_INFINITY); + let max: Point = new Array(this.dim); + fillArray(max, Number.NEGATIVE_INFINITY); + + for (let i = 0; i < data.length; ++i) { + // For each dim get the min and max. + // E.g. For 2-D, get the x_min, x_max, y_min, y_max. + for (let d = 0; d < this.dim; ++d) { + min[d] = Math.min(min[d], data[i][d]); + max[d] = Math.max(max[d], data[i][d]); + } + } + // Create a bounding box with the center of the largest span. + let center: Point = new Array(this.dim); + let halfDim = 0; + for (let d = 0; d < this.dim; ++d) { + let span = max[d] - min[d]; + center[d] = min[d] + span / 2; + halfDim = Math.max(halfDim, span / 2); + } + this.root = {box: {center: center, halfDim: halfDim}, point: data[0]}; + for (let i = 1; i < data.length; ++i) { + this.insert(this.root, data[i]); + } + } + + /** + * Visits every node in the tree. Each node can store 1 or more points, + * depending on the node capacity provided in the constructor. + * + * @param accessor Method that takes the currently visited node, and the + * low and high point of the region that this node occupies. E.g. in 2D, + * the low and high points will be the lower-left corner and the upper-right + * corner. + */ + visit( + accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean, + noBox = false) { + this.visitNode(this.root, accessor, noBox); + } + + private visitNode( + node: SPNode, + accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean, + noBox: boolean) { + let skipChildren: boolean; + if (noBox) { + skipChildren = accessor(node); + } else { + let lowPoint = new Array(this.dim); + let highPoint = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + lowPoint[d] = node.box.center[d] - node.box.halfDim; + highPoint[d] = node.box.center[d] + node.box.halfDim; + } + skipChildren = accessor(node, lowPoint, highPoint); + } + if (!node.children || skipChildren) { + return; + } + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child) { + this.visitNode(child, accessor, noBox); + } + } + } + + private insert(node: SPNode, p: Point) { + // Subdivide and then add the point to whichever node will accept it. + if (node.children == null) { + node.children = new Array(this.masks.length); + } + + // Decide which child will get the new point by constructing a D-bits binary + // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th + // coordinate is greater than the node's k-th coordinate, 0 otherwise. + // Then the binary signature in decimal system gives us the index of the + // child where the new point should be. + let index = 0; + for (let d = 0; d < this.dim; ++d) { + if (p[d] > node.box.center[d]) { + index |= this.masks[d]; + } + } + if (node.children[index] == null) { + this.makeChild(node, index, p); + } else { + this.insert(node.children[index], p); + } + } + + private makeChild(node: SPNode, index: number, p: Point): void { + let oldC = node.box.center; + let h = node.box.halfDim / 2; + let newC: Point = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + newC[d] = (index & (1 << d)) ? oldC[d] + h : oldC[d] - h; + } + node.children[index] = {box: {center: newC, halfDim: h}, point: p}; + } +} + +function fillArray(arr: T[], value: T): void { + for (let i = 0; i < arr.length; ++i) { + arr[i] = value; + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts new file mode 100644 index 00000000000..440680bdf1e --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts @@ -0,0 +1,104 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {SPTree} from './sptree'; + +const assert = chai.assert; + +it('simple 2D data', () => { + let data = [ + [0, 1], + [1, 0], + [1, 1], + [0, 0], + ]; + let tree = new SPTree(data); + // Check that each point is within the bound. + tree.visit((node, low, high) => { + assert.equal(low.length, 2); + assert.equal(high.length, 2); + let point = node.point; + assert.equal(point.length, 2); + // Each point should be in the node's bounding box. + assert.equal( + point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && + point[1] <= high[1], + true); + return false; + }); +}); + +it('simple 3D data', () => { + let data = [ + [0, 1, 0], + [1, 0.4, 2], + [1, 1, 3], + [0, 0, 5], + ]; + let tree = new SPTree(data); + // Check that each point is within the bound. + tree.visit((node, low, high) => { + assert.equal(low.length, 3); + assert.equal(high.length, 3); + let point = node.point; + assert.equal(point.length, 3); + // Each point should be in the node's bounding box. + assert.equal( + point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && + point[1] <= high[1] && point[2] >= low[2] && point[2] <= high[2], + true); + return false; + }); +}); + +it('Only visit root', () => { + let data = [ + [0, 1, 0], + [1, 0.4, 2], + [1, 1, 3], + [0, 0, 5], + ]; + let tree = new SPTree(data); + let numVisits = 0; + tree.visit((node, low, high) => { + numVisits++; + return true; + }); + assert.equal(numVisits, 1); +}); + +it('Search in random data', () => { + let N = 10000; + let data = new Array(N); + for (let i = 0; i < N; i++) { + data[i] = [Math.random(), Math.random()]; + } + let tree = new SPTree(data); + let numVisits = 0; + let query = data[Math.floor(Math.random() * N)]; + let found = false; + tree.visit((node, low, high) => { + numVisits++; + if (node.point === query) { + found = true; + return true; + } + let outOfBounds = query[0] < low[0] || query[0] > high[0] || + query[1] < low[1] || query[1] > high[1]; + return outOfBounds; + }); + assert.equal(found, true); + assert.isBelow(numVisits, N / 4); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html b/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html new file mode 100644 index 00000000000..32dc984b5d6 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html @@ -0,0 +1,185 @@ + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts new file mode 100644 index 00000000000..bd6df68b1a5 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts @@ -0,0 +1,252 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint} from './data'; +import * as logging from './logging'; +import {Point2D} from './vector'; + +/** + * Delay for running expensive tasks, in milliseconds. + * The duration was empirically found so that it leaves enough time for the + * browser to update its UI state before starting an expensive UI-blocking task. + */ +const TASK_DELAY_MS = 200; + +/** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ +export function shuffle(array: T[]): T[] { + let m = array.length; + let t: T; + let i: number; + + // While there remain elements to shuffle. + while (m) { + // Pick a remaining element + i = Math.floor(Math.random() * m--); + // And swap it with the current element. + t = array[m]; + array[m] = array[i]; + array[i] = t; + } + return array; +} + +export function range(count: number): number[] { + const rangeOutput: number[] = []; + for (let i = 0; i < count; i++) { + rangeOutput.push(i); + } + return rangeOutput; +} + +export function classed( + element: HTMLElement, className: string, enabled: boolean) { + const classNames = element.className.split(' '); + if (enabled) { + if (className in classNames) { + return; + } else { + classNames.push(className); + } + } else { + const index = classNames.indexOf(className); + if (index === -1) { + return; + } + classNames.splice(index, 1); + } + element.className = classNames.join(' '); +} + +/** Projects a 3d point into screen space */ +export function vector3DToScreenCoords( + cam: THREE.Camera, w: number, h: number, v: THREE.Vector3): Point2D { + let dpr = window.devicePixelRatio; + let pv = new THREE.Vector3().copy(v).project(cam); + + // The screen-space origin is at the middle of the screen, with +y up. + let coords: Point2D = + [((pv.x + 1) / 2 * w) * dpr, -((pv.y - 1) / 2 * h) * dpr]; + return coords; +} + +/** Loads 3 contiguous elements from a packed xyz array into a Vector3. */ +export function vector3FromPackedArray( + a: Float32Array, pointIndex: number): THREE.Vector3 { + const offset = pointIndex * 3; + return new THREE.Vector3(a[offset], a[offset + 1], a[offset + 2]); +} + +/** + * Gets the camera-space z coordinates of the nearest and farthest points. + * Ignores points that are behind the camera. + */ +export function getNearFarPoints( + worldSpacePoints: Float32Array, cameraPos: THREE.Vector3, + cameraTarget: THREE.Vector3): [number, number] { + let shortestDist: number = Infinity; + let furthestDist: number = 0; + const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos); + const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize(); + const n = worldSpacePoints.length / 3; + let src = 0; + let p = new THREE.Vector3(); + let camToPoint = new THREE.Vector3(); + for (let i = 0; i < n; i++) { + p.x = worldSpacePoints[src]; + p.y = worldSpacePoints[src + 1]; + p.z = worldSpacePoints[src + 2]; + src += 3; + + camToPoint.copy(p).sub(cameraPos); + const dist = camPlaneNormal.dot(camToPoint); + if (dist < 0) { + continue; + } + furthestDist = (dist > furthestDist) ? dist : furthestDist; + shortestDist = (dist < shortestDist) ? dist : shortestDist; + } + return [shortestDist, furthestDist]; +} + +/** + * Generate a texture for the points/images and sets some initial params + */ +export function createTexture(image: HTMLImageElement| + HTMLCanvasElement): THREE.Texture { + let tex = new THREE.Texture(image); + tex.needsUpdate = true; + // Used if the texture isn't a power of 2. + tex.minFilter = THREE.LinearFilter; + tex.generateMipmaps = false; + tex.flipY = false; + return tex; +} + +/** + * Assert that the condition is satisfied; if not, log user-specified message + * to the console. + */ +export function assert(condition: boolean, message?: string) { + if (!condition) { + message = message || 'Assertion failed'; + throw new Error(message); + } +} + +export type SearchPredicate = (p: DataPoint) => boolean; + +export function getSearchPredicate( + query: string, inRegexMode: boolean, fieldName: string): SearchPredicate { + let predicate: SearchPredicate; + if (inRegexMode) { + let regExp = new RegExp(query, 'i'); + predicate = p => regExp.test(p.metadata[fieldName].toString()); + } else { + // Doing a case insensitive substring match. + query = query.toLowerCase(); + predicate = p => { + let label = p.metadata[fieldName].toString().toLowerCase(); + return label.indexOf(query) >= 0; + }; + } + return predicate; +} + +/** + * Runs an expensive task asynchronously with some delay + * so that it doesn't block the UI thread immediately. + * + * @param message The message to display to the user. + * @param task The expensive task to run. + * @param msgId Optional. ID of an existing message. If provided, will overwrite + * an existing message and won't automatically clear the message when the + * task is done. + * @return The value returned by the task. + */ +export function runAsyncTask( + message: string, task: () => T, msgId: string = null): Promise { + let autoClear = (msgId == null); + msgId = logging.setModalMessage(message, msgId); + return new Promise((resolve, reject) => { + setTimeout(() => { + try { + let result = task(); + // Clearing the old message. + if (autoClear) { + logging.setModalMessage(null, msgId); + } + resolve(result); + } catch (ex) { + reject(ex); + } + return true; + }, TASK_DELAY_MS); + }); +} + + +/** + * Parses the URL for query parameters, e.g. ?foo=1&bar=2 will return + * {'foo': '1', 'bar': '2'}. + * @param url The URL to parse. + * @return A map of queryParam key to its value. + */ +export function getURLParams(url: string): {[key: string]: string} { + if (!url) { + return {}; + } + + let queryString = url.indexOf('?') !== -1 ? url.split('?')[1] : url; + if (queryString.indexOf('#')) { + queryString = queryString.split('#')[0]; + } + + const queryEntries = queryString.split('&'); + let queryParams: {[key: string]: string} = {}; + for (let i = 0; i < queryEntries.length; i++) { + let queryEntryComponents = queryEntries[i].split('='); + queryParams[queryEntryComponents[0].toLowerCase()] = + decodeURIComponent(queryEntryComponents[1]); + } + return queryParams; +} + +/** List of substrings that auto generated tensors have in their name. */ +const SUBSTR_GEN_TENSORS = ['/Adagrad']; + +/** Returns true if the tensor was automatically generated by TF API calls. */ +export function tensorIsGenerated(tensorName: string): boolean { + for (let i = 0; i < SUBSTR_GEN_TENSORS.length; i++) { + if (tensorName.indexOf(SUBSTR_GEN_TENSORS[i]) >= 0) { + return true; + } + } + return false; +} + +export function xor(cond1: boolean, cond2: boolean): boolean { + return (cond1 || cond2) && !(cond1 && cond2); +} + +/** Checks to see if the browser supports webgl. */ +export function hasWebGLSupport(): boolean { + try { + let c = document.createElement('canvas'); + let gl = c.getContext('webgl') || c.getContext('experimental-webgl'); + return gl != null && typeof weblas !== 'undefined'; + } catch (e) { + return false; + } +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts new file mode 100644 index 00000000000..f7c0027c81b --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts @@ -0,0 +1,42 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as util from './util'; + +describe('getURLParams', () => { + it('search query with valid param returns correct object', () => { + let urlParams = util.getURLParams('?config=http://google.com/'); + expect(urlParams).toEqual({'config': 'http://google.com/'}); + }); + + it('search query with multiple valid params returns correct object', () => { + let urlParams = util.getURLParams('?config=http://google.com/&foo=bar'); + expect(urlParams).toEqual({'config': 'http://google.com/', 'foo': 'bar'}); + }); + + it('search query with valid param with URL encoded characters', () => { + let urlParams = util.getURLParams('?config=http://google.com/%20search'); + expect(urlParams).toEqual({'config': 'http://google.com/ search'}); + }); + + it('search query with pound sign', () => { + let urlParams = util.getURLParams('?config=http://google.com/#foo'); + expect(urlParams).toEqual({'config': 'http://google.com/'}); + }); + + it('no search query returns empty object', () => { + let urlParams = util.getURLParams(''); + expect(urlParams).toEqual({}); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts new file mode 100644 index 00000000000..0de78ad85df --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts @@ -0,0 +1,266 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 +import {assert} from './util'; + +/** + * @fileoverview Useful vector utilities. + */ + +export type Vector = Float32Array | number[]; +export type Point2D = [number, number]; +export type Point3D = [number, number, number]; + +/** Returns the dot product of two vectors. */ +export function dot(a: Vector, b: Vector): number { + assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * b[i]; + } + return result; +} + +/** Sums all the elements in the vector */ +export function sum(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i]; + } + return result; +} + +/** Returns the sum of two vectors, i.e. a + b */ +export function add(a: Vector, b: Vector): Float32Array { + assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = new Float32Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] + b[i]; + } + return result; +} + +/** Subtracts vector b from vector a, i.e. returns a - b */ +export function sub(a: Vector, b: Vector): Float32Array { + assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = new Float32Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] - b[i]; + } + return result; +} + +/** Returns the square norm of the vector */ +export function norm2(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * a[i]; + } + return result; +} + +/** Returns the euclidean distance between two vectors. */ +export function dist(a: Vector, b: Vector): number { + return Math.sqrt(dist2(a, b)); +} + +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: Vector, b: Vector): number { + assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} + +/** Returns the euclidean distance between 2 3D points. */ +export function dist_3D(a: Vector, b: Vector): number { + return Math.sqrt(dist2_3D(a, b)); +} + +/** + * Returns the square euclidean distance between two vectors, with an early + * exit (returns -1) if the distance is >= to the provided limit. + */ +export function dist2WithLimit(a: Vector, b: Vector, limit: number): number { + assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + if (result >= limit) { + return -1; + } + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist22D(a: Point2D, b: Point2D): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Modifies the vector in-place to have unit norm. */ +export function unit(a: Vector): void { + let norm = Math.sqrt(norm2(a)); + assert(norm >= 0, 'Norm of the vector must be > 0'); + for (let i = 0; i < a.length; ++i) { + a[i] /= norm; + } +} + +/** + * Projects the vectors to a lower dimension + * + * @param vectors Array of vectors to be projected. + * @param newDim The resulting dimension of the vectors. + */ +export function projectRandom(vectors: Float32Array[], newDim: number): + Float32Array[] { + let dim = vectors[0].length; + let N = vectors.length; + let newVectors: Float32Array[] = new Array(N); + for (let i = 0; i < N; ++i) { + newVectors[i] = new Float32Array(newDim); + } + // Make nDim projections. + for (let k = 0; k < newDim; ++k) { + let randomVector = rn(dim); + for (let i = 0; i < N; ++i) { + newVectors[i][k] = dot(vectors[i], randomVector); + } + } + return newVectors; +} + +/** + * Projects a vector onto a 2D plane specified by the two direction vectors. + */ +export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D { + return [dot(a, dir1), dot(a, dir2)]; +} + +/** + * Computes the centroid of the data points. If the provided data points are not + * vectors, an accessor function needs to be provided. + */ +export function centroid(dataPoints: T[], accessor?: (a: T) => Vector): + Vector { + if (dataPoints.length === 0) { + return null; + } + if (accessor == null) { + accessor = (a: T) => a; + } + assert(dataPoints.length >= 0, '`vectors` must be of length >= 1'); + let centroid = new Float32Array(accessor(dataPoints[0]).length); + for (let i = 0; i < dataPoints.length; ++i) { + let dataPoint = dataPoints[i]; + let vector = accessor(dataPoint); + for (let j = 0; j < centroid.length; ++j) { + centroid[j] += vector[j]; + } + } + for (let j = 0; j < centroid.length; ++j) { + centroid[j] /= dataPoints.length; + } + return centroid; +} + +/** + * Generates a vector of the specified size where each component is drawn from + * a random (0, 1) gaussian distribution. + */ +export function rn(size: number): Float32Array { + const normal = d3.randomNormal(); + let result = new Float32Array(size); + for (let i = 0; i < size; ++i) { + result[i] = normal(); + } + return result; +} + +/** + * Returns the cosine distance ([0, 2]) between two vectors + * that have been normalized to unit norm. + */ +export function cosDistNorm(a: Vector, b: Vector): number { + return 1 - dot(a, b); +} + +/** + * Returns the cosine distance ([0, 2]) between two vectors. + */ +export function cosDist(a: Vector, b: Vector): number { + return 1 - cosSim(a, b); +} + +/** Returns the cosine similarity ([-1, 1]) between two vectors. */ +export function cosSim(a: Vector, b: Vector): number { + return dot(a, b) / Math.sqrt(norm2(a) * norm2(b)); +} + +/** + * Converts list of vectors (matrix) into a 1-dimensional + * typed array with row-first order. + */ +export function toTypedArray( + dataPoints: T[], accessor: (dataPoint: T) => Float32Array): Float32Array { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + let result = new Float32Array(N * dim); + for (let i = 0; i < N; ++i) { + let vector = accessor(dataPoints[i]); + for (let d = 0; d < dim; ++d) { + result[i * dim + d] = vector[d]; + } + } + return result; +} + +/** + * Transposes an RxC matrix represented as a flat typed array + * into a CxR matrix, again represented as a flat typed array. + */ +export function transposeTypedArray( + r: number, c: number, typedArray: Float32Array) { + let result = new Float32Array(r * c); + for (let i = 0; i < r; ++i) { + for (let j = 0; j < c; ++j) { + result[j * r + i] = typedArray[i * c + j]; + } + } + return result; +} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html new file mode 100644 index 00000000000..34aca77dde4 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html @@ -0,0 +1,105 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html new file mode 100644 index 00000000000..c37d8d9571f --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html @@ -0,0 +1,205 @@ + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts new file mode 100644 index 00000000000..53195fa47c0 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts @@ -0,0 +1,283 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {State} from './data'; +import {DataProvider, EmbeddingInfo} from './data-provider'; +import * as logging from './logging'; +import {ProjectorEventContext} from './projectorEventContext'; +import {Projector} from './vz-projector'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +// tslint:disable-next-line +export let BookmarkPanelPolymer = PolymerElement({ + is: 'vz-projector-bookmark-panel', + properties: { + savedStates: Object, + // Keep a separate polymer property because the savedStates doesn't change + // when adding and removing states. + hasStates: {type: Boolean, value: false}, + selectedState: Number + } +}); + +export class BookmarkPanel extends BookmarkPanelPolymer { + private projector: Projector; + + // A list containing all of the saved states. + private savedStates: State[]; + private hasStates = false; + private selectedState: number; + private ignoreNextProjectionEvent: boolean; + + private expandLessButton: HTMLButtonElement; + private expandMoreButton: HTMLButtonElement; + + ready() { + this.savedStates = []; + this.setupUploadButton(); + this.ignoreNextProjectionEvent = false; + this.expandLessButton = + this.querySelector('#expand-less') as HTMLButtonElement; + this.expandMoreButton = + this.querySelector('#expand-more') as HTMLButtonElement; + } + + initialize( + projector: Projector, projectorEventContext: ProjectorEventContext) { + this.projector = projector; + projectorEventContext.registerProjectionChangedListener(() => { + if (this.ignoreNextProjectionEvent) { + this.ignoreNextProjectionEvent = false; + } else { + this.clearStateSelection(); + } + }); + } + + setSelectedTensor( + run: string, tensorInfo: EmbeddingInfo, dataProvider: DataProvider) { + // Clear any existing bookmarks. + this.addStates(null); + if (tensorInfo && tensorInfo.bookmarksPath) { + // Get any bookmarks that may come when the projector starts up. + dataProvider.getBookmarks(run, tensorInfo.tensorName, bookmarks => { + this.addStates(bookmarks); + this._expandMore(); + }); + } else { + this._expandLess(); + } + } + + /** Handles a click on show bookmarks tray button. */ + _expandMore() { + this.$.panel.show(); + this.expandMoreButton.style.display = 'none'; + this.expandLessButton.style.display = ''; + } + + /** Handles a click on hide bookmarks tray button. */ + _expandLess() { + this.$.panel.hide(); + this.expandMoreButton.style.display = ''; + this.expandLessButton.style.display = 'none'; + } + + /** Handles a click on the add bookmark button. */ + _addBookmark() { + let currentState = this.projector.getCurrentState(); + currentState.label = 'State ' + this.savedStates.length; + currentState.isSelected = true; + + this.selectedState = this.savedStates.length; + + for (let i = 0; i < this.savedStates.length; i++) { + this.savedStates[i].isSelected = false; + // We have to call notifyPath so that polymer knows this element was + // updated. + this.notifyPath('savedStates.' + i + '.isSelected', false, false); + } + + this.push('savedStates', currentState as any); + this.updateHasStates(); + } + + /** Handles a click on the download bookmarks button. */ + _downloadFile() { + let serializedState = this.serializeAllSavedStates(); + let blob = new Blob([serializedState], {type: 'text/plain'}); + let textFile = window.URL.createObjectURL(blob); + + // Force a download. + let a = document.createElement('a'); + document.body.appendChild(a); + a.style.display = 'none'; + a.href = textFile; + (a as any).download = 'state'; + a.click(); + + document.body.removeChild(a); + window.URL.revokeObjectURL(textFile); + } + + /** Handles a click on the upload bookmarks button. */ + _uploadFile() { + let fileInput = this.dom.select('#state-file'); + (fileInput.node() as HTMLInputElement).click(); + } + + private setupUploadButton() { + // Show and setup the load view button. + const fileInput = this.querySelector('#state-file') as HTMLInputElement; + fileInput.onchange = () => { + const file: File = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + const str: string = fileReader.result; + const savedStates = JSON.parse(str); + + // Verify the bookmarks match. + if (this.savedStatesValid(savedStates)) { + this.addStates(savedStates); + this.loadSavedState(0); + } else { + logging.setWarningMessage( + `Unable to load bookmarks: wrong dataset, expected dataset ` + + `with shape (${savedStates[0].dataSetDimensions}).`); + } + }; + fileReader.readAsText(file); + }; + } + + addStates(savedStates?: State[]) { + if (savedStates == null) { + this.savedStates = []; + } else { + for (let i = 0; i < savedStates.length; i++) { + savedStates[i].isSelected = false; + this.push('savedStates', savedStates[i] as any); + } + } + this.updateHasStates(); + } + + /** Deselects any selected state selection. */ + clearStateSelection() { + for (let i = 0; i < this.savedStates.length; i++) { + this.setSelectionState(i, false); + } + } + + /** Handles a radio button click on a saved state. */ + _radioButtonHandler(evt: Event) { + const index = this.getParentDataIndex(evt); + this.loadSavedState(index); + this.setSelectionState(index, true); + } + + loadSavedState(index: number) { + for (let i = 0; i < this.savedStates.length; i++) { + if (this.savedStates[i].isSelected) { + this.setSelectionState(i, false); + } else if (index === i) { + this.setSelectionState(i, true); + this.ignoreNextProjectionEvent = true; + this.projector.loadState(this.savedStates[i]); + } + } + } + + private setSelectionState(stateIndex: number, selected: boolean) { + this.savedStates[stateIndex].isSelected = selected; + const path = 'savedStates.' + stateIndex + '.isSelected'; + this.notifyPath(path, selected, false); + } + + /** + * Crawls up the DOM to find an ancestor with a data-index attribute. This is + * used to match events to their bookmark index. + */ + private getParentDataIndex(evt: Event) { + for (let i = 0; i < (evt as any).path.length; i++) { + let dataIndex = (evt as any).path[i].getAttribute('data-index'); + if (dataIndex != null) { + return +dataIndex; + } + } + return -1; + } + + /** Handles a clear button click on a bookmark. */ + _clearButtonHandler(evt: Event) { + let index = this.getParentDataIndex(evt); + this.splice('savedStates', index, 1); + this.updateHasStates(); + } + + /** Handles a label change event on a bookmark. */ + _labelChange(evt: Event) { + let index = this.getParentDataIndex(evt); + this.savedStates[index].label = (evt.target as any).value; + } + + /** + * Used to determine whether to select the radio button for a given bookmark. + */ + _isSelectedState(index: number) { + return index === this.selectedState; + } + _isNotSelectedState(index: number) { + return index !== this.selectedState; + } + + /** + * Gets all of the saved states as a serialized string. + */ + serializeAllSavedStates(): string { + return JSON.stringify(this.savedStates); + } + + /** + * Loads all of the serialized states and shows them in the list of + * viewable states. + */ + loadSavedStates(serializedStates: string) { + this.savedStates = JSON.parse(serializedStates); + this.updateHasStates(); + } + + /** + * Updates the hasState polymer property. + */ + private updateHasStates() { + this.hasStates = (this.savedStates.length !== 0); + } + + /** Sanity checks a State array to ensure it matches the current dataset. */ + private savedStatesValid(states: State[]): boolean { + for (let i = 0; i < states.length; i++) { + if (states[i].dataSetDimensions[0] !== this.projector.dataSet.dim[0] || + states[i].dataSetDimensions[1] !== this.projector.dataSet.dim[1]) { + return false; + } + } + return true; + } +} +document.registerElement(BookmarkPanel.prototype.is, BookmarkPanel); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html new file mode 100644 index 00000000000..2acb570b3c1 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html @@ -0,0 +1,32 @@ + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html new file mode 100644 index 00000000000..3857113ac04 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html @@ -0,0 +1,79 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html new file mode 100644 index 00000000000..607d4467892 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html @@ -0,0 +1,399 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts new file mode 100644 index 00000000000..a6847ed3c87 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts @@ -0,0 +1,497 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 +import {ColorOption, ColumnStats, SpriteAndMetadataInfo} from './data'; +import {DataProvider, EmbeddingInfo, parseRawMetadata, parseRawTensors, ProjectorConfig} from './data-provider'; +import * as util from './util'; +import {Projector} from './vz-projector'; +import {ColorLegendRenderInfo, ColorLegendThreshold} from './vz-projector-legend'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +export let DataPanelPolymer = PolymerElement({ + is: 'vz-projector-data-panel', + properties: { + selectedTensor: {type: String, observer: '_selectedTensorChanged'}, + selectedRun: {type: String, observer: '_selectedRunChanged'}, + selectedColorOptionName: { + type: String, + notify: true, + observer: '_selectedColorOptionNameChanged' + }, + selectedLabelOption: + {type: String, notify: true, observer: '_selectedLabelOptionChanged'}, + normalizeData: Boolean, + showForceCategoricalColorsCheckbox: Boolean + } +}); + +export class DataPanel extends DataPanelPolymer { + selectedLabelOption: string; + selectedColorOptionName: string; + showForceCategoricalColorsCheckbox: boolean; + + private normalizeData: boolean; + private labelOptions: string[]; + private colorOptions: ColorOption[]; + forceCategoricalColoring: boolean = false; + + private selectedTensor: string; + private selectedRun: string; + private dataProvider: DataProvider; + private tensorNames: {name: string, shape: number[]}[]; + private runNames: string[]; + private projector: Projector; + private projectorConfig: ProjectorConfig; + private colorLegendRenderInfo: ColorLegendRenderInfo; + private spriteAndMetadata: SpriteAndMetadataInfo; + private metadataFile: string; + + ready() { + this.normalizeData = true; + } + + initialize(projector: Projector, dp: DataProvider) { + this.projector = projector; + this.dataProvider = dp; + this.setupUploadButtons(); + + // Tell the projector whenever the data normalization changes. + // Unknown why, but the polymer checkbox button stops working as soon as + // you do d3.select() on it. + this.querySelector('#normalize-data-checkbox') + .addEventListener('change', () => { + this.projector.setNormalizeData(this.normalizeData); + }); + + let forceCategoricalColoringCheckbox = + this.querySelector('#force-categorical-checkbox'); + forceCategoricalColoringCheckbox.addEventListener('change', () => { + this.setForceCategoricalColoring( + (forceCategoricalColoringCheckbox as HTMLInputElement).checked); + }); + + // Get all the runs. + this.dataProvider.retrieveRuns(runs => { + this.runNames = runs; + // Choose the first run by default. + if (this.runNames.length > 0) { + this.selectedRun = runs[0]; + } + }); + } + + setForceCategoricalColoring(forceCategoricalColoring: boolean) { + this.forceCategoricalColoring = forceCategoricalColoring; + (this.querySelector('#force-categorical-checkbox') as HTMLInputElement) + .checked = this.forceCategoricalColoring; + + this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); + + // The selected color option name doesn't change when we switch to using + // categorical coloring for stats with too many unique values, so we + // manually call this polymer observer so that we update the UI. + this._selectedColorOptionNameChanged(); + } + + getSeparatorClass(isSeparator: boolean): string { + return isSeparator ? 'separator' : null; + } + + metadataChanged( + spriteAndMetadata: SpriteAndMetadataInfo, metadataFile: string) { + this.spriteAndMetadata = spriteAndMetadata; + this.metadataFile = metadataFile; + + this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); + this.selectedColorOptionName = this.colorOptions[0].name; + } + + private addWordBreaks(longString: string): string { + if (longString == null) { + return ''; + } + return longString.replace(/([\/=-_,])/g, '$1'); + } + + private updateMetadataUI(columnStats: ColumnStats[], metadataFile: string) { + const metadataFileElement = + this.querySelector('#metadata-file') as HTMLSpanElement; + metadataFileElement.innerHTML = this.addWordBreaks(metadataFile); + metadataFileElement.title = metadataFile; + + // Label by options. + let labelIndex = -1; + this.labelOptions = columnStats.map((stats, i) => { + // Make the default label by the first non-numeric column. + if (!stats.isNumeric && labelIndex === -1) { + labelIndex = i; + } + return stats.name; + }); + this.selectedLabelOption = this.labelOptions[Math.max(0, labelIndex)]; + + // Color by options. + const standardColorOption: ColorOption[] = [ + {name: 'No color map'}, + // TODO(smilkov): Implement this. + // {name: 'Distance of neighbors', + // desc: 'How far is each point from its neighbors'} + ]; + const metadataColorOption: ColorOption[] = + columnStats + .filter(stats => { + return !stats.tooManyUniqueValues || stats.isNumeric; + }) + .map(stats => { + let map; + let items: {label: string, count: number}[]; + let thresholds: ColorLegendThreshold[]; + let isCategorical = + this.forceCategoricalColoring || !stats.tooManyUniqueValues; + if (isCategorical) { + const scale = d3.scaleOrdinal(d3.schemeCategory20); + let range = scale.range(); + // Re-order the range. + let newRange = range.map((color, i) => { + let index = (i * 3) % range.length; + return range[index]; + }); + items = stats.uniqueEntries; + scale.range(newRange).domain(items.map(x => x.label)); + map = scale; + } else { + thresholds = [ + {color: '#ffffdd', value: stats.min}, + {color: '#1f2d86', value: stats.max} + ]; + map = d3.scaleLinear() + .domain(thresholds.map(t => t.value)) + .range(thresholds.map(t => t.color)); + } + let desc = !isCategorical ? 'gradient' : + stats.uniqueEntries.length + + ((stats.uniqueEntries.length > 20) ? ' non-unique' : '') + + ' colors'; + return { + name: stats.name, + desc: desc, + map: map, + items: items, + thresholds: thresholds, + tooManyUniqueValues: stats.tooManyUniqueValues + }; + }); + + if (metadataColorOption.length > 0) { + // Add a separator line between built-in color maps + // and those based on metadata columns. + standardColorOption.push({name: 'Metadata', isSeparator: true}); + } + this.colorOptions = standardColorOption.concat(metadataColorOption); + } + + setNormalizeData(normalizeData: boolean) { + this.normalizeData = normalizeData; + } + + _selectedTensorChanged() { + this.projector.updateDataSet(null, null, null); + if (this.selectedTensor == null) { + return; + } + this.dataProvider.retrieveTensor( + this.selectedRun, this.selectedTensor, ds => { + let metadataFile = + this.getEmbeddingInfoByName(this.selectedTensor).metadataPath; + this.dataProvider.retrieveSpriteAndMetadata( + this.selectedRun, this.selectedTensor, metadata => { + this.projector.updateDataSet(ds, metadata, metadataFile); + }); + }); + this.projector.setSelectedTensor( + this.selectedRun, this.getEmbeddingInfoByName(this.selectedTensor)); + } + + _selectedRunChanged() { + this.dataProvider.retrieveProjectorConfig(this.selectedRun, info => { + this.projectorConfig = info; + let names = + this.projectorConfig.embeddings.map(e => e.tensorName) + .filter(name => { + let shape = this.getEmbeddingInfoByName(name).tensorShape; + return shape.length === 2 && shape[0] > 1 && shape[1] > 1; + }) + .sort((a, b) => { + let embA = this.getEmbeddingInfoByName(a); + let embB = this.getEmbeddingInfoByName(b); + + // Prefer tensors with metadata. + if (util.xor(!!embA.metadataPath, !!embB.metadataPath)) { + return embA.metadataPath ? -1 : 1; + } + + // Prefer non-generated tensors. + let isGenA = util.tensorIsGenerated(a); + let isGenB = util.tensorIsGenerated(b); + if (util.xor(isGenA, isGenB)) { + return isGenB ? -1 : 1; + } + + // Prefer bigger tensors. + let sizeA = embA.tensorShape[0]; + let sizeB = embB.tensorShape[0]; + if (sizeA !== sizeB) { + return sizeB - sizeA; + } + + // Sort alphabetically by tensor name. + return a <= b ? -1 : 1; + }); + this.tensorNames = names.map(name => { + return {name, shape: this.getEmbeddingInfoByName(name).tensorShape}; + }); + const wordBreakablePath = + this.addWordBreaks(this.projectorConfig.modelCheckpointPath); + const checkpointFile = + this.querySelector('#checkpoint-file') as HTMLSpanElement; + checkpointFile.innerHTML = wordBreakablePath; + checkpointFile.title = this.projectorConfig.modelCheckpointPath; + + // If in demo mode, let the order decide which tensor to load by default. + const defaultTensor = this.projector.servingMode === 'demo' ? + this.projectorConfig.embeddings[0].tensorName : + names[0]; + if (this.selectedTensor === defaultTensor) { + // Explicitly call the observer. Polymer won't call it if the previous + // string matches the current string. + this._selectedTensorChanged(); + } else { + this.selectedTensor = defaultTensor; + } + }); + } + + _selectedLabelOptionChanged() { + this.projector.setSelectedLabelOption(this.selectedLabelOption); + } + + _selectedColorOptionNameChanged() { + let colorOption: ColorOption; + for (let i = 0; i < this.colorOptions.length; i++) { + if (this.colorOptions[i].name === this.selectedColorOptionName) { + colorOption = this.colorOptions[i]; + break; + } + } + if (!colorOption) { + return; + } + + this.showForceCategoricalColorsCheckbox = !!colorOption.tooManyUniqueValues; + + if (colorOption.map == null) { + this.colorLegendRenderInfo = null; + } else if (colorOption.items) { + let items = colorOption.items.map(item => { + return { + color: colorOption.map(item.label), + label: item.label, + count: item.count + }; + }); + this.colorLegendRenderInfo = {items, thresholds: null}; + } else { + this.colorLegendRenderInfo = { + items: null, + thresholds: colorOption.thresholds + }; + } + this.projector.setSelectedColorOption(colorOption); + } + + private tensorWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { + parseRawTensors(rawContents, ds => { + const checkpointFile = + this.querySelector('#checkpoint-file') as HTMLSpanElement; + checkpointFile.innerText = fileName; + checkpointFile.title = fileName; + this.projector.updateDataSet(ds); + }); + } + + private metadataWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { + parseRawMetadata(rawContents, metadata => { + this.projector.updateDataSet(this.projector.dataSet, metadata, fileName); + }); + } + + private getEmbeddingInfoByName(tensorName: string): EmbeddingInfo { + for (let i = 0; i < this.projectorConfig.embeddings.length; i++) { + const e = this.projectorConfig.embeddings[i]; + if (e.tensorName === tensorName) { + return e; + } + } + } + + private setupUploadButtons() { + // Show and setup the upload button. + const fileInput = this.querySelector('#file') as HTMLInputElement; + fileInput.onchange = () => { + const file: File = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = evt => { + const content: ArrayBuffer = fileReader.result; + this.tensorWasReadFromFile(content, file.name); + }; + fileReader.readAsArrayBuffer(file); + }; + + const uploadButton = + this.querySelector('#upload-tensors') as HTMLButtonElement; + uploadButton.onclick = () => { + fileInput.click(); + }; + + // Show and setup the upload metadata button. + const fileMetadataInput = + this.querySelector('#file-metadata') as HTMLInputElement; + fileMetadataInput.onchange = () => { + const file: File = fileMetadataInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileMetadataInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = evt => { + const contents: ArrayBuffer = fileReader.result; + this.metadataWasReadFromFile(contents, file.name); + }; + fileReader.readAsArrayBuffer(file); + }; + + const uploadMetadataButton = + this.querySelector('#upload-metadata') as HTMLButtonElement; + uploadMetadataButton.onclick = () => { + fileMetadataInput.click(); + }; + + if (this.projector.servingMode !== 'demo') { + (this.$$('#publish-container') as HTMLElement).style.display = 'none'; + (this.$$('#upload-tensors-step-container') as HTMLElement).style.display = + 'none'; + (this.$$('#upload-metadata-label') as HTMLElement).style.display = 'none'; + } + + (this.$$('#demo-data-buttons-container') as HTMLElement).style.display = + 'block'; + + // Fill out the projector config. + const projectorConfigTemplate = + this.$$('#projector-config-template') as HTMLTextAreaElement; + const projectorConfigTemplateJson: ProjectorConfig = { + embeddings: [{ + tensorName: 'My tensor', + tensorShape: [1000, 50], + tensorPath: 'https://raw.githubusercontent.com/.../tensors.tsv', + metadataPath: + 'https://raw.githubusercontent.com/.../optional.metadata.tsv', + }], + }; + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, projectorConfigTemplateJson); + + // Set up optional field checkboxes. + const spriteFieldCheckbox = + this.$$('#config-sprite-checkbox') as HTMLInputElement; + spriteFieldCheckbox.onchange = () => { + if ((spriteFieldCheckbox as any).checked) { + projectorConfigTemplateJson.embeddings[0].sprite = { + imagePath: 'https://github.com/.../optional.sprite.png', + singleImageDim: [32, 32] + }; + } else { + delete projectorConfigTemplateJson.embeddings[0].sprite; + } + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, projectorConfigTemplateJson); + }; + const bookmarksFieldCheckbox = + this.$$('#config-bookmarks-checkbox') as HTMLInputElement; + bookmarksFieldCheckbox.onchange = () => { + if ((bookmarksFieldCheckbox as any).checked) { + projectorConfigTemplateJson.embeddings[0].bookmarksPath = + 'https://raw.githubusercontent.com/.../bookmarks.txt'; + } else { + delete projectorConfigTemplateJson.embeddings[0].bookmarksPath; + } + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, projectorConfigTemplateJson); + }; + const metadataFieldCheckbox = + this.$$('#config-metadata-checkbox') as HTMLInputElement; + metadataFieldCheckbox.onchange = () => { + if ((metadataFieldCheckbox as HTMLInputElement).checked) { + projectorConfigTemplateJson.embeddings[0].metadataPath = + 'https://raw.githubusercontent.com/.../optional.metadata.tsv'; + } else { + delete projectorConfigTemplateJson.embeddings[0].metadataPath; + } + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, projectorConfigTemplateJson); + }; + + // Update the link and the readonly shareable URL. + const projectorConfigUrlInput = + this.$$('#projector-config-url') as HTMLInputElement; + const projectorConfigDemoUrlInput = this.$$('#projector-share-url'); + const projectorConfigDemoUrlLink = this.$$('#projector-share-url-link'); + projectorConfigUrlInput.onchange = () => { + let projectorDemoUrl = location.protocol + '//' + location.host + + location.pathname + + '?config=' + (projectorConfigUrlInput as HTMLInputElement).value; + + (projectorConfigDemoUrlInput as HTMLInputElement).value = + projectorDemoUrl; + (projectorConfigDemoUrlLink as HTMLLinkElement).href = projectorDemoUrl; + }; + } + + private setProjectorConfigTemplateJson( + projectorConfigTemplate: HTMLTextAreaElement, config: ProjectorConfig) { + projectorConfigTemplate.value = + JSON.stringify(config, null, /** replacer */ 2 /** white space */); + } + + _getNumTensorsLabel(): string { + return this.tensorNames.length === 1 ? '1 tensor' : + this.tensorNames.length + ' tensors'; + } + + _getNumRunsLabel(): string { + return this.runNames.length === 1 ? '1 run' : + this.runNames.length + ' runs'; + } + + _hasChoices(choices: any[]): boolean { + return choices.length > 1; + } +} + +document.registerElement(DataPanel.prototype.is, DataPanel); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html new file mode 100644 index 00000000000..e77694426eb --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html @@ -0,0 +1,64 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts new file mode 100644 index 00000000000..e11346d327f --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +// tslint:disable-next-line +export let PolymerClass = PolymerElement( + {is: 'vz-projector-input', properties: {label: String, message: String}}); + +export interface InputChangedListener { + (value: string, inRegexMode: boolean): void; +} + +/** Input control with custom capabilities (e.g. regex). */ +export class ProjectorInput extends PolymerClass { + private textChangedListeners: InputChangedListener[]; + private paperInput: HTMLInputElement; + private inRegexModeButton: HTMLButtonElement; + private inRegexMode: boolean; + + /** Message that will be displayed at the bottom of the input control. */ + message: string; + + /** Subscribe to be called everytime the input changes. */ + registerInputChangedListener(listener: InputChangedListener) { + this.textChangedListeners.push(listener); + } + + ready() { + this.inRegexMode = false; + this.textChangedListeners = []; + this.paperInput = this.querySelector('paper-input') as HTMLInputElement; + this.inRegexModeButton = + this.querySelector('paper-button') as HTMLButtonElement; + this.paperInput.setAttribute('error-message', 'Invalid regex'); + + this.paperInput.addEventListener('input', () => { + this.onTextChanged(); + }); + + this.paperInput.addEventListener('keydown', event => { + event.stopPropagation(); + }); + + this.inRegexModeButton.addEventListener( + 'click', () => this.onClickRegexModeButton()); + this.updateRegexModeDisplaySlashes(); + this.onTextChanged(); + } + + private onClickRegexModeButton() { + this.inRegexMode = (this.inRegexModeButton as any).active; + this.updateRegexModeDisplaySlashes(); + this.onTextChanged(); + } + + private notifyInputChanged(value: string, inRegexMode: boolean) { + this.textChangedListeners.forEach(l => l(value, inRegexMode)); + } + + private onTextChanged() { + try { + if (this.inRegexMode) { + new RegExp(this.paperInput.value); + } + } catch (invalidRegexException) { + this.paperInput.setAttribute('invalid', 'true'); + this.message = ''; + this.notifyInputChanged(null, true); + return; + } + this.paperInput.removeAttribute('invalid'); + this.notifyInputChanged(this.paperInput.value, this.inRegexMode); + } + + private updateRegexModeDisplaySlashes() { + const slashes = this.paperInput.querySelectorAll('.slash'); + const display = this.inRegexMode ? '' : 'none'; + + for (let i = 0; i < slashes.length; i++) { + (slashes[i] as HTMLDivElement).style.display = display; + } + } + + getValue(): string { + return this.paperInput.value; + } + + getInRegexMode(): boolean { + return this.inRegexMode; + } + + set(value: string, inRegexMode: boolean) { + (this.inRegexModeButton as any).active = inRegexMode; + this.paperInput.value = value; + this.onClickRegexModeButton(); + } +} + +document.registerElement(ProjectorInput.prototype.is, ProjectorInput); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html new file mode 100644 index 00000000000..7554c322cef --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html @@ -0,0 +1,240 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts new file mode 100644 index 00000000000..3ee2c2165f2 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts @@ -0,0 +1,337 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DistanceFunction, SpriteAndMetadataInfo, State} from './data'; +import * as knn from './knn'; +import {ProjectorEventContext} from './projectorEventContext'; +import * as adapter from './projectorScatterPlotAdapter'; +import * as util from './util'; +import * as vector from './vector'; +import {Projector} from './vz-projector'; +import {ProjectorInput} from './vz-projector-input'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +/** Limit the number of search results we show to the user. */ +const LIMIT_RESULTS = 100; + +// tslint:disable-next-line +export let PolymerClass = PolymerElement({ + is: 'vz-projector-inspector-panel', + properties: {selectedMetadataField: String, metadataFields: Array} +}); + +export class InspectorPanel extends PolymerClass { + distFunc: DistanceFunction; + numNN: number; + + private projectorEventContext: ProjectorEventContext; + + private selectedMetadataField: string; + private metadataFields: string[]; + private projector: Projector; + private selectedPointIndices: number[]; + private neighborsOfFirstPoint: knn.NearestEntry[]; + private searchBox: ProjectorInput; + + private resetFilterButton: HTMLButtonElement; + private setFilterButton: HTMLButtonElement; + private clearSelectionButton: HTMLButtonElement; + private limitMessage: HTMLDivElement; + + ready() { + this.resetFilterButton = + this.querySelector('.reset-filter') as HTMLButtonElement; + this.setFilterButton = + this.querySelector('.set-filter') as HTMLButtonElement; + this.clearSelectionButton = + this.querySelector('.clear-selection') as HTMLButtonElement; + this.limitMessage = this.querySelector('.limit-msg') as HTMLDivElement; + this.searchBox = this.querySelector('#search-box') as ProjectorInput; + // https://www.polymer-project.org/1.0/docs/devguide/styling#scope-subtree + this.scopeSubtree(this, true); + } + + initialize( + projector: Projector, projectorEventContext: ProjectorEventContext) { + this.projector = projector; + this.projectorEventContext = projectorEventContext; + this.setupUI(projector); + projectorEventContext.registerSelectionChangedListener( + (selection, neighbors) => + this.updateInspectorPane(selection, neighbors)); + } + + /** Updates the nearest neighbors list in the inspector. */ + private updateInspectorPane( + indices: number[], neighbors: knn.NearestEntry[]) { + this.neighborsOfFirstPoint = neighbors; + this.selectedPointIndices = indices; + + this.updateFilterButtons(indices.length + neighbors.length); + this.updateNeighborsList(neighbors); + if (neighbors.length === 0) { + this.updateSearchResults(indices); + } else { + this.updateSearchResults([]); + } + } + + private enableResetFilterButton(enabled: boolean) { + this.resetFilterButton.disabled = !enabled; + } + + restoreUIFromBookmark(bookmark: State) { + this.enableResetFilterButton(bookmark.filteredPoints != null); + } + + metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { + let labelIndex = -1; + this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { + if (!stats.isNumeric && labelIndex === -1) { + labelIndex = i; + } + return stats.name; + }); + labelIndex = Math.max(0, labelIndex); + // Make the default label the first non-numeric column. + this.selectedMetadataField = spriteAndMetadata.stats[labelIndex].name; + } + + datasetChanged() { + this.enableResetFilterButton(false); + } + + private updateSearchResults(indices: number[]) { + const container = this.querySelector('.matches-list') as HTMLDivElement; + container.style.display = indices.length ? null : 'none'; + const list = container.querySelector('.list') as HTMLDivElement; + list.innerHTML = ''; + if (indices.length === 0) { + return; + } + + this.limitMessage.style.display = + indices.length <= LIMIT_RESULTS ? 'none' : null; + indices = indices.slice(0, LIMIT_RESULTS); + + for (let i = 0; i < indices.length; i++) { + const index = indices[i]; + + const row = document.createElement('div'); + row.className = 'row'; + + const label = this.getLabelFromIndex(index); + const rowLink = document.createElement('a'); + rowLink.className = 'label'; + rowLink.title = label; + rowLink.innerText = label; + + rowLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(index); + }; + rowLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); + }; + rowLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([index]); + }; + + row.appendChild(rowLink); + list.appendChild(row); + } + } + + private getLabelFromIndex(pointIndex: number): string { + const point = this.projector.dataSet.points[pointIndex]; + return point.metadata[this.selectedMetadataField].toString(); + } + + private updateNeighborsList(neighbors: knn.NearestEntry[]) { + const nnlist = this.querySelector('.nn-list') as HTMLDivElement; + nnlist.innerHTML = ''; + + (this.querySelector('.nn') as HTMLDivElement).style.display = + neighbors.length ? null : 'none'; + + if (neighbors.length === 0) { + return; + } + + this.searchBox.message = ''; + const minDist = neighbors.length > 0 ? neighbors[0].dist : 0; + + for (let i = 0; i < neighbors.length; i++) { + const neighbor = neighbors[i]; + + const neighborElement = document.createElement('div'); + neighborElement.className = 'neighbor'; + + const neighborElementLink = document.createElement('a'); + neighborElementLink.className = 'neighbor-link'; + neighborElementLink.title = this.getLabelFromIndex(neighbor.index); + + const labelValueElement = document.createElement('div'); + labelValueElement.className = 'label-and-value'; + + const labelElement = document.createElement('div'); + labelElement.className = 'label'; + labelElement.style.color = + adapter.dist2color(this.distFunc, neighbor.dist, minDist); + labelElement.innerText = this.getLabelFromIndex(neighbor.index); + + const valueElement = document.createElement('div'); + valueElement.className = 'value'; + valueElement.innerText = neighbor.dist.toFixed(3); + + labelValueElement.appendChild(labelElement); + labelValueElement.appendChild(valueElement); + + const barElement = document.createElement('div'); + barElement.className = 'bar'; + + const barFillElement = document.createElement('div'); + barFillElement.className = 'fill'; + barFillElement.style.borderTopColor = + adapter.dist2color(this.distFunc, neighbor.dist, minDist); + barFillElement.style.width = + adapter.normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 + + '%'; + barElement.appendChild(barFillElement); + + for (let j = 1; j < 4; j++) { + const tickElement = document.createElement('div'); + tickElement.className = 'tick'; + tickElement.style.left = j * 100 / 4 + '%'; + barElement.appendChild(tickElement); + } + + neighborElementLink.appendChild(labelValueElement); + neighborElementLink.appendChild(barElement); + neighborElement.appendChild(neighborElementLink); + nnlist.appendChild(neighborElement); + + neighborElementLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(neighbor.index); + }; + neighborElementLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); + }; + neighborElementLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([neighbor.index]); + }; + } + } + + private updateFilterButtons(numPoints: number) { + if (numPoints > 1) { + this.setFilterButton.innerText = `Isolate ${numPoints} points`; + this.setFilterButton.disabled = null; + this.clearSelectionButton.disabled = null; + } else { + this.setFilterButton.disabled = true; + this.clearSelectionButton.disabled = true; + } + } + + private setupUI(projector: Projector) { + this.distFunc = vector.cosDist; + const eucDist = + this.querySelector('.distance a.euclidean') as HTMLLinkElement; + eucDist.onclick = () => { + const links = this.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(eucDist as HTMLElement, 'selected', true); + + this.distFunc = vector.dist; + this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); + const neighbors = projector.dataSet.findNeighbors( + this.selectedPointIndices[0], this.distFunc, this.numNN); + this.updateNeighborsList(neighbors); + }; + + const cosDist = this.querySelector('.distance a.cosine') as HTMLLinkElement; + cosDist.onclick = () => { + const links = this.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(cosDist, 'selected', true); + + this.distFunc = vector.cosDist; + this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); + const neighbors = projector.dataSet.findNeighbors( + this.selectedPointIndices[0], this.distFunc, this.numNN); + this.updateNeighborsList(neighbors); + }; + + // Called whenever the search text input changes. + const updateInput = (value: string, inRegexMode: boolean) => { + if (value == null || value.trim() === '') { + this.searchBox.message = ''; + this.projectorEventContext.notifySelectionChanged([]); + return; + } + const indices = projector.dataSet.query( + value, inRegexMode, this.selectedMetadataField); + if (indices.length === 0) { + this.searchBox.message = '0 matches.'; + } else { + this.searchBox.message = `${indices.length} matches.`; + } + this.projectorEventContext.notifySelectionChanged(indices); + }; + this.searchBox.registerInputChangedListener((value, inRegexMode) => { + updateInput(value, inRegexMode); + }); + + // Nearest neighbors controls. + const numNNInput = this.$$('#nn-slider') as HTMLInputElement; + const updateNumNN = () => { + this.numNN = +numNNInput.value; + (this.querySelector('.num-nn .nn-count') as HTMLSpanElement).innerText = + '' + this.numNN; + if (this.selectedPointIndices != null) { + this.projectorEventContext.notifySelectionChanged( + [this.selectedPointIndices[0]]); + } + }; + numNNInput.addEventListener('change', updateNumNN); + updateNumNN(); + + // Filtering dataset. + this.setFilterButton.onclick = () => { + const indices = this.selectedPointIndices.concat( + this.neighborsOfFirstPoint.map(n => n.index)); + projector.filterDataset(indices); + this.enableResetFilterButton(true); + this.updateFilterButtons(0); + }; + + this.resetFilterButton.onclick = () => { + projector.resetFilterDataset(); + this.enableResetFilterButton(false); + }; + + this.clearSelectionButton.onclick = () => { + projector.adjustSelectionAndHover([]); + }; + this.enableResetFilterButton(false); + } +} + +document.registerElement(InspectorPanel.prototype.is, InspectorPanel); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html new file mode 100644 index 00000000000..3fc5f4db158 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html @@ -0,0 +1,76 @@ + + + + + + + + \ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts new file mode 100644 index 00000000000..1c4ddf940dc --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts @@ -0,0 +1,98 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +// tslint:disable-next-line +export let LegendPolymer = PolymerElement({ + is: 'vz-projector-legend', + properties: {renderInfo: {type: Object, observer: '_renderInfoChanged'}} +}); + +export interface ColorLegendRenderInfo { + // To be used for categorical map. + items: ColorLegendItem[]; + // To be used for gradient map. + thresholds: ColorLegendThreshold[]; +} + +/** An item in the categorical color legend. */ +export interface ColorLegendItem { + color: string; + label: string; + count: number; +} + +/** An item in the gradient color legend. */ +export interface ColorLegendThreshold { + color: string; + value: number; +} + +export class Legend extends LegendPolymer { + renderInfo: ColorLegendRenderInfo; + + _renderInfoChanged() { + if (this.renderInfo == null) { + return; + } + if (this.renderInfo.thresholds) { + // is under dom-if so we should wait for it to be + // inserted in the dom tree using async(). + this.async(() => this.setupLinearGradient()); + } + } + + _getLastThreshold(): number { + if (this.renderInfo == null || this.renderInfo.thresholds == null) { + return; + } + return this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1] + .value; + } + + private getOffset(value: number): string { + const min = this.renderInfo.thresholds[0].value; + const max = + this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1].value; + return (100 * (value - min) / (max - min)).toFixed(2) + '%'; + } + + private setupLinearGradient() { + const linearGradient = + this.querySelector('#gradient') as SVGLinearGradientElement; + + const width = + (this.querySelector('svg.gradient') as SVGElement).clientWidth; + + // Set the svg to be the width of its parent. + (this.querySelector('svg.gradient rect') as SVGRectElement).style.width = + width + 'px'; + + // Remove all children from before. + linearGradient.innerHTML = ''; + + // Add a child in for each gradient threshold. + this.renderInfo.thresholds.forEach(t => { + const stopElement = + document.createElementNS('http://www.w3.org/2000/svg', 'stop'); + stopElement.setAttribute('offset', this.getOffset(t.value)); + stopElement.setAttribute('stop-color', t.color); + }); + } +} + +document.registerElement(Legend.prototype.is, Legend); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html new file mode 100644 index 00000000000..ebdcd72c77d --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html @@ -0,0 +1,97 @@ + + + + + + + +
+
+ + +
+ + +
+ +
+
+
+ + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts new file mode 100644 index 00000000000..939300f3878 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {PointMetadata} from './data'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +// tslint:disable-next-line +export let MetadataCardPolymer = PolymerElement({ + is: 'vz-projector-metadata-card', + properties: { + hasMetadata: {type: Boolean, value: false}, + metadata: {type: Array}, + label: String + } +}); + +export class MetadataCard extends MetadataCardPolymer { + hasMetadata: boolean; + metadata: Array<{key: string, value: string}>; + label: string; + + private labelOption: string; + private pointMetadata: PointMetadata; + + private expandLessButton: HTMLButtonElement; + private expandMoreButton: HTMLButtonElement; + + ready() { + this.expandLessButton = + this.querySelector('#expand-less') as HTMLButtonElement; + this.expandMoreButton = + this.querySelector('#expand-more') as HTMLButtonElement; + } + /** Handles a click on the expand more icon. */ + _expandMore() { + (this.$$('#metadata-container') as any).toggle(); + + this.expandMoreButton.style.display = 'none'; + this.expandLessButton.style.display = ''; + } + + /** Handles a click on the expand less icon. */ + _expandLess() { + (this.$$('#metadata-container') as any).toggle(); + this.expandMoreButton.style.display = ''; + this.expandLessButton.style.display = 'none'; + } + + updateMetadata(pointMetadata?: PointMetadata) { + this.pointMetadata = pointMetadata; + this.hasMetadata = (pointMetadata != null); + + if (pointMetadata) { + let metadata = []; + for (let metadataKey in pointMetadata) { + if (!pointMetadata.hasOwnProperty(metadataKey)) { + continue; + } + metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); + } + + this.metadata = metadata; + this.label = '' + this.pointMetadata[this.labelOption]; + } + } + + setLabelOption(labelOption: string) { + this.labelOption = labelOption; + if (this.pointMetadata) { + this.label = '' + this.pointMetadata[this.labelOption]; + } + } +} + +document.registerElement(MetadataCard.prototype.is, MetadataCard); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html new file mode 100644 index 00000000000..cddcb2b7d08 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html @@ -0,0 +1,314 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts new file mode 100644 index 00000000000..377c6c11ad5 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts @@ -0,0 +1,589 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as data from './data'; +import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; +import * as util from './util'; +import * as vector from './vector'; +import {Vector} from './vector'; +import {Projector} from './vz-projector'; +import {ProjectorInput} from './vz-projector-input'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +const NUM_PCA_COMPONENTS = 10; + +// tslint:disable-next-line +export let ProjectionsPanelPolymer = PolymerElement({ + is: 'vz-projector-projections-panel', + properties: { + pcaIs3d: + {type: Boolean, value: true, observer: '_pcaDimensionToggleObserver'}, + tSNEis3d: + {type: Boolean, value: true, observer: '_tsneDimensionToggleObserver'}, + // PCA projection. + pcaComponents: Array, + pcaX: {type: Number, value: 0, observer: 'showPCAIfEnabled'}, + pcaY: {type: Number, value: 1, observer: 'showPCAIfEnabled'}, + pcaZ: {type: Number, value: 2, observer: 'showPCAIfEnabled'}, + // Custom projection. + customSelectedSearchByMetadataOption: { + type: String, + observer: '_customSelectedSearchByMetadataOptionChanged' + }, + } +}); + +type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown'; + +type CentroidResult = { + centroid?: Vector; numMatches?: number; +}; + +type Centroids = { + [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector; + yDown: Vector; +}; + +/** + * A polymer component which handles the projection tabs in the projector. + */ +export class ProjectionsPanel extends ProjectionsPanelPolymer { + private projector: Projector; + private pcaComponents: + Array<{id: number, componentNumber: number, percVariance: string}>; + private currentProjection: ProjectionType; + private polymerChangesTriggerReprojection: boolean; + private dataSet: DataSet; + private originalDataSet: DataSet; + private dim: number; + + /** T-SNE perplexity. Roughly how many neighbors each point influences. */ + private perplexity: number; + /** T-SNE learning rate. */ + private learningRate: number; + + private searchByMetadataOptions: string[]; + + /** Centroids for custom projections. */ + private centroidValues: any; + private centroids: Centroids; + /** The centroid across all points. */ + private allCentroid: number[]; + + /** Polymer properties. */ + // TODO(nsthorat): Move these to a separate view controller. + public tSNEis3d: boolean; + public pcaIs3d: boolean; + public pcaX: number; + public pcaY: number; + public pcaZ: number; + public customSelectedSearchByMetadataOption: string; + + /** Polymer elements. */ + private runTsneButton: HTMLButtonElement; + private stopTsneButton: HTMLButtonElement; + private perplexitySlider: HTMLInputElement; + private learningRateInput: HTMLInputElement; + private zDropdown: HTMLElement; + private iterationLabel: HTMLElement; + + private customProjectionXLeftInput: ProjectorInput; + private customProjectionXRightInput: ProjectorInput; + private customProjectionYUpInput: ProjectorInput; + private customProjectionYDownInput: ProjectorInput; + + initialize(projector: Projector) { + this.polymerChangesTriggerReprojection = true; + this.projector = projector; + + // Set up TSNE projections. + this.perplexity = 30; + this.learningRate = 10; + + // Setup Custom projections. + this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null}; + this.clearCentroids(); + + this.setupUIControls(); + } + + ready() { + this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement; + this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement; + this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement; + this.perplexitySlider = + this.querySelector('#perplexity-slider') as HTMLInputElement; + this.learningRateInput = + this.querySelector('#learning-rate-slider') as HTMLInputElement; + this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement; + } + + disablePolymerChangesTriggerReprojection() { + this.polymerChangesTriggerReprojection = false; + } + + enablePolymerChangesTriggerReprojection() { + this.polymerChangesTriggerReprojection = true; + } + + private updateTSNEPerplexityFromSliderChange() { + if (this.perplexitySlider) { + this.perplexity = +this.perplexitySlider.value; + } + (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText = + '' + this.perplexity; + } + + private updateTSNELearningRateFromUIChange() { + if (this.learningRateInput) { + this.learningRate = Math.pow(10, +this.learningRateInput.value); + } + (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement) + .innerText = '' + this.learningRate; + } + + private setupUIControls() { + { + const self = this; + const inkTabs = this.querySelectorAll('.ink-tab'); + for (let i = 0; i < inkTabs.length; i++) { + inkTabs[i].addEventListener('click', function() { + let id = this.getAttribute('data-tab'); + self.showTab(id); + }); + } + } + + this.runTsneButton.addEventListener('click', () => this.runTSNE()); + this.stopTsneButton.addEventListener( + 'click', () => this.dataSet.stopTSNE()); + + this.perplexitySlider.value = this.perplexity.toString(); + this.perplexitySlider.addEventListener( + 'change', () => this.updateTSNEPerplexityFromSliderChange()); + this.updateTSNEPerplexityFromSliderChange(); + + this.learningRateInput.addEventListener( + 'change', () => this.updateTSNELearningRateFromUIChange()); + this.updateTSNELearningRateFromUIChange(); + + this.setupCustomProjectionInputFields(); + // TODO: figure out why `--paper-input-container-input` css mixin didn't + // work. + const inputs = + this.querySelectorAll('paper-dropdown-menu paper-input input'); + for (let i = 0; i < inputs.length; i++) { + (inputs[i] as HTMLElement).style.fontSize = '14px'; + } + } + + restoreUIFromBookmark(bookmark: State) { + this.disablePolymerChangesTriggerReprojection(); + + // PCA + this.pcaX = bookmark.pcaComponentDimensions[0]; + this.pcaY = bookmark.pcaComponentDimensions[1]; + if (bookmark.pcaComponentDimensions.length === 3) { + this.pcaZ = bookmark.pcaComponentDimensions[2]; + } + this.pcaIs3d = (bookmark.pcaComponentDimensions.length === 3); + + // t-SNE + if (this.perplexitySlider) { + this.perplexitySlider.value = bookmark.tSNEPerplexity.toString(); + } + if (this.learningRateInput) { + this.learningRateInput.value = bookmark.tSNELearningRate.toString(); + } + this.tSNEis3d = bookmark.tSNEis3d; + + // custom + this.customSelectedSearchByMetadataOption = + bookmark.customSelectedSearchByMetadataOption; + if (this.customProjectionXLeftInput) { + this.customProjectionXLeftInput.set( + bookmark.customXLeftText, bookmark.customXLeftRegex); + } + if (this.customProjectionXRightInput) { + this.customProjectionXRightInput.set( + bookmark.customXRightText, bookmark.customXRightRegex); + } + if (this.customProjectionYUpInput) { + this.customProjectionYUpInput.set( + bookmark.customYUpText, bookmark.customYUpRegex); + } + if (this.customProjectionYDownInput) { + this.customProjectionYDownInput.set( + bookmark.customYDownText, bookmark.customYDownRegex); + } + this.computeAllCentroids(); + + this.setZDropdownEnabled(this.pcaIs3d); + this.updateTSNEPerplexityFromSliderChange(); + this.updateTSNELearningRateFromUIChange(); + if (this.iterationLabel) { + this.iterationLabel.innerText = bookmark.tSNEIteration.toString(); + } + if (bookmark.selectedProjection != null) { + this.showTab(bookmark.selectedProjection); + } + this.enablePolymerChangesTriggerReprojection(); + } + + populateBookmarkFromUI(bookmark: State) { + this.disablePolymerChangesTriggerReprojection(); + + // PCA + bookmark.pcaComponentDimensions = [this.pcaX, this.pcaY]; + if (this.pcaIs3d) { + bookmark.pcaComponentDimensions.push(this.pcaZ); + } + + // t-SNE + if (this.perplexitySlider != null) { + bookmark.tSNEPerplexity = +this.perplexitySlider.value; + } + if (this.learningRateInput != null) { + bookmark.tSNELearningRate = +this.learningRateInput.value; + } + bookmark.tSNEis3d = this.tSNEis3d; + + // custom + bookmark.customSelectedSearchByMetadataOption = + this.customSelectedSearchByMetadataOption; + if (this.customProjectionXLeftInput != null) { + bookmark.customXLeftText = this.customProjectionXLeftInput.getValue(); + bookmark.customXLeftRegex = + this.customProjectionXLeftInput.getInRegexMode(); + } + if (this.customProjectionXRightInput != null) { + bookmark.customXRightText = this.customProjectionXRightInput.getValue(); + bookmark.customXRightRegex = + this.customProjectionXRightInput.getInRegexMode(); + } + if (this.customProjectionYUpInput != null) { + bookmark.customYUpText = this.customProjectionYUpInput.getValue(); + bookmark.customYUpRegex = this.customProjectionYUpInput.getInRegexMode(); + } + if (this.customProjectionYDownInput != null) { + bookmark.customYDownText = this.customProjectionYDownInput.getValue(); + bookmark.customYDownRegex = + this.customProjectionYDownInput.getInRegexMode(); + } + + this.enablePolymerChangesTriggerReprojection(); + } + + // This method is marked as public as it is used as the view method that + // abstracts DOM manipulation so we can stub it in a test. + // TODO(nsthorat): Move this to its own class as the glue between this class + // and the DOM. + setZDropdownEnabled(enabled: boolean) { + if (this.zDropdown) { + if (enabled) { + this.zDropdown.removeAttribute('disabled'); + } else { + this.zDropdown.setAttribute('disabled', 'true'); + } + } + } + + dataSetUpdated(dataSet: DataSet, originalDataSet: DataSet, dim: number) { + this.dataSet = dataSet; + this.originalDataSet = originalDataSet; + this.dim = dim; + const pointCount = (dataSet == null) ? 0 : dataSet.points.length; + const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4)); + this.perplexitySlider.value = perplexity.toString(); + this.updateTSNEPerplexityFromSliderChange(); + this.clearCentroids(); + + (this.querySelector('#tsne-sampling') as HTMLElement).style.display = + pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'; + const wasSampled = + (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || + dataSet.dim[1] > data.PCA_SAMPLE_DIM); + (this.querySelector('#pca-sampling') as HTMLElement).style.display = + wasSampled ? null : 'none'; + this.showTab('pca'); + } + + _pcaDimensionToggleObserver() { + this.setZDropdownEnabled(this.pcaIs3d); + this.beginProjection(this.currentProjection); + } + + _tsneDimensionToggleObserver() { + this.beginProjection(this.currentProjection); + } + + metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { + // Project by options for custom projections. + let searchByMetadataIndex = -1; + this.searchByMetadataOptions = spriteAndMetadata.stats.map((stats, i) => { + // Make the default label by the first non-numeric column. + if (!stats.isNumeric && searchByMetadataIndex === -1) { + searchByMetadataIndex = i; + } + return stats.name; + }); + this.customSelectedSearchByMetadataOption = + this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)]; + } + + public showTab(id: ProjectionType) { + this.currentProjection = id; + + const tab = + this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement; + const allTabs = this.querySelectorAll('.ink-tab'); + for (let i = 0; i < allTabs.length; i++) { + util.classed(allTabs[i] as HTMLElement, 'active', false); + } + + util.classed(tab, 'active', true); + + const allTabContent = this.querySelectorAll('.ink-panel-content'); + for (let i = 0; i < allTabContent.length; i++) { + util.classed(allTabContent[i] as HTMLElement, 'active', false); + } + + util.classed( + this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as + HTMLElement, + 'active', true); + + // guard for unit tests, where polymer isn't attached and $ doesn't exist. + if (this.$ != null) { + const main = this.$['main']; + // In order for the projections panel to animate its height, we need to + // set it explicitly. + requestAnimationFrame(() => { + this.style.height = main.clientHeight + 'px'; + }); + } + + this.beginProjection(id); + } + + private beginProjection(projection: ProjectionType) { + if (this.polymerChangesTriggerReprojection === false) { + return; + } + if (projection === 'pca') { + if (this.dataSet != null) { + this.dataSet.stopTSNE(); + } + this.showPCA(); + } else if (projection === 'tsne') { + this.showTSNE(); + } else if (projection === 'custom') { + if (this.dataSet != null) { + this.dataSet.stopTSNE(); + } + this.computeAllCentroids(); + this.reprojectCustom(); + } + } + + private showTSNE() { + const dataSet = this.dataSet; + if (dataSet == null) { + return; + } + const accessors = + data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]); + const dimensionality = this.tSNEis3d ? 3 : 2; + const projection = + new Projection('tsne', accessors, dimensionality, dataSet); + this.projector.setProjection(projection); + + if (!this.dataSet.hasTSNERun) { + this.runTSNE(); + } else { + this.projector.notifyProjectionPositionsUpdated(); + } + } + + private runTSNE() { + this.runTsneButton.disabled = true; + this.stopTsneButton.disabled = null; + this.dataSet.projectTSNE( + this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2, + (iteration: number) => { + if (iteration != null) { + this.iterationLabel.innerText = '' + iteration; + this.projector.notifyProjectionPositionsUpdated(); + } else { + this.runTsneButton.disabled = null; + this.stopTsneButton.disabled = true; + } + }); + } + + // tslint:disable-next-line:no-unused-variable + private showPCAIfEnabled() { + if (this.polymerChangesTriggerReprojection) { + this.showPCA(); + } + } + + private updateTotalVarianceMessage() { + let variances = this.dataSet.fracVariancesExplained; + let totalVariance = variances[this.pcaX] + variances[this.pcaY]; + let msg = 'Total variance described: '; + if (this.pcaIs3d) { + totalVariance += variances[this.pcaZ]; + } + msg += (totalVariance * 100).toFixed(1) + '%.'; + (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg; + } + + private showPCA() { + if (this.dataSet == null) { + return; + } + this.dataSet.projectPCA().then(() => { + // Polymer properties are 1-based. + const accessors = data.getProjectionComponents( + 'pca', [this.pcaX, this.pcaY, this.pcaZ]); + + const dimensionality = this.pcaIs3d ? 3 : 2; + const projection = + new Projection('pca', accessors, dimensionality, this.dataSet); + this.projector.setProjection(projection); + let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); + this.updateTotalVarianceMessage(); + this.pcaComponents = util.range(numComponents).map(i => { + let fracVariance = this.dataSet.fracVariancesExplained[i]; + return { + id: i, + componentNumber: i + 1, + percVariance: (fracVariance * 100).toFixed(1) + }; + }); + }); + } + + private reprojectCustom() { + if (this.centroids == null || this.centroids.xLeft == null || + this.centroids.xRight == null || this.centroids.yUp == null || + this.centroids.yDown == null) { + return; + } + const xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft); + this.dataSet.projectLinear(xDir, 'linear-x'); + + const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); + this.dataSet.projectLinear(yDir, 'linear-y'); + + const accessors = data.getProjectionComponents('custom', ['x', 'y']); + const projection = new Projection('custom', accessors, 2, this.dataSet); + this.projector.setProjection(projection); + } + + clearCentroids(): void { + this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null}; + this.allCentroid = null; + } + + _customSelectedSearchByMetadataOptionChanged(newVal: string, oldVal: string) { + if (this.polymerChangesTriggerReprojection === false) { + return; + } + if (this.currentProjection === 'custom') { + this.computeAllCentroids(); + this.reprojectCustom(); + } + } + + private setupCustomProjectionInputFields() { + this.customProjectionXLeftInput = + this.setupCustomProjectionInputField('xLeft'); + this.customProjectionXRightInput = + this.setupCustomProjectionInputField('xRight'); + this.customProjectionYUpInput = this.setupCustomProjectionInputField('yUp'); + this.customProjectionYDownInput = + this.setupCustomProjectionInputField('yDown'); + } + + private computeAllCentroids() { + this.computeCentroid('xLeft'); + this.computeCentroid('xRight'); + this.computeCentroid('yUp'); + this.computeCentroid('yDown'); + } + + private computeCentroid(name: InputControlName) { + const input = this.querySelector('#' + name) as ProjectorInput; + if (input == null) { + return; + } + const value = input.getValue(); + if (value == null) { + return; + } + let inRegexMode = input.getInRegexMode(); + let result = this.getCentroid(value, inRegexMode); + if (result.numMatches === 0) { + input.message = '0 matches. Using a random vector.'; + result.centroid = vector.rn(this.dim); + } else { + input.message = `${result.numMatches} matches.`; + } + this.centroids[name] = result.centroid; + this.centroidValues[name] = value; + } + + private setupCustomProjectionInputField(name: InputControlName): + ProjectorInput { + let input = this.querySelector('#' + name) as ProjectorInput; + input.registerInputChangedListener((input, inRegexMode) => { + if (this.polymerChangesTriggerReprojection) { + this.computeCentroid(name); + this.reprojectCustom(); + } + }); + return input; + } + + private getCentroid(pattern: string, inRegexMode: boolean): CentroidResult { + if (pattern == null || pattern === '') { + return {numMatches: 0}; + } + // Search by the original dataset since we often want to filter and project + // only the nearest neighbors of A onto B-C where B and C are not nearest + // neighbors of A. + let accessor = (i: number) => this.originalDataSet.points[i].vector; + let r = this.originalDataSet.query( + pattern, inRegexMode, this.customSelectedSearchByMetadataOption); + return {centroid: vector.centroid(r, accessor), numMatches: r.length}; + } + + getPcaSampledDimText() { + return data.PCA_SAMPLE_DIM.toLocaleString(); + } + + getPcaSampleSizeText() { + return data.PCA_SAMPLE_SIZE.toLocaleString(); + } + + getTsneSampleSizeText() { + return data.TSNE_SAMPLE_SIZE.toLocaleString(); + } +} + +document.registerElement(ProjectionsPanel.prototype.is, ProjectionsPanel); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts new file mode 100644 index 00000000000..fd1acf6f085 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts @@ -0,0 +1,109 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {State} from './data'; +import {ProjectionsPanel} from './vz-projector-projections-panel'; + +const assert = chai.assert; + +describe('restoreUIFromBookmark', () => { + let projectionsPanel: ProjectionsPanel; + beforeEach(() => { + projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as + ProjectionsPanel; + + // Set up some of the UI so the elements are found in the production code. + const tsnePerplexityContainer = document.createElement('div'); + tsnePerplexityContainer.className = 'tsne-perplexity'; + const tsnePerplexity = document.createElement('span'); + tsnePerplexityContainer.appendChild(tsnePerplexity); + projectionsPanel.appendChild(tsnePerplexityContainer); + + const tsneLearningRateContainer = document.createElement('div'); + tsneLearningRateContainer.className = 'tsne-learning-rate'; + const tsneLearningRate = document.createElement('span'); + tsneLearningRateContainer.appendChild(tsneLearningRate); + projectionsPanel.appendChild(tsneLearningRateContainer); + }); + + it('sets the pcaX/Y properties when setting 2D component values', () => { + spyOn(projectionsPanel, 'setZDropdownEnabled'); + + const s = new State(); + s.pcaComponentDimensions = [0, 1]; + projectionsPanel.restoreUIFromBookmark(s); + + assert.equal(0, projectionsPanel.pcaX); + assert.equal(1, projectionsPanel.pcaY); + + expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(false); + }); + + it('sets the pcaX/Y properties when setting 3D component values', () => { + spyOn(projectionsPanel, 'setZDropdownEnabled'); + + const s = new State(); + s.pcaComponentDimensions = [0, 1, 2]; + projectionsPanel.restoreUIFromBookmark(s); + + assert.equal(0, projectionsPanel.pcaX); + assert.equal(1, projectionsPanel.pcaY); + assert.equal(2, projectionsPanel.pcaZ); + + expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(true); + }); +}); + +describe('populateBookmarkFromUI', () => { + let projectionsPanel: ProjectionsPanel; + + beforeEach(() => { + projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as + ProjectionsPanel; + + // Set up some of the UI so the elements are found in the production code. + const tsnePerplexityContainer = document.createElement('div'); + tsnePerplexityContainer.className = 'tsne-perplexity'; + const tsnePerplexity = document.createElement('span'); + tsnePerplexityContainer.appendChild(tsnePerplexity); + projectionsPanel.appendChild(tsnePerplexityContainer); + + const tsneLearningRateContainer = document.createElement('div'); + tsneLearningRateContainer.className = 'tsne-learning-rate'; + const tsneLearningRate = document.createElement('span'); + tsneLearningRateContainer.appendChild(tsneLearningRate); + projectionsPanel.appendChild(tsneLearningRateContainer); + }); + + it('gets the PCA component UI values from a 2D PCA projection', () => { + projectionsPanel.pcaX = 0; + projectionsPanel.pcaY = 1; + projectionsPanel.pcaIs3d = false; + + const s = new State(); + projectionsPanel.populateBookmarkFromUI(s); + assert.deepEqual([0, 1], s.pcaComponentDimensions); + }); + + it('gets the PCA component UI values from a 3D PCA projection', () => { + projectionsPanel.pcaX = 0; + projectionsPanel.pcaY = 1; + projectionsPanel.pcaZ = 2; + projectionsPanel.pcaIs3d = true; + + const s = new State(); + projectionsPanel.populateBookmarkFromUI(s); + assert.deepEqual([0, 1, 2], s.pcaComponentDimensions); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts new file mode 100644 index 00000000000..44062062a36 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export type Spec = { + is: string; properties?: { + [key: string]: + (Function | + { + type: Function, value?: any; + readonly?: boolean; + notify?: boolean; + observer?: string; + }) + }; + observers?: string[]; +}; + +export function PolymerElement(spec: Spec) { + return Polymer.Class(spec as any) as{new (): PolymerHTMLElement}; +} + +export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html new file mode 100644 index 00000000000..d4be2f26a5d --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html @@ -0,0 +1,343 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts new file mode 100644 index 00000000000..bf98a4d4785 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts @@ -0,0 +1,570 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {AnalyticsLogger} from './analyticsLogger'; +import * as data from './data'; +import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; +import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; +import {DemoDataProvider} from './data-provider-demo'; +import {ProtoDataProvider} from './data-provider-proto'; +import {ServerDataProvider} from './data-provider-server'; +import * as knn from './knn'; +import * as logging from './logging'; +import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; +import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; +import {MouseMode} from './scatterPlot'; +import * as util from './util'; +import {BookmarkPanel} from './vz-projector-bookmark-panel'; +import {DataPanel} from './vz-projector-data-panel'; +import {InspectorPanel} from './vz-projector-inspector-panel'; +import {MetadataCard} from './vz-projector-metadata-card'; +import {ProjectionsPanel} from './vz-projector-projections-panel'; +// tslint:disable-next-line:no-unused-variable +import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; + +/** + * The minimum number of dimensions the data should have to automatically + * decide to normalize the data. + */ +const THRESHOLD_DIM_NORMALIZE = 50; +const POINT_COLOR_MISSING = 'black'; + +export let ProjectorPolymer = PolymerElement({ + is: 'vz-projector', + properties: { + routePrefix: String, + dataProto: {type: String, observer: '_dataProtoChanged'}, + servingMode: String, + projectorConfigJsonPath: String, + pageViewLogging: Boolean, + eventLogging: Boolean + } +}); + +const INDEX_METADATA_FIELD = '__index__'; + +export class Projector extends ProjectorPolymer implements + ProjectorEventContext { + // The working subset of the data source's original data set. + dataSet: DataSet; + servingMode: ServingMode; + // The path to the projector config JSON file for demo mode. + projectorConfigJsonPath: string; + + private selectionChangedListeners: SelectionChangedListener[]; + private hoverListeners: HoverListener[]; + private projectionChangedListeners: ProjectionChangedListener[]; + private distanceMetricChangedListeners: DistanceMetricChangedListener[]; + + private originalDataSet: DataSet; + private dataSetBeforeFilter: DataSet; + private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; + private dim: number; + + private dataSetFilterIndices: number[]; + private selectedPointIndices: number[]; + private neighborsOfFirstPoint: knn.NearestEntry[]; + private hoverPointIndex: number; + + private dataProvider: DataProvider; + private inspectorPanel: InspectorPanel; + + private selectedColorOption: ColorOption; + private selectedLabelOption: string; + private routePrefix: string; + private normalizeData: boolean; + private projection: Projection; + + /** Polymer component panels */ + private dataPanel: DataPanel; + private bookmarkPanel: BookmarkPanel; + private projectionsPanel: ProjectionsPanel; + private metadataCard: MetadataCard; + + private statusBar: HTMLDivElement; + private analyticsLogger: AnalyticsLogger; + private eventLogging: boolean; + private pageViewLogging: boolean; + + ready() { + logging.setDomContainer(this); + + this.analyticsLogger = + new AnalyticsLogger(this.pageViewLogging, this.eventLogging); + this.analyticsLogger.logPageView('embeddings'); + + if (!util.hasWebGLSupport()) { + this.analyticsLogger.logWebGLDisabled(); + logging.setErrorMessage( + 'Your browser or device does not have WebGL enabled. Please enable ' + + 'hardware acceleration, or use a browser that supports WebGL.'); + return; + } + + this.selectionChangedListeners = []; + this.hoverListeners = []; + this.projectionChangedListeners = []; + this.distanceMetricChangedListeners = []; + this.selectedPointIndices = []; + this.neighborsOfFirstPoint = []; + + this.dataPanel = this.$['data-panel'] as DataPanel; + this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel; + this.inspectorPanel.initialize(this, this as ProjectorEventContext); + this.projectionsPanel = this.$['projections-panel'] as ProjectionsPanel; + this.projectionsPanel.initialize(this); + this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel; + this.bookmarkPanel.initialize(this, this as ProjectorEventContext); + this.metadataCard = this.$['metadata-card'] as MetadataCard; + this.statusBar = this.querySelector('#status-bar') as HTMLDivElement; + this.scopeSubtree(this.$$('#notification-dialog'), true); + this.setupUIControls(); + this.initializeDataProvider(); + } + + setSelectedLabelOption(labelOption: string) { + this.selectedLabelOption = labelOption; + this.metadataCard.setLabelOption(this.selectedLabelOption); + this.projectorScatterPlotAdapter.setLabelPointAccessor(labelOption); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.projectorScatterPlotAdapter.render(); + } + + setSelectedColorOption(colorOption: ColorOption) { + this.selectedColorOption = colorOption; + this.projectorScatterPlotAdapter.setLegendPointColorer( + this.getLegendPointColorer(colorOption)); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.projectorScatterPlotAdapter.render(); + } + + setNormalizeData(normalizeData: boolean) { + this.normalizeData = normalizeData; + this.setCurrentDataSet(this.originalDataSet.getSubset()); + } + + updateDataSet( + ds: DataSet, spriteAndMetadata?: SpriteAndMetadataInfo, + metadataFile?: string) { + this.dataSetFilterIndices = null; + this.originalDataSet = ds; + if (ds != null) { + this.normalizeData = + this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; + spriteAndMetadata = spriteAndMetadata || {}; + if (spriteAndMetadata.pointsInfo == null) { + let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); + spriteAndMetadata.pointsInfo = pointsInfo; + spriteAndMetadata.stats = stats; + } + let metadataMergeSucceeded = ds.mergeMetadata(spriteAndMetadata); + if (!metadataMergeSucceeded) { + return; + } + } + if (this.projectorScatterPlotAdapter != null) { + if (ds == null) { + this.projectorScatterPlotAdapter.setLabelPointAccessor(null); + this.setProjection(null); + } else { + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.projectorScatterPlotAdapter.resize(); + this.projectorScatterPlotAdapter.render(); + } + } + if (ds != null) { + this.dataPanel.setNormalizeData(this.normalizeData); + this.setCurrentDataSet(ds.getSubset()); + this.projectorScatterPlotAdapter.setLabelPointAccessor( + this.selectedLabelOption); + this.inspectorPanel.datasetChanged(); + + this.inspectorPanel.metadataChanged(spriteAndMetadata); + this.projectionsPanel.metadataChanged(spriteAndMetadata); + this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); + // Set the container to a fixed height, otherwise in Colab the + // height can grow indefinitely. + const container = this.querySelector('#container') as HTMLDivElement; + container.style.height = container.clientHeight + 'px'; + } else { + this.setCurrentDataSet(null); + } + } + + setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { + this.bookmarkPanel.setSelectedTensor(run, tensorInfo, this.dataProvider); + } + + /** + * Registers a listener to be called any time the selected point set changes. + */ + registerSelectionChangedListener(listener: SelectionChangedListener) { + this.selectionChangedListeners.push(listener); + } + + filterDataset(pointIndices: number[]) { + const selectionSize = this.selectedPointIndices.length; + if (this.dataSetBeforeFilter == null) { + this.dataSetBeforeFilter = this.dataSet; + } + this.setCurrentDataSet(this.dataSet.getSubset(pointIndices)); + this.dataSetFilterIndices = pointIndices; + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.adjustSelectionAndHover(util.range(selectionSize)); + } + + resetFilterDataset() { + const originalPointIndices = this.selectedPointIndices.map( + filteredIndex => this.dataSet.points[filteredIndex].index); + this.setCurrentDataSet(this.dataSetBeforeFilter); + if (this.projection != null) { + this.projection.dataSet = this.dataSetBeforeFilter; + } + this.dataSetBeforeFilter = null; + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.dataSetFilterIndices = []; + this.adjustSelectionAndHover(originalPointIndices); + } + + /** + * Used by clients to indicate that a selection has occurred. + */ + notifySelectionChanged(newSelectedPointIndices: number[]) { + this.selectedPointIndices = newSelectedPointIndices; + let neighbors: knn.NearestEntry[] = []; + + if (newSelectedPointIndices.length === 1) { + neighbors = this.dataSet.findNeighbors( + newSelectedPointIndices[0], this.inspectorPanel.distFunc, + this.inspectorPanel.numNN); + this.metadataCard.updateMetadata( + this.dataSet.points[newSelectedPointIndices[0]].metadata); + } else { + this.metadataCard.updateMetadata(null); + } + + this.selectionChangedListeners.forEach( + l => l(this.selectedPointIndices, neighbors)); + } + + /** + * Registers a listener to be called any time the mouse hovers over a point. + */ + registerHoverListener(listener: HoverListener) { + this.hoverListeners.push(listener); + } + + /** + * Used by clients to indicate that a hover is occurring. + */ + notifyHoverOverPoint(pointIndex: number) { + this.hoverListeners.forEach(l => l(pointIndex)); + } + + registerProjectionChangedListener(listener: ProjectionChangedListener) { + this.projectionChangedListeners.push(listener); + } + + notifyProjectionChanged(projection: Projection) { + this.projectionChangedListeners.forEach(l => l(projection)); + } + + registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) { + this.distanceMetricChangedListeners.push(l); + } + + notifyDistanceMetricChanged(distMetric: DistanceFunction) { + this.distanceMetricChangedListeners.forEach(l => l(distMetric)); + } + + _dataProtoChanged(dataProtoString: string) { + let dataProto = + dataProtoString ? JSON.parse(dataProtoString) as DataProto : null; + this.initializeDataProvider(dataProto); + } + + private makeDefaultPointsInfoAndStats(points: DataPoint[]): + [PointMetadata[], ColumnStats[]] { + let pointsInfo: PointMetadata[] = []; + points.forEach(p => { + let pointInfo: PointMetadata = {}; + pointInfo[INDEX_METADATA_FIELD] = p.index; + pointsInfo.push(pointInfo); + }); + let stats: ColumnStats[] = [{ + name: INDEX_METADATA_FIELD, + isNumeric: false, + tooManyUniqueValues: true, + min: 0, + max: pointsInfo.length - 1 + }]; + return [pointsInfo, stats]; + } + + private initializeDataProvider(dataProto?: DataProto) { + if (this.servingMode === 'demo') { + let projectorConfigUrl: string; + + // Only in demo mode do we allow the config being passed via URL. + let urlParams = util.getURLParams(window.location.search); + if ('config' in urlParams) { + projectorConfigUrl = urlParams['config']; + } else { + projectorConfigUrl = this.projectorConfigJsonPath; + } + this.dataProvider = new DemoDataProvider(projectorConfigUrl); + } else if (this.servingMode === 'server') { + if (!this.routePrefix) { + throw 'route-prefix is a required parameter'; + } + this.dataProvider = new ServerDataProvider(this.routePrefix); + } else if (this.servingMode === 'proto' && dataProto != null) { + this.dataProvider = new ProtoDataProvider(dataProto); + } + + this.dataPanel.initialize(this, this.dataProvider); + } + + private getLegendPointColorer(colorOption: ColorOption): + (ds: DataSet, index: number) => string { + if ((colorOption == null) || (colorOption.map == null)) { + return null; + } + const colorer = (ds: DataSet, i: number) => { + let value = ds.points[i].metadata[this.selectedColorOption.name]; + if (value == null) { + return POINT_COLOR_MISSING; + } + return colorOption.map(value); + }; + return colorer; + } + + private get3DLabelModeButton(): any { + return this.querySelector('#labels3DMode'); + } + + private get3DLabelMode(): boolean { + const label3DModeButton = this.get3DLabelModeButton(); + return (label3DModeButton as any).active; + } + + adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { + this.notifySelectionChanged(selectedPointIndices); + this.notifyHoverOverPoint(hoverIndex); + this.setMouseMode(MouseMode.CAMERA_AND_CLICK_SELECT); + } + + private setMouseMode(mouseMode: MouseMode) { + let selectModeButton = this.querySelector('#selectMode'); + (selectModeButton as any).active = (mouseMode === MouseMode.AREA_SELECT); + this.projectorScatterPlotAdapter.scatterPlot.setMouseMode(mouseMode); + } + + private setCurrentDataSet(ds: DataSet) { + this.adjustSelectionAndHover([]); + if (this.dataSet != null) { + this.dataSet.stopTSNE(); + } + if ((ds != null) && this.normalizeData) { + ds.normalize(); + } + this.dim = (ds == null) ? 0 : ds.dim[1]; + (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText = + (ds == null) ? '0' : '' + ds.dim[0]; + (this.querySelector('span.dim') as HTMLSpanElement).innerText = + (ds == null) ? '0' : '' + ds.dim[1]; + + this.dataSet = ds; + + this.projectionsPanel.dataSetUpdated( + this.dataSet, this.originalDataSet, this.dim); + + this.projectorScatterPlotAdapter.setDataSet(this.dataSet); + this.projectorScatterPlotAdapter.scatterPlot + .setCameraParametersForNextCameraCreation(null, true); + } + + private setupUIControls() { + // View controls + this.querySelector('#reset-zoom').addEventListener('click', () => { + this.projectorScatterPlotAdapter.scatterPlot.resetZoom(); + this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation(); + }); + + let selectModeButton = this.querySelector('#selectMode'); + selectModeButton.addEventListener('click', (event) => { + this.setMouseMode( + (selectModeButton as any).active ? MouseMode.AREA_SELECT : + MouseMode.CAMERA_AND_CLICK_SELECT); + }); + let nightModeButton = this.querySelector('#nightDayMode'); + nightModeButton.addEventListener('click', () => { + this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode( + (nightModeButton as any).active); + }); + + const labels3DModeButton = this.get3DLabelModeButton(); + labels3DModeButton.addEventListener('click', () => { + this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode()); + }); + + window.addEventListener('resize', () => { + const container = this.querySelector('#container') as HTMLDivElement; + const parentHeight = (container.parentNode as HTMLElement).clientHeight; + container.style.height = parentHeight + 'px'; + this.projectorScatterPlotAdapter.resize(); + }); + + { + this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter( + this.getScatterContainer(), this as ProjectorEventContext); + this.projectorScatterPlotAdapter.setLabelPointAccessor( + this.selectedLabelOption); + } + + this.projectorScatterPlotAdapter.scatterPlot.onCameraMove( + (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => + this.bookmarkPanel.clearStateSelection()); + + this.registerHoverListener( + (hoverIndex: number) => this.onHover(hoverIndex)); + + this.registerSelectionChangedListener( + (selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[]) => + this.onSelectionChanged( + selectedPointIndices, neighborsOfFirstPoint)); + } + + private onHover(hoverIndex: number) { + this.hoverPointIndex = hoverIndex; + let hoverText = null; + if (hoverIndex != null) { + const point = this.dataSet.points[hoverIndex]; + if (point.metadata[this.selectedLabelOption]) { + hoverText = point.metadata[this.selectedLabelOption].toString(); + } + } + if (this.selectedPointIndices.length === 0) { + this.statusBar.style.display = hoverText ? null : 'none'; + this.statusBar.innerText = hoverText; + } + } + + private getScatterContainer(): HTMLDivElement { + return this.querySelector('#scatter') as HTMLDivElement; + } + + private onSelectionChanged( + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[]) { + this.selectedPointIndices = selectedPointIndices; + this.neighborsOfFirstPoint = neighborsOfFirstPoint; + let totalNumPoints = + this.selectedPointIndices.length + neighborsOfFirstPoint.length; + this.statusBar.innerText = `Selected ${totalNumPoints} points`; + this.statusBar.style.display = totalNumPoints > 0 ? null : 'none'; + } + + setProjection(projection: Projection) { + this.projection = projection; + if (projection != null) { + this.analyticsLogger.logProjectionChanged(projection.projectionType); + } + this.notifyProjectionChanged(projection); + } + + notifyProjectionPositionsUpdated() { + this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated(); + } + + /** + * Gets the current view of the embedding and saves it as a State object. + */ + getCurrentState(): State { + const state = new State(); + + // Save the individual datapoint projections. + state.projections = []; + for (let i = 0; i < this.dataSet.points.length; i++) { + const point = this.dataSet.points[i]; + const projections: {[key: string]: number} = {}; + const keys = Object.keys(point.projections); + for (let j = 0; j < keys.length; ++j) { + projections[keys[j]] = point.projections[keys[j]]; + } + state.projections.push(projections); + } + state.selectedProjection = this.projection.projectionType; + state.dataSetDimensions = this.dataSet.dim; + state.tSNEIteration = this.dataSet.tSNEIteration; + state.selectedPoints = this.selectedPointIndices; + state.filteredPoints = this.dataSetFilterIndices; + this.projectorScatterPlotAdapter.populateBookmarkFromUI(state); + state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; + state.forceCategoricalColoring = this.dataPanel.forceCategoricalColoring; + state.selectedLabelOption = this.selectedLabelOption; + this.projectionsPanel.populateBookmarkFromUI(state); + return state; + } + + /** Loads a State object into the world. */ + loadState(state: State) { + this.setProjection(null); + { + this.projectionsPanel.disablePolymerChangesTriggerReprojection(); + if (this.dataSetBeforeFilter != null) { + this.resetFilterDataset(); + } + if (state.filteredPoints != null) { + this.filterDataset(state.filteredPoints); + } + this.projectionsPanel.enablePolymerChangesTriggerReprojection(); + } + for (let i = 0; i < state.projections.length; i++) { + const point = this.dataSet.points[i]; + const projection = state.projections[i]; + const keys = Object.keys(projection); + for (let j = 0; j < keys.length; ++j) { + point.projections[keys[j]] = projection[keys[j]]; + } + } + this.dataSet.hasTSNERun = (state.selectedProjection === 'tsne'); + this.dataSet.tSNEIteration = state.tSNEIteration; + this.projectionsPanel.restoreUIFromBookmark(state); + this.inspectorPanel.restoreUIFromBookmark(state); + this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; + this.dataPanel.setForceCategoricalColoring( + !!state.forceCategoricalColoring); + this.selectedLabelOption = state.selectedLabelOption; + this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); + { + const dimensions = stateGetAccessorDimensions(state); + const components = + data.getProjectionComponents(state.selectedProjection, dimensions); + const projection = new Projection( + state.selectedProjection, components, dimensions.length, + this.dataSet); + this.setProjection(projection); + } + this.notifySelectionChanged(state.selectedPoints); + } +} + +document.registerElement(Projector.prototype.is, Projector); diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html index 656a69f836c..8610940ac3c 100644 --- a/tensorflow/tensorboard/dist/tf-tensorboard.html +++ b/tensorflow/tensorboard/dist/tf-tensorboard.html @@ -27140,4 +27140,4 @@ arguments[4][8][0].apply(exports,arguments) },{"dup":8}]},{},[35,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34]); - \ No newline at end of file + diff --git a/tensorflow/tensorboard/gulp_tasks/compile.js b/tensorflow/tensorboard/gulp_tasks/compile.js index 44d703501ea..01af60eba77 100644 --- a/tensorflow/tensorboard/gulp_tasks/compile.js +++ b/tensorflow/tensorboard/gulp_tasks/compile.js @@ -27,7 +27,7 @@ const concat = require('gulp-concat'); const tsProject = ts.createProject('./tsconfig.json', { typescript: typescript, - noExternalResolve: true, // opt-in for faster compilation! + noExternalResolve: true, // opt-in for faster compilation! }); /** List of components (and their external deps) that are using es6 modules. */ diff --git a/tensorflow/tensorboard/gulp_tasks/vulcanize.js b/tensorflow/tensorboard/gulp_tasks/vulcanize.js index 21ea701f4ab..d2286f1d6c5 100644 --- a/tensorflow/tensorboard/gulp_tasks/vulcanize.js +++ b/tensorflow/tensorboard/gulp_tasks/vulcanize.js @@ -21,7 +21,8 @@ const replace = require('gulp-replace'); const rename = require('gulp-rename'); const header = require('gulp-header'); -const HEADER_STR = '